diff --git a/.github/workflows/build_wheel.yml b/.github/workflows/build_wheel.yml index d430fc312..e3c409f68 100644 --- a/.github/workflows/build_wheel.yml +++ b/.github/workflows/build_wheel.yml @@ -156,22 +156,8 @@ jobs: - name: Build PTOAS run: | export PATH="${PY_PATH}/bin:$PATH" - cd $PTO_SOURCE_DIR - cmake -C "$PTO_SOURCE_DIR/cmake/LinuxHardeningCache.cmake" -G Ninja \ - -S . \ - -B build \ - -DLLVM_DIR=$LLVM_BUILD_DIR/lib/cmake/llvm \ - -DMLIR_DIR=$LLVM_BUILD_DIR/lib/cmake/mlir \ - -DPython3_ROOT_DIR=${PY_PATH} \ - -DPython3_EXECUTABLE=${PY_PATH}/bin/python \ - -DPython3_FIND_STRATEGY=LOCATION \ - -Dpybind11_DIR=$(${PY_PATH}/bin/python -m pybind11 --cmakedir) \ - -DMLIR_PYTHON_PACKAGE_DIR=${LLVM_BUILD_DIR}/tools/mlir/python_packages/mlir_core \ - -DPTOAS_RELEASE_VERSION_OVERRIDE=${PTOAS_VERSION} \ - -DCMAKE_INSTALL_PREFIX=${PTO_INSTALL_DIR} \ - -DCMAKE_BUILD_TYPE=Release - ninja -C build - ninja -C build install + PTOAS_RELEASE_VERSION_OVERRIDE="${PTOAS_VERSION}" \ + pip install . --no-build-isolation - name: Create Python wheel if: false diff --git a/.github/workflows/build_wheel_mac.yml b/.github/workflows/build_wheel_mac.yml index 0e370a33d..809424447 100644 --- a/.github/workflows/build_wheel_mac.yml +++ b/.github/workflows/build_wheel_mac.yml @@ -154,22 +154,8 @@ jobs: - name: Build PTOAS run: | - cd $PTO_SOURCE_DIR - cmake -G Ninja \ - -S . \ - -B build \ - -DLLVM_DIR=$LLVM_BUILD_DIR/lib/cmake/llvm \ - -DMLIR_DIR=$LLVM_BUILD_DIR/lib/cmake/mlir \ - -DPython3_ROOT_DIR=${PY_PATH} \ - -DPython3_EXECUTABLE=$(which python) \ - -DPython3_FIND_STRATEGY=LOCATION \ - -Dpybind11_DIR=$(python -m pybind11 --cmakedir) \ - -DMLIR_PYTHON_PACKAGE_DIR=${LLVM_BUILD_DIR}/tools/mlir/python_packages/mlir_core \ - -DPTOAS_RELEASE_VERSION_OVERRIDE=${PTOAS_VERSION} \ - -DCMAKE_INSTALL_PREFIX=${PTO_INSTALL_DIR} \ - -DCMAKE_BUILD_TYPE=Release - ninja -C build - ninja -C build install + PTOAS_RELEASE_VERSION_OVERRIDE="${PTOAS_VERSION}" \ + pip install . --no-build-isolation - name: Create Python wheel if: false diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 36d2ac3c4..8651ab2d3 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -200,20 +200,9 @@ jobs: - name: Build PTOAS run: | - export PYBIND11_CMAKE_DIR="$(python3 -m pybind11 --cmakedir)" - cmake -C "${GITHUB_WORKSPACE}/cmake/LinuxHardeningCache.cmake" -G Ninja -S . -B build \ - -DLLVM_DIR="${LLVM_DIR}/lib/cmake/llvm" \ - -DMLIR_DIR="${LLVM_DIR}/lib/cmake/mlir" \ - -DPython3_EXECUTABLE=python3 \ - -DPython3_FIND_STRATEGY=LOCATION \ - -Dpybind11_DIR="${PYBIND11_CMAKE_DIR}" \ - -DMLIR_ENABLE_BINDINGS_PYTHON=ON \ - -DMLIR_PYTHON_PACKAGE_DIR="${LLVM_DIR}/tools/mlir/python_packages/mlir_core" \ - -DCMAKE_INSTALL_PREFIX="${PTO_INSTALL_DIR}" \ - -DCMAKE_BUILD_TYPE=Release - ninja -C build ptoas - ninja -C build ptobc - ninja -C build install + # LLVM_BUILD_DIR is the env var read by the build backend (_ptoas_build_backend.py). + # PTO_INSTALL_DIR is already set at the job level. + LLVM_BUILD_DIR="${LLVM_DIR}" pip install . --no-build-isolation - name: Run lit tests shell: bash @@ -398,17 +387,8 @@ jobs: shell: bash run: | set -euo pipefail - export PYBIND11_CMAKE_DIR="$(python3 -m pybind11 --cmakedir)" - cmake -G Ninja -S . -B build \ - -DLLVM_DIR="${LLVM_DIR}/lib/cmake/llvm" \ - -DMLIR_DIR="${LLVM_DIR}/lib/cmake/mlir" \ - -DPython3_EXECUTABLE=python3 \ - -DPython3_FIND_STRATEGY=LOCATION \ - -Dpybind11_DIR="${PYBIND11_CMAKE_DIR}" \ - -DMLIR_ENABLE_BINDINGS_PYTHON=ON \ - -DMLIR_PYTHON_PACKAGE_DIR="${LLVM_DIR}/tools/mlir/python_packages/mlir_core" \ - -DCMAKE_BUILD_TYPE=Release - ninja -C build ptoas + # LLVM_BUILD_DIR is the env var read by the build backend (_ptoas_build_backend.py). + LLVM_BUILD_DIR="${LLVM_DIR}" pip install . --no-build-isolation - name: Resolve simulator environment shell: bash diff --git a/.gitignore b/.gitignore index 61f15f6b0..db19e142a 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,13 @@ # Build artifacts build/ +build_plain/ +build_plan/ install/ + +# TileLang ST standalone build outputs (see temp_docs/standalone_st.md) +test/tilelang_st/npu/a5/src/st/build/ +test/tilelang_st/npu/a5/src/st/build_plain/ +test/tilelang_st/npu/a5/src/st/build_plan/ cmake-build-*/ CMakeFiles/ CMakeCache.txt @@ -49,6 +56,7 @@ venv/ dist/ # Logs/temp +tmp/ *.log *.tmp *.swp @@ -56,6 +64,13 @@ dist/ .cache/ .pytest_cache/ +# PTODSL JIT / msprof simulator artifacts +.ptodsl_jit/ +.ptodsl_cache/ +msprof_res/ +ptodsl/examples/jit/.cache/ +ptodsl/examples/jit/msprof_res/ + # Remote/NPU validation artifacts /payload/ /payload.tgz diff --git a/README_en.md b/README_en.md new file mode 100644 index 000000000..b7a060a0f --- /dev/null +++ b/README_en.md @@ -0,0 +1,236 @@ +# ptoas (PTO Assembler & Optimizer) + +## 1. Introduction + +**ptoas** is a specialized compiler toolchain built on top of **LLVM/MLIR (llvmorg-19.1.7)** *(Commit cd708029e0b2869e80abe31ddb175f7c35361f90)*, designed specifically for **PTO Bytecode** (Programming Tiling Operator Bytecode). + +Acting as the bridge between upper-level AI frameworks and underlying NPU/GPGPU/CPU hardware, `ptoas` is built in an **Out-of-Tree** architecture and provides complete C++ and Python interfaces. Its primary responsibilities include: + +1. **IR Parsing & Verification**: Parses `.pto` input files and verifies the semantic correctness of PTO Dialect operations (Ops). +2. **Compilation & Optimization (Passes)**: Executes optimization passes targeting the Da Vinci Architecture, such as operator fusion and automatic synchronization insertion. +3. **Code Generation (Lowering)**: Supports lowering PTO IR to `EmitC` / `Linalg` dialects, ultimately generating code that calls the `pto-isa` C++ library. +4. **Python Bindings**: Provides seamlessly integrated Python modules. Through integration with MLIR Core bindings, frameworks such as **PyPTO**, **TileLang**, and **CuTile** can build, manipulate, and compile PTO Bytecode directly from Python. + +--- + +## 2. Directory Structure + +```text +PTOAS/ +├── include/ +│ └── PTO/ # PTO Dialect headers and TableGen definitions (.td) +├── lib/ +│ ├── PTO/ # Dialect core implementation (IR) and Pass logic (Transforms) +│ ├── CAPI/ # C language interface exposure +│ └── Bindings/Python/ # Python Binding C++ implementation (Pybind11) +├── python/ # Python module build scripts and helper code +├── test/ +│ └── samples/ # Test cases +├── tools/ +│ ├── ptoas/ # ptoas command-line tool entry point (Output: ptoas) +│ └── ptobc/ # ptobc command-line tool entry point (Output: ptobc) +└── CMakeLists.txt # Top-level build configuration +``` + +--- + +## 3. Build Instructions + +⚠️ **Important**: This project strictly requires **LLVM llvmorg-19.1.7**. + +### 3.0 Environment Variable Configuration + +To simplify the build process, **first modify and run the following commands according to your environment**. Subsequent steps reference these variables directly. + +```bash +# ================= Configuration (edit here) ================= +# Set your workspace root directory +# (recommended: a dedicated directory for LLVM and PTOAS) +export WORKSPACE_DIR=$HOME/llvm-workspace + +# LLVM source and build paths +export LLVM_SOURCE_DIR=$WORKSPACE_DIR/llvm-project +export LLVM_BUILD_DIR=$LLVM_SOURCE_DIR/build-shared + +# PTOAS source and install paths +export PTO_SOURCE_DIR=$WORKSPACE_DIR/PTOAS +export PTO_INSTALL_DIR=$PTO_SOURCE_DIR/install +# ============================================================= + +# Create the workspace directory +mkdir -p $WORKSPACE_DIR +``` + +### 3.1 Prerequisites + +* **OS**: Linux (Ubuntu 20.04+ recommended) +* **Compiler**: GCC >= 9 or Clang (C++17 support required) +* **Build System**: CMake >= 3.20, Ninja +* **Python**: 3.8+ +* **Python Packages**: `pybind11`, `numpy` + +```bash +python3 -m pip install pybind11==2.12.0 numpy +``` + +> **Note**: The current LLVM/MLIR Python bindings are not compatible with `pybind11` 3.x. +> If you encounter errors like `def_property family does not currently support keep_alive` +> when building LLVM, run the downgrade command above first. + +### 3.2 Step 1: Build LLVM/MLIR (Dependency) + +Download the LLVM source, check out the `llvmorg-19.1.7` tag, and build with **shared libraries** to ensure correct linking for Python bindings. + +```bash +# 1. Clone LLVM +cd $WORKSPACE_DIR +git clone https://github.com/llvm/llvm-project.git +cd $LLVM_SOURCE_DIR + +# 2. [Critical] Check out llvmorg-19.1.7 +git checkout llvmorg-19.1.7 + +# 3. Configure CMake (build shared libs with Python bindings enabled) +cmake -G Ninja -S llvm -B $LLVM_BUILD_DIR \ + -DLLVM_ENABLE_PROJECTS="mlir;clang" \ + -DBUILD_SHARED_LIBS=ON \ + -DMLIR_ENABLE_BINDINGS_PYTHON=ON \ + -DPython3_EXECUTABLE=$(which python3) \ + -DCMAKE_BUILD_TYPE=Release \ + -DLLVM_TARGETS_TO_BUILD="host" + +# 4. Build LLVM (this step takes a long time) +ninja -C $LLVM_BUILD_DIR +``` + +### 3.3 Step 2: Build PTOAS (Out-of-Tree) + +Clone the PTOAS source and build against the LLVM 19 you just compiled. + +```bash +# 1. Clone PTOAS +cd $WORKSPACE_DIR +git clone https://gitcode.com/cann/pto-as.git PTOAS +cd $PTO_SOURCE_DIR + +# 2. Build and install via pip +# The build backend (pyproject.toml) drives CMake + Ninja automatically. +pip install . +``` + +This produces the same artifacts as a manual CMake build: + +```text +# CLI tools +$PTO_SOURCE_DIR/build/tools/ptoas/ptoas +$PTO_SOURCE_DIR/build/tools/ptobc/ptobc + +# Native extension installed into the MLIR Python package +$LLVM_BUILD_DIR/tools/mlir/python_packages/mlir_core/ +└── mlir + └── _mlir_libs + └── _pto.cpython-*.so + +# Python dialect files +$PTO_INSTALL_DIR/ +└── mlir + └── dialects + ├── pto.py + └── _pto_ops_gen.py +``` + +### 3.4 Step 3: Python Editable Install (Optional, for Python development) + +If you want to develop and test Python code against the in-tree build without reinstalling after every C++ change, use an **editable install**. + +```bash +pip install -e . --no-build-isolation +``` + +> **Why `--no-build-isolation`?** Without this flag, pip uses a temporary virtual environment for the build, records its pybind11 path in `CMakeCache.txt`, then deletes the venv — breaking any subsequent `ninja` reconfigure. + +If you previously ran `pip install -e .` without the flag and your build is now broken, fix the existing `CMakeCache.txt` with: + +```bash +cmake -B build -Dpybind11_DIR=$(python3 -m pybind11 --cmakedir) +``` + +--- + +## 4. Usage + +### 4.1 Command-Line Interface (CLI) + +```bash +# Parse and print PTO IR +ptoas test/lit/pto/empty_func.pto + +# Run the AutoSyncInsert pass +ptoas test/lit/pto/empty_func.pto --enable-insert-sync -o outputfile.cpp + +# Specify target hardware architecture (A3 / A5) +ptoas test/lit/pto/empty_func.pto --pto-arch=a5 -o outputfile.cpp + +# Specify build level (level3 disables PlanMemory/InsertSync) +ptoas test/lit/pto/empty_func.pto --pto-level=level3 -o outputfile.cpp + +# Print the current ptoas release version +ptoas --version +``` + +### 4.2 Python API + +After configuring the environment variables, the PTO Dialect is loaded as part of `mlir.dialects`. + +```python +from mlir.ir import Context, Module, Location +# [Key] Import pto from mlir.dialects — the standard pattern for out-of-tree bindings +from mlir.dialects import pto + +with Context() as ctx, Location.unknown(): + pto.register_dialect(ctx, load=True) + module = Module.create() + print("PTO Dialect registered successfully!") +``` + +### 4.3 Running Tests + +```bash +# Run Python binding tests +cd $PTO_SOURCE_DIR/test/samples/MatMul/ +python3 ./tmatmulk.py > ./tmatmulk.pto + +# Run ptoas tests +$PTO_SOURCE_DIR/build/tools/ptoas/ptoas ./tmatmulk.pto -o ./tmatmulk.cpp +``` + +### 4.4 On-Board Validation + +This flow generates NPU validation test cases from the `.cpp` files produced by ptoas (under `test/samples/`) and runs them on an NPU. The example below reuses `MatMul/tmatmulk.cpp` generated in section 4.3. + +> For compile-only validation on a machine without an NPU card, see [docs/no_npu_compile_only_guide_zh.md](docs/no_npu_compile_only_guide_zh.md). + +```bash +# 1) Generate the npu_validation test directory +# (creates npu_validation/ under the current sample directory) + +# A2/A3 example: +python3 test/npu_validation/scripts/generate_testcase.py \ + --input test/samples/MatMul/tmatmulk.cpp \ + --run-mode npu \ + --soc-version Ascend910B1 + +# A5 example: +python3 test/npu_validation/scripts/generate_testcase.py \ + --input test/samples/MatMul/tmatmulk.cpp \ + --run-mode npu \ + --soc-version Ascend950 + +# 2) Run validation (run.sh requires no additional arguments) +test/samples/MatMul/npu_validation/tmatmulk/run.sh +``` + +Notes: +- `test/samples/MatMul/npu_validation/tmatmulk/` will contain `tmatmulk_kernel.cpp`, `main.cpp`, `golden.py`, `compare.py`, `run.sh`, and `CMakeLists.txt`. +- `golden.py` generates random inputs by default; outputs default to all zeros (only the count, shape, and data type of inputs/outputs match the kernel parameters). +- `compare.py` compares `golden*.bin` against `output*.bin` and reports an error if they differ. diff --git a/_ptoas_build_backend.py b/_ptoas_build_backend.py new file mode 100644 index 000000000..9cf83e445 --- /dev/null +++ b/_ptoas_build_backend.py @@ -0,0 +1,252 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +""" +PEP 517 build backend for ptoas. + +Runs the CMake/Ninja build (assuming LLVM is already built), then delegates +wheel packaging to docker/create_wheel.sh. + +Environment variables (all optional): + LLVM_BUILD_DIR Path to LLVM build dir + (default: /llvm-workspace/llvm-project/build-shared) + PTO_INSTALL_DIR Install prefix (default: /install) + PTOAS_PYTHON_PACKAGE_VERSION Wheel version override +""" +from __future__ import annotations + +import base64 +import glob +import hashlib +import io +import os +import shutil +import subprocess +import sys +import zipfile +from pathlib import Path + +_REPO = Path(__file__).parent.resolve() +_LLVM_BUILD_DIR = Path( + os.environ.get("LLVM_BUILD_DIR", + "/llvm-workspace/llvm-project/build-shared") +) +_PTO_INSTALL_DIR = Path( + os.environ.get("PTO_INSTALL_DIR", str(_REPO / "install")) +) +_BUILD_DIR = _REPO / "build" +_MLIR_PY_PKG = ( + _LLVM_BUILD_DIR / "tools" / "mlir" / "python_packages" / "mlir_core" +) + + +def get_requires_for_build_wheel(config_settings=None): + return ["setuptools>=68", "wheel", "pybind11<3"] + + +def get_requires_for_build_editable(config_settings=None): + return ["setuptools>=68", "wheel", "pybind11<3"] + + +def get_requires_for_build_sdist(config_settings=None): + return [] + + +def prepare_metadata_for_build_wheel(metadata_directory, config_settings=None): + """Return wheel metadata without running the full build.""" + import email.message + + version = os.environ.get("PTOAS_PYTHON_PACKAGE_VERSION", "0.1.0") + dist_info = Path(metadata_directory) / f"ptoas-{version}.dist-info" + dist_info.mkdir(parents=True, exist_ok=True) + + meta = email.message.Message() + meta["Metadata-Version"] = "2.1" + meta["Name"] = "ptoas" + meta["Version"] = version + meta["Summary"] = "PTO Assembler & Optimizer" + meta["Requires-Python"] = ">=3.9" + meta["License"] = "Apache-2.0" + meta["Requires-Dist"] = "numpy" + (dist_info / "METADATA").write_text(str(meta)) + (dist_info / "WHEEL").write_text( + "Wheel-Version: 1.0\nGenerator: _ptoas_build_backend\n" + "Root-Is-Purelib: True\nTag: py3-none-any\n" + ) + return dist_info.name + + +prepare_metadata_for_build_editable = prepare_metadata_for_build_wheel + + +def build_sdist(sdist_directory, config_settings=None): + raise NotImplementedError( + "ptoas does not support sdist. Use `pip install .` to build a wheel." + ) + + +def _cmake_configure_and_build(): + """CMake configure + Ninja build + install.""" + _BUILD_DIR.mkdir(exist_ok=True) + + pybind11_dir = subprocess.check_output( + [sys.executable, "-m", "pybind11", "--cmakedir"], text=True + ).strip() + + cmake_cmd = [ + "cmake", "-GNinja", + f"-S{_REPO}", f"-B{_BUILD_DIR}", + "-DCMAKE_BUILD_TYPE=Release", + f"-DLLVM_DIR={_LLVM_BUILD_DIR}/lib/cmake/llvm", + f"-DMLIR_DIR={_LLVM_BUILD_DIR}/lib/cmake/mlir", + f"-DPython3_ROOT_DIR={sys.prefix}", + f"-DPython3_EXECUTABLE={sys.executable}", + "-DPython3_FIND_STRATEGY=LOCATION", + f"-Dpybind11_DIR={pybind11_dir}", + f"-DMLIR_PYTHON_PACKAGE_DIR={_MLIR_PY_PKG}", + f"-DCMAKE_INSTALL_PREFIX={_PTO_INSTALL_DIR}", + ] + + release_version = os.environ.get("PTOAS_RELEASE_VERSION_OVERRIDE", "") + if release_version: + cmake_cmd.append(f"-DPTOAS_RELEASE_VERSION_OVERRIDE={release_version}") + + hardening_cache = _REPO / "cmake" / "LinuxHardeningCache.cmake" + if hardening_cache.exists(): + cmake_cmd.insert(1, f"-C{hardening_cache}") + + subprocess.check_call(cmake_cmd) + subprocess.check_call(["ninja", "-C", str(_BUILD_DIR)]) + subprocess.check_call(["ninja", "-C", str(_BUILD_DIR), "install"]) + + +def _install_dialect_files(): + """Copy PTO dialect .py files and TileLang resources into the MLIR package dir.""" + dialects_src = _PTO_INSTALL_DIR / "mlir" / "dialects" + dialects_dst = _MLIR_PY_PKG / "mlir" / "dialects" + if dialects_src.exists() and dialects_dst.exists(): + for f in dialects_src.glob("*.py"): + shutil.copy2(f, dialects_dst / f.name) + + tilelang_src = _PTO_INSTALL_DIR / "tilelang_dsl" + tileops_src = _PTO_INSTALL_DIR / "share" / "ptoas" / "TileOps" + if tilelang_src.exists(): + dst = _MLIR_PY_PKG / "tilelang_dsl" + if dst.exists(): + shutil.rmtree(dst) + shutil.copytree(tilelang_src, dst) + if tileops_src.exists(): + dst = _MLIR_PY_PKG / "TileOps" + if dst.exists(): + shutil.rmtree(dst) + shutil.copytree(tileops_src, dst) + + +def build_wheel(wheel_directory, config_settings=None, metadata_directory=None): + _cmake_configure_and_build() + + env = os.environ.copy() + env.update({ + "PTO_SOURCE_DIR": str(_REPO), + "PTO_INSTALL_DIR": str(_PTO_INSTALL_DIR), + "LLVM_BUILD_DIR": str(_LLVM_BUILD_DIR), + }) + subprocess.check_call( + ["bash", str(_REPO / "docker" / "create_wheel.sh")], + env=env, + ) + + wheels = sorted( + glob.glob(str(_MLIR_PY_PKG / "dist" / "ptoas-*.whl")), + key=os.path.getmtime, + ) + if not wheels: + raise RuntimeError( + f"No ptoas-*.whl found in {_MLIR_PY_PKG / 'dist'} after build." + ) + + wheel_path = Path(wheels[-1]) + dest = Path(wheel_directory) / wheel_path.name + shutil.copy2(wheel_path, dest) + return dest.name + + +def build_editable(wheel_directory, config_settings=None, metadata_directory=None): + """PEP 660 editable install. + + Builds the C++ extensions in-place, then produces a minimal wheel that + installs a .pth file pointing sys.path at the build tree. No files are + copied into site-packages except the .pth file itself. + """ + _cmake_configure_and_build() + + # Copy dialect .py files so `from mlir.dialects import pto` works + _install_dialect_files() + + version = os.environ.get("PTOAS_PYTHON_PACKAGE_VERSION", "0.1.0") + + # Paths that must be on sys.path for the package to be importable + pth_paths = [ + # mlir.* namespace + _pto.so (installed there by CMake) + str(_MLIR_PY_PKG), + # _pto.so output directory (CMAKE_LIBRARY_OUTPUT_DIRECTORY) + str(_BUILD_DIR / "python" / "pto"), + # handwritten Python sources (pto/dialects/pto.py, etc.) + str(_REPO / "python"), + # ptodsl pure-Python sub-package + str(_REPO / "ptodsl"), + ] + + pth_content = "\n".join(pth_paths) + "\n" + pth_filename = "ptoas-editable.pth" + + # ---- Build the editable wheel (a zip with .pth + dist-info) ---- + tag = f"py3-none-any" + wheel_name = f"ptoas-{version}-{tag}.whl" + wheel_path = Path(wheel_directory) / wheel_name + + dist_info_dir = f"ptoas-{version}.dist-info" + + def _sha256_record(data: bytes) -> str: + digest = hashlib.sha256(data).digest() + b64 = base64.urlsafe_b64encode(digest).rstrip(b"=").decode() + return f"sha256={b64}" + + pth_bytes = pth_content.encode() + wheel_meta = ( + "Wheel-Version: 1.0\n" + "Generator: _ptoas_build_backend\n" + "Root-Is-Purelib: True\n" + f"Tag: {tag}\n" + "Build: editable\n" + ).encode() + metadata_content = ( + "Metadata-Version: 2.1\n" + "Name: ptoas\n" + f"Version: {version}\n" + "Summary: PTO Assembler & Optimizer\n" + "Requires-Python: >=3.9\n" + "License: Apache-2.0\n" + "Requires-Dist: numpy\n" + ).encode() + + record_lines = [ + f"{pth_filename},{_sha256_record(pth_bytes)},{len(pth_bytes)}", + f"{dist_info_dir}/WHEEL,{_sha256_record(wheel_meta)},{len(wheel_meta)}", + f"{dist_info_dir}/METADATA,{_sha256_record(metadata_content)},{len(metadata_content)}", + f"{dist_info_dir}/RECORD,,", + ] + record_content = "\n".join(record_lines).encode() + + with zipfile.ZipFile(wheel_path, "w", compression=zipfile.ZIP_DEFLATED) as zf: + zf.writestr(pth_filename, pth_bytes) + zf.writestr(f"{dist_info_dir}/WHEEL", wheel_meta) + zf.writestr(f"{dist_info_dir}/METADATA", metadata_content) + zf.writestr(f"{dist_info_dir}/RECORD", record_content) + + return wheel_name diff --git a/docker/create_wheel.sh b/docker/create_wheel.sh index 2145fb9e7..a3ecf38ee 100755 --- a/docker/create_wheel.sh +++ b/docker/create_wheel.sh @@ -51,6 +51,9 @@ rm -rf "${PY_PACKAGE_DIR}/tilelang_dsl" "${PY_PACKAGE_DIR}/TileOps" cp -R "${PTO_INSTALL_DIR}/tilelang_dsl" "${PY_PACKAGE_DIR}/tilelang_dsl" cp -R "${PTO_INSTALL_DIR}/share/ptoas/TileOps" "${PY_PACKAGE_DIR}/TileOps" +# Copy ptodsl into the wheel so it is always shipped with ptoas +cp -R "${PTO_SOURCE_DIR}/ptodsl/ptodsl" "${PY_PACKAGE_DIR}/ptodsl" + # Copy platform-specific setup.py to package directory. # On macOS, use setup_mac.py and rename it to setup.py in the build dir. SETUP_TEMPLATE="${PTO_SOURCE_DIR}/docker/setup.py" diff --git a/docs/designs/issue417_ptr_entry_plan_zh.md b/docs/designs/issue417_ptr_entry_plan_zh.md new file mode 100644 index 000000000..a91613704 --- /dev/null +++ b/docs/designs/issue417_ptr_entry_plan_zh.md @@ -0,0 +1,343 @@ +# Issue 417: 统一 PTODSL Kernel Entry 为 ptr + int 的分析与计划 + +## 背景 + +Issue: https://github.com/mouliangyu/PTOAS/issues/417 + +当前 `pto-dsl-impl` 分支的 Python DSL 主要按 A5 芯片能力设计。A5 有向量寄存器和较完整的 vreg/tile 编程模型,因此现有文档和示例大量使用: + +```python +@pto.jit(target="a5") +def kernel(A: pto.tensor_spec(rank=2, dtype=pto.f32)): + a = pto.make_tensor_view(A, shape=A.shape, strides=A.strides) +``` + +这个入口形式把 Python host tensor 的 ABI 信息写在 `tensor_spec(rank=..., dtype=...)` annotation 中。它适合当前 A5-oriented DSL 的开发体验,但不是跨芯片的最小共同 ABI。 + +A2/A3 没有 A5 的向量寄存器模型,kernel 更自然也更接近已有 C++ 写法: + +```cpp +__global__ AICORE void kernel(__gm__ float *x, int32_t batch) +``` + +因此 issue 希望把 Python DSL 的 kernel entry 统一到 `ptr + int`: + +```python +N = 1024 # static, closure + +@pto.jit +def my_kernel(x_ptr: pto.ptr(pto.f32, "gm"), batch: pto.i32): + x = pto.make_tensor_view( + x_ptr, + shape=[batch, N], + strides=[N, 1], + ) +``` + +核心原则: + +- kernel entry 只暴露裸 GM pointer 和运行时 scalar; +- 静态 shape 轴通过 Python closure / constexpr 固定; +- 动态 shape 轴作为 `pto.i32`/`pto.i64` 等 scalar 入参; +- 不在函数参数 annotation 中写 rank; +- rank 由 `make_tensor_view(..., shape=[...])` 的 `len(shape)` 推断; +- launch 侧从 torch tensor 或其他 host tensor 取 shape,再作为额外 Python int 传给 kernel; +- 入口 ABI 对齐 A2/A3/A5 的共同下限,也对齐现有 C++ kernel 的 `__gm__ T* + int32_t` 风格。 + +## 当前代码现状 + +当前代码已经具备一部分基础能力: + +1. `make_tensor_view` 已经从 `shape` 推断 rank + 位置:`ptodsl/ptodsl/_ops.py` + `make_tensor_view(ptr, shape=..., strides=...)` 内部使用 `rank = len(shape)`,并从 pointer element type 推 tensor view dtype。 + +2. runtime scalar entry 已经存在 + 位置:`ptodsl/ptodsl/_kernel_signature.py` + `pto.i32`、`pto.f32` 等 annotation 会被解析成 `RuntimeScalarParameterSpec`。 + +3. 裸 pointer entry 的下游结构已经存在 + 位置: + - `ptodsl/ptodsl/_kernel_signature.py`: `DeviceParameterSpec` + - `ptodsl/ptodsl/_runtime/codegen.py`: `DeviceParameterSpec` codegen 分支 + - `ptodsl/ptodsl/_runtime/launch.py`: `DeviceParameterSpec` launch marshaling 分支 + +4. 但 `@pto.jit` 入口目前不接受 `pto.ptr(...)` + `parse_jit_kernel_signature` 只接受 `pto.tensor_spec(...)` 和 runtime scalar。`pto.ptr(...)` 会走到非法 annotation 诊断路径。测试里也有“entry annotation 使用 `pto.ptr(...)` 应报错”的预期。 + +结论:这个 issue 不是从零实现 ptr ABI,而是要把已有的 `DeviceParameterSpec` 通路正式接到 public `@pto.jit` entry 上,并迁移测试/示例来验证跨芯片共同 ABI。 + +## Issue 提出的问题 + +当前 `tensor_spec(rank=..., dtype=...)` 入口存在几个问题: + +1. rank 泄漏到 entry annotation + issue 明确要求省略函数参数类型里的 rank。rank 应由 kernel body 里的 `make_tensor_view` 根据 `shape` 推断。 + +2. 动态 shape 表达不统一 + `tensor_spec` 会把 host tensor 的 shape/strides 自动展开为 entry ABI metadata。issue 希望动态轴显式成为 kernel 参数,例如 `batch: pto.i32`,静态轴留在 closure 中。 + +3. A5-oriented 设计不适合作为 A2/A3 的共同入口 + A5 可以在 DSL 内部大量使用 vreg/tile,但 A2/A3 没有向量寄存器,入口层必须回到 pointer + scalar 这个共同模型。 + +4. C++ 到 Python DSL 迁移成本偏高 + 现有 C++ kernel 多数是 `__gm__ T* + int32_t`。如果 Python DSL entry 也采用 ptr + int,迁移时只需要把 body 内部用 `make_tensor_view` 建描述符,而不是先改成 `tensor_spec` 风格。 + +## 要解决的点 + +本 issue 的完成标准应包括: + +1. `@pto.jit` public entry 支持 `pto.ptr(...)` + - `x: pto.ptr(pto.f32, "gm")` 应被解析为一个 device pointer 参数; + - IR entry 参数应是 `!pto.ptr`; + - launch wrapper 应把 Python tensor / raw pointer 转为 `void*` 或 typed pointer 后传入。 + +2. `@pto.jit` public entry 支持 ptr 与 runtime scalar 混排 + - 例如 `x_ptr: pto.ptr(...), batch: pto.i32, cols: pto.i32`; + - scalar 参数在 IR entry 中保持对应整数类型; + - launch 侧接受 Python int 并按 annotation marshaling。 + +3. `make_tensor_view` 是 tensor rank 的唯一推断点 + - 不需要也不允许在 ptr annotation 中携带 rank; + - `shape=[batch, N]` 的长度决定 tensor view rank; + - shape 可以混合 runtime scalar 和 static Python int。 + +4. 保持 A5 能力不被破坏 + - A5 内部仍可使用 vreg/tile/subkernel; + - ptr entry 只是 entry ABI 统一,不要求删除 A5 内部 vreg 模型; + - 原有 `tensor_spec` 是否保留兼容,需要单独决策。建议第一阶段保留,先把 issue 的新入口跑通。 + +5. 给 A2/A3 留出自然入口 + - target 为 `a2`/`a3` 的 DSL kernel 可以直接使用 ptr + int; + - 不依赖 vreg-only 的 entry abstraction; + - 后续 A2/A3 lowering 若受限,应在 body ops / target legality 层诊断,而不是 entry ABI 层阻塞。 + +## 解决方案 + +### 1. 接通 ptr entry parsing + +修改 `ptodsl/ptodsl/_kernel_signature.py`: + +- 新增 `_is_supported_device_parameter_annotation(annotation)`; +- 判断 annotation 是否是 `_PtrDescriptor` 或已解析的 PTO `PtrType`; +- 对 ptr annotation 生成 `DeviceParameterSpec`; +- 保持 runtime scalar 生成 `RuntimeScalarParameterSpec`; +- 保持 `tensor_spec` 兼容,至少第一阶段不删除。 + +预期解析顺序: + +```text +tensor_spec(...) -> TensorSpecParameterSpec +pto.ptr(...) -> DeviceParameterSpec +pto.i32/f32/... -> RuntimeScalarParameterSpec +otherwise -> diagnostic +``` + +### 2. 明确 ptr API 形式 + +当前代码已有: + +```python +pto.ptr(pto.f32, "gm") +pto.ptr(pto.f32, pto.MemorySpace.GM) +``` + +Issue 示例写法是: + +```python +pto.ptr(dtype=pto.f32) +``` + +建议分两步: + +1. 第一阶段使用已有 API:`pto.ptr(pto.f32, "gm")`,避免扩大改动面; +2. 第二阶段补兼容糖:`pto.ptr(dtype=pto.f32, space="gm")` 或 `pto.ptr(dtype=pto.f32, address_space="gm")`。 + +不建议直接修改 `pto.ptr` 的全局默认 address space。当前 `pto.ptr(elem, space="ub")` 可能已被内部 UB pointer 场景依赖。entry 示例里应显式写 `"gm"`,避免因为默认值不同导致 ABI 错误。 + +### 3. 复用现有 launch/codegen + +`DeviceParameterSpec` 在 codegen 和 launch 中已经有处理分支: + +- codegen 生成 `__gm__ T *param`; +- host wrapper 接收 `T *param`; +- launch marshaling 用 `_as_void_ptr` 支持 `ctypes.c_void_p`、integer pointer、以及带 `.data_ptr()` 的 torch tensor。 + +因此第一阶段不需要重写 launch,只需要补测试确认: + +```python +compiled[grid, stream](x_tensor, int(x_tensor.shape[0])) +``` + +会被 marshal 为: + +```text +x_tensor.data_ptr(), c_int32(batch) +``` + +### 4. 保持 `make_tensor_view` 作为 body 内显式转换 + +现有 `make_tensor_view` 已经符合 issue 方向。要补充测试覆盖: + +```python +N = 1024 + +@pto.jit(target="a5") +def ptr_dynamic_shape_probe(x: pto.ptr(pto.f32, "gm"), batch: pto.i32): + x_view = pto.make_tensor_view(x, shape=[batch, N], strides=[N, 1]) +``` + +断言重点: + +- entry function 参数包含 `!pto.ptr` 和 `i32`; +- `pto.make_tensor_view` 的 shape 包含 `%arg1` 和 static constant `1024`; +- tensor view rank 为 2; +- annotation 中没有 `rank=`。 + +### 5. 更新诊断 + +当前 `pto.ptr(...)` entry annotation 会被诊断为非法。需要调整: + +- 删除或改写“`@pto.jit` 不支持 `pto.ptr(...)` entry”的测试; +- 新增非法指针类型诊断,例如 storage-only dtype 不允许作为 ptr element; +- 对非 GM entry pointer 是否允许做明确策略。 + +建议策略: + +- public launch entry 第一阶段只允许 GM pointer; +- UB/MAT/LEFT/RIGHT 等非 GM pointer 仍只用于 kernel body / subkernel boundary; +- 如果用户在 public `@pto.jit` entry 写 `pto.ptr(pto.f32, "ub")`,给出清晰错误:public launch entry expects GM pointer。 + +这样可以避免 host 侧传入普通 tensor data pointer 却被标成 UB pointer 的错误。 + +### 6. 更新示例和文档 + +优先更新最小闭环示例: + +- `ptodsl/examples/jit/tadd_launch.py` +- `ptodsl/docs/user_guide/03-kernel-entry-and-subkernels.md` +- 相关 docs fixture / docs-as-test + +迁移原则: + +旧风格: + +```python +@pto.jit(target="a5") +def add(A: pto.tensor_spec(rank=2, dtype=pto.f32), O: pto.tensor_spec(rank=2, dtype=pto.f32)): + a = pto.make_tensor_view(A, shape=A.shape, strides=A.strides) +``` + +新风格: + +```python +N = 1024 + +@pto.jit(target="a5") +def add(A: pto.ptr(pto.f32, "gm"), O: pto.ptr(pto.f32, "gm"), batch: pto.i32): + a = pto.make_tensor_view(A, shape=[batch, N], strides=[N, 1]) + o = pto.make_tensor_view(O, shape=[batch, N], strides=[N, 1]) +``` + +launch 侧: + +```python +add.compile()[grid, stream](A, O, int(A.shape[0])) +``` + +## 实施计划 + +### Phase 1: 最小功能闭环 + +目标:一个 ptr + int dynamic shape kernel 可以 compile,并生成预期 IR。 + +改动: + +1. 修改 `parse_jit_kernel_signature`,支持 `pto.ptr(..., "gm")` entry; +2. 新增 frontend compile test; +3. 修改原有 ptr entry diagnostics test; +4. 验证 `make_tensor_view` shape 混合 runtime scalar 和 static int。 + +验收: + +```bash +python -m pytest test/python/ptodsl_jit_compile.py test/python/ptodsl_jit_diagnostics.py +``` + +如本地缺 MLIR/PTOAS Python 环境,则至少跑目标单测或记录无法运行原因。 + +### Phase 2: launch 侧闭环 + +目标:torch tensor / pointer + Python int 可以调用 compiled kernel wrapper。 + +改动: + +1. 添加 `_marshal_launch_args` 单测,覆盖 `DeviceParameterSpec + RuntimeScalarParameterSpec`; +2. 验证 `tensor.data_ptr()` 被用于 ptr 参数; +3. 验证 Python int 按 `pto.i32` marshaling; +4. 如有真实 NPU 环境,再跑一个小型 JIT launch demo。 + +验收: + +- `compiled[grid, stream](tensor, int(tensor.shape[0]))` 不再要求 `tensor_spec`; +- launch wrapper 生成参数顺序为 `grid, stream, ptr, batch`。 + +### Phase 3: target 兼容策略 + +目标:明确 A5 与 A2/A3 的边界。 + +改动: + +1. public entry ptr ABI 对所有 target 可用; +2. A5 继续允许 body 内使用 vreg/tile; +3. A2/A3 不支持的 vreg/tile op 由 target legality 或 lowering 阶段诊断; +4. 文档说明 ptr + int 是跨芯片推荐 entry ABI。 + +验收: + +- `@pto.jit(target="a5")` ptr entry 测试通过; +- 新增 `target="a3"` 或 `target="a2"` 的 compile-only entry test,如果当前 toolchain 支持; +- 若 toolchain 暂不支持 A2/A3 Python DSL compile,则文档标注 pending,不阻塞 Phase 1。 + +### Phase 4: 文档和示例迁移 + +目标:把 public-facing 示例从 `tensor_spec` 主路径迁移到 ptr + int。 + +改动: + +1. 更新 quick start / kernel entry 文档; +2. 更新至少一个 JIT launch 示例; +3. docs-as-test fixture 对齐; +4. 保留 `tensor_spec` 兼容说明,或者标为 legacy/compat。 + +验收: + +- 文档中的主推荐写法是 `pto.ptr(..., "gm") + pto.i32`; +- `tensor_spec` 不再被描述为唯一 public entry ABI。 + +## 风险与待确认问题 + +1. `tensor_spec` 是否最终删除 + 建议不在本 issue 中删除。先保留兼容,避免大规模 docs/tests 迁移阻塞 ptr ABI 落地。 + +2. `pto.ptr(dtype=pto.f32)` 是否必须第一阶段支持 + issue 示例用了这个写法,但当前实现已有 `pto.ptr(pto.f32, "gm")`。建议第一阶段先支持现有 API,第二阶段补 keyword 兼容。 + +3. public entry 是否只允许 GM pointer + 建议只允许 GM pointer。host tensor 的 `.data_ptr()` 是全局内存地址,标成 UB/MAT pointer 不合理。 + +4. shape scalar 用 `i32` 还是 `index/i64` + issue 示例用 `pto.i32`,但当前 `tensor_spec` metadata 使用 `int64_t`。建议遵循 issue:动态轴参数由用户 annotation 决定,示例使用 `pto.i32`;如需要大 shape,可用 `pto.i64` 或 `pto.index`。 + +5. strides 是否也要作为 runtime int 参数 + issue 核心是动态 shape 轴,但 `make_tensor_view` 也需要 strides。第一阶段示例可用 static/closure strides;非 contiguous tensor 的通用支持需要用户额外传 stride 参数,或者在更高层 wrapper 自动传入。 + +## 推荐开发顺序 + +1. 先实现 `@pto.jit` ptr entry parsing; +2. 加最小 compile test,确认 IR entry 与 `make_tensor_view` 正确; +3. 加 launch marshaling test; +4. 改一个 JIT 示例; +5. 更新文档; +6. 再考虑 `pto.ptr(dtype=...)` keyword 兼容和 `tensor_spec` legacy 策略。 + diff --git a/docs/designs/ptodsl-tiletrace-poc-proposal.md b/docs/designs/ptodsl-tiletrace-poc-proposal.md new file mode 100644 index 000000000..6b66b23ff --- /dev/null +++ b/docs/designs/ptodsl-tiletrace-poc-proposal.md @@ -0,0 +1,179 @@ +# ptodsl `vpto` POC Proposal + +## Background + +Today we have two very different authoring paths for VPTO-related Python DSLs: + +- `ptodsl` executes Python directly and builds IR through tracing-style wrappers. +- `tilelang-dsl` captures Python source as AST, then runs `frontend_ast -> semantic -> lowering`. + +This split is especially visible for tile templates such as +[`lib/TileOps/tadd_template.py`](/home/zhangzhendong/ptoas-workspace/PTOAS/lib/TileOps/tadd_template.py), +whose body is conceptually simple but currently depends on the full AST frontend. + +For the longer-term direction, we want VPTO-level authoring to converge on the +same tracing-style route as `ptodsl`, while preserving as much of the +TileLang-style surface as practical. + +## Problem Statement + +The current AST route gives us good source diagnostics and broad surface +coverage, but it also has clear costs: + +- Every new surface feature needs to be added in three layers: + frontend node building, semantic typing, and text lowering. +- Reusing mature `ptodsl` builder idioms is difficult because authored Python is + no longer the execution model. +- Simple tile templates still pay the cost of a compiler frontend even when the + kernel body is static, structured, and already close to the desired VPTO form. + +For team discussion, the concrete question is: + +Can we execute a TileLang-style tile template directly and emit useful VPTO IR +without going through AST capture? + +## Proposal + +Introduce an experimental `ptodsl.vpto` namespace as a tracing-oriented POC for +TileLang-style tile templates. + +### Design Goals + +- Reuse the authored Python function body directly. +- Keep the POC independent from `tilelang-dsl` internals. +- Preserve the most recognizable TileLang surface where it is cheap: + `@pto.vkernel`, `Tile`, `dst.element_type`, `dst.valid_shape`, + `tile[row, col:]`, `get_lanes`, `make_mask`, `vlds`, `vadd`, `vsts`. +- Keep the implementation minimal and explicit enough that the team can judge + whether the tracing route is viable before we invest in broader migration. + +### Non-Goals for This POC + +- No attempt to replace `tilelang-dsl` in-place. +- No matcher, multi-dtype registry, template slots, inline-proc, or cube + surface. +- No source-diagnostic parity with the AST frontend. +- No requirement to generalize beyond the minimal pybinding-backed subset + needed for `tadd_template.py`. + +## POC Scope + +The POC is intentionally limited to a single template shape: + +- Target template: `tadd_template.py` +- Supported parameter kind: bare static 2D `Tile` +- Supported control flow: explicit structured `for_()` builders, with optional + `vecscope()` when the author wants to spell it directly +- Supported ops: `make_mask`, `vlds`, `vadd`, `vsts` +- Supported lowering shape: nested `scf.for`, `pto.tile_buf_addr`, and vector + micro-ops, with optional `pto.vecscope` + +This means the first implementation validates the core idea: + +1. specialize bare `Tile` parameters with static shape + dtype +2. execute the authored Python body directly +3. trace tile slice accesses such as `src0[row, col:]` +4. emit structured VPTO IR with `scf.for` and no AST capture + +## Why This Cut Is Useful + +This is not yet the final architecture, but it answers the most important +migration question with low implementation risk: + +- If the POC is too awkward even for `tadd_template.py`, we should not try to + move the main TileLang route onto tracing. +- If the POC stays small and readable, then we have evidence that a tracing + backend can carry at least a meaningful subset of tile templates. + +This cut also forces one important architectural decision early: + +- The tracing route should standardize on explicit builder-style control flow. + Reconstructing `scf.for` from raw Python `for range(...)` would pull us back + toward AST capture or source transformation, which defeats the purpose of the + experiment. + +## Proposed Architecture + +Add a new lightweight module: + +- [`ptodsl/ptodsl/vpto.py`](/home/zhangzhendong/ptoas-workspace/PTOAS/ptodsl/ptodsl/vpto.py) + +Core pieces: + +- `Tile` annotation marker +- `TileSpec(shape, dtype, memory_space="ub")` +- `@vkernel(target="a5", op="pto.tadd")` +- `TracingKernelDescriptor.specialize(...)` +- proxy `Tile` arguments that expose: + - `.element_type` + - `.valid_shape` + - `tile[row, col:]` +- a trace builder that emits structured MLIR objects through Python bindings + +The key idea is that `tile[row, col:]` is not lowered from AST. Instead, it is +captured at runtime through a proxy object and immediately converted into a +traced tile-slice value. + +## Expected Output Shape + +For the `tadd_template.py`-style kernel body, the POC emits: + +- tile-buffer arguments +- nested `scf.for` for rows and columns +- `pto.tile_buf_addr` for each referenced tile +- `pto.plt_b32` +- `pto.vlds` +- `pto.vadd` +- `pto.vsts` +- `scf.yield` for loop-carried `remained` + +This is intentionally close to the already documented tile-op expand form, but +keeps structured control flow instead of concretely unrolling the loops. + +## Tradeoffs + +### Advantages + +- Very small implementation surface for the first proof point. +- No dependency on AST parsing or source capture. +- Easy to compare source body and emitted IR side by side. +- Makes it clear which parts of TileLang syntax are “real execution” versus + “frontend-only sugar”. +- Produces IR that is much closer to a future scalable frontend than the + original fully unrolled POC. + +### Limitations + +- No rich diagnostics or semantic model yet. +- No integration with the existing `tilelang-dsl` package entrypoint. +- Current output is deliberately narrow and only covers the pybinding-backed + operations needed by the first POC template. +- Control flow currently needs explicit structured `for_()` builders instead of + raw Python `for range(...)`. `vecscope()` can still be used, but is not a + hard requirement in the POC. + +These are acceptable for the first experiment because the goal is not feature +completeness; it is to validate the tracing execution model on a real tile +template. + +## Rollout Path If The POC Works + +If the POC proves maintainable, the next steps should be: + +1. Add a slightly broader vector subset beyond `tadd`. +2. Replace any remaining POC-specific glue with shared `ptodsl`/MLIR builders. +3. Introduce a reusable runtime contract layer for dtype, mask, and slice + checks. +4. Decide whether `tilelang-dsl` should: + - keep AST frontend and optionally target the tracing backend, or + - expose a parallel tracing-first authoring mode for static templates. + +## Deliverables In This Change + +- This proposal document +- an experimental `ptodsl.vpto` namespace +- a minimal `tadd_template.py`-oriented POC example + +The change is intentionally framed as a team discussion artifact plus a narrow +executable proof-of-concept, not as a replacement plan for the current +TileLang frontend. diff --git a/lib/Bindings/Python/PTOModule.cpp b/lib/Bindings/Python/PTOModule.cpp index a13b39cad..1c8ae1c0d 100644 --- a/lib/Bindings/Python/PTOModule.cpp +++ b/lib/Bindings/Python/PTOModule.cpp @@ -18,6 +18,7 @@ #include "pybind11/pybind11.h" #include "pybind11/stl.h" #include "mlir/Bindings/Python/PybindAdaptors.h" +#include "mlir/CAPI/IR.h" #include "pto-c/Dialect/PTO.h" #include "mlir-c/IR.h" #include "PTO/IR/PTO.h" @@ -27,6 +28,8 @@ #include "mlir/IR/BuiltinTypes.h" namespace py = pybind11; using namespace mlir::python::adaptors; +using llvm::cast; +using llvm::isa; static std::vector toInt64Vector(const py::sequence &seq) { std::vector out; @@ -63,6 +66,14 @@ static py::list shapeToPyList(const int64_t *data, intptr_t n) { return lst; } +static py::object wrapAttributeAs(const py::module_ &m, const char *className, + MlirAttribute attr) { + if (mlirAttributeIsNull(attr)) + return py::none(); + py::object cls = m.attr(className); + return cls.attr("__call__")(attr); +} + void populatePTODialectSubmodule(pybind11::module &m); void populatePTODialectSubmodule(pybind11::module &m) { (void)m; @@ -705,6 +716,61 @@ static void bindPTOModule(pybind11::module &m) { return mlirPTOPtrTypeGetMemorySpace(self); }); + mlir_type_subclass( + m, "VRegType", + [](MlirType type) -> bool { return isa(unwrap(type)); }) + .def_classmethod( + "get", + [](py::object cls, int64_t elementCount, MlirType elementType, + MlirContext context) -> py::object { + context = inferContextFromElementType(context, elementType); + MlirType t = wrap( + mlir::pto::VRegType::get( + unwrap(context), elementCount, unwrap(elementType))); + return cls.attr("__call__")(t); + }, + py::arg("cls"), py::arg("element_count"), py::arg("element_type"), + py::arg("context") = py::none()) + .def_property_readonly( + "element_count", + [](MlirType self) -> int64_t { + return cast(unwrap(self)).getElementCount(); + }) + .def_property_readonly( + "element_type", + [](MlirType self) -> MlirType { + return wrap(cast(unwrap(self)).getElementType()); + }); + + mlir_type_subclass( + m, "MaskType", + [](MlirType type) -> bool { return isa(unwrap(type)); }) + .def_classmethod( + "get", + [](py::object cls, std::string granularity, MlirContext context) -> py::object { + MlirType t = wrap( + mlir::pto::MaskType::get(unwrap(context), granularity)); + return cls.attr("__call__")(t); + }, + py::arg("cls"), py::arg("granularity"), + py::arg("context") = py::none()) + .def_property_readonly( + "granularity", + [](MlirType self) -> std::string { + return cast(unwrap(self)).getGranularity().str(); + }); + + mlir_type_subclass( + m, "AlignType", + [](MlirType type) -> bool { return isa(unwrap(type)); }) + .def_classmethod( + "get", + [](py::object cls, MlirContext context) -> py::object { + MlirType t = wrap(mlir::pto::AlignType::get(unwrap(context))); + return cls.attr("__call__")(t); + }, + py::arg("cls"), py::arg("context") = py::none()); + mlir_type_subclass( m, "AsyncSessionType", [](MlirType type) -> bool { return mlirPTOTypeIsAAsyncSessionType(type); }) @@ -976,7 +1042,76 @@ static void bindPTOModule(pybind11::module &m) { if (mlirPTOTypeIsATileBufType(t)) return cls(t); return py::none(); }, - py::arg("cls"), py::arg("type")); + py::arg("cls"), py::arg("type")) + .def_property_readonly( + "rank", + [](MlirType self) -> intptr_t { + return static_cast( + cast(unwrap(self)).getRank()); + }) + .def_property_readonly( + "element_type", + [](MlirType self) -> MlirType { + return wrap(cast(unwrap(self)).getElementType()); + }) + .def_property_readonly( + "memory_space", + [m](MlirType self) -> py::object { + MlirAttribute attr = + wrap(cast(unwrap(self)).getMemorySpace()); + return wrapAttributeAs(m, "AddressSpaceAttr", attr); + }) + .def_property_readonly( + "shape", + [](MlirType self) -> py::list { + auto shape = cast(unwrap(self)).getShape(); + return shapeToPyList(shape.data(), static_cast(shape.size())); + }) + .def_property_readonly( + "valid_shape", + [](MlirType self) -> py::list { + auto validShape = cast(unwrap(self)).getValidShape(); + return shapeToPyList(validShape.data(), static_cast(validShape.size())); + }) + .def_property_readonly( + "blayout_attr", + [m](MlirType self) -> py::object { + MlirAttribute attr = + wrap(cast(unwrap(self)).getBLayoutAttr()); + return wrapAttributeAs(m, "BLayoutAttr", attr); + }) + .def_property_readonly( + "slayout_attr", + [m](MlirType self) -> py::object { + MlirAttribute attr = + wrap(cast(unwrap(self)).getSLayoutAttr()); + return wrapAttributeAs(m, "SLayoutAttr", attr); + }) + .def_property_readonly( + "blayout_value", + [](MlirType self) -> int32_t { + return cast(unwrap(self)).getBLayoutValueI32(); + }) + .def_property_readonly( + "slayout_value", + [](MlirType self) -> int32_t { + return cast(unwrap(self)).getSLayoutValueI32(); + }) + .def_property_readonly( + "pad_value", + [](MlirType self) -> int32_t { + return cast(unwrap(self)).getPadValueI32(); + }) + .def_property_readonly( + "compact_mode", + [](MlirType self) -> int32_t { + return cast(unwrap(self)).getCompactModeI32(); + }) + .def_property_readonly( + "s_fractal_size", + [](MlirType self) -> int32_t { + return cast(unwrap(self)).getSFractalSizeI32(); + }); populatePTODialectSubmodule(m); } diff --git a/lib/PTO/IR/PTO.cpp b/lib/PTO/IR/PTO.cpp index 68ded38e7..ec3a8c4fe 100644 --- a/lib/PTO/IR/PTO.cpp +++ b/lib/PTO/IR/PTO.cpp @@ -2255,12 +2255,6 @@ void mlir::pto::annotatePTOEntryFunctions(ModuleOp module) { LogicalResult AllocTileOp::verify() { auto ty = getResult().getType(); // TileBufType - - Type elemTy = ty.getElementType(); - if (isPTOLowPrecisionType(elemTy)) - return emitOpError() << "result dtype " << elemTy - << " is not supported by pto.alloc_tile yet"; - if (failed(verifyTileBufLayoutConstraints(*this, ty, "result"))) return failure(); diff --git a/lib/PTO/IR/VPTO.cpp b/lib/PTO/IR/VPTO.cpp index 44d18c890..d16e6ceda 100644 --- a/lib/PTO/IR/VPTO.cpp +++ b/lib/PTO/IR/VPTO.cpp @@ -115,6 +115,12 @@ static bool isMaskGranularityAdjacentWidening(StringRef inputGranularity, (inputGranularity == "b16" && resultGranularity == "b32"); } +static bool isMaskGranularityAdjacentNarrowing(StringRef inputGranularity, + StringRef resultGranularity) { + return (inputGranularity == "b16" && resultGranularity == "b8") || + (inputGranularity == "b32" && resultGranularity == "b16"); +} + LogicalResult PTOLoadOp::verify() { return verifyVPTOScalarAccessTypes(getOperation(), getPtr().getType(), getValue().getType(), "load"); @@ -4250,8 +4256,17 @@ LogicalResult PpackOp::verify() { if (failed(verifyMaskTypeLike(*this, getInput().getType(), "input type")) || failed(verifyMaskTypeLike(*this, getResult().getType(), "result type"))) return failure(); - if (getPart() != "LOWER") - return emitOpError("currently supports only LOWER part"); + if (!isSupportedPartToken(getPart())) + return emitOpError("requires part to be LOWER or HIGHER"); + auto inputMaskType = cast(getInput().getType()); + auto resultMaskType = cast(getResult().getType()); + StringRef inputGranularity = inputMaskType.getGranularity(); + StringRef resultGranularity = resultMaskType.getGranularity(); + if (inputGranularity != resultGranularity && + !isMaskGranularityAdjacentNarrowing(inputGranularity, resultGranularity)) { + return emitOpError( + "requires result mask granularity to match the input or narrow by one step"); + } return success(); } @@ -4259,8 +4274,8 @@ LogicalResult PunpackOp::verify() { if (failed(verifyMaskTypeLike(*this, getInput().getType(), "input type")) || failed(verifyMaskTypeLike(*this, getResult().getType(), "result type"))) return failure(); - if (getPart() != "LOWER") - return emitOpError("currently supports only LOWER part"); + if (!isSupportedPartToken(getPart())) + return emitOpError("requires part to be LOWER or HIGHER"); auto inputMaskType = cast(getInput().getType()); auto resultMaskType = cast(getResult().getType()); StringRef inputGranularity = inputMaskType.getGranularity(); diff --git a/ptodsl/README.md b/ptodsl/README.md new file mode 100644 index 000000000..0f017bfe6 --- /dev/null +++ b/ptodsl/README.md @@ -0,0 +1,397 @@ +# ptodsl — PTO Python IR Builders + +A lightweight, pip-installable DSL package for building PTO MLIR IR modules +in Python. PTODSL kernels are ordinary Python functions decorated with +`@pto.jit`. Type annotations carry PTO +types as lazy descriptors, and control-flow maps 1-to-1 to MLIR operations. + +--- + +## Directory layout + +``` +ptodsl/ +├── ptodsl/ # pip-installable package +│ ├── __init__.py # exports: pto, scalar +│ ├── pto.py # main PTO DSL namespace +│ ├── scalar.py # top-level scalar.* helper namespace +│ ├── _bootstrap.py # MLIR path setup + context factory +│ ├── _types.py # lazy dtype descriptors and type constructors +│ ├── _ops.py # PTO operation wrappers +│ ├── _control_flow.py # for_, if_, yield_ context managers +│ ├── _jit.py # @pto.jit decorator +│ ├── _tracing/ # shared tracing runtime building blocks +│ └── _tile_template_tracing.py # internal tile-template tracing implementation +├── examples/ +│ ├── tadd_lowlevel.py # TADD – raw MLIR Python binding calls +│ ├── tadd_dsl.py # TADD – @pto.jit DSL style +│ ├── softmax_lowlevel.py # Softmax – raw MLIR Python binding calls +│ └── softmax_dsl.py # Softmax – @pto.jit DSL style +├── pyproject.toml # pip install -e . +└── README.md +``` + +--- + +## Prerequisites + +```bash +# Install ptoas (first time only) +cd $PTOAS_REPO_ROOT # e.g. export PTOAS_REPO_ROOT=/workdir/ptoas_a5 +bash quick_install.sh + +# Set up environment in every new shell +source set_ptoas_env.sh +``` + +--- + +## Install the package + +```bash +cd $PTOAS_REPO_ROOT/ptodsl +pip install -e . +``` + +--- + +## JIT examples + +`ptodsl/examples/jit/` contains self-contained `@pto.jit` examples that cover +both compile-only and end-to-end launch flows. + +### Prerequisites for launch examples + +- `ptoas` + `ptodsl` installed as above +- CANN 9.0+ with `ASCEND_HOME_PATH` set +- For end-to-end launch: `torch`, `torch_npu`, `numpy` +- `bisheng` on `PATH` + +Set up the environment in each new shell: + +```bash +cd $PTOAS_REPO_ROOT +source set_ptoas_env.sh +source "${ASCEND_HOME_PATH}/bin/setenv.bash" +``` + +For CPU simulation with `msprof`, the wrapper script below will set the +simulator library path and `ulimit` for you. The normal PTOAS + CANN shell +setup above is still required. + +### `tadd_launch.py` + +Single script: kernel definition, compile, launch, and accuracy check. +Equivalent IR to the TileLang ST `tadd.pto` testcase. + +Compile-only: + +```bash +python3 ptodsl/examples/jit/tadd_launch.py --emit-mlir +``` + +Expected: MLIR containing `@TADD_f32_16x64` and `@TADD_f32_32x32`. + +Optional PTOAS frontend smoke: + +```bash +python3 ptodsl/examples/jit/tadd_launch.py --emit-mlir > /tmp/tadd_dsl.mlir +ptoas --emit-pto-ir /tmp/tadd_dsl.mlir -o - | head +``` + +End-to-end under the `msprof` CPU simulator: + +```bash +scripts/sim_dsl.sh ptodsl/examples/jit/tadd_launch.py +``` + +Expected output: + +```text +PASS f32_16x64 compile=0.024s launch=35.193s +PASS f32_32x32 compile=0.022s launch=35.926s +All cases passed. +``` + +Direct run on a real NPU: + +```bash +python3 ptodsl/examples/jit/tadd_launch.py +``` + +### `flash_attention_softmax_launch.py` + +Launchable row-wise softmax demo. The kernel surface is the ordinary +`scores -> out` contract, while the implementation preloads the score matrix to +UB and then uses a packed online-softmax recurrence so one NPU can stream +64-row packs sequentially from UB. + +Compile-only: + +```bash +python3 ptodsl/examples/jit/flash_attention_softmax_launch.py --emit-mlir +``` + +End-to-end under the `msprof` CPU simulator: + +```bash +scripts/sim_dsl.sh ptodsl/examples/jit/flash_attention_softmax_launch.py +``` + +Expected output: + +```text +PASS rows64_seq128 +PASS rows81_seq96 +All cases passed. +``` + +Direct run on a real NPU: + +```bash +python3 ptodsl/examples/jit/flash_attention_softmax_launch.py +``` + +### Launch artifacts + +- `~/.cache/ptodsl/` — JIT-compiled kernel `.so` cache +- `build/msprof_res/` — `msprof` simulator trace output + +--- + +## Running regression checks + +```bash +cd $PTOAS_REPO_ROOT +python3 test/python/ptodsl_jit_compile.py +python3 test/python/ptodsl_jit_diagnostics.py +python3 test/python/ptodsl_subkernel_diagnostics.py +python3 test/python/ptodsl_flash_attention_demo_compile.py +python3 test/python/ptodsl_ptoas_frontend_verify.py +python3 test/python/ptodsl_docs_as_test.py +``` + +Expected output: + +``` +ptodsl_jit_compile: PASS +ptodsl_jit_diagnostics: PASS +ptodsl_subkernel_diagnostics: PASS +ptodsl_flash_attention_demo_compile: PASS +ptodsl_ptoas_frontend_verify: PASS +ptodsl_docs_as_test: PASS +``` + +`ptodsl_docs_as_test.py` is the docs-as-test regression for the PTODSL user +guide under `ptodsl/docs/user_guide/`. It scans every Python fenced code block +and requires each one to be explicitly classified with either +`ptodsl-doc-test` or `ptodsl-doc-pending` metadata. + +- `mode="compile"` blocks are executed as-authored and must pass the PTODSL + compile-only path, MLIR verify, and shared PTOAS frontend validation. +- `mode="compile_fragment"` blocks are embedded into explicit test fixtures so + representative partial snippets can be compiled under a declared outer + kernel context instead of relying on hidden heuristic context synthesis. +- `ptodsl-doc-pending` marks snippets the manual intends to treat as contract + later, but which are still blocked on missing implementation or missing test + harness support. + +Run it directly while editing the manual: + +```bash +cd $PTOAS_REPO_ROOT +python3 test/python/ptodsl_docs_as_test.py +``` + +When it fails, the diagnostic includes the Markdown path, starting line number, +and target symbol so the drift can be fixed in the manual instead of searching +through generated IR logs. + +These PTODSL regressions are intentionally complementary: + +- `ptodsl_jit_compile.py` protects canonical authored compile probes and + lowering contracts for the public PTODSL surface. +- `ptodsl_flash_attention_demo_compile.py` protects the bundled + `ptodsl/examplesflash_attention_sketch.py` authored demo as a stable end-to-end + contract. +- `ptodsl_ptoas_frontend_verify.py` protects the handoff from PTODSL-emitted + MLIR into standalone `ptoas` frontend verification. +- `ptodsl_docs_as_test.py` protects the user manual itself: documented + self-contained examples must still compile, fixture-backed partial fragments + must still compile inside their declared context, and explicitly marked + pending snippets remain visible as docs/test debt. + +`ptodsl_docs_as_test.py` is not a replacement for the authored compile/demo +regressions above. It reuses the same compile-only and frontend-validation +boundaries, but its job is to keep `ptodsl/docs/user_guide/` honest rather than +to redefine the canonical demo contracts. + +The legacy `ptodsl/check_ir.py` script has been retired. PTODSL validation now +lives under `test/python/` so every regression shares the same bootstrap, +public surface, and canonical authored targets as the tracing/JIT +implementation. + +--- + +## DSL-style API quick reference + +```python +from ptodsl import pto, scalar +s = scalar # arith shorthand alias +``` + +`pto` is the main DSL namespace. `scalar` is a separate top-level helper +namespace for runtime scalar load/store, arithmetic helpers, and scalar math; +it is intentionally not exported as `pto.scalar`. + +### Kernel decorator + +```python +@pto.jit(name="MyKernel", kernel_kind="vector", target="a5") +def MyKernel(): + ... + +@pto.jit(name="Softmax", kernel_kind="vector", target="a5") +def Softmax( + X: pto.tensor_spec(rank=2, dtype=pto.f32), + O: pto.tensor_spec(rank=2, dtype=pto.f32), + *, + BLOCK: pto.constexpr = 128, +): + ... + +print(MyKernel) # prints MLIR text +mod = MyKernel.mlir_module() # returns mlir.ir.Module +``` + +`@pto.jit` now emits a flat aicore launch-entry module by default. The traced +entry function carries the `pto.aicore` attribute and lives directly under the +top-level module, which matches the runtime-launch path and merged-MLIR example +flow. + +PTODSL v1 keeps the public `@pto.jit` entry ABI intentionally narrow: + +- positional parameters are Python-native tensors declared with + `pto.tensor_spec(...)` +- positional runtime scalars use PTO scalar annotations such as `pto.i32`, + `pto.f32`, and `pto.i1`, while launch-time values remain ordinary Python + scalars +- keyword-only parameters annotated with `pto.constexpr` are compile-time + specialization knobs + +Typed pointers such as `pto.ptr(...)` remain valid PTODSL surface types inside +kernel bodies and explicit-mode sub-kernels, but they are not the recommended +host-visible `@pto.jit` parameter contract. + +Additional layered kernel entry modes and shared compute decorators are also +exported on the public surface: `@pto.jit(mode="auto")`, +`@pto.jit(mode="explicit")`, `@pto.cube`, `@pto.simd`, and `@pto.simt`. + +### Type descriptors (lazy – safe to use in annotations) + +| Expression | MLIR type | +|---|---| +| `pto.float32` | `f32` | +| `pto.int32` | `i32` | +| `pto.int64` | `i64` | +| `pto.index` | `index` | +| `pto.ptr(pto.float32, "gm")` | `!pto.ptr` | +| `pto.ptr(pto.float32, "ub")` | `!pto.ptr` | + +### Type constructors (eager – require active context) + +```python +vf32 = pto.vreg_type(64, pto.float32) # !pto.vreg<64xf32> +tile_col = pto.alloc_tile(shape=[8, 1], dtype=pto.float32, blayout="ColMajor") +tile_w = pto.alloc_tile(shape=[8, 128], dtype=pto.float32) +``` + +### Constants + +```python +c0 = pto.const(0) # index +c1_i32 = pto.const(1, dtype=pto.int32) +c64_i64= pto.const(64, dtype=pto.int64) +``` + +### Control flow + +```python +with pto.simd(): # pto.simd { … } + ... + +with pto.for_(c0, c16, step=c1) as i: # simple scf.for + ... # scf.yield inserted automatically + +loop = pto.for_(c0, c128, step=c64).carry(lhs=a, rhs=b) +with loop: + x = loop.lhs + y = loop.rhs + ... + loop.update(lhs=nx, rhs=ny) +fx = loop.final("lhs") +fy = loop.final("rhs") + +with pto.if_(has_rows) as br: # simple scf.if + with br.then_: + ... + +with pto.if_(has_chunk) as br: + with br.then_: + br.assign(x=merged_max, y=merged_sum) + with br.else_: + br.assign(x=running_max, y=running_sum) +x = br.x +y = br.y +``` + +### Scalar arithmetic (`s = scalar`) + +```python +s.muli(a, b) # arith.muli +s.addi(a, b) # arith.addi +s.subi(a, b) # arith.subi +s.index_cast(val) # arith.index_cast → index +s.index_cast(pto.int32, val) # arith.index_cast → i32 +(a > b) # scalar compare → pto.i1 +(a <= b) # scalar compare → pto.i1 +s.select(cond, t, f) # arith.select +``` + +### PTO operations + +```python +pto.castptr(addr, ptr_type) # pto.castptr +pto.addptr(ptr, offset) # pto.addptr +pto.vlds(ptr, offset) # pto.vlds, result vreg inferred from ptr element type +pto.vbr(scalar) # pto.vbr, scalar broadcast -> vreg +pto.vsts(v, ptr, offset, mask) # pto.vsts +pto.plt_b32(scalar) # → (mask, scalar_out) +pto.pset_b32("PAT_ALL") # pto.pset_b32 → mask +pto.vbitcast(v, dtype) # pto.vbitcast +pto.pbitcast(mask, mask_type) # pto.pbitcast +pto.vadd(a, b, mask) # infers result type from a.type +pto.vmul / vmax / vdiv / vcmax / vcadd / vdup / vexpdif # similarly +pto.make_tensor_view(ptr, shape=…, strides=…) # type inferred +pto.partition_view(tv, offsets=…, sizes=…) # type inferred +pto.alloc_tile(shape=…, dtype=…, memory_space=…, valid_shape=…, addr=…) # authored surface +pto.tile.load(part, tile) +pto.tile.store(tile, part) +tile.as_ptr() / view.as_ptr() +pto.get_block_idx() # → i64 +pto.set_flag("MTE2", "V", event_id=0) +pto.wait_flag("MTE2", "V", event_id=0) +pto.pipe_barrier(pto.Pipe.ALL) +``` + +## How the IR check works + +``` +generated IR ──┐ + ├── Module.parse() → canonical string ──── == ──── PASS/FAIL +reference .pto ──┘ (strips comments, normalises SSA names and attr order) +``` + +Constant declaration order is preserved after the round-trip; builders must +emit constants in the same order as the reference. The diff output makes any +mismatch immediately visible. diff --git a/ptodsl/docs/user_guide/01-introduction.md b/ptodsl/docs/user_guide/01-introduction.md new file mode 100644 index 000000000..9de4d02e9 --- /dev/null +++ b/ptodsl/docs/user_guide/01-introduction.md @@ -0,0 +1,222 @@ +# 1. Introduction + +**PTO** is a virtual instruction set designed for the Ascend NPU — a hardware-abstracted programming model that exposes the full capability of the Cube, Vector, and Scalar compute units through a unified operation set. **PTODSL** is the Python frontend for PTO. It wraps the PTO instruction set in a Python-embedded DSL with tracing-based compilation, so you can write PTO programs using familiar Python syntax. Under the hood, PTODSL traces your kernel function into PTO IR, which the PTOAS compiler then lowers, optimizes, and emits as NPU executables. In short: PTO defines the *what* (the instruction set), PTODSL provides the *how* (the authoring experience), and together they give you direct access to all three NPU compute units without leaving Python. + +## 1.1 Target hardware + +The Ascend NPU is organized around three compute units and a shared on-chip buffer, connected through the Memory Transfer Engine (MTE): + +``` + ┌─────────────────────────┐ + │ Global Memory (GM) │ + │ (off-chip HBM) │ + └────────────┬──────────────┘ + │ + ┌──────────┴──────────┐ + │ MTE (DMA engine) │ + └──────────┬──────────┘ + │ + ┌────────────┴──────────────┐ + │ Unified Buffer (UB) │ + │ (on-chip scratchpad) │ + └──┬───────────┬──────────┬──┘ + │ │ │ + ┌────────┴──┐ ┌─────┴──────┐ │ + │ LEFT/RIGHT│ │ │ │ + │ /ACC/BIAS│ │ Vector │ │ + │ │ │ (SIMD) │ │ + │ Cube │ │ │ │ + │ │ └────────────┘ │ + └───────────┘ │ + ┌──────────┴──┐ + │ SIMT │ + │ (scalar PG) │ + └─────────────┘ +``` + +| Unit | Role | Typical workload | +|------|------|------------------| +| **Cube** | Matrix multiplication | GEMM, convolution | +| **SIMD** | Row-wise vector math | activation, normalization, reduction | +| **SIMT** | Scalar-programmable unit | pointwise tile walks, metadata | + +- **Global Memory (GM)** is off-chip HBM. All input and output tensors reside here. +- **Unified Buffer (UB)** is the on-chip scratchpad shared by all three compute units. Tile buffers and intermediate results live here during kernel execution. +- **MTE** (Memory Transfer Engine) handles DMA transfers between GM and UB, and between UB regions. +- **Cube** has its own private on-chip buffers — LEFT, RIGHT, ACC, and BIAS — for staging matrix operands and accumulators. +- **SIMD** executes row-wise vector instructions directly on UB-resident data. +- **SIMT** is a scalar-programmable processor group that executes scalar instructions across many work-items in parallel. It is well-suited for per-element control logic, tile boundary metadata, and pointwise blends. + +PTODSL gives you direct access to all three units and the data-movement +surfaces around them, without abstracting away the hardware boundaries. + +## 1.2 Authoring model + +PTODSL's public kernel model is **one entry point, two modes**: + +``` +Python Wrapper L0 user-facing wrapper (NumPy, torch-npu, pure Python) + └─ @pto.jit(mode="auto") tile-first authoring, compiler-managed staging + └─ @pto.jit(mode="explicit") micro-instruction authoring, user-managed staging + ├─ Tile Ops tile.load, tile.store, tile.add, ... + ├─ MTE Ops mte_load / mte_store / mte_gm_ub / ... + ├─ @pto.cube matrix products (mad, mte_l1_l0a, mte_l0c_ub, ...) + ├─ @pto.simd row-wise vector math (vlds, vadd, vexp, vsts, ...) + └─ @pto.simt scalar-like compute (lds, sts, pointwise blends, ...) +``` + +### Python wrapper + +The outermost layer is plain Python. It handles ergonomic runtime concerns: allocating output tensors, extracting shapes and strides from framework tensors, compiling the JIT kernel, and launching it. Because the wrapper is just Python, you can freely mix in NumPy, torch-npu, or any other Python framework for pre- and post-processing, data preparation, or composing multiple kernel launches. It knows nothing about NPU internals — it is just a convenience function that most end users will call. + + +```python +def flash_attention(Q, K, V, *, O=None, causal=False): + if O is None: + O = pto.empty_like(Q) + compiled = flash_attention_kernel.compile( + BLOCK_Q=128, BLOCK_KV=128, CAUSAL=causal + ) + compiled[batch * heads, stream](Q, K, V, O) + return O +``` + +### `@pto.jit` — the kernel entry + +Decorating a function with `@pto.jit` marks it as a launchable PTO kernel. This decoration means: + +- **Compilation**: the function body is traced once to record all PTO instructions, then lowered through the PTOAS compiler pipeline into an optimized NPU executable. +- **Caching**: compiled kernels are cached by specialization key (function identity + tensor ABI signature + constexpr parameter values), so repeated calls with the same configuration skip recompilation. +- **Launch binding**: the compiled kernel can be invoked with a grid and stream — `compiled[grid, stream](args...)` — which launches the executable on the NPU with the given SPMD grid. + +The parameters of a `@pto.jit` function are Python-native tensors (not PTODSL-specific descriptors). In PTODSL v1, their ABI contract is declared with `pto.tensor_spec(...)` in the function signature; this is a compile-time annotation, not a runtime object the Python wrapper must construct. The kernel body materializes `TensorView` descriptors from the runtime tensors via `make_tensor_view`, then partitions the problem with `partition_view`. Compile-time constants are declared as keyword-only arguments with `pto.constexpr`: + + +```python +from ptodsl import pto + + +@pto.jit(target="a5") +def flash_attention_kernel( + Q: pto.tensor_spec(rank=4, dtype=pto.f32), + K: pto.tensor_spec(rank=4, dtype=pto.f32), + V: pto.tensor_spec(rank=4, dtype=pto.f32), + O: pto.tensor_spec(rank=4, dtype=pto.f32), + *, + BLOCK_Q: pto.constexpr = 128, + BLOCK_KV: pto.constexpr = 128, + CAUSAL: pto.constexpr = False, +): + # ... tile allocation, block partitioning, and sub-kernel dispatch ... + return +``` + +`@pto.jit` is the only host-visible kernel entry. Its `mode` selects the +programming model: + +- `mode="auto"` (the default) is **tile-centric**. You allocate tiles, partition + GM views, use Tile Ops (`tile.load`, `tile.store`, `tile.add`, ...), and call + compute sub-kernels. The compiler manages staging and scheduling around the + tile abstraction. +- `mode="explicit"` is **tile + micro-instruction**. You keep the same tile + surface from `auto`, but also gain access to the full micro-instruction + set — MTE ops (`mte_load`, `mte_store`, ...), explicit synchronization, + and direct pointer manipulation — so you can reach below the tile abstraction + and control individual instructions when needed. + +In both modes, `@pto.jit` is where you allocate tiles (`alloc_tile`) and use +Tile Ops. The difference is that `explicit` additionally opens up the +micro-instruction surface — MTE ops, explicit sync, and pointer-level +control — so you can mix tile operations with hand-authored instructions in +the same kernel. + +The SPMD launch contract is also owned here: the runtime grid (e.g., `batch * heads` blocks) is declared at the call site, and block/subblock indices are queried via `pto.get_block_idx()` and friends. + +### Sub-kernels — `@pto.cube` / `@pto.simd` / `@pto.simt` + +These are hardware-bound compute sub-kernels, each mapped to a specific NPU compute unit: + +- **`@pto.cube`** consumes UB tiles and explicit cube-local scratch (LEFT, RIGHT, ACC, BIAS). Typical operations: `mad`, `mte_l1_l0a`, `mte_l1_l0b`, `mte_l0c_ub`. + +- **`@pto.simd`** operates on vector registers (`vreg`). Typical operations: `vlds`, `vadd`, `vexp`, `vcgmax`, `vsts`. Vector registers never cross the simd function boundary — persistent state is written back to UB tiles. + +- **`@pto.simt`** is a scalar-programmable processor group that executes scalar instructions across many work-items in parallel. Typical operations: `lds`, `sts`, scalar arithmetic and comparison. Well-suited for per-element tile walks, boundary metadata, and pointwise blends. + +Each can be invoked as a named decorated function (`@pto.cube` / +`@pto.simd` / `@pto.simt`) or inline as a context manager +(`with pto.cube():`, `with pto.simd():`, `with pto.simt():`). + +The boundary contract is strict: vreg values do not escape a simd kernel, cube-local state does not leak into UB, and data crosses layer boundaries only through UB-backed tiles or typed UB pointers. + +## 1.3 Tracing execution model + +PTODSL uses a **tracing** compilation model. When you call `kernel.compile(...)`, PTODSL executes your Python function body once to record every PTO instruction into an intermediate representation — this pass is called *tracing*. The traced IR is then lowered and optimized into device code. Once compiled, invoking `compiled[grid, stream](args...)` launches the already-built device code directly on the NPU. + +This has one critical implication for how you write control flow and scalar logic: + +- **Python native control flow** (`for`, `if`, Python arithmetic) runs at trace time. A `for i in range(4)` loop gets unrolled — the device code contains four copies of the body, not a loop instruction. An `if` branch condition is evaluated at trace time, and only the taken branch is recorded. + +- **`pto.for_` / `pto.if_`** are recorded as structured control-flow IR. They preserve loop and branch semantics into the compiler pipeline, where the PTOAS compiler may further optimize them — unrolling, folding, or keeping them as runtime control flow depending on what is known at compile time. + +- **Python scalar expressions** (`alpha * x`, `1.0 / sqrt(d)`) are evaluated at trace time and their results are baked into the IR as constants — the compiler never sees the original expression. + +- **PTO scalar instructions** (`scalar.load(...)`, `scalar.max(...)`, `scalar.exp(...)`) are recorded as scalar IR and enter the compiler pipeline, where they may be constant-folded or lowered to runtime scalar operations depending on whether their inputs are compile-time known. + +A simple rule of thumb: **Python constructs are resolved before the compiler sees them. PTO constructs are recorded into IR and the compiler decides.** + +Chapter 5 (Control Flow) and Chapter 6 (Scalar & Pointer Operations) cover this in detail. + +## 1.4 A worked example + +The flash attention kernel from Section 1.2 is not just an architectural diagram — it is a complete, runnable design sketch distributed with PTODSL (`examples/flash_attention_sketch.py`). Here is how the layers map to actual code: + +**Top-level `@pto.jit` schedule** allocates tiles for the Q block, KV block, +online-softmax state (m/l/o ping-pong tiles), and cube-local scratch. It loops +over Q blocks (outer `pto.for_`) and KV blocks (inner `pto.for_` with carry +state), and uses `tile.load`/`tile.store` at the GM boundary. + +**`mode="explicit"` orchestration path** stages the current K and V blocks with +`mte_load`, issues `pipe_barrier(Pipe.ALL)` at phase boundaries, then +sequences four sub-kernel calls: `qk_matmul` (cube), +`online_softmax_rows` (simd), `pv_matmul` (cube), `blend_output_rows` (simt). + +**`@pto.cube`** performs `mte_l1_l0a` / `mte_l1_l0b` / `mad` / +`mte_l0c_ub` for both QK^T and P@V products. + +**`@pto.simd`** implements the online softmax update: per-row max, exp, sum, +and alpha/beta computation using vector ops (`vlds`, `vcgmax`, `vexp`, +`vcgadd`, `vsts`). + +**`@pto.simt`** blends the old and new output accumulators with per-element +`lds`/`sts` and scalar arithmetic. + +Chapter 11 walks through this example in full detail. + +## 1.5 Reading guide + +| If you are... | Start with... | +|---------------|---------------| +| New to PTODSL | Chapter 2 (Quick Start), then Chapter 3 (Kernel Entries) | + +| Writing your first kernel | Chapter 2 → Chapter 4 (Type System) → Chapter 5 (Control Flow) | +| Looking up a specific operation | Chapters 6–10 (organized by topic) | +| Understanding the flash attention reference | Chapter 11 | + +**Chapter overview:** + +| Chapter | Topic | +|---------|-------| +| 1 | Introduction (this chapter) | +| 2 | Quick Start — a minimal working kernel | +| 3 | Kernel entry and sub-kernels: `@pto.jit(mode=...)`, `@pto.cube`, `@pto.simd`, `@pto.simt` | +| 4 | Type system and buffer management: scalars, tiles, views, allocation | +| 5 | Control flow: trace-time Python vs device-side `pto.for_` / `pto.if_` | +| 6 | Scalar and pointer operations | +| 7 | Data movement: tile loads/stores, DMA, vector loads/stores, cube data movement | +| 8 | Compute operations: tile-level, vector, and cube arithmetic | +| 9 | Predicate and mask operations | +| 10 | Synchronization: barriers, flags, memory fences | +| 11 | Flash attention walkthrough | +| 12 | Additional examples | +| 13 | Migration from the old `@pto.vkernel`/`@pto.ckernel` API | +| 14 | Common errors and compatibility notes | diff --git a/ptodsl/docs/user_guide/02-quick-start.md b/ptodsl/docs/user_guide/02-quick-start.md new file mode 100644 index 000000000..8ee4aa5b8 --- /dev/null +++ b/ptodsl/docs/user_guide/02-quick-start.md @@ -0,0 +1,258 @@ +# 2. Quick Start + +This chapter walks through a minimal but complete PTODSL kernel — a tiled copy from one GM tensor to another — covering the essential concepts you need to start writing your own kernels. + +## 2.1 A first kernel: tiled copy + + +```python +from ptodsl import pto + + +@pto.jit(target="a5") +def tile_copy( + A: pto.tensor_spec(rank=2, dtype=pto.f32), + O: pto.tensor_spec(rank=2, dtype=pto.f32), + *, + BLOCK: pto.constexpr = 128, +): + """Copy one 2D tensor tile from A to O.""" + + rows = A.shape[0] + cols = A.shape[1] + + # Describe the GM tensors. + a_view = pto.make_tensor_view(A, shape=A.shape, strides=A.strides) + o_view = pto.make_tensor_view(O, shape=O.shape, strides=O.strides) + + # Allocate UB tiles for one row-strip block. + a_tile = pto.alloc_tile(shape=[1, BLOCK], dtype=pto.f32) + o_tile = pto.alloc_tile(shape=[1, BLOCK], dtype=pto.f32) + + # Partition the GM views to cover the current logical slice. + a_part = pto.partition_view(a_view, offsets=[0, 0], sizes=[rows, cols]) + o_part = pto.partition_view(o_view, offsets=[0, 0], sizes=[rows, cols]) + + # Load from GM into UB, then store back out. + pto.tile.load(a_part, a_tile) + pto.tile.store(o_tile, o_part) +``` + +Let us step through each piece. + +### The entry point + +```python +@pto.jit(target="a5") +def tile_copy(A, O, *, BLOCK: pto.constexpr = 128): +``` + +`@pto.jit` marks this function as a launchable PTO kernel. The positional parameters `A` and `O` are Python-native tensors — they arrive from NumPy, torch-npu, or any framework that provides a shape and strides. Their ABI contract is declared with `pto.tensor_spec(...)`. The keyword-only argument `BLOCK` is a compile-time constant declared with `pto.constexpr`; the compiler specializes the kernel for each tile width. + +### Describing GM tensors + +```python +a_view = pto.make_tensor_view(A, shape=A.shape, strides=A.strides) +``` + +`make_tensor_view` wraps a Python tensor into a `TensorView` — a descriptor that tells the kernel how to address the tensor in global memory. You provide the logical shape and the stride (in elements) of each dimension. + +### Allocating on-chip buffers + +```python +a_tile = pto.alloc_tile(shape=[1, BLOCK], dtype=pto.f32) +``` + +`alloc_tile` reserves space in the Unified Buffer (UB). A `Tile` is a 2D buffer that lives on-chip during kernel execution. Every tile has a `shape` and a `dtype`. + +### Partitioning GM views + +```python +a_part = pto.partition_view(a_view, offsets=[0, 0], sizes=[rows, cols]) +``` + +`partition_view` creates a sub-view of a `TensorView` at a given offset and size. It describes *which part* of the GM tensor a `tile.load` or `tile.store` should operate on. For this simple whole-tensor example the offset is zero and the size matches the logical tensor extent; in a blocked kernel you would slide the offset through a loop. + +### Moving data: tile.load and tile.store + +```python +pto.tile.load(a_part, a_tile) # GM → UB +pto.tile.store(o_tile, o_part) # UB → GM +``` + +`tile.load` copies a block of data from GM (described by a partition) into a UB tile. `tile.store` copies a UB tile back to GM. These are **Tile Ops** — they operate on entire tile buffers at once. + +### Why start with copy + +```python +pto.tile.load(a_part, a_tile) +pto.tile.store(o_tile, o_part) +``` + +A copy kernel strips the example down to the essential PTODSL boundary objects: + +- host tensors entering `@pto.jit` +- `TensorView` descriptors over GM tensors +- UB `Tile` allocation +- `PartitionTensorView` slices +- tile-level movement with `tile.load` / `tile.store` + +Once these pieces are clear, arithmetic and sub-kernel orchestration become much easier to layer on. + +## 2.2 A blocked version with a loop + +The kernel above touches one logical slice directly. To introduce device-side control flow, we can iterate over the rows of a 2D tensor and copy one row-strip at a time: + + +```python +from ptodsl import pto + + +@pto.jit(target="a5") +def blocked_copy( + A: pto.tensor_spec(rank=2, dtype=pto.f32), + O: pto.tensor_spec(rank=2, dtype=pto.f32), + *, + BLOCK: pto.constexpr = 128, +): + rows = A.shape[0] + cols = A.shape[1] + + a_view = pto.make_tensor_view(A, shape=A.shape, strides=A.strides) + o_view = pto.make_tensor_view(O, shape=O.shape, strides=O.strides) + + tile = pto.alloc_tile(shape=[1, BLOCK], dtype=pto.f32) + + with pto.for_(0, rows, step=1) as row: + a_part = pto.partition_view(a_view, offsets=[row, 0], sizes=[1, cols]) + o_part = pto.partition_view(o_view, offsets=[row, 0], sizes=[1, cols]) + + pto.tile.load(a_part, tile) + pto.tile.store(tile, o_part) +``` + +Here `rows` and `cols` are dynamic — they come from `A.shape` and can differ across launches. The loop bound depends on `rows`, so `pto.for_` records a structured loop in the IR rather than unrolling at trace time. The `BLOCK` parameter stays `constexpr` because it is a tuning knob, not data-dependent. Chapter 5 covers this distinction in detail. + +## 2.3 Compile and launch + +Once the kernel is defined, you compile it and then launch it: + + +```python +# Compile once, cache the result. +compiled = blocked_copy.compile(BLOCK=128) + +# Allocate or obtain input/output tensors (NumPy, torch-npu, ...). +import numpy as np +A = np.random.randn(4, 128).astype(np.float32) +O = np.empty_like(A) + +# Launch on the NPU. +compiled[1, None](A, O) +``` + +- `.compile(**constexprs)` traces the kernel body, lowers it through the PTOAS pipeline, and returns a compiled handle. Repeated calls with the same tensor ABI contract and constexpr configuration hit the cache. +- `compiled[grid, stream](args...)` launches the compiled kernel. `grid` is the number of SPMD blocks; `stream` is the NPU stream (or `None` for the default). + +## 2.4 SPMD launch + +For workloads that can be parallelized across multiple blocks, specify a grid: + +```python +# Process batch * heads slices in parallel. +compiled[batch * heads, stream](Q, K, V, O) +``` + +Inside the kernel, each block queries its index: + +```python +block_idx = pto.get_block_idx() +block_num = pto.get_block_num() +``` + +This lets you map different data slices to different blocks — for example, one block per (batch, head) pair in flash attention. + +## 2.5 Adding sub-kernels and explicit orchestration + +The examples above used Tile Ops (`tile.load` / `tile.store` here, and +arithmetic Tile Ops in later chapters), which operate on entire tiles at once. +When you need finer control — for instance, writing a custom softmax or an +activation that maps directly to vector hardware — you can keep the same +`@pto.jit` entry and add sub-kernels. If you also need micro-instruction control, +switch that kernel to `mode="explicit"`: + + +```python +# SIMD sub-kernel — vector instructions on individual rows. +@pto.simd +def add_rows(a_tile: pto.Tile, b_tile: pto.Tile, o_tile: pto.Tile, + rows: pto.index, cols: pto.index): + VEC = pto.elements_per_vreg(pto.f32) + with pto.for_(0, rows, step=1) as r: + col_loop = pto.for_(0, cols, step=VEC).carry(remained=cols) + with col_loop: + c = col_loop.iv + remained = col_loop.remained + mask, remained = pto.make_mask(pto.f32, remained) + a_vec = pto.vlds(a_tile[r, c:]) + b_vec = pto.vlds(b_tile[r, c:]) + o_vec = pto.vadd(a_vec, b_vec, mask) + pto.vsts(o_vec, o_tile[r, c:], mask) + col_loop.update(remained=remained) + +# Single kernel entry in explicit mode — micro-instruction staging plus SIMD sub-kernel. +@pto.jit(target="a5", mode="explicit") +def vec_add_micro( + A: pto.tensor_spec(rank=1, dtype=pto.f32), + B: pto.tensor_spec(rank=1, dtype=pto.f32), + O: pto.tensor_spec(rank=1, dtype=pto.f32), + *, + BLOCK: pto.constexpr = 128, +): + N = A.shape[0] + a_view = pto.make_tensor_view(A, shape=[N], strides=A.strides) + b_view = pto.make_tensor_view(B, shape=[N], strides=B.strides) + o_view = pto.make_tensor_view(O, shape=[N], strides=O.strides) + + a_tile = pto.alloc_tile(shape=[1, BLOCK], dtype=pto.f32) + b_tile = pto.alloc_tile(shape=[1, BLOCK], dtype=pto.f32) + o_tile = pto.alloc_tile(shape=[1, BLOCK], dtype=pto.f32) + + num_blocks = (N + BLOCK - 1) // BLOCK + with pto.for_(0, num_blocks, step=1) as i: + offset = i * BLOCK + this_block = scalar.min(N - offset, BLOCK) + a_part = pto.partition_view(a_view, offsets=[offset], sizes=[this_block]) + b_part = pto.partition_view(b_view, offsets=[offset], sizes=[this_block]) + o_part = pto.partition_view(o_view, offsets=[offset], sizes=[this_block]) + row_bytes = this_block * pto.bytewidth(pto.f32) + pto.mte_load(a_part.as_ptr(), a_tile.as_ptr(), 0, row_bytes, + nburst=(1, 0, 0)) + pto.mte_load(b_part.as_ptr(), b_tile.as_ptr(), 0, row_bytes, + nburst=(1, 0, 0)) + pto.pipe_barrier(pto.Pipe.ALL) + add_rows(a_tile, b_tile, o_tile, 1, this_block) + pto.pipe_barrier(pto.Pipe.ALL) + pto.mte_store(o_tile.as_ptr(), o_part.as_ptr(), row_bytes, + nburst=(1, 0, 0)) +``` + +- **`@pto.jit(mode="explicit")`**: allocates tiles, partitions the GM views, + loops over blocks, and directly authors the micro-instruction schedule for + each block. + +- **`@pto.simd` sub-kernel**: the top-level kernel calls a SIMD sub-kernel + for the row-wise vector work while keeping instruction staging in the + explicit entry body. + +- **Inside `@pto.simd`**: the outer `pto.for_` iterates over rows, the inner + `pto.for_` iterates over column chunks of the hardware vector width + (`elements_per_vreg`). Each iteration loads a vector-width slice into a + `vreg`, does the addition under a mask (for tail elements), and stores the + result back. Both loops are recorded as structured control flow IR — the + compiler decides whether to keep them or unroll them. + +The same pattern also has an `auto` counterpart: keep `@pto.jit` in its +default mode and replace the explicit `mte_*` sequence with `tile.load` / +`tile.store`. Chapter 3 covers the full entry model; Chapters 7–10 cover each +operation family in detail. diff --git a/ptodsl/docs/user_guide/03-kernel-entry-and-subkernels.md b/ptodsl/docs/user_guide/03-kernel-entry-and-subkernels.md new file mode 100644 index 000000000..4ec1fc0fb --- /dev/null +++ b/ptodsl/docs/user_guide/03-kernel-entry-and-subkernels.md @@ -0,0 +1,690 @@ +# 3. Kernel Entry and Sub-Kernels + +PTODSL provides one host-visible kernel decorator (`@pto.jit`) and three +compute-unit sub-kernel decorators (`@pto.cube`, `@pto.simd`, `@pto.simt`), +plus matching context managers for inline use. This chapter covers the kernel +entry, the two programming models, sub-kernel reference, parameter contracts, +and boundary constraints. + +## 3.1 `@pto.jit` — the only kernel entry + +Decorator overview: + +```text +@pto.jit(mode="auto") tile-first authoring, compiler-managed staging +@pto.jit(mode="explicit") micro-instruction authoring, user-managed staging +@pto.cube Cube-unit matrix sub-kernel +@pto.simd SIMD-unit vector sub-kernel +@pto.simt SIMT-unit scalar sub-kernel +``` + +### Role + +`@pto.jit` marks a function as a launchable PTO kernel. It owns compilation +(tracing + lowering), caching, and runtime launch binding. This is the only +decorator that can be invoked directly from the host; the compute-unit +decorators define sub-kernels that are called from within `@pto.jit`. + +### Signature + + +```python +@pto.jit(target="a5", mode="auto") +def kernel_name( + tensor_arg_1: pto.tensor_spec(rank=1, dtype=pto.f32), # Python-native tensor (positional) + tensor_arg_2: pto.tensor_spec(rank=1, dtype=pto.f32), # Python-native tensor (positional) + *, + CONST_A: pto.constexpr = 128, # compile-time constant (keyword-only) + CONST_B: pto.constexpr = 64, # compile-time constant (keyword-only) +): + # ... tensor views, tile allocation, and kernel logic ... + return +``` + +### How to declare and pass parameters + +A `@pto.jit` kernel accepts three kinds of parameters. Each has a distinct role, +position in the signature, and way to supply the value: + +| Parameter kind | Position | Annotation | Pass the value at | +|---|---|---|---| +| **Tensor** | positional (before `*`) | `pto.tensor_spec(rank=N, dtype=...)` | launch time | +| **Runtime scalar** | positional (before `*`) | `pto.i32`, `pto.f32`, `pto.i1`, etc. | launch time | +| **Compile-time constant** | keyword-only (after `*`) | `pto.constexpr = ` | compile time | + +#### 1. Tensor parameters + +Declare a positional parameter with `pto.tensor_spec(rank=..., dtype=...)`. +At launch time, pass a **Python-native tensor** — a NumPy array, a torch-npu +tensor, or any object with `.shape`, `.dtype`, `.strides` (or `.stride()`), and +a data pointer (`.data_ptr()` or `.ptr`): + +```python +@pto.jit(target="a5") +def my_kernel( + X: pto.tensor_spec(rank=2, dtype=pto.f32), + O: pto.tensor_spec(rank=2, dtype=pto.f32), +): + # Inside the body, access shape/strides/dtype directly: + rows, cols = X.shape[0], X.shape[1] + # Then wrap with make_tensor_view(...) to build a GM descriptor: + x_view = pto.make_tensor_view(X, shape=X.shape, strides=X.strides) +``` + +#### 2. Runtime scalar parameters + +Declare a positional parameter with a PTO scalar annotation (`pto.i32`, +`pto.f32`, `pto.i1`, etc.). At launch time, pass an ordinary Python +`int`, `float`, or `bool`: + +```python +@pto.jit(target="a5") +def my_kernel( + X: pto.tensor_spec(rank=2, dtype=pto.f32), + n: pto.i32, # pass an int at launch + alpha: pto.f32, # pass a float at launch +): + # Scalars arrive as PTO values and can be used directly in + # index math, loop bounds, comparisons, and sub-kernel calls: + limit = n // 2 +``` + +#### 3. Compile-time constants + +Declare after `*` with `pto.constexpr` and a default value. +Pass the value to `.compile(...)` — **not** at launch time: + +```python +@pto.jit(target="a5") +def my_kernel( + X: pto.tensor_spec(rank=2, dtype=pto.f32), + *, + BLOCK: pto.constexpr = 128, +): + # BLOCK is a Python value at trace time — use it for tile shapes, + # unrolled loops, or dtype arguments: + tile = pto.alloc_tile(shape=[1, BLOCK], dtype=pto.f32) +``` + +The compiler specializes the kernel for each combination of constexpr values. +Once compiled, the values are baked in — they cannot change between launches of +the same compiled instance. To use a different value, call `.compile(...)` again. + +### Full example: declare and launch + +Bringing all three kinds together: + +```python +@pto.jit(target="a5", mode="auto") +def scaled_bias_add( + X: pto.tensor_spec(rank=2, dtype=pto.f32), # tensor + O: pto.tensor_spec(rank=2, dtype=pto.f32), # tensor + alpha: pto.f32, # runtime scalar + bias: pto.f32, # runtime scalar + *, + BLOCK: pto.constexpr = 128, # compile-time constant +): + rows, cols = X.shape[0], X.shape[1] + # ... use alpha, bias, BLOCK inside the kernel body ... + return +``` + +```python +# Step 1 — compile: constexpr values go to .compile() +compiled = scaled_bias_add.compile(BLOCK=64) + +# Step 2 — launch: tensors and runtime scalars go to compiled[grid, stream](...) +import numpy as np +X = np.random.randn(4, 128).astype(np.float32) +O = np.empty_like(X) +compiled[1, None](X, O, 2.0, 1.0) # alpha=2.0, bias=1.0 +``` + +### What is NOT accepted at the entry + +The following types are intentionally **not** accepted as `@pto.jit` parameters: + +- `pto.ptr(...)` — typed pointers are available inside the kernel body and + across sub-kernel boundaries, but not at the host/kernel entry. +- `Tile`, `PartitionTensorView`, `VReg` — these are created inside the kernel + body, not passed from the host. + +They are valid **inside** the kernel and across sub-kernel calls, just not at +the public host/kernel boundary. + +### `mode`: auto vs explicit + +`mode` is a keyword on the decorator, not a function parameter. It selects the +programming model: + +- `mode="auto"` (the default) is **tile-centric**. You write kernels in terms + of tiles and Tile Ops. The compiler manages staging, scheduling, and + synchronization around the tile abstraction. +- `mode="explicit"` adds the full **micro-instruction** surface — MTE ops, + explicit synchronization, and direct pointer manipulation — on top of + everything available in `auto`. + +`mode` changes what you can write **inside the kernel body**. It does **not** +change the recommended host-visible entry ABI: both modes use the same +`tensor_spec(...)` + runtime scalar + `constexpr` contract at the `@pto.jit` +boundary. + +Section 3.2 covers the two models in detail. + +### Compilation and launch + + +```python +import numpy as np + + +# Compile (traces the body, lowers through PTOAS, caches the result) +compiled = kernel_name.compile(CONST_A=128, CONST_B=64) + +# Allocate or obtain concrete tensors that match the declared host ABI. +A = np.random.randn(4, 128).astype(np.float32) +O = np.empty_like(A) + +# Launch on NPU +compiled[grid, stream](A, O) +``` + +- `.compile(**constexprs)` — traces the kernel body with the given constexpr + values, lowers the IR, and returns a compiled handle. Subsequent calls with + the same specialization key (function identity, tensor ABI signature, + constexpr values) hit the cache. +- `compiled[grid, stream](args...)` — launches the compiled kernel. `grid` is + the number of SPMD blocks (an integer); `stream` is the NPU stream (`None` + for default). + +### SPMD built-ins + +Available inside a `@pto.jit` body: + +| Built-in | Returns | Description | +|----------|---------|-------------| +| `pto.get_block_idx()` | `int` | Index of the current block (0-based) | +| `pto.get_block_num()` | `int` | Total number of blocks in the grid | +| `pto.get_subblock_idx()` | `int` | Index of the current sub-block | +| `pto.get_subblock_num()` | `int` | Total number of sub-blocks | + +### Typical body (auto mode) + +```python +@pto.jit(target="a5", mode="auto") +def my_kernel( + A: pto.tensor_spec(rank=2, dtype=pto.f32), + B: pto.tensor_spec(rank=2, dtype=pto.f32), + O: pto.tensor_spec(rank=2, dtype=pto.f32), + *, + BLOCK: pto.constexpr = 128, +): + rows = A.shape[0] + cols = A.shape[1] + a_view = pto.make_tensor_view(A, shape=A.shape, strides=A.strides) + b_view = pto.make_tensor_view(B, shape=B.shape, strides=B.strides) + o_view = pto.make_tensor_view(O, shape=O.shape, strides=O.strides) + + a_tile = pto.alloc_tile(shape=[1, BLOCK], dtype=pto.f32) + b_tile = pto.alloc_tile(shape=[1, BLOCK], dtype=pto.f32) + o_tile = pto.alloc_tile(shape=[1, BLOCK], dtype=pto.f32) + + with pto.for_(0, rows, step=1) as row: + a_part = pto.partition_view(a_view, offsets=[row, 0], sizes=[1, cols]) + b_part = pto.partition_view(b_view, offsets=[row, 0], sizes=[1, cols]) + o_part = pto.partition_view(o_view, offsets=[row, 0], sizes=[1, cols]) + + pto.tile.load(a_part, a_tile) + pto.tile.load(b_part, b_tile) + pto.tile.add(a_tile, b_tile, o_tile) + pto.tile.store(o_tile, o_part) +``` + +### Custom sub-kernels + +When Tile Ops don't cover the computation you need — a custom softmax, a +specialized activation, per-element blending — you write a sub-kernel in +`@pto.simd`, `@pto.simt`, or `@pto.cube` and call it directly from +`@pto.jit`. In auto mode, data movement stays with Tile Ops +(`tile.load`/`tile.store`) and PTOAS handles the synchronization between Tile +Ops and the sub-kernel: + + +```python +@pto.simd +def add_rows( + a_tile: pto.Tile, + b_tile: pto.Tile, + o_tile: pto.Tile, + rows: pto.index, + cols: pto.index, +): + VEC = pto.elements_per_vreg(pto.f32) + initial_remained = scalar.index_cast(pto.i32, cols) + with pto.for_(0, rows, step=1) as r: + col_loop = pto.for_(0, cols, step=VEC).carry(remained=initial_remained) + with col_loop: + c = col_loop.iv + remained = col_loop.remained + mask, remained = pto.make_mask(pto.f32, remained) + a_vec = pto.vlds(a_tile[r, c:]) + b_vec = pto.vlds(b_tile[r, c:]) + o_vec = pto.vadd(a_vec, b_vec, mask) + pto.vsts(o_vec, o_tile[r, c:], mask) + col_loop.update(remained=remained) + +@pto.jit(target="a5", mode="auto") +def my_kernel( + A: pto.tensor_spec(rank=2, dtype=pto.f32), + B: pto.tensor_spec(rank=2, dtype=pto.f32), + O: pto.tensor_spec(rank=2, dtype=pto.f32), + *, + BLOCK: pto.constexpr = 128, +): + rows = A.shape[0] + cols = A.shape[1] + a_view = pto.make_tensor_view(A, shape=A.shape, strides=A.strides) + b_view = pto.make_tensor_view(B, shape=B.shape, strides=B.strides) + o_view = pto.make_tensor_view(O, shape=O.shape, strides=O.strides) + + a_tile = pto.alloc_tile(shape=[1, BLOCK], dtype=pto.f32) + b_tile = pto.alloc_tile(shape=[1, BLOCK], dtype=pto.f32) + o_tile = pto.alloc_tile(shape=[1, BLOCK], dtype=pto.f32) + + with pto.for_(0, rows, step=1) as row: + a_part = pto.partition_view(a_view, offsets=[row, 0], sizes=[1, cols]) + b_part = pto.partition_view(b_view, offsets=[row, 0], sizes=[1, cols]) + o_part = pto.partition_view(o_view, offsets=[row, 0], sizes=[1, cols]) + + pto.tile.load(a_part, a_tile) + pto.tile.load(b_part, b_tile) + + add_rows(a_tile, b_tile, o_tile, 1, cols) + + pto.tile.store(o_tile, o_part) +``` + +Sub-kernels are the mechanism for custom compute in PTODSL — when Tile Ops +cover your needs, you don't need one; when they don't, a sub-kernel gives you +direct access to the hardware unit. In auto mode, a sub-kernel's parameters +are restricted to `Tile` and PTO scalar types — the compiler owns staging and +sync. In explicit mode, sub-kernels may also accept `PartitionTensorView` and +`pto.ptr` parameters, matching the richer type surface available there. +This richer pointer surface belongs to the **in-kernel orchestration and +sub-kernel boundary**, not to the public `@pto.jit` host entry ABI. +Section 3.3 covers each sub-kernel decorator in detail. + +## 3.2 Programming models: auto vs explicit + +`@pto.jit` exposes a single entry with two programming models. The entry's +host ABI, compilation flow, and launch mechanism are identical in both — the +difference is what you can write inside the kernel body. + +### `mode="auto"` — tile-centric + +In auto mode you think in tiles. You allocate tiles, partition GM views, move +data with `tile.load` and `tile.store`, compute with Tile Ops like +`tile.add` and `tile.exp`, and call sub-kernels for hardware-specific compute. +The compiler handles the lowering of tiles to micro-instructions: inferring +staging, inserting synchronization between Tile Ops and sub-kernels, and +managing tile-level scheduling. + +Use auto mode for the majority of kernels. It gives you the full performance +of the NPU without requiring you to reason about instruction-level ordering. + +### `mode="explicit"` — tile + micro-instruction + +Explicit mode extends auto mode with direct micro-instruction access. You keep +everything available in auto — tiles, Tile Ops, sub-kernels — and additionally +gain access to MTE ops, explicit synchronization, and pointer manipulation. +When you need precise control over individual instructions and phase ordering, +you can drop below the tile abstraction without leaving the `@pto.jit` entry. + +The richer type surface also applies to sub-kernels: in auto mode, a +sub-kernel's parameters are restricted to `Tile` and PTO scalar types; in +explicit mode they may also accept `PartitionTensorView` and `pto.ptr`, +matching the types available in the enclosing orchestration code. Organize +orchestration logic into helper functions that accept these types: + + +```python +def my_orchestration_helper( + part: pto.PartitionTensorView, # GM partition descriptors + tile: pto.Tile, # UB tile buffers + scratch: pto.Tile, # cube-local scratch (LEFT, RIGHT, ...) + ptr: pto.ptr(pto.f32, pto.MemorySpace.UB), # typed UB pointers + scalar_value: pto.i32, # PTO scalar values +): + return +``` + +**Typical pattern**: GM↔UB movement uses ptr-based `mte_load`/`mte_store` +rather than `tile.load`/`tile.store`. The user places `pipe_barrier` at phase +boundaries and explicitly sequences sub-kernel calls: + + +```python +def process_block(q_tile, k_part, v_part, k_tile, v_tile, + s_tile, o_tile, o_part, rows: pto.i32, cols: pto.i32): + in_row_bytes = cols * pto.bytewidth(pto.f16) + out_row_bytes = cols * pto.bytewidth(pto.f32) + gm_row_stride = k_part.strides[0] * pto.bytewidth(pto.f16) + ub_row_stride = k_tile.shape[1] * pto.bytewidth(pto.f16) + + # Stage current block from GM to UB + pto.mte_load(k_part.as_ptr(), k_tile.as_ptr(), 0, in_row_bytes, + nburst=(rows, gm_row_stride, ub_row_stride)) + pto.mte_load(v_part.as_ptr(), v_tile.as_ptr(), 0, in_row_bytes, + nburst=(rows, gm_row_stride, ub_row_stride)) + pto.pipe_barrier(pto.Pipe.ALL) + + # Dispatch sub-kernels + qk_matmul(q_tile, k_tile, s_tile) + pto.pipe_barrier(pto.Pipe.ALL) + + online_softmax(s_tile, o_tile, rows, cols) + pto.pipe_barrier(pto.Pipe.ALL) + + # Write result back + pto.mte_store(o_tile.as_ptr(), o_part.as_ptr(), out_row_bytes, + nburst=(rows, ub_row_stride, gm_row_stride)) +``` + +Sub-kernel calls and inline sub-kernel scopes (`with pto.simd():`, etc.) work +identically in both modes. + +### Choosing between modes + +| | `mode="auto"` | `mode="explicit"` | +|---|---|---| +| Abstraction | Tiles | Tiles + micro-instructions | +| Data movement | `tile.load` / `tile.store` | `mte_load` / `mte_store` (ptr-based) | +| Sync | Compiler-managed | User-authored | +| Use case | Most kernels | Hand-tuned instruction scheduling | + +Start with auto. Move to explicit when you need to control the exact sequence +of micro-instructions — for example, to overlap DMA and compute with +double-buffering, or to hand-optimize a phase boundary that the compiler +doesn't fuse as aggressively as you need. + +## 3.3 Sub-kernels + +Sub-kernels are functions decorated with `@pto.cube`, `@pto.simd`, or +`@pto.simt` that execute on a specific NPU compute unit. They can be invoked +in two ways: + +1. **As decorated functions** — reusable, named sub-kernels called from + `@pto.jit`. +2. **As context managers** (`with pto.cube():`, etc.) — inline blocks for + one-off snippets (see Section 3.4). + +### 3.3.1 `@pto.cube` — Cube unit + +**Role**: `@pto.cube` marks a function that executes on the Cube unit (matrix +multiplication engine). It consumes UB-resident tiles and explicit cube-local +scratch buffers. + +**Signature**: + + +```python +@pto.cube +def my_cube_kernel( + input_tile: pto.Tile, # UB tile (source data) + output_tile: pto.Tile, # UB tile (destination) + left_scratch: pto.Tile, # LEFT buffer (cube-local) + right_scratch: pto.Tile, # RIGHT buffer (cube-local) + acc_scratch: pto.Tile, # ACC buffer (cube-local) +): + return +``` + +All parameters are `Tile` references. Tiles marked as cube-local must be +allocated with the appropriate `memory_space` (e.g., `pto.MemorySpace.LEFT`, +`pto.MemorySpace.ACC`). + +**Typical body**: + + +```python +@pto.cube +def qk_matmul( + q_tile: pto.Tile, + k_tile: pto.Tile, + q_l0a: pto.Tile, + k_l0b: pto.Tile, + s_acc: pto.Tile, + s_tile: pto.Tile, +): + m = q_tile.valid_shape[0] + k = q_tile.valid_shape[1] + n = k_tile.valid_shape[1] + + pto.mte_l1_l0a(q_tile.as_ptr(), q_l0a.as_ptr(), m, k) + pto.mte_l1_l0b(k_tile.as_ptr(), k_l0b.as_ptr(), k, n, transpose=True) + pto.mad(q_l0a.as_ptr(), k_l0b.as_ptr(), s_acc.as_ptr(), m, n, k) + pto.mte_l0c_ub(s_acc.as_ptr(), s_tile.as_ptr(), m, n, n, n, 0) +``` + +Cube-local state (LEFT, RIGHT, ACC, BIAS) never leaks into UB — it is the +caller's responsibility to allocate scratch buffers and pass them in +explicitly. + +**Invocation modes**: can be called from `@pto.jit` in either mode, or used +inline with `with pto.cube():` (Section 3.4). + +### 3.3.2 `@pto.simd` — SIMD unit + +**Role**: `@pto.simd` marks a function that executes on the SIMD unit (vector +engine). It operates on vector registers (`vreg`) loaded from UB tiles and +stores results back to UB tiles. Vector registers are local to the function +and never cross its boundary. + +**Signature**: + + +```python +@pto.simd +def my_simd_kernel( + input_tile: pto.Tile, # UB tile + output_tile: pto.Tile, # UB tile + rows: pto.i32, # PTO scalar + cols: pto.i32, # PTO scalar +): + return +``` + +Parameters are UB `Tile` references and PTO scalar values (`pto.i32`, +`pto.f32`, etc.). Scalar parameters may come from `lds` reads or compile-time +constants. + +**Typical body**: + + +```python +@pto.simd +def add_rows(a_tile: pto.Tile, b_tile: pto.Tile, o_tile: pto.Tile, + rows: pto.index, cols: pto.index): + VEC = pto.elements_per_vreg(pto.f32) + initial_remained = scalar.index_cast(pto.i32, cols) + with pto.for_(0, rows, step=1) as r: + col_loop = pto.for_(0, cols, step=VEC).carry(remained=initial_remained) + with col_loop: + c = col_loop.iv + remained = col_loop.remained + mask, remained = pto.make_mask(pto.f32, remained) + a_vec = pto.vlds(a_tile[r, c:]) + b_vec = pto.vlds(b_tile[r, c:]) + o_vec = pto.vadd(a_vec, b_vec, mask) + pto.vsts(o_vec, o_tile[r, c:], mask) + col_loop.update(remained=remained) +``` + +The boundary contract: `vreg` values (`a_vec`, `b_vec`, `o_vec`) are local to +the function. The only way to persist data across a `@pto.simd` call is to +write it back to a UB tile via `vsts` (or `psts`, etc.). + +**Invocation modes**: can be called from `@pto.jit` in either mode, or used +inline with `with pto.simd():` (Section 3.4). + +### 3.3.3 `@pto.simt` — SIMT unit + +**Role**: `@pto.simt` marks a function that executes on the SIMT unit. SIMT +(Single Instruction, Multiple Threads) is a programming model where you write +instructions in scalar syntax, and the hardware executes them in parallel +across many threads — analogous to how a GPU SM runs a CUDA kernel. Each +instruction appears to operate on a single element (`lds`, `sts`, `a + b`), +but the same instruction is issued across a large number of work-items +simultaneously. + +**Signature**: + + +```python +@pto.simt +def my_simt_kernel( + tile: pto.Tile, # UB tile + ptr: pto.ptr(pto.f32, pto.MemorySpace.UB), # typed UB pointer + scalar_value: pto.i32, # PTO scalar +): + return +``` + +**Typical body**: + + +```python +@pto.simt +def blend_output_rows( + o_prev_tile: pto.Tile, pv_tile: pto.Tile, + alpha_tile: pto.Tile, beta_tile: pto.Tile, + o_next_tile: pto.Tile, + row_start: pto.i32, row_stop: pto.i32, valid_dim: pto.i32, +): + with pto.for_(row_start, row_stop, step=1) as row: + alpha = scalar.load(alpha_tile[row, 0]) + beta = scalar.load(beta_tile[row, 0]) + with pto.for_(0, valid_dim, step=1) as col: + o_prev = scalar.load(o_prev_tile[row, col]) + pv_val = scalar.load(pv_tile[row, col]) + o_next = alpha * o_prev + beta * pv_val + scalar.store(o_next, o_next_tile[row, col]) +``` + +SIMT kernels read and write individual scalar elements from tiles. The unit +executes the same scalar instruction across many work-items in parallel, making +it efficient for per-element operations. + +**Invocation modes**: can be called from `@pto.jit` in either mode, or used +inline with `with pto.simt():` (Section 3.4). + +## 3.4 Inline context manager syntax + +In addition to the decorator form, each sub-kernel unit provides a context +manager: `with pto.cube():`, `with pto.simd():`, and `with pto.simt():`. These +open inline blocks without requiring a separate named function — useful for +quick prototyping, one-off hardware-unit snippets, or code that is too small to +extract. Inline scopes are supported in top-level `@pto.jit` bodies. + +### Syntax + + +```python +with pto.simd(): + a_vec = pto.vlds(a_tile[r, c:]) + b_vec = pto.vlds(b_tile[r, c:]) + o_vec = pto.vadd(a_vec, b_vec, mask) + pto.vsts(o_vec, o_tile[r, c:], mask) +``` + + +```python +with pto.simt(): + alpha = scalar.load(alpha_tile[row, 0]) + beta = scalar.load(beta_tile[row, 0]) + o_next = alpha * o_prev + beta * pv_val + scalar.store(o_next, o_next_tile[row, col]) +``` + + +```python +with pto.cube(): + pto.mte_l1_l0a(q_tile.as_ptr(), q_l0a.as_ptr(), m, k) + pto.mte_l1_l0b(k_tile.as_ptr(), k_l0b.as_ptr(), k, n, transpose=True) + pto.mad(q_l0a.as_ptr(), k_l0b.as_ptr(), s_acc.as_ptr(), m, n, k) + pto.mte_l0c_ub(s_acc.as_ptr(), s_tile.as_ptr(), m, n, n, n, 0) +``` + +### Semantics + +- Inside the `with` block, instructions execute on the corresponding hardware + unit. +- `vreg` values created inside `with pto.simd():` are scoped to the block — + they do not escape. +- Cube-local scratch (`l0a`, `l0b`, `acc`) must be allocated by the caller + before entering the block. +- The context manager form is equivalent to an inline anonymous sub-kernel. The + compiler treats it identically to a named `@pto.simd` / `@pto.cube` / + `@pto.simt` function. + +### Comparison + +| | Decorator form | Context manager form | +|---|---|---| +| Reuse | Named, callable from multiple sites | Inline, single-use | +| Readability | Good for complex, multi-step logic | Good for short (3-10 line) snippets | +| Testing | Can be unit-tested independently | Tested only through the enclosing kernel | +| Cube-local args | Explicit parameters | Captured from enclosing scope | + +The two forms can be freely mixed in the same `@pto.jit` body. + +## 3.5 Boundary contracts + +Data crosses decorator boundaries only through UB-backed tiles or typed UB +pointers: + +| Boundary | Allowed | +|----------|---------| +| Host → `@pto.jit` | Python-native tensors | +| `@pto.jit(mode="auto")` → sub-kernel | `Tile`, PTO scalars (compiler handles staging + sync) | +| `@pto.jit(mode="explicit")` → sub-kernel | `Tile`, `PartitionTensorView`, `pto.ptr`, PTO scalars | +| `@pto.jit` → `with pto.{cube,simd,simt}:` | `Tile` captured from enclosing scope | +| Sub-kernel → sub-kernel | Not allowed (go through UB tiles via the caller) | +| `@pto.simd` → caller | Only via `vsts`/`psts` to UB tiles; `vreg` cannot escape | +| Cube-local → UB | Only via `mte_l0c_ub`; LEFT/RIGHT/ACC/BIAS are private | + +## 3.6 `pto.constexpr` + +`pto.constexpr` marks a `@pto.jit` keyword-only parameter as a compile-time +constant. The compiler specializes the kernel for each combination of constexpr +values, and the compiled artifact is cached by specialization key together with +the kernel's tensor ABI contract. + + +```python +@pto.jit(target="a5") +def kernel( + A: pto.tensor_spec(rank=2, dtype=pto.f32), + *, + BLOCK: pto.constexpr = 128, + DTYPE: pto.constexpr = pto.f32, +): + # ... use BLOCK / DTYPE in tile shapes, loop bounds, or dtype-specialized paths ... + return +``` + +- Must appear as a keyword-only argument (after `*`). +- Must have a default value. +- Must be provided at `.compile()` time if the caller needs to override the + default. +- Cannot change between launches of the same compiled instance — compile a new + variant for a different value. + +`pto.constexpr` parameters can be used anywhere in the kernel body where a +Python value is expected: tile shapes, loop bounds that are known at compile +time, dtype arguments, etc. They are evaluated at trace time, so `for i in +range(BLOCK)` would unroll `BLOCK` times. + +In contrast, values derived from runtime tensor shapes (e.g., `A.shape[0]`) +are dynamic — they vary per launch and should be used with `pto.for_` to +produce device-side loops. diff --git a/ptodsl/docs/user_guide/04-type-system-and-buffer.md b/ptodsl/docs/user_guide/04-type-system-and-buffer.md new file mode 100644 index 000000000..6f1dcd8a9 --- /dev/null +++ b/ptodsl/docs/user_guide/04-type-system-and-buffer.md @@ -0,0 +1,264 @@ +# 4. Type System and Buffer Management + +This chapter covers every type you can use in a PTODSL kernel, plus the operations for managing buffers in global memory (GM) and on-chip Unified Buffer (UB). + +## 4.1 Scalar types + +### Numeric scalar types + +| DSL Type | Description | Bit Width | +|----------|-------------|-----------| +| `pto.i1` | Boolean | 1 | +| `pto.i8` | 8-bit signless integer | 8 | +| `pto.si8` | 8-bit signed integer | 8 | +| `pto.ui8` | 8-bit unsigned integer | 8 | +| `pto.i16` | 16-bit signless integer | 16 | +| `pto.si16` | 16-bit signed integer | 16 | +| `pto.ui16` | 16-bit unsigned integer | 16 | +| `pto.i32` | 32-bit signless integer | 32 | +| `pto.si32` | 32-bit signed integer | 32 | +| `pto.ui32` | 32-bit unsigned integer | 32 | +| `pto.i64` | 64-bit signless integer | 64 | +| `pto.si64` | 64-bit signed integer | 64 | +| `pto.ui64` | 64-bit unsigned integer | 64 | +| `pto.f16` | Half-precision float | 16 | +| `pto.bf16` | Brain float 16 | 16 | +| `pto.f32` | Single-precision float | 32 | + +Python literals are automatically typed by the tracer: `bool` → `pto.i1`, `int` → context-dependent (typically `pto.i32` or `pto.i64`), `float` → `pto.f32`. + +For explicit typing, use type constructors: + + +```python +x = pto.i32(1024) +y = pto.ui16(7) +z: pto.i32 = 1024 +``` + +### Low-precision types (storage only) + +The following types are **storage-only**: they may only appear as element types when constructing `Tile`, `TensorView`, and `PartitionTensorView` values for storage and data movement. They **cannot** be used to construct scalars, vectors, pointers, or `tensor_spec(...)` ABI contracts. Use them to reduce memory bandwidth; convert to a compute-capable type before arithmetic. + +| DSL Type | Description | +|----------|-------------| +| `pto.hif8` | HiFloat8 format | +| `pto.f4e1m2x2` | 4-bit float (E1M2, 2-wide packed) | +| `pto.f4e2m1x2` | 4-bit float (E2M1, 2-wide packed) | +| `pto.f8e4m3` | 8-bit float (E4M3) | +| `pto.f8e5m2` | 8-bit float (E5M2) | + +These types can be used when constructing on-chip tiles and view descriptors: + + +```python +lp_tile = pto.alloc_tile(shape=[128, 64], dtype=pto.f8e4m3) +fp4_tile = pto.alloc_tile(shape=[64, 32], dtype=pto.f4e2m1x2) +``` + +Constructing a scalar, vector, pointer, or host tensor ABI contract with a low-precision type is **not supported** — `pto.f8e4m3(1.0)`, `pto.vreg_type(64, pto.f8e4m3)`, `pto.ptr(pto.f8e4m3)`, and `pto.tensor_spec(rank=2, dtype=pto.f8e4m3)` will raise an error. Load data as the storage type, then convert to a compute-capable type before arithmetic. + +### Integer literal guidance + +Prefer plain integer literals. Hex string literals are reserved for explicit bit-pattern authoring: + + +```python +count = pto.i32(1024) +delta = pto.i16(-12) +hi_bit = pto.i32("0x80000000") # bit-pattern: -2147483648 +``` + +### Floating-point literal forms + + +```python +a = pto.f16(-1.5) +b = pto.f32("inf") +c = pto.f32("-inf") +d = pto.f32("nan") +# Bit-pattern hex +f16_neg_inf = pto.f16("0xFC00") +``` + +## 4.2 Vector register type + +Vector registers hold a fixed 256-byte payload. `pto.vreg(dtype)` infers the element count automatically: + +| `dtype` | Result | Elements | +|---------|--------|----------| +| `pto.f32` / `pto.i32` / ... | `vreg<64xT>` | 64 | +| `pto.f16` / `pto.bf16` / `pto.i16` / ... | `vreg<128xT>` | 128 | +| `pto.i8` / `pto.si8` / `pto.ui8` | `vreg<256xT>` | 256 | + +Constraint: `element_count × bitwidth(dtype) = 2048`. + +Use `pto.elements_per_vreg(dtype)` to query the element count: + + +```python +lanes = pto.elements_per_vreg(pto.f32) # 64 +``` + +### vbitcast + +Reinterpret the bits of a vector register as a different element type: + + +```python +fvec = pto.vlds(ptr, offset) # !pto.vreg<64xf32> +ivec = pto.vbitcast(fvec, pto.i32) # !pto.vreg<64xi32> +f16_vec = pto.vbitcast(fvec, pto.f16) # !pto.vreg<128xf16> +``` + +`vbitcast` preserves the exact bit pattern (type punning). Use `vcvt` for numeric value conversion. + +## 4.3 Mask (predicate) types + +Masks are typed by bit granularity and must match the vector element width: + +| DSL Type | Granularity | Used with | +|----------|-------------|-----------| +| `pto.mask_b8` | 8-bit | `i8`, `si8`, `ui8` | +| `pto.mask_b16` | 16-bit | `f16`, `bf16`, `i16`, `si16`, `ui16` | +| `pto.mask_b32` | 32-bit | `f32`, `i32`, `si32`, `ui32` | + +### Constructing masks + +Use `make_mask` to generate a mask from a pattern or scalar — it automatically selects the correct bit width from the element dtype: + + +```python +active = pto.make_mask(pto.f16, pto.MaskPattern.ALL) # pattern-based full mask +tail_mask, _ = pto.make_mask(pto.f32, tail_count) # load mask from tail count scalar +``` + +The bit-width-specific `pset_b32` and `plt_b32` forms are also available: + +```python +active = pto.pset_b32("PAT_ALL") +one_mask, _ = pto.plt_b32(c1_i32) +``` + +### Reinterpreting masks + +`pbitcast` reinterprets a mask register at a different granularity: + + +```python +mask_b16 = pto.pbitcast(mask_b8, pto.mask_b16) +``` + +## 4.4 Pointer types + +Pointers combine an element type and a memory space: + + +```python +ptr_gm = pto.ptr(pto.f32, pto.MemorySpace.GM) +ptr_ub = pto.ptr(pto.f16, pto.MemorySpace.UB) +``` + +### MemorySpace enum + +| Enum Value | Description | +|------------|-------------| +| `MemorySpace.GM` | Global Memory (off-chip HBM) | +| `MemorySpace.UB` | Unified Buffer (on-chip scratchpad) | +| `MemorySpace.MAT` | Cube L1 / cbuf staging buffer | +| `MemorySpace.LEFT` | Cube L0A left-operand buffer | +| `MemorySpace.RIGHT` | Cube L0B right-operand buffer | +| `MemorySpace.ACC` | Cube L0C accumulator buffer | +| `MemorySpace.BIAS` | Cube bias table buffer | + +## 4.5 TensorView + +`TensorView` is a descriptor for a tensor in Global Memory. Create one inside a `@pto.jit` body with `make_tensor_view`: + + +```python +@pto.jit(target="a5") +def kernel( + A: pto.tensor_spec(rank=2, dtype=pto.f32), + *, + BLOCK: pto.constexpr = 128, +): + tv = pto.make_tensor_view(A, shape=A.shape, strides=A.strides) + return +``` + +`make_tensor_view` wraps a Python-native tensor. You provide the logical shape and the stride of each dimension in **elements** (not bytes). The resulting `TensorView` can be partitioned for `tile.load`/`tile.store`. + +### TensorView attributes + +| Attribute | Type | Description | +|-----------|------|-------------| +| `shape` | `tuple[int, ...]` | Logical dimensions (up to 5D) | +| `element_type` | `Type` | Element dtype (e.g., `pto.f32`) | +| `strides` | `tuple[int, ...]` | Stride of each dimension, in elements | + +Strides support non-contiguous tensors. Pass `strides=A.strides` from the source tensor for the default row-major layout, or supply explicit strides for sub-views. Use `tv.as_ptr()` to obtain a typed GM pointer for use with MTE Ops in explicit-mode orchestration. + +## 4.6 PartitionTensorView + +`partition_view` creates a sub-view of a TensorView at a given offset and size. It describes *which part* of the GM tensor a `tile.load` or `tile.store` should operate on: + + +```python +part = pto.partition_view(tv, offsets=[row_offset, 0], sizes=[BLOCK, dim]) +``` + +The result is a `PartitionTensorView` — a lightweight descriptor, not a data buffer. It carries the partition's shape, strides, and element type (inherited from the source TensorView). Use `part.as_ptr()` to obtain a typed GM pointer for MTE Ops in explicit-mode orchestration. + +## 4.7 Tile + +A `Tile` is an on-chip buffer allocated in UB or cube-local memory. Allocate tiles with `alloc_tile`: + + +```python +# UB tile +a_tile = pto.alloc_tile(shape=[BLOCK, dim], dtype=pto.f32) + +# Logical column tile +m_tile = pto.alloc_tile(shape=[Br, 1], dtype=pto.f32, blayout="ColMajor") + +# Cube-local scratch with explicit memory space +q_l0a = pto.alloc_tile(shape=[Br, dim], dtype=pto.f16, memory_space=pto.MemorySpace.LEFT) +s_acc = pto.alloc_tile(shape=[Br, Bc], dtype=pto.f32, memory_space=pto.MemorySpace.ACC) +``` + +`alloc_tile` returns a `Tile` object. The `shape` must be a compile-time constant. The default memory space is UB. +For narrow logical column tiles such as `[Br, 1]`, author them with +`blayout="ColMajor"`. Row-major none-box tiles are validated against a 32-byte +physical row-alignment rule. + +For packed types (`pto.f4e1m2x2`, `pto.f4e2m1x2`), `shape` dimensions refer to the number of **packed** elements, each containing 2 f4 values. For example, `alloc_tile(shape=[128, 64], dtype=pto.f4e1m2x2)` allocates a 128×64 tile of packed elements, holding 128×64×2 individual 4-bit floats. The same applies to TensorView shapes when the tensor spec uses a packed dtype. + +### Tile attributes + +| Attribute | Type | Description | +|-----------|------|-------------| +| `shape` | `tuple[int, ...]` | Physical tile dimensions (compile-time constant) | +| `element_type` | `Type` | Element dtype | +| `memory_space` | `MemorySpace` | Where the tile lives (UB, LEFT, RIGHT, ACC, BIAS) | +| `valid_shape` | `tuple[int, ...]` | Logical data region, ≤ `shape` in each dimension | + +### Tile methods + +| Method | Description | +|--------|-------------| +| `tile.fill(value)` | Fill the entire tile with a scalar value | +| `tile.as_ptr()` | Obtain a typed pointer to the tile's base address | + + +```python +m_prev_tile.fill(float("-inf")) +l_prev_tile.fill(0.0) + +rows = q_tile.valid_shape[0] +cols = k_tile.valid_shape[1] +meta_tile.valid_shape = [pto.const(1), pto.const(2)] +tail_tile.valid_shape = [rows] + +meta_ptr = meta_tile.as_ptr() +``` diff --git a/ptodsl/docs/user_guide/05-control-flow.md b/ptodsl/docs/user_guide/05-control-flow.md new file mode 100644 index 000000000..ee300efe7 --- /dev/null +++ b/ptodsl/docs/user_guide/05-control-flow.md @@ -0,0 +1,254 @@ +# 5. Control Flow + +PTODSL uses a **tracing** compilation model. When you call `kernel.compile(...)`, PTODSL executes your Python function body once to record every PTO instruction — this pass is called *tracing*. The recorded program is then lowered and optimized into device code. Once compiled, launching the kernel runs the already-built device code directly on the NPU. + +This has one critical implication for how you write loops and branches: + +- **Python native `for`/`if`** runs at trace time. A `for i in range(4)` loop gets unrolled — the device code contains four copies of the body, not a loop instruction. An `if` condition is evaluated at trace time, and only the taken branch is recorded. +- **`pto.for_` / `pto.if_`** produce device-side control flow. The loop bound or branch condition can be a runtime value, and the hardware will execute the loop or take the branch dynamically. + +**Simple rule: Python control flow = trace time (compile-time). `pto.*` control flow = device-side (runtime).** + +## 5.1 Python native `for` — trace-time unrolling + +When you write a plain Python `for` loop inside a kernel body, Python executes it immediately during tracing. Each iteration records its instructions separately, so the device code gets a linear sequence with the body repeated: + +```python +@pto.jit(target="a5") +def unrolled_kernel(A, O, *, N: pto.constexpr): + a_view = pto.make_tensor_view(A, shape=[N], strides=A.strides) + o_view = pto.make_tensor_view(O, shape=[N], strides=O.strides) + + # N is constexpr, so range(N) is known at trace time. + # The loop unrolls: the device gets N copies of the body. + for i in range(N): + a_part = pto.partition_view(a_view, offsets=[i], sizes=[1]) + o_part = pto.partition_view(o_view, offsets=[i], sizes=[1]) + a_tile = pto.alloc_tile(shape=[1], dtype=pto.f32) + o_tile = pto.alloc_tile(shape=[1], dtype=pto.f32) + pto.tile.load(a_part, a_tile) + pto.tile.add(a_tile, a_tile, o_tile) + pto.tile.store(o_tile, o_part) +``` + +This works when the loop bound is a compile-time constant (like a `constexpr` parameter). But if `N` comes from a tensor shape and varies per launch, `range(N)` would trace a different number of iterations each time — you would get a cache miss and recompilation on every new value. For dynamic bounds, use `pto.for_`. + +## 5.2 `pto.for_` — device-side loops + +`pto.for_` records a structured loop that executes on the device. Its bound can be any expression involving runtime values (tensor shapes, scalar computations, block indices), and the compiler may optimize it further — unrolling when the bound is known at compile time, or keeping it as a runtime loop otherwise. + +### Basic form + + +```python +with pto.for_(start, stop, step=step) as iv: + pto.tile.load(pto.partition_view(a_view, offsets=[iv, 0], sizes=[1, cols]), tile) +``` + +- `start`, `stop`, `step` are PTO scalar expressions. They are evaluated on the device. +- The loop body executes `(stop - start + step - 1) // step` times. +- Use with `step=1` unless you need a strided iteration. + +Compare the two approaches: + + +```python +# Trace-time unrolling — BLOCK must be constexpr +for i in range(BLOCK): + pto.tile.load(pto.partition_view(a_view, offsets=[0, 0], sizes=[1, cols]), tile) + +# Device-side loop — num_blocks can be dynamic +with pto.for_(0, num_blocks, step=1) as i: + pto.tile.load(pto.partition_view(a_view, offsets=[i, 0], sizes=[1, cols]), tile) +``` + +### Nested loops + + +```python +with pto.for_(0, rows, step=1) as r: + with pto.for_(0, cols, step=1) as c: + val = scalar.load(tile[r, c]) +``` + +Both loops execute on the device. The outer loop bound `rows` and inner loop bound `cols` can be runtime values. + +### Loop with carry state + +When a loop needs to propagate state from one iteration to the next, use the `.carry(...)` method. This is the PTODSL equivalent of a loop that accumulates or updates variables across iterations. The following self-contained kernel is the smallest compileable carry example used by the docs-as-test harness: + + +```python +@pto.jit(target="a5") +def carry_loop_probe(*, BLOCK: pto.constexpr = 128): + m_prev = pto.alloc_tile(shape=[1, BLOCK], dtype=pto.f32) + l_prev = pto.alloc_tile(shape=[1, BLOCK], dtype=pto.f32) + o_prev = pto.alloc_tile(shape=[1, BLOCK], dtype=pto.f32) + m_next = pto.alloc_tile(shape=[1, BLOCK], dtype=pto.f32) + l_next = pto.alloc_tile(shape=[1, BLOCK], dtype=pto.f32) + o_next = pto.alloc_tile(shape=[1, BLOCK], dtype=pto.f32) + + m_prev.fill(0.0) + l_prev.fill(0.0) + o_prev.fill(0.0) + + kv_loop = pto.for_(0, 4, step=1).carry(m=m_prev, l=l_prev, o=o_prev) + with kv_loop: + kv_loop.m.fill(1.0) + kv_loop.l.fill(2.0) + kv_loop.o.fill(3.0) + kv_loop.update(m=m_next, l=l_next, o=o_next) + + final_o = kv_loop.final("o") + final_o.fill(4.0) +``` + +`.carry(name=initial_value)` declares named state variables that are passed from one iteration to the next. Inside the loop body, access the current value with `loop.name`. At the end of the body, call `loop.update(name=new_value)` to set what the next iteration receives. After the loop exits, `loop.final("name")` retrieves the value from the last iteration. + +This pattern is central to algorithms like online softmax, where each KV block updates running statistics (row max, sum, output accumulator). The ping-pong tile pattern — allocating two tiles and swapping them each iteration — is the idiomatic way to manage this state: + + +```python +# Allocate ping-pong state tiles +m_prev = pto.alloc_tile(shape=[Br, 1], dtype=pto.f32, blayout="ColMajor") +m_next = pto.alloc_tile(shape=[Br, 1], dtype=pto.f32, blayout="ColMajor") +l_prev = pto.alloc_tile(shape=[Br, 1], dtype=pto.f32, blayout="ColMajor") +l_next = pto.alloc_tile(shape=[Br, 1], dtype=pto.f32, blayout="ColMajor") + +# Initialize prev tiles +m_prev.fill(float("-inf")) +l_prev.fill(0.0) + +loop = pto.for_(0, num_blocks, step=1).carry(m=m_prev, l=l_prev) +with loop: + m_cur = loop.m + l_cur = loop.l + + m_next.fill(1.0) + l_next.fill(2.0) + + loop.update(m=m_next, l=l_next) +``` + +### Chunked inner loop with carry (tail handling) + +For SIMD kernels that process data in vector-width chunks, use a carry loop to track the remaining element count across column iterations: + + +```python +VEC = pto.elements_per_vreg(pto.f32) +col_loop = pto.for_(0, cols, step=VEC).carry(remained=cols) +with col_loop: + c = col_loop.iv + remained = col_loop.remained + mask, remained = pto.make_mask(pto.f32, remained) + vec = pto.vlds(tile[r, c:]) + # ... operate under mask ... + pto.vsts(vec, out_tile[r, c:], mask) + col_loop.update(remained=remained) +``` + +`make_mask(dtype, n)` returns two values: the predicate mask for the current chunk and the updated remaining count. Passing the updated count back via `col_loop.update(remained=...)` feeds it into the next iteration, so each chunk correctly computes how many elements are left. If `n` is an `index`, the updated remaining count stays an `index`; PTODSL hides the hardware `i32` tail-mask bookkeeping internally. + +## 5.3 `pto.if_` — device-side conditionals + +`pto.if_` records a device-side conditional branch. Unlike a Python `if`, the condition can be a runtime PTO scalar, and both branches are recorded into the program so the hardware can choose at runtime. + +The condition must be a PTO scalar value (e.g., the result of a comparison like `a > b` or a value loaded from a tile). Python booleans evaluated at trace time should use a plain `if` instead. + +### Recommended block structure + +PTODSL should treat one device-side conditional as one explicit branch object. +The recommended surface is: + +```python +with pto.if_(cond) as br: + with br.then_: + ... + with br.else_: + ... +``` + +This keeps the `if` / `else` pairing explicit. The `else_` branch is optional +for side-effect-only conditionals. + +### Automatic named merge across branches + +When a value must flow out of both branches, PTODSL should merge by explicit +name. Each branch assigns the same output names with `br.assign(...)`, and the +merged results are read back from the branch handle after the conditional: + +```python +@pto.simt +def conditional_scale( + tile: pto.Tile, + threshold: pto.f32, + scale: pto.f32, + rows: pto.i32, + cols: pto.i32, +): + with pto.for_(0, rows, step=1) as r: + with pto.for_(0, cols, step=1) as c: + val = scalar.load(tile[r, c]) + big = val > threshold + + with pto.if_(big) as br: + with br.then_: + br.assign(val=val * scale) + with br.else_: + br.assign(val=val) + + val = br.val + scalar.store(val, tile[r, c]) +``` + +In this example, both branches define the merged value named `val`. After the +conditional closes, `br.val` is the SSA-merged result seen by downstream code. +This surface avoids explicit result-type declarations and explicit +`pto.yield_(...)` in user code while still keeping the merge contract explicit. + +## 5.4 `pto.constexpr` and tracing + +`pto.constexpr` parameters (Section 3.8) are compile-time constants. They are fixed at `.compile()` time and cannot change between launches of the same compiled kernel. Because their values are known during tracing, they interact naturally with Python control flow: + +```python +@pto.jit(target="a5") +def kernel( + A, + *, + BLOCK: pto.constexpr = 128, + NUM_BLOCKS: pto.constexpr = 8, + UNROLL: pto.constexpr = False, +): + N = A.shape[0] + num_blocks = (N + BLOCK - 1) // BLOCK + + # N and num_blocks are runtime values derived from tensor metadata. + # They can drive device-side control flow such as pto.for_(...), + # but they are not Python integers and cannot be used in range(...). + with pto.for_(0, num_blocks, step=1) as i: + ... + + if UNROLL: + # Trace-time: UNROLL and NUM_BLOCKS are both known during tracing. + # Each iteration records separately, so the loop is fully unrolled. + for i in range(NUM_BLOCKS): + ... + else: + # The non-unrolled path can still use a device-side loop whose bound + # is a constexpr value captured into the traced program. + with pto.for_(0, NUM_BLOCKS, step=1) as i: + ... +``` + +This lets you write a single kernel that specializes into different strategies based on constexpr knobs, while still using runtime tensor metadata for device-side control flow. + +## 5.5 Summary + +| Construct | When evaluated | Use for | +|-----------|---------------|---------| +| Python `for` | Trace time | Bounds known at compile time (constexpr), deliberate unrolling | +| Python `if` | Trace time | Conditions known at compile time, variant selection | +| `pto.for_` | Device-side | Dynamic bounds, runtime loop counts | +| `pto.for_(...).carry(...)` | Device-side | Loops with accumulated state across iterations | +| `pto.if_` | Device-side | Runtime conditions, data-dependent branching | diff --git a/ptodsl/docs/user_guide/06-scalar-and-pointer-ops.md b/ptodsl/docs/user_guide/06-scalar-and-pointer-ops.md new file mode 100644 index 000000000..85617e17e --- /dev/null +++ b/ptodsl/docs/user_guide/06-scalar-and-pointer-ops.md @@ -0,0 +1,421 @@ +# 6. Scalar and Pointer Operations + +Chapter 5 established the rule: Python constructs are resolved at trace time, PTO constructs produce device-side behavior. This chapter applies that distinction to scalars and pointers — when to use a plain Python number, when to use a top-level `scalar.*` helper, and how to work with typed pointers. + +## 6.1 Python scalars vs PTO scalars + +A **Python scalar** is any value computed by Python during tracing: a literal (`3.14159`), a constexpr parameter (`BLOCK`), or an arithmetic expression built only from compile-time-known values (`1.0 / sqrt(128)`). These are evaluated at trace time and their results are baked into the device code as constants. + +A **PTO scalar** is a value that lives on the device at runtime. It comes from a `scalar.load` read, a device-side computation (`scalar.max`, `scalar.exp`), a runtime query (`pto.get_block_idx()`), or `@pto.jit` tensor metadata such as `A.shape[0]` / `A.strides[1]`. PTO scalars flow through the recorded program and are not resolved until the kernel executes. The helper functions that operate on them live in the top-level `scalar` namespace, not under `pto.*`. + +### The mixed expression + +In practice, a single expression can mix both kinds: + +```python +alpha * o_prev + beta * pv_val +# ^ Python float (trace-time constant, e.g. 1.0 / sqrt(dim)) +# ^ PTO scalar (loaded from tile at runtime) +# ^ PTO scalar (loaded from tile at runtime) +``` + +`alpha` is a Python float computed from compile-time information — it becomes an immediate constant in the device code. `o_prev` and `pv_val` are PTO scalars read from tiles at runtime. The `*` and `+` operators are recorded as device-side multiply-add instructions. The tracer sees the whole expression and produces the appropriate device instructions, embedding the constant operand where possible. + +### Rule of thumb + +| If the value... | Use... | Example | +|-----------------|--------|---------| +| Is known at compile time | Python scalar | `BLOCK`, `1.0 / sqrt(128)` | +| Comes from device memory | PTO scalar | `scalar.load(tile[r, c])` | +| Depends on a runtime value | PTO scalar | `scalar.max(m_prev, row_max)` | +| Comes from tensor metadata at the `@pto.jit` boundary | PTO scalar | `A.shape[0]`, `Q.strides[2]` | +| Is a block/subblock index | PTO scalar | `pto.get_block_idx()` | + +When in doubt, ask: *can this value change between launches of the same compiled kernel?* If yes, it must be a PTO scalar. + +## 6.2 Scalar access: load and store + +`scalar.load` reads a single scalar element from a typed pointer or tile location. `scalar.store` writes a scalar back. These are the canonical scalar memory ops for SIMT authoring. The offset is counted in elements, not bytes. + +#### `scalar.load(ptr: PtrType, offset: Index) -> ScalarType` + +**Description**: Loads one scalar element from a typed pointer at the given element offset. + +**Parameters**: + +| Parameter | Type | Description | +|-----------|------|-------------| +| `ptr` | `PtrType` | Typed pointer (`pto.ptr`) or the result of `tile.as_ptr()` | +| `offset` | `Index` | Element displacement from `ptr` | + +**Returns**: + +| Return Value | Type | Description | +|--------------|------|-------------| +| `value` | `ScalarType` | The loaded scalar, matching the pointer's element type | + +**Tile-index form** — the preferred syntax when loading from a tile: + + +```python +val = scalar.load(tile[row, col]) +``` + +`tile[row, col]` selects one element. Row and column indices are PTO scalars (or Python integers that the tracer promotes). This form is equivalent to computing the pointer and offset from the tile's base address and layout. + +**Pointer forms**: + + +```python +val = scalar.load(ptr, offset) # explicit offset +val = scalar.load(ptr + offset) # pointer arithmetic shorthand +``` + +--- + +#### `scalar.store(value: ScalarType, ptr: PtrType, offset: Index) -> None` + +**Description**: Stores one scalar element to a typed pointer at the given element offset. + +**Parameters**: + +| Parameter | Type | Description | +|-----------|------|-------------| +| `value` | `ScalarType` | Scalar value to write | +| `ptr` | `PtrType` | Typed destination pointer | +| `offset` | `Index` | Element displacement from `ptr` | + +**Returns**: None (side-effect operation). + +**Tile-index form**: + + +```python +scalar.store(value, tile[row, col]) +``` + +**Pointer forms**: + + +```python +scalar.store(value, ptr, offset) +``` + +--- + +### Typical SIMT usage + +`scalar.load` and `scalar.store` are the primary data access pattern inside `@pto.simt` kernels. Each `load`/`store` operates on one element per work-item, but the SIMT unit executes the same instruction across many work-items in parallel: + + +```python +@pto.simt +def blend_output_rows( + o_prev_tile: pto.Tile, pv_tile: pto.Tile, + alpha_tile: pto.Tile, beta_tile: pto.Tile, + o_next_tile: pto.Tile, + row_start: pto.i32, row_stop: pto.i32, valid_dim: pto.i32, +): + with pto.for_(row_start, row_stop, step=1) as row: + alpha = scalar.load(alpha_tile[row, 0]) + beta = scalar.load(beta_tile[row, 0]) + with pto.for_(0, valid_dim, step=1) as col: + o_prev = scalar.load(o_prev_tile[row, col]) + pv_val = scalar.load(pv_tile[row, col]) + o_next = alpha * o_prev + beta * pv_val + scalar.store(o_next, o_next_tile[row, col]) +``` + +When writing to a raw pointer (e.g., a small metadata buffer obtained via `as_ptr()`), use the pointer-plus-offset form. The following self-contained kernel is the smallest compileable pointer-offset example: + + +```python +from ptodsl import pto, scalar + + +@pto.jit(target="a5") +def scalar_pointer_offset_probe(): + meta_tile = pto.alloc_tile(shape=[1, 8], dtype=pto.i32, valid_shape=[1, 3]) + meta_ptr = meta_tile.as_ptr() + + scalar.store(0, meta_ptr, 0) + scalar.store(1, meta_ptr, 1) + scalar.store(2, meta_ptr + 2) + + row_start = scalar.load(meta_ptr, 0) + row_stop = scalar.load(meta_ptr, 1) + valid_cols = scalar.load(meta_ptr + 2) + + _ = row_start + _ = row_stop + _ = valid_cols +``` + +## 6.3 Scalar arithmetic and comparisons + +### Python operators for basic arithmetic + +Addition, subtraction, multiplication, and division of PTO scalars use standard Python syntax. The tracer records the corresponding device-side instructions automatically: + + +```python +o_next = alpha * o_prev + beta * pv_val # multiply-add +l_scaled = l_prev * scalar.exp(m_prev - m_next) # subtraction inside exp +step = (N + BLOCK - 1) // BLOCK # Python int arithmetic (trace-time) +``` + +When both operands are PTO scalars (loaded from device memory or produced by another device-side op), `+`, `-`, `*`, `/` produce device-side arithmetic instructions. When one operand is a Python scalar (trace-time constant), the tracer embeds it as an immediate. + +### Math functions: `scalar.*` + +Non-trivial scalar math functions live under the top-level `scalar` namespace (imported as `from ptodsl import scalar`). They are intentionally separate from the `pto.*` namespace: + +#### `scalar.max(a: ScalarType, b: ScalarType) -> ScalarType` + +**Description**: Returns the maximum of two scalars. + +#### `scalar.min(a: ScalarType, b: ScalarType) -> ScalarType` + +**Description**: Returns the minimum of two scalars. + +#### `scalar.exp(x: ScalarType) -> ScalarType` + +**Description**: Exponential, e^x. + +#### `scalar.log(x: ScalarType) -> ScalarType` + +**Description**: Natural logarithm. + +#### `scalar.sqrt(x: ScalarType) -> ScalarType` + +**Description**: Square root. + +#### `scalar.abs(x: ScalarType) -> ScalarType` + +**Description**: Absolute value. + + +```python +lo = scalar.min(m_prev, row_max) +mag = scalar.abs(m_prev - row_max) +ln = scalar.log(threshold + 1.0) +root = scalar.sqrt(threshold + 4.0) +``` + +### Comparisons + +**Description**: PTO scalars use Python's native comparison operators. The tracer records the corresponding device-side comparison instruction and returns a `pto.i1` result. + +| Operator | Predicate (signed) | Predicate (unsigned) | Predicate (float) | +|----------|---------------------|-----------------------|--------------------| +| `>` | `sgt` | `ugt` | `ogt` | +| `<` | `slt` | `ult` | `olt` | +| `==` | `eq` | `eq` | `oeq` | +| `!=` | `ne` | `ne` | `one` | +| `>=` | `sge` | `uge` | `oge` | +| `<=` | `sle` | `ule` | `ole` | + +**Example**: + + +```python +m_next = scalar.max(m_prev, row_max) +l_scaled = l_prev * scalar.exp(m_prev - m_next) +need_scale = val > threshold # pto.i1 result +is_zero_mask = val == threshold +in_range = (val >= threshold) & (val <= row_max) +``` + +For readability in files with many scalar operations, use the top-level `scalar` namespace directly: + + +```python +m_next = scalar.max(m_prev, row_max) +l_scaled = l_prev * scalar.exp(m_prev - m_next) +``` + +These are the scalar-path counterparts of the vector math operations covered in Chapter 8. Use them inside `@pto.simt` kernels and in explicit-mode orchestration code where you need to compute a loop bound or a scalar coefficient from runtime data. + +## 6.4 Pointer operations + +Typed pointers (Section 4.4) carry both an element type and a memory space. This section covers the operations that create and manipulate them. + +### Obtaining pointers: as_ptr() + +Tiles and tensor views expose their base address via `as_ptr()`: + + +```python +gm_ptr = partition.as_ptr() # GM pointer from a PartitionTensorView +ub_ptr = tile.as_ptr() # UB pointer from a Tile +``` + +`as_ptr()` is the preferred way to get a typed pointer from a high-level descriptor. The result carries the correct element type and memory space from the source. + +--- + +#### `pto.addptr(ptr: PtrType, offset: Index) -> PtrType` + +**Description**: Advances a pointer by a number of elements (not bytes). + +**Parameters**: + +| Parameter | Type | Description | +|-----------|------|-------------| +| `ptr` | `PtrType` | Source pointer | +| `offset` | `Index` | Number of elements to advance | + +**Returns**: + +| Return Value | Type | Description | +|--------------|------|-------------| +| `new_ptr` | `PtrType` | Pointer advanced by `offset` elements | + +**Example**: + + +```python +ptr = pto.addptr(base_ptr, 1024) +``` + +The `+` shorthand on pointers also counts in elements, not bytes. + +--- + +#### `pto.castptr(address: Index, ptr_type: Type) -> PtrType` + +**Description**: Creates a typed pointer from an integer address or reinterprets a pointer as a different type. + +**Parameters**: + +| Parameter | Type | Description | +|-----------|------|-------------| +| `address` | `Index` | Integer address or existing pointer value | +| `ptr_type` | `Type` | Target pointer type, e.g. `pto.ptr(pto.f32, pto.MemorySpace.UB)` | + +**Returns**: + +| Return Value | Type | Description | +|--------------|------|-------------| +| `ptr` | `PtrType` | Typed pointer value | + +This is an advanced operation. Prefer `as_ptr()` when the source already carries type information. + + +```python +ptr = pto.castptr(addr, pto.ptr(pto.i32, pto.MemorySpace.UB)) +``` + +## 6.5 Compile-time queries + +These functions return values that are known at trace time from type information or hardware constants. + +#### `pto.bytewidth(dtype: Type) -> int` + +**Description**: Returns the size in bytes of a single element of the given data type. The result is a Python `int` evaluated at trace time. + +**Parameters**: + +| Parameter | Type | Description | +|-----------|------|-------------| +| `dtype` | `Type` | Data type, e.g. `pto.f32`, `pto.f16`, `pto.i8` | + +**Returns**: + +| Return Value | Type | Description | +|--------------|------|-------------| +| `size` | `int` | Element size in bytes | + +**Example**: + + +```python +bw = pto.bytewidth(pto.f32) # 4 +bw = pto.bytewidth(pto.f16) # 2 +bw = pto.bytewidth(pto.i8) # 1 +``` + +--- + +#### `pto.elements_per_vreg(dtype: Type) -> int` + +**Description**: Returns how many elements of `dtype` fit in one 256-byte vector register. The result is a Python `int` evaluated at trace time. + +**Parameters**: + +| Parameter | Type | Description | +|-----------|------|-------------| +| `dtype` | `Type` | Data type, e.g. `pto.f32`, `pto.f16`, `pto.i8` | + +**Returns**: + +| Return Value | Type | Description | +|--------------|------|-------------| +| `elems` | `int` | Number of elements per vector register | + +**Example**: + + +```python +vec = pto.elements_per_vreg(pto.f32) # 64 +vec = pto.elements_per_vreg(pto.f16) # 128 +vec = pto.elements_per_vreg(pto.i8) # 256 +``` + +This is the standard stride for chunking column loops in SIMD kernels: + + +```python +VEC = pto.elements_per_vreg(pto.f32) +with pto.for_(0, cols, step=VEC) as c: + ... +``` + +## 6.6 Per-element tile traversal in @pto.simt + +`@pto.simt` kernels are the natural home for per-element scalar work. A typical pattern uses nested `pto.for_` loops to walk over a tile row by row, column by column: + + +```python +@pto.simt +def elementwise_scale( + src_tile: pto.Tile, + dst_tile: pto.Tile, + scale: pto.f32, + rows: pto.i32, + cols: pto.i32, +): + with pto.for_(0, rows, step=1) as r: + with pto.for_(0, cols, step=1) as c: + val = scalar.load(src_tile[r, c]) + scaled = val * scale + scalar.store(scaled, dst_tile[r, c]) +``` + +This reads each element from `src_tile`, multiplies by `scale`, and writes to `dst_tile`. The SIMT unit executes the body in parallel across work-items, so this scalar-looking code achieves high throughput — each work-item handles a different `(r, c)` pair. + +For operations that need per-row metadata alongside per-element computation, lift the row-level scalar out of the inner loop: + + +```python +@pto.simt +def blend_with_per_row_coeffs( + o_prev_tile: pto.Tile, + pv_tile: pto.Tile, + alpha_tile: pto.Tile, # [rows, 1] — one coefficient per row + beta_tile: pto.Tile, # [rows, 1] + o_next_tile: pto.Tile, + rows: pto.i32, + cols: pto.i32, +): + with pto.for_(0, rows, step=1) as r: + alpha = scalar.load(alpha_tile[r, 0]) # read once per row + beta = scalar.load(beta_tile[r, 0]) # read once per row + with pto.for_(0, cols, step=1) as c: + o_prev = scalar.load(o_prev_tile[r, c]) + pv_val = scalar.load(pv_tile[r, c]) + o_next = alpha * o_prev + beta * pv_val + scalar.store(o_next, o_next_tile[r, c]) +``` + +This hoists `alpha` and `beta` out of the inner loop — the row coefficients are loaded once and broadcast across all columns in that row. diff --git a/ptodsl/docs/user_guide/07-data-movement-ops.md b/ptodsl/docs/user_guide/07-data-movement-ops.md new file mode 100644 index 000000000..072019741 --- /dev/null +++ b/ptodsl/docs/user_guide/07-data-movement-ops.md @@ -0,0 +1,1001 @@ +# 7. Data Movement Operations + +This chapter covers every operation that moves data between memory spaces in PTODSL — tile-level transfers, DMA micro-instructions, vector loads and stores, and cube data movement. Operations are organized by abstraction level: tile ops for auto mode, DMA orchestration for explicit mode, vector memory ops on the SIMD unit, and cube memory ops on the Cube unit. + +## 7.1 Tile-level movement: tile.load and tile.store + +Tile ops move entire blocks between Global Memory and the Unified Buffer in a single call. They are the primary data movement interface inside `@pto.jit`. + +#### `pto.tile.load(partition: PartitionTensorView, tile: Tile) -> None` + +**Description**: Copies data from a GM partition into a UB tile. The transfer size is determined by the partition's `sizes` and the tile's shape — they must be compatible. + +**Parameters**: + +| Parameter | Type | Description | +|-----------|------|-------------| +| `partition` | `PartitionTensorView` | Source region in GM | +| `tile` | `Tile` | Destination buffer in UB | + +**Returns**: None (side-effect operation). + +**Example**: + + +```python +a_part = pto.partition_view(a_view, offsets=[offset, 0], sizes=[1, cols]) +a_tile = pto.alloc_tile(shape=[1, BLOCK], dtype=pto.f32) +pto.tile.load(a_part, a_tile) +``` + +--- + +#### `pto.tile.store(tile: Tile, partition: PartitionTensorView) -> None` + +**Description**: Copies data from a UB tile back to a GM partition. The tile's `valid_shape` determines how many elements are written; elements outside `valid_shape` are not stored. + +**Parameters**: + +| Parameter | Type | Description | +|-----------|------|-------------| +| `tile` | `Tile` | Source buffer in UB | +| `partition` | `PartitionTensorView` | Destination region in GM | + +**Returns**: None (side-effect operation). + +**Example**: + + +```python +pto.tile.store(o_tile, o_part) +``` + +--- + +Both `tile.load` and `tile.store` operate at **tile granularity** — they are the idiomatic choice inside `@pto.jit` loops. When you need finer control over DMA scheduling, switch to +`mode="explicit"` and use the DMA micro-instructions covered in the next section. + +## 7.2 DMA micro-instructions (explicit mode) + +Inside explicit-mode orchestration, data movement between memory spaces is expressed with grouped DMA instructions on typed pointers. There are four operations covering the four data-movement directions: + +| Operation | Direction | Stride unit | Padding | +|-----------|-----------|-------------|---------| +| `pto.mte_gm_ub` | GM → UB | bytes | Supported | +| `pto.mte_ub_gm` | UB → GM | bytes | — (de-padded on read) | +| `pto.mte_ub_ub` | UB → UB | 32B units | — | +| `pto.mte_ub_l1` | UB → L1 | 32B units | — | + +All four share a common structure: a required innermost `nburst(...)` group that defines the repeated burst transfer, plus optional outer `loop(...)` groups for multi-level repetition. `pto.mte_gm_ub` additionally supports `pad(...)` for UB row padding. + +### 7.2.1 GM → UB: `pto.mte_gm_ub` + +#### `pto.mte_gm_ub(gm_src: PtrType, ub_dst: PtrType, l2_cache_ctl: int, len_burst: int, *, nburst: tuple[int, int, int], loops: list[tuple[int, int, int]] | None = None, pad: tuple[ScalarType, int, int] | tuple[ScalarType] | None = None) -> None` + +**Description**: Grouped DMA transfer from Global Memory to Unified Buffer. `nburst(...)` defines the innermost repeated burst (count, source stride in bytes, destination stride in bytes). Optional `loop(...)` groups add outer repetition levels. Optional `pad(...)` fills the gap between `len_burst` and `dst_stride` up to the 32B-aligned boundary. + +**Parameters**: + +| Parameter | Type | Description | +|-----------|------|-------------| +| `gm_src` | `PtrType` (gm) | GM source pointer | +| `ub_dst` | `PtrType` (ub) | UB destination pointer (must be 32B-aligned) | +| `l2_cache_ctl` | `int` | L2 cache allocate control (2 bits) | +| `len_burst` | `int` | Contiguous bytes transferred per burst row | +| `nburst` | `tuple[int, int, int]` | `(n_burst, src_stride, dst_stride)` — innermost burst group (required) | +| `loops` | `list[tuple[int, int, int]]` or `None` | Optional outer loop groups, each `(count, src_stride, dst_stride)`. Ordered inner to outer | +| `pad` | `tuple[ScalarType, int, int]` or `tuple[ScalarType]` or `None` | Optional padding: `(pad_value, left_count, right_count)` or `(pad_value,)`. Omitted counts default to 0 | + +**Returns**: None (side-effect operation). + +**Constraints**: +- `nburst` is always required. +- `loop` groups are ordered from inner (wrapping `nburst`) to outer. +- If `pad` specifies either left or right count, both must be provided. + +**Example** — load a 32×32 f32 tile from contiguous GM into contiguous UB: + + +```python +pto.mte_gm_ub(gm_src, ub_dst, 0, 128, + nburst=(32, 128, 128)) +# 32 rows, 128 bytes per row, contiguous in both GM and UB +``` + +**Example** — load a 64×128 f16 tile from a larger GM matrix (1024×512) into UB: + + +```python +pto.mte_gm_ub(gm_src, ub_dst, 0, 256, + nburst=(64, 1024, 256)) +# 64 rows of 256 bytes each. +# GM: each row is 1024 bytes apart (full matrix row stride). +# UB: rows are packed contiguously (256-byte stride). +``` + +**Example** — load with padding (100 valid f16 columns into a 128-wide UB tile): + + +```python +pto.mte_gm_ub(gm_src, ub_dst, 0, 200, + nburst=(64, 200, 256), + pad=(0.0, 0, 0)) +# 64 rows, 200 valid bytes per row, 256-byte UB stride. +# Gap (56 bytes) between len_burst and dst_stride is zero-padded. +``` + +**Example** — multi-level loop: load 4 batches of 8×128 f16 tiles: + + +```python +pto.mte_gm_ub(gm_src, ub_dst, 0, 256, + nburst=(8, 256, 256), + loops=[(4, 2048, 2048)]) +# Innermost: 8 rows × 256B (one tile). +# Outer loop: 4 iterations, each advancing 2048 bytes in both GM and UB. +``` + +--- + +### 7.2.2 UB → GM: `pto.mte_ub_gm` + +#### `pto.mte_ub_gm(ub_src: PtrType, gm_dst: PtrType, len_burst: int, *, nburst: tuple[int, int, int], loops: list[tuple[int, int, int]] | None = None) -> None` + +**Description**: Grouped DMA transfer from Unified Buffer to Global Memory. The MTE reads `len_burst` bytes from each UB row (skipping any padding), writing only valid data to GM. + +**Parameters**: + +| Parameter | Type | Description | +|-----------|------|-------------| +| `ub_src` | `PtrType` (ub) | UB source pointer (must be 32B-aligned) | +| `gm_dst` | `PtrType` (gm) | GM destination pointer | +| `len_burst` | `int` | Contiguous bytes transferred per burst row | +| `nburst` | `tuple[int, int, int]` | `(n_burst, src_stride, dst_stride)` — innermost burst group (required) | +| `loops` | `list[tuple[int, int, int]]` or `None` | Optional outer loop groups, ordered inner to outer | + +**Returns**: None (side-effect operation). + +**Example** — store a 32×32 f32 tile from UB to GM: + + +```python +pto.mte_ub_gm(ub_src_f32, gm_dst_f32, 128, + nburst=(32, 128, 128)) +``` + +**Example** — store a 64×128 f16 tile back to a larger GM matrix: + + +```python +pto.mte_ub_gm(ub_src, gm_dst, 256, + nburst=(64, 256, 1024)) +# UB: contiguous rows (256-byte stride). +# GM: rows spaced at 1024-byte intervals (full matrix width). +``` + +--- + +### 7.2.3 UB → UB: `pto.mte_ub_ub` + +#### `pto.mte_ub_ub(ub_src: PtrType, ub_dst: PtrType, len_burst: int, *, nburst: tuple[int, int, int]) -> None` + +**Description**: Grouped UB-to-UB copy. Stride and gap values are in units of 32 bytes. + +**Parameters**: + +| Parameter | Type | Description | +|-----------|------|-------------| +| `ub_src` | `PtrType` (ub) | UB source pointer (must be 32B-aligned) | +| `ub_dst` | `PtrType` (ub) | UB destination pointer (must be 32B-aligned) | +| `len_burst` | `int` | Burst length in units of 32 bytes | +| `nburst` | `tuple[int, int, int]` | `(n_burst, src_gap, dst_gap)` — count, source gap, destination gap (all in 32B units) | + +**Returns**: None (side-effect operation). + +Each burst copies `len_burst * 32` bytes. The next burst starts at `src + (len_burst + src_gap) * 32` and `dst + (len_burst + dst_gap) * 32`. + +**Example**: + + +```python +pto.mte_ub_ub(ub_src, ub_dst, 8, + nburst=(16, 0, 4)) +# 16 bursts, each copying 8×32=256 bytes. +# Source: contiguous (src_gap=0). +# Destination: 4×32=128-byte gap between bursts. +``` + +--- + +### 7.2.4 UB → L1: `pto.mte_ub_l1` + +#### `pto.mte_ub_l1(ub_src: PtrType, l1_dst: PtrType, len_burst: int, *, nburst: tuple[int, int, int]) -> None` + +**Description**: Grouped UB-to-L1 (CBUF) copy. Identical structure to `mte_ub_ub` but the destination is L1 cube buffer space. + +**Parameters**: + +| Parameter | Type | Description | +|-----------|------|-------------| +| `ub_src` | `PtrType` (ub) | UB source pointer (must be 32B-aligned) | +| `l1_dst` | `PtrType` (l1) | L1 destination pointer (must be 32B-aligned) | +| `len_burst` | `int` | Burst length in units of 32 bytes | +| `nburst` | `tuple[int, int, int]` | `(n_burst, src_gap, dst_gap)` — all in 32B units | + +**Returns**: None (side-effect operation). + +--- + +### 7.2.5 The nburst / loop / pad model + +All grouped DMA operations follow a nested-loop execution model. `nburst` is the innermost group; each `loop` wraps the previous group as an outer iteration level. + +For `mte_gm_ub` and `mte_ub_gm`, strides are **byte distances** from the start of one burst row to the start of the next: + +``` +GM → UB (nburst only): + + for r in range(n_burst): + memcpy(ub_dst + r * dst_stride, + gm_src + r * src_stride, + len_burst) + if pad enabled: + memset(ub_dst + r * dst_stride + len_burst, + pad_value, + dst_stride_aligned - len_burst) +``` + +Each additional `loop(count, src_stride, dst_stride)` adds one outer `for` level that advances both base pointers by the corresponding strides. + +For `mte_ub_ub` and `mte_ub_l1`, the parameters are in **32-byte units**. Each burst copies `len_burst * 32` bytes, and the next burst starts at `src + (len_burst + src_gap) * 32` / `dst + (len_burst + dst_gap) * 32`. + +**UB address alignment**: For all four operations, every UB address (source and destination) must be 32-byte aligned. The `pad(...)` on `mte_gm_ub` ensures each UB row is padded to the 32B-aligned boundary of `dst_stride`, so subsequent rows stay aligned. + +### 7.2.6 Typical explicit-mode DMA pattern + + +```python +# Inside a @pto.jit(mode="explicit") body: +def process_block(k_part, v_part, k_tile, v_tile, o_tile, o_part, + rows: pto.i32, cols: pto.i32): + # Stage K and V blocks from GM to UB + pto.mte_gm_ub(k_part.as_ptr(), k_tile.as_ptr(), 0, + cols * pto.bytewidth(pto.f16), + nburst=(rows, cols * pto.bytewidth(pto.f16), + cols * pto.bytewidth(pto.f16))) + pto.mte_gm_ub(v_part.as_ptr(), v_tile.as_ptr(), 0, + cols * pto.bytewidth(pto.f16), + nburst=(rows, cols * pto.bytewidth(pto.f16), + cols * pto.bytewidth(pto.f16))) + pto.pipe_barrier(pto.Pipe.ALL) + + # ... compute on tiles ... + + pto.pipe_barrier(pto.Pipe.ALL) + pto.mte_ub_gm(o_tile.as_ptr(), o_part.as_ptr(), + cols * pto.bytewidth(pto.f32), + nburst=(rows, cols * pto.bytewidth(pto.f32), + cols * pto.bytewidth(pto.f32))) +``` + +## 7.3 Vector loads (simd) + +Inside `@pto.simd`, data moves between UB tiles and vector registers (`vreg`). Vector loads read a contiguous chunk of a tile row into a `vreg`; the chunk size equals the hardware vector width for the element type (e.g., 64 elements for `f32`, 128 for `f16`). + +### Tile-index syntax + +All vector load and store operations support the element-indexing syntax, which eliminates manual byte-offset calculation: + + +```python +vec = pto.vlds(tile[row, col:]) # load from row, starting at column col +``` + + +```python +vec = pto.vlds(tile[start:]) # 1D tile, starting at element start +``` + +The compiler automatically computes the byte offset from the tile's shape, element type, and layout. The `:` indicates a full vector-width range — the number of elements loaded is `elements_per_vreg(dtype)`. + +--- + +#### `pto.vlds(tile[row, col:], dist: VLoadDist | None = None) -> VRegType` +#### `pto.vlds(tile[start:], dist: VLoadDist | None = None) -> VRegType` +#### `pto.vlds(buf: PtrType, offset: Index, dist: VLoadDist | None = None) -> VRegType` + +**Description**: Stateless vector load from UB. Reads one vector-width slice. + +**Parameters**: + +| Parameter | Type | Description | +|-----------|------|-------------| +| `tile[row, col:]` | Tile index | 2D tile row with starting column (vector-width range) | +| `tile[start:]` | Tile index | 1D tile with starting element (vector-width range) | +| `buf` | `PtrType` (UB) | Pointer to buffer in UB (pointer form) | +| `offset` | `Index` | Element offset (pointer form) | +| `dist` | `VLoadDist` or `None` | Optional load distribution: `NORM` (default), `UNPK_B8`/`UNPK_B16`/`UNPK_B32`, `BRC_B8`/`BRC_B16`/`BRC_B32` | + +**Returns**: + +| Return Value | Type | Description | +|--------------|------|-------------| +| `vec` | `VRegType` | Loaded vector register | + +--- + +#### `pto.vldsx2(tile[row, col:], dist: DeinterleaveDist) -> (VRegType, VRegType)` +#### `pto.vldsx2(tile[start:], dist: DeinterleaveDist) -> (VRegType, VRegType)` +#### `pto.vldsx2(buf: PtrType, offset: Index, dist: DeinterleaveDist) -> (VRegType, VRegType)` + +**Description**: Dual vector load with deinterleave (AoS → SoA). Loads interleaved data and deinterleaves into two vectors. + +PTODSL accepts both pointer-based forms and tile-slice forms. The tile-slice +spellings are PTODSL surface sugar; the pointer form `buf[offset] + dist` is +the canonical form. + +**Parameters**: + +| Parameter | Type | Description | +|-----------|------|-------------| +| `tile[row, col:]` | Tile index | 2D tile row with starting column (vector-width range) | +| `tile[start:]` | Tile index | 1D tile with starting element (vector-width range) | +| `buf` | `PtrType` (UB) | Pointer to buffer in UB (pointer form) | +| `offset` | `Index` | Element offset (pointer form) | +| `dist` | `DeinterleaveDist` | `DINTLV_B8` / `DINTLV_B16` / `DINTLV_B32` (alternating elements) or `BDINTLV` (block deinterleave) | + +**Returns**: + +| Return Value | Type | Description | +|--------------|------|-------------| +| `low` | `VRegType` | Even-indexed elements | +| `high` | `VRegType` | Odd-indexed elements | + +--- + +#### `pto.vldas(tile[row, col:]) -> AlignType` +#### `pto.vldas(tile[start:]) -> AlignType` +#### `pto.vldas(buf: PtrType) -> AlignType` + +**Description**: Primes the alignment buffer for a subsequent unaligned load stream. Returns alignment state consumed by `vldus`. + +PTODSL accepts both pointer-based forms and tile-slice forms. The tile-slice +spellings are PTODSL surface sugar; the pointer form is the canonical form. + +**Parameters**: + +| Parameter | Type | Description | +|-----------|------|-------------| +| `tile[row, col:]` | Tile index | 2D tile row with starting column | +| `tile[start:]` | Tile index | 1D tile with starting element | +| `buf` | `PtrType` | Pointer to buffer in UB | + +**Returns**: + +| Return Value | Type | Description | +|--------------|------|-------------| +| `align` | `AlignType` | Alignment state for use with `vldus` | + +--- + +#### `pto.vldus(tile[row, col:], align: AlignType) -> (VRegType, AlignType)` +#### `pto.vldus(tile[start:], align: AlignType) -> (VRegType, AlignType)` +#### `pto.vldus(buf: PtrType, align: AlignType) -> (VRegType, AlignType)` + +**Description**: Unaligned load with alignment state threading. Requires alignment state from `vldas` or a previous `vldus`. + +PTODSL accepts both pointer-based forms and tile-slice forms. The tile-slice +spellings are PTODSL surface sugar; the pointer form is the canonical form. + +**Parameters**: + +| Parameter | Type | Description | +|-----------|------|-------------| +| `tile[row, col:]` | Tile index | 2D tile row with starting column (vector-width range) | +| `tile[start:]` | Tile index | 1D tile with starting element (vector-width range) | +| `buf` | `PtrType` (UB) | Pointer to buffer in UB (pointer form) | +| `align` | `AlignType` | Alignment state from `vldas` or previous `vldus` | + +**Returns**: + +| Return Value | Type | Description | +|--------------|------|-------------| +| `vec` | `VRegType` | Assembled vector | +| `align_out` | `AlignType` | Updated alignment state for next load | +**Example**: + + +```python +align = pto.vldas(tile[row, col:]) +vec, align = pto.vldus(tile[row, col:], align) +``` + + +```python +align = pto.vldas(tile[start:]) +vec, align = pto.vldus(tile[start:], align) +``` + +--- + +#### `pto.vsld(tile[row, col], stride: StrideMode) -> VRegType` +#### `pto.vsld(tile[pos], stride: StrideMode) -> VRegType` +#### `pto.vsld(buf: PtrType, offset: Index, stride: StrideMode) -> VRegType` + +**Description**: Strided scalar load with broadcast. Loads a single element using a strided access pattern and broadcasts to all vector lanes. + +**Parameters**: + +| Parameter | Type | Description | +|-----------|------|-------------| +| `tile[row, col]` | Tile index | 2D single-element index | +| `tile[pos]` | Tile index | 1D single-element index | +| `stride` | `StrideMode` | `S3_B16`, `S4_B64`, `S8_B32`, or `S2_B64` | + +**Returns**: + +| Return Value | Type | Description | +|--------------|------|-------------| +| `vec` | `VRegType` | Broadcast vector | + +--- + +#### `pto.vgather2(buf: PtrType, offsets: Index, mask: MaskType) -> VRegType` + +**Description**: Indexed gather from UB using per-lane offsets. Only masked-on +lanes participate. + +**Parameters**: + +| Parameter | Type | Description | +|-----------|------|-------------| +| `buf` | `PtrType` (UB) | Source buffer | +| `offsets` | `Index` | Per-lane element offsets (vector register) | +| `mask` | `MaskType` | Predicate mask gating lane participation | + +**Returns**: + +| Return Value | Type | Description | +|--------------|------|-------------| +| `vec` | `VRegType` | Gathered vector | + +--- + +#### `pto.vgather2_bc(buf: PtrType, offsets: Index, mask: MaskType) -> VRegType` + +**Description**: Indexed gather with mask. Masked-off lanes are zero-filled. + +**Parameters**: + +| Parameter | Type | Description | +|-----------|------|-------------| +| `buf` | `PtrType` (UB) | Source buffer | +| `offsets` | `Index` | Per-lane element offsets (vector register) | +| `mask` | `MaskType` | Mask gating lane participation | + +**Returns**: + +| Return Value | Type | Description | +|--------------|------|-------------| +| `vec` | `VRegType` | Gathered vector | + +--- + +#### `pto.vgatherb(buf: PtrType, offsets: Index, mask: MaskType) -> VRegType` + +**Description**: Block gather load. Participating lanes gather 32-byte blocks +from UB using byte offsets. + +**Parameters**: + +| Parameter | Type | Description | +|-----------|------|-------------| +| `buf` | `PtrType` (UB) | Source buffer | +| `offsets` | `Index` | Per-block byte offsets | +| `mask` | `MaskType` | `b32` predicate controlling which blocks participate | + +**Returns**: + +| Return Value | Type | Description | +|--------------|------|-------------| +| `vec` | `VRegType` | Gathered vector | + +--- + +#### `pto.vsldb(tile[row, col], block_stride: Index, repeat_stride: Index, mask: MaskType) -> VRegType` +#### `pto.vsldb(tile[pos], block_stride: Index, repeat_stride: Index, mask: MaskType) -> VRegType` +#### `pto.vsldb(buf: PtrType, block_stride: Index, repeat_stride: Index, mask: MaskType) -> VRegType` + +**Description**: Block-strided load. The source is interpreted as a sequence of +32-byte blocks addressed by `repeat_stride + blk * block_stride`. Masked-off +blocks are zero-filled. + +**Parameters**: + +| Parameter | Type | Description | +|-----------|------|-------------| +| `block_stride` | `Index` | 16-bit block stride field | +| `repeat_stride` | `Index` | 16-bit repeat stride field | +| `mask` | `MaskType` | Mask controlling which blocks participate | + +**Returns**: + +| Return Value | Type | Description | +|--------------|------|-------------| +| `vec` | `VRegType` | Block-strided vector | + +## 7.4 Vector stores (simd) + +Vector stores write `vreg` contents back to UB tiles. Like loads, they support tile-index syntax. + +#### `pto.vsts(vec: VRegType, tile[row, col:], mask: MaskType, dist: VStoreDist | None = None) -> None` +#### `pto.vsts(vec: VRegType, tile[start:], mask: MaskType, dist: VStoreDist | None = None) -> None` +#### `pto.vsts(vec: VRegType, buf: PtrType, offset: Index, mask: MaskType, dist: VStoreDist | None = None) -> None` + +**Description**: Stateless vector store to UB. The mask gates writes for the +distributions that use predicate masking. + +**Parameters**: + +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Vector to store | +| `tile[row, col:]` | Tile index | 2D destination (vector-width range) | +| `tile[start:]` | Tile index | 1D destination (vector-width range) | +| `buf` | `PtrType` (UB) | Destination buffer (pointer form) | +| `offset` | `Index` | Element offset (pointer form) | +| `mask` | `MaskType` | Predicate mask gating writes | +| `dist` | `VStoreDist` or `None` | Store distribution token. When omitted, PTODSL defaults to `NORM_B32` on the current surface. | + +**Returns**: None (side-effect operation). + +**Distribution families**: + +| Family | Notes | +|--------|-------| +| `NORM_B8` / `NORM_B16` / `NORM_B32` | Contiguous vector store | +| `1PT_B8` / `1PT_B16` / `1PT_B32` | First-element-only store; predicate is ignored | +| `PK_B16` / `PK_B32` / `PK_B64` | Packed store families | +| `PK4_B32` | 4-way packed store | +| `MRG4CHN_B8` | 4-channel merge store | +| `MRG2CHN_B8` / `MRG2CHN_B16` | 2-channel merge store | + +--- + +#### `pto.psts(mask: MaskType, buf: PtrType, offset: Index, *, dist: PredicateDist = PredicateDist.NORM) -> None` + +**Description**: Predicate store. Writes the packed predicate payload of `mask` +to UB memory. + +**Parameters**: + +| Parameter | Type | Description | +|-----------|------|-------------| +| `mask` | `MaskType` | Predicate payload to store | +| `buf` | `PtrType` (UB) | Destination buffer | +| `offset` | `Index` | Byte offset | +| `dist` | `PredicateDist` | Predicate payload layout. PTODSL defaults to `NORM` on the current surface. | + +**Returns**: None (side-effect operation). + +--- + +#### `pto.vstsx2(low: VRegType, high: VRegType, tile[row, col:], dist: InterleaveDist, mask: MaskType) -> None` +#### `pto.vstsx2(low: VRegType, high: VRegType, tile[start:], dist: InterleaveDist, mask: MaskType) -> None` +#### `pto.vstsx2(low: VRegType, high: VRegType, buf: PtrType, offset: Index, dist: InterleaveDist, mask: MaskType) -> None` + +**Description**: Dual interleaving store (SoA → AoS). Interleaves two vectors +into one destination. + +**Parameters**: + +| Parameter | Type | Description | +|-----------|------|-------------| +| `low` | `VRegType` | First vector (even elements) | +| `high` | `VRegType` | Second vector (odd elements) | +| `tile[row, col:]` | Tile index | 2D destination (vector-width range) | +| `tile[start:]` | Tile index | 1D destination (vector-width range) | +| `buf` | `PtrType` (UB) | Destination buffer (pointer form) | +| `offset` | `Index` | Element offset (pointer form) | +| `dist` | `InterleaveDist` | `INTLV_B8` / `INTLV_B16` / `INTLV_B32` | +| `mask` | `MaskType` | Parameter retained for call-shape regularity; for the `INTLV_B*` family it does not affect the stored result | + +**Returns**: None (side-effect operation). + +--- + +#### `pto.vsstb(tile[row, col], block_stride: Index, repeat_stride: Index, mask: MaskType) -> None` +#### `pto.vsstb(tile[pos], block_stride: Index, repeat_stride: Index, mask: MaskType) -> None` +#### `pto.vsstb(buf: PtrType, block_stride: Index, repeat_stride: Index, mask: MaskType) -> None` + +**Description**: Block-strided store. Stores 32-byte source blocks to a +block-strided UB destination. Masked-off blocks do not write memory. + +**Parameters**: + +| Parameter | Type | Description | +|-----------|------|-------------| +| `tile[row, col]` | Tile index | 2D starting element | +| `tile[pos]` | Tile index | 1D starting element | +| `buf` | `PtrType` (UB) | Destination buffer (pointer form) | +| `block_stride` | `Index` | 16-bit block stride field | +| `repeat_stride` | `Index` | 16-bit repeat stride field | +| `mask` | `MaskType` | Mask controlling which blocks participate | + +**Returns**: None (side-effect operation). + +--- + +#### `pto.vstar(align: AlignType, tile[row, col:]) -> None` +#### `pto.vstar(align: AlignType, tile[start:]) -> None` +#### `pto.vstar(align: AlignType, buf: PtrType) -> None` + +**Description**: Flush alignment state to memory. Commits buffered tail bytes from an unaligned store stream. Consumes the alignment state. + +**Parameters**: + +| Parameter | Type | Description | +|-----------|------|-------------| +| `align` | `AlignType` | Pending store-alignment state | +| `tile[row, col:]` | Tile index | 2D destination (vector-width range) | +| `tile[start:]` | Tile index | 1D destination (vector-width range) | +| `buf` | `PtrType` (UB) | Destination buffer (pointer form) | + +**Returns**: None (side-effect operation). + +--- + +#### `pto.vstas(align: AlignType, tile[row, col:], offset: Index) -> None` +#### `pto.vstas(align: AlignType, tile[start:], offset: Index) -> None` +#### `pto.vstas(align: AlignType, buf: PtrType, offset: Index) -> None` + +**Description**: Scalar-register-offset form of alignment-state flush. Same buffered-tail semantics as `vstar` with an explicit scalar offset. + +**Parameters**: + +| Parameter | Type | Description | +|-----------|------|-------------| +| `align` | `AlignType` | Pending store-alignment state | +| `tile[row, col:]` | Tile index | 2D destination (vector-width range) | +| `tile[start:]` | Tile index | 1D destination (vector-width range) | +| `buf` | `PtrType` (UB) | Destination buffer (pointer form) | +| `offset` | `Index` | Element offset (all forms) | + +**Returns**: None (side-effect operation). + +--- + +#### `pto.vscatter(vec: VRegType, buf: PtrType, offsets: Index, mask: MaskType) -> None` + +**Description**: Indexed scatter to UB. Stores vector lanes to irregular locations using per-lane offsets. + +**Parameters**: + +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Source vector to scatter | +| `buf` | `PtrType` (UB) | Destination buffer | +| `offsets` | `Index` | Per-lane element offsets (vector register) | +| `mask` | `MaskType` | Predicate mask gating lane participation | + +**Returns**: None (side-effect operation). + +--- + +### Stateful store family + +For streaming unaligned stores with explicit alignment threading: + +#### `pto.vstus(align_in: AlignType, offset: Index, vec: VRegType, buf: PtrType) -> AlignType` + +**Description**: Scalar-offset unaligned store. Returns updated alignment state for the next store in the stream. + +**Parameters**: + +| Parameter | Type | Description | +|-----------|------|-------------| +| `align_in` | `AlignType` | Incoming store-alignment state | +| `offset` | `Index` | Scalar displacement | +| `vec` | `VRegType` | Vector to store | +| `buf` | `PtrType` (UB) | Destination buffer | + +**Returns**: + +| Return Value | Type | Description | +|--------------|------|-------------| +| `align_out` | `AlignType` | Updated buffered-tail state | + +--- + +#### `pto.vstur(align_in: AlignType, vec: VRegType, buf: PtrType, mode: PostUpdate = PostUpdate.OFF) -> AlignType` + +**Description**: Register-update unaligned store. Updates only residual alignment state without base pointer update. + +**Parameters**: + +| Parameter | Type | Description | +|-----------|------|-------------| +| `align_in` | `AlignType` | Incoming store-alignment state | +| `vec` | `VRegType` | Vector to store | +| `buf` | `PtrType` (UB) | Destination buffer | +| `mode` | `PostUpdate` | `PostUpdate.OFF` (default) or `PostUpdate.ON` | + +**Returns**: + +| Return Value | Type | Description | +|--------------|------|-------------| +| `align_out` | `AlignType` | Updated buffered-tail state | + +--- + +#### `pto.pstu(align_in: AlignType, mask: MaskType, buf: PtrType) -> (AlignType, PtrType)` + +**Description**: Predicate unaligned store with alignment state threading. + +**Parameters**: + +| Parameter | Type | Description | +|-----------|------|-------------| +| `align_in` | `AlignType` | Incoming store-alignment state | +| `mask` | `MaskType` | Predicate mask to store | +| `buf` | `PtrType` (UB) | Destination buffer | + +**Returns**: + +| Return Value | Type | Description | +|--------------|------|-------------| +| `align_out` | `AlignType` | Updated alignment state | +| `base_out` | `PtrType` | Post-update base pointer | + +--- + +**Unaligned store stream pattern** — prime, thread, flush: + + +```python +align = pto.init_align() +vec0 = pto.vlds(ub_src_f32, pto.const(0)) +align = pto.vstur(align, vec0, ub_dst_f32, pto.PostUpdate.OFF) +align = pto.vstus(align, pto.const(32), vec0, ub_dst_f32) +pto.vstas(align, ub_dst_f32, pto.const(64)) +``` + +### Distribution enums reference + +| Enum | Values | Used with | +|------|--------|-----------| +| `VLoadDist` | `NORM`, `UNPK_B8`, `UNPK_B16`, `UNPK_B32`, `BRC_B8`, `BRC_B16`, `BRC_B32`, `US_B8`, `US_B16`, `DS_B8`, `DS_B16` | `vlds` | +| `VStoreDist` | `NORM_B8`, `NORM_B16`, `NORM_B32`, `1PT_B8`, `1PT_B16`, `1PT_B32`, `PK_B16`, `PK_B32`, `PK_B64`, `PK4_B32`, `MRG4CHN_B8`, `MRG2CHN_B8`, `MRG2CHN_B16` | `vsts` | +| `DeinterleaveDist` | `DINTLV_B8`, `DINTLV_B16`, `DINTLV_B32`, `BDINTLV` | `vldsx2` | +| `InterleaveDist` | `INTLV_B8`, `INTLV_B16`, `INTLV_B32` | `vstsx2` | +| `StrideMode` | `S3_B16`, `S4_B64`, `S8_B32`, `S2_B64` | `vsld` | +| `PostUpdate` | `OFF`, `ON` | `vstur` | + +## 7.5 Cube data movement (cube) + +Inside `@pto.cube`, data flows through a hierarchy of private buffers: GM → L1 (cbuf) → L0A/L0B (operand buffers) → L0C (accumulator) → UB or back to GM. + +### Staging: GM → L1 and L1 → UB + +#### `pto.mte_gm_l1(src: PtrType, dst: PtrType, len_burst: int, *, nburst: tuple[int, int, int] = (1, 0, 0), loops: list[tuple[int, int, int]] | None = None) -> None` + +**Description**: Structured GM-to-L1 (cbuf) data movement. + +**Parameters**: + +| Parameter | Type | Description | +|-----------|------|-------------| +| `src` | `PtrType` (GM) | Global Memory source pointer | +| `dst` | `PtrType` (L1) | L1 (cbuf) destination pointer | +| `len_burst` | `int` | Burst length in bytes | +| `nburst` | `tuple[int, int, int]` | `(count, src_stride, dst_stride)` | +| `loops` | `list[tuple[int, int, int]]` or `None` | Optional nested loop parameters | + +**Returns**: None (side-effect operation). + +--- + +#### `pto.mte_gm_l1_frac(src: PtrType, dst: PtrType, mode: FractalMode, *, shape: tuple[int, int], src_layout: tuple[int, int], dst_group: tuple[int, int, int, int], ctrl: tuple[int, bool]) -> None` + +**Description**: Fractal GM-to-L1 load for specialized layouts (`ND2NZ`, `DN2NZ`). + +**Parameters**: + +| Parameter | Type | Description | +|-----------|------|-------------| +| `src` | `PtrType` (GM) | Global Memory source pointer | +| `dst` | `PtrType` (L1) | L1 destination pointer | +| `mode` | `FractalMode` | `ND2NZ` or `DN2NZ` | +| `shape` | `tuple[int, int]` | `(n_value, d_value)` | +| `src_layout` | `tuple[int, int]` | `(inner_stride, outer_stride)` | +| `dst_group` | `tuple[int, int, int, int]` | `(group_count, loop2_stride, loop3_stride, loop4_stride)` | +| `ctrl` | `tuple[int, bool]` | `(l2_cache_ctrl, smallc0_en)` | + +**Returns**: None (side-effect operation). + +--- + +#### `pto.mte_l1_ub(src: PtrType, dst: PtrType, len_burst: int, *, nburst: tuple[int, int, int] = (1, 0, 0), loops: list[tuple[int, int, int]] | None = None) -> None` + +**Description**: Structured L1 (cbuf) to UB data movement. + +**Parameters**: + +| Parameter | Type | Description | +|-----------|------|-------------| +| `src` | `PtrType` (L1) | L1 source pointer | +| `dst` | `PtrType` (UB) | UB destination pointer | +| `len_burst` | `int` | Burst length in bytes | +| `nburst` | `tuple[int, int, int]` | `(count, src_stride, dst_stride)` | +| `loops` | `list[tuple[int, int, int]]` or `None` | Optional nested loop parameters | + +**Returns**: None (side-effect operation). + +--- + +### Operand loading: L1 → L0A / L0B + +#### `pto.mte_l1_l0a(src: PtrType, dst: PtrType, m: int, k: int) -> None` + +**Description**: Structured L1-to-L0A (left-operand buffer) load. + +**Parameters**: + +| Parameter | Type | Description | +|-----------|------|-------------| +| `src` | `PtrType` (L1) | L1 source pointer | +| `dst` | `PtrType` (L0A) | L0A destination pointer | +| `m` | `int` | M dimension size | +| `k` | `int` | K dimension size | + +**Returns**: None (side-effect operation). + +--- + +#### `pto.mte_l1_l0b(src: PtrType, dst: PtrType, k: int, n: int, *, transpose: bool = False) -> None` + +**Description**: Structured L1-to-L0B (right-operand buffer) load. + +**Parameters**: + +| Parameter | Type | Description | +|-----------|------|-------------| +| `src` | `PtrType` (L1) | L1 source pointer | +| `dst` | `PtrType` (L0B) | L0B destination pointer | +| `k` | `int` | K dimension size | +| `n` | `int` | N dimension size | +| `transpose` | `bool` | Whether to load in transposed order | + +**Returns**: None (side-effect operation). + +--- + +#### `pto.mte_l1_l0a_mx(src: PtrType, dst: PtrType, m: int, k: int) -> None` +#### `pto.mte_l1_l0b_mx(src: PtrType, dst: PtrType, k: int, n: int) -> None` + +**Description**: MX-mode variants of `mte_l1_l0a` and `mte_l1_l0b` for MX-capable dtypes. Parameters same as their non-MX counterparts. + +--- + +### Bias loading + +#### `pto.mte_l1_bias(src: PtrType, dst: PtrType, len_burst: int, *, nburst: tuple[int, int, int] = (1, 0, 0)) -> None` + +**Description**: Structured L1 (cbuf) to bias table load. + +**Parameters**: + +| Parameter | Type | Description | +|-----------|------|-------------| +| `src` | `PtrType` (L1) | L1 source pointer | +| `dst` | `PtrType` (BIAS) | Bias table destination pointer | +| `len_burst` | `int` | Burst length in bytes | +| `nburst` | `tuple[int, int, int]` | `(count, src_gap, dst_gap)` | + +**Returns**: None (side-effect operation). + +--- + +### Accumulator writeback: L0C → L1 / GM / UB + +#### `pto.mte_l0c_l1(src: PtrType, dst: PtrType, m: int, n: int, src_stride: int, dst_stride: int, *, mode: FractalMode = FractalMode.NZ2ND, loop0_src_stride: int | None = None, split: int | None = None, loop3: tuple[int, int, int] | None = None) -> None` + +**Description**: Structured L0C (acc) to L1 (cbuf) writeback. + +**Parameters**: + +| Parameter | Type | Description | +|-----------|------|-------------| +| `src` | `PtrType` (L0C) | L0C accumulator source pointer | +| `dst` | `PtrType` (L1) | L1 destination pointer | +| `m` | `int` | M dimension size | +| `n` | `int` | N dimension size | +| `src_stride` | `int` | Source stride | +| `dst_stride` | `int` | Destination stride | +| `mode` | `FractalMode` | `NZ2ND` (default), `NZ2DN`, or `NZ2NZ` | + +**Returns**: None (side-effect operation). + +--- + +#### `pto.mte_l0c_gm(src: PtrType, dst: PtrType, m: int, n: int, src_stride: int, dst_stride: int, *, sid: int = 0, l2_cache_ctrl: int = 0, mode: FractalMode = FractalMode.NZ2ND, loop0_src_stride: int | None = None, split: int | None = None, loop3: tuple[int, int, int] | None = None) -> None` + +**Description**: Structured L0C (acc) to GM writeback. + +**Parameters**: + +| Parameter | Type | Description | +|-----------|------|-------------| +| `src` | `PtrType` (L0C) | L0C accumulator source pointer | +| `dst` | `PtrType` (gm) | GM destination pointer | +| `m` | `int` | M dimension size | +| `n` | `int` | N dimension size | +| `src_stride` | `int` | Source stride | +| `dst_stride` | `int` | Destination stride | +| `sid` | `int` | Stream ID (default 0) | +| `l2_cache_ctrl` | `int` | L2 cache control (default 0) | +| `mode` | `FractalMode` | `NZ2ND` (default), `NZ2DN`, or `NZ2NZ` | +| `loop0_src_stride` | `int` or `None` | Loop level 0 source stride | +| `split` | `int` or `None` | Split parameter | +| `loop3` | `tuple[int, int, int]` or `None` | Loop level 3 parameters | + +**Returns**: None (side-effect operation). + +--- + +#### `pto.mte_l0c_ub(src: PtrType, dst: PtrType, m: int, n: int, src_stride: int, dst_stride: int, sub_blockid: int = 0, *, dst_mode: str = "single") -> None` + +**Description**: Structured L0C (acc) directly to UB. This is the most common writeback path for cube kernels that feed results into subsequent processing. + +**Parameters**: + +| Parameter | Type | Description | +|-----------|------|-------------| +| `src` | `PtrType` (L0C) | L0C accumulator source pointer | +| `dst` | `PtrType` (ub) | UB destination pointer | +| `m` | `int` | M dimension size | +| `n` | `int` | N dimension size | +| `src_stride` | `int` | Source stride | +| `dst_stride` | `int` | Destination stride | +| `sub_blockid` | `int` | Sub-block ID (default 0) | +| `dst_mode` | `str` | Destination mode, currently `"single"` by default | + +**Returns**: None (side-effect operation). + +--- + +### Cube data movement quick reference + +| Data Flow | Operation | Src Space | Dst Space | +|-----------|-----------|-----------|-----------| +| GM → L1 | `mte_gm_l1` | gm | l1 | +| GM → L1 (fractal) | `mte_gm_l1_frac` | gm | l1 | +| L1 → UB | `mte_l1_ub` | l1 | ub | +| L1 → L0A | `mte_l1_l0a` | l1 | l0a | +| L1 → L0B | `mte_l1_l0b` | l1 | l0b | +| L1 → L0A (MX) | `mte_l1_l0a_mx` | l1 | l0a | +| L1 → L0B (MX) | `mte_l1_l0b_mx` | l1 | l0b | +| L1 → Bias | `mte_l1_bias` | l1 | bt | +| L0C → L1 | `mte_l0c_l1` | l0c | l1 | +| L0C → GM | `mte_l0c_gm` | l0c | gm | +| L0C → UB | `mte_l0c_ub` | l0c | ub | + +### Typical cube dataflow in a matmul + +A full cube matmul (`@pto.cube`) follows this dataflow pattern: + + +```python +@pto.cube +def qk_matmul(q_tile, k_tile, q_l0a, k_l0b, s_acc, s_tile): + m = q_tile.valid_shape[0] + k = q_tile.valid_shape[1] + n = k_tile.valid_shape[0] + + pto.mte_l1_l0a(q_tile.as_ptr(), q_l0a.as_ptr(), m, k) # UB tile → L0A + pto.mte_l1_l0b(k_tile.as_ptr(), k_l0b.as_ptr(), k, n, transpose=True) # UB tile → L0B + pto.mad(q_l0a.as_ptr(), k_l0b.as_ptr(), s_acc.as_ptr(), m, n, k) # L0A × L0B → L0C + pto.mte_l0c_ub(s_acc.as_ptr(), s_tile.as_ptr(), m, n, n, n, 0) # L0C → UB tile +``` + +At the cube micro-op boundary, PTODSL currently uses explicit typed pointers. `tile.as_ptr()` materializes the pointer view for UB and cube-local scratch buffers, while the surrounding sub-kernel surface still uses `Tile` values for metadata such as `valid_shape`. diff --git a/ptodsl/docs/user_guide/08-compute-operations.md b/ptodsl/docs/user_guide/08-compute-operations.md new file mode 100644 index 000000000..0da510042 --- /dev/null +++ b/ptodsl/docs/user_guide/08-compute-operations.md @@ -0,0 +1,840 @@ +# 8. Compute Operations + +Chapters 6 and 7 covered scalars, pointers, and data movement. This chapter covers everything that actually *computes* — arithmetic, math functions, reductions, comparisons, and matrix multiplication — organized by abstraction level: tile ops (L1), vector ops (L3 SIMD), and cube ops (L3 cube). + +## 8.1 Tile-level compute (L1) + +Tile compute ops are the primary arithmetic surface inside `@pto.jit`. They operate on `Tile` buffers in UB and follow a consistent pattern: each op reads one or more source tiles, optionally a scalar, and writes a destination tile. Shapes and valid regions must be compatible across all operands. + +### 8.1.1 Binary tile-tile arithmetic + +Element-wise operations between two tiles of the same shape. + +#### `pto.tile.add(src0: Tile, src1: Tile, dst: Tile) -> None` +#### `pto.tile.sub(src0: Tile, src1: Tile, dst: Tile) -> None` +#### `pto.tile.mul(src0: Tile, src1: Tile, dst: Tile) -> None` +#### `pto.tile.max(src0: Tile, src1: Tile, dst: Tile) -> None` +#### `pto.tile.min(src0: Tile, src1: Tile, dst: Tile) -> None` + +**Description**: Element-wise `dst[i,j] = src0[i,j] src1[i,j]`. + +**Parameters**: + +| Parameter | Type | Description | +|-----------|------|-------------| +| `src0` | `Tile` | First source tile | +| `src1` | `Tile` | Second source tile | +| `dst` | `Tile` | Destination tile (must be pre-allocated, shape-compatible) | + +**Returns**: None (writes to `dst`). + +**Example**: + +```python +pto.tile.add(a_tile, b_tile, o_tile) +pto.tile.mul(scale_tile, data_tile, scaled_tile) +``` + +--- + +#### `pto.tile.div(src0: Tile, src1: Tile, dst: Tile, *, precision_mode: PrecisionMode = PrecisionMode.DEFAULT) -> None` + +**Description**: Element-wise division. `precision_mode` can be `DEFAULT` or `HIGH_PRECISION` (f16/f32 only). + +**Parameters**: + +| Parameter | Type | Description | +|-----------|------|-------------| +| `src0` | `Tile` | Numerator tile | +| `src1` | `Tile` | Denominator tile | +| `dst` | `Tile` | Destination tile | +| `precision_mode` | `PrecisionMode` | `DEFAULT` (default) or `HIGH_PRECISION` | + +**Returns**: None. + +--- + +### 8.1.2 Tile-scalar arithmetic + +Element-wise operations between a tile and a scalar. + +#### `pto.tile.adds(src: Tile, scalar: ScalarType, dst: Tile) -> None` +#### `pto.tile.subs(src: Tile, scalar: ScalarType, dst: Tile) -> None` +#### `pto.tile.muls(src: Tile, scalar: ScalarType, dst: Tile) -> None` +#### `pto.tile.maxs(src: Tile, scalar: ScalarType, dst: Tile) -> None` +#### `pto.tile.mins(src: Tile, scalar: ScalarType, dst: Tile) -> None` + +**Description**: Element-wise `dst[i,j] = src[i,j] scalar`. + +**Parameters**: + +| Parameter | Type | Description | +|-----------|------|-------------| +| `src` | `Tile` | Source tile | +| `scalar` | `ScalarType` | Scalar operand (Python number or PTO scalar) | +| `dst` | `Tile` | Destination tile | + +**Returns**: None. + +--- + +#### `pto.tile.divs(src: Tile, scalar: ScalarType, dst: Tile, *, precision_mode: PrecisionMode = PrecisionMode.DEFAULT) -> None` + +**Description**: Element-wise tile-scalar division: `dst[i,j] = src[i,j] / scalar`. + +--- + +### 8.1.3 Unary math + +Single-source element-wise math functions. + +#### `pto.tile.exp(src: Tile, dst: Tile, *, precision_mode: PrecisionMode = PrecisionMode.DEFAULT) -> None` +#### `pto.tile.log(src: Tile, dst: Tile, *, precision_mode: PrecisionMode = PrecisionMode.DEFAULT) -> None` +#### `pto.tile.sqrt(src: Tile, dst: Tile, *, precision_mode: PrecisionMode = PrecisionMode.DEFAULT) -> None` +#### `pto.tile.rsqrt(src: Tile, dst: Tile, *, precision_mode: PrecisionMode = PrecisionMode.DEFAULT) -> None` +#### `pto.tile.recip(src: Tile, dst: Tile, *, precision_mode: PrecisionMode = PrecisionMode.DEFAULT) -> None` + +**Description**: Element-wise `exp`, `ln`, `sqrt`, `1/sqrt`, `1/x`. + +**Parameters**: + +| Parameter | Type | Description | +|-----------|------|-------------| +| `src` | `Tile` | Source tile | +| `dst` | `Tile` | Destination tile | +| `precision_mode` | `PrecisionMode` | `DEFAULT` or `HIGH_PRECISION` | + +**Returns**: None. + +--- + +#### `pto.tile.abs(src: Tile, dst: Tile) -> None` +#### `pto.tile.neg(src: Tile, dst: Tile) -> None` + +**Description**: Element-wise absolute value and negation. No precision mode attribute. + +--- + +### 8.1.4 Activation + +#### `pto.tile.relu(src: Tile, dst: Tile) -> None` + +**Description**: `dst[i,j] = max(0, src[i,j])`. Supported on f16, f32, i32. + +#### `pto.tile.lrelu(src: Tile, slope: float, dst: Tile) -> None` + +**Description**: Leaky ReLU — `dst[i,j] = src[i,j] >= 0 ? src[i,j] : slope * src[i,j]`. + +--- + +### 8.1.5 Row and column reductions + +Reductions collapse one dimension of a 2D tile, producing a tile with one row or one column. + +#### Row reductions + +#### `pto.tile.rowsum(src: Tile, dst: Tile, *, tmp: Tile | None = None) -> None` +#### `pto.tile.rowmax(src: Tile, dst: Tile, *, tmp: Tile | None = None) -> None` +#### `pto.tile.rowmin(src: Tile, dst: Tile, *, tmp: Tile | None = None) -> None` +#### `pto.tile.rowprod(src: Tile, dst: Tile, *, tmp: Tile | None = None) -> None` +#### `pto.tile.rowargmax(src: Tile, dst: Tile, *, tmp: Tile | None = None) -> None` +#### `pto.tile.rowargmin(src: Tile, dst: Tile, *, tmp: Tile | None = None) -> None` + +**Description**: For each row `i`, reduce across columns: `dst[i, 0] = _j src[i, j]`. `tile.rowargmax`/`tile.rowargmin` return the column index of the extremum. In the public PTODSL wrapper, `tmp` is optional; when omitted, PTODSL allocates a matching scratch tile automatically. + +**Parameters**: + +| Parameter | Type | Description | +|-----------|------|-------------| +| `src` | `Tile` | Source tile (`[rows, cols]`) | +| `dst` | `Tile` | Destination tile (`[rows, 1]`) | +| `tmp` | `Tile | None` | Optional scratch tile for intermediate reduction state; when omitted, PTODSL synthesizes a matching scratch tile automatically | + +**Returns**: None. + +--- + +#### Column reductions + +#### `pto.tile.colsum(src: Tile, dst: Tile) -> None` +#### `pto.tile.colmax(src: Tile, dst: Tile) -> None` +#### `pto.tile.colmin(src: Tile, dst: Tile) -> None` +#### `pto.tile.colprod(src: Tile, dst: Tile) -> None` + +**Description**: For each column `j`, reduce across rows: `dst[0, j] = _i src[i, j]`. + +**Parameters**: + +| Parameter | Type | Description | +|-----------|------|-------------| +| `src` | `Tile` | Source tile (`[rows, cols]`) | +| `dst` | `Tile` | Destination tile (`[1, cols]`) | + +**Returns**: None. + +--- + +### 8.1.6 Broadcast and expansion + +Expansion ops take a narrow source (scalar, row vector, or column vector) and broadcast it to a full tile shape. They are useful for applying per-row or per-column coefficients to a tile. + +#### Scalar broadcast + +#### `pto.tile.expands(scalar: ScalarType, dst: Tile) -> None` + +**Description**: `dst[i,j] = scalar` — fills every element of `dst` with the same scalar value. + +--- + +#### Row expansion + +#### `pto.tile.rowexpand(src: Tile, dst: Tile) -> None` + +**Description**: `dst[row, col] = src[row, 0]` — broadcasts each row's single value across all columns of `dst`. + +**Parameters**: + +| Parameter | Type | Description | +|-----------|------|-------------| +| `src` | `Tile` | Source tile (`[rows, 1]`) | +| `dst` | `Tile` | Destination tile (`[rows, cols]`) | + +**Returns**: None. + +--- + +#### Column expansion + +#### `pto.tile.colexpand(src: Tile, dst: Tile) -> None` + +**Description**: `dst[row, col] = src[0, col]` — broadcasts each column's single value across all rows of `dst`. + +--- + +#### Row-expand arithmetic + +These combine broadcasting with an arithmetic operation: `src1` is a per-row coefficient tile (`[rows, 1]`) that gets expanded row-wise before the element-wise op with `src0`. + +| Op | Semantics | +|----|-----------| +| `pto.tile.rowexpandadd(src0, src1, dst)` | `dst = src0 + expand_rows(src1)` | +| `pto.tile.rowexpandsub(src0, src1, dst)` | `dst = src0 - expand_rows(src1)` | +| `pto.tile.rowexpandmul(src0, src1, dst)` | `dst = src0 * expand_rows(src1)` | +| `pto.tile.rowexpanddiv(src0, src1, dst)` | `dst = src0 / expand_rows(src1)` (f-only) | +| `pto.tile.rowexpandmax(src0, src1, dst)` | `dst = max(src0, expand_rows(src1))` | +| `pto.tile.rowexpandmin(src0, src1, dst)` | `dst = min(src0, expand_rows(src1))` | +| `pto.tile.rowexpandexpdif(src0, src1, dst)` | `dst = exp(src0 - expand_rows(src1))` (f-only) | + +**Parameters**: + +| Parameter | Type | Description | +|-----------|------|-------------| +| `src0` | `Tile` | Full-shape source tile (`[rows, cols]`) | +| `src1` | `Tile` | Per-row coefficient tile (`[rows, 1]`) | +| `dst` | `Tile` | Destination tile (`[rows, cols]`) | + +**Returns**: None. + +**Example** — apply per-row scale and bias: + +```python +# alpha_tile: [rows, 1], beta_tile: [rows, 1], data_tile: [rows, cols] +pto.tile.rowexpandmul(data_tile, alpha_tile, scaled_tile) +pto.tile.rowexpandadd(scaled_tile, beta_tile, result_tile) +``` + +--- + +#### Column-expand arithmetic + +Same pattern as row-expand arithmetic, but `src1` is a per-column coefficient tile (`[1, cols]`): + +| Op | Semantics | +|----|-----------| +| `pto.tile.colexpandadd(src0, src1, dst)` | `dst = src0 + expand_cols(src1)` | +| `pto.tile.colexpandsub(src0, src1, dst)` | `dst = src0 - expand_cols(src1)` | +| `pto.tile.colexpandmul(src0, src1, dst)` | `dst = src0 * expand_cols(src1)` | +| `pto.tile.colexpanddiv(src0, src1, dst)` | `dst = src0 / expand_cols(src1)` (f-only) | +| `pto.tile.colexpandmax(src0, src1, dst)` | `dst = max(src0, expand_cols(src1))` | +| `pto.tile.colexpandmin(src0, src1, dst)` | `dst = min(src0, expand_cols(src1))` | +| `pto.tile.colexpandexpdif(src0, src1, dst)` | `dst = exp(src0 - expand_cols(src1))` (f-only) | + +--- + +### 8.1.7 Selection + +#### `pto.tile.sel(mask: Tile, src0: Tile, src1: Tile, dst: Tile, *, tmp: Tile | None = None) -> None` + +**Description**: Element-wise ternary: `dst[i,j] = mask[i,j] ? src0[i,j] : src1[i,j]`. The `mask` is an integer tile where zero means false and non-zero means true. `tmp` is an optional scratch tile override; when omitted, PTODSL synthesizes any architecture-specific scratch tile automatically. + +#### `pto.tile.sels(mask: Tile, src: Tile, scalar: ScalarType, dst: Tile, *, tmp: Tile | None = None) -> None` + +**Description**: Element-wise select with scalar fallback: `dst[i,j] = mask[i,j] ? src[i,j] : scalar`. As with `tile.sel`, `tmp` is optional and PTODSL synthesizes any required scratch tile automatically when it is omitted. + +--- + +### 8.1.8 Type conversion + +#### `pto.tile.cvt(src: Tile, dst: Tile, *, rmode: RoundMode = RoundMode.NONE) -> None` + +**Description**: Element-wise type conversion. The destination tile's `dtype` determines the target type. + +**Parameters**: + +| Parameter | Type | Description | +|-----------|------|-------------| +| `src` | `Tile` | Source tile | +| `dst` | `Tile` | Destination tile (with target dtype) | +| `rmode` | `RoundMode` | Rounding mode: `NONE`, `RINT`, `ROUND`, `FLOOR`, `CEIL`, `TRUNC`, `ODD`, `CAST_RINT` | + +**Returns**: None. + +--- + +### 8.1.9 Bitwise ops + +Bitwise operations on integer tiles (i8, i16, i32, etc.). All follow the standard `(src, dst)` or `(src0, src1, dst)` pattern. + +#### Unary bitwise + +#### `pto.tile.bit_not(src: Tile, dst: Tile) -> None` + +**Description**: Element-wise bitwise NOT: `dst[i,j] = ~src[i,j]`. Integer types only. + +**Parameters**: + +| Parameter | Type | Description | +|-----------|------|-------------| +| `src` | `Tile` | Source tile (integer dtype) | +| `dst` | `Tile` | Destination tile | + +**Returns**: None. + +--- + +#### Binary bitwise (tile-tile) + +#### `pto.tile.bit_and(src0: Tile, src1: Tile, dst: Tile) -> None` +#### `pto.tile.bit_or(src0: Tile, src1: Tile, dst: Tile) -> None` +#### `pto.tile.bit_shl(src0: Tile, src1: Tile, dst: Tile) -> None` +#### `pto.tile.bit_shr(src0: Tile, src1: Tile, dst: Tile) -> None` + +**Description**: Element-wise bitwise `dst[i,j] = src0[i,j] src1[i,j]`. Integer types only. + +**Parameters**: + +| Parameter | Type | Description | +|-----------|------|-------------| +| `src0` | `Tile` | First source tile | +| `src1` | `Tile` | Second source tile | +| `dst` | `Tile` | Destination tile | + +**Returns**: None. + +--- + +#### `pto.tile.bit_xor(src0: Tile, src1: Tile, dst: Tile, *, tmp: Tile | None = None) -> None` + +**Description**: Element-wise bitwise XOR. Requires an additional scratch buffer `tmp` of the same type as `dst`. When `tmp` is omitted, PTODSL synthesizes a matching scratch tile automatically. + +**Parameters**: + +| Parameter | Type | Description | +|-----------|------|-------------| +| `src0` | `Tile` | First source tile | +| `src1` | `Tile` | Second source tile | +| `dst` | `Tile` | Destination tile | +| `tmp` | `Tile | None` | Optional scratch tile; when omitted, PTODSL synthesizes one automatically | + +**Returns**: None. + +--- + +#### Binary bitwise (tile-scalar) + +#### `pto.tile.bit_ands(src: Tile, scalar: ScalarType, dst: Tile) -> None` +#### `pto.tile.bit_ors(src: Tile, scalar: ScalarType, dst: Tile) -> None` +#### `pto.tile.bit_shls(src: Tile, scalar: ScalarType, dst: Tile) -> None` +#### `pto.tile.bit_shrs(src: Tile, scalar: ScalarType, dst: Tile) -> None` + +**Description**: Element-wise `dst[i,j] = src[i,j] scalar`. The scalar is broadcast to all elements. Integer types only. + +**Parameters**: + +| Parameter | Type | Description | +|-----------|------|-------------| +| `src` | `Tile` | Source tile | +| `scalar` | `ScalarType` | Scalar operand (Python int or PTO scalar) | +| `dst` | `Tile` | Destination tile | + +**Returns**: None. + +--- + +#### `pto.tile.bit_xors(src: Tile, scalar: ScalarType, dst: Tile, *, tmp: Tile | None = None) -> None` + +**Description**: Element-wise bitwise XOR with scalar. Requires an additional scratch buffer `tmp` of the same type as `dst`. When `tmp` is omitted, PTODSL synthesizes a matching scratch tile automatically. + +**Parameters**: + +| Parameter | Type | Description | +|-----------|------|-------------| +| `src` | `Tile` | Source tile | +| `scalar` | `ScalarType` | Scalar operand | +| `dst` | `Tile` | Destination tile | +| `tmp` | `Tile | None` | Optional scratch tile; when omitted, PTODSL synthesizes one automatically | + +**Returns**: None. + +--- + +### 8.1.10 Partial elementwise ops + +Partial elementwise ops compute over the **intersection** of the valid regions of two source tiles. This allows element-wise arithmetic between tiles that have different `valid_shape`s — only the overlapping area is computed. + +#### `pto.tile.partadd(src0: Tile, src1: Tile, dst: Tile) -> None` +#### `pto.tile.partmul(src0: Tile, src1: Tile, dst: Tile) -> None` +#### `pto.tile.partmax(src0: Tile, src1: Tile, dst: Tile) -> None` +#### `pto.tile.partmin(src0: Tile, src1: Tile, dst: Tile) -> None` + +**Description**: Element-wise `dst[i,j] = src0[i,j] src1[i,j]` over the intersection of `src0.valid_shape` and `src1.valid_shape`. + +**Parameters**: + +| Parameter | Type | Description | +|-----------|------|-------------| +| `src0` | `Tile` | First source tile (may have a partial valid region) | +| `src1` | `Tile` | Second source tile (may have a partial valid region) | +| `dst` | `Tile` | Destination tile | + +**Returns**: None. + +**Example** — adding tiles with different valid regions: + +```python +# a_tile: valid_shape = [64, 32], b_tile: valid_shape = [64, 64] +# The partial add only operates on the intersection: 64 columns × min(32, 64) = 32 columns +pto.tile.partadd(a_tile, b_tile, result_tile) +``` + +--- + +### 8.1.11 Fill/padding + +Fill-padding ops copy a source tile's valid region into a destination tile, filling the remaining physical elements (outside `src.valid_shape`) with a configured pad value. The pad value is specified at tile allocation time via the tile's `PadValue` attribute (`Null`, `Zero`, `Max`, or `Min`). + +#### `pto.tile.fillpad(src: Tile, dst: Tile) -> None` + +**Description**: Copies `src`'s valid region into `dst` and fills extra elements of `dst` with the pad value configured on `dst`'s type. The `dst` physical shape must be at least as large as `src.valid_shape`. + +#### `pto.tile.fillpad_expand(src: Tile, dst: Tile) -> None` + +**Description**: Like `fillpad`, but the destination tile may have a different shape in the partition/tensor view. The src valid region is copied and the expanded area is filled with the pad value. Useful when expanding a tile into a larger buffer for downstream processing. + +#### `pto.tile.fillpad_inplace(src: Tile, dst: Tile) -> None` + +**Description**: In-place variant of `fillpad`. `src` and `dst` may refer to the same tile buffer, padding the tile's own valid region in place. + +**Parameters** (all three ops): + +| Parameter | Type | Description | +|-----------|------|-------------| +| `src` | `Tile` | Source tile (with valid region to copy) | +| `dst` | `Tile` | Destination tile (carries `PadValue` attribute set at allocation) | + +**Returns**: None. + +**Example** — padding a partial tile to full shape: + +```python +# tile has valid_shape [32, 16] in a physical buffer of [32, 32] +# pad=Zero at allocation time fills extra columns with zeros +pto.tile.fillpad(partial_tile, padded_tile) +``` + +--- + +### 8.1.12 Tile compute quick reference + +| Category | Operations | +|----------|------------| +| Binary tile-tile | `tile.add`, `tile.sub`, `tile.mul`, `tile.div`, `tile.max`, `tile.min` | +| Tile-scalar | `tile.adds`, `tile.subs`, `tile.muls`, `tile.divs`, `tile.maxs`, `tile.mins` | +| Unary math | `tile.exp`, `tile.log`, `tile.sqrt`, `tile.rsqrt`, `tile.recip`, `tile.abs`, `tile.neg` | +| Activation | `tile.relu`, `tile.lrelu` | +| Row reductions | `tile.rowsum`, `tile.rowmax`, `tile.rowmin`, `tile.rowprod`, `tile.rowargmax`, `tile.rowargmin` | +| Column reductions | `tile.colsum`, `tile.colmax`, `tile.colmin`, `tile.colprod` | +| Broadcast | `tile.expands`, `tile.rowexpand`, `tile.colexpand` | +| Row-expand arith | `tile.rowexpandadd`, `tile.rowexpandsub`, `tile.rowexpandmul`, `tile.rowexpanddiv`, `tile.rowexpandmax`, `tile.rowexpandmin`, `tile.rowexpandexpdif` | +| Col-expand arith | `tile.colexpandadd`, `tile.colexpandsub`, `tile.colexpandmul`, `tile.colexpanddiv`, `tile.colexpandmax`, `tile.colexpandmin`, `tile.colexpandexpdif` | +| Selection | `tile.sel`, `tile.sels` | +| Type conversion | `tile.cvt` | +| Bitwise | `tile.bit_not`, `tile.bit_and`, `tile.bit_or`, `tile.bit_xor`, `tile.bit_shl`, `tile.bit_shr`, `tile.bit_ands`, `tile.bit_ors`, `tile.bit_xors`, `tile.bit_shls`, `tile.bit_shrs` | +| Partial elementwise | `tile.partadd`, `tile.partmul`, `tile.partmax`, `tile.partmin` | +| Fill/padding | `tile.fillpad`, `tile.fillpad_expand`, `tile.fillpad_inplace` | + +--- + +## 8.2 Vector compute (L3 — `@pto.simd`) + +Vector compute ops operate on `VRegType` values inside `@pto.simd` sub-kernels. Every vector op takes a `MaskType` predicate that gates which lanes participate; masked-off lanes produce an unspecified result (use the result only where the mask is true, or feed it to a masked store). + +All vector ops in this section follow the pattern established in Section 7.3 for tile-index and pointer-form addressing. The signatures below use the vector-register form — tile-index forms load into `vreg` first, then compute. + +### 8.2.1 Unary vector ops + +#### `pto.vexp(vec: VRegType, mask: MaskType) -> VRegType` +#### `pto.vln(vec: VRegType, mask: MaskType) -> VRegType` +#### `pto.vsqrt(vec: VRegType, mask: MaskType) -> VRegType` +#### `pto.vabs(vec: VRegType, mask: MaskType) -> VRegType` +#### `pto.vneg(vec: VRegType, mask: MaskType) -> VRegType` +#### `pto.vrec(vec: VRegType, mask: MaskType) -> VRegType` +#### `pto.vrsqrt(vec: VRegType, mask: MaskType) -> VRegType` +#### `pto.vrelu(vec: VRegType, mask: MaskType) -> VRegType` +#### `pto.vnot(vec: VRegType, mask: MaskType) -> VRegType` + +**Description**: Element-wise unary operation under mask. `vrec` = reciprocal, `vrsqrt` = inverse square root, `vrelu` = `max(0, x)`. + +**Parameters**: + +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector | +| `mask` | `MaskType` | Predicate mask (granularity must match element type) | + +**Returns**: + +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Result vector | + +**Example**: + + +```python +exp_vec = pto.vexp(s_row, col_mask) +``` + +--- + +### 8.2.2 Binary vector ops + +#### `pto.vadd(v0: VRegType, v1: VRegType, mask: MaskType) -> VRegType` +#### `pto.vsub(v0: VRegType, v1: VRegType, mask: MaskType) -> VRegType` +#### `pto.vmul(v0: VRegType, v1: VRegType, mask: MaskType) -> VRegType` +#### `pto.vdiv(v0: VRegType, v1: VRegType, mask: MaskType) -> VRegType` +#### `pto.vmax(v0: VRegType, v1: VRegType, mask: MaskType) -> VRegType` +#### `pto.vmin(v0: VRegType, v1: VRegType, mask: MaskType) -> VRegType` + +**Description**: Element-wise binary operation: `result[i] = v0[i] v1[i]` for lanes where `mask[i]` is true. + +**Parameters**: + +| Parameter | Type | Description | +|-----------|------|-------------| +| `v0` | `VRegType` | First operand vector | +| `v1` | `VRegType` | Second operand vector | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: + +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Result vector | + +--- + +**Bitwise binary ops** (integer types only): + +| Op | Semantics | +|----|-----------| +| `pto.vand(v0, v1, mask) -> VRegType` | `v0 & v1` | +| `pto.vor(v0, v1, mask) -> VRegType` | `v0 \| v1` | +| `pto.vxor(v0, v1, mask) -> VRegType` | `v0 ^ v1` | +| `pto.vshl(vec, shift, mask) -> VRegType` | `vec << shift` (per-element) | +| `pto.vshr(vec, shift, mask) -> VRegType` | `vec >> shift` (per-element) | + +--- + +### 8.2.3 Vector-scalar ops + +#### `pto.vadds(vec: VRegType, scalar: ScalarType, mask: MaskType) -> VRegType` +#### `pto.vsubs(vec: VRegType, scalar: ScalarType, mask: MaskType) -> VRegType` +#### `pto.vmuls(vec: VRegType, scalar: ScalarType, mask: MaskType) -> VRegType` +#### `pto.vmaxs(vec: VRegType, scalar: ScalarType, mask: MaskType) -> VRegType` +#### `pto.vmins(vec: VRegType, scalar: ScalarType, mask: MaskType) -> VRegType` + +**Description**: Element-wise `result[i] = vec[i] scalar`. The scalar is broadcast to all active lanes. + +**Parameters**: + +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector | +| `scalar` | `ScalarType` | Scalar operand (uniform across all lanes) | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: + +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Result vector | + +**Example** — subtract row max from score row (online softmax): + + +```python +s_shifted = pto.vsubs(s_row, m_next, col_mask) +``` + +--- + +#### `pto.vlrelu(vec: VRegType, alpha: ScalarType, mask: MaskType) -> VRegType` + +**Description**: Leaky ReLU — `vec[i] >= 0 ? vec[i] : alpha * vec[i]`. + +--- + +### 8.2.4 Full-vector and group reductions + +#### Full-vector reductions + +#### `pto.vcadd(vec: VRegType, mask: MaskType) -> VRegType` + +**Description**: Full-vector sum reduction. Result placed in lane 0. + +#### `pto.vcmax(vec: VRegType, mask: MaskType) -> VRegType` + +**Description**: Full-vector max with argmax. Result lane 0 = max value, lane 1 = max index. + +#### `pto.vcmin(vec: VRegType, mask: MaskType) -> VRegType` + +**Description**: Full-vector min with argmin. Result lane 0 = min value, lane 1 = min index. + +--- + +#### Group reductions (per-VLane) + +These reduce within each hardware vector lane group (typically 8 groups per vector). Useful when a vector register holds multiple independent sub-vectors that need separate reductions. + +#### `pto.vcgadd(vec: VRegType, mask: MaskType) -> ScalarType` +#### `pto.vcgmax(vec: VRegType, mask: MaskType) -> ScalarType` +#### `pto.vcgmin(vec: VRegType, mask: MaskType) -> ScalarType` + +**Description**: Per-group sum, max, or min. The underlying vector reduction places each group's result in the first lane of that group; the ptodsl surface extracts lane 0 and returns it as a runtime scalar. + +**Parameters**: + +| Parameter | Type | Description | +|-----------|------|-------------| +| `vec` | `VRegType` | Input vector | +| `mask` | `MaskType` | Predicate mask | + +**Returns**: + +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `ScalarType` | Lane-0 scalar extracted from the grouped reduction result | + +**Example** — row max and row sum from online softmax: + + +```python +row_max = pto.vcgmax(s_row, col_mask) # grouped reduction, surfaced as a runtime scalar +row_sum = pto.vcgadd(p_row, col_mask) # grouped reduction, surfaced as a runtime scalar +``` + +--- + +#### `pto.vcpadd(vec: VRegType, mask: MaskType) -> VRegType` + +**Description**: Inclusive prefix sum (scan). `result[i] = sum_{k=0}^{i} vec[k]` for active lanes. f16 and f32 only. + +--- + +### 8.2.5 Fused and compound ops + +These combine an arithmetic operation with a math function or activation in a single instruction. + +#### `pto.vexpdif(vec: VRegType, max_vec: VRegType, mask: MaskType, *, part: PartMode = PartMode.ODD) -> VRegType` + +**Description**: `exp(vec[i] - max_vec[i])` — the stable softmax numerator. `part` controls which half of the vector is computed: `EVEN` or `ODD`. The result keeps the same `VRegType` as the input vector. + +--- + +#### `pto.vaxpy(alpha: ScalarType, x: VRegType, y: VRegType, mask: MaskType) -> VRegType` + +**Description**: Fused multiply-add: `alpha * x[i] + y[i]`. + +--- + +#### `pto.vaddrelu(v0: VRegType, v1: VRegType, mask: MaskType) -> VRegType` + +**Description**: `max(0, v0[i] + v1[i])` — fused add + ReLU. + +#### `pto.vsubrelu(v0: VRegType, v1: VRegType, mask: MaskType) -> VRegType` + +**Description**: `max(0, v0[i] - v1[i])` — fused sub + ReLU. + +--- + +### 8.2.6 Comparison and selection + +#### `pto.vcmp(v0: VRegType, v1: VRegType, seed_mask: MaskType, cmp_mode: CmpMode) -> MaskType` + +**Description**: Element-wise comparison producing a predicate mask. `seed_mask` selects which lanes participate; the result inherits its granularity (e.g., `mask_b32` for f32). + +**Parameters**: + +| Parameter | Type | Description | +|-----------|------|-------------| +| `v0` | `VRegType` | First operand | +| `v1` | `VRegType` | Second operand | +| `seed_mask` | `MaskType` | Seed mask gating participation | +| `cmp_mode` | `CmpMode` | `EQ`, `NE`, `LT`, `LE`, `GT`, `GE` | + +**Returns**: + +| Return Value | Type | Description | +|--------------|------|-------------| +| `pred` | `MaskType` | Result predicate mask | + +--- + +#### `pto.vcmps(vec: VRegType, scalar: ScalarType, seed_mask: MaskType, cmp_mode: CmpMode) -> MaskType` + +**Description**: Vector-scalar comparison. Same semantics as `vcmp` with a uniform scalar second operand. + +--- + +#### `pto.vsel(true_v: VRegType, false_v: VRegType, mask: MaskType) -> VRegType` + +**Description**: Per-lane select: `mask[i] ? true_v[i] : false_v[i]`. + +**Parameters**: + +| Parameter | Type | Description | +|-----------|------|-------------| +| `true_v` | `VRegType` | Values when mask is true | +| `false_v` | `VRegType` | Values when mask is false | +| `mask` | `MaskType` | Selection predicate | + +**Returns**: + +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `VRegType` | Selected vector | + +--- + +### 8.2.7 Vector compute quick reference + +| Category | Operations | +|----------|------------| +| Unary | `vexp`, `vln`, `vsqrt`, `vabs`, `vneg`, `vrec`, `vrsqrt`, `vrelu`, `vnot` | +| Binary | `vadd`, `vsub`, `vmul`, `vdiv`, `vmax`, `vmin`, `vand`, `vor`, `vxor`, `vshl`, `vshr` | +| Vector-scalar | `vadds`, `vsubs`, `vmuls`, `vmaxs`, `vmins`, `vlrelu` | +| Broadcast | `vbr`, `vdup` | +| Full reduction | `vcadd`, `vcmax`, `vcmin` | +| Group reduction | `vcgadd`, `vcgmax`, `vcgmin` | +| Scan | `vcpadd` | +| Fused | `vexpdif`, `vaxpy`, `vaddrelu`, `vsubrelu` | +| Compare/select | `vcmp`, `vcmps`, `vsel` | +| Conversion | `vbitcast`, `pbitcast` | + +--- + +## 8.3 Cube compute (L3 — `@pto.cube`) + +The Cube unit performs matrix multiplication. Its operands are typed pointers into cube-local buffers — L0A (left operand), L0B (right operand), L0C (accumulator), and BIAS. Cube data movement (`mte_l1_l0a`, `mte_l1_l0b`, `mte_l0c_ub`, etc.) was covered in Section 7.5; this section covers the compute instruction itself. + +### 8.3.1 Matrix multiply: `pto.mad` + +#### `pto.mad(lhs: PtrType, rhs: PtrType, dst: PtrType, m: int, n: int, k: int) -> None` + +**Description**: Zero-initialized matrix multiply: `dst[M×N] = lhs[M×K] * rhs[K×N]`. `lhs` is an L0A pointer, `rhs` is an L0B pointer, `dst` is an L0C pointer. + +**Parameters**: + +| Parameter | Type | Description | +|-----------|------|-------------| +| `lhs` | `PtrType` (L0A) | Left operand matrix (M × K) | +| `rhs` | `PtrType` (L0B) | Right operand matrix (K × N) | +| `dst` | `PtrType` (L0C) | Destination accumulator (M × N) | +| `m` | `int` | M dimension size | +| `k` | `int` | K dimension (inner/reduction dimension) | +| `n` | `int` | N dimension size | + +**Returns**: None (writes to `dst` in L0C). + +--- + +#### `pto.mad_acc(lhs: PtrType, rhs: PtrType, dst: PtrType, m: int, n: int, k: int) -> None` + +**Description**: Accumulating matrix multiply: `dst[M×N] += lhs[M×K] * rhs[K×N]`. `dst` must already hold a prior accumulation result. + +--- + +#### `pto.mad_bias(lhs: PtrType, rhs: PtrType, dst: PtrType, bias: PtrType, m: int, n: int, k: int) -> None` + +**Description**: Bias-initialized matrix multiply: `dst[M×N] = lhs[M×K] * rhs[K×N] + bias[M×N]`. `bias` is a BIAS pointer. + +--- + +#### `pto.mad_mx(lhs: PtrType, rhs: PtrType, dst: PtrType, m: int, n: int, k: int) -> None` + +**Description**: MX-format zero-initialized matrix multiply. This variant is intended for MX-enabled operand formats such as f8 payloads with their associated scale data already staged into cube-local buffers. + +--- + +#### `pto.mad_mx_acc(lhs: PtrType, rhs: PtrType, dst: PtrType, m: int, n: int, k: int) -> None` + +**Description**: MX-format accumulating matrix multiply: `dst[M×N] += lhs[M×K] * rhs[K×N]`. + +--- + +#### `pto.mad_mx_bias(lhs: PtrType, rhs: PtrType, dst: PtrType, bias: PtrType, m: int, n: int, k: int) -> None` + +**Description**: MX-format bias-initialized matrix multiply: `dst[M×N] = lhs[M×K] * rhs[K×N] + bias[M×N]`. + +--- + +### 8.3.2 Typical cube matmul pattern + +A full cube matmul follows a three-stage pattern: stage operands into L0A/L0B, compute, write back to UB. + + +```python +@pto.cube +def qk_matmul(q_tile, k_tile, q_l0a, k_l0b, s_acc, s_tile): + m = q_tile.valid_shape[0] + k = q_tile.valid_shape[1] + n = k_tile.valid_shape[1] + + # Stage: source tiles → L0A / L0B + pto.mte_l1_l0a(q_tile.as_ptr(), q_l0a.as_ptr(), m, k) + pto.mte_l1_l0b(k_tile.as_ptr(), k_l0b.as_ptr(), k, n, transpose=True) + + # Compute: L0A × L0B → L0C + pto.mad(q_l0a.as_ptr(), k_l0b.as_ptr(), s_acc.as_ptr(), m, n, k) + + # Writeback: L0C → UB + pto.mte_l0c_ub(s_acc.as_ptr(), s_tile.as_ptr(), m, n, n, n, 0) +``` + +The `mte_l1_l0a`/`mte_l1_l0b` stage operands from the authored source tiles into cube-local buffers. `mad` performs the matrix multiply into L0C. `mte_l0c_ub` writes the result back to a UB tile for downstream processing. At this micro-op layer, the operands are explicit pointer views obtained with `.as_ptr()`. + +--- + +### 8.3.3 Cube compute quick reference + +| Operation | Semantics | +|-----------|-----------| +| `pto.mad(lhs, rhs, dst, m, n, k)` | `dst = lhs * rhs` (zero-init) | +| `pto.mad_acc(lhs, rhs, dst, m, n, k)` | `dst += lhs * rhs` (accumulating) | +| `pto.mad_bias(lhs, rhs, dst, bias, m, n, k)` | `dst = lhs * rhs + bias` | +| `pto.mad_mx(lhs, rhs, dst, m, n, k)` | MX-format zero-init matmul | +| `pto.mad_mx_acc(lhs, rhs, dst, m, n, k)` | MX-format accumulating matmul | +| `pto.mad_mx_bias(lhs, rhs, dst, bias, m, n, k)` | MX-format bias-init matmul | + +MX variants require MX-enabled dtypes (f8) and pre-loaded scale payloads. For most users, the standard `mad`, `mad_acc`, and `mad_bias` are the primary interface. diff --git a/ptodsl/docs/user_guide/09-predicate-and-mask-ops.md b/ptodsl/docs/user_guide/09-predicate-and-mask-ops.md new file mode 100644 index 000000000..0362a285d --- /dev/null +++ b/ptodsl/docs/user_guide/09-predicate-and-mask-ops.md @@ -0,0 +1,424 @@ +# 9. Predicate and Mask Operations + +Vector operations on the SIMD unit execute across many lanes in parallel — but not all lanes always hold valid data. The last chunk of a row may be shorter than the hardware vector width; a row-wise reduction may need to skip padding elements. **Predicate masks** are the mechanism that gates which lanes participate in an operation. + +This chapter covers mask types, mask creation, logical manipulation, reorganization, and load/store. Comparison operations that *produce* masks from vector data (`vcmp`, `vcmps`) are also covered here, since masks are their primary output. + +## 9.1 Mask types + +The hardware predicate register is a 256-bit register. PTODSL exposes three typed views of it, differing in how many elements each bit represents: + +| Mask type | ALU width | Lanes | Used with vector types | +|-----------|-----------|-------|----------------------| +| `pto.mask_b8` | 8-bit | 256 | `i8` vectors | +| `pto.mask_b16` | 16-bit | 128 | `f16`, `bf16`, `i16` vectors | +| `pto.mask_b32` | 32-bit | 64 | `f32`, `i32` vectors | + +A mask and the vector it gates must share the same granularity: a `mask_b32` gates an `f32` vector (64 lanes), not an `f16` vector (128 lanes). + +**Zeroing predication**: when a lane is masked off, the operation produces zero in that lane. This is the gating model for all vector compute ops in Chapter 8. + +## 9.2 Mask creation: `pto.make_mask` + +The recommended front door for creating masks is `pto.make_mask`. It dispatches to the right underlying op based on its arguments. + +#### `pto.make_mask(dtype: Type, value: int-like | MaskPattern) -> MaskType | (MaskType, int-like)` + +**Description**: Creates a predicate mask of the granularity matching `dtype`. When `value` is an integer-like scalar (typically a remaining-element count in a chunked loop), returns a tuple `(mask, remaining)`. When `value` is a `MaskPattern`, returns just the mask. + +**Parameters**: + +| Parameter | Type | Description | +|-----------|------|-------------| +| `dtype` | `Type` | Element type to infer mask granularity from (e.g., `pto.f32` → `mask_b32`, `pto.f16` → `mask_b16`) | +| `value` | `int-like` or `MaskPattern` | Either a remaining-element count or a pattern token | + +**Returns**: + +| Return Value | Type | Description | +|--------------|------|-------------| +| `mask` | `MaskType` | The created mask | +| `remained` | `int-like` | Updated remaining count (only when `value` is an integer-like scalar); its scalar kind is preserved, so an `index` remainder stays an `index` | + +**Example** — chunked SIMD loop with tail handling: + + +```python +VEC = pto.elements_per_vreg(pto.f32) +col_loop = pto.for_(0, cols, step=VEC).carry(remained=cols) +with col_loop: + c = col_loop.iv + remained = col_loop.remained + mask, remained = pto.make_mask(pto.f32, remained) + vec = pto.vlds(tile[r, c:]) + # ... operate under mask ... + pto.vsts(vec, out_tile[r, c:], mask) + col_loop.update(remained=remained) +``` + +`make_mask` generates a tail mask from the remaining count: the first `min(remained, VL)` lanes are active, and `remained` is decremented by `VL` for the next iteration. On the final partial chunk, fewer than `VL` lanes are active. PTODSL handles the hardware `i32` tail-mask operand internally, so loop-carried `index` metadata can flow through `make_mask` without manual casts. + +--- + +When the mask pattern is known at compile time, pass a `MaskPattern` instead: + + +```python +full_mask = pto.make_mask(pto.f32, pto.MaskPattern.ALL) +``` + +This is equivalent to calling the granularity-specific ops described below. + +--- + +## 9.3 Granularity-specific creation ops + +When you need explicit control over the mask granularity, use these ops directly. + +### 9.3.1 Pattern-based: `pset_b*` and `pge_b*` + +`pset` generates a mask from a named pattern. `pge` generates a tail mask where the first N lanes are active (N encoded in the pattern). + + +```python +full_mask = pto.pset_b32(pto.MaskPattern.ALL) +``` + + +```python +mask8 = pto.pset_b8(pto.MaskPattern.ALL) +mask16 = pto.pset_b16(pto.MaskPattern.ALL) +``` + +#### `pto.pset_b8(pattern: MaskPattern) -> pto.mask_b8` +#### `pto.pset_b16(pattern: MaskPattern) -> pto.mask_b16` +#### `pto.pset_b32(pattern: MaskPattern) -> pto.mask_b32` + +**Description**: Creates a mask from a pattern token. `PAT_ALL` sets all lanes active; other patterns set a subset. + +**Parameters**: + +| Parameter | Type | Description | +|-----------|------|-------------| +| `pattern` | `MaskPattern` | Pattern token: `ALL`, `ALLF`, `H`, `Q`, `VL1`–`VL128`, `M3`, `M4` | + +**Returns**: + +| Return Value | Type | Description | +|--------------|------|-------------| +| `mask` | `MaskType` | Mask with lanes set per the pattern | + +--- + +#### `pto.pge_b8(pattern: MaskPattern) -> pto.mask_b8` +#### `pto.pge_b16(pattern: MaskPattern) -> pto.mask_b16` +#### `pto.pge_b32(pattern: MaskPattern) -> pto.mask_b32` + +**Description**: Tail mask — `mask[i] = (i < N) ? 1 : 0`, where N is encoded in the pattern. Typically uses `VL*` patterns. + +--- + +### 9.3.2 Scalar-driven: `plt_b*` + +`plt` generates a tail mask from a live `i32` scalar — the idiomatic choice for dynamic tail handling when not using `make_mask`. + + +```python +mask, remained = pto.plt_b32(remained) +``` + +#### `pto.plt_b8(scalar: pto.i32) -> (pto.mask_b8, pto.i32)` +#### `pto.plt_b16(scalar: pto.i32) -> (pto.mask_b16, pto.i32)` +#### `pto.plt_b32(scalar: pto.i32) -> (pto.mask_b32, pto.i32)` + +**Description**: Generates a tail mask where the first `min(scalar, VL)` lanes are active, and returns `scalar - min(scalar, VL)` as the updated remaining count. + +**Parameters**: + +| Parameter | Type | Description | +|-----------|------|-------------| +| `scalar` | `pto.i32` | Remaining element count | + +**Returns**: + +| Return Value | Type | Description | +|--------------|------|-------------| +| `mask` | `MaskType` | Tail mask (first N lanes active) | +| `scalar_out` | `pto.i32` | Updated remaining = `max(0, scalar - VL)` | + +`VL` is 256 for `b8`, 128 for `b16`, and 64 for `b32`. + +--- + +## 9.4 Mask logical operations + + +```python +merged = pto.pand(src0, src1, gate) +``` + +Once created, masks can be combined with bitwise logical ops. All take a gating mask that selects which lanes participate; inactive lanes are zeroed in the result. + +#### `pto.pand(src0: MaskType, src1: MaskType, mask: MaskType) -> MaskType` +#### `pto.por(src0: MaskType, src1: MaskType, mask: MaskType) -> MaskType` +#### `pto.pxor(src0: MaskType, src1: MaskType, mask: MaskType) -> MaskType` + +**Description**: Bitwise AND / OR / XOR of two masks, gated by a third mask: `dst[i] = mask[i] ? (src0[i] src1[i]) : 0`. All three masks must share the same granularity. + +**Parameters**: + +| Parameter | Type | Description | +|-----------|------|-------------| +| `src0` | `MaskType` | First source mask | +| `src1` | `MaskType` | Second source mask | +| `mask` | `MaskType` | Gating mask (lanes where false produce 0) | + +**Returns**: + +| Return Value | Type | Description | +|--------------|------|-------------| +| `result` | `MaskType` | Combined mask | + +--- + +#### `pto.pnot(src: MaskType, mask: MaskType) -> MaskType` + +**Description**: Bitwise NOT under gate: `dst[i] = mask[i] ? (~src[i]) : 0`. + +--- + +#### `pto.psel(src0: MaskType, src1: MaskType, sel: MaskType) -> MaskType` + +**Description**: Per-lane mask select: `dst[i] = sel[i] ? src0[i] : src1[i]`. All lanes participate directly — there is no additional gating beyond `sel` itself. + +--- + +## 9.5 Mask reorganization + +These ops reshape masks between granularities and layouts without changing the underlying 256-bit register image (except pack/unpack, which remap bits). + +#### `pto.pbitcast(mask: MaskType, to_type: MaskType) -> MaskType` + +**Description**: Bitwise reinterpretation of a mask at a different granularity. The 256-bit predicate register image is unchanged; only the lane count and element-width interpretation change. + +**Example**: + + +```python +# Reinterpret a b16 mask as b32 +mask32 = pto.pbitcast(mask16, pto.mask_b32) +``` + +--- + +#### `pto.ppack(mask: MaskType, part: PredicatePart) -> MaskType` + +**Description**: Narrowing pack — keeps one bit out of each adjacent 2-bit group from the source, packing them into the selected half (`LOWER` or `HIGHER`) of the result. The other half is zero-filled. + +#### `pto.punpack(mask: MaskType, part: PredicatePart) -> MaskType` + +**Description**: Widening unpack — reads the selected half of the source, zero-extends each 1-bit element into a 2-bit group in the result. + + +```python +packed_hi = pto.ppack(mask32, pto.PredicatePart.HIGHER) +unpacked_hi = pto.punpack(packed_hi, pto.PredicatePart.HIGHER) +``` + +--- + +#### `pto.pintlv_b8(src0: pto.mask_b8, src1: pto.mask_b8) -> (pto.mask_b8, pto.mask_b8)` +#### `pto.pintlv_b16(src0: pto.mask_b16, src1: pto.mask_b16) -> (pto.mask_b16, pto.mask_b16)` +#### `pto.pintlv_b32(src0: pto.mask_b32, src1: pto.mask_b32) -> (pto.mask_b32, pto.mask_b32)` + +**Description**: Interleave two masks element-wise. Returns `(low, high)` where `low[i] = src0[i]` and `high[i] = src1[i]` at each interleaved position. + +#### `pto.pdintlv_b8(src0: pto.mask_b8, src1: pto.mask_b8) -> (pto.mask_b8, pto.mask_b8)` +#### `pto.pdintlv_b16(src0: pto.mask_b16, src1: pto.mask_b16) -> (pto.mask_b16, pto.mask_b16)` +#### `pto.pdintlv_b32(src0: pto.mask_b32, src1: pto.mask_b32) -> (pto.mask_b32, pto.mask_b32)` + +**Description**: Deinterleave — the inverse of `pintlv`. Takes interleaved data in two masks and separates even/odd elements. + +--- + +## 9.6 Comparisons: producing masks from vectors + +Vector comparisons produce predicate masks from vector data. The result can feed into mask logical ops, `vsel`, or gated stores. + +#### `pto.vcmp(v0: VRegType, v1: VRegType, seed_mask: MaskType, cmp_mode: CmpMode) -> MaskType` + +**Description**: Element-wise vector-vector comparison: `dst[i] = seed_mask[i] ? (v0[i] v1[i]) : 0`. + +**Parameters**: + +| Parameter | Type | Description | +|-----------|------|-------------| +| `v0` | `VRegType` | First operand vector | +| `v1` | `VRegType` | Second operand vector | +| `seed_mask` | `MaskType` | Seed mask gating which lanes participate | +| `cmp_mode` | `CmpMode` | `EQ`, `NE`, `LT`, `LE`, `GT`, `GE` | + +**Returns**: + +| Return Value | Type | Description | +|--------------|------|-------------| +| `pred` | `MaskType` | Result predicate mask (inherits granularity from operands) | + +--- + +#### `pto.vcmps(vec: VRegType, scalar: ScalarType, seed_mask: MaskType, cmp_mode: CmpMode) -> MaskType` + +**Description**: Vector-scalar comparison: `dst[i] = seed_mask[i] ? (vec[i] scalar) : 0`. The scalar is broadcast to all lanes. + +**Example** — threshold a vector: + + +```python +big = pto.vcmps(scores, threshold, seed, pto.CmpMode.GT) +# big[i] = 1 where scores[i] > threshold +``` + +--- + +**Tile-level comparisons** (`pto.tile.cmp`, `pto.tile.cmps`) compare two tiles and write packed predicate bytes into an `i8` destination tile. They are used when the comparison result needs to be stored to UB for later selection (`tile.sel`) or cross-kernel communication. + +--- + +## 9.7 Mask load and store + +Masks can be persisted to and loaded from UB memory, enabling cross-stage predicate communication. + +### 9.7.1 Predicate loads + +#### `pto.plds(buf: PtrType, offset: Index, *, dist: PredicateDist = PredicateDist.NORM) -> MaskType` + +**Description**: Load a predicate mask from UB memory at the given byte offset. The mask granularity is determined by the pointer element type of `buf` (`ui8`/`ui16`/`ui32` -> `mask_b8`/`mask_b16`/`mask_b32`). + +**Parameters**: + +| Parameter | Type | Description | +|-----------|------|-------------| +| `buf` | `PtrType` (UB) | Source buffer | +| `offset` | `Index` | Byte offset | +| `dist` | `PredicateDist` | `NORM` (load VL/8 packed bytes), `US` (upsample), `DS` (downsample) | + +**Returns**: + +| Return Value | Type | Description | +|--------------|------|-------------| +| `mask` | `MaskType` | Loaded predicate mask | + +--- + +### 9.7.2 Predicate stores + +#### `pto.psts(mask: MaskType, buf: PtrType, offset: Index, *, dist: PredicateDist = PredicateDist.NORM) -> None` + +**Description**: Store a predicate mask to UB memory. + +**Parameters**: + +| Parameter | Type | Description | +|-----------|------|-------------| +| `mask` | `MaskType` | Predicate mask to store | +| `buf` | `PtrType` (UB) | Destination buffer | +| `offset` | `Index` | Byte offset | +| `dist` | `PredicateDist` | `NORM` (store VL/8 packed bytes) or `PK` (pack to VL/16 bytes) | + +**Returns**: None. + +--- + +### 9.7.3 Unaligned predicate store + +#### `pto.pstu(align_in: AlignType, mask: MaskType, buf: PtrType) -> (AlignType, PtrType)` + +**Description**: Unaligned predicate store with alignment state threading. Threads the `align` state through a stream of stores, ensuring tail bytes are correctly buffered. This op currently supports only `mask_b16` and `mask_b32`; the base pointer type is determined by the mask granularity (`ui16` for `b16`, `ui32` for `b32`). + +**Parameters**: + +| Parameter | Type | Description | +|-----------|------|-------------| +| `align_in` | `AlignType` | Incoming alignment state (from `init_align` or previous `pstu`) | +| `mask` | `MaskType` | Predicate mask to store (`mask_b16` or `mask_b32` only) | +| `buf` | `PtrType` (UB) | Destination buffer | + +**Returns**: + +| Return Value | Type | Description | +|--------------|------|-------------| +| `align_out` | `AlignType` | Updated alignment state | +| `base_out` | `PtrType` | Post-update base pointer | + + +## 9.8 How masks gate vector operations + +Every vector compute op in Chapter 8 takes a mask as its last operand. The contract is consistent: + +- For **unary ops** (`vexp`, `vabs`, etc.): `dst[i] = mask[i] ? f(src[i]) : 0` +- For **binary ops** (`vadd`, `vmul`, etc.): `dst[i] = mask[i] ? (lhs[i] rhs[i]) : 0` +- For **vector stores** (`vsts`): `dst[i] = mask[i] ? src[i]` — masked-off lanes are not written +- For **reductions** (`vcadd`, `vcgmax`, etc.): only lanes where `mask[i]` is true contribute to the result + +The mask granularity must match the vector element type. Using a `mask_b16` with an `f32` vector (or vice versa) is an error. + +**Typical pattern** — tail-safe vector processing: + + +```python +VEC = pto.elements_per_vreg(pto.f32) +with pto.for_(0, rows, step=1) as r: + col_loop = pto.for_(0, cols, step=VEC).carry(remained=cols) + with col_loop: + c = col_loop.iv + remained = col_loop.remained + mask, remained = pto.make_mask(pto.f32, remained) + + vec = pto.vlds(tile[r, c:]) + vec = pto.vexp(vec, mask) + pto.vsts(vec, out_tile[r, c:], mask) + + col_loop.update(remained=remained) +``` + +The `mask` gates the `vexp` (masked-off lanes produce 0) and the `vsts` (masked-off lanes are not written). `col_loop` carries the remaining count across iterations, so the final partial chunk correctly masks only the valid tail elements. + +--- + +## 9.9 Tile-level mask operations + +When working at the tile level (L1, `@pto.jit`), masks are carried in `i8` tile buffers holding packed predicate bytes. The key consumer of tile-level masks is `tile.sel`. + +#### `pto.tile.sel(mask: Tile, src0: Tile, src1: Tile, dst: Tile, *, tmp: Tile | None = None) -> None` + +**Description**: Element-wise ternary select: `dst[i,j] = mask[i,j] ? src0[i,j] : src1[i,j]`. `mask` is an integer tile (typically `i8`) where zero means false. `tmp` is an optional scratch tile override; when omitted, PTODSL synthesizes any architecture-specific scratch tile automatically. + +**Parameters**: + +| Parameter | Type | Description | +|-----------|------|-------------| +| `mask` | `Tile` | Integer mask tile (zero = false) | +| `src0` | `Tile` | True-branch source tile | +| `src1` | `Tile` | False-branch source tile | +| `tmp` | `Tile \| None` | Optional scratch tile override | +| `dst` | `Tile` | Destination tile | + +**Returns**: None. + +--- + +#### `pto.tile.sels(mask: Tile, src: Tile, scalar: ScalarType, dst: Tile, *, tmp: Tile | None = None) -> None` + +**Description**: Element-wise select with scalar fallback: `dst[i,j] = mask[i,j] ? src[i,j] : scalar`. As with `tile.sel`, `tmp` is optional and PTODSL synthesizes any required scratch tile automatically when it is omitted. + +--- + +## 9.10 Enum reference + +| Enum | Values | Used with | +|------|--------|-----------| +| `MaskPattern` | `ALL`, `ALLF`, `H`, `Q`, `VL1`–`VL128`, `M3`, `M4` | `pset_b*`, `pge_b*`, `make_mask` | +| `CmpMode` | `EQ`, `NE`, `LT`, `LE`, `GT`, `GE` | `vcmp`, `vcmps` | +| `PredicateDist` (load) | `NORM`, `US`, `DS` | `plds` | +| `PredicateDist` (store) | `NORM`, `PK` | `psts` | +| `PredicatePart` | `LOWER`, `HIGHER` | `ppack`, `punpack` | diff --git a/ptodsl/docs/user_guide/10-sync-ops.md b/ptodsl/docs/user_guide/10-sync-ops.md new file mode 100644 index 000000000..de727ee89 --- /dev/null +++ b/ptodsl/docs/user_guide/10-sync-ops.md @@ -0,0 +1,451 @@ +# 10. Synchronization Operations + +Chapters 7 and 8 covered data movement and computation. This chapter covers the synchronization primitives that keep those operations correctly ordered across the NPU's concurrent hardware pipelines. + +The Ascend NPU executes work across multiple independent pipelines — MTE (DMA), Vector, and Cube — each with its own instruction stream. Synchronization operations coordinate these pipelines: a DMA must finish loading data before the vector unit starts computing on it; a matrix multiply must complete before the result is stored. These operations are available in both `mode="auto"` and `mode="explicit"` when the kernel needs them. Without correct synchronization, pipelines race, and results are undefined. + +## 10.1 Enum types for synchronization + +PTODSL provides three enum types for type-safe specification of synchronization parameters. + +### `BarrierType` + +Memory barrier types used with `pto.mem_bar`. Each value specifies which category of prior instruction must complete before which category of subsequent instruction may proceed. + +| Member | Meaning | +|--------|---------| +| `VV_ALL` | All vector ops before → all vector ops after | +| `VST_VLD` | Vector stores before → vector loads after | +| `VLD_VST` | Vector loads before → vector stores after | +| `VST_VST` | Vector stores before → vector stores after | +| `VS_ALL` | All vector ops before → all scalar ops after | +| `VST_LD` | Vector stores before → scalar loads after | +| `VLD_ST` | Vector loads before → scalar stores after | +| `VST_ST` | Vector stores before → scalar stores after | +| `SV_ALL` | All scalar ops before → all vector ops after | +| `ST_VLD` | Scalar stores before → vector loads after | +| `LD_VST` | Scalar loads before → vector stores after | +| `ST_VST` | Scalar stores before → vector stores after | + +The naming convention: `V` = vector, `S` = scalar, `ST` = store, `LD` = load. `VST_VLD` reads "Vector STore before Vector LoaD." + +### `Pipe` + +Hardware pipeline identifiers used with `pto.set_flag`, `pto.wait_flag`, and `pto.pipe_barrier`. + +| Member | Pipeline | +|--------|----------| +| `S` | Scalar / control pipeline | +| `V` | Vector pipeline (SIMD) | +| `M` | Matrix / Cube pipeline | +| `MTE1` | Memory Transfer Engine 1 | +| `MTE2` | Memory Transfer Engine 2 | +| `MTE3` | Memory Transfer Engine 3 | +| `MTE4` | Memory Transfer Engine 4 | +| `ALL` | All pipelines (for barrier operations) | + +The most commonly used pipes in synchronization are `MTE2` (GM ↔ UB DMA), `MTE3` (UB ↔ UB DMA), `V` (vector compute), and `M` (matrix compute). + +### `event_id` + +Event identifiers for pipeline synchronization flags. The hardware provides 8 event IDs (`0`–`7`) per pipeline pair, supporting up to 8 concurrent in-flight DMA/compute sequences. + +In PTODSL, `event_id` may be either: + +- a Python integer literal in `0`–`7` +- a runtime index-like PTO scalar value + +Events are per-pipeline-pair: the same `event_id=0` used between `MTE2 → V` is independent from `event_id=0` used between `MTE3 → V`. + +--- + +## 10.2 Pipeline synchronization: `set_flag`, `wait_flag`, `pipe_barrier` + +Pipeline synchronization is the primary mechanism for ordering work across pipelines. The pattern is always **signal then wait**: the producer pipeline sets a flag when its work is done; the consumer pipeline waits on that flag before proceeding. + +### `pto.set_flag(pipe_from, pipe_to, *, event_id=0)` + +**Description**: Sets a synchronization flag between two hardware pipelines. The producing pipeline signals that work up to this point is complete. + +**Parameters**: + +| Parameter | Type | Description | +|-----------|------|-------------| +| `pipe_from` | `Pipe` | Source pipeline — the pipeline that has completed its work | +| `pipe_to` | `Pipe` | Destination pipeline — the pipeline being notified | +| `event_id` | `int` or index-like PTO scalar | Event identifier for this specific synchronization point (`0`–`7`) | + +**Returns**: None (side-effect operation). + +**Example**: + + +```python +# MTE2 has finished loading tile data — signal Vector pipeline +pto.set_flag(pto.Pipe.MTE2, pto.Pipe.V, event_id=0) +``` + +### `pto.wait_flag(pipe_from, pipe_to, *, event_id=0)` + +**Description**: Waits for a synchronization flag. The consuming pipeline blocks until the flag is set by the producing pipeline. + +**Parameters**: + +| Parameter | Type | Description | +|-----------|------|-------------| +| `pipe_from` | `Pipe` | Source pipeline that set the flag | +| `pipe_to` | `Pipe` | Destination pipeline — the pipeline that is waiting | +| `event_id` | `int` or index-like PTO scalar | Event identifier matching the corresponding `set_flag` (`0`–`7`) | + +**Returns**: None (side-effect operation). + +**Example**: + + +```python +# Vector pipeline waits for MTE2 to finish loading +pto.wait_flag(pto.Pipe.MTE2, pto.Pipe.V, event_id=0) +``` + +### `pto.pipe_barrier(pipes)` + +**Description**: Executes a barrier across the specified pipelines. All work before the barrier in the named pipelines must complete before any work after the barrier may begin. + +**Parameters**: + +| Parameter | Type | Description | +|-----------|------|-------------| +| `pipes` | `Pipe` | Pipeline specification — typically `Pipe.ALL` for a full barrier | + +**Returns**: None (side-effect operation). + +**Example**: + + +```python +# Full hardware barrier — all pipelines synchronize +pto.pipe_barrier(pto.Pipe.ALL) +``` + +### Typical explicit-mode usage pattern + +A common explicit-mode pattern interleaves DMA and compute with `set_flag` / +`wait_flag` pairs: + + +```python +# Inside a @pto.jit(mode="explicit") body: +def gemm_block( + q_tile: pto.Tile, + k_part: pto.PartitionTensorView, + v_part: pto.PartitionTensorView, + k_tile: pto.Tile, + v_tile: pto.Tile, + p_tile: pto.Tile, + o_tile: pto.Tile, + o_part: pto.PartitionTensorView, + rows: pto.i32, + cols: pto.i32, +): + # DMA: load K and V tiles from GM to UB + row_bytes = cols * pto.bytewidth(pto.f16) + gm_row_stride = k_part.strides[0] * pto.bytewidth(pto.f16) + ub_row_stride = k_tile.shape[1] * pto.bytewidth(pto.f16) + out_row_bytes = cols * pto.bytewidth(pto.f32) + out_gm_row_stride = o_part.strides[0] * pto.bytewidth(pto.f32) + out_ub_row_stride = o_tile.shape[1] * pto.bytewidth(pto.f32) + pto.mte_load(k_part.as_ptr(), k_tile.as_ptr(), 0, row_bytes, + nburst=(rows, gm_row_stride, ub_row_stride)) + pto.mte_load(v_part.as_ptr(), v_tile.as_ptr(), 0, row_bytes, + nburst=(rows, gm_row_stride, ub_row_stride)) + + # Signal: DMA done, UB data ready + pto.set_flag(pto.Pipe.MTE2, pto.Pipe.V, event_id=0) + + # Wait: vector pipeline stalls until data arrives + pto.wait_flag(pto.Pipe.MTE2, pto.Pipe.V, event_id=0) + + # Compute: now safe to use k_tile and v_tile + qk_matmul(q_tile, k_tile, p_tile) + pv_matmul(p_tile, v_tile, o_tile) + + # Signal: compute done, results ready for store + pto.set_flag(pto.Pipe.V, pto.Pipe.MTE3, event_id=1) + pto.wait_flag(pto.Pipe.V, pto.Pipe.MTE3, event_id=1) + + # DMA: store results back to GM + pto.mte_store(o_tile.as_ptr(), o_part.as_ptr(), out_row_bytes, + nburst=(rows, out_ub_row_stride, out_gm_row_stride)) +``` + +--- + +## 10.3 Buffer management: `get_buf`, `rls_buf` + +Double-buffering is a common optimization in NPU kernels: while one buffer is being computed on, the other is being loaded with the next block of data. The `get_buf` / `rls_buf` pair coordinates buffer ownership between pipelines. + +### `pto.get_buf(pipe, buf_id, mode=0)` + +**Description**: Acquire a buffer slot for inter-pipeline double-buffering coordination. The calling pipeline claims ownership of the buffer, blocking if the buffer is still in use by another pipeline. + +**Parameters**: + +| Parameter | Type | Description | +|-----------|------|-------------| +| `pipe` | `Pipe` | Pipeline identifier of the acquiring pipeline | +| `buf_id` | `pto.i64` | Buffer identifier (0-based index into the buffer pool) | +| `mode` | `pto.i64` | Acquisition mode (default 0) | + +**Returns**: None (side-effect operation). + +### `pto.rls_buf(pipe, buf_id, mode=0)` + +**Description**: Release a buffer slot, allowing another pipeline to acquire it. + +**Parameters**: + +| Parameter | Type | Description | +|-----------|------|-------------| +| `pipe` | `Pipe` | Pipeline identifier of the releasing pipeline | +| `buf_id` | `pto.i64` | Buffer identifier matching the corresponding `get_buf` | +| `mode` | `pto.i64` | Release mode (default 0) | + +**Returns**: None (side-effect operation). + +### Double-buffering example + + +```python +# Pipeline V acquires buffer 0 for compute +pto.get_buf(pto.Pipe.V, 0, 0) + +# ... compute into buffer 0 ... + +# Release buffer 0 — DMA can now refill it +pto.rls_buf(pto.Pipe.V, 0, 0) + +# Pipeline MTE2 acquires buffer 0 for reload +pto.get_buf(pto.Pipe.MTE2, 0, 0) + +# ... DMA loads next block into buffer 0 ... + +pto.rls_buf(pto.Pipe.MTE2, 0, 0) +``` + +--- + +## 10.4 Memory barriers: `mem_bar` + +Within a single pipeline, load and store instructions may be reordered by the hardware. `mem_bar` enforces ordering when UB addresses alias between operations — for example, when a store to a region must be visible to a subsequent load from the same region. + +### `pto.mem_bar(barrier_type)` + +**Description**: Inserts a memory barrier that enforces ordering of prior and subsequent instructions within the same pipeline. + +**Parameters**: + +| Parameter | Type | Description | +|-----------|------|-------------| +| `barrier_type` | `BarrierType` | Barrier type controlling which categories of prior instructions must complete before which categories of subsequent instructions may proceed | + +**Returns**: None (side-effect operation). + +**Example**: + + +```python +# Ensure all prior vector stores are visible before any subsequent vector loads +pto.mem_bar(pto.BarrierType.VST_VLD) +``` + +The most commonly used barrier types in practice: + +| Use case | Barrier type | +|----------|--------------| +| General vector ordering | `BarrierType.VV_ALL` | +| Store-then-load to same UB region | `BarrierType.VST_VLD` | +| Vector → scalar handoff | `BarrierType.VS_ALL` | +| Scalar → vector handoff | `BarrierType.SV_ALL` | + +### Usage in explicit orchestration blocks + +In explicit-mode kernels, phase boundaries use `pipe_barrier(Pipe.ALL)`, while +`mem_bar` remains the tool for narrower intra-pipeline ordering: + + +```python +# Inside a @pto.jit(mode="explicit") body: +def flash_attention_block( + q_tile: pto.Tile, + k_part: pto.PartitionTensorView, + v_part: pto.PartitionTensorView, + k_tile: pto.Tile, + v_tile: pto.Tile, + s_tile: pto.Tile, + p_tile: pto.Tile, + pv_tile: pto.Tile, + o_prev_tile: pto.Tile, + o_next_tile: pto.Tile, + rows: pto.i32, + cols: pto.i32, +): + # Phase 1: load K/V + row_bytes = cols * pto.bytewidth(pto.f16) + gm_row_stride = k_part.strides[0] * pto.bytewidth(pto.f16) + ub_row_stride = k_tile.shape[1] * pto.bytewidth(pto.f16) + pto.mte_load(k_part.as_ptr(), k_tile.as_ptr(), 0, row_bytes, + nburst=(rows, gm_row_stride, ub_row_stride)) + pto.mte_load(v_part.as_ptr(), v_tile.as_ptr(), 0, row_bytes, + nburst=(rows, gm_row_stride, ub_row_stride)) + pto.pipe_barrier(pto.Pipe.ALL) + + # Phase 2: S = Q @ K^T + qk_matmul(q_tile, k_tile, s_tile) + pto.pipe_barrier(pto.Pipe.ALL) + + # Phase 3: softmax(S) + online_softmax(s_tile, p_tile, rows, cols) + pto.mem_bar(pto.BarrierType.VV_ALL) + pto.pipe_barrier(pto.Pipe.ALL) + + # Phase 4: PV = P @ V + pv_matmul(p_tile, v_tile, pv_tile) + pto.pipe_barrier(pto.Pipe.ALL) + + # Phase 5: blend output + blend_output(o_prev_tile, pv_tile, o_next_tile, rows, cols) + pto.pipe_barrier(pto.Pipe.ALL) +``` + +--- + +## 10.5 Cross-core and intra-block synchronization + +Section 10.2 covers the general pipe-to-pipe sync mechanism (`set_flag`/`wait_flag`). This section covers two additional sync domains that the pipe-flag mechanism does not address: **cross-core** communication between separate NPU cores, and **intra-block** synchronization between the Cube and Vector units within a block. + +### 10.5.1 Cross-core sync: `set_cross_flag`, `wait_cross_flag` + +When a kernel spans multiple cores, cores need to coordinate through shared resources. `set_cross_flag` sends a signal to another core; `wait_cross_flag` blocks the calling core until the expected signal arrives. + +These are core-level (SU) operations — `wait_cross_flag` stalls the entire core, not just a single pipeline. Use them sparingly: splitting work so that each core operates independently for as long as possible minimises cross-core sync overhead. + +#### `pto.set_cross_flag(pipe, event_id)` + +**Description**: Signal an event on a synchronization endpoint. In the current PTODSL surface this is authored with a `Pipe`; the backend maps it to the architecture-specific cross-core / intra-block builtin during lowering. + +**Parameters**: + +| Parameter | Type | Description | +|-----------|------|-------------| +| `pipe` | `Pipe` | Producing endpoint for the synchronization event. The public DSL accepts `Pipe.FIX` here. | +| `event_id` | `int` | Cross-core event identifier (`0`–`7`) | + +**Returns**: None (side-effect operation). + +**Example**: + + +```python +# Signal from the FIX/Cube-side endpoint +pto.set_cross_flag(pto.Pipe.FIX, 0) +``` + +#### `pto.wait_cross_flag(pipe, event_id)` + +**Description**: Wait for an event on a synchronization endpoint. On architectures that lower this surface to the backend `sync.wait` primitive, the wait is core-level (SU) blocking. + +**Parameters**: + +| Parameter | Type | Description | +|-----------|------|-------------| +| `pipe` | `Pipe` | Waiting endpoint for the synchronization event. The public DSL accepts `Pipe.FIX` here. | +| `event_id` | `int` | Event identifier to wait on (`0`–`7`) | + +**Returns**: None (side-effect operation). + +**Example**: + + +```python +# Wait on the FIX/Cube-side endpoint +pto.wait_cross_flag(pto.Pipe.FIX, 0) +``` + +### 10.5.2 Intra-block sync: `set_intra_flag`, `wait_intra_flag` + +The Cube unit (matrix pipeline) has a dedicated synchronization channel separate from the standard pipe-flag mechanism used by MTE and Vector pipelines. `set_intra_flag` and `wait_intra_flag` synchronize Cube and Vector within the same block, ensuring that shared UB tile data is not accessed before the producer finishes. + +Unlike `wait_cross_flag`, `wait_intra_flag` only stalls the specified pipeline — the SU and other pipelines continue executing. + +#### `pto.set_intra_flag(pipe, event_id)` + +**Description**: Signal a synchronization event within a block. The current PTODSL surface authors the trigger endpoint explicitly as a `Pipe`. + +**Parameters**: + +| Parameter | Type | Description | +|-----------|------|-------------| +| `pipe` | `Pipe` | Trigger endpoint for the synchronization event. The public DSL accepts `Pipe.MTE3` here. | +| `event_id` | `int` | Event identifier (`0`–`7`) | + +**Returns**: None (side-effect operation). + +**Example**: + + +```python +# Signal event ID0 from the MTE3-side endpoint +pto.set_intra_flag(pto.Pipe.MTE3, 0) +``` + +#### `pto.wait_intra_flag(pipe, event_id)` + +**Description**: Wait for an intra-block event. Only the specified pipeline stalls — the SU and other pipelines continue executing independently. + +**Parameters**: + +| Parameter | Type | Description | +|-----------|------|-------------| +| `pipe` | `Pipe` | Waiting endpoint for the synchronization event. The public DSL accepts `Pipe.V` here. | +| `event_id` | `int` | Event identifier to wait on (`0`–`7`) | + +**Returns**: None (side-effect operation). + +**Example**: + + +```python +# Vector-side endpoint waits for event ID0 +pto.wait_intra_flag(pto.Pipe.V, 0) +``` + +## 10.6 Synchronization in the authoring model + +Where do sync operations belong in PTODSL's public entry model? + +| Surface | Sync responsibility | +|---------|---------------------| +| `@pto.jit(mode="auto")` | Users can write sync explicitly when needed. PTOAS also provides an `--enable-insert-sync` option that auto-inserts `set_flag`/`wait_flag` pairs based on op-to-pipe mapping. | +| `@pto.jit(mode="explicit")` | The compiler does not insert sync — the user is fully responsible. Place `set_flag`/`wait_flag` between MTE and compute, `mem_bar` between compute phases, `pipe_barrier` at phase boundaries. | +| Shared `@pto.cube` / `@pto.simd` / `@pto.simt` helpers | Cross-pipeline ordering is provided by the surrounding `@pto.jit` schedule. Helpers may still use `mem_bar` for intra-pipeline ordering when UB addresses alias. | + +**Rule of thumb**: in `mode="auto"`, think in tiles and let the compiler handle +orchestration. In `mode="explicit"`, think in micro-instructions and place the +required sync yourself. + +### Auto-sync at the tile level + +In auto mode, users can still write sync operations directly — `set_flag`/`wait_flag`, `pipe_barrier`, `mem_bar` are available in both modes. For convenience, PTOAS also provides an `--enable-insert-sync` pass: each tile op carries a pipe assignment (e.g., `tile.load` → `PIPE_MTE2`, `tile.add` → `PIPE_V`), and the pass analyzes the op sequence, infers the necessary `set_flag`/`wait_flag` pairs from pipe transitions, and injects them into the lowered code. + +### Quick reference: which sync for which scenario + +| Scenario | Sync primitive | +|----------|----------------| +| DMA load must finish before compute | `set_flag(MTE2, V, event_id=id)` + `wait_flag(MTE2, V, event_id=id)` | +| Compute must finish before DMA store | `set_flag(V, MTE3, event_id=id)` + `wait_flag(V, MTE3, event_id=id)` | +| Two compute phases must not overlap | `mem_bar(BarrierType.VV_ALL)` | +| Store must be visible to later load (same UB) | `mem_bar(BarrierType.VST_VLD)` | +| Full pipeline sync point | `pipe_barrier(Pipe.ALL)` | +| Double-buffer handoff (compute → DMA) | `rls_buf(V, id)` + `get_buf(MTE2, id)` | +| Double-buffer handoff (DMA → compute) | `rls_buf(MTE2, id)` + `get_buf(V, id)` | +| Core A notifies core B | `set_cross_flag(B, id)` + `wait_cross_flag(A, id)` | diff --git a/ptodsl/docs/user_guide/11-flash-attention-walkthrough.md b/ptodsl/docs/user_guide/11-flash-attention-walkthrough.md new file mode 100644 index 000000000..6e5a5e98f --- /dev/null +++ b/ptodsl/docs/user_guide/11-flash-attention-walkthrough.md @@ -0,0 +1,645 @@ +# 11. Flash Attention Complete Walkthrough + +This chapter walks through `examples/flash_attention_sketch.py` layer by layer, tracing a complete flash attention implementation from the user-facing Python wrapper down to hardware-bound sub-kernels. Every API discussed in Chapters 1–10 appears in context here. + +The sketch computes **online-softmax flash attention** for one `(batch, head)` slice per launch instance. It partitions Q into blocks along the sequence dimension, iterates over KV blocks for each Q block, and maintains rolling softmax state across KV iterations. + +## 11.1 Architecture overview + +``` +flash_attention(...) L0 user-facing wrapper + └─ @pto.jit(mode="explicit") flash_attention_kernel + ├─ Tile Ops tile.load / tile.store at the GM↔UB boundary + ├─ explicit orchestration mte_load / pipe_barrier / pointer sequencing + ├─ @pto.cube qk_matmul / pv_matmul + ├─ @pto.simd online_softmax_rows + └─ @pto.simt materialize_tile_bounds / blend_output_rows +``` + +The dataflow for one KV block: + +``` +explicit-mode orchestration loads the K/V block and sequences the pipeline + │ + ├─ cube: Q + K ───────────────► S + ├─ simd: S + (m_prev, l_prev) ─► P, (m_next, l_next), alpha, beta + ├─ cube: P + V ───────────────► PV + └─ simt: (o_prev, PV, alpha, beta) ─► o_next + +After each KV block: + (m_prev, l_prev, o_prev) := (m_next, l_next, o_next) +``` + +## 11.2 The Python wrapper + +```python +def flash_attention(Q, K, V, *, O=None, causal=False, + block_q=128, block_kv=128, stream=None): + if O is None: + O = pto.empty_like(Q) + + batch, seq_q, heads, dim = Q.shape + _, seq_k, _, _ = K.shape + + compiled = flash_attention_kernel.compile( + BLOCK_Q=block_q, BLOCK_KV=block_kv, CAUSAL=causal, + ) + compiled[batch * heads, stream](Q, K, V, O) + return O +``` + +This is plain Python — no PTO types, no IR. It handles ergonomic runtime concerns: + +- **Output allocation**: `pto.empty_like(Q)` when the caller doesn't provide one. +- **Shape extraction**: reads `batch`, `seq_q`, `heads`, `dim` from the framework tensors. +- **Compile + launch**: `flash_attention_kernel.compile(...)` JIT-compiles the kernel with the given constexpr parameters, then launches it with a `[batch * heads]` grid — one block per `(batch, head)` slice. + +The wrapper knows nothing about tiles, UB, or pipelines. It is the boundary between the user's tensor world and the PTO device world. + +## 11.3 Top-level `@pto.jit(mode="explicit")` kernel entry + + +```python +@pto.jit(target="a5", mode="explicit") +def flash_attention_kernel( + Q: pto.tensor_spec(rank=4, dtype=pto.f32), + K: pto.tensor_spec(rank=4, dtype=pto.f32), + V: pto.tensor_spec(rank=4, dtype=pto.f32), + O: pto.tensor_spec(rank=4, dtype=pto.f32), + *, + BLOCK_Q: pto.constexpr = 128, + BLOCK_KV: pto.constexpr = 128, + CAUSAL: pto.constexpr = False, + NUM_STAGES: pto.constexpr = 2, +): + # Walkthrough body omitted in this signature overview. + return +``` + +The `@pto.jit(mode="explicit")` decorator marks the compile + launch boundary. Inputs are Python-native tensors; outputs are written in-place to `O`. Keyword-only `constexpr` parameters (`BLOCK_Q`, `BLOCK_KV`, `CAUSAL`) are baked at compile time. + +### 11.3.1 TensorView construction + + +```python +q_view = pto.make_tensor_view(Q, shape=[batch, seq_q, heads, dim], + strides=Q.strides) +k_view = pto.make_tensor_view(K, shape=[batch, seq_k, heads, dim], + strides=K.strides) +v_view = pto.make_tensor_view(V, shape=[batch, seq_k, heads, dim], + strides=V.strides) +o_view = pto.make_tensor_view(O, shape=[batch, seq_q, heads, dim], + strides=O.strides) +``` + +`make_tensor_view` wraps each framework tensor with a PTO TensorView descriptor — a GM pointer paired with shape and stride metadata. These descriptors are what the rest of the kernel uses to address global memory. No data moves yet. + +### 11.3.2 SPMD launch contract + + +```python +block_idx = pto.get_block_idx() +block_num = pto.get_block_num() +subblock_idx = pto.get_subblock_idx() +subblock_num = pto.get_subblock_num() + +batch_idx = block_idx // heads +head_idx = block_idx % heads +``` + +The launch grid is `[batch * heads]`. Each block computes one `(batch, head)` slice. `get_block_idx()` returns the current block's linear index; dividing by `heads` recovers the batch and head indices. + +### 11.3.3 Per-head view partitioning + + +```python +q_head = pto.partition_view( + q_view, + offsets=[batch_idx, 0, head_idx, 0], + sizes=[1, seq_q, 1, dim], +) +k_head = pto.partition_view( + k_view, + offsets=[batch_idx, 0, head_idx, 0], + sizes=[1, seq_k, 1, dim], +) +v_head = pto.partition_view( + v_view, + offsets=[batch_idx, 0, head_idx, 0], + sizes=[1, seq_k, 1, dim], +) +o_head = pto.partition_view( + o_view, + offsets=[batch_idx, 0, head_idx, 0], + sizes=[1, seq_q, 1, dim], +) +``` + +There is no dedicated `select_head_view` public helper anymore. Each `(batch, head)` working set is sliced from the 4D TensorView with the standard `partition_view(...)` surface, and further logical slicing composes on top of the same primitive. + +### 11.3.4 Tile allocation + +Three categories of tiles are allocated: + +**MAT-backed bridge tiles** — the logical Q/K/V/P blocks that feed the cube path: + + +```python +q_mat = pto.alloc_tile( + shape=[Br, D], + dtype=pto.f32, + memory_space=pto.MemorySpace.MAT, + valid_shape=[full_br, dim], + blayout="ColMajor", + slayout="RowMajor", +) +k_mat = pto.alloc_tile( + shape=[Bc, D], + dtype=pto.f32, + memory_space=pto.MemorySpace.MAT, + valid_shape=[full_bc, dim], + blayout="ColMajor", + slayout="RowMajor", +) +v_mat = pto.alloc_tile( + shape=[Bc, D], + dtype=pto.f32, + memory_space=pto.MemorySpace.MAT, + valid_shape=[full_bc, dim], + blayout="ColMajor", + slayout="RowMajor", +) +p_mat = pto.alloc_tile( + shape=[Br, Bc], + dtype=pto.f32, + memory_space=pto.MemorySpace.MAT, + valid_shape=[full_br, full_bc], + blayout="ColMajor", + slayout="RowMajor", +) + +o_prev_tile = pto.alloc_tile(shape=[Br, D], dtype=pto.f32, valid_shape=[full_br, dim]) +o_next_tile = pto.alloc_tile(shape=[Br, D], dtype=pto.f32, valid_shape=[full_br, dim]) +m_prev_tile = pto.alloc_tile(shape=[Br, 1], dtype=pto.f32, valid_shape=[full_br, one], blayout="ColMajor") +m_next_tile = pto.alloc_tile(shape=[Br, 1], dtype=pto.f32, valid_shape=[full_br, one], blayout="ColMajor") +l_prev_tile = pto.alloc_tile(shape=[Br, 1], dtype=pto.f32, valid_shape=[full_br, one], blayout="ColMajor") +l_next_tile = pto.alloc_tile(shape=[Br, 1], dtype=pto.f32, valid_shape=[full_br, one], blayout="ColMajor") + +s_tile = pto.alloc_tile(shape=[Br, Bc], dtype=pto.f32, valid_shape=[full_br, full_bc]) +p_tile = pto.alloc_tile(shape=[Br, Bc], dtype=pto.f32, valid_shape=[full_br, full_bc]) +pv_tile = pto.alloc_tile(shape=[Br, D], dtype=pto.f32, valid_shape=[full_br, dim]) +alpha_tile = pto.alloc_tile(shape=[Br, 1], dtype=pto.f32, valid_shape=[full_br, one], blayout="ColMajor") +beta_tile = pto.alloc_tile(shape=[Br, 1], dtype=pto.f32, valid_shape=[full_br, one], blayout="ColMajor") +``` + +The walkthrough keeps Q/K/V/P on the MAT path so the cube sub-kernels consume the same tile objects that the top-level kernel owns. Runtime tails still live in `valid_shape`; the physical tile shapes stay static. + +**UB-resident state and scratch tiles** — the online-softmax state plus intermediate outputs: + +```python +o_prev_tile = pto.alloc_tile(shape=[Br, D], dtype=pto.f32, valid_shape=[full_br, dim]) +o_next_tile = pto.alloc_tile(shape=[Br, D], dtype=pto.f32, valid_shape=[full_br, dim]) +m_prev_tile = pto.alloc_tile(shape=[Br, 1], dtype=pto.f32, valid_shape=[full_br, one], blayout="ColMajor") +m_next_tile = pto.alloc_tile(shape=[Br, 1], dtype=pto.f32, valid_shape=[full_br, one], blayout="ColMajor") +l_prev_tile = pto.alloc_tile(shape=[Br, 1], dtype=pto.f32, valid_shape=[full_br, one], blayout="ColMajor") +l_next_tile = pto.alloc_tile(shape=[Br, 1], dtype=pto.f32, valid_shape=[full_br, one], blayout="ColMajor") + +s_tile = pto.alloc_tile(shape=[Br, Bc], dtype=pto.f32, valid_shape=[full_br, full_bc]) +p_tile = pto.alloc_tile(shape=[Br, Bc], dtype=pto.f32, valid_shape=[full_br, full_bc]) +pv_tile = pto.alloc_tile(shape=[Br, D], dtype=pto.f32, valid_shape=[full_br, dim]) +alpha_tile = pto.alloc_tile(shape=[Br, 1], dtype=pto.f32, valid_shape=[full_br, one], blayout="ColMajor") +beta_tile = pto.alloc_tile(shape=[Br, 1], dtype=pto.f32, valid_shape=[full_br, one], blayout="ColMajor") +``` + +The online-softmax algorithm requires **ping-pong state tiles**: `m_prev`/`m_next`, `l_prev`/`l_next`, `o_prev`/`o_next`. After each KV block, `next` becomes `prev` for the following iteration. + +**Cube-local scratch tiles** — allocated in specific memory spaces: + + +```python +q_l0a = pto.alloc_tile(shape=[Br, D], dtype=pto.f32, + memory_space=pto.MemorySpace.LEFT, valid_shape=[full_br, dim]) +p_l0a = pto.alloc_tile(shape=[Br, Bc], dtype=pto.f32, + memory_space=pto.MemorySpace.LEFT, valid_shape=[full_br, full_bc]) +rhs_l0b = pto.alloc_tile(shape=[Bc, D], dtype=pto.f32, + memory_space=pto.MemorySpace.RIGHT, valid_shape=[full_bc, dim]) +qk_acc_tile = pto.alloc_tile(shape=[Br, Bc], dtype=pto.f32, + memory_space=pto.MemorySpace.ACC, valid_shape=[full_br, full_bc]) +pv_acc_tile = pto.alloc_tile(shape=[Br, D], dtype=pto.f32, + memory_space=pto.MemorySpace.ACC, valid_shape=[full_br, dim]) +``` + +Cube scratch tiles are NOT UB buffers. `LEFT`, `RIGHT`, and `ACC` are distinct hardware memory spaces inside the Cube unit. They serve as staging for matrix operands and accumulators. + +### 11.3.5 SIMT metadata buffer + + +```python +meta_tile = pto.alloc_tile(shape=[1, 8], dtype=pto.i32, valid_shape=[1, 3]) +meta_ptr = meta_tile.as_ptr() +``` + +A small UB tile stores three scalar loop bounds (`row_start`, `row_stop`, `valid_cols`). `meta_tile.as_ptr()` materializes a typed UB pointer into it, which is passed to the explicit-mode orchestration as scalar control metadata. + +Notice that the row-wise softmax state tiles (`m_*`, `l_*`, `alpha_tile`, +`beta_tile`) are authored as `blayout="ColMajor"`. This is the intended public +surface for logical column vectors; it avoids forcing users to manufacture a +row-major padded physical width just to satisfy row-byte alignment. + +### 11.3.6 Outer Q loop + inner KV loop + + +```python +with pto.for_(0, q_blocks, step=1) as qi: + q_rows = _block_valid_extent(seq_q, qi, Br) + q_part = pto.partition_view(q_head, offsets=[0, qi * Br, 0, 0], + sizes=[1, q_rows, 1, dim]) + o_part = pto.partition_view(o_head, offsets=[0, qi * Br, 0, 0], + sizes=[1, q_rows, 1, dim]) + + q_mat.valid_shape = [q_rows, dim] + o_prev_tile.valid_shape = [q_rows, dim] + o_next_tile.valid_shape = [q_rows, dim] + m_prev_tile.valid_shape = [q_rows, one] + m_next_tile.valid_shape = [q_rows, one] + l_prev_tile.valid_shape = [q_rows, one] + l_next_tile.valid_shape = [q_rows, one] + alpha_tile.valid_shape = [q_rows, one] + beta_tile.valid_shape = [q_rows, one] + p_mat.valid_shape = [q_rows, full_bc] + pv_tile.valid_shape = [q_rows, dim] + q_l0a.valid_shape = [q_rows, dim] + + pto.tile.load(q_part, q_mat) + + m_prev_tile.fill(float("-inf")) + l_prev_tile.fill(0.0) + o_prev_tile.fill(0.0) + + kv_loop = pto.for_(0, kv_blocks, step=1).carry( + m=m_prev_tile, l=l_prev_tile, o=o_prev_tile, + ) + with kv_loop: + kj = kv_loop.iv + m_cur = kv_loop.m + l_cur = kv_loop.l + o_cur = kv_loop.o + kv_rows = _block_valid_extent(seq_k, kj, Bc) + k_part = pto.partition_view(k_head, + offsets=[0, kj * Bc, 0, 0], sizes=[1, kv_rows, 1, dim]) + v_part = pto.partition_view(v_head, + offsets=[0, kj * Bc, 0, 0], sizes=[1, kv_rows, 1, dim]) + + k_mat.valid_shape = [kv_rows, dim] + v_mat.valid_shape = [kv_rows, dim] + s_tile.valid_shape = [q_rows, kv_rows] + p_tile.valid_shape = [q_rows, kv_rows] + p_mat.valid_shape = [q_rows, kv_rows] + pv_tile.valid_shape = [q_rows, dim] + p_l0a.valid_shape = [q_rows, kv_rows] + rhs_l0b.valid_shape = [kv_rows, dim] + qk_acc_tile.valid_shape = [q_rows, kv_rows] + pv_acc_tile.valid_shape = [q_rows, dim] + + kv_block_process( + q_mat, k_part, v_part, k_mat, v_mat, + o_cur, o_next_tile, + m_cur, l_cur, m_next_tile, l_next_tile, + s_tile, p_tile, p_mat, pv_tile, + alpha_tile, beta_tile, + q_l0a, p_l0a, rhs_l0b, + qk_acc_tile, pv_acc_tile, + meta_ptr, + ) + + kv_loop.update(m=m_next_tile, l=l_next_tile, o=o_next_tile) + + o_final_tile = kv_loop.final("o") + pto.tile.store(o_final_tile, o_part) +``` + +Key points: + +- **Static physical shape, dynamic valid extent**: `alloc_tile(shape=...)` stays constexpr. Tail handling is expressed by updating `valid_shape` before each block load and sub-kernel call. +- **`tile.load` at the kernel entry boundary**: Q is loaded once per Q block using a tile op into the MAT-backed bridge tile `q_mat`. The compiler auto-inserts the necessary `set_flag`/`wait_flag` pairs. +- **State initialization**: `fill(float("-inf"))` and `fill(0.0)` initialize the online-softmax accumulators before the first KV block. +- **Carry state**: the inner `kv_loop` carries three ping-pong tiles (`m`, `l`, `o`) across iterations using `.carry(...)` / `.update(...)` / `.final(...)`. After each KV block, the loop updates the carried values to the `_next` tiles. After the loop, `.final("o")` extracts the final output accumulator. +- **`tile.store` at the kernel entry boundary**: writes the final result for this Q block back to GM. + +## 11.4 Explicit orchestration + +```python +# Explicit orchestration helper used by flash_attention_kernel: +def kv_block_process( + q_mat, k_part, v_part, k_mat, v_mat, + o_prev_tile, o_next_tile, + m_prev_tile, l_prev_tile, m_next_tile, l_next_tile, + s_tile, p_tile, p_mat, pv_tile, + alpha_tile, beta_tile, + q_l0a, p_l0a, rhs_l0b, + qk_acc_tile, pv_acc_tile, + meta_ptr, +): +``` + +The explicit-mode body processes one KV block against an already-loaded Q tile. It owns the execution sandwich: + +### Phase 0 — Stage K/V data + + +```python +rows = k_mat.valid_shape[0] +cols = k_mat.valid_shape[1] +row_bytes = cols * pto.bytewidth(pto.f32) +gm_row_stride = k_part.strides[0] * pto.bytewidth(pto.f32) +mat_row_stride = k_mat.shape[1] * pto.bytewidth(pto.f32) +pto.mte_load(k_part.as_ptr(), k_mat.as_ptr(), 0, row_bytes, + nburst=(rows, gm_row_stride, mat_row_stride)) +pto.mte_load(v_part.as_ptr(), v_mat.as_ptr(), 0, row_bytes, + nburst=(rows, gm_row_stride, mat_row_stride)) +pto.pipe_barrier(pto.Pipe.ALL) +``` + +`mte_load` is the ptr-based GM→MAT DMA wrapper used by this walkthrough. Explicit mode passes GM/MAT pointers plus the DMA grouping parameters, and `pipe_barrier(Pipe.ALL)` makes the phase boundary explicit before the cube unit reads `k_mat`/`v_mat`. + +### Phase 0b — Materialize loop bounds + + +```python +materialize_tile_bounds(meta_ptr, + q_mat.valid_shape[0], + k_mat.valid_shape[0]) +row_start = scalar.load(meta_ptr + 0) +row_stop = scalar.load(meta_ptr + 1) +valid_cols = scalar.load(meta_ptr + 2) +``` + +The SIMT sub-kernel `materialize_tile_bounds` writes `{0, valid_rows, valid_cols}` into the metadata buffer. The explicit-mode body then loads these scalars. They control the row iteration range in subsequent sub-kernels, handling partial tail blocks. + +### Phase 1 — `S = Q @ K^T` + + +```python +qk_matmul(q_mat, k_mat, q_l0a, rhs_l0b, qk_acc_tile, s_tile) +pto.pipe_barrier(pto.Pipe.ALL) +``` + +Dispatches the cube sub-kernel. `pipe_barrier(Pipe.ALL)` separates the matrix multiply from the subsequent softmax. + +### Phase 2 — Online softmax + + +```python +online_softmax_rows( + s_tile, p_tile, + m_prev_tile, l_prev_tile, + m_next_tile, l_next_tile, + alpha_tile, beta_tile, + row_start, row_stop, valid_cols, +) +pto.pipe_barrier(pto.Pipe.ALL) +``` + +The simd sub-kernel computes per-row softmax on `S`, updates the running `m`/`l` state, and writes `P`, `alpha`, and `beta`. + +### Phase 3 — `PV = P @ V` + + +```python +pto.tile.mov(p_tile, p_mat) +pto.pipe_barrier(pto.Pipe.ALL) + +pv_matmul(p_mat, v_mat, p_l0a, rhs_l0b, pv_acc_tile, pv_tile) +pto.pipe_barrier(pto.Pipe.ALL) +``` + +The probability tile is first staged onto the MAT path with `pto.tile.mov(p_tile, p_mat)`. Then the second cube dispatch reuses `rhs_l0b` for `V` and `pv_acc_tile` for the accumulator. + +### Phase 4 — Blend output + + +```python +blend_output_rows( + o_prev_tile, pv_tile, alpha_tile, beta_tile, + o_next_tile, row_start, row_stop, + v_mat.valid_shape[1], +) +pto.pipe_barrier(pto.Pipe.ALL) +``` + +The simt sub-kernel blends the old output accumulator with the new PV contribution, weighted by `alpha` and `beta`. + +### Why explicit mode owns sync ordering + +Each `pipe_barrier(Pipe.ALL)` between phases is explicit in the orchestration body. This is intentional: at the orchestration boundary, the user controls pipeline ordering. Auto mode may still use synchronization primitives where needed, but it does so around compiler-managed tile staging rather than user-authored instruction scheduling. + +## 11.5 Cube sub-kernel — `@pto.cube` + +### `qk_matmul` — `S = Q @ K^T` + + +```python +@pto.cube +def qk_matmul(q_mat, k_mat, q_l0a, k_l0b, s_acc, s_tile): + m = q_mat.valid_shape[0] + k = q_mat.valid_shape[1] + n = k_mat.valid_shape[0] + + pto.mte_l1_l0a(q_mat.as_ptr(), q_l0a.as_ptr(), m, k) + pto.mte_l1_l0b(k_mat.as_ptr(), k_l0b.as_ptr(), k, n, transpose=True) + pto.mad(q_l0a.as_ptr(), k_l0b.as_ptr(), s_acc.as_ptr(), m, n, k) + pto.mte_l0c_ub(s_acc.as_ptr(), s_tile.as_ptr(), m, n, n, n, 0) +``` + +Four cube ops: + +1. **`mte_l1_l0a`**: load Q tile from UB into LEFT scratch (`q_l0a`). +2. **`mte_l1_l0b`**: load K tile from UB into RIGHT scratch (`k_l0b`), with `transpose=True` for K^T. +3. **`mad`**: matrix multiply-accumulate — `s_acc = q_l0a @ k_l0b`. +4. **`mte_l0c_ub`**: write the accumulator result to the UB output tile `s_tile`. + +The cube kernel does not allocate scratch — the caller (top-level kernel) owns scratch lifetime. The cube kernel only expresses dataflow. + +### `pv_matmul` — `PV = P @ V` + + +```python +@pto.cube +def pv_matmul(p_mat, v_mat, p_l0a, v_l0b, pv_acc, pv_tile): + m = p_mat.valid_shape[0] + k = p_mat.valid_shape[1] + n = v_mat.valid_shape[1] + + pto.mte_l1_l0a(p_mat.as_ptr(), p_l0a.as_ptr(), m, k) + pto.mte_l1_l0b(v_mat.as_ptr(), v_l0b.as_ptr(), k, n) + pto.mad(p_l0a.as_ptr(), v_l0b.as_ptr(), pv_acc.as_ptr(), m, n, k) + pto.mte_l0c_ub(pv_acc.as_ptr(), pv_tile.as_ptr(), m, n, n, n, 0) +``` + +Structurally identical to `qk_matmul`, but without transposition and with different input/output tiles. The scratch tiles `p_l0a`, `v_l0b`, and `pv_acc` are reused across KV blocks — the caller (top-level kernel) allocates them once. + +## 11.6 SIMD sub-kernel — online softmax + +```python +@pto.simd +def online_softmax_rows( + s_tile, p_tile, + m_prev_tile, l_prev_tile, + m_next_tile, l_next_tile, + alpha_tile, beta_tile, + row_start, row_stop, valid_cols, +): +``` + +The simd kernel iterates over rows with `pto.for_`, processing one row per iteration: + + +```python +with pto.for_(row_start, row_stop, step=1) as row: + col_mask = pto.make_mask(pto.f32, valid_cols) + + s_row = pto.vlds(s_tile[row, 0:]) + m_prev = scalar.load(m_prev_tile[row, 0]) + l_prev = scalar.load(l_prev_tile[row, 0]) +``` + +- **Mask creation**: `make_mask(pto.f32, valid_cols)` generates a tail mask for the column dimension. On the last KV block, `valid_cols` may be less than the full block width. +- **Vector load**: `vlds(s_tile[row, 0:])` loads one entire row of `S` from UB into a vector register. The slice syntax `[row, 0:]` selects the full row. +- **Scalar load**: `lds` reads per-row scalars (`m_prev`, `l_prev`) from the state tiles. + +### Softmax computation + + +```python + row_max = pto.vcgmax(s_row, col_mask) + m_next = scalar.max(m_prev, row_max) + + s_shifted = pto.vsubs(s_row, m_next, col_mask) + p_row = pto.vexp(s_shifted, col_mask) + + row_sum = pto.vcgadd(p_row, col_mask) + l_scaled = l_prev * scalar.exp(m_prev - m_next) + l_next = l_scaled + row_sum + + alpha = l_scaled / l_next + beta = 1.0 / l_next +``` + +This implements the online-softmax update from the Flash Attention paper: + +- `vcgmax` (cross-lane max reduction) finds the row maximum. +- `max(m_prev, m_next)` combines with the running maximum. +- `vsubs` subtracts the scalar `m_next` from every lane (stabilized softmax). +- `vexp` computes `exp(s_shifted)` element-wise. +- `vcgadd` (cross-lane sum reduction) computes the row sum. +- `l_scaled` rescales the previous sum with the running-max correction factor. +- `alpha` and `beta` are the blending coefficients for the output update. + +### Store results + + +```python + pto.vsts(p_row, p_tile[row, 0:], col_mask) + scalar.store(m_next, m_next_tile[row, 0]) + scalar.store(l_next, l_next_tile[row, 0]) + scalar.store(alpha, alpha_tile[row, 0]) + scalar.store(beta, beta_tile[row, 0]) +``` + +- `vsts` stores the vector `p_row` back to UB under the column mask. +- `sts` stores each scalar to its respective UB tile. + +**Boundary contract**: vreg values (`s_row`, `p_row`, `row_max`, `row_sum`) never escape the simd kernel. All persistent state is written to UB tiles. + +## 11.7 SIMT sub-kernel — blend output + +### `materialize_tile_bounds` — scalar metadata + + +```python +@pto.simt +def materialize_tile_bounds(meta_ptr, valid_rows, valid_cols): + scalar.store(0, meta_ptr + 0) + scalar.store(valid_rows, meta_ptr + 1) + scalar.store(valid_cols, meta_ptr + 2) +``` + +Three scalar stores write the loop bounds into the metadata buffer. `meta_ptr` is a typed UB pointer; `+ 0`, `+ 1`, `+ 2` are element offsets into `i32` storage, not byte offsets. This is the simplest sub-kernel in the sketch — it handles scalar control metadata, not vector math. + +### `blend_output_rows` — output accumulation + + +```python +@pto.simt +def blend_output_rows(o_prev_tile, pv_tile, alpha_tile, beta_tile, + o_next_tile, row_start, row_stop, valid_dim): + with pto.for_(row_start, row_stop, step=1) as row: + alpha = scalar.load(alpha_tile[row, 0]) + beta = scalar.load(beta_tile[row, 0]) + + with pto.for_(0, valid_dim, step=1) as col: + o_prev = scalar.load(o_prev_tile[row, col]) + pv_val = scalar.load(pv_tile[row, col]) + o_next = alpha * o_prev + beta * pv_val + scalar.store(o_next, o_next_tile[row, col]) +``` + +This is a scalar element-wise blend over the tile domain: + +``` +O_next[row, col] = alpha[row] * O_prev[row, col] + beta[row] * PV[row, col] +``` + +The SIMT kernel walks the tile element by element with nested `pto.for_` loops. Each iteration loads two scalars (`o_prev` and `pv_val`), computes the weighted sum, and stores the result. The `alpha`/`beta` coefficients are per-row (loaded once per row), while the blend is per-element. + +**Why SIMT instead of SIMD?** The intent is to contrast with `online_softmax_rows`: softmax is dominated by row-wise vector reductions and exponentials — natural SIMD work. The final blend is a simple linear combination with per-row coefficients — expressing it as explicit scalar work-items makes the per-element access pattern explicit and leaves the compiler free to vectorize or fuse as it sees fit. + +### Context manager alternative + +For trivial sub-kernels like `materialize_tile_bounds`, a named function is overkill — the context manager form keeps the logic inline where it's used. The inline SIMT scope itself looks like this: + + +```python +with pto.simt(): + scalar.store(0, meta_ptr + 0) + scalar.store(q_mat.valid_shape[0], meta_ptr + 1) + scalar.store(k_mat.valid_shape[0], meta_ptr + 2) +``` + +The `with pto.simt():` block acts as an anonymous inline sub-kernel scope. For 3-line helpers that have no reuse, the context manager avoids the indirection of a separate function. For complex, reusable logic like `online_softmax_rows` or `qk_matmul`, the named decorator form remains the better fit. + +## 11.8 Putting it all together: one KV block execution + +For one KV block, the full execution sequence is: + +| Step | Layer | Operation | Hardware | +|------|-------|-----------|----------| +| 1 | explicit | `tile.load(q_part, q_mat)` | GM → MAT | +| 2 | explicit | `mte_load(k_part.as_ptr(), k_mat.as_ptr(), ...)` | GM → MAT | +| 3 | explicit | `mte_load(v_part.as_ptr(), v_mat.as_ptr(), ...)` | GM → MAT | +| 4 | explicit | `pipe_barrier(Pipe.ALL)` | — | +| 5 | simt | `materialize_tile_bounds` | SIMT | +| 6 | cube | `qk_matmul` (mte_l1_l0a, mte_l1_l0b, mad, mte_l0c_ub) | Cube | +| 7 | explicit | `pipe_barrier(Pipe.ALL)` | — | +| 8 | simd | `online_softmax_rows` (vlds, vcgmax, vexp, vcgadd, vsts, ...) | SIMD | +| 9 | explicit | `pipe_barrier(Pipe.ALL)` | — | +| 10 | explicit | `tile.mov(p_tile, p_mat)` | Tile copy | +| 11 | explicit | `pipe_barrier(Pipe.ALL)` | — | +| 12 | cube | `pv_matmul` | Cube | +| 13 | explicit | `pipe_barrier(Pipe.ALL)` | — | +| 14 | simt | `blend_output_rows` | SIMT | +| 15 | explicit | `pipe_barrier(Pipe.ALL)` | — | + +After all KV blocks: the top-level kernel issues `tile.store(o_final_tile, o_part)` to write the result back to GM. + +## 11.9 Design patterns in this sketch + +**Ping-pong state for online accumulators**: `m_prev`/`m_next`, `l_prev`/`l_next`, `o_prev`/`o_next` make the state transition explicit. After each KV block, the caller swaps the ping-pong pair (via `kv_loop.update(...)`) rather than aliasing in place. + +**Scratch reuse**: `rhs_l0b` serves both `K` (in `qk_matmul`) and `V` (in `pv_matmul`). `pv_acc_tile` reuses the accumulator from QK^T. The caller (top-level kernel) allocates once; the explicit-mode body passes them to both cube sub-kernels. + +**Tile-level boundary vs micro-instruction boundary**: `tile.load`/`tile.store` are the tile-atomic surface used in auto mode and at the top-level tile boundary of this sketch. `mte_load` appears in explicit orchestration, authored as individual pointer-based instructions. The abstraction split is auto mode as tile-centric authoring, explicit mode as user-ordered orchestration. + +**No vreg across sub-kernel boundaries**: vector registers are local to each `@pto.simd` kernel. Data crosses sub-kernel boundaries through UB tiles — the boundary contract is enforced by the type system. + +**Invocation flexibility**: This sketch uses the explicit `@pto.jit(mode="explicit")` path for full micro-instruction control. The same named sub-kernels can also be reused from `@pto.jit(mode="auto")` when the body stays within the auto-mode contract, or written inline as context managers (`with pto.simd():`, etc.). See Chapter 3 for details. diff --git a/ptodsl/docs/user_guide/12-additional-examples.md b/ptodsl/docs/user_guide/12-additional-examples.md new file mode 100644 index 000000000..ff4937dda --- /dev/null +++ b/ptodsl/docs/user_guide/12-additional-examples.md @@ -0,0 +1,382 @@ +# 12. Additional Examples + +This chapter presents four self-contained examples that build on the concepts introduced in Chapters 1–11. Each example demonstrates a specific pattern: blocked 2D processing, tail handling with masks, matrix multiplication on the Cube unit, and loop-carried state for online normalization. + +## 12.1 Blocked 2D elementwise addition + +Chapter 2 showed a 1D vector add with a single blocking dimension. Real workloads often involve 2D tensors — matrices — where blocking happens along both rows and columns. + +```python +@pto.jit(target="a5") +def mat_add(A, B, O, *, BLOCK_M: pto.constexpr = 64, BLOCK_N: pto.constexpr = 128): + M, N_ = A.shape + + a_view = pto.make_tensor_view(A, shape=[M, N_], strides=A.strides) + b_view = pto.make_tensor_view(B, shape=[M, N_], strides=B.strides) + o_view = pto.make_tensor_view(O, shape=[M, N_], strides=O.strides) + + a_tile = pto.alloc_tile(shape=[BLOCK_M, BLOCK_N], dtype=pto.f32) + b_tile = pto.alloc_tile(shape=[BLOCK_M, BLOCK_N], dtype=pto.f32) + o_tile = pto.alloc_tile(shape=[BLOCK_M, BLOCK_N], dtype=pto.f32) + + num_m = (M + BLOCK_M - 1) // BLOCK_M + num_n = (N_ + BLOCK_N - 1) // BLOCK_N + + with pto.for_(0, num_m, step=1) as mi: + m_off = mi * BLOCK_M + with pto.for_(0, num_n, step=1) as ni: + n_off = ni * BLOCK_N + + a_part = pto.partition_view(a_view, offsets=[m_off, n_off], sizes=[BLOCK_M, BLOCK_N]) + b_part = pto.partition_view(b_view, offsets=[m_off, n_off], sizes=[BLOCK_M, BLOCK_N]) + o_part = pto.partition_view(o_view, offsets=[m_off, n_off], sizes=[BLOCK_M, BLOCK_N]) + + pto.tile.load(a_part, a_tile) + pto.tile.load(b_part, b_tile) + pto.tile.add(a_tile, b_tile, o_tile) + pto.tile.store(o_tile, o_part) +``` + +**Key points**: + +- Nested `pto.for_` loops produce a 2D block traversal. Both loops are recorded as device-side control flow — they adapt to the runtime shape `M`. +- Tile shape `[BLOCK_M, BLOCK_N]` is 2D; all three tiles use the same shape so `tile.add` is elementwise. +- `partition_view` takes 2D offsets and sizes. +- `BLOCK_M` and `BLOCK_N` are `constexpr` — the compiler specializes the kernel per tile shape. + +The Python wrapper follows the same pattern as Chapter 2: + + +```python +def mat_add_wrapper(A, B, O=None, stream=None): + if O is None: + O = pto.empty_like(A) + compiled = mat_add.compile(BLOCK_M=64, BLOCK_N=128) + m, n = A.shape[1], A.shape[2] # assuming batch-first: [batch, M, N] + compiled[A.shape[0], stream](A, B, O) + return O +``` + +The grid is `A.shape[0]` so each SPMD block processes one slice of the leading batch dimension. + +## 12.2 Vector operations with tail handling + +When a data dimension is not evenly divisible by the tile size or the hardware vector width, the last iteration must operate on fewer elements. PTODSL provides masks for this — `make_mask` produces a predicate that guards loads, computes, and stores so out-of-bounds lanes are not touched. + +### 12.2.1 Tail handling in a SIMD kernel + +Below is a self-contained `@pto.simd` kernel that adds two tiles row by row, handling column tails with `make_mask`: + + +```python +@pto.simd +def add_rows_with_tail(a_tile: pto.Tile, b_tile: pto.Tile, o_tile: pto.Tile, + rows: pto.i32, cols: pto.i32): + VEC = pto.elements_per_vreg(pto.f32) # 64 for f32 + + with pto.for_(0, rows, step=1) as r: + col_loop = pto.for_(0, cols, step=VEC).carry(remained=cols) + with col_loop: + c = col_loop.iv + remained = col_loop.remained + mask, remained = pto.make_mask(pto.f32, remained) + + a_vec = pto.vlds(a_tile[r, c:]) # load under mask + b_vec = pto.vlds(b_tile[r, c:]) + o_vec = pto.vadd(a_vec, b_vec, mask) # compute under mask + pto.vsts(o_vec, o_tile[r, c:], mask) # store under mask + + col_loop.update(remained=remained) +``` + +The pattern: + +1. **Chunk**: Each iteration processes `VEC` elements (one vector register's worth). +2. **Mask**: `make_mask` returns a predicate and the updated remainder. On the last iteration, where `remained < VEC`, the mask has `remained` valid lanes followed by inactive lanes. +3. **Guard**: `vlds`, `vadd`, and `vsts` all accept the mask — inactive lanes are neither loaded, computed, nor stored. +4. **Carry**: `.carry(remained=cols)` carries the remaining column count across iterations. `col_loop.update(remained=remained)` feeds the updated count to the next iteration. + +### 12.2.2 Tile-level tail handling + +At the Tile Op level, tail handling is built into `tile.load` and `tile.store`. When a partition size along a dimension is smaller than the tile size, the tile's `valid_shape` tracks the actual data extent: + + +```python +@pto.jit(target="a5") +def vec_add_with_tail( + A: pto.tensor_spec(rank=1, dtype=pto.f32), + B: pto.tensor_spec(rank=1, dtype=pto.f32), + O: pto.tensor_spec(rank=1, dtype=pto.f32), + *, + BLOCK: pto.constexpr = 128, +): + N = A.shape[0] + + a_view = pto.make_tensor_view(A, shape=[N], strides=A.strides) + b_view = pto.make_tensor_view(B, shape=[N], strides=B.strides) + o_view = pto.make_tensor_view(O, shape=[N], strides=O.strides) + + a_tile = pto.alloc_tile(shape=[BLOCK], dtype=pto.f32, valid_shape=[pto.const(BLOCK)]) + b_tile = pto.alloc_tile(shape=[BLOCK], dtype=pto.f32, valid_shape=[pto.const(BLOCK)]) + o_tile = pto.alloc_tile(shape=[BLOCK], dtype=pto.f32, valid_shape=[pto.const(BLOCK)]) + + num_blocks = (N + BLOCK - 1) // BLOCK + + with pto.for_(0, num_blocks, step=1) as i: + offset = i * BLOCK + this_block = scalar.min(BLOCK, N - offset) + + a_part = pto.partition_view(a_view, offsets=[offset], sizes=[this_block]) + b_part = pto.partition_view(b_view, offsets=[offset], sizes=[this_block]) + o_part = pto.partition_view(o_view, offsets=[offset], sizes=[this_block]) + + pto.tile.load(a_part, a_tile) + pto.tile.load(b_part, b_tile) + + a_tile.valid_shape = [this_block] + b_tile.valid_shape = [this_block] + o_tile.valid_shape = [this_block] + + pto.tile.add(a_tile, b_tile, o_tile) + pto.tile.store(o_tile, o_part) +``` + +- `this_block = scalar.min(BLOCK, N - offset)` computes the actual block size for the tail iteration on the device side. +- `sizes=[this_block]` on the partition and `tile.valid_shape = [...]` on the tile tell `tile.load`/`tile.add`/`tile.store` how many elements are live. + +### 12.2.3 The general rule + +| Tail scenario | Mechanism | +|---------------|-----------| +| Tile Op boundary (tile.load/tile.store) | `valid_shape` on tile + smaller `sizes` on partition | +| SIMD vector boundary (vlds/vadd/vsts) | `make_mask` + mask parameter on op | +| SIMT scalar loop boundary | `min(BLOCK, N - offset)` in loop bound | + +## 12.3 GEMM: matrix multiplication on the Cube unit + +This example demonstrates a complete GEMM kernel: `C = A @ B` where A is `[M, K]` and B is `[K, N]`. It uses `@pto.jit` for tile allocation and loop scheduling, and `@pto.cube` for the actual matrix multiply. + +### 12.3.1 Cube sub-kernel + + +```python +@pto.cube +def gemm_tile(a_mat: pto.Tile, b_mat: pto.Tile, o_tile: pto.Tile, + a_l0a: pto.Tile, b_l0b: pto.Tile, o_acc: pto.Tile): + m = a_mat.valid_shape[0] + k = a_mat.valid_shape[1] + n = b_mat.valid_shape[1] + + pto.mte_l1_l0a(a_mat.as_ptr(), a_l0a.as_ptr(), m, k) + pto.mte_l1_l0b(b_mat.as_ptr(), b_l0b.as_ptr(), k, n) + pto.mad(a_l0a.as_ptr(), b_l0b.as_ptr(), o_acc.as_ptr(), m, n, k) + pto.mte_l0c_ub(o_acc.as_ptr(), o_tile.as_ptr(), m, n, n, n, 0) +``` + +The cube sub-kernel consumes MAT staging tiles plus cube-local scratch buffers. The four-step sequence — stage left operand, stage right operand, multiply, writeback — is the canonical cube compute pattern. + +### 12.3.2 Tile orchestration + + +```python +@pto.jit(target="a5", mode="explicit") +def gemm( + A: pto.tensor_spec(rank=2, dtype=pto.f32), + B: pto.tensor_spec(rank=2, dtype=pto.f32), + O: pto.tensor_spec(rank=2, dtype=pto.f32), + *, + BLOCK_M: pto.constexpr = 64, + BLOCK_K: pto.constexpr = 64, + BLOCK_N: pto.constexpr = 64, +): + M, K_ = A.shape + _, N_ = B.shape + + a_view = pto.make_tensor_view(A, shape=[M, K_], strides=A.strides) + b_view = pto.make_tensor_view(B, shape=[K_, N_], strides=B.strides) + o_view = pto.make_tensor_view(O, shape=[M, N_], strides=O.strides) + + a_mat = pto.alloc_tile(shape=[BLOCK_M, BLOCK_K], dtype=pto.f32, + memory_space=pto.MemorySpace.MAT) + b_mat = pto.alloc_tile(shape=[BLOCK_K, BLOCK_N], dtype=pto.f32, + memory_space=pto.MemorySpace.MAT) + o_tile = pto.alloc_tile(shape=[BLOCK_M, BLOCK_N], dtype=pto.f32) + + a_l0a = pto.alloc_tile(shape=[BLOCK_M, BLOCK_K], dtype=pto.f32, + memory_space=pto.MemorySpace.LEFT) + b_l0b = pto.alloc_tile(shape=[BLOCK_K, BLOCK_N], dtype=pto.f32, + memory_space=pto.MemorySpace.RIGHT) + o_acc = pto.alloc_tile(shape=[BLOCK_M, BLOCK_N], dtype=pto.f32, + memory_space=pto.MemorySpace.ACC) + + num_m = (M + BLOCK_M - 1) // BLOCK_M + num_n = (N_ + BLOCK_N - 1) // BLOCK_N + num_k = (K_ + BLOCK_K - 1) // BLOCK_K + + with pto.for_(0, num_m, step=1) as mi: + m_off = mi * BLOCK_M + with pto.for_(0, num_n, step=1) as ni: + n_off = ni * BLOCK_N + o_part = pto.partition_view(o_view, offsets=[m_off, n_off], + sizes=[BLOCK_M, BLOCK_N]) + + o_tile.fill(0.0) + + with pto.for_(0, num_k, step=1) as ki: + k_off = ki * BLOCK_K + + a_part = pto.partition_view(a_view, offsets=[m_off, k_off], + sizes=[BLOCK_M, BLOCK_K]) + b_part = pto.partition_view(b_view, offsets=[k_off, n_off], + sizes=[BLOCK_K, BLOCK_N]) + + pto.tile.load(a_part, a_mat) + pto.tile.load(b_part, b_mat) + + gemm_tile(a_mat, b_mat, o_tile, a_l0a, b_l0b, o_acc) + + pto.tile.store(o_tile, o_part) +``` + +**Key points**: + +- **Triply nested loops**: M, N, and K dimensions are all blocked. The K loop accumulates partial results into `o_tile`. +- **Accumulation**: `o_tile.fill(0.0)` resets the accumulator before the K loop. Each K-block calls `gemm_tile` which writes its partial product back to `o_tile`. The Cube unit accumulates implicitly via `mad` — each K-block's partial result is added to the running total in `o_acc`. +- **MAT staging + cube-local scratch**: `a_mat` and `b_mat` are explicit MAT tiles that satisfy the `mte_l1_l0a` / `mte_l1_l0b` source contract. `a_l0a`, `b_l0b`, and `o_acc` are cube-local scratch (`LEFT`, `RIGHT`, `ACC`). +- **Direct sub-kernel call**: `gemm_tile` is called directly from `@pto.jit` — no separate orchestration layer needed. The compiler handles sync between `tile.load` and the Cube sub-kernel. +- **Cube sub-kernel reuse**: the same `gemm_tile` function is called for every K-block — the named decorator form enables reuse. + +### 12.3.3 Python wrapper + + +```python +import numpy as np + + +def gemm_wrapper(A, B, O=None, stream=None): + if O is None: + O = np.empty((A.shape[0], B.shape[1]), dtype=A.dtype) + compiled = gemm.compile(BLOCK_M=64, BLOCK_K=64, BLOCK_N=64) + compiled[1, stream](A, B, O) + return O +``` + +This pattern extends directly to batch-GEMM: pass a grid of `batch` and use `pto.get_block_idx()` to select the per-batch slice from `A` and `B`. + +### 12.3.4 Comparison with explicit-mode orchestration + +For reference, the same GEMM could be written in `mode="explicit"` when the kernel needs micro-instruction control. The direct-call path used above is recommended for most users; explicit mode is for cases that need hand-authored instruction scheduling and ordering. + +## 12.4 Online normalization with loop-carried state + +Chapter 11 demonstrated online softmax with ping-pong state tiles. A simpler but instructive case is **online layer normalization** — computing mean and variance incrementally across blocks while carrying only scalar state between iterations. + +Given a vector `X` of length `N`, the streaming Welford algorithm updates the running mean `mu` and variance `var` as each new element `x` arrives: + +``` +n_next = n_prev + 1 +delta = x - mu_prev +mu_next = mu_prev + delta / n_next +m2_next = m2_prev + delta * (x - mu_next) +``` + +The example below keeps the whole pattern inside one `@pto.jit` kernel. The first pass carries `mu`, `n`, and `m2` across blocks; the second pass reloads each block and applies the normalization explicitly with scalar loads and stores. This version assumes `N > 0`. + +### 12.4.1 JIT example with loop-carried Welford state + + +```python +@pto.jit(target="a5") +def online_layernorm( + X: pto.tensor_spec(rank=1, dtype=pto.f32), + O: pto.tensor_spec(rank=1, dtype=pto.f32), + *, + BLOCK: pto.constexpr = 128, +): + N = X.shape[0] + x_view = pto.make_tensor_view(X, shape=[N], strides=X.strides) + o_view = pto.make_tensor_view(O, shape=[N], strides=O.strides) + + x_tile = pto.alloc_tile(shape=[BLOCK], dtype=pto.f32, valid_shape=[pto.const(BLOCK)]) + o_tile = pto.alloc_tile(shape=[BLOCK], dtype=pto.f32, valid_shape=[pto.const(BLOCK)]) + + num_blocks = (N + BLOCK - 1) // BLOCK + + # Pass 1: running Welford state across blocks. + stats_loop = pto.for_(0, num_blocks, step=1).carry( + mu=pto.f32(0.0), n=pto.f32(0.0), m2=pto.f32(0.0) + ) + with stats_loop: + i = stats_loop.iv + offset = i * BLOCK + this_block = scalar.min(BLOCK, N - offset) + x_part = pto.partition_view(x_view, offsets=[offset], sizes=[this_block]) + pto.tile.load(x_part, x_tile) + x_tile.valid_shape = [this_block] + + elem_loop = pto.for_(0, this_block, step=1).carry( + mu=stats_loop.mu, n=stats_loop.n, m2=stats_loop.m2 + ) + with elem_loop: + j = elem_loop.iv + x = scalar.load(x_tile.as_ptr(), j) + n_next = elem_loop.n + 1.0 + delta = x - elem_loop.mu + mu_next = elem_loop.mu + delta / n_next + delta2 = x - mu_next + m2_next = elem_loop.m2 + delta * delta2 + elem_loop.update(mu=mu_next, n=n_next, m2=m2_next) + + stats_loop.update( + mu=elem_loop.final("mu"), + n=elem_loop.final("n"), + m2=elem_loop.final("m2"), + ) + + mean = stats_loop.final("mu") + count = stats_loop.final("n") + inv_std = 1.0 / scalar.sqrt(stats_loop.final("m2") / count + pto.f32(1.0e-5)) + + # Pass 2: apply (x - mean) / sqrt(var + eps) block by block. + with pto.for_(0, num_blocks, step=1) as i: + offset = i * BLOCK + this_block = scalar.min(BLOCK, N - offset) + x_part = pto.partition_view(x_view, offsets=[offset], sizes=[this_block]) + o_part = pto.partition_view(o_view, offsets=[offset], sizes=[this_block]) + + pto.tile.load(x_part, x_tile) + x_tile.valid_shape = [this_block] + o_tile.valid_shape = [this_block] + + with pto.for_(0, this_block, step=1) as j: + x = scalar.load(x_tile.as_ptr(), j) + y = (x - mean) * inv_std + scalar.store(y, o_tile.as_ptr(), j) + + pto.tile.store(o_tile, o_part) +``` + +**Key points**: + +- **Carry state**: `.carry(mu=..., n=..., m2=...)` on both loops keeps the running Welford state in SSA form. The outer loop carries state across blocks; the inner loop carries state across elements inside one block. +- **Tail handling**: `scalar.min(BLOCK, N - offset)` computes the live width of the current block, and `tile.valid_shape = [this_block]` keeps the tile contract aligned with that tail. +- **No special tile op required**: the normalization pass is written explicitly with `scalar.load(...)`, scalar arithmetic, `scalar.sqrt(...)`, and `scalar.store(...)`. There is no dependency on a dedicated `tnormalize` op. +- **Compare to flash attention**: the flash attention carry in Chapter 11 moves several tiles through ping-pong buffers. Here the carried state is only three scalars, so the same `.carry(...)` surface reads more like a conventional streaming reduction. + +## 12.5 Design guidelines + +**Start simple, refine later.** Begin with `@pto.jit` + Tile Ops. If Tile Ops don't cover the computation (e.g., custom softmax, specialized activation), add a sub-kernel. If you need micro-instruction-level control, switch the kernel to `mode="explicit"`. + +**Choose the right entry for each piece:** + +| Goal | Use | +|------|-----| +| Whole-kernel orchestration, GM↔UB boundary | `@pto.jit` | +| Tile-level data movement | `tile.load` / `tile.store` | +| Custom row-wise vector math | `@pto.simd` | +| Custom per-element logic | `@pto.simt` | +| Matrix multiply | `@pto.cube` | +| Micro-instruction-level control | `mode="explicit"` | +| Inline compute for quick prototyping | `with pto.simd():` etc. | + +**Respect boundary contracts.** Vregs don't cross `@pto.simd` boundaries. Cube-local state doesn't leak into UB. Tile Ops and MTE Ops belong to different programming models — use Tile Ops in `mode="auto"`, and micro-instructions in `mode="explicit"`. diff --git a/ptodsl/examples/flash_attention_sketch.py b/ptodsl/examples/flash_attention_sketch.py new file mode 100644 index 000000000..0a9dcaf60 --- /dev/null +++ b/ptodsl/examples/flash_attention_sketch.py @@ -0,0 +1,729 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +""" +Flash Attention compile-only demo. + +This file is a compileable PTODSL demo whose current milestone is MLIR +emission, inspection, and API review. The goal is to make the intended API +layering explicit and keep the semantic contracts clean: + + emit_flash_attention_mlir(...) compile/inspect wrapper + └─ flash_attention_kernel (@pto.jit, mode="explicit") + ├─ Tile Ops tile.load / tile.store at the GM↔UB boundary + ├─ explicit orchestration mte_load / pipe_barrier / pointer sequencing + ├─ @pto.cube matrix products (QK^T and P@V) + ├─ @pto.simd row-wise online softmax + └─ @pto.simt scalar metadata and output blending + +Design rules illustrated here: + +1. ``@pto.jit`` marks a launchable kernel template. It owns JIT compilation, + cache lookup, and artifact emission, instead of forcing users to hop through + extra builder objects for common cases. +2. The Python wrapper owns compile/inspection concerns such as selecting + specialization knobs and returning the emitted MLIR text for review. +3. ``@pto.jit`` also owns the top-level logical tiling, tile allocation, and + loop scheduling for one already-selected per-head 2D slice. The per-block + DMA and barrier choreography is delegated to explicit orchestration. +4. explicit mode owns the per-block execution sandwich: stage the current K/V + block with explicit micro-instructions, synchronize, call hardware-bound + sub-kernels, and manage scratch/state. +5. ``@pto.jit`` may use tile ops such as ``tile.load`` / ``tile.store`` at the logical + scheduling boundary, but explicit mode can also express GM<->UB movement + directly. Once execution enters explicit orchestration, MTE micro-instructions + such as ``mte_load`` are used instead of tile ops where needed. + ``mte_load`` / ``mte_store`` accept partitions and tiles directly, + deriving strides and burst sizes from the type metadata. +6. ``simd`` / ``simt`` / ``cube`` are hardware boundaries. They do not expose + vreg values across the function boundary. Data crosses the boundary through + UB-backed tiles or typed UB pointers only. +7. Named sub-kernels are reusable wherever their parameter contract is + satisfied. This sketch uses the explicit ``@pto.jit(mode="explicit")`` path + because it needs user-ordered DMA and phase barriers; smaller kernels can + stay in auto mode and rely on tile-atomic staging instead. +8. Online-softmax state is made explicit with ping-pong tiles + (``m_prev``/``m_next``, ``l_prev``/``l_next``, ``o_prev``/``o_next``). + Hiding these dependencies with in-place aliases makes the algorithm harder + to read and obscures what the DSL needs to express. + +Because this demo targets a tracing-style frontend, any control flow that +must reach MLIR is expressed with structured DSL constructs such as +``pto.for_`` instead of native Python ``for`` loops. + +Scalar literals and simple index/integer conversions are also written in the +authored PTODSL surface. The current frontend lowers these through tracing +instead of forcing authors to spell ``pto.const(...)`` or ``index_cast(...)`` +at every use site. +""" + +from pathlib import Path +import sys + +if __package__ in {None, ""}: + here = Path(__file__).resolve() + for candidate in here.parents: + if (candidate / "ptodsl" / "__init__.py").exists(): + sys.path.insert(0, str(candidate)) + break + else: + raise RuntimeError( + "Unable to locate the PTODSL Python package root from flash_attention_sketch.py" + ) + +from ptodsl import pto, scalar + + +def _min_index(lhs, rhs): + return scalar.select( + lhs < rhs, + lhs, + rhs, + ) + + +def _block_valid_extent(total, block_index, block_size): + return _min_index(total - block_index * block_size, pto.const(block_size)) + + +# ═══════════════════════════════════════════════════════════════════════════════ +# Public API sketch +# ═══════════════════════════════════════════════════════════════════════════════ +# +# This section shows the current compile-only public surface. The split follows +# the common industry pattern: +# +# - a user-facing tensor wrapper +# - a launchable JIT kernel entry +# - hardware-bound sub-kernels below it +# +# The low-level kernel body should not double as the user-facing runtime API. +# +# Two intended usage styles for the current compile-only milestone: +# +# 1. One-shot MLIR emission: +# mlir_text = emit_flash_attention_mlir(head_dim=128, causal=True) +# +# 2. Compile first, then inspect: +# compiled = flash_attention_kernel.compile(BLOCK_Q=128, BLOCK_KV=128, CAUSAL=True) +# mlir_text = compiled.mlir_text() + +def emit_flash_attention_mlir( + *, + head_dim=128, + causal=False, + block_q=128, + block_kv=128, +): + """ + Compile the flash-attention sketch and return its MLIR text. + + The current milestone for this demo is compile / inspect / review, not + runtime launch. The wrapper therefore only specializes the JIT kernel and + returns the emitted MLIR text. + """ + compiled = flash_attention_kernel.compile( + BLOCK_Q=block_q, + BLOCK_KV=block_kv, + HEAD_DIM=head_dim, + CAUSAL=causal, + ) + return compiled.mlir_text() + +@pto.jit(target="a5", mode="explicit") +def flash_attention_kernel( + Q: pto.tensor_spec(rank=4, dtype=pto.f32), # Python/framework tensor, logical [batch, seq_q, heads, dim] + K: pto.tensor_spec(rank=4, dtype=pto.f32), # Python/framework tensor, logical [batch, seq_k, heads, dim] + V: pto.tensor_spec(rank=4, dtype=pto.f32), # Python/framework tensor, logical [batch, seq_k, heads, dim] + O: pto.tensor_spec(rank=4, dtype=pto.f32), # Python/framework tensor, logical [batch, seq_q, heads, dim] + *, + BLOCK_Q: pto.constexpr = 128, + BLOCK_KV: pto.constexpr = 128, + HEAD_DIM: pto.constexpr = 128, + CAUSAL: pto.constexpr = False, + NUM_STAGES: pto.constexpr = 2, +): + """ + Launchable device entry. + + ``@pto.jit`` is the compile boundary. Inputs/outputs at this + boundary are Python-native tensor objects; PTO-specific ``TensorView`` + descriptors are materialized inside the JIT body rather than exposed in the + public signature. Tile sizes and specialization knobs remain constexpr + metadata. + + A launch instance is responsible for one ``(batch, head)`` slice. The + per-slice logical tiling is expressed directly in this top-level JIT entry. + """ + batch, seq_q, heads, dim = Q.shape + _, seq_k, _, _ = K.shape + + q_view = pto.make_tensor_view(Q) + k_view = pto.make_tensor_view(K) + v_view = pto.make_tensor_view(V) + o_view = pto.make_tensor_view(O) + + # Make the SPMD launch contract explicit in the authored surface. + # This sketch uses one block per (batch, head) slice and does not further + # split work across subblocks, but the runtime indices still belong in a + # realistic launchable entry. + block_idx = pto.get_block_idx() + block_num = pto.get_block_num() + subblock_idx = pto.get_subblock_idx() + subblock_num = pto.get_subblock_num() + + # Current mapping: + # - launch grid = batch * heads + # - block_idx selects one (batch, head) slice + # - subblock_idx is queried explicitly, but no extra intra-block partition + # is modeled in this sketch yet + _ = block_num + _ = subblock_idx + _ = subblock_num + + batch_idx = block_idx // heads + head_idx = block_idx % heads + + q_head = pto.partition_view( + q_view, + offsets=[batch_idx, 0, head_idx, 0], + sizes=[1, seq_q, 1, dim], + ) + k_head = pto.partition_view( + k_view, + offsets=[batch_idx, 0, head_idx, 0], + sizes=[1, seq_k, 1, dim], + ) + v_head = pto.partition_view( + v_view, + offsets=[batch_idx, 0, head_idx, 0], + sizes=[1, seq_k, 1, dim], + ) + o_head = pto.partition_view( + o_view, + offsets=[batch_idx, 0, head_idx, 0], + sizes=[1, seq_q, 1, dim], + ) + + Br = BLOCK_Q + Bc = BLOCK_KV + D = HEAD_DIM + full_br = pto.const(Br) + full_bc = pto.const(Bc) + one = pto.const(1) + + q_blocks = (seq_q + Br - 1) // Br + kv_blocks = (seq_k + Bc - 1) // Bc + + # Physical tile shape remains static. Runtime tails live in valid_shape. + # Cube bridge sources are MAT-backed so they can feed LEFT/RIGHT staging. + q_mat = pto.alloc_tile( + shape=[Br, D], + dtype=pto.f32, + memory_space=pto.MemorySpace.MAT, + valid_shape=[full_br, dim], + blayout="ColMajor", + slayout="RowMajor", + ) + k_mat = pto.alloc_tile( + shape=[Bc, D], + dtype=pto.f32, + memory_space=pto.MemorySpace.MAT, + valid_shape=[full_bc, dim], + blayout="ColMajor", + slayout="RowMajor", + ) + v_mat = pto.alloc_tile( + shape=[Bc, D], + dtype=pto.f32, + memory_space=pto.MemorySpace.MAT, + valid_shape=[full_bc, dim], + blayout="ColMajor", + slayout="RowMajor", + ) + + o_prev_tile = pto.alloc_tile(shape=[Br, D], dtype=pto.f32, valid_shape=[full_br, dim]) + o_next_tile = pto.alloc_tile(shape=[Br, D], dtype=pto.f32, valid_shape=[full_br, dim]) + m_prev_tile = pto.alloc_tile(shape=[Br, 1], dtype=pto.f32, valid_shape=[full_br, one], blayout="ColMajor") + m_next_tile = pto.alloc_tile(shape=[Br, 1], dtype=pto.f32, valid_shape=[full_br, one], blayout="ColMajor") + l_prev_tile = pto.alloc_tile(shape=[Br, 1], dtype=pto.f32, valid_shape=[full_br, one], blayout="ColMajor") + l_next_tile = pto.alloc_tile(shape=[Br, 1], dtype=pto.f32, valid_shape=[full_br, one], blayout="ColMajor") + + s_tile = pto.alloc_tile(shape=[Br, Bc], dtype=pto.f32, valid_shape=[full_br, full_bc]) + p_tile = pto.alloc_tile(shape=[Br, Bc], dtype=pto.f32, valid_shape=[full_br, full_bc]) + p_mat = pto.alloc_tile( + shape=[Br, Bc], + dtype=pto.f32, + memory_space=pto.MemorySpace.MAT, + valid_shape=[full_br, full_bc], + blayout="ColMajor", + slayout="RowMajor", + ) + pv_tile = pto.alloc_tile(shape=[Br, D], dtype=pto.f32, valid_shape=[full_br, dim]) + alpha_tile = pto.alloc_tile(shape=[Br, 1], dtype=pto.f32, valid_shape=[full_br, one], blayout="ColMajor") + beta_tile = pto.alloc_tile(shape=[Br, 1], dtype=pto.f32, valid_shape=[full_br, one], blayout="ColMajor") + + # Cube-local scratch is explicit; it should not be conflated with UB tiles. + q_l0a = pto.alloc_tile(shape=[Br, D], dtype=pto.f32, memory_space=pto.MemorySpace.LEFT, valid_shape=[full_br, dim]) + p_l0a = pto.alloc_tile(shape=[Br, Bc], dtype=pto.f32, memory_space=pto.MemorySpace.LEFT, valid_shape=[full_br, full_bc]) + rhs_l0b = pto.alloc_tile(shape=[Bc, D], dtype=pto.f32, memory_space=pto.MemorySpace.RIGHT, valid_shape=[full_bc, dim]) + qk_acc_tile = pto.alloc_tile(shape=[Br, Bc], dtype=pto.f32, memory_space=pto.MemorySpace.ACC, valid_shape=[full_br, full_bc]) + pv_acc_tile = pto.alloc_tile(shape=[Br, D], dtype=pto.f32, memory_space=pto.MemorySpace.ACC, valid_shape=[full_br, dim]) + + # SIMT metadata buffer. A tiny raw-pointer island is acceptable at the + # explicit-orchestration boundary because this is scalar control data, not + # user-facing math. + meta_tile = pto.alloc_tile(shape=[1, 8], dtype=pto.i32, valid_shape=[1, 3]) + meta_ptr = meta_tile.as_ptr() + + with pto.for_(0, q_blocks, step=1) as qi: + q_rows = _block_valid_extent(seq_q, qi, Br) + q_part = pto.partition_view(q_head, offsets=[0, qi * Br, 0, 0], sizes=[1, q_rows, 1, dim]) + o_part = pto.partition_view(o_head, offsets=[0, qi * Br, 0, 0], sizes=[1, q_rows, 1, dim]) + + q_mat.valid_shape = [q_rows, dim] + o_prev_tile.valid_shape = [q_rows, dim] + o_next_tile.valid_shape = [q_rows, dim] + m_prev_tile.valid_shape = [q_rows, one] + m_next_tile.valid_shape = [q_rows, one] + l_prev_tile.valid_shape = [q_rows, one] + l_next_tile.valid_shape = [q_rows, one] + alpha_tile.valid_shape = [q_rows, one] + beta_tile.valid_shape = [q_rows, one] + p_mat.valid_shape = [q_rows, full_bc] + pv_tile.valid_shape = [q_rows, dim] + q_l0a.valid_shape = [q_rows, dim] + + pto.tile.load(q_part, q_mat) + + # Initial online-softmax state for this Q block. + # ``CAUSAL`` is threaded at the API boundary even though the masking + # details are intentionally omitted from this design-focused sketch. + m_prev_tile.fill(float("-inf")) + l_prev_tile.fill(0.0) + o_prev_tile.fill(0.0) + + kv_loop = pto.for_(0, kv_blocks, step=1).carry( + m=m_prev_tile, + l=l_prev_tile, + o=o_prev_tile, + ) + with kv_loop: + kj = kv_loop.iv + m_cur = kv_loop.m + l_cur = kv_loop.l + o_cur = kv_loop.o + kv_rows = _block_valid_extent(seq_k, kj, Bc) + k_part = pto.partition_view(k_head, offsets=[0, kj * Bc, 0, 0], sizes=[1, kv_rows, 1, dim]) + v_part = pto.partition_view(v_head, offsets=[0, kj * Bc, 0, 0], sizes=[1, kv_rows, 1, dim]) + + k_mat.valid_shape = [kv_rows, dim] + v_mat.valid_shape = [kv_rows, dim] + s_tile.valid_shape = [q_rows, kv_rows] + p_tile.valid_shape = [q_rows, kv_rows] + p_mat.valid_shape = [q_rows, kv_rows] + pv_tile.valid_shape = [q_rows, dim] + p_l0a.valid_shape = [q_rows, kv_rows] + rhs_l0b.valid_shape = [kv_rows, dim] + qk_acc_tile.valid_shape = [q_rows, kv_rows] + pv_acc_tile.valid_shape = [q_rows, dim] + + kv_block_process( + q_mat, + k_part, + v_part, + k_mat, + v_mat, + o_cur, + o_next_tile, + m_cur, + l_cur, + m_next_tile, + l_next_tile, + s_tile, + p_tile, + p_mat, + pv_tile, + alpha_tile, + beta_tile, + q_l0a, + p_l0a, + rhs_l0b, + qk_acc_tile, + pv_acc_tile, + meta_ptr, + ) + + # Loop-carried state is still explicit, but the authored surface no + # longer mirrors raw scf.iter_args / scf.yield spellings. + kv_loop.update( + m=m_next_tile, + l=l_next_tile, + o=o_next_tile, + ) + + o_final_tile = kv_loop.final("o") + pto.tile.store(o_final_tile, o_part) + + +# ═══════════════════════════════════════════════════════════════════════════════ +# Hardware-bound sub-kernels +# ═══════════════════════════════════════════════════════════════════════════════ +# +# Boundary contract: +# - Tile arguments are UB-backed or cube-local buffers carrying addressable +# storage. +# - No vector register escapes a simd function. +# - No implicit global-memory access happens inside these kernels. + + +@pto.cube +def qk_matmul( + q_mat: pto.Tile, # MAT, [Br, dim] + k_mat: pto.Tile, # MAT, [Bc, dim] + q_l0a: pto.Tile, # LEFT scratch + k_l0b: pto.Tile, # RIGHT scratch + s_acc: pto.Tile, # ACC scratch + s_tile: pto.Tile, # UB, [Br, Bc] output +): + """ + Compute ``S = Q @ K^T`` for one attention block. + + The key point for the redesign is that the cube kernel consumes MAT tiles and + explicit cube-local scratch, rather than pretending a logical scheduling tile can also stand + in for LEFT/RIGHT/ACC state. + """ + m = q_mat.valid_shape[0] + k = q_mat.valid_shape[1] + n = k_mat.valid_shape[0] + + # Caller owns scratch lifetime. The cube kernel only expresses dataflow. + pto.mte_l1_l0a(q_mat.as_ptr(), q_l0a.as_ptr(), m, k) + pto.mte_l1_l0b(k_mat.as_ptr(), k_l0b.as_ptr(), k, n, transpose=True) + pto.mad(q_l0a.as_ptr(), k_l0b.as_ptr(), s_acc.as_ptr(), m, n, k) + pto.mte_l0c_ub(s_acc.as_ptr(), s_tile.as_ptr(), m, n, n, n, 0) + + +@pto.cube +def pv_matmul( + p_mat: pto.Tile, # MAT, [Br, Bc] + v_mat: pto.Tile, # MAT, [Bc, dim] + p_l0a: pto.Tile, # LEFT scratch (reused) + v_l0b: pto.Tile, # RIGHT scratch (reused) + pv_acc: pto.Tile, # ACC scratch (reused) + pv_tile: pto.Tile, # UB, [Br, dim] output +): + """ + Compute ``PV = P @ V`` for the current block. + + This keeps the second matrix product on the cube path as well, instead of + accidentally collapsing it into an elementwise vector expression. + """ + m = p_mat.valid_shape[0] + k = p_mat.valid_shape[1] + n = v_mat.valid_shape[1] + + pto.mte_l1_l0a(p_mat.as_ptr(), p_l0a.as_ptr(), m, k) + pto.mte_l1_l0b(v_mat.as_ptr(), v_l0b.as_ptr(), k, n) + pto.mad(p_l0a.as_ptr(), v_l0b.as_ptr(), pv_acc.as_ptr(), m, n, k) + pto.mte_l0c_ub(pv_acc.as_ptr(), pv_tile.as_ptr(), m, n, n, n, 0) + + +@pto.simd +def online_softmax_rows( + s_tile: pto.Tile, # UB, [Br, Bc] + p_tile: pto.Tile, # UB, [Br, Bc], output + m_prev_tile: pto.Tile, # UB, [Br, 1] + l_prev_tile: pto.Tile, # UB, [Br, 1] + m_next_tile: pto.Tile, # UB, [Br, 1], output + l_next_tile: pto.Tile, # UB, [Br, 1], output + alpha_tile: pto.Tile, # UB, [Br, 1], output + beta_tile: pto.Tile, # UB, [Br, 1], output + row_start: pto.i32, + row_stop: pto.i32, + valid_cols: pto.i32, +): + """ + Per-row online softmax update. + + For each active row:: + + m_next = max(m_prev, row_max(S)) + P = exp(S - m_next) + l_next = l_prev * exp(m_prev - m_next) + row_sum(P) + alpha = l_prev * exp(m_prev - m_next) / l_next + beta = 1 / l_next + + ``alpha`` and ``beta`` are kept explicitly because the output update needs + both the old accumulator and the newly computed ``P @ V`` contribution. + """ + with pto.for_(row_start, row_stop, step=1) as row: + col_mask = pto.make_mask(pto.f32, valid_cols) + + s_row = pto.vlds(s_tile[row, 0:]) + m_prev = scalar.load(m_prev_tile[row, 0]) + l_prev = scalar.load(l_prev_tile[row, 0]) + + row_max = pto.vcgmax(s_row, col_mask) + m_next = scalar.max(m_prev, row_max) + + s_shifted = pto.vsubs(s_row, m_next, col_mask) + p_row = pto.vexp(s_shifted, col_mask) + + row_sum = pto.vcgadd(p_row, col_mask) + l_scaled = l_prev * scalar.exp(m_prev - m_next) + l_next = l_scaled + row_sum + + alpha = l_scaled / l_next + beta = 1.0 / l_next + + pto.vsts(p_row, p_tile[row, 0:], col_mask) + scalar.store(m_next, m_next_tile[row, 0]) + scalar.store(l_next, l_next_tile[row, 0]) + scalar.store(alpha, alpha_tile[row, 0]) + scalar.store(beta, beta_tile[row, 0]) + + +@pto.simt +def blend_output_rows( + o_prev_tile: pto.Tile, # UB, [Br, dim] + pv_tile: pto.Tile, # UB, [Br, dim] + alpha_tile: pto.Tile, # UB, [Br, 1] + beta_tile: pto.Tile, # UB, [Br, 1] + o_next_tile: pto.Tile, # UB, [Br, dim], output + row_start: pto.i32, + row_stop: pto.i32, + valid_dim: pto.i32, +): + """ + Update the output accumulator with SIMT-style scalar element work:: + + O_next[row, col] = alpha[row] * O_prev[row, col] + beta[row] * PV[row, col] + + This intentionally contrasts with ``online_softmax_rows``: the softmax step + stays on the SIMD path because it is dominated by row-wise vector math, + while the final blend is expressed here as explicit scalar work-items over + the tile domain. + """ + with pto.for_(row_start, row_stop, step=1) as row: + alpha = scalar.load(alpha_tile[row, 0]) + beta = scalar.load(beta_tile[row, 0]) + + with pto.for_(0, valid_dim, step=1) as col: + o_prev = scalar.load(o_prev_tile[row, col]) + pv_val = scalar.load(pv_tile[row, col]) + + o_next = alpha * o_prev + beta * pv_val + scalar.store(o_next, o_next_tile[row, col]) + + +@pto.simt +def materialize_tile_bounds( + meta_ptr: pto.ptr(pto.i32, pto.MemorySpace.UB), # [out] {row_start, row_stop, valid_cols} + valid_rows: pto.i32, + valid_cols: pto.i32, +): + """ + Materialize tile-local loop bounds for the current block. + + The SIMT kernel stays intentionally small here: it is responsible for + scalar control metadata, not for rewriting the vector or cube logic. + """ + scalar.store(0, meta_ptr + 0) + scalar.store(valid_rows, meta_ptr + 1) + scalar.store(valid_cols, meta_ptr + 2) + + +# ═══════════════════════════════════════════════════════════════════════════════ +# Level 2: explicit orchestration — one KV block worth of execution +# ═══════════════════════════════════════════════════════════════════════════════ + + +def kv_block_process( + q_mat: pto.Tile, # MAT, reused across inner KV loop + k_part: pto.PartitionTensorView, # GM view for current K block + v_part: pto.PartitionTensorView, # GM view for current V block + k_mat: pto.Tile, # MAT scratch + v_mat: pto.Tile, # MAT scratch + o_prev_tile: pto.Tile, # UB state + o_next_tile: pto.Tile, # UB state + m_prev_tile: pto.Tile, # UB state + l_prev_tile: pto.Tile, # UB state + m_next_tile: pto.Tile, # UB state + l_next_tile: pto.Tile, # UB state + s_tile: pto.Tile, # UB scratch for QK^T + p_tile: pto.Tile, # UB scratch for probabilities + p_mat: pto.Tile, # MAT scratch for probabilities + pv_tile: pto.Tile, # UB scratch for P@V + alpha_tile: pto.Tile, # UB scratch + beta_tile: pto.Tile, # UB scratch + q_l0a: pto.Tile, # LEFT scratch for Q + p_l0a: pto.Tile, # LEFT scratch for P + rhs_l0b: pto.Tile, # RIGHT scratch, reused by K/V + qk_acc_tile: pto.Tile, # ACC scratch for QK^T + pv_acc_tile: pto.Tile, # ACC scratch for P@V + meta_ptr: pto.ptr(pto.i32, pto.MemorySpace.UB), +): + """ + Process one KV block against an already-loaded Q tile. + + The explicit-mode body owns: + - staging the current K/V block into reusable UB scratch with explicit + DMA-style micro-instructions, + - synchronizing the hand-off between MTE, cube, simd, and simt stages, + - wiring together the explicit state transition + (prev -> next for m/l/o). + """ + # Current-block GM->MAT staging via explicit ptr-based DMA parameters. + rows = k_mat.valid_shape[0] + cols = k_mat.valid_shape[1] + row_bytes = cols * pto.bytewidth(pto.f32) + gm_row_stride = k_part.strides[0] * pto.bytewidth(pto.f32) + mat_row_stride = k_mat.shape[1] * pto.bytewidth(pto.f32) + pto.mte_load( + k_part.as_ptr(), + k_mat.as_ptr(), + 0, + row_bytes, + nburst=(rows, gm_row_stride, mat_row_stride), + ) + pto.mte_load( + v_part.as_ptr(), + v_mat.as_ptr(), + 0, + row_bytes, + nburst=(rows, gm_row_stride, mat_row_stride), + ) + pto.pipe_barrier(pto.Pipe.ALL) + + materialize_tile_bounds( + meta_ptr, + q_mat.valid_shape[0], + k_mat.valid_shape[0], + ) + row_start = scalar.load(meta_ptr + 0) + row_stop = scalar.load(meta_ptr + 1) + valid_cols = scalar.load(meta_ptr + 2) + + # 1. S = Q @ K^T + qk_matmul(q_mat, k_mat, q_l0a, rhs_l0b, qk_acc_tile, s_tile) + pto.pipe_barrier(pto.Pipe.ALL) + + # 2. Row-wise online softmax over S + online_softmax_rows( + s_tile, + p_tile, + m_prev_tile, + l_prev_tile, + m_next_tile, + l_next_tile, + alpha_tile, + beta_tile, + row_start, + row_stop, + valid_cols, + ) + pto.pipe_barrier(pto.Pipe.ALL) + + # Stage the probability tile onto the cube MAT path. + pto.tile.mov(p_tile, p_mat) + pto.pipe_barrier(pto.Pipe.ALL) + + # 3. PV = P @ V + pv_matmul(p_mat, v_mat, p_l0a, rhs_l0b, pv_acc_tile, pv_tile) + pto.pipe_barrier(pto.Pipe.ALL) + + # 4. O_next = alpha * O_prev + beta * PV + blend_output_rows( + o_prev_tile, + pv_tile, + alpha_tile, + beta_tile, + o_next_tile, + row_start, + row_stop, + v_mat.valid_shape[1], + ) + pto.pipe_barrier(pto.Pipe.ALL) + + +# ═══════════════════════════════════════════════════════════════════════════════ +# Layer summary +# ═══════════════════════════════════════════════════════════════════════════════ +# +# ┌──────────────────────────────────────────────────────────────────────────┐ +# │ L0 Python wrapper emit_flash_attention_mlir(...) │ +# │ │ +# │ specialize kernel parameters, compile, emit MLIR text │ +# │ │ +# │ Key idea: current demo goal is compile/inspect, not runtime launch. │ +# ├──────────────────────────────────────────────────────────────────────────┤ +# │ L1 @pto.jit(mode="explicit") flash_attention_kernel │ +# │ │ +# │ flash_attention_kernel.compile(...).mlir_text() │ +# │ TensorView metadata / alloc_tile / partition_view / tile.load / tile.store │ +# │ outer Q loop + inner KV loop + ping-pong state ownership │ +# │ │ +# │ Key idea: one launchable entry owns both runtime binding and logical │ +# │ tile scheduling. │ +# ├──────────────────────────────────────────────────────────────────────────┤ +# │ L2 explicit orchestration Per-block execution sandwich │ +# │ │ +# │ explicit mte_load(part, tile) staging for current K/V block, │ +# │ pipe_barrier, call cube/simd/simt sub-kernels, │ +# │ manage scratch/state hand-off │ +# │ │ +# │ Key idea: one place owns the "how this block runs on hardware" story. │ +# ├──────────────────────────────────────────────────────────────────────────┤ +# │ @pto.cube Matrix-product kernels │ +# │ │ +# │ qk_matmul: Q @ K^T │ +# │ pv_matmul: P @ V │ +# │ explicit LEFT/RIGHT/ACC scratch + UB output │ +# │ │ +# │ Key idea: UB tiles are inputs/outputs; cube-local state is explicit. │ +# ├──────────────────────────────────────────────────────────────────────────┤ +# │ @pto.simd Row-wise vector math │ +# │ │ +# │ online_softmax_rows │ +# │ vreg stays local; persistent state is written back to UB tiles │ +# │ │ +# │ Key idea: no cross-kernel vreg values, only UB-backed state. │ +# ├──────────────────────────────────────────────────────────────────────────┤ +# │ @pto.simt Scalar metadata and pointwise blend │ +# │ │ +# │ materialize_tile_bounds / blend_output_rows │ +# │ │ +# │ Key idea: SIMT handles scalar control facts and scalar tile walks. │ +# └──────────────────────────────────────────────────────────────────────────┘ +# +# dataflow for one KV block +# +# jit kernel alloc/schedule +# │ +# ▼ +# explicit orchestration loads K/V block and sequences the pipeline +# │ +# ├─ cube: Q + K ───────────────► S +# ├─ simd: S + (m_prev, l_prev) ─► P, (m_next, l_next), alpha, beta +# ├─ cube: P + V ───────────────► PV +# └─ simt: (o_prev, PV, alpha, beta) ─► o_next +# +# After each KV block: +# (m_prev, l_prev, o_prev) := (m_next, l_next, o_next) +# +# The important part for the demo is that every cross-stage dependency is +# visible in the surface language and the whole kernel can already be traced to +# MLIR for review. + + +def main(): + print(emit_flash_attention_mlir()) + + +if __name__ == "__main__": + main() diff --git a/ptodsl/examples/jit/flash_attention_softmax_launch.py b/ptodsl/examples/jit/flash_attention_softmax_launch.py new file mode 100644 index 000000000..3ed204bd7 --- /dev/null +++ b/ptodsl/examples/jit/flash_attention_softmax_launch.py @@ -0,0 +1,290 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +""" +Row-wise softmax — end-to-end launch demo. + +This example is the launchable counterpart to the compile-only +``softmax_dsl.py`` sample. It uses an online-softmax recurrence internally, +but the public kernel surface is the ordinary softmax contract: load an input +score matrix, compute the per-row softmax, and store the normalized output. + +Each kernel instance runs on one NPU, preloads the full score matrix to UB, +where the input is already laid out as ``[seq, rows]`` so each UB row +represents one score column, and then streams 64-row packs through the +online-softmax recurrence: + + running_max = max(running_max, score_col) + running_sum = running_sum * exp(old_max - new_max) + exp(score_col - new_max) + out = exp(score_col - final_max) / final_sum + +The demo offers two launchable kernels so the current launch ABI does not need +an extra runtime tile-width parameter: + +- ``rows64_seq128``: full-width 64-row packed softmax +- ``rows81_seq96``: same single NPU, but two sequential row-pack updates +""" + +import argparse +import time +from pathlib import Path +import sys + +import numpy as np + +if __package__ in {None, ""}: + here = Path(__file__).resolve() + for candidate in here.parents: + if (candidate / "ptodsl" / "__init__.py").exists(): + sys.path.insert(0, str(candidate)) + break + else: + raise RuntimeError( + "Unable to locate the PTODSL Python package root from flash_attention_softmax_launch.py" + ) + +from ptodsl import pto + +_DEVICE = "npu:0" + + +def _make_softmax_kernel(name: str, *, rows: int, seq: int): + if rows <= 0: + raise ValueError("rows must be positive") + if seq <= 0: + raise ValueError("seq must be positive") + + @pto.jit( + name=name, + target="a5", + mode="explicit", + insert_sync=False + ) + def kernel( + scores: pto.tensor_spec(rank=2, dtype=pto.f32), + out: pto.tensor_spec(rank=2, dtype=pto.f32), + ): + lane_num = pto.elements_per_vreg(pto.f32) + physical_rows = ((rows + lane_num - 1) // lane_num) * lane_num + scores_tile_bytes = seq * physical_rows * pto.bytewidth(pto.f32) + runtime_seq = scores.shape[0] + runtime_rows = scores.shape[1] + total_elems = runtime_rows * runtime_seq + + scores_view = pto.make_tensor_view( + scores, + shape=[1, 1, 1, runtime_seq, runtime_rows], + strides=[total_elems, total_elems, total_elems, runtime_rows, 1], + ) + out_view = pto.make_tensor_view( + out, + shape=[1, 1, 1, runtime_seq, runtime_rows], + strides=[total_elems, total_elems, total_elems, runtime_rows, 1], + ) + scores_part = pto.partition_view( + scores_view, + offsets=[0, 0, 0, 0, 0], + sizes=[1, 1, 1, runtime_seq, runtime_rows], + ) + out_part = pto.partition_view( + out_view, + offsets=[0, 0, 0, 0, 0], + sizes=[1, 1, 1, runtime_seq, runtime_rows], + ) + + scores_tile = pto.alloc_tile( + shape=[seq, physical_rows], + dtype=pto.float32, + addr=0, + valid_shape=[runtime_seq, runtime_rows], + blayout="RowMajor", + ) + out_tile = pto.alloc_tile( + shape=[seq, physical_rows], + dtype=pto.float32, + addr=scores_tile_bytes, + valid_shape=[runtime_seq, runtime_rows], + blayout="RowMajor", + ) + + pto.tile.load(scores_part, scores_tile) + out_tile.fill(0.0) + + pto.set_flag("MTE2", "V", event_id=0) + pto.wait_flag("MTE2", "V", event_id=0) + + with pto.simd(): + row_loop = pto.for_(0, runtime_rows, step=lane_num).carry(remained=runtime_rows) + with row_loop: + row_base = row_loop.iv + remaining_rows = row_loop.remained + active_rows, remaining_after_pack = pto.make_mask(pto.f32, remaining_rows) + running_max = pto.vlds(scores_tile[0, row_base:]) + running_sum = pto.vbr(1.0) + + softmax_loop = pto.for_(1, runtime_seq, step=1).carry( + running_max=running_max, + running_sum=running_sum, + ) + with softmax_loop: + col = softmax_loop.iv + running_max = softmax_loop.running_max + running_sum = softmax_loop.running_sum + col_vec = pto.vlds(scores_tile[col, row_base:]) + merged_max = pto.vmax(running_max, col_vec, active_rows) + running_delta = pto.vsub(running_max, merged_max, active_rows) + scaled_running = pto.vexp(running_delta, active_rows) + running_sum_scaled = pto.vmul(scaled_running, running_sum, active_rows) + col_delta = pto.vsub(col_vec, merged_max, active_rows) + col_exp = pto.vexp(col_delta, active_rows) + merged_sum = pto.vadd(running_sum_scaled, col_exp, active_rows) + softmax_loop.update(running_max=merged_max, running_sum=merged_sum) + + final_max = softmax_loop.final("running_max") + final_sum = softmax_loop.final("running_sum") + + with pto.for_(0, runtime_seq, step=1) as col: + col_vec = pto.vlds(scores_tile[col, row_base:]) + out_delta = pto.vsub(col_vec, final_max, active_rows) + exp_vec = pto.vexp(out_delta, active_rows) + out_vec = pto.vdiv(exp_vec, final_sum, active_rows) + pto.vsts(out_vec, out_tile[col, row_base:], active_rows) + + row_loop.update(remained=remaining_after_pack) + + pto.set_flag("V", "MTE3", event_id=0) + pto.wait_flag("V", "MTE3", event_id=0) + + pto.tile.store(out_tile, out_part) + pto.pipe_barrier(pto.Pipe.ALL) + + return kernel + + +SOFTMAX_ROWS64_SEQ128 = _make_softmax_kernel( + "softmax_rows64_seq128", + rows=64, + seq=128, +) +SOFTMAX_ROWS81_SEQ96 = _make_softmax_kernel( + "softmax_rows81_seq96", + rows=81, + seq=96, +) + +KERNELS = ( + SOFTMAX_ROWS64_SEQ128, + SOFTMAX_ROWS81_SEQ96, +) + +CASES = [ + { + "name": "rows64_seq128", + "kernel": SOFTMAX_ROWS64_SEQ128, + "rows": 64, + "seq": 128, + }, + { + "name": "rows81_seq96", + "kernel": SOFTMAX_ROWS81_SEQ96, + "rows": 81, + "seq": 96, + }, +] + + +def emit_mlir(): + return pto.merge_jit_modules(*KERNELS) + + +def reference_softmax(scores: np.ndarray): + row_max = np.max(scores, axis=0, keepdims=True) + shifted = np.exp(scores - row_max, dtype=np.float32) + row_sum = np.sum(shifted, axis=0, keepdims=True, dtype=np.float32) + return shifted / row_sum + + +def init_runtime(): + import torch + import torch_npu # noqa: F401 + + torch.npu.config.allow_internal_format = False + torch_npu.npu.set_compile_mode(jit_compile=False) + torch.npu.set_device(_DEVICE) + return torch + + +def npu_stream(torch): + return torch.npu.current_stream()._as_parameter_ # noqa: SLF001 + + +def make_case_inputs(case: dict[str, object]): + rows = int(case["rows"]) + seq = int(case["seq"]) + rng = np.random.RandomState(hash(case["name"]) & 0xFFFFFFFF) + + scores = rng.uniform(-4.0, 4.0, size=(seq, rows)).astype(np.float32) + out = np.zeros((seq, rows), dtype=np.float32) + + return scores, out + + +def run_case(case: dict[str, object], torch) -> None: + scores, out = make_case_inputs(case) + ref_out = reference_softmax(scores) + + scores_t = torch.from_numpy(scores).to(_DEVICE) + out_t = torch.from_numpy(out).to(_DEVICE) + stream = npu_stream(torch) + + t0 = time.perf_counter() + compiled = case["kernel"].compile() + compile_s = time.perf_counter() - t0 + + t0 = time.perf_counter() + compiled[1, stream]( + scores_t, + out_t, + ) + torch.npu.synchronize() + launch_s = time.perf_counter() - t0 + + np.testing.assert_allclose(out_t.cpu().numpy(), ref_out, rtol=1e-5, atol=1e-5) + + print( + f"PASS {case['name']} " + f"compile={compile_s:.3f}s launch={launch_s:.3f}s" + ) + + +def test_softmax() -> None: + torch = init_runtime() + for case in CASES: + run_case(case, torch) + print("All cases passed.") + + +def main(argv=None) -> int: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + "--emit-mlir", + action="store_true", + help="print the merged MLIR module and exit", + ) + args = parser.parse_args(argv) + + if args.emit_mlir: + print(emit_mlir()) + return 0 + + test_softmax() + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/ptodsl/examples/jit/tadd_launch.py b/ptodsl/examples/jit/tadd_launch.py new file mode 100644 index 000000000..734041689 --- /dev/null +++ b/ptodsl/examples/jit/tadd_launch.py @@ -0,0 +1,182 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +""" +TADD tile kernel — Python DSL equivalent of + test/tilelang_st/npu/a5/src/st/testcase/tadd/tadd.pto + +End-to-end: @pto.jit → MLIR → binary → launch → accuracy check. +""" + +import argparse +import time +from pathlib import Path +import sys + +import numpy as np + +if __package__ in {None, ""}: + here = Path(__file__).resolve() + for candidate in here.parents: + if (candidate / "ptodsl" / "__init__.py").exists(): + sys.path.insert(0, str(candidate)) + break + else: + raise RuntimeError( + "Unable to locate the PTODSL Python package root from tadd_launch.py" + ) + +from ptodsl import pto + +_DEVICE = "npu:0" + + +# --------------------------------------------------------------------------- +# Kernel +# --------------------------------------------------------------------------- + +def _tadd_tile(A, B, C, rows: int, cols: int) -> None: + c0 = pto.const(0) + c1 = pto.const(1) + c_rows = pto.const(rows) + c_cols = c_rows if rows == cols else pto.const(cols) + c_elems = pto.const(rows * cols) + + shape = [c1, c1, c1, c_rows, c_cols] + strides = [c_elems, c_elems, c_elems, c_cols, c1] + off = [c0, c0, c0, c0, c0] + + a_view = pto.make_tensor_view(A, shape=shape, strides=strides) + b_view = pto.make_tensor_view(B, shape=shape, strides=strides) + c_view = pto.make_tensor_view(C, shape=shape, strides=strides) + + a_part = pto.partition_view(a_view, offsets=off, sizes=shape) + b_part = pto.partition_view(b_view, offsets=off, sizes=shape) + c_part = pto.partition_view(c_view, offsets=off, sizes=shape) + + a_tile = pto.alloc_tile(shape=[rows, cols], dtype=pto.float32) + b_tile = pto.alloc_tile(shape=[rows, cols], dtype=pto.float32) + c_tile = pto.alloc_tile(shape=[rows, cols], dtype=pto.float32) + + pto.tile.load(a_part, a_tile) + pto.tile.load(b_part, b_tile) + pto.tile.add(a_tile, b_tile, c_tile) + pto.tile.store(c_tile, c_part) + + +@pto.jit( + name="TADD_f32_16x64", + kernel_kind="vector", + target="a5", +) +def TADD_f32_16x64( + A: pto.tensor_spec(rank=2, dtype=pto.f32), + B: pto.tensor_spec(rank=2, dtype=pto.f32), + C: pto.tensor_spec(rank=2, dtype=pto.f32), +): + _tadd_tile(A, B, C, 16, 64) + + +@pto.jit( + name="TADD_f32_32x32", + kernel_kind="vector", + target="a5", +) +def TADD_f32_32x32( + A: pto.tensor_spec(rank=2, dtype=pto.f32), + B: pto.tensor_spec(rank=2, dtype=pto.f32), + C: pto.tensor_spec(rank=2, dtype=pto.f32), +): + _tadd_tile(A, B, C, 32, 32) + + +KERNELS = (TADD_f32_16x64, TADD_f32_32x32) + + +def emit_mlir(): + return pto.merge_jit_modules(*KERNELS) + + +# --------------------------------------------------------------------------- +# Host +# --------------------------------------------------------------------------- + +CASES = [ + {"name": "f32_16x64", "kernel": TADD_f32_16x64, "shape": (16, 64), "eps": 1e-6}, + {"name": "f32_32x32", "kernel": TADD_f32_32x32, "shape": (32, 32), "eps": 1e-6}, +] + + +def init_torch_npu() -> None: + import torch + import torch_npu # noqa: F401 + + torch.npu.config.allow_internal_format = False + torch_npu.npu.set_compile_mode(jit_compile=False) + torch.npu.set_device(_DEVICE) + return torch + + +def npu_stream(torch): + return torch.npu.current_stream()._as_parameter_ # noqa: SLF001 + + +def run_case(case: dict, torch) -> None: + shape = case["shape"] + rng = np.random.RandomState(hash(case["name"]) & 0xFFFFFFFF) + x = rng.randint(1, 10, size=shape).astype(np.float32) + y = rng.randint(1, 10, size=shape).astype(np.float32) + ref = x + y + + a = torch.from_numpy(x).to(_DEVICE) + b = torch.from_numpy(y).to(_DEVICE) + c = torch.empty(shape, dtype=torch.float32, device=_DEVICE) + stream = npu_stream(torch) + + t0 = time.perf_counter() + compiled = case["kernel"].compile() + compile_s = time.perf_counter() - t0 + + t0 = time.perf_counter() + compiled[1, stream](a, b, c) + torch.npu.synchronize() + launch_s = time.perf_counter() - t0 + + torch.testing.assert_close(ref, c.cpu().numpy(), rtol=case["eps"], atol=case["eps"]) + print( + f"PASS {case['name']} " + f"compile={compile_s:.3f}s launch={launch_s:.3f}s" + ) + + +def test_tadd() -> None: + torch = init_torch_npu() + for case in CASES: + run_case(case, torch) + print("All cases passed.") + + +def main(argv=None) -> int: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + "--emit-mlir", + action="store_true", + help="print merged MLIR module and exit (compile-only)", + ) + args = parser.parse_args(argv) + + if args.emit_mlir: + print(emit_mlir()) + return 0 + + test_tadd() + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/ptodsl/examples/softmax_dsl.py b/ptodsl/examples/softmax_dsl.py new file mode 100644 index 000000000..fb417bbe1 --- /dev/null +++ b/ptodsl/examples/softmax_dsl.py @@ -0,0 +1,165 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +""" +Row-wise softmax kernel – compile-only DSL builder. + +This sample mirrors the launchable softmax demo. It uses a transposed logical +GM view so each UB row holds one score column, then processes 64 rows in +parallel with the online-softmax recurrence using only public PTODSL surface +syntax. +""" + +from pathlib import Path +import sys + +if __package__ in {None, ""}: + here = Path(__file__).resolve() + for candidate in here.parents: + if (candidate / "ptodsl" / "__init__.py").exists(): + sys.path.insert(0, str(candidate)) + break + else: + raise RuntimeError( + "Unable to locate the PTODSL Python package root from softmax_dsl.py" + ) + +from ptodsl import pto + + +def _make_softmax_kernel(name: str, *, rows: int, seq: int): + if rows <= 0: + raise ValueError("rows must be positive") + if seq <= 0: + raise ValueError("seq must be positive") + + @pto.jit( + name=name, + kernel_kind="vector", + target="a5", + mode="explicit", + insert_sync=False, + ) + def kernel( + scores: pto.tensor_spec(rank=2, dtype=pto.f32), + out: pto.tensor_spec(rank=2, dtype=pto.f32), + ): + packed_rows = pto.elements_per_vreg(pto.f32) + physical_rows = ((rows + packed_rows - 1) // packed_rows) * packed_rows + scores_tile_bytes = seq * physical_rows * pto.bytewidth(pto.f32) + runtime_rows = scores.shape[0] + runtime_seq = scores.shape[1] + has_rows = runtime_rows > 0 + + with pto.if_(has_rows) as has_rows_br: + with has_rows_br.then_: + scores_view = pto.make_tensor_view( + scores, + shape=[seq, rows], + strides=[1, seq], + ) + out_view = pto.make_tensor_view( + out, + shape=[seq, rows], + strides=[1, seq], + ) + scores_part = pto.partition_view( + scores_view, + offsets=[0, 0], + sizes=[runtime_seq, runtime_rows], + ) + out_part = pto.partition_view( + out_view, + offsets=[0, 0], + sizes=[runtime_seq, runtime_rows], + ) + + scores_tile = pto.alloc_tile( + shape=[seq, physical_rows], + dtype=pto.float32, + addr=pto.const(0, dtype=pto.i64), + valid_shape=[runtime_seq, runtime_rows], + ) + out_tile = pto.alloc_tile( + shape=[seq, physical_rows], + dtype=pto.float32, + addr=pto.const(scores_tile_bytes, dtype=pto.i64), + valid_shape=[runtime_seq, runtime_rows], + ) + + pto.tile.load(scores_part, scores_tile) + + pto.set_flag("MTE2", "V", event_id=0) + pto.wait_flag("MTE2", "V", event_id=0) + + with pto.simd(): + row_loop = pto.for_(0, runtime_rows, step=packed_rows).carry(remained=runtime_rows) + with row_loop: + row_base = row_loop.iv + remaining_rows = row_loop.remained + active_rows, remaining_after_pack = pto.make_mask(pto.f32, remaining_rows) + running_max = pto.vlds(scores_tile[0, row_base:]) + running_sum = pto.vbr(1.0) + + softmax_loop = pto.for_(1, runtime_seq, step=1).carry( + running_max=running_max, + running_sum=running_sum, + ) + with softmax_loop: + col = softmax_loop.iv + running_max = softmax_loop.running_max + running_sum = softmax_loop.running_sum + col_vec = pto.vlds(scores_tile[col, row_base:]) + merged_max = pto.vmax(running_max, col_vec, active_rows) + running_delta = pto.vsub(running_max, merged_max, active_rows) + scaled_running = pto.vexp(running_delta, active_rows) + running_sum_scaled = pto.vmul(scaled_running, running_sum, active_rows) + col_delta = pto.vsub(col_vec, merged_max, active_rows) + col_exp = pto.vexp(col_delta, active_rows) + merged_sum = pto.vadd(running_sum_scaled, col_exp, active_rows) + softmax_loop.update(running_max=merged_max, running_sum=merged_sum) + + final_max = softmax_loop.final("running_max") + final_sum = softmax_loop.final("running_sum") + + with pto.for_(0, runtime_seq, step=1) as col: + col_vec = pto.vlds(scores_tile[col, row_base:]) + out_delta = pto.vsub(col_vec, final_max, active_rows) + exp_vec = pto.vexp(out_delta, active_rows) + out_vec = pto.vdiv(exp_vec, final_sum, active_rows) + pto.vsts(out_vec, out_tile[col, row_base:], active_rows) + + row_loop.update(remained=remaining_after_pack) + + pto.set_flag("V", "MTE3", event_id=0) + pto.wait_flag("V", "MTE3", event_id=0) + + pto.tile.store(out_tile, out_part) + pto.pipe_barrier(pto.Pipe.ALL) + + return kernel + + +SOFTMAX_ROWS64_SEQ128 = _make_softmax_kernel( + "softmax_rows64_seq128_dsl", + rows=64, + seq=128, +) +SOFTMAX_ROWS81_SEQ96 = _make_softmax_kernel( + "softmax_rows81_seq96_dsl", + rows=81, + seq=96, +) + + +def build(): + return pto.merge_jit_modules(SOFTMAX_ROWS64_SEQ128, SOFTMAX_ROWS81_SEQ96) + + +if __name__ == "__main__": + print(build()) diff --git a/ptodsl/examples/softmax_lowlevel.py b/ptodsl/examples/softmax_lowlevel.py new file mode 100644 index 000000000..d93e70ccf --- /dev/null +++ b/ptodsl/examples/softmax_lowlevel.py @@ -0,0 +1,357 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +""" +Low-level builder for the online softmax kernel. + +Reconstructs the IR in + test/tilelang_st/npu/a5/src/st/testcase/softmax/softmax.pto +using raw MLIR Python binding calls, with no additional abstraction layer. +""" + +from mlir.ir import ( + Attribute, + Context, + F32Type, + InsertionPoint, + IntegerType, + IndexType, + Location, + Module, + ShapedType, + StringAttr, + Type, + UnitAttr, +) +from mlir.dialects import arith, func, pto, scf + + +def build(): + with Context() as ctx: + pto.register_dialect(ctx, load=True) + + with Location.unknown(): + # ── Types ──────────────────────────────────────────────────────── + i1 = IntegerType.get_signless(1) + i32 = IntegerType.get_signless(32) + i64 = IntegerType.get_signless(64) + idx = IndexType.get() + f32 = F32Type.get() + + # Address-space attributes used in pointer and tile types + _gm = pto.AddressSpaceAttr.get(pto.AddressSpace.GM) # gm = global memory + _ub = pto.AddressSpaceAttr.get(pto.AddressSpace.VEC) # vec = UB (unified buffer) + # Sentinel value for a dynamic (unknown) dimension + _dyn = ShapedType.get_dynamic_size() + + # Pointer types built with PtrType.get + ptr_gm = pto.PtrType.get(f32, memory_space=_gm) # !pto.ptr + ptr_ub = pto.PtrType.get(f32, memory_space=_ub) # !pto.ptr + + # Tensor-view types built with TensorViewType / PartitionTensorViewType + tv5d = pto.TensorViewType.get(5, f32) # !pto.tensor_view + ptv5d = pto.PartitionTensorViewType.get([_dyn] * 5, f32) # !pto.partition_tensor_view + + # Tile-buffer config attributes + _col_cfg = pto.TileBufConfigAttr.get( + pto.BLayoutAttr.get(pto.BLayout.ColMajor), + pto.SLayoutAttr.get(pto.SLayout.NoneBox), + 512, pto.PadValueAttr.get(pto.PadValue.Null), + ) + _row_cfg = pto.TileBufConfigAttr.get( + pto.BLayoutAttr.get(pto.BLayout.RowMajor), + pto.SLayoutAttr.get(pto.SLayout.NoneBox), + 512, pto.PadValueAttr.get(pto.PadValue.Null), + ) + # !pto.tile_buf + tile_col = pto.TileBufType.get([8, 1], f32, _ub, [-1, 1], _col_cfg) + # !pto.tile_buf + tile_wide = pto.TileBufType.get([8, 128], f32, _ub, [-1, -1], _row_cfg) + + # VReg and Mask types have no Python-binding constructors yet; + # Type.parse is the only available path for these two. + vreg = Type.parse("!pto.vreg<64xf32>") + mask_b32 = Type.parse("!pto.mask") + + # ── Flat single module ──────────────────────────────────────── + m = Module.create() + m.operation.attributes["pto.target_arch"] = StringAttr.get("a5") + # FunctionKernelKindAttr has no binding; Attribute.parse is the only path. + m.operation.attributes["pto.kernel_kind"] = Attribute.parse( + "#pto.kernel_kind" + ) + + fn_ty = func.FunctionType.get([ptr_gm] * 7 + [i32, i32], []) + with InsertionPoint(m.body): + fn = func.FuncOp("online_softmax_update_kernel_2d", fn_ty) + fn.attributes["pto.aicore"] = UnitAttr.get() + entry = fn.add_entry_block() + + with InsertionPoint(entry): + a0, a1, a2, a3, a4, a5, a6, arg7, arg8 = entry.arguments + + # ── Index constants ─────────────────────────────────────── + c0 = arith.ConstantOp(idx, 0).result + c1 = arith.ConstantOp(idx, 1).result + c8 = arith.ConstantOp(idx, 8).result + c64 = arith.ConstantOp(idx, 64).result + c128 = arith.ConstantOp(idx, 128).result + + # ── i64 constants ───────────────────────────────────────── + c0_i64 = arith.ConstantOp(i64, 0).result + c1_i64 = arith.ConstantOp(i64, 1).result + c8_i64 = arith.ConstantOp(i64, 8).result + c16_i64 = arith.ConstantOp(i64, 16).result + c32_i64 = arith.ConstantOp(i64, 32).result + c64_i64 = arith.ConstantOp(i64, 64).result + c128_i64 = arith.ConstantOp(i64, 128).result + c256_i64 = arith.ConstantOp(i64, 256).result + c512_i64 = arith.ConstantOp(i64, 512).result + c8448_i64 = arith.ConstantOp(i64, 8448).result + c16640_i64 = arith.ConstantOp(i64, 16640).result + c16768_i64 = arith.ConstantOp(i64, 16768).result + c16896_i64 = arith.ConstantOp(i64, 16896).result + + # ── i32 constants ───────────────────────────────────────── + c1_i32 = arith.ConstantOp(i32, 1).result + c8_i32 = arith.ConstantOp(i32, 8).result + c64_i32 = arith.ConstantOp(i32, 64).result + c0_i32 = arith.ConstantOp(i32, 0).result + + # ── Block / row computation ─────────────────────────────── + block = pto.GetBlockIdxOp().result # i64 + block_idx = arith.IndexCastOp(idx, block).result + row_base = arith.MulIOp(block_idx, c8).result + block_rows_i32= arith.IndexCastOp(i32, c8).result + row_base_i32 = arith.IndexCastOp(i32, row_base).result + remaining_rows= arith.SubIOp(arg8, row_base_i32).result + has_rows = arith.CmpIOp(arith.CmpIPredicate.sgt, + remaining_rows, c0_i32).result + too_many_rows = arith.CmpIOp(arith.CmpIPredicate.sgt, + remaining_rows, c8_i32).result + row_count_i32 = arith.SelectOp(too_many_rows, c8_i32, + remaining_rows).result + row_count = arith.IndexCastOp(idx, row_count_i32).result + seq = arith.IndexCastOp(idx, arg7).result + rows = arith.IndexCastOp(idx, arg8).result + rows_x_128 = arith.MulIOp(rows, c128).result + + # ── scf.if %has_rows ────────────────────────────────────── + if_rows = scf.IfOp(has_rows) + with InsertionPoint(if_rows.then_block): + + # ── Tensor views ────────────────────────────────────── + s1 = [rows, rows, rows, c1, rows] + s128 = [rows_x_128, rows_x_128, rows_x_128, c128, c1] + sh1 = [c1, c1, c1, rows, c1] + sh128 = [c1, c1, c1, rows, c128] + + oldmax_view = pto.MakeTensorViewOp(tv5d, a0, sh1, s1).result + oldsum_view = pto.MakeTensorViewOp(tv5d, a1, sh1, s1).result + qk_view = pto.MakeTensorViewOp(tv5d, a2, sh128, s128).result + newmax_view = pto.MakeTensorViewOp(tv5d, a3, sh1, s1).result + newsum_view = pto.MakeTensorViewOp(tv5d, a4, sh1, s1).result + expmax_view = pto.MakeTensorViewOp(tv5d, a5, sh1, s1).result + out_view = pto.MakeTensorViewOp(tv5d, a6, sh128, s128).result + + # ── Partition views ─────────────────────────────────── + off5 = [c0, c0, c0, row_base, c0] + sz1 = [c1, c1, c1, row_count, c1] + szs = [c1, c1, c1, row_count, seq] + + oldmax_part = pto.PartitionViewOp(ptv5d, oldmax_view, off5, sz1).result + oldsum_part = pto.PartitionViewOp(ptv5d, oldsum_view, off5, sz1).result + qk_part = pto.PartitionViewOp(ptv5d, qk_view, off5, szs).result + newmax_part = pto.PartitionViewOp(ptv5d, newmax_view, off5, sz1).result + newsum_part = pto.PartitionViewOp(ptv5d, newsum_view, off5, sz1).result + expmax_part = pto.PartitionViewOp(ptv5d, expmax_view, off5, sz1).result + out_part = pto.PartitionViewOp(ptv5d, out_view, off5, szs).result + + # ── Tile allocation ─────────────────────────────────── + oldmax_tile = pto.AllocTileOp(tile_col, addr=c0_i64, valid_row=row_count).result + oldsum_tile = pto.AllocTileOp(tile_col, addr=c128_i64, valid_row=row_count).result + qk_tile = pto.AllocTileOp(tile_wide, addr=c256_i64, valid_row=row_count, valid_col=seq).result + out_tile = pto.AllocTileOp(tile_wide, addr=c8448_i64, valid_row=row_count, valid_col=seq).result + newmax_tile = pto.AllocTileOp(tile_col, addr=c16640_i64, valid_row=row_count).result + newsum_tile = pto.AllocTileOp(tile_col, addr=c16768_i64, valid_row=row_count).result + expmax_tile = pto.AllocTileOp(tile_col, addr=c16896_i64, valid_row=row_count).result + + # ── Tile loads ──────────────────────────────────────── + pto.TLoadOp(None, oldmax_part, oldmax_tile) + pto.TLoadOp(None, oldsum_part, oldsum_tile) + pto.TLoadOp(None, qk_part, qk_tile) + + # ── Sync before vecscope ────────────────────────────── + pto.set_flag("PIPE_MTE2", "PIPE_V", pto.EVENT_ID0) + pto.wait_flag("PIPE_MTE2", "PIPE_V", pto.EVENT_ID0) + + # ── pto.vecscope ────────────────────────────────────── + vs_op = pto.VecScopeOp() + vs_block = vs_op.body.blocks.append() + with InsertionPoint(vs_block): + + # Materialise UB pointers from tile handles + ub_oldmax = pto.TileBufAddrOp(ptr_ub, oldmax_tile).result + ub_oldsum = pto.TileBufAddrOp(ptr_ub, oldsum_tile).result + ub_qk = pto.TileBufAddrOp(ptr_ub, qk_tile).result + ub_out = pto.TileBufAddrOp(ptr_ub, out_tile).result + ub_newmax = pto.TileBufAddrOp(ptr_ub, newmax_tile).result + ub_newsum = pto.TileBufAddrOp(ptr_ub, newsum_tile).result + ub_expmax = pto.TileBufAddrOp(ptr_ub, expmax_tile).result + + active = pto.PsetB32Op(mask_b32, "PAT_ALL").result + plt1 = pto.PltB32Op(mask_b32, i32, c1_i32) + one_mask = plt1.mask + + # ── for row in [0, row_count) ───────────────────── + row_for = scf.ForOp(c0, row_count, c1) + with InsertionPoint(row_for.body): + row = row_for.induction_variable + row_qk = arith.MulIOp(row, c128).result + + oldmax_bc = pto.VldsOp(vreg, ub_oldmax, row, + dist="BRC_B32").result + oldsum_bc = pto.VldsOp(vreg, ub_oldsum, row, + dist="BRC_B32").result + + # ── for chunk in [0,128,64) with iter_args ──── + chunk_for = scf.ForOp(c0, c128, c64, + [oldmax_bc, oldsum_bc]) + with InsertionPoint(chunk_for.body): + chunk = chunk_for.induction_variable + running_max = chunk_for.inner_iter_args[0] + running_sum = chunk_for.inner_iter_args[1] + + chunk_i32 = arith.IndexCastOp(i32, chunk).result + remaining_cols= arith.SubIOp(arg7, chunk_i32).result + has_chunk = arith.CmpIOp( + arith.CmpIPredicate.sgt, + remaining_cols, c0_i32).result + + # ── if has_chunk -> (vreg, vreg) ────────── + c_if = scf.IfOp(has_chunk, [vreg, vreg], + hasElse=True) + with InsertionPoint(c_if.then_block): + cplt = pto.PltB32Op(mask_b32, i32, + remaining_cols) + chunk_mask = cplt.mask + chunk_base = arith.AddIOp(row_qk, + chunk).result + vec = pto.VldsOp(vreg, ub_qk, + chunk_base).result + chunk_max = pto.VcmaxOp(vreg, vec, + chunk_mask).result + chunk_max_bc= pto.VdupOp(vreg, chunk_max, + active, + position="LOWEST").result + merged_max = pto.VmaxOp(vreg, running_max, + chunk_max_bc, + active).result + scaled_run = pto.VexpdifOp(vreg, + running_max, + merged_max, + active, + "ODD").result + run_sum_sc = pto.VmulOp(vreg, scaled_run, + running_sum, + active).result + chunk_exp = pto.VexpdifOp(vreg, vec, + merged_max, + chunk_mask, + "ODD").result + chunk_sum = pto.VcaddOp(vreg, chunk_exp, + chunk_mask).result + chunk_sum_bc= pto.VdupOp(vreg, chunk_sum, + active, + position="LOWEST").result + merged_sum = pto.VaddOp(vreg, run_sum_sc, + chunk_sum_bc, + active).result + scf.YieldOp([merged_max, merged_sum]) + with InsertionPoint(c_if.else_block): + scf.YieldOp([running_max, running_sum]) + + next_max, next_sum = c_if.results + scf.YieldOp([next_max, next_sum]) + + final_max, final_sum = chunk_for.results + + # ── Post-loop: compute expmax ───────────────── + raw_expmax = pto.VexpdifOp(vreg, oldmax_bc, + final_max, active, + "ODD").result + scaled_oldsum = pto.VmulOp(vreg, raw_expmax, + oldsum_bc, + active).result + expmax = pto.VdivOp(vreg, scaled_oldsum, + final_sum, + active).result + + pto.VstsOp(final_max, ub_newmax, row, one_mask, + dist="1PT_B32") + pto.VstsOp(final_sum, ub_newsum, row, one_mask, + dist="1PT_B32") + pto.VstsOp(expmax, ub_expmax, row, one_mask, + dist="1PT_B32") + + # ── Output normalisation loop ───────────────── + out_for = scf.ForOp(c0, c128, c64) + with InsertionPoint(out_for.body): + chunk2 = out_for.induction_variable + ci32_2 = arith.IndexCastOp(i32, + chunk2).result + rem2 = arith.SubIOp(arg7, ci32_2).result + has_chunk2 = arith.CmpIOp( + arith.CmpIPredicate.sgt, + rem2, c0_i32).result + + o_if = scf.IfOp(has_chunk2) + with InsertionPoint(o_if.then_block): + oplt = pto.PltB32Op(mask_b32, i32, + rem2) + cmask2 = oplt.mask + cbase2 = arith.AddIOp(row_qk, + chunk2).result + vec2 = pto.VldsOp(vreg, ub_qk, + cbase2).result + exp2 = pto.VexpdifOp(vreg, vec2, + final_max, + cmask2, + "ODD").result + out2 = pto.VdivOp(vreg, exp2, + final_sum, + cmask2).result + pto.VstsOp(out2, ub_out, cbase2, cmask2) + scf.YieldOp([]) + + scf.YieldOp([]) # out_for body + + scf.YieldOp([]) # row_for body + + # ── Sync after vecscope ─────────────────────────────── + pto.set_flag("PIPE_V", "PIPE_MTE3", pto.EVENT_ID0) + pto.wait_flag("PIPE_V", "PIPE_MTE3", pto.EVENT_ID0) + + # ── Tile stores ─────────────────────────────────────── + pto.TStoreOp(None, newmax_tile, newmax_part) + pto.TStoreOp(None, newsum_tile, newsum_part) + pto.TStoreOp(None, expmax_tile, expmax_part) + pto.TStoreOp(None, out_tile, out_part) + + scf.YieldOp([]) # if_rows then_block + + # ── Barrier and return ──────────────────────────────────── + pto.BarrierOp(pto.PipeAttr.get(pto.PIPE.PIPE_ALL)) + func.ReturnOp([]) + + m.operation.verify() + return m + + +if __name__ == "__main__": + print(build()) diff --git a/ptodsl/examples/tadd_dsl.py b/ptodsl/examples/tadd_dsl.py new file mode 100644 index 000000000..4d1482dce --- /dev/null +++ b/ptodsl/examples/tadd_dsl.py @@ -0,0 +1,67 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +""" +TADD kernel – DSL-style builder. + +Generates the same IR as expand_tileop_to_vpto_result.pto using the +``@pto.jit`` decorator and the ``pto.*`` namespace. + +The Python code maps 1-to-1 to the MLIR IR lines: + + func.func @TADD() { # @pto.jit(name="TADD", …) + %c0_i64 = arith.constant 0 : i64 # pto.const(0, dtype=pto.int64) + %c16 = arith.constant 16 : index # pto.const(16, dtype=pto.index) + … + pto.simd { # with pto.simd(): + %0 = pto.castptr %c4096_i64 … # pto.castptr(c4096_i64, …) + scf.for %arg0 = %c0 to %c16 … { # with pto.for_(c0, c16, step=c1) as i: + %mask, _ = pto.plt_b32 … # pto.plt_b32(c64_i32) + … + } + } + } +""" + +from ptodsl import pto, scalar + +s = scalar # arith shorthand alias + + +@pto.jit(name="TADD", kernel_kind="vector", target="a5") +def TADD(): + c0_i64 = pto.const(0, dtype=pto.int64) + c16 = pto.const(16, dtype=pto.index) + c4096_i64 = pto.const(4096, dtype=pto.int64) + c0 = pto.const(0) + c1 = pto.const(1) + c64_i32 = pto.const(64, dtype=pto.int32) + c64 = pto.const(64) + + with pto.simd(): + ptr_f32_ub = pto.ptr(pto.float32, "ub") + vf32 = pto.vreg_type(64, pto.float32) + ptr_src = pto.castptr(c4096_i64, ptr_f32_ub) + ptr_dst = pto.castptr(c0_i64, ptr_f32_ub) + + with pto.for_(c0, c16, step=c1) as tile_idx: + mask, _ = pto.plt_b32(c64_i32) + tile_off = s.muli(tile_idx, c64) + va = pto.vlds(pto.addptr(ptr_src, tile_off), c0, vf32) + ptr_dst_tile = pto.addptr(ptr_dst, tile_off) + vb = pto.vlds(ptr_dst_tile, c0, vf32) + vc = pto.vadd(va, vb, mask) + pto.vsts(vc, ptr_dst_tile, c0, mask) + + +def build(): + return TADD.mlir_module() + + +if __name__ == "__main__": + print(TADD) diff --git a/ptodsl/examples/tadd_lowlevel.py b/ptodsl/examples/tadd_lowlevel.py new file mode 100644 index 000000000..8dd98c77d --- /dev/null +++ b/ptodsl/examples/tadd_lowlevel.py @@ -0,0 +1,169 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +""" +Builds the MLIR IR module equivalent to expand_tileop_to_vpto_result.pto using +low-level MLIR Python bindings. + +Target IR (expand_tileop_to_vpto_result.pto): + module attributes {pto.target_arch = "a5"} { + module attributes {pto.kernel_kind = #pto.kernel_kind, pto.target_arch = "a5"} { + func.func @TADD() { + %c0_i64 = arith.constant 0 : i64 + %c16 = arith.constant 16 : index + %c4096_i64 = arith.constant 4096 : i64 + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64_i32 = arith.constant 64 : i32 + %c64 = arith.constant 64 : index + pto.vecscope { + %0 = pto.castptr %c4096_i64 : i64 -> !pto.ptr + %1 = pto.castptr %c0_i64 : i64 -> !pto.ptr + scf.for %arg0 = %c0 to %c16 step %c1 { + %mask, %scalar_out = pto.plt_b32 %c64_i32 : i32 -> !pto.mask, i32 + %2 = arith.muli %arg0, %c64 : index + %3 = pto.addptr %0, %2 : -> + %4 = pto.vlds %3[%c0] : !pto.ptr -> !pto.vreg<64xf32> + %5 = pto.addptr %1, %2 : -> + %6 = pto.vlds %5[%c0] : !pto.ptr -> !pto.vreg<64xf32> + %7 = pto.vadd %4, %6, %mask : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %7, %5[%c0], %mask : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + } + } + return + } + } + } +""" + +from mlir.ir import ( + Attribute, + Context, + F32Type, + IntegerType, + IndexType, + InsertionPoint, + Location, + Module, + Operation, + StringAttr, + Type, +) +from mlir.dialects import arith, func, pto, scf + + +def build(): + with Context() as ctx: + pto.register_dialect(ctx, load=True) + + with Location.unknown(): + # ── Types ──────────────────────────────────────────────────────── + i32 = IntegerType.get_signless(32) + i64 = IntegerType.get_signless(64) + idx = IndexType.get() + f32 = F32Type.get() + + # !pto.ptr – pointer to f32 in the UB (VEC) address space + ptr_f32_ub = pto.PtrType.get( + f32, memory_space=pto.AddressSpaceAttr.get(pto.AddressSpace.VEC) + ) + + # VReg and Mask types have no Python-binding constructors yet; + # Type.parse is the only available path for these two. + vreg_64f32 = Type.parse("!pto.vreg<64xf32>") + mask_b32 = Type.parse("!pto.mask") + + # ── Shared attributes ───────────────────────────────────────── + target_arch_attr = StringAttr.get("a5") + kernel_kind_attr = Attribute.parse("#pto.kernel_kind") + + # ── Outer module ───────────────────────────────────────────── + outer_mod = Module.create() + outer_mod.operation.attributes["pto.target_arch"] = target_arch_attr + + with InsertionPoint(outer_mod.body): + # ── Inner module ───────────────────────────────────────── + # Module.create() does not use the active InsertionPoint, so we + # use Operation.create("builtin.module") directly instead. + inner_op = Operation.create("builtin.module", regions=1) + inner_op.attributes["pto.target_arch"] = target_arch_attr + inner_op.attributes["pto.kernel_kind"] = kernel_kind_attr + + # builtin.module needs exactly one block in its body region. + inner_body = inner_op.regions[0].blocks.append() + + with InsertionPoint(inner_body): + # ── func @TADD() ────────────────────────────────────── + fn_ty = func.FunctionType.get([], []) + fn = func.FuncOp("TADD", fn_ty) + entry = fn.add_entry_block() + + with InsertionPoint(entry): + # Constants live outside vecscope; they are visible + # inside because vecscope is not a new scope for SSA. + c0_i64 = arith.ConstantOp(i64, 0).result + c16 = arith.ConstantOp(idx, 16).result + c4096_i64 = arith.ConstantOp(i64, 4096).result + c0 = arith.ConstantOp(idx, 0).result + c1 = arith.ConstantOp(idx, 1).result + c64_i32 = arith.ConstantOp(i32, 64).result + c64 = arith.ConstantOp(idx, 64).result + + # ── pto.vecscope { … } ──────────────────────────── + vecscope_op = pto.VecScopeOp() + # vecscope has one region; we must append its entry block. + vs_block = vecscope_op.body.blocks.append() + + with InsertionPoint(vs_block): + # %0 = pto.castptr %c4096_i64 : i64 -> !pto.ptr + ptr0 = pto.CastPtrOp(ptr_f32_ub, c4096_i64).result + + # %1 = pto.castptr %c0_i64 : i64 -> !pto.ptr + ptr1 = pto.CastPtrOp(ptr_f32_ub, c0_i64).result + + # scf.for %arg0 = %c0 to %c16 step %c1 { … } + for_op = scf.ForOp(c0, c16, c1) + with InsertionPoint(for_op.body): + arg0 = for_op.induction_variable + + # %mask, %scalar_out = pto.plt_b32 %c64_i32 + plt = pto.PltB32Op(mask_b32, i32, c64_i32) + mask = plt.mask + # scalar_out is unused in this kernel + + # %2 = arith.muli %arg0, %c64 : index + off = arith.MulIOp(arg0, c64).result + + # %3 = pto.addptr %0, %2 + ptr3 = pto.AddPtrOp(ptr0, off).result + + # %4 = pto.vlds %3[%c0] : !pto.ptr -> !pto.vreg<64xf32> + vreg4 = pto.VldsOp(vreg_64f32, ptr3, c0).result + + # %5 = pto.addptr %1, %2 + ptr5 = pto.AddPtrOp(ptr1, off).result + + # %6 = pto.vlds %5[%c0] : !pto.ptr -> !pto.vreg<64xf32> + vreg6 = pto.VldsOp(vreg_64f32, ptr5, c0).result + + # %7 = pto.vadd %4, %6, %mask + vreg7 = pto.VaddOp(vreg_64f32, vreg4, vreg6, mask).result + + # pto.vsts %7, %5[%c0], %mask + pto.VstsOp(vreg7, ptr5, c0, mask) + + scf.YieldOp([]) + + func.ReturnOp([]) + + outer_mod.operation.verify() + return outer_mod + + +if __name__ == "__main__": + print(build()) diff --git a/ptodsl/examples/tilelang_codegen.py b/ptodsl/examples/tilelang_codegen.py new file mode 100644 index 000000000..de7dbc6ab --- /dev/null +++ b/ptodsl/examples/tilelang_codegen.py @@ -0,0 +1,314 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +""" +TileLang-generated explicit PTODSL kernel. + +This file keeps the original generated kernel body essentially intact and only +adds the minimum wrapper needed to make it usable as a compile/test target: + +- public `@pto.jit` host ABI via `tensor_spec(...)` +- `--emit-mlir` entry point +- compile smoke path for regression tests +""" + +import argparse +from pathlib import Path +import sys +import time + +import numpy as np + +if __package__ in {None, ""}: + here = Path(__file__).resolve() + for candidate in here.parents: + if (candidate / "ptodsl" / "__init__.py").exists(): + sys.path.insert(0, str(candidate)) + break + else: + raise RuntimeError( + "Unable to locate the PTODSL Python package root from tilelang_codegen.py" + ) + +from ptodsl import pto + + +_DEVICE = "npu:0" + + +def _tilelang_generated_body( + A, + B, + C, +): + bx = pto.get_block_idx() + buf_dyn_shmem = pto.const(0, dtype=pto.int64) + with pto.for_(0, 2, step=1) as f: + pto.set_flag("MTE3", "V", event_id=f) + pto.set_flag("V", "MTE2", event_id=f) + with pto.for_(0, 2048, step=1) as iter: + pto.wait_flag("V", "MTE2", event_id=iter % 2) + pto.mte_gm_ub( + pto.addptr(A, (iter * 524288) + (bx * 8192)), + pto.addptr( + pto.castptr(buf_dyn_shmem, pto.ptr(pto.float32, "ub")), + (iter % 2) * 8192, + ), + 0, + 32768, + nburst=(1, 0, 0), + ) + pto.mte_gm_ub( + pto.addptr(B, (iter * 524288) + (bx * 8192)), + pto.addptr( + pto.castptr(buf_dyn_shmem, pto.ptr(pto.float32, "ub")), + ((iter % 2) * 8192) + 16384, + ), + 0, + 32768, + nburst=(1, 0, 0), + ) + pto.set_flag("MTE2", "V", event_id=iter % 2) + pto.wait_flag("MTE2", "V", event_id=iter % 2) + pto.wait_flag("MTE3", "V", event_id=iter % 2) + with pto.simd(): + mask_cnt = 8192 + with pto.for_(0, 128, step=1) as i: + mask = pto.pset_b32("PAT_ALL") + r0 = pto.vlds( + pto.addptr( + pto.castptr(buf_dyn_shmem, pto.ptr(pto.float32, "ub")), + ((iter % 2) * 8192) + (i * 64), + ), + pto.const(0), + pto.vreg_type(64, pto.float32), + ) + r1 = pto.vlds( + pto.addptr( + pto.castptr(buf_dyn_shmem, pto.ptr(pto.float32, "ub")), + (((iter % 2) * 8192) + (i * 64)) + 16384, + ), + pto.const(0), + pto.vreg_type(64, pto.float32), + ) + r0 = pto.vadd(r0, r1, mask) + pto.vsts( + r0, + pto.addptr( + pto.castptr(buf_dyn_shmem, pto.ptr(pto.float32, "ub")), + (((iter % 2) * 8192) + (i * 64)) + 32768, + ), + pto.const(0), + mask, + ) + pto.set_flag("V", "MTE3", event_id=iter % 2) + pto.set_flag("V", "MTE2", event_id=iter % 2) + pto.wait_flag("V", "MTE3", event_id=iter % 2) + pto.mte_ub_gm( + pto.addptr( + pto.castptr(buf_dyn_shmem, pto.ptr(pto.float32, "ub")), + ((iter % 2) * 8192) + 32768, + ), + pto.addptr(C, (iter * 524288) + (bx * 8192)), + 32768, + nburst=(1, 0, 0), + ) + pto.set_flag("MTE3", "V", event_id=iter % 2) + with pto.for_(0, 2, step=1) as f_1: + pto.wait_flag("MTE3", "V", event_id=f_1) + pto.wait_flag("V", "MTE2", event_id=f_1) + + +@pto.jit( + name="main_kernel", + kernel_kind="vector", + target="a5", + mode="explicit", + insert_sync=False, +) +def main_kernel( + A: pto.tensor_spec(rank=1, dtype=pto.f32), + B: pto.tensor_spec(rank=1, dtype=pto.f32), + C: pto.tensor_spec(rank=1, dtype=pto.f32), +): + _tilelang_generated_body(A.data_handle, B.data_handle, C.data_handle) + + +def _tilelang_generated_body_small(A, B, C): + bx = pto.get_block_idx() + buf_dyn_shmem = pto.const(0, dtype=pto.int64) + with pto.for_(0, 2, step=1) as f: + pto.set_flag("MTE3", "V", event_id=f) + pto.set_flag("V", "MTE2", event_id=f) + with pto.for_(0, 2, step=1) as iter: + pto.wait_flag("V", "MTE2", event_id=iter % 2) + pto.mte_gm_ub( + pto.addptr(A, (iter * 128) + (bx * 128)), + pto.addptr( + pto.castptr(buf_dyn_shmem, pto.ptr(pto.float32, "ub")), + (iter % 2) * 128, + ), + 0, + 512, + nburst=(1, 0, 0), + ) + pto.mte_gm_ub( + pto.addptr(B, (iter * 128) + (bx * 128)), + pto.addptr( + pto.castptr(buf_dyn_shmem, pto.ptr(pto.float32, "ub")), + ((iter % 2) * 128) + 256, + ), + 0, + 512, + nburst=(1, 0, 0), + ) + pto.set_flag("MTE2", "V", event_id=iter % 2) + pto.wait_flag("MTE2", "V", event_id=iter % 2) + pto.wait_flag("MTE3", "V", event_id=iter % 2) + with pto.simd(): + with pto.for_(0, 2, step=1) as i: + mask = pto.pset_b32("PAT_ALL") + r0 = pto.vlds( + pto.addptr( + pto.castptr(buf_dyn_shmem, pto.ptr(pto.float32, "ub")), + ((iter % 2) * 128) + (i * 64), + ), + pto.const(0), + pto.vreg_type(64, pto.float32), + ) + r1 = pto.vlds( + pto.addptr( + pto.castptr(buf_dyn_shmem, pto.ptr(pto.float32, "ub")), + (((iter % 2) * 128) + (i * 64)) + 256, + ), + pto.const(0), + pto.vreg_type(64, pto.float32), + ) + r0 = pto.vadd(r0, r1, mask) + pto.vsts( + r0, + pto.addptr( + pto.castptr(buf_dyn_shmem, pto.ptr(pto.float32, "ub")), + (((iter % 2) * 128) + (i * 64)) + 512, + ), + pto.const(0), + mask, + ) + pto.set_flag("V", "MTE3", event_id=iter % 2) + pto.set_flag("V", "MTE2", event_id=iter % 2) + pto.wait_flag("V", "MTE3", event_id=iter % 2) + pto.mte_ub_gm( + pto.addptr( + pto.castptr(buf_dyn_shmem, pto.ptr(pto.float32, "ub")), + ((iter % 2) * 128) + 512, + ), + pto.addptr(C, (iter * 128) + (bx * 128)), + 512, + nburst=(1, 0, 0), + ) + pto.set_flag("MTE3", "V", event_id=iter % 2) + with pto.for_(0, 2, step=1) as f_1: + pto.wait_flag("MTE3", "V", event_id=f_1) + pto.wait_flag("V", "MTE2", event_id=f_1) + + +@pto.jit( + name="main_kernel_precision_test", + kernel_kind="vector", + target="a5", + mode="explicit", + insert_sync=False, +) +def main_kernel_precision_test( + A: pto.tensor_spec(rank=1, dtype=pto.f32), + B: pto.tensor_spec(rank=1, dtype=pto.f32), + C: pto.tensor_spec(rank=1, dtype=pto.f32), +): + _tilelang_generated_body_small(A.data_handle, B.data_handle, C.data_handle) + + +def emit_mlir(): + return main_kernel.mlir_text() + + +def compile_kernel(): + compiled = main_kernel.compile() + compiled.verify() + return compiled + + +def init_torch_npu(): + import torch + import torch_npu # noqa: F401 + + torch.npu.config.allow_internal_format = False + torch_npu.npu.set_compile_mode(jit_compile=False) + torch.npu.set_device(_DEVICE) + return torch + + +def npu_stream(torch): + return torch.npu.current_stream()._as_parameter_ # noqa: SLF001 + + +def make_case_inputs(): + total = 256 + rng = np.random.RandomState(20260524) + a = rng.uniform(-3.0, 3.0, size=(total,)).astype(np.float32) + b = rng.uniform(-3.0, 3.0, size=(total,)).astype(np.float32) + c = np.full((total,), np.nan, dtype=np.float32) + return a, b, c + + +def run_precision_case(torch) -> None: + a_np, b_np, c_np = make_case_inputs() + ref = a_np + b_np + + a_t = torch.from_numpy(a_np).to(_DEVICE) + b_t = torch.from_numpy(b_np).to(_DEVICE) + c_t = torch.from_numpy(c_np).to(_DEVICE) + stream = npu_stream(torch) + + t0 = time.perf_counter() + compiled = main_kernel_precision_test.compile() + compile_s = time.perf_counter() - t0 + + t0 = time.perf_counter() + compiled[1, stream](a_t, b_t, c_t) + torch.npu.synchronize() + launch_s = time.perf_counter() - t0 + + np.testing.assert_allclose(c_t.cpu().numpy(), ref, rtol=1e-6, atol=1e-6) + print(f"PASS tilelang_codegen compile={compile_s:.3f}s launch={launch_s:.3f}s") + + +def test_tilelang_codegen() -> None: + torch = init_torch_npu() + run_precision_case(torch) + print("All cases passed.") + + +def main(argv=None) -> int: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + "--emit-mlir", + action="store_true", + help="print compiled MLIR and exit", + ) + args = parser.parse_args(argv) + + if args.emit_mlir: + print(emit_mlir()) + return 0 + + test_tilelang_codegen() + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/ptodsl/ptodsl/__init__.py b/ptodsl/ptodsl/__init__.py new file mode 100644 index 000000000..a0722d975 --- /dev/null +++ b/ptodsl/ptodsl/__init__.py @@ -0,0 +1,20 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +"""ptodsl – PTO MLIR DSL package.""" + +from importlib import import_module + +__all__ = ["pto", "scalar"] + + +def __getattr__(name): + if name in {"pto", "scalar"}: + module = import_module(f".{name}", __name__) + globals()[name] = module + return module + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/ptodsl/ptodsl/_bootstrap.py b/ptodsl/ptodsl/_bootstrap.py new file mode 100644 index 000000000..ec9a6707a --- /dev/null +++ b/ptodsl/ptodsl/_bootstrap.py @@ -0,0 +1,71 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +""" +MLIR path bootstrap and context factory. + +Discovers local LLVM MLIR Python bindings plus PTO Python dialect artifacts so +that ``ptodsl`` can import ``mlir`` / ``mlir.dialects.pto`` directly from a +developer workspace without requiring the caller to pre-seed ``PYTHONPATH``. +""" + +import os +import sys +from pathlib import Path + + +def _candidate_python_roots() -> list[Path]: + here = Path(__file__).resolve() + repo_root = here.parents[2] + workspace_root = repo_root.parent + env_roots = [] + for env_name in ("MLIR_PYTHON_ROOT", "PTO_PYTHON_ROOT"): + raw = os.environ.get(env_name) + if raw: + env_roots.append(Path(raw)) + + return [ + *env_roots, + repo_root / "build" / "python", + repo_root / "install", + workspace_root / "llvm-project" / "build-shared" / "tools" / "mlir" / "python_packages" / "mlir_core", + ] + + +def _bootstrap_python_paths() -> None: + ordered_roots: list[str] = [] + seen = set() + for root in _candidate_python_roots(): + if not root or not root.is_dir(): + continue + if not (root / "mlir").exists(): + continue + root_text = str(root) + if root_text in seen: + continue + ordered_roots.append(root_text) + seen.add(root_text) + for root_text in reversed(ordered_roots): + if root_text in sys.path: + sys.path.remove(root_text) + sys.path.insert(0, root_text) + + +_bootstrap_python_paths() + +from mlir.dialects import pto as _pto_dialect # noqa: E402 +from mlir.ir import Context, Location # noqa: E402 + + +def make_context() -> Context: + """Create a fresh MLIR Context with the PTO dialect loaded.""" + ctx = Context() + _pto_dialect.register_dialect(ctx, load=True) + return ctx + + +__all__ = ["make_context"] diff --git a/ptodsl/ptodsl/_control_flow.py b/ptodsl/ptodsl/_control_flow.py new file mode 100644 index 000000000..5520d7093 --- /dev/null +++ b/ptodsl/ptodsl/_control_flow.py @@ -0,0 +1,578 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +""" +Control-flow context managers for PTO kernels. + +All CMs work with the current MLIR insertion point; no context threading needed. + +Public API +────────── +``vecscope()`` – ``pto.vecscope { … }`` +``for_(lo, hi, step)`` + – ``scf.for`` with optional named carry state via ``.carry(...)`` +``if_(cond)`` – ``scf.if`` via explicit branch handle + automatic named merge +``yield_(*vals)`` – ``scf.yield`` +""" + +from ._bootstrap import make_context # noqa: F401 +from ._runtime_index_ops import coerce_runtime_index +from ._tracing.active import current_session +from ._surface_values import unwrap_surface_value, wrap_like_surface_value, wrap_surface_value + +from mlir.dialects import pto as _pto, scf +from mlir.ir import InsertionPoint + + +# ── vecscope ────────────────────────────────────────────────────────────────── + +class _VecScopeCM: + """Context manager for ``pto.vecscope { … }``.""" + + def __enter__(self): + self._op = _pto.VecScopeOp() + self._block = self._op.body.blocks.append() + self._ip = InsertionPoint(self._block) + self._ip.__enter__() + return None + + def __exit__(self, *exc): + self._ip.__exit__(*exc) + + +def vecscope() -> _VecScopeCM: + """Return a context manager that emits ``pto.vecscope { … }``.""" + return _VecScopeCM() + + +# ── for_ ────────────────────────────────────────────────────────────────────── + +class LoopHandle: + """ + Internal handle for a lowered ``scf.for`` loop. + + Attributes used by the control-flow implementation:: + + loop.iv – induction variable + loop.iter_args – tuple of inner (mutable) SSA values + loop.results – tuple of ForOp results (after loop exit) + """ + + def __init__(self, for_op, *, iter_arg_templates=()): + self._op = for_op + self._iter_arg_templates = tuple(iter_arg_templates) + + @property + def iv(self): + return wrap_surface_value(self._op.induction_variable) + + @property + def iter_args(self): + return tuple( + wrap_like_surface_value(template, value) + for template, value in zip(self._iter_arg_templates, self._op.inner_iter_args) + ) + + @property + def results(self): + return tuple( + wrap_like_surface_value(template, value) + for template, value in zip(self._iter_arg_templates, self._op.results) + ) + + +class _ForCM: + def __init__(self, start, stop, step, iter_args): + self._start = start + self._stop = stop + self._step = step + self._iter_arg_templates = tuple(iter_args) if iter_args is not None else () + self._iter_args = [unwrap_surface_value(value) for value in self._iter_arg_templates] + self._for_op = None + self._ip = None + + def __enter__(self): + self._for_op = scf.ForOp( + _coerce_index(self._start), + _coerce_index(self._stop), + _coerce_index(self._step), + self._iter_args if self._iter_args else None, + ) + self._ip = InsertionPoint(self._for_op.body) + self._ip.__enter__() + if not self._iter_args: + return wrap_surface_value(self._for_op.induction_variable) + return LoopHandle(self._for_op, iter_arg_templates=self._iter_arg_templates) + + def __exit__(self, *exc): + if not self._iter_args: + scf.YieldOp([]) + self._ip.__exit__(*exc) + + +def for_(start, stop, *, step): + """ + ``scf.for`` context manager. + + Yields the induction variable; ``scf.yield`` is inserted automatically:: + + with pto.for_(c0, c16, step=c1) as i: + ... + + Named carry state is expressed with ``.carry(...)``:: + + loop = pto.for_(c0, c128, step=c64).carry(acc=tile) + with loop: + cur = loop.acc + loop.update(acc=cur) + out = loop.final("acc") + """ + return _ForBuilder(start, stop, step) + + +class _CarryLoopStateView: + def __init__(self, names, values): + self._names = tuple(names) + self._values = dict(zip(self._names, values)) + + def __getattr__(self, name): + try: + return self._values[name] + except KeyError as exc: + raise AttributeError(name) from exc + + +class _CarryForCM(_ForCM): + def __init__(self, start, stop, step, state_items): + self._state_items = tuple(state_items) + self._state_names = tuple(name for name, _ in self._state_items) + self._state_templates = tuple(value for _, value in self._state_items) + self._session = None + self._session_frame = None + super().__init__(start, stop, step, self._state_templates) + self._yield_values = None + self._entered = False + + def __enter__(self): + self._session = current_session() + if self._session is not None: + self._session_frame = self._session.begin_carry_loop( + self._start, + self._stop, + self._step, + self._state_items, + ) + self._for_op = self._session_frame.for_op + handle = LoopHandle(self._for_op, iter_arg_templates=self._state_templates) + else: + handle = super().__enter__() + self._entered = True + self._yield_values = None + self._loop_handle = handle + self._state = _CarryLoopStateView(self._state_names, handle.iter_args) + return self + + def __exit__(self, exc_type, exc, tb): + try: + if self._session_frame is not None: + self._session.finish_carry_loop(self._session_frame, exc_type, exc, tb) + return None + if exc_type is None: + if self._yield_values is None: + raise RuntimeError( + "pto.for_(...).carry(...) requires loop.update(...) before leaving the loop body" + ) + scf.YieldOp(self._yield_values) + return super().__exit__(exc_type, exc, tb) + finally: + self._entered = False + self._session = None + self._session_frame = None + + @property + def iv(self): + if not self._entered: + raise RuntimeError("loop.iv is only available inside an active carry loop body") + return self._loop_handle.iv + + def __getattr__(self, name): + if name in self._state_names: + if not self._entered: + raise RuntimeError(f"loop.{name} is only available inside an active carry loop body") + return getattr(self._state, name) + raise AttributeError(name) + + def update(self, **kwargs): + if not self._entered: + raise RuntimeError("loop.update(...) may only be called inside the loop body") + if self._session_frame is not None: + self._session.update_carry_loop(self._session_frame, **kwargs) + return + missing = [name for name in self._state_names if name not in kwargs] + extra = [name for name in kwargs if name not in self._state_names] + if missing or extra: + pieces = [] + if missing: + pieces.append(f"missing: {', '.join(missing)}") + if extra: + pieces.append(f"unexpected: {', '.join(extra)}") + raise RuntimeError("loop.update(...) must match carry names exactly; " + "; ".join(pieces)) + if self._yield_values is not None: + raise RuntimeError("loop.update(...) may only be called once per loop body") + self._yield_values = [ + unwrap_surface_value(kwargs[name]) + for name in self._state_names + ] + + def final(self, name): + if self._for_op is None: + raise RuntimeError("loop.final(...) is only available after the loop has been built") + try: + index = self._state_names.index(name) + except ValueError as exc: + raise RuntimeError( + f"loop.final(...) requested unknown carry state '{name}'; " + f"expected one of: {', '.join(self._state_names)}" + ) from exc + return wrap_like_surface_value(self._state_templates[index], self._for_op.results[index]) + + +class _ForBuilder: + def __init__(self, start, stop, step): + self._start = start + self._stop = stop + self._step = step + + def __enter__(self): + self._cm = _ForCM(self._start, self._stop, self._step, None) + return self._cm.__enter__() + + def __exit__(self, *exc): + return self._cm.__exit__(*exc) + + def carry(self, **kwargs): + if not kwargs: + raise ValueError("carry(...) requires at least one named loop-carried value") + for name in kwargs: + if not isinstance(name, str) or not name: + raise TypeError("carry(...) names must be non-empty strings") + return _CarryForCM(self._start, self._stop, self._step, tuple(kwargs.items())) + + +def _coerce_index(value): + raw_value = unwrap_surface_value(value) + return coerce_runtime_index(raw_value, context="pto.for_(...) loop bound") + + +# ── if_ ─────────────────────────────────────────────────────────────────────── + +def _find_parent_block(op_view): + """Return the block that directly contains *op_view*.""" + parent_op = op_view.operation.parent + if parent_op is None: + raise RuntimeError("unable to locate the parent block for pto.if_(...)") + for region in parent_op.regions: + for block in region.blocks: + for candidate in block.operations: + if candidate.operation is op_view.operation: + return block + raise RuntimeError("unable to locate the parent block for pto.if_(...)") + + +def _move_block_ops(src_block, dst_block, *, yield_values): + """Move all non-terminator ops from *src_block* into *dst_block* and yield.""" + with InsertionPoint(dst_block): + terminator = scf.YieldOp(list(yield_values)) + yield_anchor = terminator.operation.opview + for op in list(src_block.operations): + if op.operation.name == "scf.yield": + continue + op.move_before(yield_anchor) + + +class _IfBranchCM: + """Enters the insertion point of one branch block for ``with br.then_:`` style.""" + + def __init__(self, owner, branch_name, block): + self._owner = owner + self._branch_name = branch_name + self._block = block + self._ip = None + + def __enter__(self): + self._owner._enter_branch(self._branch_name) + self._ip = InsertionPoint(self._block) + self._ip.__enter__() + + def __exit__(self, *exc): + try: + self._ip.__exit__(*exc) + finally: + self._owner._leave_branch(self._branch_name) + + +class BranchHandle: + """ + Handle for one authored ``pto.if_(...)`` branch pair. + + Usage:: + + with pto.if_(cond) as br: + with br.then_: + br.assign(val=x) + with br.else_: + br.assign(val=y) + out = br.val + """ + + def __init__(self, owner): + self._owner = owner + self.then_ = _IfBranchCM(owner, "then", owner._tmp_if.then_block) + self.else_ = _IfBranchCM(owner, "else", owner._tmp_if.else_block) + + def assign(self, **kwargs): + self._owner._assign_branch_values(kwargs) + + def __getattr__(self, name): + if name.startswith("_"): + raise AttributeError(name) + return self._owner._get_merged_value(name) + + +class _IfCM: + def __init__(self, cond): + self._cond = cond + self._cond_value = None + self._tmp_if = None + self._parent_block = None + self._active_branch = None + self._branch_closed = {"then": False, "else": False} + self._branch_entered = {"then": False, "else": False} + self._branch_assignments = {"then": None, "else": None} + self._merged_values = None + self._finalized = False + self._handle = None + + def __enter__(self): + self._cond_value = unwrap_surface_value(self._cond) + self._tmp_if = scf.IfOp(self._cond_value, hasElse=True) + self._parent_block = _find_parent_block(self._tmp_if) + self._handle = BranchHandle(self) + return self._handle + + def __exit__(self, exc_type, exc, tb): + if exc_type is not None: + self._erase_tmp_if() + return None + try: + self._finalize() + except Exception: + self._erase_tmp_if() + raise + return None + + def _enter_branch(self, branch_name): + if self._finalized: + raise RuntimeError("pto.if_(...) branches are no longer available after the conditional closes") + if self._active_branch is not None: + raise RuntimeError( + "pto.if_(...) does not support nested branch entry; close the current " + f"br.{self._active_branch}_ block before entering br.{branch_name}_" + ) + if self._branch_closed[branch_name]: + raise RuntimeError(f"br.{branch_name}_ may only be entered once per pto.if_(...)") + self._active_branch = branch_name + self._branch_entered[branch_name] = True + + def _leave_branch(self, branch_name): + if self._active_branch == branch_name: + self._active_branch = None + self._branch_closed[branch_name] = True + + def _assign_branch_values(self, kwargs): + if self._active_branch is None: + raise RuntimeError("br.assign(...) may only be used inside br.then_ or br.else_") + if not kwargs: + raise ValueError("br.assign(...) requires at least one named value") + branch_name = self._active_branch + if self._branch_assignments[branch_name] is not None: + raise RuntimeError(f"br.{branch_name}_ may call br.assign(...) at most once") + raw_values = {} + templates = {} + order = tuple(kwargs.keys()) + for name, value in kwargs.items(): + raw_value = unwrap_surface_value(value) + if not hasattr(raw_value, "type"): + raise TypeError( + "br.assign(...) expects PTO runtime values or authored surface values; " + f"'{name}' received {type(value).__name__}" + ) + raw_values[name] = raw_value + templates[name] = value + self._branch_assignments[branch_name] = { + "order": order, + "raw_values": raw_values, + "templates": templates, + } + + def _get_merged_value(self, name): + if not self._finalized: + raise RuntimeError(f"br.{name} is only available after the pto.if_(...) block closes") + if self._merged_values is None or name not in self._merged_values: + expected = () + if self._merged_values: + expected = tuple(self._merged_values.keys()) + if expected: + raise AttributeError( + f"br.{name} was not assigned by this conditional; " + f"expected one of: {', '.join(expected)}" + ) + raise AttributeError(f"br.{name} was not assigned by this conditional") + return self._merged_values[name] + + def _finalize(self): + self._validate_no_stray_ops() + if not any(self._branch_entered.values()): + raise RuntimeError( + "pto.if_(...) requires at least one explicit branch block; " + "use 'with br.then_:' and optionally 'with br.else_:'" + ) + merge_spec = self._validate_merge_spec() + if merge_spec is None: + self._finalize_side_effect_if() + else: + self._finalize_merged_if(merge_spec) + self._finalized = True + + def _validate_no_stray_ops(self): + parent_ops = list(self._parent_block.operations) + if not parent_ops or parent_ops[-1].operation is not self._tmp_if.operation: + raise RuntimeError( + "pto.if_(...) body may only contain explicit 'with br.then_:' / " + "'with br.else_:' blocks; PTODSL found operations emitted directly " + "in the outer if body" + ) + + def _validate_merge_spec(self): + then_assignment = self._branch_assignments["then"] + else_assignment = self._branch_assignments["else"] + if then_assignment is None and else_assignment is None: + return None + if then_assignment is None or else_assignment is None: + raise RuntimeError( + "automatic branch merge requires both br.then_ and br.else_ to call br.assign(...)" + ) + + then_names = set(then_assignment["raw_values"].keys()) + else_names = set(else_assignment["raw_values"].keys()) + if then_names != else_names: + missing_in_else = sorted(then_names - else_names) + missing_in_then = sorted(else_names - then_names) + pieces = [] + if missing_in_else: + pieces.append(f"missing in else: {', '.join(missing_in_else)}") + if missing_in_then: + pieces.append(f"missing in then: {', '.join(missing_in_then)}") + raise RuntimeError("br.assign(...) names must match across branches; " + "; ".join(pieces)) + + order = then_assignment["order"] + result_types = [] + for name in order: + then_value = then_assignment["raw_values"][name] + else_value = else_assignment["raw_values"][name] + if then_value.type != else_value.type: + raise RuntimeError( + f"br.assign(...) type mismatch for '{name}': " + f"then branch yields {then_value.type}, else branch yields {else_value.type}" + ) + result_types.append(then_value.type) + + return { + "order": order, + "result_types": result_types, + "then": then_assignment, + "else": else_assignment, + } + + def _finalize_side_effect_if(self): + has_else = self._branch_entered["else"] + final_if = scf.IfOp(self._cond_value, hasElse=has_else) + _move_block_ops(self._tmp_if.then_block, final_if.then_block, yield_values=[]) + if has_else: + _move_block_ops(self._tmp_if.else_block, final_if.else_block, yield_values=[]) + self._merged_values = {} + self._tmp_if.erase() + self._tmp_if = final_if + + def _finalize_merged_if(self, merge_spec): + final_if = scf.IfOp(self._cond_value, merge_spec["result_types"], hasElse=True) + then_yield_values = [ + merge_spec["then"]["raw_values"][name] + for name in merge_spec["order"] + ] + else_yield_values = [ + merge_spec["else"]["raw_values"][name] + for name in merge_spec["order"] + ] + _move_block_ops(self._tmp_if.then_block, final_if.then_block, yield_values=then_yield_values) + _move_block_ops(self._tmp_if.else_block, final_if.else_block, yield_values=else_yield_values) + + merged = {} + for name, template, result in zip( + merge_spec["order"], + (merge_spec["then"]["templates"][name] for name in merge_spec["order"]), + final_if.results, + ): + merged[name] = wrap_like_surface_value(template, result) + self._merged_values = merged + self._tmp_if.erase() + self._tmp_if = final_if + + def _erase_tmp_if(self): + if self._tmp_if is None: + return + try: + self._tmp_if.erase() + except Exception: + pass + finally: + self._tmp_if = None + + +def if_(cond) -> _IfCM: + """ + ``scf.if`` context manager with explicit branch handles. + + Side-effect-only form:: + + with pto.if_(has_rows) as br: + with br.then_: + ... + + Automatic named merge form:: + + with pto.if_(has_chunk) as br: + with br.then_: + br.assign(x=a) + with br.else_: + br.assign(x=b) + x = br.x + """ + return _IfCM(cond) + + +# ── yield_ ──────────────────────────────────────────────────────────────────── + +def yield_(*vals): + """Emit ``scf.yield`` with the given values.""" + scf.YieldOp([unwrap_surface_value(value) for value in vals]) + + +__all__ = [ + "vecscope", "LoopHandle", "BranchHandle", + "for_", "if_", "yield_", +] diff --git a/ptodsl/ptodsl/_diagnostics.py b/ptodsl/ptodsl/_diagnostics.py new file mode 100644 index 000000000..a8c60b28a --- /dev/null +++ b/ptodsl/ptodsl/_diagnostics.py @@ -0,0 +1,216 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +"""Shared user-facing diagnostics for PTODSL tracing misuse.""" + +from __future__ import annotations + + +class PTODSLTracingMisuseError(TypeError): + """Raised when authored Python misuses PTODSL runtime values during tracing.""" + + +def _format_source_context(function_name: str | None, source_file: str | None, source_line: int | None) -> str: + details = [] + if function_name: + details.append(f"kernel {function_name!r}") + if source_file is not None: + location = source_file + if source_line is not None: + location = f"{location}:{source_line}" + details.append(location) + if not details: + return "" + return f" ({', '.join(details)})" + + +def native_python_control_flow_error(usage: str) -> PTODSLTracingMisuseError: + """Return one actionable diagnostic for native Python control-flow misuse.""" + return PTODSLTracingMisuseError( + f"native Python {usage} cannot consume a PTODSL runtime value during tracing. " + "This value is a device-side SSA/runtime-metadata value, not a Python bool/int. " + "Use pto.if_(...) or pto.for_(...) for device-side control flow, or keep the " + "bound/condition in pto.constexpr." + ) + + +def host_tensor_metadata_error(message: str, *, param_name: str | None = None) -> TypeError: + """Return one actionable diagnostic for unsupported host-tensor metadata.""" + prefix = "host tensor metadata is incomplete or unsupported" + if param_name is not None: + prefix = f"@pto.jit host tensor '{param_name}' metadata is incomplete or unsupported" + return TypeError(f"{prefix}: {message}") + + +def jit_missing_annotation_error(name: str) -> TypeError: + """Return one diagnostic for missing ``@pto.jit`` positional ABI annotations.""" + return TypeError( + f"@pto.jit positional parameter '{name}' does not declare an entry ABI annotation. " + "Use pto.ptr(..., 'gm') for device pointers, pto.tensor_spec(...) for runtime tensors, a PTO scalar type such as " + "pto.i32/pto.f32/pto.i1 for runtime scalars, or move compile-time values " + "to keyword-only pto.constexpr parameters." + ) + + +def jit_illegal_formal_annotation_error(name: str, annotation: object) -> TypeError: + """Return one diagnostic for unsupported ``@pto.jit`` positional annotations.""" + return TypeError( + f"@pto.jit positional parameter '{name}' uses unsupported entry annotation {annotation!r}. " + "The public @pto.jit entry ABI accepts pto.ptr(..., 'gm') device pointers, pto.tensor_spec(...) runtime tensors, " + "PTO scalar annotations such as pto.i32/pto.f32/pto.i1 for runtime scalars, " + "and keyword-only pto.constexpr compile-time parameters. " + "Non-GM pointers, Tile, PartitionTensorView, and VReg belong inside the kernel body " + "or across sub-kernel boundaries, not at the host/kernel entry." + ) + + +def subkernel_host_tensor_boundary_error(role: str, name: str) -> TypeError: + """Return one diagnostic for host-tensor usage outside the JIT boundary.""" + return TypeError( + f"@pto.{role} parameter '{name}' uses a host tensor value, but host tensors only belong " + "at the @pto.jit boundary. Pass PTODSL device-side values such as Tile, " + "PartitionTensorView, typed pointers, or PTO scalars instead." + ) + + +def subkernel_signature_boundary_error(role: str, name: str) -> TypeError: + """Return one diagnostic for illegal host-tensor formal annotations on a subkernel.""" + return TypeError( + f"@pto.{role} parameter '{name}' cannot be annotated with pto.tensor_spec(...). " + "Host tensors are only valid as @pto.jit positional parameters." + ) + + +def illegal_subkernel_placement_error(role: str, outer_role: str | None) -> RuntimeError: + """Return one diagnostic for a subkernel call placed outside the supported layer graph.""" + if role == "simt": + return RuntimeError( + "@pto.simt helper materialization is only supported from the top-level @pto.jit body; " + f"it cannot be materialized inside @pto.{outer_role}." + ) + return RuntimeError( + f"@pto.{role} may only be called from the top-level @pto.jit body; " + f"nested invocation inside @pto.{outer_role} is not part of the PTODSL layer contract." + ) + + +def illegal_inline_subkernel_placement_error(role: str, outer_role: str | None) -> RuntimeError: + """Return one diagnostic for an inline subkernel scope placed outside the supported layer graph.""" + return RuntimeError( + f"inline pto.{role}() may only be used from the top-level @pto.jit body; " + f"nested use inside @pto.{outer_role} is not part of the PTODSL layer contract." + ) + + +def simd_value_escape_error(type_text: str) -> RuntimeError: + """Return one diagnostic for transient SIMD values escaping a simd subkernel boundary.""" + return RuntimeError( + f"@pto.simd cannot return transient SIMD values across the subkernel boundary " + f"(got {type_text}). Write the value back to a Tile/UB buffer instead." + ) + + +def tile_row_alignment_error(*, shape, dtype, row_bytes: int, required_alignment: int) -> TypeError: + """Return one diagnostic for authored tile shapes violating row-byte alignment.""" + return TypeError( + "alloc_tile(shape=...) physical row layout is invalid for the current PTODSL tile contract: " + f"shape={list(shape)!r} with dtype={dtype!r} gives a row byte size of {row_bytes}, " + f"but row-major none-box tiles must be {required_alignment}-byte aligned. " + "For logical column tiles such as [Br, 1], prefer blayout='ColMajor' instead of authoring them " + "as row-major narrow tiles. If row-major is truly required, keep the physical tile shape explicitly " + "aligned and express the logical tail with valid_shape=[...]." + ) + + +def explicit_mode_required_error(surface: str, current_mode: str | None) -> RuntimeError: + """Return one diagnostic for explicit-only surfaces used outside explicit mode.""" + observed_mode = "unknown" if current_mode is None else current_mode + return RuntimeError( + f"{surface} is an auto-mode contract violation: it is only available in " + f'@pto.jit(mode="explicit"); current kernel mode is {observed_mode!r}. ' + "Move the kernel to explicit mode before authoring this surface." + ) + + +def explicit_mode_required_with_context_error(surface: str, module_spec) -> RuntimeError: + """Return one diagnostic for explicit-only surfaces used outside explicit mode with source context.""" + observed_mode = getattr(module_spec, "mode", None) + context = _format_source_context( + getattr(module_spec, "function_name", None), + getattr(module_spec, "source_file", None), + getattr(module_spec, "source_line", None), + ) + observed_mode = "unknown" if observed_mode is None else observed_mode + return RuntimeError( + f"{surface} is an auto-mode contract violation{context}: it is only available in " + f'@pto.jit(mode="explicit"); current kernel mode is {observed_mode!r}. ' + "Move the kernel to explicit mode before authoring this surface." + ) + + +def invalid_jit_mode_error( + mode: str, + *, + function_name: str | None = None, + source_file: str | None = None, + source_line: int | None = None, +) -> ValueError: + """Return one diagnostic for unsupported ``@pto.jit(mode=...)`` values.""" + context = _format_source_context(function_name, source_file, source_line) + return ValueError( + f"unsupported PTODSL jit mode {mode!r}{context}; expected 'auto' or 'explicit'" + ) + + +def unsupported_public_surface_error(name: str) -> AttributeError: + """Return one diagnostic for unsupported names on the public ``pto`` surface.""" + hints = { + "ukernel": ( + 'Use @pto.jit(mode="explicit") for explicit DMA orchestration, and call or inline ' + "@pto.simd/@pto.simt/@pto.cube directly from that kernel." + ), + "tile_buf_type": ( + "Use pto.alloc_tile(shape=..., dtype=..., memory_space=..., valid_shape=..., addr=...) " + "to author tiles, and keep explicit tile-type construction inside internal implementation code only." + ), + "vecscope": ( + "Use @pto.simd for named SIMD helpers, or inline SIMD code with `with pto.simd():`." + ), + "as_ptr": ( + "Use tile.as_ptr(), view.as_ptr(), or partition.as_ptr() on the authored object itself " + "instead of the removed pto.as_ptr(...) helper." + ), + "vbrc_load": ( + 'Use pto.vlds(ptr, offset, dist="BRC_B32") instead of the removed pto.vbrc_load(...) helper.' + ), + "vsts_1pt": ( + 'Use pto.vsts(vec, ptr, offset, mask, dist="1PT_B32") instead of the removed pto.vsts_1pt(...) helper.' + ), + } + suffix = hints.get(name, "Use the documented PTODSL public surface instead.") + return AttributeError( + f"pto.{name} is not a supported PTODSL public interface. {suffix}" + ) + + +__all__ = [ + "PTODSLTracingMisuseError", + "explicit_mode_required_error", + "explicit_mode_required_with_context_error", + "host_tensor_metadata_error", + "jit_illegal_formal_annotation_error", + "jit_missing_annotation_error", + "illegal_inline_subkernel_placement_error", + "illegal_subkernel_placement_error", + "invalid_jit_mode_error", + "native_python_control_flow_error", + "simd_value_escape_error", + "subkernel_host_tensor_boundary_error", + "subkernel_signature_boundary_error", + "tile_row_alignment_error", + "unsupported_public_surface_error", +] diff --git a/ptodsl/ptodsl/_host_tensors.py b/ptodsl/ptodsl/_host_tensors.py new file mode 100644 index 000000000..0217ffb5f --- /dev/null +++ b/ptodsl/ptodsl/_host_tensors.py @@ -0,0 +1,242 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +"""Host-tensor boundary helpers for ``@pto.jit``.""" + +from __future__ import annotations + +import inspect +from dataclasses import dataclass + +from ._diagnostics import host_tensor_metadata_error +from ._types import _ensure_non_storage_only_authored_dtype, _resolve, index, ptr + + +def _normalize_tensor_shape(shape): + try: + return tuple(int(dim) for dim in shape) + except TypeError as exc: + raise host_tensor_metadata_error("missing iterable .shape") from exc + except ValueError as exc: + raise host_tensor_metadata_error(".shape must contain integer-like dimensions") from exc + + +def _normalize_tensor_strides(tensor): + stride_method = getattr(tensor, "stride", None) + if callable(stride_method): + try: + return tuple(int(dim) for dim in stride_method()) + except TypeError as exc: + raise host_tensor_metadata_error(".stride() must return an iterable of integer-like dimensions") from exc + except ValueError as exc: + raise host_tensor_metadata_error(".stride() must return integer-like dimensions") from exc + strides = getattr(tensor, "strides", None) + if strides is None: + raise host_tensor_metadata_error("missing .strides or .stride()") + try: + return tuple(int(dim) for dim in strides) + except TypeError as exc: + raise host_tensor_metadata_error(".strides must be iterable") from exc + except ValueError as exc: + raise host_tensor_metadata_error(".strides must contain integer-like dimensions") from exc + + +def _extract_tensor_data_handle(tensor): + for attr_name in ("data_ptr", "ptr"): + attr = getattr(tensor, attr_name, None) + if callable(attr): + value = attr() + else: + value = attr + if value is not None: + try: + return int(value) + except (TypeError, ValueError) as exc: + raise host_tensor_metadata_error( + f"{attr_name} must return an integer-like data handle" + ) from exc + array_interface = getattr(tensor, "__array_interface__", None) + if array_interface is not None: + data = array_interface.get("data") + if isinstance(data, tuple) and data: + try: + return int(data[0]) + except (TypeError, ValueError) as exc: + raise host_tensor_metadata_error( + "__array_interface__['data'][0] must be an integer-like data handle" + ) from exc + raise host_tensor_metadata_error( + "missing data handle; expected .data_ptr(), .ptr, or __array_interface__" + ) + + +@dataclass(frozen=True) +class HostTensorMetadata: + """Concrete runtime metadata extracted from a Python host tensor.""" + + shape: tuple[int, ...] + strides: tuple[int, ...] + dtype: object + data_handle: int + + +def inspect_host_tensor_metadata(tensor) -> HostTensorMetadata: + """Extract shape / strides / dtype / data-handle from a Python tensor-like object.""" + shape = _normalize_tensor_shape(getattr(tensor, "shape", None)) + strides = _normalize_tensor_strides(tensor) + dtype = getattr(tensor, "dtype", None) + if dtype is None: + raise host_tensor_metadata_error("missing .dtype") + return HostTensorMetadata( + shape=shape, + strides=strides, + dtype=dtype, + data_handle=_extract_tensor_data_handle(tensor), + ) + + +@dataclass(frozen=True) +class TensorSpec: + """Static ABI hint for one Python-native ``@pto.jit`` tensor parameter.""" + + rank: int + dtype: object + address_space: str = "gm" + + def __post_init__(self): + if self.rank <= 0: + raise ValueError("tensor_spec(rank=...) expects a positive rank") + _ensure_non_storage_only_authored_dtype( + self.dtype, + context="pto.tensor_spec(...)", + ) + + def entry_arg_types(self): + data_type = _resolve(ptr(self.dtype, self.address_space)) + index_type = _resolve(index) + return ( + data_type, + *([index_type] * self.rank), + *([index_type] * self.rank), + ) + + def abi_signature(self): + return ( + "tensor_spec", + self.rank, + self.dtype, + self.address_space, + ) + + def __repr__(self): + return ( + f"pto.tensor_spec(rank={self.rank}, dtype={self.dtype!r}, " + f"address_space={self.address_space!r})" + ) + + +def tensor_spec(*, rank: int, dtype, address_space: str = "gm") -> TensorSpec: + """Declare the ABI contract of one Python-native ``@pto.jit`` tensor parameter.""" + return TensorSpec(rank=rank, dtype=dtype, address_space=address_space) + + +class HostTensorValue: + """Tracing-time proxy for one Python-native tensor at the ``@pto.jit`` boundary.""" + + def __init__(self, name: str, spec: TensorSpec, data_handle, shape, strides): + from ._surface_values import wrap_surface_value + self.name = name + self.spec = spec + self.data_handle = wrap_surface_value(data_handle) + self.shape = tuple(wrap_surface_value(dim) for dim in shape) + self.strides = tuple(wrap_surface_value(dim) for dim in strides) + self.dtype = spec.dtype + + @property + def rank(self): + return self.spec.rank + + def __repr__(self): + return ( + f"" + ) + + +def bind_host_tensor_argument(name: str, spec: TensorSpec, entry_arguments): + """Bind one flattened entry-ABI slice into a ``HostTensorValue``.""" + expected = 1 + spec.rank + spec.rank + if len(entry_arguments) < expected: + raise RuntimeError( + f"entry ABI for host tensor '{name}' is incomplete: expected {expected} " + f"arguments, got {len(entry_arguments)}" + ) + data_handle = entry_arguments[0] + shape = entry_arguments[1:1 + spec.rank] + strides = entry_arguments[1 + spec.rank:1 + spec.rank + spec.rank] + return ( + HostTensorValue(name, spec, data_handle, shape, strides), + entry_arguments[expected:], + ) + + +def infer_jit_host_tensor_spec(param: inspect.Parameter): + """ + Resolve one ``@pto.jit`` positional parameter to a host-tensor contract. + + V1 cannot infer rank/dtype from an unannotated formal parameter while still + tracing at compile time, so host tensors currently require an explicit + ``pto.tensor_spec(...)`` ABI hint. + """ + if isinstance(param.annotation, TensorSpec): + return param.annotation + if param.annotation is inspect.Parameter.empty: + raise TypeError( + f"@pto.jit positional parameter '{param.name}' uses the host-tensor " + "boundary but does not declare an ABI hint. Add an annotation such " + "as `Q: pto.tensor_spec(rank=4, dtype=pto.f32)`." + ) + return None + + +def resolve_tensor_data_entry(value): + """Return the pointer-like data entry behind a host tensor proxy or raw value.""" + if isinstance(value, HostTensorValue): + return value.data_handle + return value + + +def looks_like_host_tensor(value) -> bool: + """Best-effort predicate for Python-native tensor-like objects at the JIT boundary.""" + if isinstance(value, HostTensorValue): + return True + return ( + getattr(value, "shape", None) is not None + and getattr(value, "dtype", None) is not None + and ( + callable(getattr(value, "stride", None)) + or getattr(value, "strides", None) is not None + ) + and ( + callable(getattr(value, "data_ptr", None)) + or getattr(value, "ptr", None) is not None + or getattr(value, "__array_interface__", None) is not None + ) + ) + + +__all__ = [ + "HostTensorMetadata", + "TensorSpec", + "HostTensorValue", + "bind_host_tensor_argument", + "tensor_spec", + "infer_jit_host_tensor_spec", + "inspect_host_tensor_metadata", + "looks_like_host_tensor", + "resolve_tensor_data_entry", +] diff --git a/ptodsl/ptodsl/_jit.py b/ptodsl/ptodsl/_jit.py new file mode 100644 index 000000000..4c7012518 --- /dev/null +++ b/ptodsl/ptodsl/_jit.py @@ -0,0 +1,166 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +"""``@pto.jit`` decorator and compiled-kernel handles.""" + +from __future__ import annotations + +import inspect + +from ._diagnostics import invalid_jit_mode_error +from ._kernel_compilation import CompiledKernelHandle, KernelCompiler +from ._kernel_signature import parse_jit_kernel_signature +from ._tracing import ( + KernelModuleSpec, + ModuleArtifact, + ModuleStyle, +) + +from mlir.ir import InsertionPoint + + +_MODULE_ATTRS = ("pto.target_arch", "pto.kernel_kind", "pto.mode") + + +def _normalize_mode(mode: str, *, fn=None) -> str: + if mode not in {"auto", "explicit"}: + source_file = None + source_line = None + function_name = None + if fn is not None: + function_name = fn.__name__ + try: + source_file = inspect.getsourcefile(fn) or inspect.getfile(fn) + except (OSError, TypeError): + source_file = None + source_line = getattr(getattr(fn, "__code__", None), "co_firstlineno", None) + raise invalid_jit_mode_error( + mode, + function_name=function_name, + source_file=source_file, + source_line=source_line, + ) + return mode + + +def _module_attr_map(module): + attrs = module.operation.attributes + return {name: str(attrs[name]) for name in _MODULE_ATTRS if name in attrs} + + +def merge_jit_modules(*kernels: KernelHandle): + """ + Merge multiple ``@pto.jit`` kernels into one MLIR module. + + Each handle must have been compiled with the same ``target``, + ``kernel_kind``, and ``mode`` module attributes. Function order follows + *kernels*. + """ + if not kernels: + raise ValueError("merge_jit_modules() requires at least one kernel handle") + + merged = kernels[0].build() + expected_attrs = _module_attr_map(merged) + + for kernel in kernels[1:]: + module = kernel.build() + actual_attrs = _module_attr_map(module) + if actual_attrs != expected_attrs: + raise ValueError( + "merge_jit_modules() requires compatible module attributes; " + f"expected {expected_attrs}, got {actual_attrs}" + ) + with InsertionPoint(merged.body): + for op in module.body.operations: + op.operation.clone() + + merged.operation.verify() + return merged + + +def jit( + name=None, + *, + target: str = "a5", + kernel_kind: str = "vector", + mode: str = "auto", + insert_sync: bool | None = None, +): + """ + Decorator that wraps a Python function as a PTODSL JIT kernel template. + + Parameters + ---------- + name: IR function name (defaults to the Python function name). + target: Target architecture string, e.g. ``"a5"``. + kernel_kind: ``"vector"`` or ``"cube"`` – sets ``pto.kernel_kind``. + mode: ``"auto"`` or ``"explicit"`` – sets ``pto.mode``. + insert_sync: ``True``/``False`` to explicitly control PTOAS sync insertion + for launch builds. ``None`` keeps the mode-based default + behavior. + + The decorated function is replaced by a :class:`KernelHandle` that: + + - supports ``my_kernel.compile(**constexprs)`` specialization, + - prints as the default-specialization MLIR text, + - exposes ``my_kernel.mlir_module()`` / ``verify()`` / ``emit()`` on the + default specialization for convenience. + - emits a flat aicore launch-entry module by default. + """ + + def decorator(fn): + fn_name = name or fn.__name__ + kernel_signature = parse_jit_kernel_signature(fn) + normalized_mode = _normalize_mode(mode, fn=fn) + source_file = None + try: + source_file = inspect.getsourcefile(fn) or inspect.getfile(fn) + except (OSError, TypeError): + source_file = None + compiler = KernelCompiler( + fn.__name__, + KernelModuleSpec( + function_name=fn_name, + target_arch=target, + kernel_kind=kernel_kind, + mode=normalized_mode, + insert_sync=insert_sync, + module_style=ModuleStyle.FLAT_AICORE, + source_file=source_file, + source_line=getattr(fn.__code__, "co_firstlineno", None), + ), + kernel_signature, + fn, + ) + return KernelHandle(fn.__name__, compiler) + + return decorator + + +class KernelHandle(ModuleArtifact): + """ + Represents a JIT kernel template plus its compiled specializations. + + ``handle.compile(**constexprs)`` returns one compiled specialization. + ``print(handle)`` emits the default-specialization MLIR module text. + """ + + def __init__(self, py_name: str, compiler: KernelCompiler): + self._compiler = compiler + super().__init__(py_name, module_factory=self._build_default_module) + + def compile(self, **constexpr_bindings) -> CompiledKernelHandle: + return self._compiler.compile(**constexpr_bindings) + + def cached_specializations(self): + return self._compiler.cached_specializations() + + def _build_default_module(self): + return self.compile().build() + + +__all__ = ["CompiledKernelHandle", "jit", "KernelHandle", "merge_jit_modules"] diff --git a/ptodsl/ptodsl/_kernel_compilation.py b/ptodsl/ptodsl/_kernel_compilation.py new file mode 100644 index 000000000..15cae91d9 --- /dev/null +++ b/ptodsl/ptodsl/_kernel_compilation.py @@ -0,0 +1,99 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +"""Kernel specialization and compilation helpers for ``@pto.jit``.""" + +from __future__ import annotations + +from ._runtime.launch import LaunchHandle, parse_launch_spec +from ._tracing import ModuleArtifact, SignatureTracingRuntime + + +class CompiledKernelHandle(ModuleArtifact): + """One compiled ``@pto.jit`` specialization.""" + + def __init__( + self, + py_name: str, + *, + specialization_key, + constexpr_bindings, + module_factory, + module_spec, + kernel_signature, + ): + super().__init__(py_name, module_factory=module_factory) + self._specialization_key = specialization_key + self._constexpr_bindings = dict(constexpr_bindings) + self._module_spec = module_spec + self._kernel_signature = kernel_signature + + @property + def specialization_key(self): + return self._specialization_key + + @property + def constexpr_bindings(self): + return dict(self._constexpr_bindings) + + @property + def ir_function_name(self): + return self._module_spec.function_name + + def __getitem__(self, launch_spec): + grid, stream = parse_launch_spec(launch_spec) + return LaunchHandle(self, grid, stream) + + +class KernelCompiler: + """Per-kernel specialization cache and module builder.""" + + def __init__(self, py_name: str, module_spec, kernel_signature, callback): + self._py_name = py_name + self._module_spec = module_spec + self._kernel_signature = kernel_signature + self._callback = callback + self._kernel_identity = id(callback) + self._compiled_cache = {} + + def compile(self, **constexpr_bindings): + normalized_bindings = self._kernel_signature.bind_constexpr_bindings(constexpr_bindings) + specialization_key = self._kernel_signature.specialization_key( + self._kernel_identity, + normalized_bindings, + ) + + cached = self._compiled_cache.get(specialization_key) + if cached is not None: + return cached + + runtime = SignatureTracingRuntime( + self._module_spec, + self._kernel_signature, + self._callback, + constexpr_bindings=normalized_bindings, + ) + compiled = CompiledKernelHandle( + self._py_name, + specialization_key=specialization_key, + constexpr_bindings=normalized_bindings, + module_factory=runtime.build_module, + module_spec=self._module_spec, + kernel_signature=self._kernel_signature, + ) + compiled.build() + self._compiled_cache[specialization_key] = compiled + return compiled + + def cached_specializations(self): + return tuple(self._compiled_cache.values()) + + +__all__ = [ + "CompiledKernelHandle", + "KernelCompiler", +] diff --git a/ptodsl/ptodsl/_kernel_signature.py b/ptodsl/ptodsl/_kernel_signature.py new file mode 100644 index 000000000..ea2784c9a --- /dev/null +++ b/ptodsl/ptodsl/_kernel_signature.py @@ -0,0 +1,243 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +"""Declarative PTODSL kernel-signature parsing and entry-ABI binding.""" + +from __future__ import annotations + +import inspect +from dataclasses import dataclass + +from ._diagnostics import ( + jit_illegal_formal_annotation_error, + jit_missing_annotation_error, +) +from ._host_tensors import bind_host_tensor_argument, infer_jit_host_tensor_spec +from ._surface_values import wrap_surface_value +from ._surface_types import constexpr as _constexpr_marker +from ._types import ( + _DType, + _MaskDescriptor, + _PtrDescriptor, + _VRegDescriptor, + _normalize_address_space, + _resolve, +) + +from mlir.dialects import pto as _pto + + +@dataclass(frozen=True) +class KernelSpecializationKey: + kernel_identity: int + abi_signature: tuple + constexpr_signature: tuple[tuple[str, object], ...] + + +@dataclass(frozen=True) +class DeviceParameterSpec: + name: str + annotation: object + + def entry_arg_types(self): + return (_resolve(self.annotation),) + + def bind_entry_arguments(self, entry_arguments): + if not entry_arguments: + raise RuntimeError(f"entry ABI for device parameter '{self.name}' is incomplete") + return wrap_surface_value(entry_arguments[0]), entry_arguments[1:] + + def abi_signature(self): + return ("device", self.name, _hashable_signature_atom(self.annotation)) + + +@dataclass(frozen=True) +class RuntimeScalarParameterSpec: + name: str + annotation: object + + def entry_arg_types(self): + return (_resolve(self.annotation),) + + def bind_entry_arguments(self, entry_arguments): + if not entry_arguments: + raise RuntimeError(f"entry ABI for runtime scalar parameter '{self.name}' is incomplete") + return wrap_surface_value(entry_arguments[0]), entry_arguments[1:] + + def abi_signature(self): + return ("scalar", self.name, _hashable_signature_atom(self.annotation)) + + +@dataclass(frozen=True) +class TensorSpecParameterSpec: + name: str + tensor_spec: object + + def entry_arg_types(self): + return tuple(self.tensor_spec.entry_arg_types()) + + def bind_entry_arguments(self, entry_arguments): + return bind_host_tensor_argument(self.name, self.tensor_spec, entry_arguments) + + def abi_signature(self): + return ("tensor", self.name, self.tensor_spec.abi_signature()) + + +@dataclass(frozen=True) +class ConstexprParameterSpec: + name: str + default: object + + def bind_specialization(self, provided_bindings): + value = provided_bindings.get(self.name, self.default) + try: + hash(value) + except TypeError as exc: + raise TypeError( + f"@pto.jit constexpr parameter '{self.name}' must be hashable so it can " + "participate in the specialization cache" + ) from exc + return value + + +def _hashable_signature_atom(value): + try: + hash(value) + except TypeError: + return repr(value) + return value + + +def _is_supported_runtime_scalar_annotation(annotation) -> bool: + return ( + isinstance(annotation, _DType) + and not isinstance(annotation, (_PtrDescriptor, _VRegDescriptor, _MaskDescriptor)) + ) + + +def _is_supported_device_parameter_annotation(annotation) -> bool: + if not isinstance(annotation, _PtrDescriptor): + return False + return _normalize_address_space(annotation._space) == _pto.AddressSpace.GM + + +@dataclass(frozen=True) +class KernelSignature: + positional_parameters: tuple + constexpr_parameters: tuple[ConstexprParameterSpec, ...] + + def compute_entry_arg_types(self): + arg_types = [] + for param in self.positional_parameters: + arg_types.extend(param.entry_arg_types()) + return tuple(arg_types) + + def bind_entry_arguments(self, entry_arguments): + remaining = tuple(entry_arguments) + bound_args = [] + for param in self.positional_parameters: + bound_value, remaining = param.bind_entry_arguments(remaining) + bound_args.append(bound_value) + if remaining: + raise RuntimeError(f"unexpected trailing entry arguments in PTODSL kernel ABI: {len(remaining)}") + return tuple(bound_args) + + def default_constexpr_bindings(self): + return {param.name: param.default for param in self.constexpr_parameters} + + def bind_constexpr_bindings(self, provided_bindings): + provided = dict(provided_bindings) + expected_names = {param.name for param in self.constexpr_parameters} + unknown = sorted(name for name in provided if name not in expected_names) + if unknown: + raise TypeError( + f"unknown @pto.jit constexpr parameter(s): {', '.join(unknown)}" + ) + + bound = {} + for param in self.constexpr_parameters: + bound[param.name] = param.bind_specialization(provided) + return bound + + def abi_signature(self): + return tuple(param.abi_signature() for param in self.positional_parameters) + + def specialization_key(self, kernel_identity, constexpr_bindings): + return KernelSpecializationKey( + kernel_identity=kernel_identity, + abi_signature=self.abi_signature(), + constexpr_signature=tuple( + (param.name, constexpr_bindings[param.name]) + for param in self.constexpr_parameters + ), + ) + + +def parse_jit_kernel_signature(py_fn) -> KernelSignature: + """Parse one authored ``@pto.jit`` function signature.""" + sig = inspect.signature(py_fn) + positional_parameters = [] + constexpr_parameters = [] + + for param in sig.parameters.values(): + if param.kind in { + inspect.Parameter.POSITIONAL_ONLY, + inspect.Parameter.POSITIONAL_OR_KEYWORD, + }: + if param.annotation is inspect.Parameter.empty: + raise jit_missing_annotation_error(param.name) + host_tensor_spec = infer_jit_host_tensor_spec(param) + if host_tensor_spec is not None: + positional_parameters.append( + TensorSpecParameterSpec(param.name, host_tensor_spec) + ) + elif _is_supported_device_parameter_annotation(param.annotation): + positional_parameters.append( + DeviceParameterSpec(param.name, param.annotation) + ) + elif _is_supported_runtime_scalar_annotation(param.annotation): + positional_parameters.append( + RuntimeScalarParameterSpec(param.name, param.annotation) + ) + else: + raise jit_illegal_formal_annotation_error(param.name, param.annotation) + continue + + if param.kind is inspect.Parameter.KEYWORD_ONLY: + if param.annotation is not _constexpr_marker: + raise TypeError( + f"@pto.jit keyword-only parameter '{param.name}' must be annotated " + "with pto.constexpr in PTODSL v1" + ) + if param.default is inspect.Parameter.empty: + raise TypeError( + f"@pto.jit constexpr parameter '{param.name}' must declare a default " + "value until explicit compile-time specialization is implemented" + ) + constexpr_parameters.append(ConstexprParameterSpec(param.name, param.default)) + continue + + raise TypeError( + f"@pto.jit parameter '{param.name}' uses unsupported parameter kind " + f"{param.kind!r}" + ) + + return KernelSignature( + positional_parameters=tuple(positional_parameters), + constexpr_parameters=tuple(constexpr_parameters), + ) + + +__all__ = [ + "ConstexprParameterSpec", + "DeviceParameterSpec", + "KernelSpecializationKey", + "KernelSignature", + "RuntimeScalarParameterSpec", + "TensorSpecParameterSpec", + "parse_jit_kernel_signature", +] diff --git a/ptodsl/ptodsl/_ops.py b/ptodsl/ptodsl/_ops.py new file mode 100644 index 000000000..1042d39b1 --- /dev/null +++ b/ptodsl/ptodsl/_ops.py @@ -0,0 +1,3271 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +""" +PTO operation wrappers. + +Every function in this module emits one or more MLIR operations at the +active insertion point and returns the primary SSA result(s). + +Design rules: +- Vector math ops infer the result type from the first operand's type. +- ``vlds(tile[row, col:])`` and ``vlds(ptr, offset)`` infer the result + ``vreg`` type from the source element type. ``vbrc_load`` still requires an + explicit result ``vreg`` type because broadcast widths are authored + explicitly in the current surface. +- ``make_tensor_view`` infers the TensorViewType from ``len(shape)`` and the + pointer's element type. +- ``partition_view`` infers the PartitionTensorViewType from the source type. +""" + +from functools import wraps + +from ._bootstrap import make_context # noqa: F401 – ensure MLIR on sys.path +from ._diagnostics import explicit_mode_required_with_context_error, tile_row_alignment_error +from ._host_tensors import resolve_tensor_data_entry +from ._scalar_coercion import coerce_scalar_to_type, materialize_scalar_literal +from ._runtime_scalar_ops import classify_runtime_scalar_type, emit_runtime_binary_op +from ._surface_values import ( + MaskResultValue, + PartitionTensorViewValue, + TensorViewValue, + TileSliceValue, + TileValue, + _coerce_index_value, + _static_index_dims, + _unwrap_sequence, + compose_partition_spec, + emit_as_ptr, + infer_tile_element_type, + parse_tile_type_metadata, + unwrap_surface_value, + wrap_surface_value, +) +from ._types import ( + _isinstance_pto_type, + _integer_signedness, + _materialize_integer_literal, + _resolve, + _strip_integer_signedness, + mask_type, + part_tensor_view_type, + part_tensor_view_type_from_dims, + tensor_view_type, + tensor_view_type_from_dims, + vreg_type, +) + +from mlir.dialects import arith, pto as _pto +from mlir.ir import ( + Attribute, + BF16Type, + F16Type, + F32Type, + Float8E4M3FNType, + Float8E5M2Type, + FloatAttr, + IndexType, + IntegerType, + MemRefType, + Type, +) + +# Pipe name shorthands → canonical PIPE_* names +_PIPE_ALIASES = { + "MTE1": "PIPE_MTE1", + "MTE2": "PIPE_MTE2", + "MTE3": "PIPE_MTE3", + "MTE4": "PIPE_MTE4", + "V": "PIPE_V", + "M": "PIPE_M", + "S": "PIPE_S", + "ALL": "PIPE_ALL", +} + + +def _pipe_attr(name: str): + if not isinstance(name, str): + return _pto.PipeAttr.get(name) + canonical = _PIPE_ALIASES.get(name, name) + if not canonical.startswith("PIPE_"): + canonical = "PIPE_" + canonical + return _pto.PipeAttr.get(getattr(_pto.PIPE, canonical)) + + +def _event_attr(event_id: int): + return getattr(_pto, f"EVENT_ID{event_id}") + + +def _canonical_pipe_token(pipe): + if isinstance(pipe, str): + canonical = _PIPE_ALIASES.get(pipe, pipe) + if not canonical.startswith("PIPE_"): + canonical = "PIPE_" + canonical + return canonical + + for canonical in ( + "PIPE_FIX", "PIPE_MTE1", "PIPE_MTE2", "PIPE_MTE3", "PIPE_MTE4", + "PIPE_V", "PIPE_M", "PIPE_S", "PIPE_V2", "PIPE_ALL", + ): + pipe_attr = getattr(_pto.PIPE, canonical, None) + if pipe_attr is not None and pipe == pipe_attr: + return canonical + return None + + +def _validate_static_event_id(event_id, *, context: str): + if isinstance(event_id, int) and not 0 <= event_id <= 7: + raise ValueError(f"{context} expects static event_id in [0, 7], got {event_id}") + + +def _validate_sync_pipe(pipe, *, context: str, allowed: tuple[str, ...]): + canonical = _canonical_pipe_token(pipe) + if canonical is None: + raise TypeError(f"{context} expects a concrete Pipe value, got {pipe!r}") + if canonical not in allowed: + expected = ", ".join(f"<{name}>" for name in allowed) + raise ValueError(f"{context} expects pipe to be one of {expected}, got <{canonical}>") + + +def _require_explicit_mode(surface: str): + try: + from ._tracing.active import current_session + session = current_session() + except Exception: + session = None + if session is None: + return + current_mode = getattr(session.module_spec, "mode", None) + if current_mode != "explicit": + raise explicit_mode_required_with_context_error(surface, session.module_spec) + + +def _explicit_mode_only(surface: str): + def decorator(fn): + @wraps(fn) + def wrapper(*args, **kwargs): + _require_explicit_mode(surface) + return fn(*args, **kwargs) + + return wrapper + + return decorator + + +# ── Constants ──────────────────────────────────────────────────────────────── + +def const(value: int, *, dtype=None): + """ + Emit an ``arith.constant``. + + ``dtype`` is a ``_DType`` descriptor or a concrete ``mlir.ir.Type``. + Defaults to ``index`` when omitted. + """ + from ._types import index as _idx_dtype + mlir_type = _resolve(dtype) if dtype is not None else _resolve(_idx_dtype) + if any(cls.isinstance(mlir_type) for cls in (F16Type, BF16Type, F32Type)): + return wrap_surface_value(arith.ConstantOp(mlir_type, FloatAttr.get(mlir_type, value)).result) + if IntegerType.isinstance(mlir_type): + return wrap_surface_value(_materialize_integer_literal(mlir_type, value)) + return wrap_surface_value(arith.ConstantOp(mlir_type, value).result) + + +# ── Pointer ops ─────────────────────────────────────────────────────────────── + +def castptr(int_addr, result_ptr_type): + """``pto.castptr`` – cast an integer address to a typed PTO pointer.""" + return wrap_surface_value( + _pto.CastPtrOp(_resolve(result_ptr_type), unwrap_surface_value(int_addr)).result + ) + + +def addptr(base_ptr, index_offset): + """``pto.addptr`` – advance a pointer by an index offset.""" + return wrap_surface_value( + _pto.AddPtrOp( + unwrap_surface_value(base_ptr), + _coerce_index(index_offset, context="addptr(ptr, offset)"), + ).result + ) + + +# ── Vector load / store ─────────────────────────────────────────────────────── + +_VLOAD_DIST_TOKENS = { + "NORM", + "UNPK_B8", "UNPK_B16", "UNPK_B32", + "BRC_B8", "BRC_B16", "BRC_B32", + "US_B8", "US_B16", + "DS_B8", "DS_B16", +} + + +def vlds(src_ptr, offset=None, result_vreg_type=None, *, dist=None): + """``pto.vlds`` – vector load from a tile slice or from *src_ptr* at *offset*.""" + if isinstance(src_ptr, TileSliceValue): + if offset is not None or result_vreg_type is not None: + raise TypeError("vlds(tile[row, col:]) infers its memref slice and vreg type; do not pass offset/result_vreg_type") + kwargs = {} + if dist is not None: + kwargs["dist"] = _normalize_dist_token( + dist, + allowed=_VLOAD_DIST_TOKENS, + context="vlds(..., dist)", + ) + return wrap_surface_value(_pto.VldsOp( + _infer_vreg_type_from_tile_slice(src_ptr), + unwrap_surface_value(src_ptr), + _index_zero(), + **kwargs, + ).result) + + if offset is None: + raise TypeError("vlds(ptr, offset, result_vreg_type=None) requires an explicit offset") + if result_vreg_type is None: + result_vreg_type = _infer_vreg_type_from_address_source(src_ptr) + kwargs = {} + if dist is not None: + kwargs["dist"] = _normalize_dist_token( + dist, + allowed=_VLOAD_DIST_TOKENS, + context="vlds(..., dist)", + ) + return wrap_surface_value(_pto.VldsOp( + _resolve(result_vreg_type), + unwrap_surface_value(src_ptr), + unwrap_surface_value(offset), + **kwargs, + ).result) + + +def vldas(source): + """``pto.vldas`` – prime alignment state for a following unaligned load stream.""" + if isinstance(source, TileSliceValue): + source = _tile_slice_ptr(source) + return wrap_surface_value( + _pto.VldasOp( + _pto.AlignType.get(), + unwrap_surface_value(source), + ).result + ) + + +def vldus(source, align): + """``pto.vldus`` – unaligned vector load threaded through alignment state.""" + result_type = ( + _infer_vreg_type_from_tile_slice(source) + if isinstance(source, TileSliceValue) + else _infer_vreg_type_from_address_source(source) + ) + if isinstance(source, TileSliceValue): + source = _tile_slice_ptr(source) + op = _pto.VldusOp( + result_type, + _pto.AlignType.get(), + unwrap_surface_value(source), + unwrap_surface_value(align), + ) + return wrap_surface_value(op.result), wrap_surface_value(op.updated_align) + + +_DEINTERLEAVE_DIST_TOKENS = {"DINTLV_B8", "DINTLV_B16", "DINTLV_B32", "BDINTLV"} +_INTERLEAVE_DIST_TOKENS = {"INTLV_B8", "INTLV_B16", "INTLV_B32"} +_VSTORE_DIST_TOKENS = { + "NORM_B8", "NORM_B16", "NORM_B32", + "1PT_B8", "1PT_B16", "1PT_B32", + "PK_B16", "PK_B32", "PK_B64", "PK4_B32", + "MRG4CHN_B8", "MRG2CHN_B8", "MRG2CHN_B16", +} + + +def _normalize_dist_token(dist, *, allowed: set[str], context: str): + token = dist + if not isinstance(token, str): + token = str(token) + if "." in token: + token = token.rsplit(".", 1)[-1] + normalized = token.strip().upper() + if normalized.startswith("_"): + normalized = normalized[1:] + if normalized not in allowed: + expected = ", ".join(sorted(allowed)) + raise ValueError(f"{context} does not support dist {dist!r}; expected one of {expected}") + return normalized + + +def vldsx2(source, offset_or_dist, dist=None): + """``pto.vldsx2`` – dual vector load with deinterleave.""" + if isinstance(source, TileSliceValue): + if dist is not None: + raise TypeError("vldsx2(tile[row, col:], dist) does not accept a separate offset argument") + result_type = _infer_vreg_type_from_tile_slice(source) + op = _pto.Vldsx2Op( + result_type, + result_type, + unwrap_surface_value(source), + _index_zero(), + _normalize_dist_token( + offset_or_dist, + allowed=_DEINTERLEAVE_DIST_TOKENS, + context="vldsx2(..., dist)", + ), + ) + return wrap_surface_value(op.low), wrap_surface_value(op.high) + + if dist is None: + raise TypeError("vldsx2(ptr, offset, dist) requires an explicit offset and dist") + result_type = _infer_vreg_type_from_address_source(source) + op = _pto.Vldsx2Op( + result_type, + result_type, + unwrap_surface_value(source), + _coerce_index(offset_or_dist, context="vldsx2(ptr, offset, dist)"), + _normalize_dist_token( + dist, + allowed=_DEINTERLEAVE_DIST_TOKENS, + context="vldsx2(..., dist)", + ), + ) + return wrap_surface_value(op.low), wrap_surface_value(op.high) + + +def vbitcast(vector_value, to_dtype): + """``pto.vbitcast`` – reinterpret one vector register as a different element type.""" + target_elem = _resolve(to_dtype) + target_type = _resolve(vreg_type(_elements_per_vreg(target_elem), target_elem)) + return wrap_surface_value( + _pto.VbitcastOp( + target_type, + unwrap_surface_value(vector_value), + ).result + ) + + +def pbitcast(mask_value, to_type): + """``pto.pbitcast`` – reinterpret one mask register at a different granularity.""" + return wrap_surface_value( + _pto.PbitcastOp( + _resolve(to_type), + unwrap_surface_value(mask_value), + ).result + ) + + +def vsts(val, dst_ptr, offset, mask=None, *, dist=None): + """``pto.vsts`` – vector store to a tile slice or to *dst_ptr* at *offset*.""" + if isinstance(dst_ptr, TileSliceValue): + if mask is not None: + raise TypeError("vsts(vec, tile[row, col:], mask) does not accept a separate offset argument") + kwargs = {} + if dist is not None: + kwargs["dist"] = _normalize_dist_token( + dist, + allowed=_VSTORE_DIST_TOKENS, + context="vsts(..., dist)", + ) + _pto.VstsOp( + unwrap_surface_value(val), + unwrap_surface_value(dst_ptr), + _index_zero(), + unwrap_surface_value(offset), + **kwargs, + ) + return + + if mask is None: + raise TypeError("vsts(vec, ptr, offset, mask) requires an explicit mask") + kwargs = {} + if dist is not None: + kwargs["dist"] = _normalize_dist_token( + dist, + allowed=_VSTORE_DIST_TOKENS, + context="vsts(..., dist)", + ) + _pto.VstsOp( + unwrap_surface_value(val), + unwrap_surface_value(dst_ptr), + unwrap_surface_value(offset), + unwrap_surface_value(mask), + **kwargs, + ) + + +def vstsx2(low, high, dst_ptr, offset_or_dist, dist_or_mask=None, mask=None): + """``pto.vstsx2`` – dual interleaving vector store.""" + if isinstance(dst_ptr, TileSliceValue): + if mask is not None: + raise TypeError("vstsx2(low, high, tile[row, col:], dist, mask) does not accept a separate offset argument") + _pto.Vstsx2Op( + unwrap_surface_value(low), + unwrap_surface_value(high), + unwrap_surface_value(dst_ptr), + _index_zero(), + _normalize_dist_token( + offset_or_dist, + allowed=_INTERLEAVE_DIST_TOKENS, + context="vstsx2(..., dist)", + ), + unwrap_surface_value(dist_or_mask), + ) + return + + if mask is None: + raise TypeError("vstsx2(low, high, ptr, offset, dist, mask) requires an explicit offset, dist, and mask") + _pto.Vstsx2Op( + unwrap_surface_value(low), + unwrap_surface_value(high), + unwrap_surface_value(dst_ptr), + _coerce_index(offset_or_dist, context="vstsx2(ptr, offset, dist, mask)"), + _normalize_dist_token( + dist_or_mask, + allowed=_INTERLEAVE_DIST_TOKENS, + context="vstsx2(..., dist)", + ), + unwrap_surface_value(mask), + ) + + +def vgather2(buf, offsets, mask, result_vreg_type=None): + """``pto.vgather2`` – indexed gather from UB.""" + rt = result_vreg_type if result_vreg_type is not None else _infer_vreg_type_from_address_source(buf) + return wrap_surface_value( + _pto.Vgather2Op( + _resolve(rt), + unwrap_surface_value(buf), + unwrap_surface_value(offsets), + unwrap_surface_value(mask), + ).result + ) + + +def vgather2_bc(buf, offsets, mask, result_vreg_type=None): + """``pto.vgather2_bc`` – indexed gather from UB with masked zero-fill.""" + rt = result_vreg_type if result_vreg_type is not None else _infer_vreg_type_from_address_source(buf) + return wrap_surface_value( + _pto.Vgather2BcOp( + _resolve(rt), + unwrap_surface_value(buf), + unwrap_surface_value(offsets), + unwrap_surface_value(mask), + ).result + ) + + +def vgatherb(buf, offsets, mask, result_vreg_type=None): + """``pto.vgatherb`` – block gather from UB using byte offsets.""" + rt = result_vreg_type if result_vreg_type is not None else _infer_vreg_type_from_address_source(buf) + return wrap_surface_value( + _pto.VgatherbOp( + _resolve(rt), + unwrap_surface_value(buf), + unwrap_surface_value(offsets), + unwrap_surface_value(mask), + ).result + ) + + +def vscatter(value, destination, offsets, mask): + """``pto.vscatter`` – indexed scatter to UB.""" + _pto.VscatterOp( + unwrap_surface_value(value), + unwrap_surface_value(destination), + unwrap_surface_value(offsets), + unwrap_surface_value(mask), + ) + + +def _coerce_i16(value, *, context: str): + raw_value = unwrap_surface_value(value) + i16_type = IntegerType.get_signless(16) + if isinstance(raw_value, bool): + raise TypeError(f"{context} does not accept bool values") + if isinstance(raw_value, int): + return _materialize_integer_literal(i16_type, raw_value) + kind = classify_runtime_scalar_type(raw_value.type) + if kind == "float": + raise TypeError(f"{context} expects an integer-like scalar, got {raw_value.type}") + if kind == "index": + return arith.IndexCastOp(i16_type, raw_value).result + signless_value = _strip_integer_signedness(raw_value) + if signless_value.type == i16_type: + return signless_value + width = IntegerType(raw_value.type).width + if width < 16: + if _integer_signedness(raw_value.type) == "unsigned": + return arith.ExtUIOp(i16_type, signless_value).result + return arith.ExtSIOp(i16_type, signless_value).result + if width > 16: + return arith.TruncIOp(i16_type, signless_value).result + return signless_value + + +def vsldb(source, block_stride, repeat_stride, mask): + """``pto.vsldb`` – block-strided load.""" + result_type = ( + _infer_vreg_type_from_tile_slice(source) + if isinstance(source, TileSliceValue) + else _infer_vreg_type_from_address_source(source) + ) + return wrap_surface_value( + _pto.VsldbOp( + result_type, + unwrap_surface_value(source), + _coerce_i16(block_stride, context="vsldb(..., block_stride, repeat_stride, mask)"), + _coerce_i16(repeat_stride, context="vsldb(..., block_stride, repeat_stride, mask)"), + unwrap_surface_value(mask), + ).result + ) + + +def vsstb(value, destination, block_stride, repeat_stride, mask): + """``pto.vsstb`` – block-strided store.""" + _pto.VsstbOp( + unwrap_surface_value(value), + unwrap_surface_value(destination), + _coerce_i16(block_stride, context="vsstb(..., block_stride, repeat_stride, mask)"), + _coerce_i16(repeat_stride, context="vsstb(..., block_stride, repeat_stride, mask)"), + unwrap_surface_value(mask), + ) + + +# ── Mask / predicate ops ────────────────────────────────────────────────────── + +_MASK_PATTERN_TOKENS = { + "PAT_ALL", + "PAT_ALLF", + "PAT_H", + "PAT_Q", + "PAT_M3", + "PAT_M4", + *(f"PAT_VL{count}" for count in range(1, 129)), +} + +_CMP_MODE_TOKENS = {"eq", "ne", "lt", "le", "gt", "ge"} +_PREDICATE_PART_TOKENS = {"LOWER", "HIGHER"} +_PREDICATE_LOAD_DIST_TOKENS = {"NORM", "US", "DS"} +_PREDICATE_STORE_DIST_TOKENS = {"NORM", "PK"} +_POST_UPDATE_TOKENS = {"NO_POST_UPDATE", "POST_UPDATE"} + + +def _normalize_mask_pattern(pattern): + token = pattern + if not isinstance(token, str): + token = str(token) + if "." in token: + token = token.rsplit(".", 1)[-1] + token = token.strip().upper() + normalized = token if token.startswith("PAT_") else f"PAT_{token}" + if normalized not in _MASK_PATTERN_TOKENS: + raise ValueError( + f"unsupported mask pattern {pattern!r}; expected one of PAT_ALL, PAT_ALLF, " + "PAT_H, PAT_Q, PAT_VL1..PAT_VL128, PAT_M3, PAT_M4" + ) + return normalized + + +def _normalize_cmp_mode(cmp_mode): + token = cmp_mode + if not isinstance(token, str): + token = str(token) + if "." in token: + token = token.rsplit(".", 1)[-1] + normalized = token.strip().lower() + if normalized not in _CMP_MODE_TOKENS: + raise ValueError( + f"unsupported cmp_mode {cmp_mode!r}; expected one of EQ, NE, LT, LE, GT, GE" + ) + return normalized + + +def _cmp_mode_attr(cmp_mode): + return Attribute.parse(f"#pto") + + +def _normalize_predicate_part(part): + token = part + if not isinstance(token, str): + token = str(token) + if "." in token: + token = token.rsplit(".", 1)[-1] + normalized = token.strip().upper() + if normalized not in _PREDICATE_PART_TOKENS: + raise ValueError(f"unsupported predicate part {part!r}; expected LOWER or HIGHER") + return normalized + + +def _normalize_predicate_dist(dist, *, allowed: set[str], context: str): + token = dist + if not isinstance(token, str): + token = str(token) + if "." in token: + token = token.rsplit(".", 1)[-1] + normalized = token.strip().upper() + if normalized not in allowed: + expected = ", ".join(sorted(allowed)) + raise ValueError(f"{context} does not support dist {dist!r}; expected one of {expected}") + return normalized + + +def _normalize_post_update_mode(mode, *, context: str): + token = mode + if not isinstance(token, str): + token = str(token) + if "." in token: + token = token.rsplit(".", 1)[-1] + normalized = token.strip().upper() + if normalized in {"OFF", "NO_POST_UPDATE"}: + return "NO_POST_UPDATE" + if normalized in {"ON", "POST_UPDATE"}: + return "POST_UPDATE" + expected = ", ".join(sorted(_POST_UPDATE_TOKENS)) + raise ValueError(f"{context} does not support mode {mode!r}; expected one of ON/OFF ({expected})") + + +def _mask_type_from_bits(mask_bits: int): + return _resolve(mask_type(f"b{mask_bits}")) + + +def _infer_mask_metadata(mask_value, *, context: str): + raw_type = unwrap_surface_value(mask_value).type + try: + mask_ty = _pto.MaskType(raw_type) + except Exception as exc: + raise TypeError(f"{context} expects a PTO mask value, got {raw_type}") from exc + granularity = mask_ty.granularity + return int(granularity[1:]), raw_type + + +def _require_same_mask_types(values, *, context: str): + raw_types = [unwrap_surface_value(value).type for value in values] + first = raw_types[0] + for other in raw_types[1:]: + if other != first: + raise TypeError(f"{context} expects masks of the same granularity, got {first} and {other}") + return first + + +def _pointer_element_type(ptr_value, *, context: str): + raw_type = unwrap_surface_value(ptr_value).type + try: + return _pto.PtrType(raw_type).element_type + except Exception: + try: + return MemRefType(raw_type).element_type + except Exception as exc: + raise TypeError(f"{context} expects a PTO pointer or memref-backed address, got {raw_type}") from exc + + +def _coerce_index(value, *, context: str): + raw_value = unwrap_surface_value(value) + index_type = IndexType.get() + if isinstance(raw_value, bool): + raise TypeError(f"{context} does not accept bool values") + if isinstance(raw_value, int): + return arith.ConstantOp(index_type, raw_value).result + kind = classify_runtime_scalar_type(raw_value.type) + if kind == "float": + raise TypeError(f"{context} expects an index-like scalar, got {raw_value.type}") + if IndexType.isinstance(raw_value.type): + return raw_value + if IntegerType.isinstance(raw_value.type): + return arith.IndexCastOp(index_type, _strip_integer_signedness(raw_value)).result + raise TypeError(f"{context} expects an index-like scalar, got {raw_value.type}") + + +def init_align(): + """``pto.init_align`` – materialize the initial alignment state.""" + return wrap_surface_value(_pto.InitAlignOp(_pto.AlignType.get()).result) + + +def _plt_impl(mask_bits: int, scalar): + plt_op = _plt_op_for_mask_bits(mask_bits)( + _mask_type_from_bits(mask_bits), + IntegerType.get_signless(32), + _coerce_i32(scalar, context=f"plt_b{mask_bits}(scalar)"), + ) + return wrap_surface_value(plt_op.mask), wrap_surface_value(plt_op.scalar_out) + + +def plt_b8(scalar): + """``pto.plt_b8`` – predicate-load from a 32-bit scalar into a b8 mask.""" + return _plt_impl(8, scalar) + + +def plt_b16(scalar): + """``pto.plt_b16`` – predicate-load from a 32-bit scalar into a b16 mask.""" + return _plt_impl(16, scalar) + + +def plt_b32(scalar): + """ + ``pto.plt_b32`` – predicate-load from a 32-bit scalar. + + Returns ``(mask_value, scalar_out)``. ``scalar_out`` is often unused + and can be discarded with ``_``. + """ + return _plt_impl(32, scalar) + + +def _pset_impl(mask_bits: int, pattern): + return wrap_surface_value( + _pset_op_for_mask_bits(mask_bits)( + _mask_type_from_bits(mask_bits), + _normalize_mask_pattern(pattern), + ).result + ) + + +def pset_b8(pattern): + """``pto.pset_b8(pattern)`` → ``!pto.mask``.""" + return _pset_impl(8, pattern) + + +def pset_b16(pattern): + """``pto.pset_b16(pattern)`` → ``!pto.mask``.""" + return _pset_impl(16, pattern) + + +def pset_b32(pattern): + """``pto.pset_b32(pattern)`` → ``!pto.mask``.""" + return _pset_impl(32, pattern) + + +def _pge_op_for_mask_bits(mask_bits: int): + return { + 8: _pto.PgeB8Op, + 16: _pto.PgeB16Op, + 32: _pto.PgeB32Op, + }[mask_bits] + + +def _pge_impl(mask_bits: int, pattern): + return wrap_surface_value( + _pge_op_for_mask_bits(mask_bits)( + _mask_type_from_bits(mask_bits), + _normalize_mask_pattern(pattern), + ).result + ) + + +def pge_b8(pattern): + """``pto.pge_b8(pattern)`` → ``!pto.mask``.""" + return _pge_impl(8, pattern) + + +def pge_b16(pattern): + """``pto.pge_b16(pattern)`` → ``!pto.mask``.""" + return _pge_impl(16, pattern) + + +def pge_b32(pattern): + """``pto.pge_b32(pattern)`` → ``!pto.mask``.""" + return _pge_impl(32, pattern) + + +def pand(src0, src1, mask): + """``pto.pand`` – gated mask AND.""" + result_type = _require_same_mask_types((src0, src1, mask), context="pand(src0, src1, mask)") + return wrap_surface_value( + _pto.PandOp( + result_type, + unwrap_surface_value(src0), + unwrap_surface_value(src1), + unwrap_surface_value(mask), + ).result + ) + + +def por(src0, src1, mask): + """``pto.por`` – gated mask OR.""" + result_type = _require_same_mask_types((src0, src1, mask), context="por(src0, src1, mask)") + return wrap_surface_value( + _pto.PorOp( + result_type, + unwrap_surface_value(src0), + unwrap_surface_value(src1), + unwrap_surface_value(mask), + ).result + ) + + +def pxor(src0, src1, mask): + """``pto.pxor`` – gated mask XOR.""" + result_type = _require_same_mask_types((src0, src1, mask), context="pxor(src0, src1, mask)") + return wrap_surface_value( + _pto.PxorOp( + result_type, + unwrap_surface_value(src0), + unwrap_surface_value(src1), + unwrap_surface_value(mask), + ).result + ) + + +def pnot(src, mask): + """``pto.pnot`` – gated mask NOT.""" + result_type = _require_same_mask_types((src, mask), context="pnot(src, mask)") + return wrap_surface_value( + _pto.PnotOp( + result_type, + unwrap_surface_value(src), + unwrap_surface_value(mask), + ).result + ) + + +def psel(src0, src1, sel): + """``pto.psel`` – per-lane mask select.""" + result_type = _require_same_mask_types((src0, src1, sel), context="psel(src0, src1, sel)") + return wrap_surface_value( + _pto.PselOp( + result_type, + unwrap_surface_value(src0), + unwrap_surface_value(src1), + unwrap_surface_value(sel), + ).result + ) + + +def ppack(mask_value, part): + """``pto.ppack`` – pack predicate bits into the selected half.""" + _, result_type = _infer_mask_metadata(mask_value, context="ppack(mask, part)") + return wrap_surface_value( + _pto.PpackOp( + result_type, + unwrap_surface_value(mask_value), + _normalize_predicate_part(part), + ).result + ) + + +def punpack(mask_value, part): + """``pto.punpack`` – unpack predicate bits from the selected half.""" + _, result_type = _infer_mask_metadata(mask_value, context="punpack(mask, part)") + return wrap_surface_value( + _pto.PunpackOp( + result_type, + unwrap_surface_value(mask_value), + _normalize_predicate_part(part), + ).result + ) + + +def _pintlv_op_for_mask_bits(mask_bits: int): + return { + 8: _pto.PintlvB8Op, + 16: _pto.PintlvB16Op, + 32: _pto.PintlvB32Op, + }[mask_bits] + + +def _pdintlv_op_for_mask_bits(mask_bits: int): + return { + 8: _pto.PdintlvB8Op, + 16: _pto.PdintlvB16Op, + 32: _pto.PdintlvB32Op, + }[mask_bits] + + +def _mask_pair_op(op_resolver, lhs, rhs, *, expected_mask_bits: int, context: str): + mask_bits, result_type = _infer_mask_metadata(lhs, context=context) + if mask_bits != expected_mask_bits: + raise TypeError(f"{context} expects mask_b{expected_mask_bits} operands, got mask_b{mask_bits}") + _require_same_mask_types((lhs, rhs), context=context) + op = op_resolver(mask_bits)( + result_type, + result_type, + unwrap_surface_value(lhs), + unwrap_surface_value(rhs), + ) + return wrap_surface_value(op.low), wrap_surface_value(op.high) + + +def pintlv_b8(lhs, rhs): + """``pto.pintlv_b8`` – interleave two b8 masks.""" + return _mask_pair_op( + _pintlv_op_for_mask_bits, + lhs, + rhs, + expected_mask_bits=8, + context="pintlv_b8(lhs, rhs)", + ) + + +def pintlv_b16(lhs, rhs): + """``pto.pintlv_b16`` – interleave two b16 masks.""" + return _mask_pair_op( + _pintlv_op_for_mask_bits, + lhs, + rhs, + expected_mask_bits=16, + context="pintlv_b16(lhs, rhs)", + ) + + +def pintlv_b32(lhs, rhs): + """``pto.pintlv_b32`` – interleave two b32 masks.""" + return _mask_pair_op( + _pintlv_op_for_mask_bits, + lhs, + rhs, + expected_mask_bits=32, + context="pintlv_b32(lhs, rhs)", + ) + + +def pdintlv_b8(lhs, rhs): + """``pto.pdintlv_b8`` – deinterleave two b8 masks.""" + return _mask_pair_op( + _pdintlv_op_for_mask_bits, + lhs, + rhs, + expected_mask_bits=8, + context="pdintlv_b8(lhs, rhs)", + ) + + +def pdintlv_b16(lhs, rhs): + """``pto.pdintlv_b16`` – deinterleave two b16 masks.""" + return _mask_pair_op( + _pdintlv_op_for_mask_bits, + lhs, + rhs, + expected_mask_bits=16, + context="pdintlv_b16(lhs, rhs)", + ) + + +def pdintlv_b32(lhs, rhs): + """``pto.pdintlv_b32`` – deinterleave two b32 masks.""" + return _mask_pair_op( + _pdintlv_op_for_mask_bits, + lhs, + rhs, + expected_mask_bits=32, + context="pdintlv_b32(lhs, rhs)", + ) + + +def vcmp(src0, src1, seed_mask, cmp_mode): + """``pto.vcmp`` – vector/vector comparison producing a predicate mask.""" + _, elem_type = _infer_vreg_metadata(src0) + result_type = _mask_type_from_bits(_mask_bits_for_dtype(elem_type)) + seed_type = unwrap_surface_value(seed_mask).type + if seed_type != result_type: + raise TypeError( + f"vcmp(src0, src1, seed_mask, cmp_mode) expects seed_mask {result_type}, got {seed_type}" + ) + return wrap_surface_value( + _pto.VcmpOp( + result_type, + unwrap_surface_value(src0), + unwrap_surface_value(src1), + unwrap_surface_value(seed_mask), + _normalize_cmp_mode(cmp_mode), + ).result + ) + + +def vcmps(src, scalar, seed_mask, cmp_mode): + """``pto.vcmps`` – vector/scalar comparison producing a predicate mask.""" + _, elem_type = _infer_vreg_metadata(src) + result_type = _mask_type_from_bits(_mask_bits_for_dtype(elem_type)) + seed_type = unwrap_surface_value(seed_mask).type + if seed_type != result_type: + raise TypeError( + f"vcmps(src, scalar, seed_mask, cmp_mode) expects seed_mask {result_type}, got {seed_type}" + ) + scalar_value = _coerce_scalar_like_vector_element(src, scalar, context="vcmps") + return wrap_surface_value( + _pto.VcmpsOp( + result_type, + unwrap_surface_value(src), + unwrap_surface_value(scalar_value), + unwrap_surface_value(seed_mask), + _normalize_cmp_mode(cmp_mode), + ).result + ) + + +def plds(buf, offset, *, dist="NORM"): + """``pto.plds`` – load a predicate mask from UB memory.""" + elem_type = _pointer_element_type(buf, context="plds(buf, offset)") + result_type = _mask_type_from_bits(_mask_bits_for_dtype(elem_type)) + return wrap_surface_value( + _pto.PldsOp( + result_type, + unwrap_surface_value(buf), + _coerce_index(offset, context="plds(buf, offset)"), + _normalize_predicate_dist( + dist, + allowed=_PREDICATE_LOAD_DIST_TOKENS, + context="plds(..., dist)", + ), + ).result + ) + + +def psts(mask_value, buf, offset, *, dist="NORM"): + """``pto.psts`` – store a predicate mask to UB memory.""" + _infer_mask_metadata(mask_value, context="psts(mask, buf, offset)") + _pto.PstsOp( + unwrap_surface_value(mask_value), + unwrap_surface_value(buf), + _coerce_index(offset, context="psts(mask, buf, offset)"), + _normalize_predicate_dist( + dist, + allowed=_PREDICATE_STORE_DIST_TOKENS, + context="psts(..., dist)", + ), + ) + + +def pstu(align_in, mask_value, buf): + """``pto.pstu`` – unaligned predicate store with threaded alignment state.""" + mask_bits, _ = _infer_mask_metadata(mask_value, context="pstu(align_in, mask, buf)") + if mask_bits not in {16, 32}: + raise TypeError("pstu(align_in, mask, buf) currently supports only mask_b16 and mask_b32") + elem_type = _pointer_element_type(buf, context="pstu(align_in, mask, buf)") + expected_bytes = mask_bits // 8 + actual_bytes = _element_bytewidth(elem_type) + if actual_bytes != expected_bytes: + raise TypeError( + f"pstu(align_in, mask, buf) expects a {expected_bytes}-byte pointer element for mask_b{mask_bits}, " + f"got {elem_type}" + ) + align_type = _pto.AlignType.get() + base_type = unwrap_surface_value(buf).type + op = _pto.PstuOp( + align_type, + base_type, + unwrap_surface_value(align_in), + unwrap_surface_value(mask_value), + unwrap_surface_value(buf), + ) + return wrap_surface_value(op.align_out), wrap_surface_value(op.base_out) + + +def vstar(align, destination): + """``pto.vstar`` – flush alignment-buffered tail bytes to the destination base.""" + _pto.VstarOp( + unwrap_surface_value(align), + unwrap_surface_value(destination), + ) + + +def vstas(align, destination, offset): + """``pto.vstas`` – flush alignment-buffered tail bytes with an explicit offset.""" + _pto.VstasOp( + unwrap_surface_value(align), + unwrap_surface_value(destination), + _coerce_i32(offset, context="vstas(align, destination, offset)"), + ) + + +def vstur(align_in, value, base, mode="NO_POST_UPDATE"): + """``pto.vstur`` – unaligned vector store that updates only alignment state.""" + return wrap_surface_value( + _pto.VsturOp( + _pto.AlignType.get(), + unwrap_surface_value(align_in), + unwrap_surface_value(value), + unwrap_surface_value(base), + _normalize_post_update_mode(mode, context="vstur(..., mode)"), + ).align_out + ) + + +def vstus(align_in, offset, value, base): + """``pto.vstus`` – scalar-offset unaligned vector store that updates alignment state.""" + return wrap_surface_value( + _pto.VstusOp( + _pto.AlignType.get(), + unwrap_surface_value(align_in), + _coerce_i32(offset, context="vstus(align, offset, value, base)"), + unwrap_surface_value(value), + unwrap_surface_value(base), + ).align_out + ) + + +# ── Vector math (result type inferred from first operand) ───────────────────── + +def vbr(value): + """``pto.vbr`` – broadcast one scalar value to all vector lanes.""" + raw_value = unwrap_surface_value(value) + if isinstance(raw_value, bool): + raise TypeError("vbr(value) does not accept bool values") + + if hasattr(raw_value, "type"): + scalar_kind = classify_runtime_scalar_type(raw_value.type) + if scalar_kind == "index": + raise TypeError("vbr(value) does not support index scalars") + scalar_value = raw_value + elem_type = raw_value.type + else: + if isinstance(raw_value, float): + elem_type = F32Type.get() + elif isinstance(raw_value, int): + elem_type = IntegerType.get_signless(32) + else: + raise TypeError("vbr(value) expects a runtime scalar or one Python int/float literal") + scalar_value = materialize_scalar_literal(raw_value, elem_type, context="vbr(value)") + + try: + result_type = _resolve(vreg_type(_elements_per_vreg(elem_type), elem_type)) + except TypeError as exc: + raise TypeError(f"vbr(value) does not support scalar type {elem_type}") from exc + + return wrap_surface_value(_pto.VbrOp(result_type, scalar_value).result) + + +def _emit_unary_vec_op(op_ctor, inp, mask): + return wrap_surface_value( + op_ctor( + unwrap_surface_value(inp).type, + unwrap_surface_value(inp), + unwrap_surface_value(mask), + ).result + ) + + +def _emit_binary_vec_op(op_ctor, lhs, rhs, mask): + return wrap_surface_value( + op_ctor( + unwrap_surface_value(lhs).type, + unwrap_surface_value(lhs), + unwrap_surface_value(rhs), + unwrap_surface_value(mask), + ).result + ) + + +def _emit_vec_scalar_masked_op(op_ctor, inp, scalar, mask, *, context: str): + scalar_value = _coerce_scalar_like_vector_element(inp, scalar, context=context) + return wrap_surface_value( + op_ctor( + unwrap_surface_value(inp).type, + unwrap_surface_value(inp), + unwrap_surface_value(scalar_value), + unwrap_surface_value(mask), + ).result + ) + + +def vadd(lhs, rhs, mask, result_type=None): + """``pto.vadd`` – element-wise add.""" + rt = result_type if result_type is not None else lhs.type + return wrap_surface_value( + _pto.VaddOp( + _resolve(rt), + unwrap_surface_value(lhs), + unwrap_surface_value(rhs), + unwrap_surface_value(mask), + ).result + ) + + +def vsub(lhs, rhs, mask): + """``pto.vsub`` – element-wise subtract.""" + return _emit_binary_vec_op(_pto.VsubOp, lhs, rhs, mask) + + +def vmul(lhs, rhs, mask): + """``pto.vmul`` – element-wise multiply.""" + return _emit_binary_vec_op(_pto.VmulOp, lhs, rhs, mask) + + +def vmax(lhs, rhs, mask): + """``pto.vmax`` – element-wise maximum.""" + return _emit_binary_vec_op(_pto.VmaxOp, lhs, rhs, mask) + + +def vmin(lhs, rhs, mask): + """``pto.vmin`` – element-wise minimum.""" + return _emit_binary_vec_op(_pto.VminOp, lhs, rhs, mask) + + +def vand(lhs, rhs, mask): + """``pto.vand`` – element-wise bitwise and.""" + return _emit_binary_vec_op(_pto.VandOp, lhs, rhs, mask) + + +def vor(lhs, rhs, mask): + """``pto.vor`` – element-wise bitwise or.""" + return _emit_binary_vec_op(_pto.VorOp, lhs, rhs, mask) + + +def vxor(lhs, rhs, mask): + """``pto.vxor`` – element-wise bitwise xor.""" + return _emit_binary_vec_op(_pto.VxorOp, lhs, rhs, mask) + + +def vdiv(lhs, rhs, mask): + """``pto.vdiv`` – element-wise divide.""" + return _emit_binary_vec_op(_pto.VdivOp, lhs, rhs, mask) + + +def vshl(lhs, rhs, mask): + """``pto.vshl`` – element-wise shift left.""" + return _emit_binary_vec_op(_pto.VshlOp, lhs, rhs, mask) + + +def vshr(lhs, rhs, mask): + """``pto.vshr`` – element-wise shift right.""" + return _emit_binary_vec_op(_pto.VshrOp, lhs, rhs, mask) + + +def vcmax(v, mask): + """``pto.vcmax`` – cross-lane maximum reduction.""" + return _emit_unary_vec_op(_pto.VcmaxOp, v, mask) + + +def vcadd(v, mask): + """``pto.vcadd`` – cross-lane add (sum reduction).""" + return _emit_unary_vec_op(_pto.VcaddOp, v, mask) + + +def vcmin(v, mask): + """``pto.vcmin`` – cross-lane minimum reduction.""" + return _emit_unary_vec_op(_pto.VcminOp, v, mask) + + +def vdup(v, mask, *, position=None): + """``pto.vdup`` – duplicate a lane value into all lanes. + + Pass ``position="LOWEST"`` to broadcast the lowest (lane-0) element. + """ + return wrap_surface_value( + _pto.VdupOp( + unwrap_surface_value(v).type, + unwrap_surface_value(v), + unwrap_surface_value(mask), + position=position, + ).result + ) + + +def vln(inp, mask): + """``pto.vln`` – element-wise natural logarithm.""" + return _emit_unary_vec_op(_pto.VlnOp, inp, mask) + + +def vsqrt(inp, mask): + """``pto.vsqrt`` – element-wise square root.""" + return _emit_unary_vec_op(_pto.VsqrtOp, inp, mask) + + +def vabs(inp, mask): + """``pto.vabs`` – element-wise absolute value.""" + return _emit_unary_vec_op(_pto.VabsOp, inp, mask) + + +def vneg(inp, mask): + """``pto.vneg`` – element-wise negation.""" + return _emit_unary_vec_op(_pto.VnegOp, inp, mask) + + +def vrelu(inp, mask): + """``pto.vrelu`` – element-wise ReLU.""" + return _emit_unary_vec_op(_pto.VreluOp, inp, mask) + + +def vnot(inp, mask): + """``pto.vnot`` – element-wise bitwise/logical not.""" + return _emit_unary_vec_op(_pto.VnotOp, inp, mask) + + +def vexpdif(inp, ref, mask, part: str = "ODD"): + """``pto.vexpdif`` – ``exp(inp - ref)`` selecting ODD or EVEN lanes.""" + return wrap_surface_value( + _pto.VexpdifOp( + unwrap_surface_value(inp).type, + unwrap_surface_value(inp), + unwrap_surface_value(ref), + unwrap_surface_value(mask), + part, + ).result + ) + + +def vexp(inp, mask): + """``pto.vexp`` – element-wise exponential.""" + return _emit_unary_vec_op(_pto.VexpOp, inp, mask) + + +def vrec(inp, mask): + """``pto.vrec`` – reciprocal, surfaced as ``1 / inp``.""" + zero_vec = vmuls(inp, 0, mask) + one_vec = vadds(zero_vec, 1, mask) + return vdiv(one_vec, inp, mask) + + +def vrsqrt(inp, mask): + """``pto.vrsqrt`` – inverse square root, surfaced as ``1 / sqrt(inp)``.""" + sqrt_vec = vsqrt(inp, mask) + return vrec(sqrt_vec, mask) + + +def vcgmax(v, mask): + """``pto.vcgmax`` – group maximum reduction, surfaced as the lowest-lane scalar.""" + reduced = _pto.VcgmaxOp( + unwrap_surface_value(v).type, + unwrap_surface_value(v), + unwrap_surface_value(mask), + ).result + return _extract_lowest_lane_scalar(reduced, mask) + + +def vcgadd(v, mask): + """``pto.vcgadd`` – group sum reduction, surfaced as the lowest-lane scalar.""" + reduced = _pto.VcgaddOp( + unwrap_surface_value(v).type, + unwrap_surface_value(v), + unwrap_surface_value(mask), + ).result + return _extract_lowest_lane_scalar(reduced, mask) + + +def vcgmin(v, mask): + """``pto.vcgmin`` – group minimum reduction, surfaced as the lowest-lane scalar.""" + reduced = _pto.VcgminOp( + unwrap_surface_value(v).type, + unwrap_surface_value(v), + unwrap_surface_value(mask), + ).result + return _extract_lowest_lane_scalar(reduced, mask) + + +def vcpadd(v, mask): + """``pto.vcpadd`` – inclusive prefix sum.""" + return _emit_unary_vec_op(_pto.VcpaddOp, v, mask) + + +def vadds(inp, scalar, mask): + """``pto.vadds`` – vector plus scalar under mask.""" + return _emit_vec_scalar_masked_op(_pto.VaddsOp, inp, scalar, mask, context="vadds") + + +def vsubs(inp, scalar, mask): + """``pto.vsubs`` – vector minus scalar under mask.""" + raw_scalar = _coerce_scalar_like_vector_element(inp, scalar, context="vsubs") + neg_scalar = _negate_runtime_scalar(raw_scalar) + return wrap_surface_value( + _pto.VaddsOp( + unwrap_surface_value(inp).type, + unwrap_surface_value(inp), + neg_scalar, + unwrap_surface_value(mask), + ).result + ) + + +def vmuls(inp, scalar, mask): + """``pto.vmuls`` – vector times scalar under mask.""" + return _emit_vec_scalar_masked_op(_pto.VmulsOp, inp, scalar, mask, context="vmuls") + + +def vmaxs(inp, scalar, mask): + """``pto.vmaxs`` – vector/scalar maximum under mask.""" + return _emit_vec_scalar_masked_op(_pto.VmaxsOp, inp, scalar, mask, context="vmaxs") + + +def vmins(inp, scalar, mask): + """``pto.vmins`` – vector/scalar minimum under mask.""" + return _emit_vec_scalar_masked_op(_pto.VminsOp, inp, scalar, mask, context="vmins") + + +def vlrelu(inp, alpha, mask): + """``pto.vlrelu`` – vector leaky ReLU under mask.""" + return _emit_vec_scalar_masked_op(_pto.VlreluOp, inp, alpha, mask, context="vlrelu") + + +def vaddrelu(lhs, rhs, mask): + """``pto.vaddrelu`` – add, then apply ReLU.""" + return vrelu(vadd(lhs, rhs, mask), mask) + + +def vsubrelu(lhs, rhs, mask): + """``pto.vsubrelu`` – subtract, then apply ReLU.""" + return vrelu(vsub(lhs, rhs, mask), mask) + + +def vaxpy(alpha, x, y, mask): + """``pto.vaxpy`` – fused ``alpha * x + y``.""" + alpha_value = _coerce_scalar_like_vector_element(x, alpha, context="vaxpy") + return wrap_surface_value( + _pto.VaxpyOp( + unwrap_surface_value(x).type, + unwrap_surface_value(x), + unwrap_surface_value(y), + unwrap_surface_value(alpha_value), + unwrap_surface_value(mask), + ).result + ) + + +def vsel(true_v, false_v, mask): + """``pto.vsel`` – element-wise select under a predicate mask.""" + return wrap_surface_value( + _pto.VselOp( + unwrap_surface_value(true_v).type, + unwrap_surface_value(true_v), + unwrap_surface_value(false_v), + unwrap_surface_value(mask), + ).result + ) + + +# ── Tile-domain operations ──────────────────────────────────────────────────── + +def make_tensor_view(ptr, *, shape=None, strides=None): + """ + ``pto.make_tensor_view`` – wrap a pointer as a tensor view. + + Type is inferred: rank from ``len(shape)``, element type from ``ptr``. + """ + authored_ptr = ptr + if shape is None: + shape = getattr(authored_ptr, "shape", None) + if strides is None: + strides = getattr(authored_ptr, "strides", None) + if shape is None or strides is None: + raise TypeError("make_tensor_view() requires shape= and strides=, or a host tensor proxy carrying both") + ptr = resolve_tensor_data_entry(authored_ptr) + rank = len(shape) + raw_ptr = unwrap_surface_value(ptr) + elem = _pto.PtrType(raw_ptr.type).element_type + static_dims = _static_index_dims(shape) + tv_type = ( + tensor_view_type_from_dims(static_dims, elem) + if static_dims is not None + else tensor_view_type(rank, elem) + ) + shape_operands = [_coerce_index_value(dim) for dim in shape] + stride_operands = [_coerce_index_value(dim) for dim in strides] + value = _pto.MakeTensorViewOp( + tv_type, + raw_ptr, + shape_operands, + stride_operands, + ).result + return TensorViewValue(value, shape=tuple(shape), strides=tuple(strides)) + + +def _normalize_static_tile_shape(shape): + static_shape = [] + for dim in shape: + if isinstance(dim, bool) or not isinstance(dim, int): + raise TypeError( + "alloc_tile(shape=...) currently requires a static physical tile shape. " + "Use constexpr/static integers for shape and place runtime metadata in valid_shape." + ) + static_shape.append(dim) + return tuple(static_shape) + + +def _authored_tile_physical_shape(shape): + if len(shape) == 1: + return (1, shape[0]) + return tuple(shape) + + +def _split_valid_shape(shape, valid_shape): + logical_rank = len(shape) + if valid_shape is None: + return _authored_tile_physical_shape(shape), None, None, tuple(shape) + + if len(valid_shape) != logical_rank: + raise TypeError( + f"alloc_tile(valid_shape=...) rank mismatch: expected {logical_rank} dims, got {len(valid_shape)}" + ) + + surface_valid_shape = [] + if logical_rank == 1: + dim = valid_shape[0] + surface_valid_shape.append(dim) + if isinstance(dim, bool): + raise TypeError("alloc_tile(valid_shape=...) does not accept bool dimensions") + if isinstance(dim, int): + return (1, dim), None, None, tuple(surface_valid_shape) + return (-1, -1), 1, dim, tuple(surface_valid_shape) + + type_valid_shape = [] + valid_row = None + valid_col = None + for index, dim in enumerate(valid_shape): + surface_valid_shape.append(dim) + if isinstance(dim, bool): + raise TypeError("alloc_tile(valid_shape=...) does not accept bool dimensions") + if isinstance(dim, int): + type_valid_shape.append(dim) + continue + type_valid_shape.append(-1) + if index == 0: + valid_row = dim + continue + if index == 1: + valid_col = dim + continue + raise TypeError( + "alloc_tile(valid_shape=...) currently only supports dynamic runtime metadata " + "for the first two dimensions" + ) + return tuple(type_valid_shape), valid_row, valid_col, tuple(surface_valid_shape) + + +def _uses_row_major_none_box_layout(blayout, slayout) -> bool: + return str(blayout).lower() == "rowmajor" and str(slayout).lower() == "nonebox" + + +def _validate_authored_tile_row_alignment(shape, dtype, *, blayout, slayout): + if not _uses_row_major_none_box_layout(blayout, slayout): + return + if not shape: + return + elem_bytewidth = _element_bytewidth(_resolve(dtype)) + row_bytes = shape[-1] * elem_bytewidth + required_alignment = 32 + if row_bytes % required_alignment == 0: + return + raise tile_row_alignment_error( + shape=shape, + dtype=str(_resolve(dtype)), + row_bytes=row_bytes, + required_alignment=required_alignment, + ) + + +def partition_view(tv, *, offsets, sizes): + """ + ``pto.partition_view`` – slice a tensor view. + + Type is inferred from the source tensor-view type. + """ + spec = compose_partition_spec(tv, offsets=offsets, sizes=sizes) + if spec is not None: + source = spec.root_tensor_view + offsets = spec.offsets + sizes = spec.sizes + else: + source = tv + + raw_source = unwrap_surface_value(source) + src_type = _pto.TensorViewType(raw_source.type) + rank = src_type.rank + elem = src_type.element_type + static_dims = _static_index_dims(sizes) + ptv_type = ( + part_tensor_view_type_from_dims(static_dims, elem) + if static_dims is not None + else part_tensor_view_type(rank, elem) + ) + value = _pto.PartitionViewOp( + ptv_type, + raw_source, + _unwrap_sequence(offsets), + _unwrap_sequence(sizes), + ).result + return wrap_surface_value( + value, + root_tensor_view=source if spec is None else spec.root_tensor_view, + offsets=tuple(offsets), + sizes=tuple(sizes), + ) + + +def alloc_tile( + tile_type=None, + *, + shape=None, + dtype=None, + memory_space="ub", + valid_shape=None, + blayout: str = "RowMajor", + slayout: str = "NoneBox", + fractal_size: int = 512, + pad: str = "Null", + addr=None, + valid_row=None, + valid_col=None, +): + """ + ``pto.alloc_tile``. + + Accepts either the authored surface form: + + ``alloc_tile(shape=[...], dtype=..., memory_space=..., valid_shape=..., addr=...)`` + + or the low-level explicit-type form: + + ``alloc_tile(tile_type, addr=..., valid_row=..., valid_col=...)``. + """ + if tile_type is not None and shape is not None: + raise TypeError("alloc_tile() accepts either tile_type or shape=/dtype=, not both") + + if tile_type is None: + if shape is None or dtype is None: + raise TypeError("alloc_tile() requires either tile_type or both shape= and dtype=") + if valid_row is not None or valid_col is not None: + raise TypeError( + "alloc_tile(shape=..., dtype=...) uses the authored surface form; " + "use valid_shape=... instead of valid_row=/valid_col=" + ) + logical_shape = _normalize_static_tile_shape(shape) + physical_shape = _authored_tile_physical_shape(logical_shape) + _validate_authored_tile_row_alignment(physical_shape, dtype, blayout=blayout, slayout=slayout) + type_valid_shape, valid_row, valid_col, surface_valid_shape = _split_valid_shape(logical_shape, valid_shape) + from ._types import tile_buf_type + tile_type = tile_buf_type( + physical_shape, + dtype, + type_valid_shape, + blayout=blayout, + address_space=memory_space, + slayout=slayout, + fractal_size=fractal_size, + pad=pad, + ) + shape = logical_shape + else: + physical_shape = None + surface_valid_shape = None + + value = _pto.AllocTileOp( + _resolve(tile_type), + addr=_coerce_i64(addr, context="alloc_tile(addr)") if addr is not None else None, + valid_row=_coerce_index(valid_row, context="alloc_tile(valid_row)") if valid_row is not None else None, + valid_col=_coerce_index(valid_col, context="alloc_tile(valid_col)") if valid_col is not None else None, + ).result + if tile_type is not None and (valid_row is not None or valid_col is not None): + parsed_tile_type = parse_tile_type_metadata(_resolve(tile_type)) + rank = len(shape) if shape is not None else len(parsed_tile_type["shape_dims"]) + surface_valid_shape = [None] * rank + if rank >= 1: + surface_valid_shape[0] = valid_row + if rank >= 2: + surface_valid_shape[1] = valid_col + surface_valid_shape = tuple(surface_valid_shape) + return wrap_surface_value( + value, + tile_metadata={ + "shape": shape, + "physical_shape": physical_shape, + "dtype": dtype, + "memory_space": memory_space, + "valid_shape": surface_valid_shape, + }, + ) + + +def set_tile_valid_shape(tile, valid_shape): + """Update the runtime valid-shape metadata of an authored dynamic tile.""" + parsed_tile_type = parse_tile_type_metadata(unwrap_surface_value(tile).type) + if parsed_tile_type is None: + raise TypeError("tile.valid_shape assignment expects a tile_buf-backed value") + if len(parsed_tile_type["shape_dims"]) != 2: + raise TypeError("tile.valid_shape assignment currently only supports rank-2 tiles") + logical_rank = len(tile.shape) if getattr(tile, "shape", None) is not None else 2 + if logical_rank == 1: + if len(valid_shape) != 1: + raise TypeError("rank-1 tile.valid_shape assignment expects exactly one dimension") + if parsed_tile_type["valid_dims"] != (None, None): + raise TypeError( + "rank-1 tile.valid_shape assignment requires a tile allocated with " + "valid_shape=[...] so the physical valid row/col metadata remain dynamic" + ) + valid_row = _coerce_index_value(1) + valid_col, = _unwrap_sequence(valid_shape) + else: + if len(valid_shape) != 2: + raise TypeError("tile.valid_shape assignment currently expects exactly two dimensions") + if parsed_tile_type["valid_dims"] != (None, None): + raise TypeError( + "tile.valid_shape assignment requires a tile allocated with fully dynamic " + "valid_shape=[..., ...]" + ) + valid_row, valid_col = _unwrap_sequence(valid_shape) + _pto.SetValidShapeOp( + unwrap_surface_value(tile), + valid_row, + valid_col, + ) + + +def tload(part, tile): + """``pto.tload ins(part) outs(tile)``.""" + _pto.TLoadOp(None, unwrap_surface_value(part), unwrap_surface_value(tile)) + + +def tstore(tile, part): + """``pto.tstore ins(tile) outs(part)``.""" + _pto.TStoreOp(None, unwrap_surface_value(tile), unwrap_surface_value(part)) + + +def tmov(src, dst): + """``pto.tmov ins(src) outs(dst)`` – move data between tile domains.""" + _pto.TMovOp(None, unwrap_surface_value(src), unwrap_surface_value(dst)) + + +def _coerce_tile_scalar_operand(tile, scalar, *, context: str): + return _constant_like(scalar, infer_tile_element_type(wrap_surface_value(tile))) + + +def tadd(src0, src1, dst): + """``pto.tadd ins(src0, src1) outs(dst)``.""" + _pto.tadd( + unwrap_surface_value(src0), + unwrap_surface_value(src1), + unwrap_surface_value(dst), + ) + + +def tsub(src0, src1, dst): + """``pto.tsub ins(src0, src1) outs(dst)``.""" + _pto.tsub( + unwrap_surface_value(src0), + unwrap_surface_value(src1), + unwrap_surface_value(dst), + ) + + +def tmul(src0, src1, dst): + """``pto.tmul ins(src0, src1) outs(dst)``.""" + _pto.tmul( + unwrap_surface_value(src0), + unwrap_surface_value(src1), + unwrap_surface_value(dst), + ) + + +def tdiv(src0, src1, dst, *, precision_mode=None): + """``pto.tdiv ins(src0, src1) outs(dst)``.""" + _pto.tdiv( + unwrap_surface_value(src0), + unwrap_surface_value(src1), + unwrap_surface_value(dst), + precision_mode=precision_mode, + ) + + +def tmax(src0, src1, dst): + """``pto.tmax ins(src0, src1) outs(dst)``.""" + _pto.tmax( + unwrap_surface_value(src0), + unwrap_surface_value(src1), + unwrap_surface_value(dst), + ) + + +def tmin(src0, src1, dst): + """``pto.tmin ins(src0, src1) outs(dst)``.""" + _pto.tmin( + unwrap_surface_value(src0), + unwrap_surface_value(src1), + unwrap_surface_value(dst), + ) + + +def tadds(src, scalar, dst): + """``pto.tadds ins(src, scalar) outs(dst)``.""" + _pto.tadds( + unwrap_surface_value(src), + _coerce_tile_scalar_operand(src, scalar, context="tadds"), + unwrap_surface_value(dst), + ) + + +def tsubs(src, scalar, dst): + """``pto.tsubs ins(src, scalar) outs(dst)``.""" + _pto.tsubs( + unwrap_surface_value(src), + _coerce_tile_scalar_operand(src, scalar, context="tsubs"), + unwrap_surface_value(dst), + ) + + +def tmuls(src, scalar, dst): + """``pto.tmuls ins(src, scalar) outs(dst)``.""" + _pto.tmuls( + unwrap_surface_value(src), + _coerce_tile_scalar_operand(src, scalar, context="tmuls"), + unwrap_surface_value(dst), + ) + + +def tdivs(src, scalar, dst, *, precision_mode=None): + """``pto.tdivs ins(src, scalar) outs(dst)``.""" + _pto.tdivs( + unwrap_surface_value(src), + _coerce_tile_scalar_operand(src, scalar, context="tdivs"), + unwrap_surface_value(dst), + precision_mode=precision_mode, + ) + + +def tmaxs(src, scalar, dst): + """``pto.tmaxs ins(src, scalar) outs(dst)``.""" + _pto.tmaxs( + unwrap_surface_value(src), + _coerce_tile_scalar_operand(src, scalar, context="tmaxs"), + unwrap_surface_value(dst), + ) + + +def tmins(src, scalar, dst): + """``pto.tmins ins(src, scalar) outs(dst)``.""" + _pto.tmins( + unwrap_surface_value(src), + _coerce_tile_scalar_operand(src, scalar, context="tmins"), + unwrap_surface_value(dst), + ) + + +def texp(src, dst, *, precision_mode=None): + """``pto.texp ins(src) outs(dst)``.""" + _pto.texp( + unwrap_surface_value(src), + unwrap_surface_value(dst), + precision_mode=precision_mode, + ) + + +def tlog(src, dst, *, precision_mode=None): + """``pto.tlog ins(src) outs(dst)``.""" + _pto.tlog( + unwrap_surface_value(src), + unwrap_surface_value(dst), + precision_mode=precision_mode, + ) + + +def tsqrt(src, dst, *, precision_mode=None): + """``pto.tsqrt ins(src) outs(dst)``.""" + _pto.tsqrt( + unwrap_surface_value(src), + unwrap_surface_value(dst), + precision_mode=precision_mode, + ) + + +def trsqrt(src, dst, *, tmp=None, precision_mode=None): + """``pto.trsqrt ins(src, tmp?) outs(dst)``.""" + _pto.trsqrt( + unwrap_surface_value(src), + unwrap_surface_value(dst), + tmp=None if tmp is None else unwrap_surface_value(tmp), + precision_mode=precision_mode, + ) + + +def trecip(src, dst, *, precision_mode=None): + """``pto.trecip ins(src) outs(dst)``.""" + _pto.trecip( + unwrap_surface_value(src), + unwrap_surface_value(dst), + precision_mode=precision_mode, + ) + + +def tabs(src, dst): + """``pto.tabs ins(src) outs(dst)``.""" + _pto.tabs( + unwrap_surface_value(src), + unwrap_surface_value(dst), + ) + + +def tneg(src, dst): + """``pto.tneg ins(src) outs(dst)``.""" + _pto.tneg( + unwrap_surface_value(src), + unwrap_surface_value(dst), + ) + + +def trelu(src, dst): + """``pto.trelu ins(src) outs(dst)``.""" + _pto.trelu( + unwrap_surface_value(src), + unwrap_surface_value(dst), + ) + + +def tlrelu(src, slope, dst): + """``pto.tlrelu ins(src, slope) outs(dst)``.""" + _pto.tlrelu( + unwrap_surface_value(src), + _coerce_tile_scalar_operand(src, slope, context="tlrelu"), + unwrap_surface_value(dst), + ) + + +def trowsum(src, tmp, dst): + """``pto.trowsum ins(src, tmp) outs(dst)``.""" + _pto.trowsum( + unwrap_surface_value(src), + unwrap_surface_value(tmp), + unwrap_surface_value(dst), + ) + + +def trowmax(src, tmp, dst): + """``pto.trowmax ins(src, tmp) outs(dst)``.""" + _pto.trowmax( + unwrap_surface_value(src), + unwrap_surface_value(tmp), + unwrap_surface_value(dst), + ) + + +def trowmin(src, tmp, dst): + """``pto.trowmin ins(src, tmp) outs(dst)``.""" + _pto.trowmin( + unwrap_surface_value(src), + unwrap_surface_value(tmp), + unwrap_surface_value(dst), + ) + + +def trowprod(src, tmp, dst): + """``pto.trowprod ins(src, tmp) outs(dst)``.""" + _pto.trowprod( + unwrap_surface_value(src), + unwrap_surface_value(tmp), + unwrap_surface_value(dst), + ) + + +def trowargmax(src, tmp, dst): + """``pto.trowargmax ins(src, tmp) outs(dst)``.""" + _pto.trowargmax( + unwrap_surface_value(src), + unwrap_surface_value(tmp), + unwrap_surface_value(dst), + ) + + +def trowargmin(src, tmp, dst): + """``pto.trowargmin ins(src, tmp) outs(dst)``.""" + _pto.trowargmin( + unwrap_surface_value(src), + unwrap_surface_value(tmp), + unwrap_surface_value(dst), + ) + + +def tcolsum(src, dst, *, tmp=None, is_binary=None): + """``pto.tcolsum ins(src, tmp?) outs(dst)``.""" + _pto.tcolsum( + unwrap_surface_value(src), + unwrap_surface_value(dst), + tmp=None if tmp is None else unwrap_surface_value(tmp), + is_binary=is_binary, + ) + + +def tcolmax(src, dst): + """``pto.tcolmax ins(src) outs(dst)``.""" + _pto.tcolmax( + unwrap_surface_value(src), + unwrap_surface_value(dst), + ) + + +def tcolmin(src, dst): + """``pto.tcolmin ins(src) outs(dst)``.""" + _pto.tcolmin( + unwrap_surface_value(src), + unwrap_surface_value(dst), + ) + + +def tcolprod(src, dst): + """``pto.tcolprod ins(src) outs(dst)``.""" + _pto.tcolprod( + unwrap_surface_value(src), + unwrap_surface_value(dst), + ) + + +def tcolargmax(src, tmp, dst): + """``pto.tcolargmax ins(src, tmp) outs(dst)``.""" + _pto.tcolargmax( + unwrap_surface_value(src), + unwrap_surface_value(tmp), + unwrap_surface_value(dst), + ) + + +def tcolargmin(src, tmp, dst): + """``pto.tcolargmin ins(src, tmp) outs(dst)``.""" + _pto.tcolargmin( + unwrap_surface_value(src), + unwrap_surface_value(tmp), + unwrap_surface_value(dst), + ) + + +def tcmp(src0, src1, dst, *, cmp_mode=None): + """``pto.tcmp ins(src0, src1) outs(dst)``.""" + _pto.tcmp( + unwrap_surface_value(src0), + unwrap_surface_value(src1), + unwrap_surface_value(dst), + cmp_mode=None if cmp_mode is None else _cmp_mode_attr(cmp_mode), + ) + + +def tcmps(src, scalar, dst, *, cmp_mode=None): + """``pto.tcmps ins(src, scalar) outs(dst)``.""" + _pto.tcmps( + unwrap_surface_value(src), + _coerce_tile_scalar_operand(src, scalar, context="tcmps"), + unwrap_surface_value(dst), + cmp_mode=None if cmp_mode is None else _cmp_mode_attr(cmp_mode), + ) + + +def texpands(scalar, dst): + """``pto.texpands ins(scalar) outs(dst)``.""" + _pto.texpands( + _coerce_tile_scalar_operand(dst, scalar, context="texpands"), + unwrap_surface_value(dst), + ) + + +def trowexpand(src, dst): + """``pto.trowexpand ins(src) outs(dst)``.""" + _pto.trowexpand( + unwrap_surface_value(src), + unwrap_surface_value(dst), + ) + + +def tcolexpand(src, dst): + """``pto.tcolexpand ins(src) outs(dst)``.""" + _pto.tcolexpand( + unwrap_surface_value(src), + unwrap_surface_value(dst), + ) + + +def trowexpandadd(src0, src1, dst): + """``pto.trowexpandadd ins(src0, src1) outs(dst)``.""" + _pto.trowexpandadd( + unwrap_surface_value(src0), + unwrap_surface_value(src1), + unwrap_surface_value(dst), + ) + + +def trowexpandsub(src0, src1, dst, *, tmp=None): + """``pto.trowexpandsub ins(src0, src1, tmp?) outs(dst)``.""" + _pto.trowexpandsub( + unwrap_surface_value(src0), + unwrap_surface_value(src1), + unwrap_surface_value(dst), + tmp=None if tmp is None else unwrap_surface_value(tmp), + ) + + +def trowexpandmul(src0, src1, dst, *, tmp=None): + """``pto.trowexpandmul ins(src0, src1, tmp?) outs(dst)``.""" + _pto.trowexpandmul( + unwrap_surface_value(src0), + unwrap_surface_value(src1), + unwrap_surface_value(dst), + tmp=None if tmp is None else unwrap_surface_value(tmp), + ) + + +def trowexpanddiv(src0, src1, dst, *, tmp=None, precision_mode=None): + """``pto.trowexpanddiv ins(src0, src1, tmp?) outs(dst)``.""" + _pto.trowexpanddiv( + unwrap_surface_value(src0), + unwrap_surface_value(src1), + unwrap_surface_value(dst), + tmp=None if tmp is None else unwrap_surface_value(tmp), + precision_mode=precision_mode, + ) + + +def trowexpandmax(src0, src1, dst, *, tmp=None): + """``pto.trowexpandmax ins(src0, src1, tmp?) outs(dst)``.""" + _pto.trowexpandmax( + unwrap_surface_value(src0), + unwrap_surface_value(src1), + unwrap_surface_value(dst), + tmp=None if tmp is None else unwrap_surface_value(tmp), + ) + + +def trowexpandmin(src0, src1, dst, *, tmp=None): + """``pto.trowexpandmin ins(src0, src1, tmp?) outs(dst)``.""" + _pto.trowexpandmin( + unwrap_surface_value(src0), + unwrap_surface_value(src1), + unwrap_surface_value(dst), + tmp=None if tmp is None else unwrap_surface_value(tmp), + ) + + +def trowexpandexpdif(src0, src1, dst, *, tmp=None): + """``pto.trowexpandexpdif ins(src0, src1, tmp?) outs(dst)``.""" + _pto.trowexpandexpdif( + unwrap_surface_value(src0), + unwrap_surface_value(src1), + unwrap_surface_value(dst), + tmp=None if tmp is None else unwrap_surface_value(tmp), + ) + + +def tcolexpandadd(src0, src1, dst): + """``pto.tcolexpandadd ins(src0, src1) outs(dst)``.""" + _pto.tcolexpandadd( + unwrap_surface_value(src0), + unwrap_surface_value(src1), + unwrap_surface_value(dst), + ) + + +def tcolexpandsub(src0, src1, dst): + """``pto.tcolexpandsub ins(src0, src1) outs(dst)``.""" + _pto.tcolexpandsub( + unwrap_surface_value(src0), + unwrap_surface_value(src1), + unwrap_surface_value(dst), + ) + + +def tcolexpandmul(src0, src1, dst): + """``pto.tcolexpandmul ins(src0, src1) outs(dst)``.""" + _pto.tcolexpandmul( + unwrap_surface_value(src0), + unwrap_surface_value(src1), + unwrap_surface_value(dst), + ) + + +def tcolexpanddiv(src0, src1, dst, *, precision_mode=None): + """``pto.tcolexpanddiv ins(src0, src1) outs(dst)``.""" + _pto.tcolexpanddiv( + unwrap_surface_value(src0), + unwrap_surface_value(src1), + unwrap_surface_value(dst), + precision_mode=precision_mode, + ) + + +def tcolexpandmax(src0, src1, dst): + """``pto.tcolexpandmax ins(src0, src1) outs(dst)``.""" + _pto.tcolexpandmax( + unwrap_surface_value(src0), + unwrap_surface_value(src1), + unwrap_surface_value(dst), + ) + + +def tcolexpandmin(src0, src1, dst): + """``pto.tcolexpandmin ins(src0, src1) outs(dst)``.""" + _pto.tcolexpandmin( + unwrap_surface_value(src0), + unwrap_surface_value(src1), + unwrap_surface_value(dst), + ) + + +def tcolexpandexpdif(src0, src1, dst): + """``pto.tcolexpandexpdif ins(src0, src1) outs(dst)``.""" + _pto.tcolexpandexpdif( + unwrap_surface_value(src0), + unwrap_surface_value(src1), + unwrap_surface_value(dst), + ) + + +def _resolve_selection_tmp(dst, tmp, *, context: str): + if tmp is not None: + return tmp + + session = None + try: + from ._tracing.active import current_session + session = current_session() + except Exception: + session = None + + if session is not None and getattr(session.module_spec, "target_arch", None) == "a5": + return dst + + return alloc_tile(tile_type=unwrap_surface_value(dst).type) + + +def tsel(mask, src0, src1, dst, *, tmp=None): + """``pto.tsel ins(mask, src0, src1, tmp) outs(dst)`` with synthesized scratch when omitted.""" + resolved_tmp = tmp if tmp is not None else _resolve_selection_tmp(dst, tmp, context="tsel") + _pto.tsel( + unwrap_surface_value(mask), + unwrap_surface_value(src0), + unwrap_surface_value(src1), + unwrap_surface_value(resolved_tmp), + unwrap_surface_value(dst), + ) + + +def tsels(mask, src, scalar, dst, *, tmp=None): + """``pto.tsels ins(mask, src, tmp, scalar) outs(dst)`` with synthesized scratch when omitted.""" + resolved_tmp = tmp if tmp is not None else _resolve_selection_tmp(dst, tmp, context="tsels") + _pto.tsels( + unwrap_surface_value(mask), + unwrap_surface_value(src), + unwrap_surface_value(resolved_tmp), + _coerce_tile_scalar_operand(src, scalar, context="tsels"), + unwrap_surface_value(dst), + ) + + +def tcvt(src, dst, *, tmp=None, rmode=None, sat_mode=None): + """``pto.tcvt ins(src, tmp?) outs(dst)``.""" + _pto.tcvt( + unwrap_surface_value(src), + unwrap_surface_value(dst), + tmp=None if tmp is None else unwrap_surface_value(tmp), + rmode=rmode, + sat_mode=sat_mode, + ) + + +def tnot(src, dst): + """``pto.tnot ins(src) outs(dst)``.""" + _pto.tnot( + unwrap_surface_value(src), + unwrap_surface_value(dst), + ) + + +def tand(src0, src1, dst): + """``pto.tand ins(src0, src1) outs(dst)``.""" + _pto.tand( + unwrap_surface_value(src0), + unwrap_surface_value(src1), + unwrap_surface_value(dst), + ) + + +def tands(src, scalar, dst): + """``pto.tands ins(src, scalar) outs(dst)``.""" + _pto.tands( + unwrap_surface_value(src), + _coerce_tile_scalar_operand(src, scalar, context="tands"), + unwrap_surface_value(dst), + ) + + +def tor(src0, src1, dst): + """``pto.tor ins(src0, src1) outs(dst)``.""" + _pto.tor( + unwrap_surface_value(src0), + unwrap_surface_value(src1), + unwrap_surface_value(dst), + ) + + +def tors(src, scalar, dst): + """``pto.tors ins(src, scalar) outs(dst)``.""" + _pto.tors( + unwrap_surface_value(src), + _coerce_tile_scalar_operand(src, scalar, context="tors"), + unwrap_surface_value(dst), + ) + + +def txor(src0, src1, tmp, dst): + """``pto.txor ins(src0, src1, tmp) outs(dst)``.""" + _pto.txor( + unwrap_surface_value(src0), + unwrap_surface_value(src1), + unwrap_surface_value(tmp), + unwrap_surface_value(dst), + ) + + +def txors(src, scalar, tmp, dst): + """``pto.txors ins(src, scalar, tmp) outs(dst)``.""" + _pto.txors( + unwrap_surface_value(src), + _coerce_tile_scalar_operand(src, scalar, context="txors"), + unwrap_surface_value(tmp), + unwrap_surface_value(dst), + ) + + +def tshl(src0, src1, dst): + """``pto.tshl ins(src0, src1) outs(dst)``.""" + _pto.tshl( + unwrap_surface_value(src0), + unwrap_surface_value(src1), + unwrap_surface_value(dst), + ) + + +def tshls(src, scalar, dst): + """``pto.tshls ins(src, scalar) outs(dst)``.""" + _pto.tshls( + unwrap_surface_value(src), + _coerce_tile_scalar_operand(src, scalar, context="tshls"), + unwrap_surface_value(dst), + ) + + +def tshr(src0, src1, dst): + """``pto.tshr ins(src0, src1) outs(dst)``.""" + _pto.tshr( + unwrap_surface_value(src0), + unwrap_surface_value(src1), + unwrap_surface_value(dst), + ) + + +def tshrs(src, scalar, dst): + """``pto.tshrs ins(src, scalar) outs(dst)``.""" + _pto.tshrs( + unwrap_surface_value(src), + _coerce_tile_scalar_operand(src, scalar, context="tshrs"), + unwrap_surface_value(dst), + ) + + +def tpartadd(src0, src1, dst): + """``pto.tpartadd ins(src0, src1) outs(dst)``.""" + _pto.tpartadd( + unwrap_surface_value(src0), + unwrap_surface_value(src1), + unwrap_surface_value(dst), + ) + + +def tpartmul(src0, src1, dst): + """``pto.tpartmul ins(src0, src1) outs(dst)``.""" + _pto.tpartmul( + unwrap_surface_value(src0), + unwrap_surface_value(src1), + unwrap_surface_value(dst), + ) + + +def tpartmax(src0, src1, dst): + """``pto.tpartmax ins(src0, src1) outs(dst)``.""" + _pto.tpartmax( + unwrap_surface_value(src0), + unwrap_surface_value(src1), + unwrap_surface_value(dst), + ) + + +def tpartmin(src0, src1, dst): + """``pto.tpartmin ins(src0, src1) outs(dst)``.""" + _pto.tpartmin( + unwrap_surface_value(src0), + unwrap_surface_value(src1), + unwrap_surface_value(dst), + ) + + +def tfillpad(src, dst): + """``pto.tfillpad ins(src) outs(dst)``.""" + _pto.tfillpad( + unwrap_surface_value(src), + unwrap_surface_value(dst), + ) + + +def tfillpad_expand(src, dst): + """``pto.tfillpad_expand ins(src) outs(dst)``.""" + _pto.tfillpad_expand( + unwrap_surface_value(src), + unwrap_surface_value(dst), + ) + + +def tfillpad_inplace(src, dst): + """``pto.tfillpad_inplace ins(src) outs(dst)``.""" + _pto.tfillpad_inplace( + unwrap_surface_value(src), + unwrap_surface_value(dst), + ) + + +def as_ptr(value): + """Materialize a typed pointer from a tile or tensor-view descriptor.""" + wrapped = wrap_surface_value(value) + return emit_as_ptr(wrapped) + + +def _constant_like(value, mlir_type): + value = unwrap_surface_value(value) + if hasattr(value, "type"): + return value + if isinstance(value, float): + return arith.ConstantOp(mlir_type, FloatAttr.get(mlir_type, value)).result + if IntegerType.isinstance(mlir_type): + return _materialize_integer_literal(mlir_type, value) + return arith.ConstantOp(mlir_type, value).result + + +def _index_zero(): + return arith.ConstantOp(IndexType.get(), 0).result + + +def _tile_slice_linear_offset(tile_slice: TileSliceValue): + offsets = tile_slice.offsets + if len(offsets) == 1: + return offsets[0] + if len(offsets) != 2: + raise RuntimeError("tile slice pointer lowering only supports rank-1 or rank-2 offsets") + + physical_shape = getattr(tile_slice.tile, "physical_shape", None) + if physical_shape is None or len(physical_shape) != 2 or physical_shape[1] is None: + raise RuntimeError("tile slice pointer lowering requires static physical column shape metadata") + + row, col = offsets + stride = physical_shape[1] + if isinstance(row, int) and isinstance(col, int): + return row * stride + col + + row_value = _coerce_index(row, context="tile slice pointer lowering") + row_stride = arith.MulIOp(row_value, arith.ConstantOp(IndexType.get(), stride).result).result + col_value = _coerce_index(col, context="tile slice pointer lowering") + return arith.AddIOp(row_stride, col_value).result + + +def _tile_slice_ptr(tile_slice: TileSliceValue): + base_ptr = emit_as_ptr(tile_slice.tile) + linear_offset = _tile_slice_linear_offset(tile_slice) + if isinstance(linear_offset, int) and linear_offset == 0: + return base_ptr + return addptr(base_ptr, _coerce_index(linear_offset, context="tile slice pointer lowering")) + + +def _infer_vreg_type_from_tile_slice(tile_slice: TileSliceValue): + memref_type = MemRefType(tile_slice.type) + elem_type = memref_type.element_type + lanes = _elements_per_vreg(elem_type) + return _resolve(vreg_type(lanes, elem_type)) + + +def _infer_vreg_type_from_address_source(src_ptr): + raw_source = unwrap_surface_value(src_ptr) + source_type = raw_source.type + try: + elem_type = _pto.PtrType(source_type).element_type + except Exception: + try: + elem_type = MemRefType(source_type).element_type + except Exception as exc: + raise TypeError( + f"vlds(ptr, offset) cannot infer a vector-register type from source {source_type}; " + "pass result_vreg_type= explicitly" + ) from exc + lanes = _elements_per_vreg(elem_type) + return _resolve(vreg_type(lanes, elem_type)) + + +def _elements_per_vreg(elem_type): + try: + bytewidth = _element_bytewidth(elem_type) + except TypeError as exc: + raise TypeError(f"vlds/vsts tile-slice sugar does not support element type {elem_type}") + return 256 // bytewidth + + +def _infer_vreg_metadata(vector_value): + raw_type = unwrap_surface_value(vector_value).type + try: + vreg_type = _pto.VRegType(raw_type) + return vreg_type.lanes, vreg_type.element_type + except Exception: + text = str(raw_type) + if not text.startswith("!pto.vreg<") or "x" not in text: + raise TypeError(f"expected PTO vector-register type, got {raw_type}") + body = text[len("!pto.vreg<"):-1] + lanes_text, elem_text = body.split("x", 1) + return int(lanes_text), Type.parse(elem_text) + + +def _extract_lowest_lane_scalar(vector_value, mask): + lanes, elem_type = _infer_vreg_metadata(vector_value) + tmp_tile = alloc_tile(shape=[1, lanes], dtype=elem_type, valid_shape=[1, 1]) + vsts(vector_value, tmp_tile.as_ptr(), _index_zero(), mask, dist="1PT_B32") + from . import scalar as _scalar + return _scalar.load(tmp_tile[0, 0]) + + +def _element_bytewidth(elem_type): + if F32Type.isinstance(elem_type): + return 4 + if any(cls.isinstance(elem_type) for cls in (F16Type, BF16Type)): + return 2 + if Float8E4M3FNType.isinstance(elem_type) or Float8E5M2Type.isinstance(elem_type): + return 1 + if any(_isinstance_pto_type(elem_type, name) for name in ("HiF8Type", "F4E1M2x2Type", "F4E2M1x2Type")): + return 1 + if IntegerType.isinstance(elem_type): + width = IntegerType(elem_type).width + if width % 8 != 0: + raise TypeError(f"unsupported sub-byte integer element type {elem_type}") + return width // 8 + raise TypeError(f"unsupported element type {elem_type}") + + +def bytewidth(dtype): + """Return the size in bytes of one element of *dtype*.""" + return _element_bytewidth(_resolve(dtype)) + + +def elements_per_vreg(dtype): + """Return how many elements of *dtype* fit in one 256-byte vector register.""" + return _elements_per_vreg(_resolve(dtype)) + + +def _mask_bits_for_dtype(dtype): + elem_type = _resolve(dtype) + bytewidth = _element_bytewidth(elem_type) + if bytewidth == 4: + return 32 + if bytewidth == 2: + return 16 + if bytewidth == 1: + return 8 + raise TypeError(f"make_mask(...) does not support dtype {elem_type}") + + +def _pset_op_for_mask_bits(mask_bits: int): + return { + 8: _pto.PsetB8Op, + 16: _pto.PsetB16Op, + 32: _pto.PsetB32Op, + }[mask_bits] + + +def _plt_op_for_mask_bits(mask_bits: int): + return { + 8: _pto.PltB8Op, + 16: _pto.PltB16Op, + 32: _pto.PltB32Op, + }[mask_bits] + + +def _coerce_i32(value, *, context: str): + raw_value = unwrap_surface_value(value) + i32_type = IntegerType.get_signless(32) + if isinstance(raw_value, bool): + raise TypeError(f"{context} does not accept bool values") + if isinstance(raw_value, int): + return _materialize_integer_literal(i32_type, raw_value) + kind = classify_runtime_scalar_type(raw_value.type) + if kind == "float": + raise TypeError(f"{context} expects an integer-like scalar, got {raw_value.type}") + if kind == "index": + return arith.IndexCastOp(i32_type, raw_value).result + signless_value = _strip_integer_signedness(raw_value) + if signless_value.type == i32_type: + return signless_value + width = IntegerType(raw_value.type).width + if width < 32: + if _integer_signedness(raw_value.type) == "unsigned": + return arith.ExtUIOp(i32_type, signless_value).result + return arith.ExtSIOp(i32_type, signless_value).result + if width > 32: + return arith.TruncIOp(i32_type, signless_value).result + return signless_value + + +def _coerce_i64(value, *, context: str): + raw_value = unwrap_surface_value(value) + i64_type = IntegerType.get_signless(64) + if isinstance(raw_value, bool): + raise TypeError(f"{context} does not accept bool values") + if isinstance(raw_value, int): + return _materialize_integer_literal(i64_type, raw_value) + kind = classify_runtime_scalar_type(raw_value.type) + if kind == "float": + raise TypeError(f"{context} expects an integer-like scalar, got {raw_value.type}") + if kind == "index": + return arith.IndexCastOp(i64_type, raw_value).result + signless_value = _strip_integer_signedness(raw_value) + if signless_value.type == i64_type: + return signless_value + width = IntegerType(raw_value.type).width + if width < 64: + if _integer_signedness(raw_value.type) == "unsigned": + return arith.ExtUIOp(i64_type, signless_value).result + return arith.ExtSIOp(i64_type, signless_value).result + if width > 64: + return arith.TruncIOp(i64_type, signless_value).result + return signless_value + + +def _i64_zero(): + return arith.ConstantOp(IntegerType.get_signless(64), 0).result + + +def _coerce_scalar_like_vector_element(vector_value, scalar_value, *, context: str): + _, elem_type = _infer_vreg_metadata(vector_value) + return coerce_scalar_to_type(scalar_value, elem_type, context=f"{context}(...)") + + +def _negate_runtime_scalar(value): + raw_value = unwrap_surface_value(value) + kind = classify_runtime_scalar_type(raw_value.type) + zero = materialize_scalar_literal(0.0 if kind == "float" else 0, raw_value.type, context="_negate_runtime_scalar(...)") + return emit_runtime_binary_op("sub", zero, raw_value) + + +def _mul_bytes(value, elem_type): + factor = _element_bytewidth(_resolve(elem_type)) + raw_value = unwrap_surface_value(value) + if isinstance(raw_value, int): + return raw_value * factor + return emit_runtime_binary_op("mul", raw_value, factor) + + +def _membar_attr(kind: str): + normalized = str(kind) + supported = { + "VV_ALL", + "VST_VLD", + "VLD_VST", + "VST_VST", + "VS_ALL", + "VST_LD", + "VLD_ST", + "VST_ST", + "SV_ALL", + "ST_VLD", + "LD_VST", + "ST_VST", + "SS_ALL", + "ST_LD", + "LD_ST", + "ST_ST", + } + if normalized not in supported: + raise ValueError(f"unsupported mem_bar kind {kind!r}") + return Attribute.parse(f"#pto.membar<{normalized}>") + + +def _acc_store_ub_dst_mode_attr(mode): + normalized = { + 0: "single", + 1: "split_m", + 2: "split_n", + "single": "single", + "split_m": "split_m", + "split_n": "split_n", + }.get(mode if isinstance(mode, int) else str(mode).lower()) + if normalized is None: + raise ValueError(f"unsupported mte_l0c_ub dst_mode {mode!r}") + return Attribute.parse(f"#pto") + + +def _infer_dma_partition_row_stride(partition: PartitionTensorViewValue): + if partition.shape is None or partition.strides is None: + raise TypeError("mte_load/mte_store require partition view shape/stride metadata") + outer_dims = list(partition.shape[:-1]) + non_unit = [i for i, dim in enumerate(outer_dims) if dim != 1] + if len(non_unit) > 1: + raise TypeError( + "mte_load/mte_store currently only support partitions with at most one non-unit " + "dimension before the contiguous innermost dimension" + ) + if not non_unit: + return 1, 0 + dim_index = non_unit[0] + return partition.shape[dim_index], partition.strides[dim_index] + + +def _infer_dma_tile_geometry(tile: TileValue): + if tile.shape is None: + raise TypeError("mte_load/mte_store require tile shape metadata") + if len(tile.shape) == 1: + valid_cols = tile.valid_shape[0] + return 1, valid_cols, tile.shape[0] + if len(tile.shape) == 2: + return tile.valid_shape[0], tile.valid_shape[1], tile.shape[1] + raise TypeError("mte_load/mte_store currently only support rank-1 or rank-2 tiles") + + +def _infer_dma_2d_copy_signature(partition, tile, *, direction: str): + row_count, src_row_stride = _infer_dma_partition_row_stride(partition) + tile_rows, valid_cols, physical_cols = _infer_dma_tile_geometry(tile) + if direction == "gm_to_ub": + return row_count, valid_cols, _mul_bytes(src_row_stride, infer_tile_element_type(tile)), physical_cols * _element_bytewidth(infer_tile_element_type(tile)) + return row_count, valid_cols, physical_cols * _element_bytewidth(infer_tile_element_type(tile)), _mul_bytes(src_row_stride, infer_tile_element_type(tile)) + + +def fill_tile(tile, value): + """Broadcast a scalar into an entire tile.""" + wrapped_tile = wrap_surface_value(tile) + scalar_value = _constant_like(value, infer_tile_element_type(wrapped_tile)) + _pto.TExpandsOp(scalar_value, unwrap_surface_value(wrapped_tile)) + + +def make_mask(dtype, value): + """Create a predicate mask matching *dtype* granularity.""" + mask_bits = _mask_bits_for_dtype(dtype) + result_type = _mask_type_from_bits(mask_bits) + + if isinstance(value, str): + return wrap_surface_value( + _pset_op_for_mask_bits(mask_bits)(result_type, _normalize_mask_pattern(value)).result + ) + + raw_value = unwrap_surface_value(value) + authored_scalar_type = raw_value.type if hasattr(raw_value, "type") else IntegerType.get_signless(32) + raw_value = _coerce_i32(raw_value, context="make_mask(..., value)") + plt_op = _plt_op_for_mask_bits(mask_bits)(result_type, IntegerType.get_signless(32), raw_value) + next_value = coerce_scalar_to_type( + plt_op.scalar_out, + authored_scalar_type, + context="make_mask(..., value) result", + ) + return MaskResultValue(plt_op.mask, next_value) + + +# ── Hardware / sync ─────────────────────────────────────────────────────────── + +def _require_pto_ptr_operand(value, *, context: str): + raw_value = unwrap_surface_value(value) + try: + _pto.PtrType(raw_value.type) + except Exception as exc: + raise TypeError(f"{context} expects PTO ptr operands, got {raw_value.type}") from exc + return raw_value + + +@_explicit_mode_only("pto.mte_load(...)") +def mte_load(source, destination, l2_cache_ctl, len_burst, *, nburst, loops=None, pad=None): + """ + Ptr-based GM->UB DMA wrapper aligned with the underlying ``pto.dma_load`` surface. + + This wrapper intentionally accepts only explicit pointer operands. It does + not infer burst shape or strides from TensorView / PartitionTensorView / + Tile metadata. + """ + n_burst, nburst_src_stride, nburst_dst_stride = _normalize_dma_group( + "nburst", + nburst, + context="mte_load(...)", + ) + loop_counts, loop_src_strides, loop_dst_strides = _normalize_dma_loops( + loops, + context="mte_load(...)", + ) + pad_value, left_padding_count, right_padding_count = _normalize_dma_pad( + pad, + context="mte_load(...)", + ) + _pto.MteGmUbOp( + _require_pto_ptr_operand(source, context="mte_load(...)"), + _require_pto_ptr_operand(destination, context="mte_load(...)"), + _coerce_i64(l2_cache_ctl, context="mte_load l2_cache_ctl"), + _coerce_i64(len_burst, context="mte_load len_burst"), + n_burst, + nburst_src_stride, + nburst_dst_stride, + loop_counts, + loop_src_strides, + loop_dst_strides, + pad_value=pad_value, + left_padding_count=left_padding_count, + right_padding_count=right_padding_count, + ) + + +@_explicit_mode_only("pto.mte_store(...)") +def mte_store(source, destination, len_burst, *, nburst, loops=None): + """Ptr-based UB->GM DMA wrapper aligned with the underlying ``pto.dma_store`` surface.""" + n_burst, nburst_src_stride, nburst_dst_stride = _normalize_dma_group( + "nburst", + nburst, + context="mte_store(...)", + ) + loop_counts, loop_src_strides, loop_dst_strides = _normalize_dma_loops( + loops, + context="mte_store(...)", + ) + _pto.MteUbGmOp( + _require_pto_ptr_operand(source, context="mte_store(...)"), + _require_pto_ptr_operand(destination, context="mte_store(...)"), + _coerce_i64(len_burst, context="mte_store len_burst"), + n_burst, + nburst_src_stride, + nburst_dst_stride, + loop_counts, + loop_src_strides, + loop_dst_strides, + ) + + +def _normalize_dma_group(name, triple, *, context: str): + if not isinstance(triple, tuple) or len(triple) != 3: + raise TypeError(f"{context} expects {name}=(count, src_stride, dst_stride)") + count, src_stride, dst_stride = triple + return ( + _coerce_i64(count, context=f"{context} {name}[0]"), + _coerce_i64(src_stride, context=f"{context} {name}[1]"), + _coerce_i64(dst_stride, context=f"{context} {name}[2]"), + ) + + +def _normalize_dma_loops(loops, *, context: str): + if loops is None: + return [], [], [] + if not isinstance(loops, (list, tuple)): + raise TypeError(f"{context} expects loops to be a list[tuple[int, int, int]] or None") + counts = [] + src_strides = [] + dst_strides = [] + for i, loop in enumerate(loops): + count, src_stride, dst_stride = _normalize_dma_group( + f"loops[{i}]", + loop, + context=context, + ) + counts.append(count) + src_strides.append(src_stride) + dst_strides.append(dst_stride) + return counts, src_strides, dst_strides + + +def _normalize_dma_pad(pad, *, context: str): + if pad is None: + return None, None, None + if not isinstance(pad, tuple): + raise TypeError(f"{context} expects pad to be tuple[ScalarType] or tuple[ScalarType, int, int]") + if len(pad) == 1: + pad_value = pad[0] + left_count = 0 + right_count = 0 + elif len(pad) == 3: + pad_value, left_count, right_count = pad + else: + raise TypeError(f"{context} expects pad to have length 1 or 3") + return ( + materialize_scalar_literal(pad_value, F32Type.get(), context=f"{context} pad[0]") + if not hasattr(pad_value, "type") else unwrap_surface_value(pad_value), + _coerce_i64(left_count, context=f"{context} pad[1]"), + _coerce_i64(right_count, context=f"{context} pad[2]"), + ) + + +@_explicit_mode_only("pto.mte_gm_ub(...)") +def mte_gm_ub(source, destination, l2_cache_ctl, len_burst, *, nburst, loops=None, pad=None): + """``pto.mte_gm_ub`` – grouped GM-to-UB DMA surface.""" + n_burst, nburst_src_stride, nburst_dst_stride = _normalize_dma_group( + "nburst", + nburst, + context="mte_gm_ub(...)", + ) + loop_counts, loop_src_strides, loop_dst_strides = _normalize_dma_loops( + loops, + context="mte_gm_ub(...)", + ) + pad_value, left_padding_count, right_padding_count = _normalize_dma_pad( + pad, + context="mte_gm_ub(...)", + ) + _pto.MteGmUbOp( + unwrap_surface_value(source), + unwrap_surface_value(destination), + _coerce_i64(l2_cache_ctl, context="mte_gm_ub l2_cache_ctl"), + _coerce_i64(len_burst, context="mte_gm_ub len_burst"), + n_burst, + nburst_src_stride, + nburst_dst_stride, + loop_counts, + loop_src_strides, + loop_dst_strides, + pad_value=pad_value, + left_padding_count=left_padding_count, + right_padding_count=right_padding_count, + ) + + +@_explicit_mode_only("pto.mte_ub_gm(...)") +def mte_ub_gm(source, destination, len_burst, *, nburst, loops=None): + """``pto.mte_ub_gm`` – grouped UB-to-GM DMA surface.""" + n_burst, nburst_src_stride, nburst_dst_stride = _normalize_dma_group( + "nburst", + nburst, + context="mte_ub_gm(...)", + ) + loop_counts, loop_src_strides, loop_dst_strides = _normalize_dma_loops( + loops, + context="mte_ub_gm(...)", + ) + _pto.MteUbGmOp( + unwrap_surface_value(source), + unwrap_surface_value(destination), + _coerce_i64(len_burst, context="mte_ub_gm len_burst"), + n_burst, + nburst_src_stride, + nburst_dst_stride, + loop_counts, + loop_src_strides, + loop_dst_strides, + ) + + +@_explicit_mode_only("pto.mte_ub_ub(...)") +def mte_ub_ub(source, destination, len_burst, *, nburst): + """``pto.mte_ub_ub`` – grouped UB-to-UB DMA surface.""" + n_burst, src_stride, dst_stride = _normalize_dma_group( + "nburst", + nburst, + context="mte_ub_ub(...)", + ) + _pto.MteUbUbOp( + unwrap_surface_value(source), + unwrap_surface_value(destination), + n_burst, + _coerce_i64(len_burst, context="mte_ub_ub len_burst"), + src_stride, + dst_stride, + ) + + +@_explicit_mode_only("pto.mte_ub_l1(...)") +def mte_ub_l1(source, destination, len_burst, *, nburst): + """``pto.mte_ub_l1`` – grouped UB-to-L1 DMA surface.""" + n_burst, src_stride, dst_stride = _normalize_dma_group( + "nburst", + nburst, + context="mte_ub_l1(...)", + ) + _pto.MteUbL1Op( + unwrap_surface_value(source), + unwrap_surface_value(destination), + n_burst, + _coerce_i64(len_burst, context="mte_ub_l1 len_burst"), + src_stride, + dst_stride, + ) + + +def mem_bar(barrier_type): + """``pto.mem_bar`` with a small authored enum surface.""" + barrier_name = getattr(barrier_type, "value", barrier_type) + _pto.MemBarOp(kind=_membar_attr(barrier_name)) + + +@_explicit_mode_only("pto.mte_l1_l0a(...)") +def mte_l1_l0a(source, destination, m, k, *, transpose=False): + """``pto.mte_l1_l0a`` – cube-side LEFT staging.""" + _pto.MteL1L0aOp( + unwrap_surface_value(source), + unwrap_surface_value(destination), + _coerce_i64(m, context="mte_l1_l0a m"), + _coerce_i64(k, context="mte_l1_l0a k"), + transpose=transpose, + ) + + +@_explicit_mode_only("pto.mte_l1_l0b(...)") +def mte_l1_l0b(source, destination, k, n, *, transpose=False): + """``pto.mte_l1_l0b`` – cube-side RIGHT staging.""" + _pto.MteL1L0bOp( + unwrap_surface_value(source), + unwrap_surface_value(destination), + _coerce_i64(k, context="mte_l1_l0b k"), + _coerce_i64(n, context="mte_l1_l0b n"), + transpose=transpose, + ) + + +@_explicit_mode_only("pto.mte_l0c_ub(...)") +def mte_l0c_ub(source, destination, m, n, src_stride, dst_stride, sub_blockid=0, *, dst_mode="single"): + """``pto.mte_l0c_ub`` – ACC to UB store.""" + _pto.MteL0cUbOp( + unwrap_surface_value(source), + unwrap_surface_value(destination), + _coerce_i64(m, context="mte_l0c_ub m"), + _coerce_i64(n, context="mte_l0c_ub n"), + _coerce_i64(src_stride, context="mte_l0c_ub src_stride"), + _coerce_i64(dst_stride, context="mte_l0c_ub dst_stride"), + _acc_store_ub_dst_mode_attr(dst_mode), + sub_blockid=_coerce_i64(sub_blockid, context="mte_l0c_ub sub_blockid"), + ) + + +def mad(lhs, rhs, dst, m, n, k): + """``pto.mad`` – cube matmul accumulate.""" + _pto.MadOp( + unwrap_surface_value(lhs), + unwrap_surface_value(rhs), + unwrap_surface_value(dst), + _coerce_i64(m, context="mad m"), + _coerce_i64(n, context="mad n"), + _coerce_i64(k, context="mad k"), + ) + + +def mad_acc(lhs, rhs, dst, m, n, k): + """``pto.mad_acc`` – cube matmul accumulate into an existing accumulator.""" + _pto.MadAccOp( + unwrap_surface_value(lhs), + unwrap_surface_value(rhs), + unwrap_surface_value(dst), + _coerce_i64(m, context="mad_acc m"), + _coerce_i64(n, context="mad_acc n"), + _coerce_i64(k, context="mad_acc k"), + ) + + +def mad_bias(lhs, rhs, dst, bias, m, n, k): + """``pto.mad_bias`` – cube matmul initialized from a bias buffer.""" + _pto.MadBiasOp( + unwrap_surface_value(lhs), + unwrap_surface_value(rhs), + unwrap_surface_value(dst), + unwrap_surface_value(bias), + _coerce_i64(m, context="mad_bias m"), + _coerce_i64(n, context="mad_bias n"), + _coerce_i64(k, context="mad_bias k"), + ) + + +def mad_mx(lhs, rhs, dst, m, n, k): + """``pto.mad_mx`` – MX-format cube matmul.""" + _pto.MadMxOp( + unwrap_surface_value(lhs), + unwrap_surface_value(rhs), + unwrap_surface_value(dst), + _coerce_i64(m, context="mad_mx m"), + _coerce_i64(n, context="mad_mx n"), + _coerce_i64(k, context="mad_mx k"), + ) + + +def mad_mx_acc(lhs, rhs, dst, m, n, k): + """``pto.mad_mx_acc`` – MX-format cube matmul accumulate.""" + _pto.MadMxAccOp( + unwrap_surface_value(lhs), + unwrap_surface_value(rhs), + unwrap_surface_value(dst), + _coerce_i64(m, context="mad_mx_acc m"), + _coerce_i64(n, context="mad_mx_acc n"), + _coerce_i64(k, context="mad_mx_acc k"), + ) + + +def mad_mx_bias(lhs, rhs, dst, bias, m, n, k): + """``pto.mad_mx_bias`` – MX-format cube matmul initialized from a bias buffer.""" + _pto.MadMxBiasOp( + unwrap_surface_value(lhs), + unwrap_surface_value(rhs), + unwrap_surface_value(dst), + unwrap_surface_value(bias), + _coerce_i64(m, context="mad_mx_bias m"), + _coerce_i64(n, context="mad_mx_bias n"), + _coerce_i64(k, context="mad_mx_bias k"), + ) + +def get_block_idx(): + """``pto.get_block_idx`` → i64 block index.""" + return wrap_surface_value(_pto.GetBlockIdxOp().result) + + +def get_block_num(): + """``pto.get_block_num`` → i64 block count.""" + return wrap_surface_value(_pto.GetBlockNumOp().result) + + +def get_subblock_idx(): + """``pto.get_subblock_idx`` → i64 subblock index.""" + return wrap_surface_value(_pto.GetSubBlockIdxOp().result) + + +def get_subblock_num(): + """``pto.get_subblock_num`` → i64 subblock count.""" + return wrap_surface_value(_pto.GetSubBlockNumOp().result) + + +def store_vfsimt_info(dim_z, dim_y, dim_x): + """``pto.store_vfsimt_info`` – configure the SIMT VF launch descriptor.""" + _pto.StoreVfSimtInfoOp( + unwrap_surface_value(dim_z), + unwrap_surface_value(dim_y), + unwrap_surface_value(dim_x), + ) + + +def get_tid_x(): + """``pto.get_tid_x`` → i32 SIMT lane X coordinate.""" + return wrap_surface_value(_pto.GetTidXOp().result) + + +def get_tid_y(): + """``pto.get_tid_y`` → i32 SIMT lane Y coordinate.""" + return wrap_surface_value(_pto.GetTidYOp().result) + + +def get_tid_z(): + """``pto.get_tid_z`` → i32 SIMT lane Z coordinate.""" + return wrap_surface_value(_pto.GetTidZOp().result) + + +def pipe_barrier(pipe): + """``pto.pipe_barrier(pipe)`` – drain the specified hardware pipeline.""" + _pto.BarrierOp(_pipe_attr(pipe)) + + +def get_buf(pipe, buf_id, mode=0): + """``pto.get_buf(pipe, buf_id, mode=0)`` – acquire a buffer token.""" + _pto.GetBufOp( + _pipe_attr(pipe), + buf_id, + mode=mode, + ) + + +def rls_buf(pipe, buf_id, mode=0): + """``pto.rls_buf(pipe, buf_id, mode=0)`` – release a buffer token.""" + _pto.RlsBufOp( + _pipe_attr(pipe), + buf_id, + mode=mode, + ) + + +def _sync_event_id_operand(event_id, *, context: str): + _validate_static_event_id(event_id, context=context) + return event_id if isinstance(event_id, int) else unwrap_surface_value(event_id) + + +def _flag_event_id_operand(event_id, *, context: str): + if isinstance(event_id, int): + _validate_static_event_id(event_id, context=context) + return event_id, True + return _coerce_index(event_id, context=context), False + + +def set_cross_flag(pipe, event_id): + """``pto.set_cross_flag(pipe, event_id)`` – cross-core sync facade for ``pto.sync.set``.""" + _validate_sync_pipe(pipe, context="set_cross_flag(pipe, event_id)", allowed=("PIPE_FIX",)) + event_operand = _sync_event_id_operand(event_id, context="set_cross_flag(..., event_id=...)") + _pto.sync_set(_pipe_attr(pipe), event_operand) + + +def wait_cross_flag(pipe, event_id): + """``pto.wait_cross_flag(pipe, event_id)`` – cross-core sync facade for ``pto.sync.wait``.""" + _validate_sync_pipe(pipe, context="wait_cross_flag(pipe, event_id)", allowed=("PIPE_FIX",)) + event_operand = _sync_event_id_operand(event_id, context="wait_cross_flag(..., event_id=...)") + _pto.sync_wait(_pipe_attr(pipe), event_operand) + + +def set_intra_flag(pipe, event_id): + """``pto.set_intra_flag(pipe, event_id)`` – intra-block sync facade for ``pto.sync.set``.""" + _validate_sync_pipe(pipe, context="set_intra_flag(pipe, event_id)", allowed=("PIPE_MTE3",)) + event_operand = _sync_event_id_operand(event_id, context="set_intra_flag(..., event_id=...)") + _pto.sync_set(_pipe_attr(pipe), event_operand) + + +def wait_intra_flag(pipe, event_id): + """``pto.wait_intra_flag(pipe, event_id)`` – intra-block sync facade for ``pto.sync.wait``.""" + _validate_sync_pipe(pipe, context="wait_intra_flag(pipe, event_id)", allowed=("PIPE_V",)) + event_operand = _sync_event_id_operand(event_id, context="wait_intra_flag(..., event_id=...)") + _pto.sync_wait(_pipe_attr(pipe), event_operand) + + +def set_flag(src: str, dst: str, *, event_id: int = 0): + """``pto.set_flag[src, dst, event_id]``. + + Accepts short pipe names (``"MTE2"``, ``"V"``, …) or full ``"PIPE_MTE2"`` + names. Static ``event_id`` values in ``[0, 7]`` lower to ``pto.set_flag``; + runtime index-like values lower to ``pto.set_flag_dyn``. + """ + event_operand, is_static = _flag_event_id_operand( + event_id, + context="set_flag(..., event_id=...)", + ) + if is_static: + _pto.set_flag(_pipe_attr(src), _pipe_attr(dst), _event_attr(event_operand)) + return + _pto.set_flag_dyn(_pipe_attr(src), _pipe_attr(dst), event_operand) + + +def wait_flag(src: str, dst: str, *, event_id: int = 0): + """``pto.wait_flag[src, dst, event_id]``. + + Static ``event_id`` values in ``[0, 7]`` lower to ``pto.wait_flag``; + runtime index-like values lower to ``pto.wait_flag_dyn``. + """ + event_operand, is_static = _flag_event_id_operand( + event_id, + context="wait_flag(..., event_id=...)", + ) + if is_static: + _pto.wait_flag(_pipe_attr(src), _pipe_attr(dst), _event_attr(event_operand)) + return + _pto.wait_flag_dyn(_pipe_attr(src), _pipe_attr(dst), event_operand) + + +__all__ = [ + "const", + "castptr", "addptr", + "vlds", "vldas", "vldus", "vldsx2", "vsts", "vstsx2", + "init_align", + "plt_b8", "plt_b16", "plt_b32", + "pset_b8", "pset_b16", "pset_b32", + "pge_b8", "pge_b16", "pge_b32", + "make_mask", + "pand", "por", "pxor", "pnot", "psel", + "pbitcast", "ppack", "punpack", + "pintlv_b8", "pintlv_b16", "pintlv_b32", + "pdintlv_b8", "pdintlv_b16", "pdintlv_b32", + "vgather2", "vgather2_bc", "vgatherb", "vscatter", "vsldb", "vsstb", + "vcmp", "vcmps", + "plds", "psts", "pstu", "vstar", "vstas", "vstur", "vstus", + "vbitcast", + "vbr", + "vadd", "vsub", "vmul", "vdiv", "vmax", "vmin", + "vand", "vor", "vxor", "vshl", "vshr", + "vcmax", "vcadd", "vcmin", "vdup", "vexpdif", + "vexp", "vln", "vsqrt", "vabs", "vneg", "vrec", "vrsqrt", "vrelu", "vnot", + "vcgmax", "vcgadd", "vcgmin", "vcpadd", + "vadds", "vsubs", "vmuls", "vmaxs", "vmins", "vlrelu", + "vaxpy", "vaddrelu", "vsubrelu", + "vsel", + "make_tensor_view", "partition_view", + "alloc_tile", + "tload", "tstore", "tmov", + "tadd", "tsub", "tmul", "tdiv", "tmax", "tmin", + "tadds", "tsubs", "tmuls", "tdivs", "tmaxs", "tmins", + "texp", "tlog", "tsqrt", "trsqrt", "trecip", "tabs", "tneg", + "trelu", "tlrelu", + "trowsum", "trowmax", "trowmin", "trowprod", "trowargmax", "trowargmin", + "tcolsum", "tcolmax", "tcolmin", "tcolprod", "tcolargmax", "tcolargmin", + "tcmp", "tcmps", + "texpands", "trowexpand", "tcolexpand", + "trowexpandadd", "trowexpandsub", "trowexpandmul", "trowexpanddiv", "trowexpandmax", "trowexpandmin", "trowexpandexpdif", + "tcolexpandadd", "tcolexpandsub", "tcolexpandmul", "tcolexpanddiv", "tcolexpandmax", "tcolexpandmin", "tcolexpandexpdif", + "tsel", "tsels", "tcvt", + "tnot", "tand", "tands", "tor", "tors", "txor", "txors", "tshl", "tshls", "tshr", "tshrs", + "tpartadd", "tpartmul", "tpartmax", "tpartmin", + "tfillpad", "tfillpad_expand", "tfillpad_inplace", + "as_ptr", + "mte_load", "mte_store", "mte_gm_ub", "mte_ub_gm", "mte_ub_ub", "mte_ub_l1", "mem_bar", + "mte_l1_l0a", "mte_l1_l0b", "mte_l0c_ub", + "mad", "mad_acc", "mad_bias", "mad_mx", "mad_mx_acc", "mad_mx_bias", + "get_block_idx", "get_block_num", "get_subblock_idx", "get_subblock_num", + "store_vfsimt_info", "get_tid_x", "get_tid_y", "get_tid_z", + "pipe_barrier", "get_buf", "rls_buf", + "set_cross_flag", "wait_cross_flag", "set_intra_flag", "wait_intra_flag", + "set_flag", "wait_flag", +] diff --git a/ptodsl/ptodsl/_runtime/__init__.py b/ptodsl/ptodsl/_runtime/__init__.py new file mode 100644 index 000000000..a83ab9eaf --- /dev/null +++ b/ptodsl/ptodsl/_runtime/__init__.py @@ -0,0 +1,16 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +"""Private runtime helpers for ``@pto.jit`` launch.""" + +from .launch import LaunchHandle +from .native_build import build_native_library + +__all__ = [ + "LaunchHandle", + "build_native_library", +] diff --git a/ptodsl/ptodsl/_runtime/cache.py b/ptodsl/ptodsl/_runtime/cache.py new file mode 100644 index 000000000..8964cd1c0 --- /dev/null +++ b/ptodsl/ptodsl/_runtime/cache.py @@ -0,0 +1,90 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +"""Artifact cache layout for JIT-compiled native libraries.""" + +from __future__ import annotations + +import hashlib +import json +import os +from dataclasses import dataclass +from pathlib import Path + + +def _default_cache_root() -> Path: + override = os.environ.get("PTODSL_CACHE_DIR") + if override: + return Path(override) + return Path.home() / ".cache" / "ptodsl" + + +def _specialization_digest(specialization_key) -> str: + payload = repr(specialization_key).encode("utf-8") + return hashlib.sha256(payload).hexdigest()[:16] + + +@dataclass(frozen=True) +class NativeBuildArtifacts: + """Paths to one compiled native kernel specialization.""" + + cache_dir: Path + mlir_path: Path + kernel_object: Path + launch_cpp: Path + shared_library: Path + manifest_path: Path + + +def artifact_paths(py_name: str, ir_function_name: str, specialization_key) -> NativeBuildArtifacts: + """Return stable artifact paths for one compiled specialization.""" + digest = _specialization_digest(specialization_key) + safe_name = py_name.replace("/", "_") + cache_dir = _default_cache_root() / f"{safe_name}_{digest}" + lib_name = f"lib{ir_function_name}.so" + return NativeBuildArtifacts( + cache_dir=cache_dir, + mlir_path=cache_dir / "kernel.mlir", + kernel_object=cache_dir / "kernel.o", + launch_cpp=cache_dir / "launch.cpp", + shared_library=cache_dir / lib_name, + manifest_path=cache_dir / "manifest.json", + ) + + +def write_manifest(artifacts: NativeBuildArtifacts, *, ir_function_name: str, launch_symbol: str) -> None: + artifacts.cache_dir.mkdir(parents=True, exist_ok=True) + manifest = { + "ir_function_name": ir_function_name, + "launch_symbol": launch_symbol, + "shared_library": str(artifacts.shared_library), + } + artifacts.manifest_path.write_text(json.dumps(manifest, indent=2) + "\n", encoding="utf-8") + + +def read_manifest(artifacts: NativeBuildArtifacts) -> dict: + return json.loads(artifacts.manifest_path.read_text(encoding="utf-8")) + + +def is_native_build_current(artifacts: NativeBuildArtifacts) -> bool: + required = ( + artifacts.mlir_path, + artifacts.kernel_object, + artifacts.launch_cpp, + artifacts.shared_library, + artifacts.manifest_path, + ) + return all(path.is_file() for path in required) + + +__all__ = [ + "NativeBuildArtifacts", + "artifact_paths", + "is_native_build_current", + "read_manifest", + "write_manifest", +] diff --git a/ptodsl/ptodsl/_runtime/codegen.py b/ptodsl/ptodsl/_runtime/codegen.py new file mode 100644 index 000000000..f118225c1 --- /dev/null +++ b/ptodsl/ptodsl/_runtime/codegen.py @@ -0,0 +1,157 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +"""Generate host-side launch wrappers for traced PTODSL kernels.""" + +from __future__ import annotations + +from mlir.ir import BF16Type, F16Type, F32Type, IndexType, IntegerType + +from .._bootstrap import make_context +from .._kernel_signature import DeviceParameterSpec, RuntimeScalarParameterSpec, TensorSpecParameterSpec +from .._types import _PtrDescriptor, _resolve + + +def _resolve_type(annotation): + with make_context(): + return _resolve(annotation) + + +def _elem_cpp_type(elem) -> str: + name = getattr(elem, "__name__", repr(elem)).lower() + mapping = { + "float32": "float", + "f32": "float", + "float16": "__fp16", + "f16": "__fp16", + "bf16": "__bf16", + "int8": "int8_t", + "int16": "int16_t", + "int32": "int32_t", + "int64": "int64_t", + "ui8": "uint8_t", + "ui16": "uint16_t", + "ui32": "uint32_t", + "ui64": "uint64_t", + } + for key, cpp in mapping.items(): + if key in name: + return cpp + return "float" + + +def _device_param_cpp_type(annotation) -> str: + if isinstance(annotation, _PtrDescriptor): + return _elem_cpp_type(annotation._elem) + type_repr = repr(annotation).replace(" ", "").lower() + if "f32" in type_repr or "float32" in type_repr: + return "float" + if "i32" in type_repr or "int32" in type_repr: + return "int32_t" + if "i64" in type_repr or "int64" in type_repr: + return "int64_t" + return "float" + + +def _runtime_scalar_cpp_type(annotation) -> str: + type_obj = _resolve_type(annotation) + if IndexType.isinstance(type_obj): + return "int64_t" + if IntegerType.isinstance(type_obj): + width = IntegerType(type_obj).width + if width == 1: + return "bool" + signedness = str(type_obj) + if signedness.startswith("ui"): + return { + 8: "uint8_t", + 16: "uint16_t", + 32: "uint32_t", + 64: "uint64_t", + }[width] + return { + 8: "int8_t", + 16: "int16_t", + 32: "int32_t", + 64: "int64_t", + }[width] + if F32Type.isinstance(type_obj): + return "float" + if F16Type.isinstance(type_obj): + return "__fp16" + if BF16Type.isinstance(type_obj): + return "__bf16" + raise TypeError(f"unsupported @pto.jit runtime scalar codegen type {type_obj}") + + +def launch_symbol_name(ir_function_name: str) -> str: + return f"ptodsl_launch_{ir_function_name}" + + +def _tensor_metadata_cpp_type() -> str: + # Host-visible tensor shape/stride metadata is marshaled as 64-bit integers. + return "int64_t" + + +def generate_launch_cpp(*, ir_function_name: str, kernel_signature) -> str: + """Return C++ source for one extern-C launch entry point.""" + gm_params = [] + host_params = [] + kernel_args = [] + + for param in kernel_signature.positional_parameters: + if isinstance(param, DeviceParameterSpec): + cpp_type = _device_param_cpp_type(param.annotation) + gm_params.append(f"__gm__ {cpp_type} *{param.name}") + host_params.append(f"{cpp_type} *{param.name}") + kernel_args.append(f"(__gm__ {cpp_type} *){param.name}") + continue + if isinstance(param, RuntimeScalarParameterSpec): + cpp_type = _runtime_scalar_cpp_type(param.annotation) + gm_params.append(f"{cpp_type} {param.name}") + host_params.append(f"{cpp_type} {param.name}") + kernel_args.append(param.name) + continue + if isinstance(param, TensorSpecParameterSpec): + cpp_type = _elem_cpp_type(param.tensor_spec.dtype) + meta_cpp_type = _tensor_metadata_cpp_type() + rank = param.tensor_spec.rank + gm_params.append(f"__gm__ {cpp_type} *{param.name}_ptr") + host_params.append(f"{cpp_type} *{param.name}_ptr") + kernel_args.append(f"(__gm__ {cpp_type} *){param.name}_ptr") + for idx in range(rank): + gm_params.append(f"{meta_cpp_type} {param.name}_shape_{idx}") + host_params.append(f"{meta_cpp_type} {param.name}_shape_{idx}") + kernel_args.append(f"{param.name}_shape_{idx}") + for idx in range(rank): + gm_params.append(f"{meta_cpp_type} {param.name}_stride_{idx}") + host_params.append(f"{meta_cpp_type} {param.name}_stride_{idx}") + kernel_args.append(f"{param.name}_stride_{idx}") + continue + raise TypeError(f"unsupported launch parameter spec: {param!r}") + + gm_sig = ", ".join(gm_params) + host_sig = ", ".join(["uint32_t grid", "void *stream"] + host_params) + kernel_call = ", ".join(kernel_args) + launch_symbol = launch_symbol_name(ir_function_name) + + return ( + "#include \n\n" + "#ifndef AICORE\n" + "#define AICORE [aicore]\n" + "#endif\n\n" + f'extern "C" __global__ AICORE void {ir_function_name}({gm_sig});\n\n' + f"extern \"C\" void {launch_symbol}({host_sig}) {{\n" + f" {ir_function_name}<<>>({kernel_call});\n" + "}\n" + ) + + +__all__ = [ + "generate_launch_cpp", + "launch_symbol_name", +] diff --git a/ptodsl/ptodsl/_runtime/launch.py b/ptodsl/ptodsl/_runtime/launch.py new file mode 100644 index 000000000..9ae78dbd3 --- /dev/null +++ b/ptodsl/ptodsl/_runtime/launch.py @@ -0,0 +1,200 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +"""Launch handles and ctypes dispatch for compiled PTODSL kernels.""" + +from __future__ import annotations + +import ctypes +from typing import TYPE_CHECKING + +from .._host_tensors import ( + inspect_host_tensor_metadata, + looks_like_host_tensor, +) +from .._bootstrap import make_context +from .._kernel_signature import DeviceParameterSpec, RuntimeScalarParameterSpec, TensorSpecParameterSpec +from .._types import _resolve +from .native_build import build_native_library + +from mlir.ir import BF16Type, F16Type, F32Type, IndexType, IntegerType + +if TYPE_CHECKING: + from .._kernel_compilation import CompiledKernelHandle + + +def _resolve_type(annotation): + with make_context(): + return _resolve(annotation) + + +def _normalize_stream_ptr(stream): + if stream is None: + try: + import torch + except ImportError as exc: + raise ImportError( + "stream=None requires torch; install torch and torch_npu for default-stream launch" + ) from exc + return torch.npu.current_stream()._as_parameter_ # noqa: SLF001 + + if isinstance(stream, ctypes.c_void_p): + return stream + if isinstance(stream, int): + return ctypes.c_void_p(stream) + if hasattr(stream, "value"): + return ctypes.c_void_p(int(stream.value)) + return stream + + +def _as_void_ptr(value): + if isinstance(value, ctypes.c_void_p): + return value + if hasattr(value, "data_ptr"): + return ctypes.c_void_p(value.data_ptr()) + if isinstance(value, int): + return ctypes.c_void_p(value) + raise TypeError(f"expected a pointer-like launch argument, got {type(value)!r}") + + +def _ctype_for_runtime_scalar(annotation): + type_obj = _resolve_type(annotation) + if IndexType.isinstance(type_obj): + return ctypes.c_int64 + if IntegerType.isinstance(type_obj): + width = IntegerType(type_obj).width + if width == 1: + return ctypes.c_bool + if width == 8: + return ctypes.c_int8 + if width == 16: + return ctypes.c_int16 + if width == 32: + return ctypes.c_int32 + if width == 64: + return ctypes.c_int64 + if F32Type.isinstance(type_obj): + return ctypes.c_float + if F16Type.isinstance(type_obj) or BF16Type.isinstance(type_obj): + raise TypeError( + f"runtime launch does not yet support host scalar marshaling for {type_obj}; " + "use pto.f32 / integer scalar parameters or tensorize this value for now" + ) + raise TypeError(f"unsupported @pto.jit runtime scalar launch type {type_obj}") + + +def _marshal_runtime_scalar(annotation, value): + ctype = _ctype_for_runtime_scalar(annotation) + if ctype is ctypes.c_bool: + return ctype(bool(value)) + return ctype(value) + + +def _marshal_launch_args(kernel_signature, args): + if len(args) != len(kernel_signature.positional_parameters): + raise TypeError( + f"expected {len(kernel_signature.positional_parameters)} launch argument(s), " + f"got {len(args)}" + ) + + marshaled = [] + for param, value in zip(kernel_signature.positional_parameters, args): + if isinstance(param, DeviceParameterSpec): + marshaled.append(_as_void_ptr(value)) + continue + if isinstance(param, RuntimeScalarParameterSpec): + marshaled.append(_marshal_runtime_scalar(param.annotation, value)) + continue + if isinstance(param, TensorSpecParameterSpec): + if not looks_like_host_tensor(value): + raise TypeError( + f"launch argument '{param.name}' expects a Python-native tensor-like object" + ) + meta = inspect_host_tensor_metadata(value) + marshaled.append(_as_void_ptr(meta.data_handle)) + for dim in meta.shape: + marshaled.append(ctypes.c_int64(dim)) + for dim in meta.strides: + marshaled.append(ctypes.c_int64(dim)) + continue + raise TypeError(f"unsupported launch parameter spec: {param!r}") + return marshaled + + +class LaunchHandle: + """Callable launch binding returned by ``compiled[grid, stream]``.""" + + def __init__(self, compiled: CompiledKernelHandle, grid: int, stream): + if not isinstance(grid, int) or grid <= 0: + raise ValueError("launch grid must be a positive integer") + self._compiled = compiled + self._grid = grid + self._stream = stream + self._launch_fn = None + self._launch_symbol = None + + def _ensure_launch_fn(self): + if self._launch_fn is not None: + return + + lib_path, launch_symbol = build_native_library( + py_name=self._compiled._py_name, + module_spec=self._compiled._module_spec, + kernel_signature=self._compiled._kernel_signature, + mlir_text=self._compiled.mlir_text(), + specialization_key=self._compiled.specialization_key, + ) + lib = ctypes.CDLL(str(lib_path)) + fn = getattr(lib, launch_symbol) + fn.argtypes = _launch_argtypes(self._compiled._kernel_signature) + fn.restype = None + self._launch_fn = fn + self._launch_symbol = launch_symbol + + def __call__(self, *args): + self._ensure_launch_fn() + marshaled = _marshal_launch_args(self._compiled._kernel_signature, args) + self._launch_fn( + ctypes.c_uint32(self._grid), + _normalize_stream_ptr(self._stream), + *marshaled, + ) + + +def _launch_argtypes(kernel_signature): + argtypes = [ctypes.c_uint32, ctypes.c_void_p] + for param in kernel_signature.positional_parameters: + if isinstance(param, DeviceParameterSpec): + argtypes.append(ctypes.c_void_p) + continue + if isinstance(param, RuntimeScalarParameterSpec): + argtypes.append(_ctype_for_runtime_scalar(param.annotation)) + continue + if isinstance(param, TensorSpecParameterSpec): + argtypes.append(ctypes.c_void_p) + rank = param.tensor_spec.rank + argtypes.extend([ctypes.c_int64] * (rank + rank)) + continue + raise TypeError(f"unsupported launch parameter spec: {param!r}") + return argtypes + + +def parse_launch_spec(launch_spec) -> tuple[int, object]: + if not isinstance(launch_spec, tuple) or len(launch_spec) != 2: + raise TypeError( + "compiled launch syntax expects compiled[grid, stream]; " + f"got {type(launch_spec)!r} with length " + f"{len(launch_spec) if isinstance(launch_spec, tuple) else 'n/a'}" + ) + grid, stream = launch_spec + return grid, stream + + +__all__ = [ + "LaunchHandle", + "parse_launch_spec", +] diff --git a/ptodsl/ptodsl/_runtime/native_build.py b/ptodsl/ptodsl/_runtime/native_build.py new file mode 100644 index 000000000..38fc59286 --- /dev/null +++ b/ptodsl/ptodsl/_runtime/native_build.py @@ -0,0 +1,217 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +"""MLIR → ptoas → bisheng native library build.""" + +from __future__ import annotations + +import subprocess +from pathlib import Path + +from .cache import ( + NativeBuildArtifacts, + artifact_paths, + is_native_build_current, + write_manifest, +) +from .codegen import generate_launch_cpp, launch_symbol_name +from .toolchain import ( + aicore_arch_for_kernel_kind, + common_include_flags, + resolve_bisheng, + resolve_ptoas_binary, +) + + +def _run(cmd: list[str], *, cwd: Path | None = None) -> None: + result = subprocess.run(cmd, cwd=str(cwd) if cwd else None, capture_output=True, text=True) + if result.returncode != 0: + output = (result.stdout or "") + (result.stderr or "") + raise RuntimeError( + f"command failed ({result.returncode}): {' '.join(cmd)}\n{output}" + ) + + +def _run_ptoas( + mlir_path: Path, + kernel_object: Path, + *, + target_arch: str, + mode: str, + insert_sync: bool | None, +) -> None: + ptoas = resolve_ptoas_binary() + cmd = [ + str(ptoas), + f"--pto-arch={target_arch}", + "--pto-backend=vpto", + ] + effective_insert_sync = (mode != "explicit") if insert_sync is None else insert_sync + if mode == "explicit": + cmd.append("--pto-level=level3") + if effective_insert_sync: + cmd.append("--enable-insert-sync") + cmd.extend([ + "--enable-tile-op-expand", + str(mlir_path), + "-o", + str(kernel_object), + ]) + _run( + cmd + ) + + +def _host_compile_flags() -> list[str]: + return common_include_flags() + [ + "-std=gnu++17", + "-O2", + "-Wno-macro-redefined", + "-Wno-ignored-attributes", + "-Wno-unknown-attributes", + "-xc++", + "-include", + "stdint.h", + "-include", + "stddef.h", + "-fPIC", + ] + + +def _kernel_compile_flags(kernel_kind: str) -> list[str]: + arch = aicore_arch_for_kernel_kind(kernel_kind) + return common_include_flags() + [ + "-std=gnu++17", + "-O2", + "-Wno-macro-redefined", + "-Wno-ignored-attributes", + "-Wno-unknown-attributes", + "-fPIC", + "-xcce", + "-Xhost-start", + "-Xhost-end", + "-mllvm", + "-cce-aicore-stack-size=0x8000", + "-mllvm", + "-cce-aicore-function-stack-size=0x8000", + "-mllvm", + "-cce-aicore-record-overflow=true", + "-mllvm", + "-cce-aicore-addr-transform", + "-mllvm", + "-cce-aicore-dcci-insert-for-scalar=false", + f"--cce-aicore-arch={arch}", + ] + + +def _compile_launch_cpp( + launch_cpp: Path, + launch_object: Path, + *, + kernel_kind: str, + export_macro: str, +) -> None: + bisheng = resolve_bisheng() + _run( + [ + bisheng, + *_kernel_compile_flags(kernel_kind), + f"-D{export_macro}", + "-c", + str(launch_cpp), + "-o", + str(launch_object), + ] + ) + + +def _link_shared_library( + launch_object: Path, + kernel_object: Path, + shared_library: Path, + *, + kernel_kind: str, +) -> None: + bisheng = resolve_bisheng() + soname = shared_library.name + _run( + [ + bisheng, + "-fPIC", + "-shared", + "--cce-fatobj-link", + f"-Wl,-soname,{soname}", + "-o", + str(shared_library), + str(launch_object), + str(kernel_object), + ] + ) + + +def build_native_library( + *, + py_name: str, + module_spec, + kernel_signature, + mlir_text: str, + specialization_key, +) -> tuple[Path, str]: + """Build or reuse the shared library for one compiled specialization.""" + ir_function_name = module_spec.function_name + artifacts = artifact_paths(py_name, ir_function_name, specialization_key) + launch_symbol = launch_symbol_name(ir_function_name) + + if is_native_build_current(artifacts): + return artifacts.shared_library, launch_symbol + + artifacts.cache_dir.mkdir(parents=True, exist_ok=True) + artifacts.mlir_path.write_text(mlir_text, encoding="utf-8") + artifacts.launch_cpp.write_text( + generate_launch_cpp( + ir_function_name=ir_function_name, + kernel_signature=kernel_signature, + ), + encoding="utf-8", + ) + + _run_ptoas( + artifacts.mlir_path, + artifacts.kernel_object, + target_arch=module_spec.target_arch, + mode=module_spec.mode, + insert_sync=module_spec.insert_sync, + ) + + launch_object = artifacts.cache_dir / "launch.o" + export_macro = f"{ir_function_name}_EXPORTS" + _compile_launch_cpp( + artifacts.launch_cpp, + launch_object, + kernel_kind=module_spec.kernel_kind, + export_macro=export_macro, + ) + _link_shared_library( + launch_object, + artifacts.kernel_object, + artifacts.shared_library, + kernel_kind=module_spec.kernel_kind, + ) + write_manifest( + artifacts, + ir_function_name=ir_function_name, + launch_symbol=launch_symbol, + ) + return artifacts.shared_library, launch_symbol + + +__all__ = [ + "NativeBuildArtifacts", + "artifact_paths", + "build_native_library", + "is_native_build_current", +] diff --git a/ptodsl/ptodsl/_runtime/toolchain.py b/ptodsl/ptodsl/_runtime/toolchain.py new file mode 100644 index 000000000..edd9a43ad --- /dev/null +++ b/ptodsl/ptodsl/_runtime/toolchain.py @@ -0,0 +1,87 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +"""Resolve external toolchain binaries and CANN paths.""" + +from __future__ import annotations + +import os +import shutil +from pathlib import Path + + +def resolve_ptoas_binary() -> Path: + repo_root = Path(__file__).resolve().parents[4] + candidates = [ + repo_root / "build" / "tools" / "ptoas" / "ptoas", + repo_root / "install" / "bin" / "ptoas", + ] + for candidate in candidates: + if candidate.is_file(): + return candidate + + from_path = shutil.which("ptoas") + if from_path: + return Path(from_path) + + raise FileNotFoundError( + "unable to locate ptoas; build ptoas or add it to PATH after sourcing set_ptoas_env.sh" + ) + + +def resolve_bisheng() -> str: + ascend_home = ascend_home_path() + candidate = ascend_home / "bin" / "bisheng" + if candidate.is_file(): + return str(candidate) + + found = shutil.which("bisheng") + if found: + return found + + raise FileNotFoundError("bisheng compiler not found; source ASCEND setenv.bash first") + + +def ascend_home_path() -> Path: + home = os.environ.get("ASCEND_HOME_PATH") + if not home: + raise EnvironmentError("ASCEND_HOME_PATH is not set; source CANN setenv.bash first") + return Path(home) + + +def ascend_driver_path() -> Path: + return Path(os.environ.get("ASCEND_DRIVER_PATH", "/usr/local/Ascend/driver")) + + +def common_include_flags() -> list[str]: + ascend = ascend_home_path() + driver = ascend_driver_path() + return [ + f"-I{ascend}/include", + f"-I{driver}/kernel/inc", + f"-I{ascend}/pkg_inc", + f"-I{ascend}/pkg_inc/profiling", + f"-I{ascend}/pkg_inc/runtime/runtime", + ] + + +def aicore_arch_for_kernel_kind(kernel_kind: str) -> str: + if kernel_kind == "vector": + return "dav-c310-vec" + if kernel_kind == "cube": + return "dav-c310-cube" + raise ValueError(f"unsupported kernel_kind for native build: {kernel_kind!r}") + + +__all__ = [ + "aicore_arch_for_kernel_kind", + "ascend_driver_path", + "ascend_home_path", + "common_include_flags", + "resolve_bisheng", + "resolve_ptoas_binary", +] diff --git a/ptodsl/ptodsl/_runtime_index_ops.py b/ptodsl/ptodsl/_runtime_index_ops.py new file mode 100644 index 000000000..03256c8aa --- /dev/null +++ b/ptodsl/ptodsl/_runtime_index_ops.py @@ -0,0 +1,43 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +"""Tracing-time helpers for coercing authored runtime values to MLIR index.""" + +from __future__ import annotations + +from mlir.dialects import arith +from mlir.ir import IndexType, IntegerType + + +def coerce_runtime_index(value, *, context: str): + """Normalize one authored loop/slice bound to an MLIR index SSA value.""" + if isinstance(value, bool): + raise TypeError(f"{context} does not accept bool values") + + if isinstance(value, int): + return arith.ConstantOp(IndexType.get(), value).result + + if not hasattr(value, "type"): + raise TypeError( + f"{context} expects a Python int, an index value, or an integer runtime scalar; " + f"got {value!r}" + ) + + value_type = value.type + if IndexType.isinstance(value_type): + return value + if IntegerType.isinstance(value_type): + return arith.IndexCastOp(IndexType.get(), value).result + + raise TypeError( + f"{context} expects an index or integer runtime scalar, got {value_type}" + ) + + +__all__ = [ + "coerce_runtime_index", +] diff --git a/ptodsl/ptodsl/_runtime_scalar_ops.py b/ptodsl/ptodsl/_runtime_scalar_ops.py new file mode 100644 index 000000000..989dca8bf --- /dev/null +++ b/ptodsl/ptodsl/_runtime_scalar_ops.py @@ -0,0 +1,302 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +"""Tracing-time authored scalar operator lowering for runtime values.""" + +from __future__ import annotations + +from ._types import ( + _integer_signedness, + _materialize_integer_literal, + _restore_integer_signedness, + _strip_integer_signedness, +) + +from mlir.dialects import arith, math +from mlir.ir import BF16Type, F16Type, F32Type, FloatAttr, IndexType, IntegerType + + +_FLOAT_BINARY_OPS = { + "add": arith.AddFOp, + "sub": arith.SubFOp, + "mul": arith.MulFOp, + "truediv": arith.DivFOp, +} + + +def emit_runtime_binary_op(op_name: str, lhs, rhs): + """Lower one authored runtime scalar binary operator.""" + lhs, rhs, kind = normalize_runtime_binary_operands(lhs, rhs) + if kind in {"index", "integer"}: + op_cls = _integer_binary_op(op_name, lhs.type) + if op_cls is None: + raise TypeError(f"runtime scalar operator '{op_name}' is not supported for integer/index values") + authored_type = lhs.type + if kind == "integer": + lhs = _strip_integer_signedness(lhs) + rhs = _strip_integer_signedness(rhs) + result = op_cls(lhs, rhs).result + if kind == "index": + return result + return _restore_runtime_integer_result(result, authored_type) + if kind == "float": + op_cls = _FLOAT_BINARY_OPS.get(op_name) + if op_cls is None: + raise TypeError(f"runtime scalar operator '{op_name}' is not supported for floating-point values") + return op_cls(lhs, rhs).result + raise TypeError(f"unsupported runtime scalar operand category '{kind}'") + + +def emit_runtime_max(lhs, rhs): + """Lower one authored runtime scalar max operation.""" + lhs, rhs, kind = normalize_runtime_binary_operands(lhs, rhs) + if kind == "float": + return arith.MaximumFOp(lhs, rhs).result + if kind == "integer": + signedness = _integer_signedness(lhs.type) + signless_lhs = _strip_integer_signedness(lhs) + signless_rhs = _strip_integer_signedness(rhs) + if signedness == "unsigned": + result = arith.MaxUIOp(signless_lhs, signless_rhs).result + else: + result = arith.MaxSIOp(signless_lhs, signless_rhs).result + return _restore_integer_signedness(result, lhs.type) + if kind == "index": + cond = arith.CmpIOp(arith.CmpIPredicate.sge, lhs, rhs).result + return arith.SelectOp(cond, lhs, rhs).result + raise TypeError(f"unsupported runtime scalar operand category '{kind}'") + + +def emit_runtime_min(lhs, rhs): + """Lower one authored runtime scalar min operation.""" + lhs, rhs, kind = normalize_runtime_binary_operands(lhs, rhs) + if kind == "float": + return arith.MinimumFOp(lhs, rhs).result + if kind == "integer": + signedness = _integer_signedness(lhs.type) + signless_lhs = _strip_integer_signedness(lhs) + signless_rhs = _strip_integer_signedness(rhs) + if signedness == "unsigned": + result = arith.MinUIOp(signless_lhs, signless_rhs).result + else: + result = arith.MinSIOp(signless_lhs, signless_rhs).result + return _restore_integer_signedness(result, lhs.type) + if kind == "index": + cond = arith.CmpIOp(arith.CmpIPredicate.sle, lhs, rhs).result + return arith.SelectOp(cond, lhs, rhs).result + raise TypeError(f"unsupported runtime scalar operand category '{kind}'") + + +def normalize_runtime_binary_operands(lhs, rhs): + lhs_is_value = _is_mlir_value(lhs) + rhs_is_value = _is_mlir_value(rhs) + + if not lhs_is_value and not rhs_is_value: + raise TypeError("runtime scalar operators require at least one traced runtime operand") + + if lhs_is_value and rhs_is_value: + return _reconcile_typed_operands(lhs, rhs) + + anchor_type = lhs.type if lhs_is_value else rhs.type + lhs = lhs if lhs_is_value else _materialize_literal(lhs, anchor_type) + rhs = rhs if rhs_is_value else _materialize_literal(rhs, anchor_type) + return _reconcile_typed_operands(lhs, rhs) + + +def _reconcile_typed_operands(lhs, rhs): + lhs_type = lhs.type + rhs_type = rhs.type + + if lhs_type == rhs_type: + return lhs, rhs, classify_runtime_scalar_type(lhs_type) + + if IndexType.isinstance(lhs_type) and IntegerType.isinstance(rhs_type): + rhs = arith.IndexCastOp(IndexType.get(), _strip_integer_signedness(rhs)).result + return lhs, rhs, "index" + + if IntegerType.isinstance(lhs_type) and IndexType.isinstance(rhs_type): + lhs = arith.IndexCastOp(IndexType.get(), _strip_integer_signedness(lhs)).result + return lhs, rhs, "index" + + raise TypeError( + "runtime scalar operators require matching scalar types or an index/integer pair; " + f"got {lhs_type} and {rhs_type}" + ) + + +def _materialize_literal(value, anchor_type): + if isinstance(value, bool): + raise TypeError("runtime scalar operators do not accept bool literals") + + kind = classify_runtime_scalar_type(anchor_type) + if kind == "float": + return arith.ConstantOp(anchor_type, FloatAttr.get(anchor_type, float(value))).result + if kind == "index": + return arith.ConstantOp(anchor_type, int(value)).result + + if isinstance(value, float): + raise TypeError( + "runtime scalar operators cannot materialize a floating-point literal " + f"against non-floating operand type {anchor_type}" + ) + + return _materialize_integer_literal(anchor_type, value) + + +def classify_runtime_scalar_type(type_obj): + if IndexType.isinstance(type_obj): + return "index" + if IntegerType.isinstance(type_obj): + return "integer" + if any(cls.isinstance(type_obj) for cls in (BF16Type, F16Type, F32Type)): + return "float" + raise TypeError(f"runtime scalar operators only support index/int/float values, got {type_obj}") + + +def _is_mlir_value(value) -> bool: + return not isinstance(value, (bool, int, float)) and hasattr(value, "type") + + +def _restore_runtime_integer_result(result, authored_type): + if IndexType.isinstance(authored_type): + return result + if not IntegerType.isinstance(authored_type): + return result + return _restore_integer_signedness(result, authored_type) + + +def emit_runtime_compare(op_name: str, lhs, rhs): + """Lower one authored runtime scalar comparison operator.""" + lhs, rhs, kind = normalize_runtime_binary_operands(lhs, rhs) + + if kind == "float": + predicate = { + "lt": arith.CmpFPredicate.OLT, + "le": arith.CmpFPredicate.OLE, + "gt": arith.CmpFPredicate.OGT, + "ge": arith.CmpFPredicate.OGE, + "eq": arith.CmpFPredicate.OEQ, + "ne": arith.CmpFPredicate.ONE, + }.get(op_name) + if predicate is None: + raise TypeError(f"runtime scalar comparison '{op_name}' is not supported for floating-point values") + return arith.CmpFOp(predicate, lhs, rhs).result + + if kind == "index": + predicate = { + "lt": arith.CmpIPredicate.slt, + "le": arith.CmpIPredicate.sle, + "gt": arith.CmpIPredicate.sgt, + "ge": arith.CmpIPredicate.sge, + "eq": arith.CmpIPredicate.eq, + "ne": arith.CmpIPredicate.ne, + }.get(op_name) + if predicate is None: + raise TypeError(f"runtime scalar comparison '{op_name}' is not supported for index values") + return arith.CmpIOp(predicate, lhs, rhs).result + + if kind == "integer": + signedness = _integer_signedness(lhs.type) + signed_predicates = { + "lt": arith.CmpIPredicate.slt, + "le": arith.CmpIPredicate.sle, + "gt": arith.CmpIPredicate.sgt, + "ge": arith.CmpIPredicate.sge, + "eq": arith.CmpIPredicate.eq, + "ne": arith.CmpIPredicate.ne, + } + unsigned_predicates = { + "lt": arith.CmpIPredicate.ult, + "le": arith.CmpIPredicate.ule, + "gt": arith.CmpIPredicate.ugt, + "ge": arith.CmpIPredicate.uge, + "eq": arith.CmpIPredicate.eq, + "ne": arith.CmpIPredicate.ne, + } + predicate = (unsigned_predicates if signedness == "unsigned" else signed_predicates).get(op_name) + if predicate is None: + raise TypeError(f"runtime scalar comparison '{op_name}' is not supported for integer values") + return arith.CmpIOp(predicate, _strip_integer_signedness(lhs), _strip_integer_signedness(rhs)).result + + raise TypeError(f"unsupported runtime scalar operand category '{kind}'") + + +def emit_runtime_bitwise_op(op_name: str, lhs, rhs): + """Lower one authored runtime scalar bitwise operator.""" + lhs, rhs, kind = normalize_runtime_binary_operands(lhs, rhs) + if kind != "integer": + raise TypeError( + f"runtime scalar bitwise operator '{op_name}' expects integer-like operands, got {lhs.type} and {rhs.type}" + ) + + op_cls = { + "and": arith.AndIOp, + "or": arith.OrIOp, + "xor": arith.XOrIOp, + }.get(op_name) + if op_cls is None: + raise TypeError(f"unsupported runtime scalar bitwise operator '{op_name}'") + + authored_type = lhs.type + result = op_cls(_strip_integer_signedness(lhs), _strip_integer_signedness(rhs)).result + return _restore_integer_signedness(result, authored_type) + + +def emit_runtime_abs(value): + """Lower one authored runtime scalar absolute-value operation.""" + kind = classify_runtime_scalar_type(value.type) + if kind == "float": + return math.AbsFOp(value).result + if kind == "index": + return value + if kind == "integer": + signedness = _integer_signedness(value.type) + if signedness == "unsigned": + return value + result = math.AbsIOp(_strip_integer_signedness(value)).result + return _restore_integer_signedness(result, value.type) + raise TypeError(f"unsupported runtime scalar operand category '{kind}'") + + +def _integer_binary_op(op_name: str, authored_type): + if IndexType.isinstance(authored_type): + return { + "add": arith.AddIOp, + "sub": arith.SubIOp, + "mul": arith.MulIOp, + "floordiv": arith.FloorDivSIOp, + "mod": arith.RemSIOp, + }.get(op_name) + + signedness = _integer_signedness(authored_type) + if op_name in {"add", "sub", "mul"}: + return { + "add": arith.AddIOp, + "sub": arith.SubIOp, + "mul": arith.MulIOp, + }[op_name] + if op_name == "floordiv": + if signedness == "unsigned": + return arith.DivUIOp + return arith.FloorDivSIOp + if op_name == "mod": + if signedness == "unsigned": + return arith.RemUIOp + return arith.RemSIOp + return None + + +__all__ = [ + "classify_runtime_scalar_type", + "emit_runtime_abs", + "emit_runtime_binary_op", + "emit_runtime_compare", + "emit_runtime_bitwise_op", + "emit_runtime_max", + "emit_runtime_min", + "normalize_runtime_binary_operands", +] diff --git a/ptodsl/ptodsl/_scalar_coercion.py b/ptodsl/ptodsl/_scalar_coercion.py new file mode 100644 index 000000000..ad839977c --- /dev/null +++ b/ptodsl/ptodsl/_scalar_coercion.py @@ -0,0 +1,119 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +"""Shared authored scalar type-adaptation helpers for PTODSL surface lowering.""" + +from __future__ import annotations + +from ._runtime_scalar_ops import classify_runtime_scalar_type +from ._surface_values import unwrap_surface_value +from ._types import ( + _integer_signedness, + _materialize_integer_literal, + _restore_integer_signedness, + _signless_integer_type, + _strip_integer_signedness, +) + +from mlir.dialects import arith +from mlir.ir import BF16Type, F16Type, F32Type, FloatAttr, IndexType, IntegerType + + +def coerce_scalar_to_type(value, target_type, *, context: str): + """Normalize one authored scalar value/literal to *target_type*.""" + raw_value = unwrap_surface_value(value) + if not hasattr(raw_value, "type"): + return materialize_scalar_literal(raw_value, target_type, context=context) + + if raw_value.type == target_type: + return raw_value + + source_kind = classify_runtime_scalar_type(raw_value.type) + target_kind = classify_runtime_scalar_type(target_type) + + if source_kind == "index" and target_kind == "integer": + return _coerce_integer_like(raw_value, target_type) + if source_kind == "integer" and target_kind == "index": + return arith.IndexCastOp(target_type, _strip_integer_signedness(raw_value)).result + if source_kind == "integer" and target_kind == "integer": + return _coerce_integer_like(raw_value, target_type) + if source_kind == "float" and target_kind == "float": + return _coerce_float_like(raw_value, target_type) + + raise TypeError( + f"{context} cannot coerce the authored value to the expected scalar type: " + f"got {raw_value.type}, expected {target_type}" + ) + + +def materialize_scalar_literal(value, target_type, *, context: str): + """Materialize one Python literal as an MLIR scalar constant of *target_type*.""" + if isinstance(value, bool): + raise TypeError(f"{context} does not accept bool literals") + + target_kind = classify_runtime_scalar_type(target_type) + if target_kind == "float": + return arith.ConstantOp(target_type, FloatAttr.get(target_type, float(value))).result + if target_kind == "index": + return arith.ConstantOp(target_type, int(value)).result + + if isinstance(value, float): + raise TypeError( + f"{context} cannot materialize a floating-point literal against non-floating " + f"target type {target_type}" + ) + + return _materialize_integer_literal(target_type, value) + + +def _coerce_integer_like(raw_value, target_type): + if IndexType.isinstance(raw_value.type): + signless_target = _signless_integer_type(target_type) + adapted = arith.IndexCastOp(signless_target, raw_value).result + return _restore_integer_signedness(adapted, target_type) + + source_type = raw_value.type + source_width = IntegerType(source_type).width + target_width = IntegerType(target_type).width + signless_source = _strip_integer_signedness(raw_value) + signless_target = _signless_integer_type(target_type) + + if source_width < target_width: + source_signedness = _integer_signedness(source_type) + if source_signedness == "unsigned": + widened = arith.ExtUIOp(signless_target, signless_source).result + else: + widened = arith.ExtSIOp(signless_target, signless_source).result + return _restore_integer_signedness(widened, target_type) + if source_width > target_width: + truncated = arith.TruncIOp(signless_target, signless_source).result + return _restore_integer_signedness(truncated, target_type) + return _restore_integer_signedness(signless_source, target_type) + + +def _coerce_float_like(raw_value, target_type): + source_width = _float_bytewidth(raw_value.type) + target_width = _float_bytewidth(target_type) + if source_width < target_width: + return arith.ExtFOp(target_type, raw_value).result + if source_width > target_width: + return arith.TruncFOp(target_type, raw_value).result + return raw_value + + +def _float_bytewidth(type_obj): + if BF16Type.isinstance(type_obj) or F16Type.isinstance(type_obj): + return 2 + if F32Type.isinstance(type_obj): + return 4 + raise TypeError(f"unsupported floating-point type {type_obj}") + + +__all__ = [ + "coerce_scalar_to_type", + "materialize_scalar_literal", +] diff --git a/ptodsl/ptodsl/_subkernels.py b/ptodsl/ptodsl/_subkernels.py new file mode 100644 index 000000000..66d44e442 --- /dev/null +++ b/ptodsl/ptodsl/_subkernels.py @@ -0,0 +1,204 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +"""Layered PTODSL subkernel decorators.""" + +from __future__ import annotations + +from dataclasses import dataclass +from enum import Enum +from functools import update_wrapper +import inspect + +from ._diagnostics import ( + illegal_inline_subkernel_placement_error, + illegal_subkernel_placement_error, + simd_value_escape_error, + subkernel_host_tensor_boundary_error, + subkernel_signature_boundary_error, +) +from ._host_tensors import TensorSpec, looks_like_host_tensor +from ._surface_values import unwrap_surface_value +from ._tracing import current_runtime, current_session + + +class KernelRole(str, Enum): + CUBE = "cube" + SIMD = "simd" + SIMT = "simt" + + +@dataclass(frozen=True) +class SubkernelSpec: + """Declarative metadata for a PTODSL subkernel surface.""" + + role: KernelRole + symbol_name: str + target: str = "a5" + + +class SubkernelTemplate: + """Callable decorated PTODSL subkernel surface.""" + + def __init__(self, spec: SubkernelSpec, py_fn): + self.spec = spec + self.py_fn = py_fn + self.signature = inspect.signature(py_fn) + self._validate_definition() + update_wrapper(self, py_fn) + + def emit_body(self, *args, **kwargs): + """Emit this subkernel body into the currently active trace.""" + result = self.py_fn(*args, **kwargs) + self._validate_result(result) + return result + + def trace_body(self, *args, **kwargs): + """Backward-compatible alias for body emission.""" + return self.emit_body(*args, **kwargs) + + def __call__(self, *args, **kwargs): + runtime = current_runtime() + if runtime is None: + raise RuntimeError( + f"@pto.{self.spec.role.value} kernels may only be called while tracing " + "a compatible PTODSL kernel" + ) + self._validate_invocation(*args, **kwargs) + return runtime.dispatch_subkernel_call(self, *args, **kwargs) + + def _validate_definition(self) -> None: + for param in self.signature.parameters.values(): + if isinstance(param.annotation, TensorSpec): + raise subkernel_signature_boundary_error(self.spec.role.value, param.name) + + def _validate_invocation(self, *args, **kwargs) -> None: + session = current_session() + outer = session.current_subkernel if session is not None else None + _validate_subkernel_placement(self.spec.role, outer) + + bound = self.signature.bind_partial(*args, **kwargs) + for name, value in bound.arguments.items(): + if looks_like_host_tensor(value): + raise subkernel_host_tensor_boundary_error(self.spec.role.value, name) + + def _validate_result(self, result) -> None: + if self.spec.role != KernelRole.SIMD: + return + escaped_type = _find_transient_simd_escape(result) + if escaped_type is not None: + raise simd_value_escape_error(escaped_type) + + +def _find_transient_simd_escape(value): + if value is None: + return None + if isinstance(value, (tuple, list)): + for item in value: + escaped = _find_transient_simd_escape(item) + if escaped is not None: + return escaped + return None + if isinstance(value, dict): + for item in value.values(): + escaped = _find_transient_simd_escape(item) + if escaped is not None: + return escaped + return None + raw_value = unwrap_surface_value(value) + type_obj = getattr(raw_value, "type", None) + if type_obj is None: + return None + type_text = str(type_obj) + if type_text.startswith("!pto.vreg<") or type_text.startswith("!pto.mask<"): + return type_text + return None + + +def _validate_subkernel_placement(role: KernelRole, outer_frame, *, inline: bool = False) -> None: + if outer_frame is None: + return + if inline: + raise illegal_inline_subkernel_placement_error(role.value, outer_frame.role) + raise illegal_subkernel_placement_error(role.value, outer_frame.role) + + +class _SubkernelSurface: + """Dual-use surface that supports both decorators and inline context-manager scopes.""" + + def __init__(self, role: KernelRole, *, name: str | None = None, target: str = "a5"): + self._role = role + self._name = name + self._target = target + self._session_cm = None + + def __call__(self, fn): + return SubkernelTemplate( + SubkernelSpec( + role=self._role, + symbol_name=self._name or fn.__name__, + target=self._target, + ), + fn, + ) + + def __enter__(self): + runtime = current_runtime() + if runtime is None: + raise RuntimeError( + f"inline pto.{self._role.value}() may only be used while tracing " + "a compatible PTODSL kernel" + ) + session = current_session() + outer = session.current_subkernel if session is not None else None + _validate_subkernel_placement(self._role, outer, inline=True) + symbol_name = self._name or f"inline_{self._role.value}" + self._session_cm = session.enter_inline_subkernel( + self._role.value, + symbol_name, + self._target, + ) + self._session_cm.__enter__() + return None + + def __exit__(self, *exc): + try: + return self._session_cm.__exit__(*exc) + finally: + self._session_cm = None + + +def _subkernel_decorator(role: KernelRole, *, name: str | None = None, target: str = "a5"): + return _SubkernelSurface(role, name=name, target=target) + + +def _decorate_subkernel(role: KernelRole, fn=None, *, name: str | None = None, target: str = "a5"): + if fn is not None: + return _subkernel_decorator(role, name=name, target=target)(fn) + return _subkernel_decorator(role, name=name, target=target) + + +def cube(fn=None, *, name: str | None = None, target: str = "a5"): + return _decorate_subkernel(KernelRole.CUBE, fn, name=name, target=target) + + +def simd(fn=None, *, name: str | None = None, target: str = "a5"): + return _decorate_subkernel(KernelRole.SIMD, fn, name=name, target=target) + + +def simt(fn=None, *, name: str | None = None, target: str = "a5"): + return _decorate_subkernel(KernelRole.SIMT, fn, name=name, target=target) + + +__all__ = [ + "KernelRole", + "SubkernelSpec", + "SubkernelTemplate", + "cube", + "simd", + "simt", +] diff --git a/ptodsl/ptodsl/_surface_types.py b/ptodsl/ptodsl/_surface_types.py new file mode 100644 index 000000000..8ea729eaa --- /dev/null +++ b/ptodsl/ptodsl/_surface_types.py @@ -0,0 +1,200 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +"""Public PTODSL surface markers and enums.""" + +from ._bootstrap import make_context # noqa: F401 +from ._host_tensors import TensorSpec, tensor_spec + +from mlir.dialects import pto as _pto + + +class _ConstexprMarker: + """Marker annotation for PTODSL compile-time specialization parameters.""" + + def __repr__(self): + return "pto.constexpr" + + +constexpr = _ConstexprMarker() + + +class MemorySpace: + """Public PTODSL memory-space enum aliases.""" + + GM = _pto.AddressSpace.GM + UB = _pto.AddressSpace.VEC + VEC = _pto.AddressSpace.VEC + MAT = _pto.AddressSpace.MAT + LEFT = _pto.AddressSpace.LEFT + RIGHT = _pto.AddressSpace.RIGHT + ACC = _pto.AddressSpace.ACC + BIAS = _pto.AddressSpace.BIAS + SCALING = _pto.AddressSpace.SCALING + + +class BarrierType: + """Public PTODSL memory-barrier kind aliases.""" + + VV_ALL = "VV_ALL" + VST_VLD = "VST_VLD" + VLD_VST = "VLD_VST" + VST_VST = "VST_VST" + VS_ALL = "VS_ALL" + VST_LD = "VST_LD" + VLD_ST = "VLD_ST" + VST_ST = "VST_ST" + SV_ALL = "SV_ALL" + ST_VLD = "ST_VLD" + LD_VST = "LD_VST" + ST_VST = "ST_VST" + SS_ALL = "SS_ALL" + ST_LD = "ST_LD" + LD_ST = "LD_ST" + ST_ST = "ST_ST" + + +class Pipe: + """Public PTODSL pipeline aliases for pipeline-level sync ops.""" + + S = _pto.PIPE.PIPE_S + V = _pto.PIPE.PIPE_V + M = _pto.PIPE.PIPE_M + MTE1 = _pto.PIPE.PIPE_MTE1 + MTE2 = _pto.PIPE.PIPE_MTE2 + MTE3 = _pto.PIPE.PIPE_MTE3 + MTE4 = _pto.PIPE.PIPE_MTE4 + MTE5 = _pto.PIPE.PIPE_MTE5 + V2 = _pto.PIPE.PIPE_V2 + FIX = _pto.PIPE.PIPE_FIX + ALL = _pto.PIPE.PIPE_ALL + + +class MaskPattern: + """Public PTODSL mask-pattern tokens.""" + + ALL = "PAT_ALL" + ALLF = "PAT_ALLF" + H = "PAT_H" + Q = "PAT_Q" + M3 = "PAT_M3" + M4 = "PAT_M4" + + +for _vl in range(1, 129): + setattr(MaskPattern, f"VL{_vl}", f"PAT_VL{_vl}") + + +class CmpMode: + """Public PTODSL compare-mode tokens.""" + + EQ = "eq" + NE = "ne" + LT = "lt" + LE = "le" + GT = "gt" + GE = "ge" + + +class PredicatePart: + """Public PTODSL predicate pack/unpack part tokens.""" + + LOWER = "LOWER" + HIGHER = "HIGHER" + + +class PredicateDist: + """Public PTODSL predicate load/store distribution tokens.""" + + NORM = "NORM" + US = "US" + DS = "DS" + PK = "PK" + + +class VStoreDist: + """Public PTODSL vector-store distribution tokens.""" + + NORM_B8 = "NORM_B8" + NORM_B16 = "NORM_B16" + NORM_B32 = "NORM_B32" + _1PT_B8 = "1PT_B8" + _1PT_B16 = "1PT_B16" + _1PT_B32 = "1PT_B32" + PK_B16 = "PK_B16" + PK_B32 = "PK_B32" + PK_B64 = "PK_B64" + PK4_B32 = "PK4_B32" + MRG4CHN_B8 = "MRG4CHN_B8" + MRG2CHN_B8 = "MRG2CHN_B8" + MRG2CHN_B16 = "MRG2CHN_B16" + + +setattr(VStoreDist, "1PT_B8", VStoreDist._1PT_B8) +setattr(VStoreDist, "1PT_B16", VStoreDist._1PT_B16) +setattr(VStoreDist, "1PT_B32", VStoreDist._1PT_B32) + + +class DeinterleaveDist: + """Public PTODSL dual-load distribution tokens.""" + + DINTLV_B8 = "DINTLV_B8" + DINTLV_B16 = "DINTLV_B16" + DINTLV_B32 = "DINTLV_B32" + BDINTLV = "BDINTLV" + + +class InterleaveDist: + """Public PTODSL dual-store distribution tokens.""" + + INTLV_B8 = "INTLV_B8" + INTLV_B16 = "INTLV_B16" + INTLV_B32 = "INTLV_B32" + + +class PostUpdate: + """Public PTODSL post-update mode tokens for stateful stores.""" + + OFF = "NO_POST_UPDATE" + ON = "POST_UPDATE" + + +AlignType = _pto.AlignType + + +class TensorView: + """Authoring-time marker for a tensor-view descriptor value.""" + + +class PartitionTensorView: + """Authoring-time marker for a partitioned tensor-view descriptor value.""" + + +class Tile: + """Authoring-time marker for an on-chip tile value.""" + + +__all__ = [ + "constexpr", + "TensorSpec", + "MemorySpace", + "BarrierType", + "Pipe", + "MaskPattern", + "CmpMode", + "PredicatePart", + "PredicateDist", + "VStoreDist", + "DeinterleaveDist", + "InterleaveDist", + "PostUpdate", + "AlignType", + "TensorView", + "PartitionTensorView", + "Tile", + "tensor_spec", +] diff --git a/ptodsl/ptodsl/_surface_values.py b/ptodsl/ptodsl/_surface_values.py new file mode 100644 index 000000000..be9a89097 --- /dev/null +++ b/ptodsl/ptodsl/_surface_values.py @@ -0,0 +1,944 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +"""Tracing-time wrappers for authored PTODSL surface values.""" + +from __future__ import annotations + +import re +from dataclasses import dataclass + +from ._diagnostics import native_python_control_flow_error +from ._runtime_scalar_ops import emit_runtime_binary_op, emit_runtime_bitwise_op, emit_runtime_compare +from ._surface_types import PartitionTensorView, TensorView, Tile +from ._types import _normalize_address_space, _resolve, ptr + +from mlir.dialects import arith +from mlir.dialects import memref +from mlir.dialects import pto as _pto +from mlir.ir import IndexType, IntegerAttr, IntegerType, MemRefType, ShapedType, StridedLayoutAttr, Type + + +def unwrap_surface_value(value): + """Return the underlying MLIR SSA value for a surface wrapper.""" + return value.value if isinstance(value, _SurfaceValue) else value + + +def _unwrap_sequence(values): + normalized = [] + interned_ints = {} + for value in values: + if isinstance(value, int): + if value not in interned_ints: + interned_ints[value] = _index_const(value) + normalized.append(interned_ints[value]) + else: + normalized.append(unwrap_surface_value(value)) + return normalized + + +def _normalize_index(value): + return unwrap_surface_value(value) + + +def _index_const(value: int): + return arith.ConstantOp(IndexType.get(), value).result + + +def _add_index(lhs, rhs): + if isinstance(lhs, int) and lhs == 0: + return _normalize_index(rhs) + if isinstance(rhs, int) and rhs == 0: + return _normalize_index(lhs) + lhs = _normalize_index(lhs) + rhs = _normalize_index(rhs) + if isinstance(lhs, int) and isinstance(rhs, int): + return lhs + rhs + if isinstance(lhs, int): + lhs = _index_const(lhs) + if isinstance(rhs, int): + rhs = _index_const(rhs) + return arith.AddIOp(lhs, rhs).result + + +def _try_get_constant_index(value) -> int | None: + """Return a compile-time index when *value* is a Python int or ``arith.constant``.""" + if isinstance(value, int): + return value + raw = unwrap_surface_value(value) + owner = getattr(raw, "owner", None) + if owner is None or not hasattr(owner, "operation"): + return None + if owner.operation.name != "arith.constant": + return None + attrs = owner.operation.attributes + if "value" not in attrs: + return None + try: + return IntegerAttr(attrs["value"]).value + except Exception: + return None + + +def _static_index_dims(values) -> tuple[int, ...] | None: + """Return static index dimensions when every entry is known at trace time.""" + dims = [] + for value in values: + dim = _try_get_constant_index(value) + if dim is None: + return None + dims.append(dim) + return tuple(dims) + + +def _maybe_cast_tensor_view_type(type_obj): + try: + return _pto.TensorViewType(type_obj) + except Exception: + return None + + +def _maybe_cast_partition_tensor_view_type(type_obj): + try: + return _pto.PartitionTensorViewType(type_obj) + except Exception: + return None + + +def _maybe_cast_tile_buf_type(type_obj): + try: + return _pto.TileBufType(type_obj) + except Exception: + return None + + +def wrap_surface_value( + value, + *, + root_tensor_view=None, + offsets=None, + sizes=None, + tile_metadata=None, +): + """Wrap a raw MLIR value into the authored PTODSL surface type when needed.""" + if isinstance(value, _SurfaceValue): + return value + + type_obj = value.type + if _maybe_cast_tensor_view_type(type_obj) is not None: + return TensorViewValue(value) + if _maybe_cast_partition_tensor_view_type(type_obj) is not None: + return PartitionTensorViewValue( + value, + root_tensor_view=root_tensor_view, + offsets=offsets, + sizes=sizes, + ) + if _maybe_cast_tile_buf_type(type_obj) is not None: + return TileValue(value, **(tile_metadata or {})) + try: + MemRefType(type_obj) + return AddressValue(value) + except Exception: + pass + return RuntimeValue(value) + + +class _SurfaceValue: + """Base class for authored PTODSL values backed by one MLIR SSA value.""" + + def __init__(self, value): + self._value = value + + @property + def value(self): + return self._value + + @property + def type(self): + return self._value.type + + @property + def surface_metadata(self): + return None + + def __bool__(self): + raise native_python_control_flow_error("if/while condition") + + def __iter__(self): + raise native_python_control_flow_error("for-loop iteration") + + def __repr__(self): + return repr(self._value) + + +class RuntimeValue(_SurfaceValue): + """Generic authored runtime value wrapper with fail-fast Python misuse diagnostics.""" + + def __index__(self): + raise native_python_control_flow_error("range()/loop bound") + + def __int__(self): + raise native_python_control_flow_error("int() coercion") + + def __add__(self, other): + return wrap_surface_value(emit_runtime_binary_op("add", self.value, unwrap_surface_value(other))) + + def __radd__(self, other): + return wrap_surface_value(emit_runtime_binary_op("add", unwrap_surface_value(other), self.value)) + + def __sub__(self, other): + return wrap_surface_value(emit_runtime_binary_op("sub", self.value, unwrap_surface_value(other))) + + def __rsub__(self, other): + return wrap_surface_value(emit_runtime_binary_op("sub", unwrap_surface_value(other), self.value)) + + def __mul__(self, other): + return wrap_surface_value(emit_runtime_binary_op("mul", self.value, unwrap_surface_value(other))) + + def __rmul__(self, other): + return wrap_surface_value(emit_runtime_binary_op("mul", unwrap_surface_value(other), self.value)) + + def __truediv__(self, other): + return wrap_surface_value(emit_runtime_binary_op("truediv", self.value, unwrap_surface_value(other))) + + def __rtruediv__(self, other): + return wrap_surface_value(emit_runtime_binary_op("truediv", unwrap_surface_value(other), self.value)) + + def __floordiv__(self, other): + return wrap_surface_value(emit_runtime_binary_op("floordiv", self.value, unwrap_surface_value(other))) + + def __rfloordiv__(self, other): + return wrap_surface_value(emit_runtime_binary_op("floordiv", unwrap_surface_value(other), self.value)) + + def __mod__(self, other): + return wrap_surface_value(emit_runtime_binary_op("mod", self.value, unwrap_surface_value(other))) + + def __rmod__(self, other): + return wrap_surface_value(emit_runtime_binary_op("mod", unwrap_surface_value(other), self.value)) + + def __lt__(self, other): + return wrap_surface_value(emit_runtime_compare("lt", self.value, unwrap_surface_value(other))) + + def __le__(self, other): + return wrap_surface_value(emit_runtime_compare("le", self.value, unwrap_surface_value(other))) + + def __gt__(self, other): + return wrap_surface_value(emit_runtime_compare("gt", self.value, unwrap_surface_value(other))) + + def __ge__(self, other): + return wrap_surface_value(emit_runtime_compare("ge", self.value, unwrap_surface_value(other))) + + def __eq__(self, other): + return wrap_surface_value(emit_runtime_compare("eq", self.value, unwrap_surface_value(other))) + + def __ne__(self, other): + return wrap_surface_value(emit_runtime_compare("ne", self.value, unwrap_surface_value(other))) + + def __and__(self, other): + return wrap_surface_value(emit_runtime_bitwise_op("and", self.value, unwrap_surface_value(other))) + + def __rand__(self, other): + return wrap_surface_value(emit_runtime_bitwise_op("and", unwrap_surface_value(other), self.value)) + + def __or__(self, other): + return wrap_surface_value(emit_runtime_bitwise_op("or", self.value, unwrap_surface_value(other))) + + def __ror__(self, other): + return wrap_surface_value(emit_runtime_bitwise_op("or", unwrap_surface_value(other), self.value)) + + def __xor__(self, other): + return wrap_surface_value(emit_runtime_bitwise_op("xor", self.value, unwrap_surface_value(other))) + + def __rxor__(self, other): + return wrap_surface_value(emit_runtime_bitwise_op("xor", unwrap_surface_value(other), self.value)) + + +class MaskResultValue(_SurfaceValue): + """Mask value that also supports `(mask, remained)` unpacking.""" + + def __init__(self, mask_value, scalar_out): + super().__init__(mask_value) + self.scalar_out = wrap_surface_value(scalar_out) + + def __iter__(self): + yield self + yield self.scalar_out + + +class AddressValue(_SurfaceValue): + """Author-facing address view backed by either a PTO ptr or a memref.""" + + def __add__(self, offset): + return AddressOffsetValue(self, offset) + + def __radd__(self, offset): + return AddressOffsetValue(self, offset) + + +@dataclass(frozen=True) +class AddressOffsetValue: + """Address view plus an element offset, used by scalar.load/store sugar.""" + + base: AddressValue + offset: object + + def __add__(self, other): + return AddressOffsetValue(self.base, _add_index(self.offset, other)) + + def __radd__(self, other): + return AddressOffsetValue(self.base, _add_index(other, self.offset)) + + def __bool__(self): + raise native_python_control_flow_error("if/while condition") + + def __iter__(self): + raise native_python_control_flow_error("for-loop iteration") + + +@dataclass(frozen=True) +class TileElementRef: + """One logical tile element selected by tile[row, col] surface syntax.""" + + tile: "TileValue" + linear_offset: object + + def __bool__(self): + raise native_python_control_flow_error("if/while condition") + + def __iter__(self): + raise native_python_control_flow_error("for-loop iteration") + + +class TileSliceValue(_SurfaceValue): + """Author-facing memref view produced by `tile[row, col:]` style indexing.""" + + def __init__(self, value, *, tile: "TileValue", offsets, shape): + super().__init__(value) + self.tile = tile + self.offsets = tuple(offsets) + self.shape = tuple(shape) + + @property + def surface_metadata(self): + return { + "tile": self.tile, + "offsets": self.offsets, + "shape": self.shape, + } + + +class TensorViewValue(_SurfaceValue, TensorView): + """Author-facing tensor-view descriptor value.""" + + def __init__(self, value, *, shape=None, strides=None): + super().__init__(value) + self.shape = tuple(shape) if shape is not None else None + self.strides = tuple(strides) if strides is not None else None + + @property + def surface_metadata(self): + return { + "shape": self.shape, + "strides": self.strides, + } + + def as_ptr(self): + from ._ops import as_ptr + return as_ptr(self) + + +class PartitionTensorViewValue(_SurfaceValue, PartitionTensorView): + """Author-facing partitioned tensor-view descriptor value.""" + + def __init__(self, value, *, root_tensor_view=None, offsets=None, sizes=None): + super().__init__(value) + self.root_tensor_view = root_tensor_view + self.offsets = tuple(offsets) if offsets is not None else None + self.sizes = tuple(sizes) if sizes is not None else None + self.shape = self.sizes + self.strides = getattr(root_tensor_view, "strides", None) + + def as_ptr(self): + from ._ops import as_ptr + return as_ptr(self) + + +class _TileValidShapeView: + """Tuple-like proxy that lowers `tile.valid_shape[i]` on demand.""" + + def __init__(self, tile: "TileValue"): + self._tile = tile + self._cache: dict[int, object] = {} + + def __getitem__(self, index: int): + logical_rank = len(self._tile.shape) if self._tile.shape is not None else 2 + allowed = {0} if logical_rank == 1 else {0, 1} + if index not in allowed: + if logical_rank == 1: + raise IndexError("PTODSL rank-1 tile.valid_shape currently supports only index 0") + raise IndexError("PTODSL tile.valid_shape currently supports indices 0 and 1") + cached = self._cache.get(index) + if cached is not None: + return cached + if self._tile.static_valid_shape is not None: + dim = self._tile.static_valid_shape[index] + if dim is not None: + value = _index_const(dim) if isinstance(dim, int) else unwrap_surface_value(dim) + value = wrap_surface_value(value) + self._cache[index] = value + return value + try: + if logical_rank == 1: + value = wrap_surface_value(_pto.TileValidColsOp(self._tile.value).result) + elif index == 0: + value = wrap_surface_value(_pto.TileValidRowsOp(self._tile.value).result) + else: + value = wrap_surface_value(_pto.TileValidColsOp(self._tile.value).result) + except Exception: + static_dim = _fallback_static_valid_dim(self._tile.type, index) + if static_dim is None: + raise RuntimeError( + "tile.valid_shape could not be lowered because the current " + "Python bindings do not materialize pto.tile_valid_* and " + "the tile type does not carry a recoverable static bound" + ) from None + value = wrap_surface_value(_index_const(static_dim)) + self._cache[index] = value + return value + + +class TileValue(_SurfaceValue, Tile): + """Author-facing tile handle with surface-style accessors.""" + + def __init__( + self, + value, + *, + shape=None, + physical_shape=None, + dtype=None, + memory_space=None, + valid_shape=None, + ): + super().__init__(value) + parsed = parse_tile_type_metadata(value.type) + self.shape = tuple(shape) if shape is not None else ( + parsed["shape_dims"] if parsed is not None else None + ) + self.physical_shape = tuple(physical_shape) if physical_shape is not None else ( + tuple(shape) if shape is not None else ( + parsed["shape_dims"] if parsed is not None else None + ) + ) + self.dtype = dtype if dtype is not None else ( + parsed["element_type"] if parsed is not None else None + ) + self.memory_space = memory_space if memory_space is not None else ( + parsed["memory_space"] if parsed is not None else None + ) + self.static_valid_shape = tuple(valid_shape) if valid_shape is not None else ( + parsed["valid_dims"] if parsed is not None else None + ) + self._valid_shape = _TileValidShapeView(self) + + @property + def valid_shape(self): + return self._valid_shape + + @valid_shape.setter + def valid_shape(self, dims): + from ._ops import set_tile_valid_shape + + set_tile_valid_shape(self, dims) + self.static_valid_shape = tuple(dims) + self._valid_shape._cache.clear() + + @property + def surface_metadata(self): + return { + "shape": self.shape, + "physical_shape": self.physical_shape, + "dtype": self.dtype, + "memory_space": self.memory_space, + "valid_shape": self.static_valid_shape, + } + + def as_ptr(self): + from ._ops import as_ptr + return as_ptr(self) + + def fill(self, value): + from ._ops import fill_tile + fill_tile(self, value) + + def __getitem__(self, key): + if not isinstance(key, tuple): + key = (key,) + if self.shape is None: + raise RuntimeError("tile indexing requires tile shape metadata") + + if _is_tile_slice_key(key, self.shape): + return _materialize_tile_slice(self, key) + + if len(key) != len(self.shape): + raise TypeError( + f"tile indexing expects {len(self.shape)} indices, got {len(key)}" + ) + linear_offset = 0 + stride = 1 + for index, dim in zip(reversed(key), reversed(self.shape)): + linear_offset = _add_index(linear_offset, _mul_index(index, stride)) + if dim is None: + raise RuntimeError("tile indexing requires static tile shape metadata") + stride *= dim + return TileElementRef(self, linear_offset) + + +@dataclass(frozen=True) +class PartitionSpec: + """Logical authored partition metadata used to compose nested slices.""" + + root_tensor_view: object + offsets: tuple + sizes: tuple + + +def wrap_like_surface_value(template, value): + """Wrap *value* using the same authored surface contract as *template*.""" + if isinstance(template, PartitionTensorViewValue): + return PartitionTensorViewValue( + value, + root_tensor_view=template.root_tensor_view, + offsets=template.offsets, + sizes=template.sizes, + ) + if isinstance(template, TensorViewValue): + return TensorViewValue(value, shape=template.shape, strides=template.strides) + if isinstance(template, TileValue): + return TileValue(value, **template.surface_metadata) + if isinstance(template, AddressValue): + return AddressValue(value) + return wrap_surface_value(value) + + +def extract_partition_spec(source) -> PartitionSpec | None: + """Return the root tensor-view + composed slice metadata when available.""" + if isinstance(source, PartitionTensorViewValue) and source.root_tensor_view is not None: + return PartitionSpec( + root_tensor_view=source.root_tensor_view, + offsets=source.offsets or (), + sizes=source.sizes or (), + ) + if isinstance(source, TensorViewValue): + return PartitionSpec(root_tensor_view=source, offsets=(), sizes=()) + return None + + +def compose_partition_spec(source, *, offsets, sizes) -> PartitionSpec | None: + """Compose a nested `partition_view(...)` against an existing partition.""" + parent = extract_partition_spec(source) + if parent is None: + return None + if isinstance(source, TensorViewValue): + return PartitionSpec( + root_tensor_view=source, + offsets=tuple(offsets), + sizes=tuple(sizes), + ) + if parent.offsets and len(parent.offsets) != len(offsets): + raise ValueError("nested partition_view rank mismatch") + composed_offsets = tuple( + _add_index(parent_offset, child_offset) + for parent_offset, child_offset in zip(parent.offsets, offsets) + ) + return PartitionSpec( + root_tensor_view=parent.root_tensor_view, + offsets=composed_offsets, + sizes=tuple(sizes), + ) + + +def infer_ptr_type_from_surface_value(surface_value): + """Infer a PTO pointer type for `as_ptr()` from the authored source value.""" + value_type = surface_value.type + + tv_type = _maybe_cast_tensor_view_type(value_type) + if tv_type is not None: + return _resolve(ptr(tv_type.element_type, "gm")) + + part_type = _maybe_cast_partition_tensor_view_type(value_type) + if part_type is not None: + return _resolve(ptr(part_type.element_type, "gm")) + + tile_type = _maybe_cast_tile_buf_type(value_type) + if tile_type is None: + raise TypeError("as_ptr() expects a Tile, TensorView, or PartitionTensorView surface value") + + memory_space = getattr(tile_type, "memory_space", None) + parsed = None + if memory_space is None: + parsed = parse_tile_type_metadata(value_type) + if parsed is None: + raise RuntimeError("unable to infer tile pointer type: tile type is missing memory-space metadata") + memory_space = parsed["memory_space"] + + space_enum = getattr(memory_space, "value", None) + if space_enum is not None: + space_enum = _normalize_address_space(_ADDRESS_SPACE_VALUE_TO_KEYWORD.get(space_enum)) + else: + space_enum = _normalize_address_space(str(memory_space)) + if space_enum is None: + raise RuntimeError("unable to infer tile pointer type: unsupported tile memory space") + + return _resolve(ptr(tile_type.element_type, space_enum)) + + +def emit_as_ptr(surface_value): + """Lower `as_ptr()` on a surface value to the appropriate PTO op.""" + value = unwrap_surface_value(surface_value) + result_type = infer_address_type_from_surface_value(surface_value) + + if isinstance(surface_value, (TensorViewValue, PartitionTensorViewValue)): + return AddressValue(_pto.TensorViewAddrOp(result_type, value).result) + if isinstance(surface_value, TileValue): + return AddressValue(_pto.TileBufAddrOp(result_type, value).result) + raise TypeError("as_ptr() expects a Tile, TensorView, or PartitionTensorView surface value") + + +_TILE_TYPE_RE = re.compile( + r"!pto\.tile_buf<(?P[^,]+),\s*(?P.+?)x(?P[^,x>]+),\s*valid=(?P[^,>]+)(?:,.*)?>" +) + + +_ADDRESS_SPACE_VALUE_TO_KEYWORD = { + 1: "gm", + 2: "mat", + 3: "left", + 4: "right", + 5: "acc", + 6: "vec", + 7: "bias", + 8: "scaling", +} + + +def _read_tile_type_metadata_from_binding(type_obj): + required = ("shape", "element_type", "memory_space", "valid_shape") + if not all(hasattr(type_obj, name) for name in required): + return None + + memory_space_attr = type_obj.memory_space + memory_space_value = getattr(memory_space_attr, "value", None) + memory_space = _ADDRESS_SPACE_VALUE_TO_KEYWORD.get(memory_space_value) + if memory_space is None: + return None + + def _normalize_dims(seq): + dims = [] + for dim in seq: + dims.append(None if dim == ShapedType.get_dynamic_size() else int(dim)) + return tuple(dims) + + return { + "memory_space": memory_space, + "shape_dims": _normalize_dims(type_obj.shape), + "element_type": type_obj.element_type, + "valid_dims": _normalize_dims(type_obj.valid_shape), + } + + +def _fallback_static_valid_dim(type_obj, index: int): + parsed = parse_tile_type_metadata(type_obj) + if parsed is None: + return None + shape_dims = parsed["shape_dims"] + valid_dims = parsed["valid_dims"] + if index >= len(shape_dims) or index >= len(valid_dims): + return None + valid_dim = valid_dims[index] + if valid_dim is not None: + return valid_dim + return shape_dims[index] + + +def parse_tile_type_metadata(type_obj): + bound = _read_tile_type_metadata_from_binding(type_obj) + if bound is not None: + return bound + + match = _TILE_TYPE_RE.match(str(type_obj)) + if match is None: + return None + shape_dims = [ + None if dim == "?" else int(dim) + for dim in match.group("shape").split("x") + ] + valid_dims = [ + None if dim == "?" else int(dim) + for dim in match.group("valid").split("x") + ] + return { + "memory_space": match.group("space"), + "shape_dims": tuple(shape_dims), + "element_type": Type.parse(match.group("elem")), + "valid_dims": tuple(valid_dims), + } + + +def infer_tile_element_type(tile): + """Recover the tile element type from authored metadata or type text.""" + if isinstance(tile, TileValue) and tile.dtype is not None: + return _resolve(tile.dtype) + parsed = parse_tile_type_metadata(tile.type if isinstance(tile, TileValue) else tile) + if parsed is None: + raise RuntimeError("unable to recover tile element type from tile surface value") + return parsed["element_type"] + + +def infer_address_type_from_surface_value(surface_value): + """Infer the concrete result type emitted by `as_ptr()`.""" + return infer_ptr_type_from_surface_value(surface_value) + + +def infer_memref_type_from_surface_value(surface_value): + """Build a memref address-view type that preserves element/rank/address-space.""" + if isinstance(surface_value, TileSliceValue): + return surface_value.type + + if isinstance(surface_value, TileValue): + physical_shape = getattr(surface_value, "physical_shape", None) + if physical_shape is not None and surface_value.dtype is not None and surface_value.memory_space is not None: + space_enum = _normalize_address_space(surface_value.memory_space) + if space_enum is None: + raise RuntimeError("unsupported tile memory space for memref address view") + return MemRefType.get( + list(physical_shape), + _resolve(surface_value.dtype), + memory_space=_pto.AddressSpaceAttr.get(space_enum), + ) + + value_type = surface_value.type + + tv_type = _maybe_cast_tensor_view_type(value_type) + if tv_type is not None: + return MemRefType.get( + [ShapedType.get_dynamic_size()] * tv_type.rank, + tv_type.element_type, + memory_space=_pto.AddressSpaceAttr.get(_pto.AddressSpace.GM), + ) + + part_type = _maybe_cast_partition_tensor_view_type(value_type) + if part_type is not None: + return MemRefType.get( + [ShapedType.get_dynamic_size()] * part_type.rank, + part_type.element_type, + memory_space=_pto.AddressSpaceAttr.get(_pto.AddressSpace.GM), + ) + + tile_type = _maybe_cast_tile_buf_type(value_type) + if tile_type is None: + raise TypeError("memref address inference expects a Tile, TensorView, or PartitionTensorView") + + parsed = parse_tile_type_metadata(value_type) + if parsed is None: + raise RuntimeError("unable to recover tile memref shape/address-space") + space_enum = _normalize_address_space(parsed["memory_space"]) + if space_enum is None: + raise RuntimeError("unsupported tile memory space for memref address view") + return MemRefType.get( + list(parsed["shape_dims"]), + parsed["element_type"], + memory_space=_pto.AddressSpaceAttr.get(space_enum), + ) + + +def resolve_address_access(target, offset=None): + """Normalize address/tile element sugar into `(buffer, index_offset)`.""" + if isinstance(target, TileElementRef): + base = emit_as_ptr(target.tile) + resolved_offset = target.linear_offset + elif isinstance(target, AddressOffsetValue): + base = target.base + resolved_offset = target.offset + elif isinstance(target, AddressValue): + base = target + resolved_offset = 0 + else: + base = target + resolved_offset = 0 + + if offset is not None: + resolved_offset = _add_index(resolved_offset, offset) + + return unwrap_surface_value(base), _coerce_index_value(resolved_offset) + + +def _is_tile_slice_key(key, shape): + if len(shape) == 1: + return len(key) == 1 and isinstance(key[0], slice) + if len(shape) == 2: + return len(key) == 2 and isinstance(key[1], slice) + return False + + +def _materialize_tile_slice(tile: TileValue, key): + rank = len(tile.shape) + if rank == 1: + start_slice = key[0] + if start_slice.stop is not None or start_slice.step is not None: + raise TypeError("tile[start:] only supports an open-ended slice") + start = 0 if start_slice.start is None else start_slice.start + return _build_tile_slice_view( + tile, + raw_offsets=[0, start], + shape=[_dynamic_extent(tile.shape[0], start)], + ) + + row, col_slice = key + if col_slice.stop is not None or col_slice.step is not None: + raise TypeError("tile[row, col:] only supports an open-ended column slice") + col = 0 if col_slice.start is None else col_slice.start + return _build_tile_slice_view( + tile, + raw_offsets=[row, col], + shape=[_dynamic_extent(tile.shape[1], col)], + ) + + +def _build_tile_slice_view(tile: TileValue, *, raw_offsets, shape): + base_memref = _emit_tile_memref(tile) + base_type = MemRefType(base_memref.type) + rank = len(base_type.shape) + offset_operands, static_offsets = _split_dynamic_index_operands(raw_offsets) + shape_operands, static_shape = _split_dynamic_index_operands(shape) + if rank == 1: + slice_type = _make_strided_memref_type( + [_static_extent_if_known(shape[0])], + base_type.element_type, + [1], + base_type.memory_space, + ) + slice_value = memref.SubViewOp( + slice_type, + base_memref, + offset_operands, + shape_operands, + [], + static_offsets, + static_shape, + [1], + ).result + return TileSliceValue(slice_value, tile=tile, offsets=tuple(raw_offsets), shape=shape) + + slice_type = _make_strided_memref_type( + [_static_extent_if_known(shape[0])], + base_type.element_type, + [1], + base_type.memory_space, + ) + slice_value = memref.SubViewOp( + slice_type, + base_memref, + offset_operands, + shape_operands, + [], + static_offsets, + [1, static_shape[0]], + [1, 1], + ).result + return TileSliceValue(slice_value, tile=tile, offsets=tuple(raw_offsets), shape=shape) + + +def _emit_tile_memref(tile: TileValue): + memref_type = infer_memref_type_from_surface_value(tile) + return _pto.TileBufAddrOp(memref_type, tile.value).result + + +def _dynamic_extent(static_dim, start): + if isinstance(start, int): + return static_dim - start + return arith.SubIOp(_index_const(static_dim), _coerce_index_value(start)).result + + +def _static_extent_if_known(extent): + return extent if isinstance(extent, int) else ShapedType.get_dynamic_size() + + +def _static_index_attr(value): + return value if isinstance(value, int) else ShapedType.get_dynamic_size() + + +def _split_dynamic_index_operands(values): + operands = [] + static_attrs = [] + for value in values: + if isinstance(value, int): + static_attrs.append(value) + else: + operands.append(_coerce_index_value(value)) + static_attrs.append(ShapedType.get_dynamic_size()) + return operands, static_attrs + + +def _make_strided_memref_type(shape, element_type, strides, memory_space): + return MemRefType.get( + list(shape), + element_type, + StridedLayoutAttr.get(ShapedType.get_dynamic_size(), list(strides)), + memory_space, + ) + + +def _mul_index(lhs, rhs): + lhs = _normalize_index(lhs) + rhs = _normalize_index(rhs) + if isinstance(lhs, int) and isinstance(rhs, int): + return lhs * rhs + if isinstance(lhs, int): + lhs = _index_const(lhs) + if isinstance(rhs, int): + rhs = _index_const(rhs) + return arith.MulIOp(lhs, rhs).result + + +def _coerce_index_value(value): + value = _normalize_index(value) + if isinstance(value, int): + return _index_const(value) + if IndexType.isinstance(value.type): + return value + if IntegerType.isinstance(value.type): + return arith.IndexCastOp(IndexType.get(), value).result + raise TypeError(f"expected an index-like value, got {value.type}") + + +__all__ = [ + "AddressOffsetValue", + "AddressValue", + "MaskResultValue", + "PartitionSpec", + "PartitionTensorViewValue", + "RuntimeValue", + "TileElementRef", + "TileSliceValue", + "TensorViewValue", + "TileValue", + "compose_partition_spec", + "emit_as_ptr", + "extract_partition_spec", + "infer_tile_element_type", + "infer_address_type_from_surface_value", + "infer_memref_type_from_surface_value", + "infer_ptr_type_from_surface_value", + "parse_tile_type_metadata", + "resolve_address_access", + "unwrap_surface_value", + "wrap_like_surface_value", + "wrap_surface_value", + "_unwrap_sequence", +] diff --git a/ptodsl/ptodsl/_tensor_factories.py b/ptodsl/ptodsl/_tensor_factories.py new file mode 100644 index 000000000..4336c3df9 --- /dev/null +++ b/ptodsl/ptodsl/_tensor_factories.py @@ -0,0 +1,42 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +"""Small host-side tensor factory helpers used by PTODSL wrappers.""" + +from __future__ import annotations + + +def empty_like(tensor): + """Allocate one host-side tensor with the same logical metadata as *tensor*.""" + new_empty = getattr(tensor, "new_empty", None) + if callable(new_empty): + return new_empty(tensor.shape) + + try: + import torch # type: ignore + except Exception: + torch = None + if torch is not None and isinstance(tensor, torch.Tensor): + return torch.empty_like(tensor) + + try: + import numpy as np # type: ignore + except Exception: + np = None + if np is not None and isinstance(tensor, np.ndarray): + return np.empty_like(tensor) + + raise TypeError( + "pto.empty_like(...) could not infer how to allocate an output tensor for " + f"{type(tensor)!r}; provide O= explicitly or use a tensor type exposing " + ".new_empty(...), torch.empty_like, or numpy.empty_like support" + ) + + +__all__ = [ + "empty_like", +] diff --git a/ptodsl/ptodsl/_tile_namespace.py b/ptodsl/ptodsl/_tile_namespace.py new file mode 100644 index 000000000..66d01224e --- /dev/null +++ b/ptodsl/ptodsl/_tile_namespace.py @@ -0,0 +1,128 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +from . import _ops + + +def _resolve_row_reduction_tmp(src, tmp): + if tmp is not None: + return tmp + return _ops.alloc_tile(tile_type=_ops.unwrap_surface_value(src).type) + + +class _TileNamespace: + load = staticmethod(_ops.tload) + store = staticmethod(_ops.tstore) + mov = staticmethod(_ops.tmov) + + add = staticmethod(_ops.tadd) + sub = staticmethod(_ops.tsub) + mul = staticmethod(_ops.tmul) + div = staticmethod(_ops.tdiv) + max = staticmethod(_ops.tmax) + min = staticmethod(_ops.tmin) + + adds = staticmethod(_ops.tadds) + subs = staticmethod(_ops.tsubs) + muls = staticmethod(_ops.tmuls) + divs = staticmethod(_ops.tdivs) + maxs = staticmethod(_ops.tmaxs) + mins = staticmethod(_ops.tmins) + + exp = staticmethod(_ops.texp) + log = staticmethod(_ops.tlog) + sqrt = staticmethod(_ops.tsqrt) + rsqrt = staticmethod(_ops.trsqrt) + recip = staticmethod(_ops.trecip) + abs = staticmethod(_ops.tabs) + neg = staticmethod(_ops.tneg) + + relu = staticmethod(_ops.trelu) + lrelu = staticmethod(_ops.tlrelu) + + @staticmethod + def rowsum(src, dst, *, tmp=None): + return _ops.trowsum(src, _resolve_row_reduction_tmp(src, tmp), dst) + + @staticmethod + def rowmax(src, dst, *, tmp=None): + return _ops.trowmax(src, _resolve_row_reduction_tmp(src, tmp), dst) + + @staticmethod + def rowmin(src, dst, *, tmp=None): + return _ops.trowmin(src, _resolve_row_reduction_tmp(src, tmp), dst) + + @staticmethod + def rowprod(src, dst, *, tmp=None): + return _ops.trowprod(src, _resolve_row_reduction_tmp(src, tmp), dst) + + @staticmethod + def rowargmax(src, dst, *, tmp=None): + return _ops.trowargmax(src, _resolve_row_reduction_tmp(src, tmp), dst) + + @staticmethod + def rowargmin(src, dst, *, tmp=None): + return _ops.trowargmin(src, _resolve_row_reduction_tmp(src, tmp), dst) + + colsum = staticmethod(_ops.tcolsum) + colmax = staticmethod(_ops.tcolmax) + colmin = staticmethod(_ops.tcolmin) + colprod = staticmethod(_ops.tcolprod) + colargmax = staticmethod(_ops.tcolargmax) + colargmin = staticmethod(_ops.tcolargmin) + + cmp = staticmethod(_ops.tcmp) + cmps = staticmethod(_ops.tcmps) + + expands = staticmethod(_ops.texpands) + rowexpand = staticmethod(_ops.trowexpand) + colexpand = staticmethod(_ops.tcolexpand) + + rowexpandadd = staticmethod(_ops.trowexpandadd) + rowexpandsub = staticmethod(_ops.trowexpandsub) + rowexpandmul = staticmethod(_ops.trowexpandmul) + rowexpanddiv = staticmethod(_ops.trowexpanddiv) + rowexpandmax = staticmethod(_ops.trowexpandmax) + rowexpandmin = staticmethod(_ops.trowexpandmin) + rowexpandexpdif = staticmethod(_ops.trowexpandexpdif) + + colexpandadd = staticmethod(_ops.tcolexpandadd) + colexpandsub = staticmethod(_ops.tcolexpandsub) + colexpandmul = staticmethod(_ops.tcolexpandmul) + colexpanddiv = staticmethod(_ops.tcolexpanddiv) + colexpandmax = staticmethod(_ops.tcolexpandmax) + colexpandmin = staticmethod(_ops.tcolexpandmin) + colexpandexpdif = staticmethod(_ops.tcolexpandexpdif) + + sel = staticmethod(_ops.tsel) + sels = staticmethod(_ops.tsels) + cvt = staticmethod(_ops.tcvt) + + bit_not = staticmethod(_ops.tnot) + bit_and = staticmethod(_ops.tand) + bit_ands = staticmethod(_ops.tands) + bit_or = staticmethod(_ops.tor) + bit_ors = staticmethod(_ops.tors) + bit_xor = staticmethod(_ops.txor) + bit_xors = staticmethod(_ops.txors) + bit_shl = staticmethod(_ops.tshl) + bit_shls = staticmethod(_ops.tshls) + bit_shr = staticmethod(_ops.tshr) + bit_shrs = staticmethod(_ops.tshrs) + + partadd = staticmethod(_ops.tpartadd) + partmul = staticmethod(_ops.tpartmul) + partmax = staticmethod(_ops.tpartmax) + partmin = staticmethod(_ops.tpartmin) + + fillpad = staticmethod(_ops.tfillpad) + fillpad_expand = staticmethod(_ops.tfillpad_expand) + fillpad_inplace = staticmethod(_ops.tfillpad_inplace) + + +tile = _TileNamespace() diff --git a/ptodsl/ptodsl/_tile_template_tracing.py b/ptodsl/ptodsl/_tile_template_tracing.py new file mode 100644 index 000000000..971e8ece3 --- /dev/null +++ b/ptodsl/ptodsl/_tile_template_tracing.py @@ -0,0 +1,718 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +""" +Tile-template tracing implementation for PTODSL tile templates. + +This module keeps the authored Python body close to TileLang-style templates, +but traces execution directly into MLIR Python bindings instead of going through +an AST-capture frontend. + +Current scope: +- bare ``Tile`` parameters with static 2D specializations +- ``dst.element_type`` / ``dst.valid_shape`` +- optional `with pto.vecscope():` +- explicit structured `with pto.for_(...) as ...:` +- optional named loop-carried state via ``state={...}`` +- ``get_lanes(dtype)`` +- ``make_mask(dtype, remained)`` +- ``vlds(tile[row, col:])`` +- ``vadd(lhs, rhs, mask)`` +- ``vsts(vec, tile[row, col:], mask)`` + +The current goal is to keep a narrow tile-template tracing path that already +builds real MLIR Python objects, while keeping its scope explicit and aligned +with the main PTODSL tracing runtime. +""" + +from __future__ import annotations + +import inspect +from dataclasses import dataclass +from pathlib import Path +from . import scalar as _scalar +from ._surface_types import Tile +from ._tracing import ( + KernelModuleSpec, + ModuleArtifact, + ModuleStyle, + TracingRuntime, + require_active_runtime, +) +from ._types import ( + _resolve, + float16 as _float16, + float32 as _float32, + index as _index, + int8 as _int8, + int16 as _int16, + int32 as _int32, + int64 as _int64, + mask_type as _mask_type, + ptr as _ptr, + tile_buf_type as _tile_buf_type, + vreg_type as _vreg_type, +) + +from mlir.dialects import arith, pto as _pto, scf +from mlir.ir import InsertionPoint, IntegerType, Type + + +@dataclass(frozen=True) +class ScalarType: + name: str + lanes: int + mask_bits: int + bytewidth: int + + def __repr__(self) -> str: + return self.name + + +f32 = ScalarType("f32", lanes=64, mask_bits=32, bytewidth=4) +f16 = ScalarType("f16", lanes=128, mask_bits=16, bytewidth=2) +bf16 = ScalarType("bf16", lanes=128, mask_bits=16, bytewidth=2) +i32 = ScalarType("i32", lanes=64, mask_bits=32, bytewidth=4) +i16 = ScalarType("i16", lanes=128, mask_bits=16, bytewidth=2) +i8 = ScalarType("i8", lanes=256, mask_bits=8, bytewidth=1) + + +@dataclass(frozen=True) +class TileSpec: + shape: tuple[int, int] + dtype: ScalarType + memory_space: str = "ub" + + def __post_init__(self): + if len(self.shape) != 2: + raise ValueError("TileSpec currently only supports rank-2 tile shapes") + if any(not isinstance(dim, int) or dim <= 0 for dim in self.shape): + raise ValueError("TileSpec.shape must contain positive integers") + if self.memory_space != "ub": + raise ValueError("TileSpec currently only supports ub tiles") + + def mlir_type(self): + rows, cols = self.shape + return _tile_buf_type( + [rows, cols], + _scalar_descriptor(self.dtype), + [rows, cols], + blayout="RowMajor", + address_space=self.memory_space, + slayout="NoneBox", + fractal_size=512, + pad="Null", + ) + + +@dataclass(frozen=True) +class _Value: + value: object + const_value: int | None = None + + def __repr__(self) -> str: + return str(self.value) + + @property + def type_text(self) -> str: + return str(self.value.type) + + @property + def is_const(self) -> bool: + return self.const_value is not None + + +@dataclass(frozen=True) +class _MaskValue: + value: object + dtype: ScalarType + + @property + def type_text(self) -> str: + return str(self.value.type) + + +@dataclass(frozen=True) +class _VectorValue: + value: object + dtype: ScalarType + + @property + def type_text(self) -> str: + return str(self.value.type) + + +@dataclass(frozen=True) +class _TileSlice: + tile: "_TileProxy" + row: int | _Value + col: int | _Value + + +class _TileProxy: + def __init__(self, trace: "_TraceBuilder", arg_value, spec: TileSpec): + self._trace = trace + self._arg_value = arg_value + self._spec = spec + + @property + def element_type(self) -> ScalarType: + return self._spec.dtype + + @property + def valid_shape(self) -> tuple[_Value, _Value]: + return ( + self._trace.index_const(self._spec.shape[0]), + self._trace.index_const(self._spec.shape[1]), + ) + + @property + def type_text(self) -> str: + return str(self._arg_value.type) + + def __getitem__(self, key): + if ( + not isinstance(key, tuple) + or len(key) != 2 + or not _is_index_like(key[0]) + or not isinstance(key[1], slice) + ): + raise TypeError("tile-template tracing only supports tile[row, col:] indexing") + row, col_slice = key + if col_slice.stop is not None or col_slice.step is not None: + raise TypeError("tile-template tracing only supports tile[row, col:] slices") + col = 0 if col_slice.start is None else col_slice.start + if not _is_index_like(col): + raise TypeError("tile-template tracing only supports integer/index column offsets") + _validate_static_bound(row, self._spec.shape[0], "row") + _validate_static_bound(col, self._spec.shape[1], "column") + return _TileSlice(self, row=row, col=col) + + +class _LoopStateView: + def __init__(self, names: tuple[str, ...], values: tuple[_Value, ...]): + self._values = dict(zip(names, values)) + + def __getattr__(self, name: str) -> _Value: + try: + return self._values[name] + except KeyError as exc: + raise AttributeError(name) from exc + + +class _LoopHandle: + def __init__( + self, + trace: "_TraceBuilder", + for_op, + iv: _Value, + iter_args: tuple[_Value, ...], + state_names: tuple[str, ...] = (), + ): + self._trace = trace + self._for_op = for_op + self.iv = iv + self.iter_args = iter_args + self._state_names = state_names + self.state = _LoopStateView(state_names, iter_args) if state_names else None + self.results: tuple[_Value, ...] = () + + def _finalize(self) -> None: + self.results = tuple(_Value(result) for result in self._for_op.results) + + def yield_state(self, **kwargs) -> None: + if not self._state_names: + raise RuntimeError("loop.yield_state(...) requires for_(..., state={...})") + missing = [name for name in self._state_names if name not in kwargs] + extra = [name for name in kwargs if name not in self._state_names] + if missing or extra: + pieces = [] + if missing: + pieces.append(f"missing: {', '.join(missing)}") + if extra: + pieces.append(f"unexpected: {', '.join(extra)}") + raise RuntimeError( + "loop.yield_state(...) must match loop state names exactly; " + + "; ".join(pieces) + ) + ordered = tuple(kwargs[name] for name in self._state_names) + self._trace._yield_loop_values(ordered, surface="loop.yield_state", from_named_state=True) + + +class _VecScopeCM: + def __init__(self, trace: "_TraceBuilder"): + self._trace = trace + + def __enter__(self): + self._trace._enter_vecscope() + return None + + def __exit__(self, exc_type, exc, tb): + self._trace._exit_vecscope(exc_type, exc, tb) + + +class _ForCM: + def __init__(self, trace: "_TraceBuilder", start, stop, step, iter_args, state): + self._trace = trace + self._start = start + self._stop = stop + self._step = step + self._iter_args = list(iter_args) if iter_args is not None else [] + self._state = tuple(state.items()) if state is not None else () + self._handle: _LoopHandle | None = None + + def __enter__(self): + self._handle = self._trace._enter_for( + self._start, + self._stop, + self._step, + self._iter_args, + self._state, + ) + if self._iter_args or self._state: + return self._handle + return self._handle.iv + + def __exit__(self, exc_type, exc, tb): + self._trace._exit_for(self._handle, exc_type, exc, tb) + + +class _TraceBuilder(TracingRuntime): + def __init__(self, descriptor: "TileTemplate", tile_specs: dict[str, TileSpec]): + super().__init__( + KernelModuleSpec( + function_name=descriptor.name, + target_arch=descriptor.target, + kernel_kind="vector", + mode="auto", + module_style=ModuleStyle.NESTED, + source_file=inspect.getsourcefile(descriptor.py_fn) or inspect.getfile(descriptor.py_fn), + source_line=getattr(descriptor.py_fn.__code__, "co_firstlineno", None), + ) + ) + self.descriptor = descriptor + self.tile_specs = tile_specs + self._const_cache: dict[tuple[int, str], _Value] = {} + self._tile_ptr_cache: dict[int, _Value] = {} + self._row_offset_cache: dict[tuple[str, str], _Value] = {} + self._loop_stack: list[dict] = [] + self._inside_vecscope = False + self._ordered_specs: list[tuple[str, TileSpec]] = [] + signature = inspect.signature(self.descriptor.py_fn) + self._signature_parameters = tuple(signature.parameters.items()) + + def compute_argument_types(self): + arg_types = [] + ordered_specs = [] + for param_name, param in self._signature_parameters: + if not _is_tile_annotation(param.annotation): + raise TypeError( + "tile-template tracing currently only supports Tile parameters; " + f"parameter {param_name!r} uses {param.annotation!r}" + ) + spec = self.tile_specs.get(param_name) + if spec is None: + raise ValueError(f"missing specialization for Tile parameter {param_name!r}") + ordered_specs.append((param_name, spec)) + arg_types.append(spec.mlir_type()) + self._ordered_specs = ordered_specs + return arg_types + + def bind_entry_arguments(self, entry_arguments): + args = [] + for arg_value, (_, spec) in zip(entry_arguments, self._ordered_specs): + args.append(_TileProxy(self, arg_value, spec)) + return tuple(args) + + def trace_entry(self, *args): + self.descriptor.py_fn(*args) + + def validate_trace_state(self): + if self._inside_vecscope: + raise RuntimeError("tile-template trace exited with an open vecscope block") + if self._loop_stack: + raise RuntimeError("tile-template trace exited with an open scf.for block") + + def vecscope(self) -> _VecScopeCM: + return _VecScopeCM(self) + + def for_(self, start, stop, *, step, iter_args=None, state=None) -> _ForCM: + if iter_args is not None and state is not None: + raise ValueError("for_() accepts either iter_args= or state=, not both") + if state is not None: + if not hasattr(state, "items"): + raise TypeError("for_(..., state=...) expects a mapping of name -> initial value") + for name in state: + if not isinstance(name, str) or not name: + raise TypeError("for_ state names must be non-empty strings") + return _ForCM(self, start, stop, step, iter_args, state) + + def yield_(self, *vals): + self._yield_loop_values(vals, surface="yield_", from_named_state=False) + + def _yield_loop_values(self, vals, *, surface: str, from_named_state: bool): + if not self._loop_stack: + raise RuntimeError(f"{surface}(...) may only be used inside a tile-template for_ block") + frame = self._loop_stack[-1] + if frame["kind"] != "for": + raise RuntimeError(f"{surface}(...) may only be used inside a tile-template for_ block") + if frame["state_names"] and not from_named_state: + raise RuntimeError( + f"{surface}(...) is ambiguous for tile-template for_ with named state; " + "use loop.yield_state(...) instead" + ) + if frame["yielded"]: + raise RuntimeError( + f"{surface}(...) may only be emitted once per tile-template for_ block" + ) + if len(vals) != len(frame["iter_args"]): + raise RuntimeError( + f"{surface}(...) expected {len(frame['iter_args'])} value(s), got {len(vals)}" + ) + coerced = tuple( + self._coerce_like(arg, expected.type_text) + for arg, expected in zip(vals, frame["iter_args"]) + ) + scf.YieldOp([val.value for val in coerced]) + frame["yielded"] = True + frame["yield_vals"] = coerced + + def index_const(self, value: int) -> _Value: + return self._const(value, _resolve(_index)) + + def scalar_const(self, value: int, dtype: ScalarType) -> _Value: + return self._const(value, _resolve(_scalar_descriptor(dtype))) + + def _const(self, value: int, mlir_type) -> _Value: + cache_key = (value, str(mlir_type)) + cached = self._const_cache.get(cache_key) + if cached is not None: + return cached + const = _Value(arith.ConstantOp(mlir_type, value).result, const_value=value) + self._const_cache[cache_key] = const + return const + + def ensure_tile_ptr(self, tile: _TileProxy) -> _Value: + cache_key = id(tile._arg_value) + cached = self._tile_ptr_cache.get(cache_key) + if cached is not None: + return cached + ptr_type = _resolve(_ptr(_scalar_descriptor(tile.element_type), tile._spec.memory_space)) + ptr_value = _Value(_pto.TileBufAddrOp(ptr_type, tile._arg_value).result) + self._tile_ptr_cache[cache_key] = ptr_value + return ptr_value + + def materialize_linear_offset(self, tile_slice: _TileSlice) -> _Value: + cols = tile_slice.tile._spec.shape[1] + row = self._coerce_index(tile_slice.row) + col = self._coerce_index(tile_slice.col) + if row.is_const and col.is_const: + return self.index_const(row.const_value * cols + col.const_value) + row_stride = self.index_const(cols) + row_off = self._materialize_row_offset(row, row_stride) + return _Value(_scalar.addi(row_off.value, col.value)) + + def _enter_vecscope(self): + if self._inside_vecscope: + raise RuntimeError( + "nested tile-template vecscope blocks are not supported in the current implementation" + ) + vecscope_op = _pto.VecScopeOp() + vecscope_block = vecscope_op.body.blocks.append() + vecscope_ip = InsertionPoint(vecscope_block) + vecscope_ip.__enter__() + self._loop_stack.append( + { + "kind": "vecscope", + "ip": vecscope_ip, + } + ) + self._inside_vecscope = True + + def _exit_vecscope(self, exc_type, exc, tb): + if not self._inside_vecscope: + raise RuntimeError("vecscope exit without matching enter") + frame = self._loop_stack.pop() + if frame["kind"] != "vecscope": + raise RuntimeError("tile-template vecscope stack corruption detected") + frame["ip"].__exit__(exc_type, exc, tb) + self._inside_vecscope = False + + def _enter_for(self, start, stop, step, iter_args, state_items) -> _LoopHandle: + start_val = self._coerce_index(start) + stop_val = self._coerce_index(stop) + step_val = self._coerce_index(step) + state_names = tuple(name for name, _ in state_items) + if state_names: + iter_arg_vals = tuple(self._coerce_value(arg) for _, arg in state_items) + else: + iter_arg_vals = tuple(self._coerce_value(arg) for arg in iter_args) + for_op = scf.ForOp( + start_val.value, + stop_val.value, + step_val.value, + [arg.value for arg in iter_arg_vals] if iter_arg_vals else None, + ) + loop_ip = InsertionPoint(for_op.body) + loop_ip.__enter__() + iv = _Value(for_op.induction_variable) + inner_iter_args = tuple(_Value(arg) for arg in for_op.inner_iter_args) + handle = _LoopHandle(self, for_op, iv, inner_iter_args, state_names=state_names) + self._loop_stack.append( + { + "kind": "for", + "handle": handle, + "ip": loop_ip, + "iter_args": inner_iter_args, + "state_names": state_names, + "yielded": False, + "yield_vals": (), + } + ) + return handle + + def _exit_for(self, handle: _LoopHandle | None, exc_type, exc, tb): + if handle is None: + raise RuntimeError("for_ exit without a loop handle") + frame = self._loop_stack.pop() + if frame["kind"] != "for" or frame["handle"] is not handle: + raise RuntimeError("tile-template for_ stack corruption detected") + if exc_type is None: + if frame["iter_args"] and not frame["yielded"]: + if frame["state_names"]: + raise RuntimeError( + "tile-template for_ with named state requires explicit loop.yield_state(...)" + ) + raise RuntimeError("tile-template for_ with iter_args requires explicit yield_(...)") + if not frame["iter_args"]: + scf.YieldOp([]) + frame["ip"].__exit__(exc_type, exc, tb) + if exc_type is not None: + return + handle._finalize() + + def _materialize_row_offset(self, row: _Value, row_stride: _Value) -> _Value: + if row.is_const and row_stride.is_const: + return self.index_const(row.const_value * row_stride.const_value) + cache_key = (str(row.value), str(row_stride.value)) + cached = self._row_offset_cache.get(cache_key) + if cached is not None: + return cached + result = _Value(_scalar.muli(row.value, row_stride.value)) + self._row_offset_cache[cache_key] = result + return result + + def _coerce_index(self, value) -> _Value: + coerced = self._coerce_value(value) + if coerced.type_text != str(_resolve(_index)): + raise TypeError(f"expected index value, got {coerced.type_text}") + return coerced + + def _coerce_value(self, value) -> _Value: + if isinstance(value, _Value): + return value + if isinstance(value, int): + return self.index_const(value) + if hasattr(value, "type"): + return _Value(value) + raise TypeError(f"unsupported tile-template scalar value {value!r}") + + def _coerce_like(self, value, ty: str) -> _Value: + coerced = self._coerce_value(value) + if coerced.type_text != ty: + raise TypeError(f"expected value of type {ty}, got {coerced.type_text}") + return coerced + + +@dataclass(frozen=True) +class TileTemplate: + py_fn: object + target: str + op: str + name: str + source_label: str + + def specialize(self, **tile_specs: TileSpec) -> "SpecializedTileTemplate": + return SpecializedTileTemplate(self, tile_specs) + + +class SpecializedTileTemplate(ModuleArtifact): + def __init__(self, descriptor: TileTemplate, tile_specs: dict[str, TileSpec]): + super().__init__( + descriptor.name, + module_factory=lambda: _TraceBuilder(descriptor, tile_specs).build_module(), + ) + self.descriptor = descriptor + self.tile_specs = tile_specs + + +def tile_template(*, target: str = "a5", op: str, name: str | None = None): + if target != "a5": + raise ValueError("tile-template tracing currently only supports target='a5'") + + def decorator(fn): + source_path = Path(inspect.getsourcefile(fn) or "") + descriptor_name = name or fn.__name__ + return TileTemplate( + py_fn=fn, + target=target, + op=op, + name=descriptor_name, + source_label=f"{source_path}:{fn.__name__}", + ) + + return decorator + + +def vecscope() -> _VecScopeCM: + return require_active_runtime("vecscope", expected_type=_TraceBuilder).vecscope() + + +def for_(start, stop, *, step, iter_args=None, state=None) -> _ForCM: + return require_active_runtime("for_", expected_type=_TraceBuilder).for_( + start, stop, step=step, iter_args=iter_args, state=state + ) + + +def yield_(*vals): + require_active_runtime("yield_", expected_type=_TraceBuilder).yield_(*vals) + + +def get_lanes(dtype: ScalarType) -> _Value: + return require_active_runtime("get_lanes", expected_type=_TraceBuilder).index_const(dtype.lanes) + + +def scalar_const(value: int, dtype: ScalarType) -> _Value: + return require_active_runtime("scalar_const", expected_type=_TraceBuilder).scalar_const(value, dtype) + + +def make_mask(dtype: ScalarType, remained) -> tuple[_MaskValue, _Value]: + trace = require_active_runtime("make_mask", expected_type=_TraceBuilder) + remained_val = trace._coerce_value(remained) + expected_scalar_ty = str(_resolve(_scalar_descriptor(_scalar_type_for_mask(dtype)))) + if remained_val.type_text != expected_scalar_ty: + raise TypeError( + f"tile-template tracing expects make_mask remained to use {expected_scalar_ty}, got {remained_val.type_text}" + ) + if dtype.mask_bits not in {8, 16, 32}: + raise ValueError(f"unsupported mask bit-width {dtype.mask_bits}") + mask_ty = _resolve(_mask_type(f"b{dtype.mask_bits}")) + scalar_ty = IntegerType.get_signless(dtype.mask_bits) + op_cls = getattr(_pto, f"PltB{dtype.mask_bits}Op", None) + if op_cls is None: + raise NotImplementedError( + f"pto.PltB{dtype.mask_bits}Op is not available in the current Python bindings" + ) + plt_op = op_cls(mask_ty, scalar_ty, remained_val.value) + lanes = trace.scalar_const(dtype.lanes, _scalar_type_for_mask(dtype)) + next_value = _Value(_scalar.subi(remained_val.value, lanes.value)) + return _MaskValue(plt_op.mask, dtype), next_value + + +def vlds(tile_slice: _TileSlice) -> _VectorValue: + trace = require_active_runtime("vlds", expected_type=_TraceBuilder) + if not isinstance(tile_slice, _TileSlice): + raise TypeError("tile-template tracing only supports vlds(tile[row, col:])") + ptr_value = trace.ensure_tile_ptr(tile_slice.tile) + offset = trace.materialize_linear_offset(tile_slice) + vector_ty = _resolve(_vreg_type(tile_slice.tile.element_type.lanes, _scalar_descriptor(tile_slice.tile.element_type))) + result = _pto.VldsOp(vector_ty, ptr_value.value, offset.value).result + return _VectorValue(result, tile_slice.tile.element_type) + + +def vadd(lhs: _VectorValue, rhs: _VectorValue, mask: _MaskValue) -> _VectorValue: + if lhs.dtype != rhs.dtype: + raise TypeError("tile-template tracing expects vadd operands to use the same dtype") + if lhs.dtype != mask.dtype: + raise TypeError("tile-template tracing expects vadd mask dtype to match vector dtype") + result = _pto.VaddOp(lhs.value.type, lhs.value, rhs.value, mask.value).result + return _VectorValue(result, lhs.dtype) + + +def vsts(vec: _VectorValue, tile_slice: _TileSlice, mask: _MaskValue) -> None: + trace = require_active_runtime("vsts", expected_type=_TraceBuilder) + if vec.dtype != mask.dtype: + raise TypeError("tile-template tracing expects vsts mask dtype to match vector dtype") + if vec.dtype != tile_slice.tile.element_type: + raise TypeError("tile-template tracing expects vsts destination dtype to match vector dtype") + ptr_value = trace.ensure_tile_ptr(tile_slice.tile) + offset = trace.materialize_linear_offset(tile_slice) + _pto.VstsOp(vec.value, ptr_value.value, offset.value, mask.value) + + +def _is_tile_annotation(annotation) -> bool: + if annotation is Tile: + return True + if isinstance(annotation, str): + return annotation == "Tile" or annotation.endswith(".Tile") + return getattr(annotation, "__name__", None) == "Tile" + + +def _is_index_like(value) -> bool: + return isinstance(value, int) or (isinstance(value, _Value) and value.type_text == str(_resolve(_index))) + + +def _validate_static_bound(value, upper_bound: int, label: str): + if isinstance(value, int): + if value < 0 or value >= upper_bound: + raise IndexError(f"{label} {value} is outside tile bound {upper_bound}") + return + if isinstance(value, _Value) and value.is_const: + concrete = value.const_value + if concrete < 0 or concrete >= upper_bound: + raise IndexError(f"{label} {concrete} is outside tile bound {upper_bound}") + + +def _scalar_descriptor(dtype: ScalarType): + descriptors = { + "f32": _float32, + "f16": _float16, + "bf16": Type.parse("bf16"), + "i8": _int8, + "i16": _int16, + "i32": _int32, + "i64": _int64, + } + descriptor = descriptors.get(dtype.name) + if descriptor is None: + raise ValueError(f"unsupported scalar dtype {dtype.name}") + return descriptor + + +def _scalar_type_for_mask(dtype: ScalarType) -> ScalarType: + if dtype.mask_bits == 8: + return i8 + if dtype.mask_bits == 16: + return i16 + if dtype.mask_bits == 32: + return i32 + raise ValueError(f"unsupported mask bit-width {dtype.mask_bits}") + + +__all__ = [ + "Tile", + "TileSpec", + "TileTemplate", + "SpecializedTileTemplate", + "ScalarType", + "f32", + "f16", + "bf16", + "i32", + "i16", + "i8", + "tile_template", + "vecscope", + "for_", + "yield_", + "get_lanes", + "scalar_const", + "make_mask", + "vlds", + "vadd", + "vsts", +] diff --git a/ptodsl/ptodsl/_tracing/__init__.py b/ptodsl/ptodsl/_tracing/__init__.py new file mode 100644 index 000000000..70901127d --- /dev/null +++ b/ptodsl/ptodsl/_tracing/__init__.py @@ -0,0 +1,40 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +"""Shared tracing runtime building blocks for PTODSL frontends.""" + +from .active import ( + activate_runtime, + activate_session, + current_runtime, + current_session, + require_active_runtime, + require_active_session, +) +from .artifacts import ModuleArtifact +from .module_builder import KernelModuleSpec, ModuleStyle, create_kernel_module +from .runtime import CallbackTracingRuntime, SignatureTracingRuntime, TracingRuntime +from .session import HelperFunctionSpec, SubkernelTraceFrame, TraceSession + +__all__ = [ + "activate_runtime", + "activate_session", + "current_runtime", + "current_session", + "require_active_runtime", + "require_active_session", + "ModuleArtifact", + "KernelModuleSpec", + "ModuleStyle", + "create_kernel_module", + "CallbackTracingRuntime", + "SignatureTracingRuntime", + "TracingRuntime", + "HelperFunctionSpec", + "SubkernelTraceFrame", + "TraceSession", +] diff --git a/ptodsl/ptodsl/_tracing/active.py b/ptodsl/ptodsl/_tracing/active.py new file mode 100644 index 000000000..0a32a2b52 --- /dev/null +++ b/ptodsl/ptodsl/_tracing/active.py @@ -0,0 +1,86 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +"""Active tracing-runtime stack shared by PTODSL frontends.""" + +from __future__ import annotations + +from contextlib import contextmanager + +_ACTIVE_RUNTIME_STACK = [] +_ACTIVE_SESSION_STACK = [] + + +@contextmanager +def activate_runtime(runtime): + """Push *runtime* as the current active tracing runtime.""" + _ACTIVE_RUNTIME_STACK.append(runtime) + try: + yield runtime + finally: + popped = _ACTIVE_RUNTIME_STACK.pop() + if popped is not runtime: + raise RuntimeError("PTODSL active tracing runtime stack corruption detected") + + +@contextmanager +def activate_session(session): + """Push *session* as the current active trace session.""" + _ACTIVE_SESSION_STACK.append(session) + try: + yield session + finally: + popped = _ACTIVE_SESSION_STACK.pop() + if popped is not session: + raise RuntimeError("PTODSL active trace-session stack corruption detected") + + +def current_runtime(expected_type=None): + """Return the current active tracing runtime, or ``None`` if inactive.""" + if not _ACTIVE_RUNTIME_STACK: + return None + runtime = _ACTIVE_RUNTIME_STACK[-1] + if expected_type is not None and not isinstance(runtime, expected_type): + return None + return runtime + + +def current_session(): + """Return the current active trace session, or ``None`` if inactive.""" + if not _ACTIVE_SESSION_STACK: + return None + return _ACTIVE_SESSION_STACK[-1] + + +def require_active_runtime(surface: str, expected_type=None): + """Return the active runtime or raise a surface-specific error.""" + runtime = current_runtime(expected_type=expected_type) + if runtime is None: + raise RuntimeError( + f"{surface}() may only be used while tracing a compatible PTODSL kernel" + ) + return runtime + + +def require_active_session(surface: str): + """Return the active trace session or raise a surface-specific error.""" + session = current_session() + if session is None: + raise RuntimeError( + f"{surface}() may only be used while tracing a compatible PTODSL kernel" + ) + return session + + +__all__ = [ + "activate_runtime", + "activate_session", + "current_runtime", + "current_session", + "require_active_runtime", + "require_active_session", +] diff --git a/ptodsl/ptodsl/_tracing/artifacts.py b/ptodsl/ptodsl/_tracing/artifacts.py new file mode 100644 index 000000000..650d0517f --- /dev/null +++ b/ptodsl/ptodsl/_tracing/artifacts.py @@ -0,0 +1,58 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +"""Reusable module-backed artifacts for PTODSL tracing frontends.""" + +from __future__ import annotations + +from pathlib import Path + + +class ModuleArtifact: + """ + Cached module-backed artifact. + + Subclasses may either pass an eager ``module`` or a lazy ``module_factory``. + """ + + def __init__(self, py_name: str, *, module=None, module_factory=None): + self._py_name = py_name + self._cached_module = module + self._module_factory = module_factory + + def build(self): + """Return the cached ``mlir.ir.Module``.""" + if self._cached_module is None: + if self._module_factory is None: + raise RuntimeError(f"{self._py_name} has no module factory") + self._cached_module = self._module_factory() + return self._cached_module + + def mlir_module(self): + """Return the cached ``mlir.ir.Module``.""" + return self.build() + + def mlir_text(self) -> str: + """Return the textual MLIR form.""" + return str(self.build()) + + def verify(self) -> None: + """Verify the cached module operation.""" + self.build().operation.verify() + + def emit(self, path: str | Path) -> None: + """Write the textual MLIR form to *path*.""" + Path(path).write_text(self.mlir_text(), encoding="utf-8") + + def __str__(self): + return self.mlir_text() + + def __repr__(self): + return self.mlir_text() + + +__all__ = ["ModuleArtifact"] diff --git a/ptodsl/ptodsl/_tracing/control_flow.py b/ptodsl/ptodsl/_tracing/control_flow.py new file mode 100644 index 000000000..c2c370fe7 --- /dev/null +++ b/ptodsl/ptodsl/_tracing/control_flow.py @@ -0,0 +1,92 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +"""Tracing-time helpers for structured PTODSL control-flow lowering.""" + +from __future__ import annotations + +from dataclasses import dataclass + +from .._runtime_index_ops import coerce_runtime_index +from .._surface_values import unwrap_surface_value + +from mlir.dialects import scf +from mlir.ir import InsertionPoint + + +@dataclass +class CarryLoopFrame: + """Active loop-carry lowering frame for one authored ``pto.for_().carry()``.""" + + for_op: object + insertion_point: InsertionPoint + state_names: tuple[str, ...] + state_templates: tuple[object, ...] + yielded: bool = False + + +def build_carry_loop_frame(start, stop, step, state_items) -> CarryLoopFrame: + """Materialize one ``scf.for`` carry loop and enter its body insertion point.""" + state_items = tuple(state_items) + state_names = tuple(name for name, _ in state_items) + state_templates = tuple(value for _, value in state_items) + iter_args = [unwrap_surface_value(value) for value in state_templates] + for_op = scf.ForOp( + _coerce_index(start), + _coerce_index(stop), + _coerce_index(step), + iter_args, + ) + insertion_point = InsertionPoint(for_op.body) + insertion_point.__enter__() + return CarryLoopFrame( + for_op=for_op, + insertion_point=insertion_point, + state_names=state_names, + state_templates=state_templates, + ) + + +def yield_carry_loop_state(frame: CarryLoopFrame, **kwargs) -> None: + """Validate one ``loop.update(...)`` call and emit the matching ``scf.yield``.""" + missing = [name for name in frame.state_names if name not in kwargs] + extra = [name for name in kwargs if name not in frame.state_names] + if missing or extra: + pieces = [] + if missing: + pieces.append(f"missing: {', '.join(missing)}") + if extra: + pieces.append(f"unexpected: {', '.join(extra)}") + raise RuntimeError("loop.update(...) must match carry names exactly; " + "; ".join(pieces)) + if frame.yielded: + raise RuntimeError("loop.update(...) may only be called once per loop body") + scf.YieldOp([unwrap_surface_value(kwargs[name]) for name in frame.state_names]) + frame.yielded = True + + +def finish_carry_loop_frame(frame: CarryLoopFrame, exc_type, exc, tb) -> None: + """Leave one active carry-loop frame and close its insertion point.""" + try: + if exc_type is None and not frame.yielded: + raise RuntimeError( + "pto.for_(...).carry(...) requires loop.update(...) before leaving the loop body" + ) + finally: + frame.insertion_point.__exit__(exc_type, exc, tb) + + +def _coerce_index(value): + raw_value = unwrap_surface_value(value) + return coerce_runtime_index(raw_value, context="pto.for_(...).carry(...) loop bound") + + +__all__ = [ + "CarryLoopFrame", + "build_carry_loop_frame", + "yield_carry_loop_state", + "finish_carry_loop_frame", +] diff --git a/ptodsl/ptodsl/_tracing/module_builder.py b/ptodsl/ptodsl/_tracing/module_builder.py new file mode 100644 index 000000000..baba789c6 --- /dev/null +++ b/ptodsl/ptodsl/_tracing/module_builder.py @@ -0,0 +1,88 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +"""Common MLIR module/container builders for PTODSL tracing frontends.""" + +from __future__ import annotations + +from dataclasses import dataclass +from enum import Enum + +from mlir.dialects import func +from mlir.ir import Attribute, InsertionPoint, Module, Operation, StringAttr, UnitAttr + + +class ModuleStyle(str, Enum): + """Supported top-level PTODSL module layouts.""" + + FLAT_AICORE = "flat_aicore" + NESTED = "nested" + + +@dataclass(frozen=True) +class KernelModuleSpec: + """Declarative description of a traced PTODSL kernel container.""" + + function_name: str + target_arch: str + kernel_kind: str + mode: str = "auto" + insert_sync: bool | None = None + module_style: ModuleStyle = ModuleStyle.NESTED + source_file: str | None = None + source_line: int | None = None + + +def _kernel_kind_attr(kernel_kind: str): + return Attribute.parse(f"#pto.kernel_kind<{kernel_kind}>") + + +def _build_flat_aicore_module(spec: KernelModuleSpec, arg_types): + module = Module.create() + module.operation.attributes["pto.target_arch"] = StringAttr.get(spec.target_arch) + module.operation.attributes["pto.kernel_kind"] = _kernel_kind_attr(spec.kernel_kind) + module.operation.attributes["pto.mode"] = StringAttr.get(spec.mode) + fn_ty = func.FunctionType.get(arg_types, []) + with InsertionPoint(module.body): + ir_fn = func.FuncOp(spec.function_name, fn_ty) + ir_fn.attributes["pto.aicore"] = UnitAttr.get() + return module, ir_fn + + +def _build_nested_module(spec: KernelModuleSpec, arg_types): + outer = Module.create() + outer.operation.attributes["pto.target_arch"] = StringAttr.get(spec.target_arch) + outer.operation.attributes["pto.mode"] = StringAttr.get(spec.mode) + + with InsertionPoint(outer.body): + inner_op = Operation.create("builtin.module", regions=1) + inner_op.attributes["pto.target_arch"] = StringAttr.get(spec.target_arch) + inner_op.attributes["pto.kernel_kind"] = _kernel_kind_attr(spec.kernel_kind) + inner_op.attributes["pto.mode"] = StringAttr.get(spec.mode) + inner_body = inner_op.regions[0].blocks.append() + + with InsertionPoint(inner_body): + fn_ty = func.FunctionType.get(arg_types, []) + ir_fn = func.FuncOp(spec.function_name, fn_ty) + + return outer, ir_fn + + +def create_kernel_module(spec: KernelModuleSpec, arg_types): + """Create the top-level module and entry function for *spec*.""" + if spec.module_style == ModuleStyle.FLAT_AICORE: + return _build_flat_aicore_module(spec, arg_types) + if spec.module_style == ModuleStyle.NESTED: + return _build_nested_module(spec, arg_types) + raise ValueError(f"unsupported PTODSL module style {spec.module_style!r}") + + +__all__ = [ + "KernelModuleSpec", + "ModuleStyle", + "create_kernel_module", +] diff --git a/ptodsl/ptodsl/_tracing/runtime.py b/ptodsl/ptodsl/_tracing/runtime.py new file mode 100644 index 000000000..630cd9931 --- /dev/null +++ b/ptodsl/ptodsl/_tracing/runtime.py @@ -0,0 +1,131 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +"""Base tracing runtimes shared by PTODSL frontends.""" + +from __future__ import annotations + +from .active import activate_runtime, activate_session, require_active_session +from .module_builder import create_kernel_module +from .session import TraceSession +from .._bootstrap import make_context +from .._types import _resolve + +from mlir.dialects import func +from mlir.ir import InsertionPoint, Location + + +class TracingRuntime: + """Shared module-building runtime for tracing-based PTODSL frontends.""" + + def __init__(self, module_spec): + self.module_spec = module_spec + + def compute_argument_types(self): + """Return the MLIR entry argument types for this runtime.""" + raise NotImplementedError + + def bind_entry_arguments(self, entry_arguments): + """Wrap raw entry-block arguments into surface values.""" + return tuple(entry_arguments) + + def trace_entry(self, *args): + """Emit the traced function body using wrapped entry arguments.""" + raise NotImplementedError + + def validate_trace_state(self): + """Validate runtime-local tracing state before the function returns.""" + + def emit_return(self): + """Emit the function return terminator.""" + func.ReturnOp([]) + + def verify_module(self, module): + """Verify the completed module.""" + module.operation.verify() + + def create_session(self, module, entry_function): + """Create the shared trace session for this build.""" + return TraceSession(self.module_spec, module, entry_function) + + def initialize_session(self, session, entry_block): + """Populate runtime-specific session state before tracing.""" + session.bind_entry_block(entry_block) + + def finalize_session(self, session): + """Finalize runtime-specific session state after tracing.""" + + def dispatch_subkernel_call(self, subkernel, *args, **kwargs): + """Dispatch a decorated PTODSL subkernel call in the active trace.""" + session = require_active_session(f"@pto.{subkernel.spec.role.value}") + if subkernel.spec.role.value in {"cube", "simd"}: + return session.lower_inline_subkernel(subkernel, *args, **kwargs) + if subkernel.spec.role.value == "simt": + return session.lower_simt_helper_subkernel(subkernel, *args, **kwargs) + return subkernel.emit_body(*args, **kwargs) + + def build_module(self): + """Materialize the full MLIR module for this runtime.""" + ctx = make_context() + with ctx, Location.unknown(): + arg_types = list(self.compute_argument_types()) + module, ir_fn = create_kernel_module(self.module_spec, arg_types) + session = self.create_session(module, ir_fn) + entry = ir_fn.add_entry_block() + with InsertionPoint(entry), activate_runtime(self), activate_session(session): + self.initialize_session(session, entry) + args = self.bind_entry_arguments(entry.arguments) + self.trace_entry(*args) + self.validate_trace_state() + self.emit_return() + self.finalize_session(session) + session.validate_final_state() + self.verify_module(module) + return module + + +class CallbackTracingRuntime(TracingRuntime): + """Small tracing runtime for eager callback-style module materialization.""" + + def __init__(self, module_spec, arg_types, callback): + super().__init__(module_spec) + self._arg_types = tuple(arg_types) + self._callback = callback + + def compute_argument_types(self): + return tuple(_resolve(arg_type) for arg_type in self._arg_types) + + def trace_entry(self, *args): + self._callback(*args) + + +class SignatureTracingRuntime(TracingRuntime): + """Tracing runtime that binds a parsed PTODSL kernel signature.""" + + def __init__(self, module_spec, kernel_signature, callback, *, constexpr_bindings=None): + super().__init__(module_spec) + self._kernel_signature = kernel_signature + self._callback = callback + self._constexpr_bindings = dict(constexpr_bindings or {}) + + def compute_argument_types(self): + return self._kernel_signature.compute_entry_arg_types() + + def bind_entry_arguments(self, entry_arguments): + return self._kernel_signature.bind_entry_arguments(entry_arguments) + + def trace_entry(self, *args): + kwargs = self._kernel_signature.default_constexpr_bindings() + kwargs.update(self._constexpr_bindings) + self._callback(*args, **kwargs) + + +__all__ = [ + "CallbackTracingRuntime", + "SignatureTracingRuntime", + "TracingRuntime", +] diff --git a/ptodsl/ptodsl/_tracing/session.py b/ptodsl/ptodsl/_tracing/session.py new file mode 100644 index 000000000..ac2da5f9d --- /dev/null +++ b/ptodsl/ptodsl/_tracing/session.py @@ -0,0 +1,215 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +"""Trace-session objects shared by PTODSL tracing runtimes.""" + +from __future__ import annotations + +from contextlib import contextmanager +from dataclasses import dataclass + +from .control_flow import ( + build_carry_loop_frame, + finish_carry_loop_frame, + yield_carry_loop_state, +) +from .._surface_values import unwrap_surface_value, wrap_like_surface_value + +from mlir.dialects import arith, func +from mlir.dialects import pto as _pto +from mlir.ir import InsertionPoint, IntegerType, UnitAttr + + +@dataclass(frozen=True) +class HelperFunctionSpec: + """Declarative description of a helper function emitted during tracing.""" + + symbol_name: str + arg_types: tuple + result_types: tuple = () + attributes: tuple[tuple[str, object], ...] = () + + +@dataclass(frozen=True) +class SubkernelTraceFrame: + """Active inline-lowering frame for one PTODSL subkernel call.""" + + role: str + symbol_name: str + target: str + + +class TraceSession: + """Shared per-build state for a traced PTODSL module.""" + + def __init__(self, module_spec, module, entry_function): + self.module_spec = module_spec + self.module = module + self.entry_function = entry_function + self.entry_block = None + self._function_stack = [entry_function] + self._function_symbol_table = entry_function.operation.parent.regions[0].blocks[0] + self._helpers: dict[str, object] = {} + self._subkernel_stack: list[SubkernelTraceFrame] = [] + self._carry_loop_stack = [] + + @property + def current_function(self): + return self._function_stack[-1] + + @property + def current_subkernel(self): + if not self._subkernel_stack: + return None + return self._subkernel_stack[-1] + + @property + def subkernel_stack_depth(self): + return len(self._subkernel_stack) + + @property + def current_carry_loop(self): + if not self._carry_loop_stack: + return None + return self._carry_loop_stack[-1] + + def bind_entry_block(self, entry_block) -> None: + """Record the root entry block for the active trace.""" + self.entry_block = entry_block + + @contextmanager + def enter_function(self, ir_fn): + """Push *ir_fn* as the current active function in this session.""" + self._function_stack.append(ir_fn) + try: + yield ir_fn + finally: + popped = self._function_stack.pop() + if popped is not ir_fn: + raise RuntimeError("PTODSL trace-session function stack corruption detected") + + @contextmanager + def enter_inline_subkernel(self, role: str, symbol_name: str, target: str): + """Push one inline subkernel frame onto the active tracing stack.""" + frame = SubkernelTraceFrame( + role=role, + symbol_name=symbol_name, + target=target, + ) + self._subkernel_stack.append(frame) + try: + yield frame + finally: + popped = self._subkernel_stack.pop() + if popped is not frame: + raise RuntimeError("PTODSL trace-session subkernel stack corruption detected") + + @contextmanager + def enter_subkernel(self, subkernel): + """Push *subkernel* as the current active inline-lowering frame.""" + with self.enter_inline_subkernel( + subkernel.spec.role.value, + subkernel.spec.symbol_name, + subkernel.spec.target, + ) as frame: + yield frame + + def lower_inline_subkernel(self, subkernel, *args, **kwargs): + """Lower one inline PTODSL subkernel call through the shared session.""" + with self.enter_subkernel(subkernel): + return subkernel.emit_body(*args, **kwargs) + + def begin_carry_loop(self, start, stop, step, state_items): + """Materialize one authored ``pto.for_(...).carry(...)`` loop body.""" + frame = build_carry_loop_frame(start, stop, step, state_items) + self._carry_loop_stack.append(frame) + return frame + + def update_carry_loop(self, frame, **kwargs): + """Emit the one legal ``loop.update(...)`` for the active carry loop.""" + active = self.current_carry_loop + if active is None or active is not frame: + raise RuntimeError("loop.update(...) may only be called inside the active carry loop body") + yield_carry_loop_state(frame, **kwargs) + + def finish_carry_loop(self, frame, exc_type, exc, tb): + """Finalize one active authored carry loop and close its body insertion point.""" + if not self._carry_loop_stack: + raise RuntimeError("carry-loop exit without a matching active PTODSL trace-session frame") + popped = self._carry_loop_stack.pop() + if popped is not frame: + raise RuntimeError("PTODSL trace-session carry-loop stack corruption detected") + finish_carry_loop_frame(frame, exc_type, exc, tb) + + def lower_simt_helper_subkernel(self, subkernel, *args, **kwargs): + """Lower one ``@pto.simt`` call through a dedicated helper function.""" + outer_frame = self.current_subkernel + if outer_frame is not None and outer_frame.role == "simt": + raise RuntimeError("@pto.simt helper lowering does not support nested SIMT helper calls") + + arg_templates = tuple(args) + arg_types = tuple(unwrap_surface_value(arg).type for arg in arg_templates) + helper_spec = HelperFunctionSpec( + symbol_name=subkernel.spec.symbol_name, + arg_types=arg_types, + attributes=(("pto.simt_entry", UnitAttr.get()),), + ) + helper_fn, created = self.get_or_create_helper_function(helper_spec) + + if created: + entry_block = helper_fn.add_entry_block() + wrapped_args = tuple( + wrap_like_surface_value(template, value) + for template, value in zip(arg_templates, entry_block.arguments) + ) + with self.enter_function(helper_fn), self.enter_subkernel(subkernel), InsertionPoint(entry_block): + subkernel.emit_body(*wrapped_args, **kwargs) + func.ReturnOp([]) + + i32 = IntegerType.get_signless(32) + dim_z = arith.ConstantOp(i32, 1).result + dim_y = arith.ConstantOp(i32, 1).result + dim_x = arith.ConstantOp(i32, 1).result + _pto.StoreVfSimtInfoOp(dim_z, dim_y, dim_x) + func.CallOp(helper_fn, [unwrap_surface_value(arg) for arg in arg_templates]) + + def lookup_helper(self, symbol_name: str): + """Return a previously declared helper function, or ``None``.""" + return self._helpers.get(symbol_name) + + def get_or_create_helper_function(self, spec: HelperFunctionSpec): + """ + Look up or create a helper ``func.func`` in the current symbol table. + + Returns ``(helper_fn, created)`` where *created* reports whether a new + symbol was emitted in this trace session. + """ + helper = self._helpers.get(spec.symbol_name) + if helper is not None: + return helper, False + + fn_ty = func.FunctionType.get(list(spec.arg_types), list(spec.result_types)) + with InsertionPoint(self._function_symbol_table): + helper = func.FuncOp(spec.symbol_name, fn_ty) + for attr_name, attr_value in spec.attributes: + helper.attributes[attr_name] = attr_value + self._helpers[spec.symbol_name] = helper + return helper, True + + def validate_final_state(self) -> None: + """Check that tracing-time session stacks were fully unwound.""" + if self._subkernel_stack: + raise RuntimeError("PTODSL trace-session exited with an open subkernel lowering frame") + if self._carry_loop_stack: + raise RuntimeError("PTODSL trace-session exited with an open loop-carry lowering frame") + + +__all__ = [ + "HelperFunctionSpec", + "SubkernelTraceFrame", + "TraceSession", +] diff --git a/ptodsl/ptodsl/_types.py b/ptodsl/ptodsl/_types.py new file mode 100644 index 000000000..7423613f1 --- /dev/null +++ b/ptodsl/ptodsl/_types.py @@ -0,0 +1,475 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +""" +Lazy MLIR type descriptors and eager type constructors. + +Type descriptors (``_DType`` subclasses) can be created *before* any MLIR +Context exists – they only resolve to concrete ``mlir.ir.Type`` objects when +``_resolve()`` is called inside an active context. This lets users write:: + + def softmax(arg0: pto.ptr(pto.float32, "GM"), ...): + ... + +where the annotation is evaluated at *import* time (no active context), and +the actual type is materialised later by the ``@pto.jit`` decorator. +""" + +from ._bootstrap import make_context # ensure MLIR is on sys.path + +from mlir.dialects import pto as _pto +from mlir.dialects import arith +from mlir.dialects.builtin import UnrealizedConversionCastOp +from mlir.ir import ( + BF16Type, + F16Type, + F32Type, + Float8E4M3FNType, + Float8E5M2Type, + FloatAttr, + IndexType, + IntegerType, + ShapedType, + Type, +) + +# ── Address-space name → AddressSpace enum ─────────────────────────────────── +_ADDR_SPACE = { + "ub": _pto.AddressSpace.VEC, # UB == unified buffer == VEC in PTO + "gm": _pto.AddressSpace.GM, + "vec": _pto.AddressSpace.VEC, + "mat": _pto.AddressSpace.MAT, + "left": _pto.AddressSpace.LEFT, + "right": _pto.AddressSpace.RIGHT, + "acc": _pto.AddressSpace.ACC, + "bias": _pto.AddressSpace.BIAS, + "scaling": _pto.AddressSpace.SCALING, + "GM": _pto.AddressSpace.GM, + "UB": _pto.AddressSpace.VEC, + "VEC": _pto.AddressSpace.VEC, + "MAT": _pto.AddressSpace.MAT, + "LEFT": _pto.AddressSpace.LEFT, + "RIGHT": _pto.AddressSpace.RIGHT, + "ACC": _pto.AddressSpace.ACC, + "BIAS": _pto.AddressSpace.BIAS, + "SCALING": _pto.AddressSpace.SCALING, +} + + +# ── Lazy type descriptor base ───────────────────────────────────────────────── + +class _DType: + """Deferred MLIR type: only resolves inside an active MLIR context.""" + + def __init__(self, factory): + self._factory = factory + + def resolve(self) -> Type: + return self._factory() + + def __call__(self, value): + target_type = self.resolve() + kind = _classify_scalar_type(target_type) + if kind == "float": + return arith.ConstantOp(target_type, _parse_float_attr(target_type, value)).result + if kind == "integer": + return _materialize_integer_literal(target_type, value) + raise TypeError(f"unsupported eager constructor target type {target_type}") + + def __repr__(self): + return f"" + + +class _PtrDescriptor(_DType): + def __init__(self, elem, space: str): + self._elem = elem + self._space = space + + def resolve(self) -> Type: + elem = _ensure_non_storage_only_dtype(self._elem, context="pto.ptr(...)") + space_enum = _normalize_address_space(self._space) + if space_enum is None: + raise ValueError( + f"Unknown address space '{self._space}'; " + f"known: {list(_ADDR_SPACE)}" + ) + space_attr = _pto.AddressSpaceAttr.get(space_enum) + try: + return _pto.PtrType.get(elem, memory_space=space_attr) + except TypeError: + ptr_get_impl = getattr(_pto, "_ptr_type_get_impl", None) + if ptr_get_impl is None: + raise + if space_enum != _pto.AddressSpace.GM: + raise TypeError( + "The current PTO Python bindings only expose the default-GM " + "PtrType builder. Non-GM pointer construction is not " + "available through ptodsl._types.ptr(...) yet." + ) + return ptr_get_impl(elem) + + def __repr__(self): + return f"" + + +class _VRegDescriptor(_DType): + def __init__(self, lanes: int, elem): + self._lanes = lanes + self._elem = elem + + def resolve(self) -> Type: + elem = _ensure_non_storage_only_dtype(self._elem, context="pto.vreg_type(...)") + vreg_type_cls = getattr(_pto, "VRegType", None) + if vreg_type_cls is None: + raise TypeError( + "The current PTO Python bindings do not expose VRegType. " + "Rebuild the PTO Python extension before using pto.vreg_type(...)." + ) + return vreg_type_cls.get(self._lanes, elem) + + def __repr__(self): + return f"" + + +class _MaskDescriptor(_DType): + def __init__(self, bits: str): + self._bits = bits + + def resolve(self) -> Type: + mask_type_cls = getattr(_pto, "MaskType", None) + if mask_type_cls is None: + raise TypeError( + "The current PTO Python bindings do not expose MaskType. " + "Rebuild the PTO Python extension before using pto.mask_type(...)." + ) + return mask_type_cls.get(self._bits) + + def __repr__(self): + return f"" + + +def _resolve(dtype) -> Type: + """Coerce a ``_DType`` descriptor or a concrete ``mlir.ir.Type`` to a Type.""" + if isinstance(dtype, _DType): + return dtype.resolve() + return dtype # already an mlir.ir.Type + + +def _classify_scalar_type(type_obj): + if F32Type.isinstance(type_obj) or F16Type.isinstance(type_obj) or BF16Type.isinstance(type_obj): + return "float" + if IndexType.isinstance(type_obj) or IntegerType.isinstance(type_obj): + return "integer" + return None + + +def _isinstance_pto_type(type_obj, type_name: str) -> bool: + cls = getattr(_pto, type_name, None) + if cls is None: + return False + try: + return cls.isinstance(type_obj) + except Exception: + return False + + +def _classify_storage_dtype(type_obj): + if _classify_scalar_type(type_obj) is not None: + return "compute" + if Float8E4M3FNType.isinstance(type_obj) or Float8E5M2Type.isinstance(type_obj): + return "storage_only" + if any(_isinstance_pto_type(type_obj, name) for name in ("HiF8Type", "F4E1M2x2Type", "F4E2M1x2Type")): + return "storage_only" + return "other" + + +def _is_storage_only_dtype(type_obj): + return _classify_storage_dtype(type_obj) == "storage_only" + + +def _is_storage_only_authored_dtype(dtype) -> bool: + if isinstance(dtype, _DType): + return dtype in _STORAGE_ONLY_DTYPE_DESCRIPTORS + return _is_storage_only_dtype(_resolve(dtype)) + + +def _ensure_tensor_storage_dtype(dtype, *, context: str): + type_obj = _resolve(dtype) + category = _classify_storage_dtype(type_obj) + if category not in {"compute", "storage_only"}: + raise TypeError(f"{context} does not support element type {type_obj}") + return type_obj + + +def _ensure_non_storage_only_dtype(dtype, *, context: str): + type_obj = _resolve(dtype) + if _is_storage_only_dtype(type_obj): + raise TypeError( + f"{context} does not accept storage-only low-precision type {type_obj}; " + "these dtypes are only supported in Tile / TensorView / PartitionTensorView construction" + ) + return type_obj + + +def _ensure_non_storage_only_authored_dtype(dtype, *, context: str): + if _is_storage_only_authored_dtype(dtype): + raise TypeError( + f"{context} does not accept storage-only low-precision types; " + "these dtypes are only supported in Tile / TensorView / PartitionTensorView construction" + ) + return dtype + + +def _integer_signedness(type_obj): + if not IntegerType.isinstance(type_obj): + raise TypeError(f"expected integer type, got {type_obj}") + text = str(type_obj) + if text.startswith("si"): + return "signed" + if text.startswith("ui"): + return "unsigned" + return "signless" + + +def _signless_integer_type(type_obj): + if not IntegerType.isinstance(type_obj): + raise TypeError(f"expected integer type, got {type_obj}") + return IntegerType.get_signless(IntegerType(type_obj).width) + + +def _strip_integer_signedness(value): + value_type = getattr(value, "type", None) + if value_type is None or not IntegerType.isinstance(value_type): + return value + signless_type = _signless_integer_type(value_type) + if value_type == signless_type: + return value + return UnrealizedConversionCastOp([signless_type], [value]).results[0] + + +def _restore_integer_signedness(value, target_type): + if not IntegerType.isinstance(target_type): + raise TypeError(f"expected integer target type, got {target_type}") + signless_type = _signless_integer_type(target_type) + if target_type == signless_type: + return value + return UnrealizedConversionCastOp([target_type], [value]).results[0] + + +def _materialize_integer_literal(target_type, value): + if not IntegerType.isinstance(target_type): + raise TypeError(f"unsupported eager integer constructor target type {target_type}") + signless_type = _signless_integer_type(target_type) + raw_value = _parse_integer_value(value, target_type=target_type) + constant = arith.ConstantOp(signless_type, raw_value).result + if target_type == signless_type: + return constant + return _restore_integer_signedness(constant, target_type) + + +def _parse_integer_value(value, *, target_type=None): + if isinstance(value, bool): + raise TypeError("eager scalar constructors do not accept bool values") + if isinstance(value, int): + return value + if isinstance(value, str): + text = value.strip() + return _parse_integer_text(text) + raise TypeError(f"cannot materialize {value!r} as an integer constant of type {target_type}") + + +def _parse_integer_text(text: str): + if text.startswith(("0x", "0X", "-0x", "-0X")): + return int(text, 16) + return int(text, 0) + + +def _parse_float_attr(target_type, value): + if isinstance(value, bool): + raise TypeError("eager scalar constructors do not accept bool values") + if isinstance(value, str): + text = value.strip() + lower = text.lower() + if lower in {"inf", "+inf", "-inf", "nan"}: + numeric = float(lower) + elif text.startswith(("0x", "0X")): + return _float_attr_from_bit_pattern(target_type, text) + else: + numeric = float(text) + else: + numeric = float(value) + return FloatAttr.get(target_type, numeric) + + +def _float_attr_from_bit_pattern(target_type, text): + import math + import struct + + if F16Type.isinstance(target_type): + bits = int(text, 16) & 0xFFFF + as_bytes = bits.to_bytes(2, byteorder="little", signed=False) + numeric = struct.unpack(" _PtrDescriptor: + """Return a lazy descriptor for ``!pto.ptr``.""" + return _PtrDescriptor(elem, space) + + +def vreg_type(lanes: int, elem) -> _VRegDescriptor: + """Return a lazy descriptor for ``!pto.vreg``.""" + return _VRegDescriptor(lanes, elem) + + +def mask_type(bits: str = "b32") -> _MaskDescriptor: + """Return a lazy descriptor for ``!pto.mask``.""" + return _MaskDescriptor(bits) + + +def tile_buf_type(shape, dtype, valid_shape, *, + blayout: str = "RowMajor", + address_space: str = "ub", + slayout: str = "NoneBox", + fractal_size: int = 512, + pad: str = "Null") -> Type: + """ + Construct a ``!pto.tile_buf<…>`` type via the Python bindings. + + ``valid_shape`` entries may be ``-1`` for dynamic (``?``) dimensions. + ``blayout="ColMajor"`` prints as ``blayout=col_major``. + + Requires an active MLIR context. + """ + elem = _ensure_tensor_storage_dtype(dtype, context="pto.tile_buf_type(...)") + space_enum = _normalize_address_space(address_space) + if space_enum is None: + raise ValueError( + f"Unknown address_space '{address_space}'; known: {list(_ADDR_SPACE)}" + ) + space_attr = _pto.AddressSpaceAttr.get(space_enum) + cfg = _pto.TileBufConfigAttr.get( + _pto.BLayoutAttr.get(getattr(_pto.BLayout, blayout)), + _pto.SLayoutAttr.get(getattr(_pto.SLayout, slayout)), + fractal_size, + _pto.PadValueAttr.get(getattr(_pto.PadValue, pad)), + ) + return _pto.TileBufType.get(shape, elem, space_attr, valid_shape, cfg) + + +def tensor_view_type(rank: int, elem) -> Type: + """``!pto.tensor_view`` with *rank* all-dynamic dims.""" + return _pto.TensorViewType.get(rank, _ensure_tensor_storage_dtype(elem, context="pto.tensor_view_type(...)")) + + +def tensor_view_type_from_dims(dims, elem) -> Type: + """``!pto.tensor_view`` when every dimension is static.""" + resolved_elem = _ensure_tensor_storage_dtype(elem, context="pto.tensor_view_type_from_dims(...)") + if all(isinstance(dim, int) for dim in dims): + return _pto.TensorViewType.get(list(dims), resolved_elem) + return tensor_view_type(len(dims), resolved_elem) + + +def part_tensor_view_type(rank: int, elem) -> Type: + """``!pto.partition_tensor_view`` with *rank* all-dynamic dims.""" + kDynamic = ShapedType.get_dynamic_size() + return _pto.PartitionTensorViewType.get( + [kDynamic] * rank, + _ensure_tensor_storage_dtype(elem, context="pto.part_tensor_view_type(...)"), + ) + + +def part_tensor_view_type_from_dims(dims, elem) -> Type: + """``!pto.partition_tensor_view`` when every dimension is static.""" + resolved_elem = _ensure_tensor_storage_dtype(elem, context="pto.part_tensor_view_type_from_dims(...)") + if all(isinstance(dim, int) for dim in dims): + return _pto.PartitionTensorViewType.get(list(dims), resolved_elem) + return part_tensor_view_type(len(dims), resolved_elem) + + +__all__ = [ + "_DType", "_resolve", + "float32", "float16", "bf16", + "f8e4m3", "f8e5m2", "hif8", "f4e1m2x2", "f4e2m1x2", + "int1", "int8", "int16", "int32", "int64", + "si8", "si16", "si32", "si64", + "ui8", "ui16", "ui32", "ui64", + "index", + "ptr", "vreg_type", "mask_type", + "tile_buf_type", "tensor_view_type", "tensor_view_type_from_dims", + "part_tensor_view_type", "part_tensor_view_type_from_dims", +] diff --git a/ptodsl/ptodsl/pto.py b/ptodsl/ptodsl/pto.py new file mode 100644 index 000000000..da43f3c2b --- /dev/null +++ b/ptodsl/ptodsl/pto.py @@ -0,0 +1,126 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +""" +``pto`` – the public DSL namespace. + +Import as:: + + import pto + +or as the sub-namespace ``pto`` from the ptodsl package:: + + from ptodsl import pto + +All user-facing symbols live here. Low-level MLIR bindings are accessed +internally as ``_pto`` (``from mlir.dialects import pto as _pto``). +""" + +from ._diagnostics import unsupported_public_surface_error + +# ── Types ───────────────────────────────────────────────────────────────────── +from ._types import ( # noqa: F401 + float32, float16, bf16, + f8e4m3, f8e5m2, hif8, f4e1m2x2, f4e2m1x2, + int1, int8, int16, int32, int64, + si8, si16, si32, si64, + ui8, ui16, ui32, ui64, + index, + ptr, vreg_type, mask_type, + _resolve, +) +from ._surface_types import ( # noqa: F401 + constexpr, + tensor_spec, + TensorSpec, + BarrierType, + Pipe, + MemorySpace, + MaskPattern, + CmpMode, + PredicatePart, + PredicateDist, + VStoreDist, + DeinterleaveDist, + InterleaveDist, + PostUpdate, + AlignType, + TensorView, + PartitionTensorView, + Tile, +) +from ._tensor_factories import empty_like # noqa: F401 +from ._tile_namespace import tile # noqa: F401 + +# ── Operations ──────────────────────────────────────────────────────────────── +from ._ops import ( # noqa: F401 + const, + castptr, addptr, + vlds, vldas, vldus, vldsx2, vsts, vstsx2, + init_align, + plt_b8, plt_b16, plt_b32, + pset_b8, pset_b16, pset_b32, + pge_b8, pge_b16, pge_b32, + make_mask, bytewidth, elements_per_vreg, + pand, por, pxor, pnot, psel, + pbitcast, + ppack, punpack, + pintlv_b8, pintlv_b16, pintlv_b32, + pdintlv_b8, pdintlv_b16, pdintlv_b32, + vgather2, vgather2_bc, vgatherb, vscatter, vsldb, vsstb, + vcmp, vcmps, + plds, psts, pstu, vstar, vstas, vstur, vstus, + vbitcast, + vbr, + vadd, vsub, vmul, vdiv, vmax, vmin, + vand, vor, vxor, vshl, vshr, + vcmax, vcadd, vcmin, vdup, vexpdif, + vexp, vln, vsqrt, vabs, vneg, vrec, vrsqrt, vrelu, vnot, + vcgmax, vcgadd, vcgmin, vcpadd, + vadds, vsubs, vmuls, vmaxs, vmins, vlrelu, + vaxpy, vaddrelu, vsubrelu, + vsel, + make_tensor_view, partition_view, + alloc_tile, + mte_load, mte_store, mte_gm_ub, mte_ub_gm, mte_ub_ub, mte_ub_l1, mem_bar, + mte_l1_l0a, mte_l1_l0b, mte_l0c_ub, + mad, mad_acc, mad_bias, mad_mx, mad_mx_acc, mad_mx_bias, + get_block_idx, get_block_num, get_subblock_idx, get_subblock_num, + store_vfsimt_info, get_tid_x, get_tid_y, get_tid_z, + pipe_barrier, + get_buf, rls_buf, + set_cross_flag, wait_cross_flag, set_intra_flag, wait_intra_flag, + set_flag, wait_flag, +) + +# ── Control flow ────────────────────────────────────────────────────────────── +from ._control_flow import ( # noqa: F401 + for_, if_, yield_, + LoopHandle, BranchHandle, +) + +# ── Decorator ───────────────────────────────────────────────────────────────── +from ._jit import jit, KernelHandle, merge_jit_modules # noqa: F401 +from ._subkernels import cube, simd, simt # noqa: F401 + +# ── Shorthand dtype aliases ─────────────────────────────────────────────────── +f32 = float32 +f16 = float16 +i1 = int1 +i8 = int8 +i16 = int16 +i32 = int32 +i64 = int64 +mask_b8 = mask_type("b8") +mask_b16 = mask_type("b16") +mask_b32 = mask_type("b32") + + +def __getattr__(name): + if name in {"ukernel", "tile_buf_type", "vecscope", "as_ptr", "vbrc_load", "vsts_1pt"}: + raise unsupported_public_surface_error(name) + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/ptodsl/ptodsl/scalar.py b/ptodsl/ptodsl/scalar.py new file mode 100644 index 000000000..fef02ff96 --- /dev/null +++ b/ptodsl/ptodsl/scalar.py @@ -0,0 +1,157 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +""" +Scalar arithmetic helpers – exposed as top-level ``scalar.*`` from the +``ptodsl`` package (for example ``from ptodsl import scalar``). + +Arithmetic helpers operate on raw ``mlir.ir.Value`` objects and emit the +corresponding arith dialect operations at the active insertion point. +Scalar memory helpers (`load` / `store`) also accept PTODSL surface-level +address views such as `tile[row, col]` and `tile.as_ptr() + offset`. +""" + +from ._bootstrap import make_context # ensure MLIR is on sys.path # noqa: F401 +from ._scalar_coercion import coerce_scalar_to_type +from ._runtime_scalar_ops import ( + classify_runtime_scalar_type, + emit_runtime_abs, + emit_runtime_binary_op, + emit_runtime_max, + emit_runtime_min, +) +from ._surface_values import resolve_address_access, unwrap_surface_value, wrap_surface_value +from ._types import _resolve + +from mlir.dialects import arith +from mlir.dialects import math +from mlir.ir import IndexType, MemRefType, Operation +from mlir.dialects import pto as _pto + + +def muli(lhs, rhs): + """arith.muli""" + return wrap_surface_value(emit_runtime_binary_op("mul", unwrap_surface_value(lhs), unwrap_surface_value(rhs))) + + +def addi(lhs, rhs): + """arith.addi""" + return wrap_surface_value(emit_runtime_binary_op("add", unwrap_surface_value(lhs), unwrap_surface_value(rhs))) + + +def subi(lhs, rhs): + """arith.subi""" + return wrap_surface_value(emit_runtime_binary_op("sub", unwrap_surface_value(lhs), unwrap_surface_value(rhs))) + + +def index_cast(type_or_val, val=None): + """ + arith.index_cast. + + Two calling conventions:: + + index_cast(result_type, value) # explicit result type + index_cast(value) # result type = index (1-arg shorthand) + """ + if val is None: + # 1-arg form: cast to index + return wrap_surface_value(arith.IndexCastOp(IndexType.get(), unwrap_surface_value(type_or_val)).result) + return wrap_surface_value(arith.IndexCastOp(_resolve(type_or_val), unwrap_surface_value(val)).result) + + +def select(cond, true_val, false_val): + """arith.select""" + return wrap_surface_value(arith.SelectOp( + unwrap_surface_value(cond), + unwrap_surface_value(true_val), + unwrap_surface_value(false_val), + ).result) + + +def max(lhs, rhs): + """Runtime scalar maximum across float / integer / index values.""" + return wrap_surface_value(emit_runtime_max( + unwrap_surface_value(lhs), + unwrap_surface_value(rhs), + )) + + +def min(lhs, rhs): + """Runtime scalar minimum across float / integer / index values.""" + return wrap_surface_value(emit_runtime_min( + unwrap_surface_value(lhs), + unwrap_surface_value(rhs), + )) + + +def exp(value): + """Runtime scalar exponential for floating-point values.""" + raw_value = unwrap_surface_value(value) + kind = classify_runtime_scalar_type(raw_value.type) + if kind != "float": + raise TypeError(f"scalar.exp(...) expects a floating-point runtime scalar, got {raw_value.type}") + return wrap_surface_value(math.ExpOp(raw_value).result) + + +def log(value): + """Runtime scalar natural logarithm for floating-point values.""" + raw_value = unwrap_surface_value(value) + kind = classify_runtime_scalar_type(raw_value.type) + if kind != "float": + raise TypeError(f"scalar.log(...) expects a floating-point runtime scalar, got {raw_value.type}") + return wrap_surface_value(math.LogOp(raw_value).result) + + +def sqrt(value): + """Runtime scalar square root for floating-point values.""" + raw_value = unwrap_surface_value(value) + kind = classify_runtime_scalar_type(raw_value.type) + if kind != "float": + raise TypeError(f"scalar.sqrt(...) expects a floating-point runtime scalar, got {raw_value.type}") + return wrap_surface_value(math.SqrtOp(raw_value).result) + + +def abs(value): + """Runtime scalar absolute value across float / integer / index values.""" + return wrap_surface_value(emit_runtime_abs(unwrap_surface_value(value))) + + +def load(ptr_or_ref, offset=None): + """Load one scalar element from a PTODSL address view or tile element.""" + buffer_value, index_value = resolve_address_access(ptr_or_ref, offset) + result_type = _infer_buffer_element_type(buffer_value.type) + return wrap_surface_value(Operation.create( + "pto.load", + results=[result_type], + operands=[buffer_value, index_value], + ).results[0]) + + +def store(value, ptr_or_ref, offset=None): + """Store one scalar element to a PTODSL address view or tile element.""" + buffer_value, index_value = resolve_address_access(ptr_or_ref, offset) + elem_type = _infer_buffer_element_type(buffer_value.type) + Operation.create( + "pto.store", + operands=[buffer_value, index_value, coerce_scalar_to_type(value, elem_type, context="scalar.store(...)")], + ) + + +def _infer_buffer_element_type(buffer_type): + try: + return _pto.PtrType(buffer_type).element_type + except Exception: + return MemRefType(buffer_type).element_type + + +__all__ = [ + "muli", "addi", "subi", + "index_cast", + "select", + "max", "min", "exp", "log", "sqrt", "abs", + "load", "store", +] diff --git a/ptodsl/pyproject.toml b/ptodsl/pyproject.toml new file mode 100644 index 000000000..07762191c --- /dev/null +++ b/ptodsl/pyproject.toml @@ -0,0 +1,13 @@ +[build-system] +requires = ["setuptools>=68", "wheel"] +build-backend = "setuptools.build_meta" + +[project] +name = "ptodsl" +version = "0.1.0" +description = "PTO MLIR DSL – Pythonic JIT-compiler-style IR builder for the PTO dialect" +requires-python = ">=3.9" + +[tool.setuptools.packages.find] +where = ["."] +include = ["ptodsl*"] diff --git a/ptodsl/tests/test_vector_cube_ops.py b/ptodsl/tests/test_vector_cube_ops.py new file mode 100644 index 000000000..49782bc04 --- /dev/null +++ b/ptodsl/tests/test_vector_cube_ops.py @@ -0,0 +1,381 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import unittest +import inspect +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +from ptodsl.ptodsl import _ops, pto + + +def _identity(value): + return value + + +class VectorCubeSurfaceTest(unittest.TestCase): + def test_public_namespace_exports_new_vector_and_cube_apis(self): + names = [ + "vsub", "vmin", "vand", "vor", "vxor", "vshl", "vshr", + "vln", "vsqrt", "vabs", "vneg", "vrec", "vrsqrt", "vrelu", "vnot", + "vcmin", "vcgmin", "vcpadd", + "vadds", "vmuls", "vmaxs", "vmins", "vlrelu", + "vaxpy", "vaddrelu", "vsubrelu", "vsel", + "mad_acc", "mad_bias", "mad_mx", "mad_mx_acc", "mad_mx_bias", + ] + + for name in names: + self.assertTrue(hasattr(pto, name), name) + + def test_tile_bitwise_aliases_are_exposed_without_legacy_names(self): + preferred_names = [ + "bit_not", "bit_and", "bit_ands", "bit_or", "bit_ors", + "bit_xor", "bit_xors", "bit_shl", "bit_shls", "bit_shr", "bit_shrs", + ] + legacy_names = [ + "not_", "and_", "ands", "or_", "ors", "xor", "xors", "shl", "shls", "shr", "shrs", + ] + + for name in preferred_names: + with self.subTest(name=name): + self.assertTrue(hasattr(pto.tile, name), name) + + for name in legacy_names: + with self.subTest(name=name): + self.assertFalse(hasattr(pto.tile, name), name) + + def test_tile_partial_and_fillpad_names_are_exposed_without_legacy_names(self): + preferred_names = [ + "partadd", "partmul", "partmax", "partmin", + "fillpad", "fillpad_expand", "fillpad_inplace", + ] + legacy_names = [ + "part_add", "part_mul", "part_max", "part_min", + "fill_pad", "fill_pad_expand", "fill_pad_inplace", + ] + + for name in preferred_names: + with self.subTest(name=name): + self.assertTrue(hasattr(pto.tile, name), name) + + for name in legacy_names: + with self.subTest(name=name): + self.assertFalse(hasattr(pto.tile, name), name) + + def test_sync_flag_names_are_exposed_without_legacy_aliases(self): + preferred_names = [ + "set_cross_flag", "wait_cross_flag", + "set_intra_flag", "wait_intra_flag", + ] + legacy_names = [ + "set_cross_core", "wait_flag_dev", + "set_intra_block", "wait_intra_core", + ] + + for name in preferred_names: + with self.subTest(name=name): + self.assertTrue(hasattr(pto, name), name) + + for name in legacy_names: + with self.subTest(name=name): + self.assertFalse(hasattr(pto, name), name) + + def test_direct_vector_wrappers_dispatch_to_generated_ops(self): + lhs = SimpleNamespace(type="vec_ty") + rhs = SimpleNamespace(type="vec_ty") + mask = SimpleNamespace(type="mask_ty") + result = object() + + binary_cases = [ + ("vsub", "VsubOp", (lhs, rhs, mask)), + ("vmin", "VminOp", (lhs, rhs, mask)), + ("vand", "VandOp", (lhs, rhs, mask)), + ("vor", "VorOp", (lhs, rhs, mask)), + ("vxor", "VxorOp", (lhs, rhs, mask)), + ("vshl", "VshlOp", (lhs, rhs, mask)), + ("vshr", "VshrOp", (lhs, rhs, mask)), + ] + unary_cases = [ + ("vln", "VlnOp", (lhs, mask)), + ("vsqrt", "VsqrtOp", (lhs, mask)), + ("vabs", "VabsOp", (lhs, mask)), + ("vneg", "VnegOp", (lhs, mask)), + ("vrelu", "VreluOp", (lhs, mask)), + ("vnot", "VnotOp", (lhs, mask)), + ("vcmin", "VcminOp", (lhs, mask)), + ("vcpadd", "VcpaddOp", (lhs, mask)), + ] + + with patch.object(_ops, "unwrap_surface_value", side_effect=_identity), \ + patch.object(_ops, "wrap_surface_value", side_effect=_identity): + for func_name, op_name, args in binary_cases + unary_cases: + with self.subTest(func=func_name): + fake_op = SimpleNamespace(result=result) + with patch.object(_ops._pto, op_name, return_value=fake_op) as op_ctor: + output = getattr(_ops, func_name)(*args) + self.assertIs(output, result) + self.assertEqual(op_ctor.call_args.args[0], "vec_ty") + + def test_vec_scalar_wrappers_and_vaxpy_coerce_scalar_operands(self): + vec = SimpleNamespace(type="vec_ty") + other = SimpleNamespace(type="vec_ty") + mask = SimpleNamespace(type="mask_ty") + scalar = object() + coerced_scalar = object() + result = object() + + vec_scalar_cases = [ + ("vadds", "VaddsOp"), + ("vmuls", "VmulsOp"), + ("vmaxs", "VmaxsOp"), + ("vmins", "VminsOp"), + ("vlrelu", "VlreluOp"), + ] + + with patch.object(_ops, "unwrap_surface_value", side_effect=_identity), \ + patch.object(_ops, "wrap_surface_value", side_effect=_identity), \ + patch.object(_ops, "_coerce_scalar_like_vector_element", return_value=coerced_scalar) as coerce_scalar: + for func_name, op_name in vec_scalar_cases: + with self.subTest(func=func_name): + fake_op = SimpleNamespace(result=result) + with patch.object(_ops._pto, op_name, return_value=fake_op) as op_ctor: + output = getattr(_ops, func_name)(vec, scalar, mask) + self.assertIs(output, result) + self.assertEqual(op_ctor.call_args.args, ("vec_ty", vec, coerced_scalar, mask)) + + fake_op = SimpleNamespace(result=result) + with patch.object(_ops._pto, "VaxpyOp", return_value=fake_op) as op_ctor: + output = _ops.vaxpy(scalar, vec, other, mask) + self.assertIs(output, result) + self.assertEqual(op_ctor.call_args.args, ("vec_ty", vec, other, coerced_scalar, mask)) + self.assertGreaterEqual(coerce_scalar.call_count, len(vec_scalar_cases) + 1) + + def test_composed_vector_wrappers_chain_existing_primitives(self): + vec = object() + rhs = object() + mask = object() + zero_vec = object() + one_vec = object() + sqrt_vec = object() + add_vec = object() + sub_vec = object() + relu_vec = object() + reciprocal_vec = object() + + with patch.object(_ops, "vmuls", return_value=zero_vec) as vmuls, \ + patch.object(_ops, "vadds", return_value=one_vec) as vadds, \ + patch.object(_ops, "vdiv", return_value=reciprocal_vec) as vdiv: + self.assertIs(_ops.vrec(vec, mask), reciprocal_vec) + vmuls.assert_called_once_with(vec, 0, mask) + vadds.assert_called_once_with(zero_vec, 1, mask) + vdiv.assert_called_once_with(one_vec, vec, mask) + + with patch.object(_ops, "vsqrt", return_value=sqrt_vec) as vsqrt, \ + patch.object(_ops, "vrec", return_value=reciprocal_vec) as vrec: + self.assertIs(_ops.vrsqrt(vec, mask), reciprocal_vec) + vsqrt.assert_called_once_with(vec, mask) + vrec.assert_called_once_with(sqrt_vec, mask) + + with patch.object(_ops, "vadd", return_value=add_vec) as vadd, \ + patch.object(_ops, "vrelu", return_value=relu_vec) as vrelu: + self.assertIs(_ops.vaddrelu(vec, rhs, mask), relu_vec) + vadd.assert_called_once_with(vec, rhs, mask) + vrelu.assert_called_once_with(add_vec, mask) + + with patch.object(_ops, "vsub", return_value=sub_vec) as vsub, \ + patch.object(_ops, "vrelu", return_value=relu_vec) as vrelu: + self.assertIs(_ops.vsubrelu(vec, rhs, mask), relu_vec) + vsub.assert_called_once_with(vec, rhs, mask) + vrelu.assert_called_once_with(sub_vec, mask) + + def test_vcgmin_and_vsel_dispatch_correctly(self): + vec = SimpleNamespace(type="vec_ty") + other = SimpleNamespace(type="vec_ty") + mask = SimpleNamespace(type="mask_ty") + reduced = object() + scalar = object() + selected = object() + + with patch.object(_ops, "unwrap_surface_value", side_effect=_identity), \ + patch.object(_ops, "wrap_surface_value", side_effect=_identity), \ + patch.object(_ops, "_extract_lowest_lane_scalar", return_value=scalar) as extract_scalar, \ + patch.object(_ops._pto, "VcgminOp", return_value=SimpleNamespace(result=reduced)) as vcgmin_op: + output = _ops.vcgmin(vec, mask) + self.assertIs(output, scalar) + self.assertEqual(vcgmin_op.call_args.args, ("vec_ty", vec, mask)) + extract_scalar.assert_called_once_with(reduced, mask) + + with patch.object(_ops, "unwrap_surface_value", side_effect=_identity), \ + patch.object(_ops, "wrap_surface_value", side_effect=_identity), \ + patch.object(_ops._pto, "VselOp", return_value=SimpleNamespace(result=selected)) as vsel_op: + output = _ops.vsel(vec, other, mask) + self.assertIs(output, selected) + self.assertEqual(vsel_op.call_args.args, ("vec_ty", vec, other, mask)) + + def test_cube_variant_wrappers_dispatch_to_generated_ops(self): + lhs = object() + rhs = object() + dst = object() + bias = object() + + cube_cases = [ + ("mad_acc", "MadAccOp", (lhs, rhs, dst, 1, 2, 3), (lhs, rhs, dst, "i64:1", "i64:2", "i64:3")), + ("mad_bias", "MadBiasOp", (lhs, rhs, dst, bias, 1, 2, 3), (lhs, rhs, dst, bias, "i64:1", "i64:2", "i64:3")), + ("mad_mx", "MadMxOp", (lhs, rhs, dst, 1, 2, 3), (lhs, rhs, dst, "i64:1", "i64:2", "i64:3")), + ("mad_mx_acc", "MadMxAccOp", (lhs, rhs, dst, 1, 2, 3), (lhs, rhs, dst, "i64:1", "i64:2", "i64:3")), + ("mad_mx_bias", "MadMxBiasOp", (lhs, rhs, dst, bias, 1, 2, 3), (lhs, rhs, dst, bias, "i64:1", "i64:2", "i64:3")), + ] + + with patch.object(_ops, "unwrap_surface_value", side_effect=_identity), \ + patch.object(_ops, "_coerce_i64", side_effect=lambda value, *, context: f"i64:{value}"): + for func_name, op_name, args, expected_call in cube_cases: + with self.subTest(func=func_name): + op_ctor = MagicMock() + with patch.object(_ops._pto, op_name, op_ctor): + getattr(_ops, func_name)(*args) + self.assertEqual(op_ctor.call_args.args, expected_call) + + def test_tile_selection_surface_exposes_optional_tmp(self): + for func, expected in [ + (_ops.tsel, ["mask", "src0", "src1", "dst", "tmp"]), + (_ops.tsels, ["mask", "src", "scalar", "dst", "tmp"]), + (pto.tile.sel, ["mask", "src0", "src1", "dst", "tmp"]), + (pto.tile.sels, ["mask", "src", "scalar", "dst", "tmp"]), + ]: + with self.subTest(func=func): + signature = inspect.signature(func) + self.assertEqual(list(signature.parameters.keys()), expected) + self.assertEqual(signature.parameters["tmp"].kind, inspect.Parameter.KEYWORD_ONLY) + self.assertIsNone(signature.parameters["tmp"].default) + + def test_tile_selection_wrappers_use_explicit_tmp_or_synthesize_one(self): + mask = object() + src0 = object() + src1 = object() + src = object() + dst = object() + tmp = object() + scalar = object() + coerced_scalar = object() + synthesized_tmp = object() + + with patch.object(_ops, "unwrap_surface_value", side_effect=_identity), \ + patch.object(_ops, "_coerce_tile_scalar_operand", return_value=coerced_scalar): + with patch.object(_ops, "_resolve_selection_tmp", return_value=synthesized_tmp) as resolve_tmp, \ + patch.object(_ops._pto, "tsel") as tsel_op: + _ops.tsel(mask, src0, src1, dst) + resolve_tmp.assert_called_once_with(dst, None, context="tsel") + self.assertEqual(tsel_op.call_args.args, (mask, src0, src1, synthesized_tmp, dst)) + + with patch.object(_ops, "_resolve_selection_tmp", side_effect=AssertionError("should not synthesize")), \ + patch.object(_ops._pto, "tsel") as tsel_op: + _ops.tsel(mask, src0, src1, dst, tmp=tmp) + self.assertEqual(tsel_op.call_args.args, (mask, src0, src1, tmp, dst)) + + with patch.object(_ops, "_resolve_selection_tmp", return_value=synthesized_tmp) as resolve_tmp, \ + patch.object(_ops._pto, "tsels") as tsels_op: + _ops.tsels(mask, src, scalar, dst) + resolve_tmp.assert_called_once_with(dst, None, context="tsels") + self.assertEqual(tsels_op.call_args.args, (mask, src, synthesized_tmp, coerced_scalar, dst)) + + with patch.object(_ops, "_resolve_selection_tmp", side_effect=AssertionError("should not synthesize")), \ + patch.object(_ops._pto, "tsels") as tsels_op: + _ops.tsels(mask, src, scalar, dst, tmp=tmp) + self.assertEqual(tsels_op.call_args.args, (mask, src, tmp, coerced_scalar, dst)) + + def test_tile_row_reductions_expose_optional_tmp_and_synthesize_one(self): + src = SimpleNamespace(type="src_ty") + dst = object() + tmp = object() + synthesized_tmp = object() + + row_cases = [ + ("rowsum", "trowsum"), + ("rowmax", "trowmax"), + ("rowmin", "trowmin"), + ("rowprod", "trowprod"), + ("rowargmax", "trowargmax"), + ("rowargmin", "trowargmin"), + ] + + for name, low_level_name in row_cases: + with self.subTest(func=name): + signature = inspect.signature(getattr(pto.tile, name)) + self.assertEqual(list(signature.parameters.keys()), ["src", "dst", "tmp"]) + self.assertEqual(signature.parameters["tmp"].kind, inspect.Parameter.KEYWORD_ONLY) + self.assertIsNone(signature.parameters["tmp"].default) + + with patch.object(_ops, "unwrap_surface_value", side_effect=_identity), \ + patch.object(_ops, "alloc_tile", return_value=synthesized_tmp) as alloc_tile, \ + patch.object(_ops, low_level_name) as low_level_op: + getattr(pto.tile, name)(src, dst) + alloc_tile.assert_called_once_with(tile_type="src_ty") + low_level_op.assert_called_once_with(src, synthesized_tmp, dst) + + with patch.object(_ops, "unwrap_surface_value", side_effect=_identity), \ + patch.object(_ops, "alloc_tile", side_effect=AssertionError("should not synthesize")), \ + patch.object(_ops, low_level_name) as low_level_op: + getattr(pto.tile, name)(src, dst, tmp=tmp) + low_level_op.assert_called_once_with(src, tmp, dst) + + def test_sync_event_id_rejects_out_of_range_static_values(self): + cases = [ + (_ops.set_flag, ("MTE2", "V"), {"event_id": 8}, "set_flag(..., event_id=...)"), + (_ops.wait_flag, ("MTE2", "V"), {"event_id": -1}, "wait_flag(..., event_id=...)"), + (_ops.set_cross_flag, (pto.Pipe.FIX, 8), {}, "set_cross_flag(..., event_id=...)"), + (_ops.wait_cross_flag, (pto.Pipe.FIX, -1), {}, "wait_cross_flag(..., event_id=...)"), + (_ops.set_intra_flag, (pto.Pipe.MTE3, 9), {}, "set_intra_flag(..., event_id=...)"), + (_ops.wait_intra_flag, (pto.Pipe.V, -2), {}, "wait_intra_flag(..., event_id=...)"), + ] + + with patch.object(_ops._pto, "set_flag") as set_flag_op, \ + patch.object(_ops._pto, "set_flag_dyn") as set_flag_dyn_op, \ + patch.object(_ops._pto, "wait_flag") as wait_flag_op, \ + patch.object(_ops._pto, "wait_flag_dyn") as wait_flag_dyn_op, \ + patch.object(_ops._pto, "sync_set") as sync_set_op, \ + patch.object(_ops._pto, "sync_wait") as sync_wait_op: + for func, args, kwargs, context in cases: + with self.subTest(func=func.__name__, event_id=kwargs.get("event_id", args[-1])): + with self.assertRaises(ValueError) as exc: + func(*args, **kwargs) + message = str(exc.exception) + self.assertIn(context, message) + self.assertIn("[0, 7]", message) + + set_flag_op.assert_not_called() + set_flag_dyn_op.assert_not_called() + wait_flag_op.assert_not_called() + wait_flag_dyn_op.assert_not_called() + sync_set_op.assert_not_called() + sync_wait_op.assert_not_called() + + def test_sync_facades_reject_illegal_pipe_endpoints(self): + cases = [ + (_ops.set_cross_flag, (pto.Pipe.V, 0), "set_cross_flag(pipe, event_id)", "", ""), + (_ops.wait_cross_flag, (pto.Pipe.MTE3, 0), "wait_cross_flag(pipe, event_id)", "", ""), + (_ops.set_intra_flag, (pto.Pipe.FIX, 0), "set_intra_flag(pipe, event_id)", "", ""), + (_ops.wait_intra_flag, (pto.Pipe.MTE3, 0), "wait_intra_flag(pipe, event_id)", "", ""), + ] + + with patch.object(_ops._pto, "sync_set") as sync_set_op, \ + patch.object(_ops._pto, "sync_wait") as sync_wait_op: + for func, args, context, expected, actual in cases: + with self.subTest(func=func.__name__, pipe=args[0]): + with self.assertRaises(ValueError) as exc: + func(*args) + message = str(exc.exception) + self.assertIn(context, message) + self.assertIn(expected, message) + self.assertIn(actual, message) + + sync_set_op.assert_not_called() + sync_wait_op.assert_not_called() + + +if __name__ == "__main__": + unittest.main() diff --git a/pyproject.toml b/pyproject.toml index 56f81ad59..2a03b7276 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,9 +1,22 @@ # pyproject.toml - Project metadata # -# NOTE: This project has a complex build process that requires LLVM/MLIR to be -# built first. The wheel is created from MLIR's python packages directory, not -# from this repo root. See .github/workflows/build_wheel.yml for the full build -# process. +# Build flow (requires LLVM already built): +# pip install . +# +# This will: +# 1. CMake configure + Ninja build + install +# 2. Package Python bindings into a wheel via docker/create_wheel.sh +# +# Environment variables (all optional): +# LLVM_BUILD_DIR Path to LLVM build dir +# (default: /llvm-workspace/llvm-project/build-shared) +# PTO_INSTALL_DIR Install prefix (default: /install) +# PTOAS_PYTHON_PACKAGE_VERSION Wheel version override + +[build-system] +requires = ["setuptools>=68", "wheel"] +build-backend = "_ptoas_build_backend" +backend-path = ["."] [project] name = "ptoas" @@ -11,7 +24,10 @@ version = "0.1.0" description = "PTO Assembler & Optimizer" readme = "README.md" requires-python = ">=3.9" -license = {text = "Apache-2.0"} +license = "Apache-2.0" +dependencies = [ + "numpy", +] classifiers = [ "Development Status :: 3 - Alpha", "Intended Audience :: Developers", diff --git a/python/pto/dialects/pto.py b/python/pto/dialects/pto.py index c56756f56..ab5788683 100644 --- a/python/pto/dialects/pto.py +++ b/python/pto/dialects/pto.py @@ -51,6 +51,9 @@ def get_op_result_or_value(value): register_dialect = _pto_mod.register_dialect PtrType = _pto_mod.PtrType +VRegType = _pto_mod.VRegType +MaskType = _pto_mod.MaskType +AlignType = _pto_mod.AlignType AsyncSessionType = _pto_mod.AsyncSessionType AsyncEventType = _pto_mod.AsyncEventType HiF8Type = _pto_mod.HiF8Type @@ -115,9 +118,19 @@ def _ptr_type_get_compat(cls, element_type, memory_space=None, context=None): raise TypeError("PtrType.get got multiple context arguments") context = memory_space memory_space = None - return _ptr_type_get_impl( - element_type, memory_space=memory_space, context=context - ) + if memory_space is None: + if context is None: + return _ptr_type_get_impl(element_type) + return _ptr_type_get_impl(element_type, context=context) + try: + return _ptr_type_get_impl( + element_type, memory_space=memory_space, context=context + ) + except TypeError as exc: + raise TypeError( + "PtrType.get(element_type, memory_space=...) requires a PTO Python " + "extension built with non-default address-space pointer support" + ) from exc PtrType.get = classmethod(_ptr_type_get_compat) @@ -162,6 +175,9 @@ def compat_init(self, *args, __orig_init=original_init, precision_mode=None, **k "register_dialect", # Types "PtrType", + "VRegType", + "MaskType", + "AlignType", "AsyncSessionType", "AsyncEventType", "HiF8Type", diff --git a/quick_install.sh b/quick_install.sh new file mode 100755 index 000000000..38509b61d --- /dev/null +++ b/quick_install.sh @@ -0,0 +1,79 @@ +#!/usr/bin/env bash +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# For quick development, build and install ptoas and its python bindings +# on top of Docker image https://github.com/learning-chip/agent_docker_npu/pull/8 +# assume MLIR is already installed to save time, takes <3min to finish the build of pto extension +# +# Optional env: +# LLVM_BUILD_DIR - default: ${LLVM_SOURCE_DIR:-/llvm-workspace/llvm-project}/build-shared +# PTO_INSTALL_DIR - default: /install + +set -euo pipefail + +PTO_SOURCE_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +PTO_INSTALL_DIR="${PTO_INSTALL_DIR:-${PTO_SOURCE_DIR}/install}" + +LLVM_SOURCE_DIR="${LLVM_SOURCE_DIR:-/llvm-workspace/llvm-project}" +LLVM_BUILD_DIR="${LLVM_BUILD_DIR:-${LLVM_SOURCE_DIR}/build-shared}" + +PY_ROOT="$(python -c 'import sys; print(sys.prefix)')" + +for d in "$LLVM_BUILD_DIR/lib/cmake/llvm" "$LLVM_BUILD_DIR/lib/cmake/mlir"; do + test -d "$d" || { echo "error: missing $d (set LLVM_BUILD_DIR?)" >&2; exit 1; } +done + +PYBIND11_DIR="$(python -m pybind11 --cmakedir)" +MLIR_PY_PKG="${LLVM_BUILD_DIR}/tools/mlir/python_packages/mlir_core" +test -d "$MLIR_PY_PKG" || { echo "error: MLIR python package dir missing: $MLIR_PY_PKG" >&2; exit 1; } + +PTOAS_VERSION="${PTOAS_VERSION:-$(python "${PTO_SOURCE_DIR}/.github/scripts/compute_ptoas_version.py" --cmake-file "${PTO_SOURCE_DIR}/CMakeLists.txt" --mode dev)}" + +cd "$PTO_SOURCE_DIR" + +cmake -C "${PTO_SOURCE_DIR}/cmake/LinuxHardeningCache.cmake" -G Ninja \ + -S . \ + -B build \ + -DLLVM_DIR="${LLVM_BUILD_DIR}/lib/cmake/llvm" \ + -DMLIR_DIR="${LLVM_BUILD_DIR}/lib/cmake/mlir" \ + -DPython3_ROOT_DIR="${PY_ROOT}" \ + -DPython3_EXECUTABLE=python \ + -DPython3_FIND_STRATEGY=LOCATION \ + -Dpybind11_DIR="${PYBIND11_DIR}" \ + -DMLIR_PYTHON_PACKAGE_DIR="${MLIR_PY_PKG}" \ + -DPTOAS_RELEASE_VERSION_OVERRIDE="${PTOAS_VERSION}" \ + -DCMAKE_INSTALL_PREFIX="${PTO_INSTALL_DIR}" + +ninja -C build +ninja -C build install + +export PTO_SOURCE_DIR PTO_INSTALL_DIR LLVM_BUILD_DIR +export PTOAS_PYTHON_PACKAGE_VERSION="${PTOAS_PYTHON_PACKAGE_VERSION:-${PTOAS_VERSION}}" +bash "${PTO_SOURCE_DIR}/docker/create_wheel.sh" + +shopt -s nullglob +wheels=("${MLIR_PY_PKG}/dist/ptoas-"*.whl) +shopt -u nullglob +((${#wheels[@]} > 0)) || { echo "error: no ptoas-*.whl under ${MLIR_PY_PKG}/dist" >&2; exit 1; } +pip install --force-reinstall "${wheels[0]}" + +export PATH="${PTO_SOURCE_DIR}/build/tools/ptoas:${PATH}" +export LD_LIBRARY_PATH="${LLVM_BUILD_DIR}/lib:${PTO_INSTALL_DIR}/lib:${LD_LIBRARY_PATH:-}" + +python -c "import mlir.ir" +python -c "from mlir.dialects import pto" + +which ptoas + +PTOAS_ENV_TMP="${PTO_SOURCE_DIR}/tmp/set_ptoas_env" +mkdir -p "${PTOAS_ENV_TMP}/MatMul" "${PTOAS_ENV_TMP}/Abs" +(cd "${PTO_SOURCE_DIR}/test/samples/MatMul" && python ./tmatmulk.py > "${PTOAS_ENV_TMP}/MatMul/tmatmulk.pto" && ptoas "${PTOAS_ENV_TMP}/MatMul/tmatmulk.pto" -o "${PTOAS_ENV_TMP}/MatMul/tmatmulk.cpp") +(cd "${PTO_SOURCE_DIR}/test/samples/Abs" && python ./abs.py > "${PTOAS_ENV_TMP}/Abs/abs.pto" && ptoas --enable-insert-sync "${PTOAS_ENV_TMP}/Abs/abs.pto" -o "${PTOAS_ENV_TMP}/Abs/abs.cpp") + +echo "quick_install.sh: OK" diff --git a/scripts/sim_dsl.sh b/scripts/sim_dsl.sh new file mode 100755 index 000000000..26fe4c658 --- /dev/null +++ b/scripts/sim_dsl.sh @@ -0,0 +1,113 @@ +#!/usr/bin/env bash +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +set -euo pipefail + +SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" && pwd)" +REPO_ROOT="$(cd -- "${SCRIPT_DIR}/.." && pwd)" + +usage() { + cat <<'EOF' +Run a PTODSL JIT example under `msprof op simulator`. + +Usage: + scripts/sim_dsl.sh [options] [-- ] + +Options: + --output Override msprof output directory. + --soc-version Override simulator soc version. Default: Ascend950PR_9599 + -h, --help Show this help. + +Examples: + scripts/sim_dsl.sh ptodsl/examples/jit/tadd_launch.py + scripts/sim_dsl.sh \ + --output "$PWD/build/msprof_res/flash_softmax" \ + ptodsl/examples/jit/flash_attention_softmax_launch.py +EOF +} + +die() { + echo "error: $*" >&2 + exit 1 +} + +SOC_VERSION="Ascend950PR_9599" +OUTPUT_DIR="" +EXAMPLE_PATH="" +EXAMPLE_ARGS=() + +while [[ $# -gt 0 ]]; do + case "$1" in + --output) + [[ $# -ge 2 ]] || die "--output requires a value" + OUTPUT_DIR="$2" + shift 2 + ;; + --soc-version) + [[ $# -ge 2 ]] || die "--soc-version requires a value" + SOC_VERSION="$2" + shift 2 + ;; + -h|--help) + usage + exit 0 + ;; + --) + shift + EXAMPLE_ARGS=("$@") + break + ;; + -*) + die "unknown option: $1" + ;; + *) + if [[ -z "${EXAMPLE_PATH}" ]]; then + EXAMPLE_PATH="$1" + else + EXAMPLE_ARGS+=("$1") + fi + shift + ;; + esac +done + +[[ -n "${EXAMPLE_PATH}" ]] || die "missing " + +if [[ "${EXAMPLE_PATH}" != /* ]]; then + EXAMPLE_PATH="${REPO_ROOT}/${EXAMPLE_PATH}" +fi +[[ -f "${EXAMPLE_PATH}" ]] || die "example script not found: ${EXAMPLE_PATH}" + +if [[ -z "${ASCEND_HOME_PATH:-}" ]]; then + die "ASCEND_HOME_PATH is not set; source CANN setenv or export it first" +fi + +if [[ -z "${OUTPUT_DIR}" ]]; then + EXAMPLE_STEM="$(basename -- "${EXAMPLE_PATH}" .py)" + OUTPUT_DIR="${REPO_ROOT}/build/msprof_res/${EXAMPLE_STEM}" +fi + +SIM_LIB_DIR="${ASCEND_HOME_PATH}/tools/simulator/${SOC_VERSION}/lib" +[[ -d "${SIM_LIB_DIR}" ]] || die "simulator library directory not found: ${SIM_LIB_DIR}" + +mkdir -p "${OUTPUT_DIR}" + +source "${ASCEND_HOME_PATH}/bin/setenv.bash" +source "${REPO_ROOT}/set_ptoas_env.sh" +export LD_LIBRARY_PATH="${SIM_LIB_DIR}:${LD_LIBRARY_PATH:-}" +ulimit -n 65535 + +# msprof rejects group/other-writable working directories, so always launch +# from a private directory and use an absolute path for the example script. +cd "${HOME}" + +exec msprof op simulator \ + --soc-version="${SOC_VERSION}" \ + --output="${OUTPUT_DIR}" \ + python3 "${EXAMPLE_PATH}" "${EXAMPLE_ARGS[@]}" diff --git a/set_ptoas_env.sh b/set_ptoas_env.sh new file mode 100644 index 000000000..996d84ad4 --- /dev/null +++ b/set_ptoas_env.sh @@ -0,0 +1,21 @@ +#!/usr/bin/env bash +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +# after `quick_install.sh`, run `source set_ptoas_env.sh` in a new shell to find the lib +export PTO_SOURCE_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +PTO_INSTALL_DIR="${PTO_INSTALL_DIR:-${PTO_SOURCE_DIR}/install}" +export PATH="${PTO_SOURCE_DIR}/build/tools/ptoas:${PATH}" +export LD_LIBRARY_PATH="${LLVM_BUILD_DIR}/lib:${PTO_INSTALL_DIR}/lib:${LD_LIBRARY_PATH:-}" + +PTOAS_ENV_TMP="${PTO_SOURCE_DIR}/tmp/set_ptoas_env" +mkdir -p "${PTOAS_ENV_TMP}/MatMul" "${PTOAS_ENV_TMP}/Abs" +(cd "${PTO_SOURCE_DIR}/test/samples/MatMul" && python ./tmatmulk.py > "${PTOAS_ENV_TMP}/MatMul/tmatmulk.pto" && ptoas "${PTOAS_ENV_TMP}/MatMul/tmatmulk.pto" -o "${PTOAS_ENV_TMP}/MatMul/tmatmulk.cpp") +(cd "${PTO_SOURCE_DIR}/test/samples/Abs" && python ./abs.py > "${PTOAS_ENV_TMP}/Abs/abs.pto" && ptoas --enable-insert-sync "${PTOAS_ENV_TMP}/Abs/abs.pto" -o "${PTOAS_ENV_TMP}/Abs/abs.cpp") + +echo "test set_env: OK" diff --git a/temp_docs/standalone_st.md b/temp_docs/standalone_st.md new file mode 100644 index 000000000..4538e9486 --- /dev/null +++ b/temp_docs/standalone_st.md @@ -0,0 +1,112 @@ +Minimum commands to run a single standalone st + + +``` +# 0) env +cd /workdir/ptoas_a5 +source set_ptoas_env.sh +source "${ASCEND_HOME_PATH}/bin/setenv.bash" +export LD_LIBRARY_PATH="${ASCEND_HOME_PATH}/tools/simulator/Ascend950PR_9599/lib:${ASCEND_HOME_PATH}/runtime/lib64/stub:${LD_LIBRARY_PATH}" + +# 1) build (tadd currently fails here; tload succeeds) +ST=/workdir/ptoas_a5/test/tilelang_st/npu/a5/src/st +cd "$ST" && rm -rf build && mkdir build && cd build +cmake .. -DRUN_MODE=sim -DSOC_VERSION=Ascend950PR_9599 -DTEST_CASE=tadd \ + -DPTOAS_BIN=/workdir/ptoas_a5/build/tools/ptoas/ptoas +make -j"$(nproc)" tadd # ← ✅ works now with beta1 + +export LD_LIBRARY_PATH="${ST}/build/lib:${LD_LIBRARY_PATH}" + +# 2) gen golden + inputs +WORK="${ST}/build/testcase/tadd" +mkdir -p "$WORK" +cp "${ST}/testcase/st_common.py" "$WORK/" +cp "${ST}/testcase/tadd/"{cases.py,gen_data.py,compare.py} "$WORK/" +cd "$WORK" && python3 gen_data.py # ✅ verified + +# 3) run main (blocked until build succeeds) +../../bin/tadd # ✅ runs CA model now + +# 4) validate +python3 compare.py # ✅ verified +``` + + +Equivalent, plain CLI, no cmake/make: + + +```bash +# 0) env +cd /workdir/ptoas_a5 +source set_ptoas_env.sh +source "${ASCEND_HOME_PATH}/bin/setenv.bash" +export LD_LIBRARY_PATH="${ASCEND_HOME_PATH}/tools/simulator/Ascend950PR_9599/lib:${ASCEND_HOME_PATH}/runtime/lib64/stub:${LD_LIBRARY_PATH}" + +ST=/workdir/ptoas_a5/test/tilelang_st/npu/a5/src/st +TC="$ST/testcase/tadd" +BUILD="$ST/build" +PTOAS=/workdir/ptoas_a5/build/tools/ptoas/ptoas + +# 1) build (plain commands — no cmake/make) +rm -rf "$BUILD" && mkdir -p "$BUILD/bin" "$BUILD/lib" "$BUILD/testcase/tadd" +cd "$BUILD/testcase/tadd" + +# 1a) ptoas: tadd.pto -> tadd_kernel.o +"$PTOAS" --pto-arch=a5 --pto-backend=vpto --enable-insert-sync --enable-tile-op-expand \ + "$TC/tadd.pto" -o tadd_kernel.o + +# 1b) compile launch.cpp + link kernel shared library +bisheng -fPIC -D_FORTIFY_SOURCE=2 -O2 -std=c++17 \ + -Wno-macro-redefined -Wno-ignored-attributes -Wno-unknown-attributes \ + -fstack-protector-strong -fPIC \ + -xcce -Xhost-start -Xhost-end \ + -mllvm -cce-aicore-stack-size=0x8000 \ + -mllvm -cce-aicore-function-stack-size=0x8000 \ + -mllvm -cce-aicore-record-overflow=true \ + -mllvm -cce-aicore-addr-transform \ + -mllvm -cce-aicore-dcci-insert-for-scalar=false \ + --cce-aicore-arch=dav-c310-vec -std=gnu++17 \ + -Dtadd_kernel_EXPORTS \ + -I"${ASCEND_HOME_PATH}/include" \ + -I/usr/local/Ascend/driver/kernel/inc \ + -I"${ASCEND_HOME_PATH}/pkg_inc" \ + -I"${ASCEND_HOME_PATH}/pkg_inc/profiling" \ + -I"${ASCEND_HOME_PATH}/pkg_inc/runtime/runtime" \ + -c "$TC/launch.cpp" -o launch.cpp.o + +bisheng -fPIC -s -Wl,-z,relro -Wl,-z,now --cce-fatobj-link -shared \ + -Wl,-soname,libtadd_kernel.so \ + -o ../../lib/libtadd_kernel.so launch.cpp.o tadd_kernel.o + +# 1c) compile main.cpp + link host executable +bisheng -fPIE -D_FORTIFY_SOURCE=2 -O2 -std=c++17 \ + -Wno-macro-redefined -Wno-ignored-attributes -Wno-unknown-attributes \ + -fstack-protector-strong -fPIC \ + -xc++ -include stdint.h -include stddef.h -std=gnu++17 \ + -I"${ASCEND_HOME_PATH}/include" \ + -I/usr/local/Ascend/driver/kernel/inc \ + -I"$ST/common" \ + -c "$TC/main.cpp" -o main.cpp.o + +bisheng -s -Wl,-z,relro -Wl,-z,now main.cpp.o -o ../../bin/tadd \ + -L"${ASCEND_HOME_PATH}/lib64" \ + -L"${ASCEND_HOME_PATH}/tools/simulator/Ascend950PR_9599/lib" \ + -Wl,-rpath,"${ASCEND_HOME_PATH}/lib64:${ASCEND_HOME_PATH}/tools/simulator/Ascend950PR_9599/lib:${BUILD}/lib" \ + ../../lib/libtadd_kernel.so \ + -lruntime_camodel -lstdc++ -lascendcl -lm -ltiling_api -lplatform -lc_sec -ldl -lnnopbase -lpthread + +export LD_LIBRARY_PATH="${BUILD}/lib:${LD_LIBRARY_PATH}" + +# 2) gen golden + inputs +WORK="${BUILD}/testcase/tadd" +mkdir -p "$WORK" +cp "${ST}/testcase/st_common.py" "$WORK/" +cp "${TC}/"{cases.py,gen_data.py,compare.py} "$WORK/" +cd "$WORK" && python3 gen_data.py + +# 3) run main +../../bin/tadd + +# 4) validate +python3 compare.py +``` diff --git a/test/python/ptodsl_docs_as_test.py b/test/python/ptodsl_docs_as_test.py new file mode 100644 index 000000000..c920dc5ff --- /dev/null +++ b/test/python/ptodsl_docs_as_test.py @@ -0,0 +1,588 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +from contextlib import contextmanager +from dataclasses import dataclass +from pathlib import Path +from typing import Iterable +import json +import re +import shutil +import subprocess +import sys +import tempfile +from unittest import mock + + +REPO_ROOT = Path(__file__).resolve().parents[2] +USER_GUIDE_ROOT = REPO_ROOT / "ptodsl" / "docs" / "user_guide" +sys.path.insert(0, str(REPO_ROOT / "ptodsl")) + +from ptodsl import pto, scalar +from ptodsl._bootstrap import make_context +from ptodsl._runtime.launch import LaunchHandle, _marshal_launch_args +from mlir.ir import Module +from ptodsl_docs_fragment_fixtures import FRAGMENT_FIXTURES, render_fragment_fixture + +FENCE_RE = re.compile(r"^```(?P[A-Za-z0-9_+-]*)\s*$") +META_RE = re.compile(r"^\s*\s*$") + + +@dataclass(frozen=True) +class MarkdownCodeBlock: + path: Path + start_line: int + end_line: int + language: str + lines: tuple[str, ...] + metadata: "DocBlockMetadata | None" + + @property + def text(self) -> str: + return "".join(self.lines) + + +@dataclass(frozen=True) +class MarkdownScanResult: + path: Path + blocks: tuple[MarkdownCodeBlock, ...] + + +@dataclass(frozen=True) +class DocBlockMetadata: + kind: str + body: str + line: int + raw: str + + +@dataclass(frozen=True) +class DocTestDirective: + mode: str + symbol: str | None = None + compile_kwargs: dict[str, object] | None = None + fixture: str | None = None + + +@dataclass(frozen=True) +class LaunchRecord: + compiled: object + grid: int + stream: object + args: tuple[object, ...] + marshaled_arg_count: int + + +def expect(condition: bool, message: str) -> None: + if not condition: + raise AssertionError(message) + + +def format_doc_context(path: Path, start_line: int, symbol: str | None = None) -> str: + symbol_text = symbol if symbol is not None else "" + return f"{path}:{start_line} [symbol={symbol_text}]" + + +def fail_doc(path: Path, start_line: int, message: str, symbol: str | None = None) -> None: + raise AssertionError(f"{format_doc_context(path, start_line, symbol)}: {message}") + + +def iter_markdown_files(root: Path) -> Iterable[Path]: + yield from sorted(root.glob("*.md")) + + +def parse_metadata_line(path: Path, line: str, line_number: int) -> DocBlockMetadata | None: + match = META_RE.match(line) + if match is None: + return None + + kind = match.group("kind") + body = match.group("body").strip() + expect(body, f"{format_doc_context(path, line_number)}: ptodsl-doc-{kind} metadata must not be empty") + if kind == "test": + try: + json.loads(body) + except json.JSONDecodeError as exc: + raise AssertionError( + f"{format_doc_context(path, line_number)}: ptodsl-doc-test metadata must be valid JSON: {exc.msg}" + ) from exc + return DocBlockMetadata(kind=kind, body=body, line=line_number, raw=line.rstrip("\n")) + + +def find_block_metadata(path: Path, lines: list[str], fence_line: int) -> DocBlockMetadata | None: + candidate = fence_line - 2 + while candidate >= 0 and not lines[candidate].strip(): + candidate -= 1 + if candidate < 0: + return None + line = lines[candidate] + if line.lstrip().startswith("