diff --git a/.gitignore b/.gitignore index 740d5464fb9..a2fb1473abb 100644 --- a/.gitignore +++ b/.gitignore @@ -81,7 +81,23 @@ CMakeUserPresets.json # Python cache __pycache__/ +# Cache directories .cache/ +.ck_tile_cache/ +ck_tile_cache/ +**/kernel_cache/ +**/.kernel_cache/ + +# Dispatcher kernel cache (user-generated, can be large) +dispatcher/**/kernel_cache/ +dispatcher/**/.kernel_cache/ +dispatcher/**/cached_kernels/ +dispatcher/**/*.hsaco +dispatcher/**/*.co + +# Dispatcher generated JSON exports +dispatcher/**/*_kernels.json +dispatcher/**/dispatcher_kernels.json # Generated test data test_data/* diff --git a/CHANGELOG.md b/CHANGELOG.md index c3a257e464a..54c8b776ddb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -24,6 +24,7 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj ## Composable Kernel 1.2.0 for ROCm 7.2.0 ### Added +* Added CK-Tile dispatcher - a unified kernel dispatch, code generation and architecture-based kernel filtering system with with C++ and Python frontends starting with GEMM support. * Added support for bf16 data type to grouped_gemm and grouped_gemm_preshuffle. * Added Col-Col-Row-Col layout support for aquant mode in blockscale GEMM. * Added support for mixed precision fp8 x bf8 universal GEMM and weight preshuffle GEMM. diff --git a/dispatcher/CMakeLists.txt b/dispatcher/CMakeLists.txt new file mode 100644 index 00000000000..2acc73d1d50 --- /dev/null +++ b/dispatcher/CMakeLists.txt @@ -0,0 +1,117 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +cmake_minimum_required(VERSION 3.16) + +project(ck_tile_dispatcher VERSION 1.0.0 LANGUAGES CXX) + +# C++17 required +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD_REQUIRED ON) +set(CMAKE_CXX_EXTENSIONS OFF) + +# Find HIP for headers (needed for validation kernels) +find_package(hip QUIET) +if(NOT hip_FOUND) + list(APPEND CMAKE_PREFIX_PATH /opt/rocm /opt/rocm/hip) + find_package(hip REQUIRED) +endif() + +# Dispatcher library +add_library(ck_tile_dispatcher + src/registry.cpp + src/dispatcher.cpp +) + +# Enable PIC for Python bindings +set_target_properties(ck_tile_dispatcher PROPERTIES + POSITION_INDEPENDENT_CODE ON +) + +target_include_directories(ck_tile_dispatcher + PUBLIC + $ + $ +) + +# Link against CK Tile headers (header-only) +target_include_directories(ck_tile_dispatcher + PUBLIC + $ + $ +) + +# Link against HIP headers if available +if(hip_FOUND) + target_link_libraries(ck_tile_dispatcher PUBLIC hip::host) +endif() + +# Compiler warnings +if(CMAKE_CXX_COMPILER_ID MATCHES "GNU|Clang") + target_compile_options(ck_tile_dispatcher PRIVATE + -Wall -Wextra -Wpedantic + ) +elseif(CMAKE_CXX_COMPILER_ID MATCHES "MSVC") + target_compile_options(ck_tile_dispatcher PRIVATE + /W4 + ) +endif() + +# Optional: Build tests +option(BUILD_DISPATCHER_TESTS "Build dispatcher unit tests" OFF) +if(BUILD_DISPATCHER_TESTS) + enable_testing() + add_subdirectory(tests) +endif() + +# Optional: Build Python bindings +option(BUILD_DISPATCHER_PYTHON "Build Python bindings for dispatcher" OFF) +if(BUILD_DISPATCHER_PYTHON) + add_subdirectory(python) +endif() + +# Optional: Codegen for tile_engine integration +option(DISPATCHER_AUTO_GENERATE_WRAPPERS "Auto-generate wrappers from tile_engine" OFF) +if(DISPATCHER_AUTO_GENERATE_WRAPPERS) + add_subdirectory(codegen) +endif() + +# Optional: Build examples +option(BUILD_DISPATCHER_EXAMPLES "Build dispatcher examples" OFF) +if(BUILD_DISPATCHER_EXAMPLES) + add_subdirectory(examples) +endif() + +# Optional: Build ctypes bindings +option(BUILD_DISPATCHER_BINDINGS "Build language bindings for dispatcher" OFF) +if(BUILD_DISPATCHER_BINDINGS) + add_subdirectory(bindings/ctypes) +endif() + +# If codegen is enabled, add generated include directory +if(DISPATCHER_AUTO_GENERATE_WRAPPERS AND DISPATCHER_GENERATED_INCLUDE_DIR) + target_include_directories(ck_tile_dispatcher + PUBLIC + $ + ) +endif() + +# Installation +install(TARGETS ck_tile_dispatcher + EXPORT ck_tile_dispatcher_targets + LIBRARY DESTINATION lib + ARCHIVE DESTINATION lib + RUNTIME DESTINATION bin +) + +install(DIRECTORY include/ + DESTINATION include + FILES_MATCHING PATTERN "*.hpp" +) + +install(EXPORT ck_tile_dispatcher_targets + FILE ck_tile_dispatcher_targets.cmake + NAMESPACE ck_tile:: + DESTINATION lib/cmake/ck_tile_dispatcher +) + diff --git a/dispatcher/README.md b/dispatcher/README.md new file mode 100644 index 00000000000..fa3fbd3a596 --- /dev/null +++ b/dispatcher/README.md @@ -0,0 +1,736 @@ +# CK Tile Dispatcher + +A unified kernel dispatch system for AMD GPUs with C++ and Python frontends. + +**Validated Platform:** AMD Instinct MI300 series (gfx942) + + +--- + +## Table of Contents + +1. [Quick Start](#quick-start) +2. [Docker Setup](#docker-setup-recommended) +3. [Prerequisites](#prerequisites) +4. [Step-by-Step Build Guide](#step-by-step-build-guide) +5. [Running Examples](#running-examples) +6. [External Integration](#external-integration) +7. [Core Concepts](#core-concepts) +8. [Troubleshooting](#troubleshooting) +9. [File Structure](#file-structure) + +--- + +## Quick Start + +**Complete setup from scratch (5 minutes):** + +```bash +# From the composable_kernel root directory +cd dispatcher + +# Step 1: Create build directory +mkdir -p build && cd build + +# Step 2: Configure CMake +cmake .. \ + -DCMAKE_PREFIX_PATH=/opt/rocm \ + -DCMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \ + -DCMAKE_BUILD_TYPE=Release \ + -DGPU_TARGETS="gfx942" \ + -DBUILD_DISPATCHER_EXAMPLES=ON + +# Step 3: Generate kernels and build (CMake handles this automatically) +make -j$(nproc) + +# Step 4: Run C++ examples +./examples/gemm_01_basic + +# Step 5: Build Python libraries (required for Python examples) +make python_libs + +# Step 6: Run Python examples (from dispatcher directory) +cd .. +python3 examples/gemm/python/01_basic_gemm.py +``` + +--- + +## Docker Setup (Recommended) + +For a reproducible build environment, use the official ROCm Docker image: + +### Step 1: Pull and Run Container + +```bash +# Pull the CK Docker image +docker pull rocm/composable_kernel:ck_ub24.04_rocm7.0.1 + +# Run container with GPU access +docker run \ + -it \ + --privileged \ + --device=/dev/kfd \ + --device=/dev/dri \ + --group-add video \ + --group-add render \ + -w /root/workspace \ + -v $(pwd):/root/workspace \ + rocm/composable_kernel:ck_ub24.04_rocm7.0.1 \ + /bin/bash +``` + +> **Note:** Omit `--device` flags if building without GPU access. + +### Step 2: Clone and Build + +```bash +# Inside the container +git clone https://github.com/ROCm/composable_kernel.git +cd composable_kernel +git checkout builder-dispatch-tile-gemm + +# Set up Python environment +python3 -m venv .venv +source .venv/bin/activate +pip install numpy + +# Build dispatcher +cd dispatcher +mkdir -p build && cd build +cmake .. \ + -DCMAKE_PREFIX_PATH=/opt/rocm \ + -DCMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \ + -DCMAKE_BUILD_TYPE=Release \ + -DGPU_TARGETS="gfx942" \ + -DBUILD_DISPATCHER_EXAMPLES=ON + +make -j$(nproc) +``` + +### One-Liner Build (inside container) + +```bash +git clone https://github.com/ROCm/composable_kernel.git && \ +cd composable_kernel && git checkout builder-dispatch-tile-gemm && \ +python3 -m venv .venv && source .venv/bin/activate && pip install numpy && \ +cd dispatcher && mkdir -p build && cd build && \ +cmake .. -DCMAKE_PREFIX_PATH=/opt/rocm -DCMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \ + -DCMAKE_BUILD_TYPE=Release -DGPU_TARGETS="gfx942" -DBUILD_DISPATCHER_EXAMPLES=ON && \ +make -j$(nproc) +``` + +--- + +## Prerequisites + +### Required Software + +| Software | Minimum Version | Check Command | +|----------|-----------------|---------------| +| ROCm | 6.4+ | `rocminfo` | +| CMake | 3.16+ | `cmake --version` | +| Python | 3.8+ | `python3 --version` | +| NumPy | 1.20+ | `pip show numpy` | +| hipcc | (from ROCm) | `/opt/rocm/bin/hipcc --version` | + +> **Note:** Newer GPU targets (gfx950, gfx1201) require ROCm 6.3+. For ROCm 6.4+, you can also use `amdclang++` instead of `hipcc`. + +### Check Your GPU Architecture + +```bash +# Find your GPU architecture +rocminfo | grep -i "gfx" +# Example output: "gfx942" +``` + +**Supported architectures:** +- **gfx942** - MI300X, MI300A, MI308, MI325 (Instinct MI300 series) +- **gfx90a** - MI200 series (MI250, MI250X) +- **gfx950** - MI350 series +- **gfx1101** - RDNA3 series +- **gfx1201** - RDNA4 series + +### Install Python Dependencies + +NumPy is required for Python examples and kernel generation. We recommend using a virtual environment: + +**Option 1: Using standard venv** +```bash +# Create virtual environment +python3 -m venv .venv + +# Activate virtual environment +source .venv/bin/activate # Linux/macOS +# .venv\Scripts\activate # Windows + +# Install NumPy +pip install numpy +``` + +**Option 2: Using uv (faster alternative)** +```bash +# Install uv if not already installed +curl -LsSf https://astral.sh/uv/install.sh | sh + +# Create and activate virtual environment +uv venv .venv +source .venv/bin/activate # Linux/macOS +# .venv\Scripts\activate # Windows + +# Install NumPy +uv pip install numpy +``` + +**Option 3: System-wide install (not recommended)** +```bash +pip install numpy +``` + +> **Note:** Always activate your virtual environment before running CMake or Python examples. + +### Supported Data Types + +CK Tile supports a wide range of data types for GEMM operations: + +| A dtype | B dtype | Acc dtype | Warp Tile Sizes | Notes | +|---------|---------|-----------|-----------------|-------| +| `fp32` | `fp32` | `fp32` | 16x16x4, 16x16x16 | Full precision | +| `fp16` | `fp16` | `fp32` | 32x32x8, 32x32x16, 16x16x16, 16x16x32 | Standard half | +| `bf16` | `bf16` | `fp32` | 32x32x8, 32x32x16, 16x16x16, 16x16x32 | Brain float 16 | +| `fp8` | `fp8` | `fp32` | 32x32x16, 32x32x32, 16x16x32, 16x16x64 | FP8 E4M3 | +| `fp8` | `bf8` | `fp32` | 32x32x16, 16x16x32 | Mixed FP8/BF8 | +| `bf8` | `fp8` | `fp32` | 32x32x16, 16x16x128 | Mixed BF8/FP8 | +| `bf8` | `bf8` | `fp32` | 32x32x16, 32x32x32, 16x16x32 | BF8 E5M2 | +| `int8` | `int8` | `int32` | 32x32x16, 16x16x32, 16x16x16 | Integer GEMM | +| `pk_fp4` | `pk_fp4` | `fp32` | 16x16x128 | Packed 4-bit float | + +**Notes:** +- Accumulator is always `fp32` except for `int8` which uses `int32` +- FP8 types: `fp8` = E4M3, `bf8` = E5M2 +- `pk_fp4` = Packed 4-bit float (2 values per byte) +- Some dtypes require specific GPU architectures (e.g., FP8 requires MI300+) + +--- + +## Step-by-Step Build Guide + +### Step 1: Navigate to Dispatcher Directory + +```bash +# From composable_kernel root +cd dispatcher + +# Verify you're in the right place +ls CMakeLists.txt # Should exist +``` + +### Step 2: Create Build Directory + +```bash +mkdir -p build +cd build +``` + +### Step 3: Configure CMake + +**Basic configuration (library only):** +```bash +cmake .. \ + -DCMAKE_PREFIX_PATH=/opt/rocm \ + -DCMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \ + -DCMAKE_BUILD_TYPE=Release \ + -DGPU_TARGETS="gfx942" +``` + +**Full configuration (with examples and tests):** +```bash +cmake .. \ + -DCMAKE_PREFIX_PATH=/opt/rocm \ + -DCMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \ + -DCMAKE_BUILD_TYPE=Release \ + -DGPU_TARGETS="gfx942" \ + -DBUILD_DISPATCHER_EXAMPLES=ON \ + -DBUILD_DISPATCHER_TESTS=ON +``` + +**Expected output:** +``` +-- Found hip: /opt/rocm (found suitable version "6.x.x") +-- Generating GEMM kernels... +-- Built: gemm_01 through gemm_06, dispatcher_gemm_lib.so +-- Configuring done +``` + +### Step 4: Build + +```bash +# Build all targets (generates kernels automatically, then compiles) +make -j$(nproc) + +# Or build specific targets +make gemm_01_basic # Single GEMM example +make dispatcher_gemm_lib # GEMM shared library for Python + +# Build ONLY Python libraries (faster if you don't need C++ examples) +make python_libs -j$(nproc) +``` + +### Kernel Generation Targets + +Kernels are generated automatically during `make`, but you can also control generation explicitly: + +```bash +# Generate all kernels only (no compilation) +make generate_all_kernels + +# Generate GEMM kernels only +make generate_gemm_kernels + +# Force regenerate (even if kernels exist) +make regenerate_all_kernels +make regenerate_gemm_kernels + +# Generate for specific GPU architecture +make generate_kernels_gfx942 # MI300X +make generate_kernels_gfx90a # MI200 +make generate_kernels_gfx1100 # RDNA3 +``` + +### Step 5: Verify Build + +```bash +# Check executables were built +ls examples/gemm_* + +# Check shared libraries were built +ls examples/libdispatcher_gemm_lib.so +``` + +### CMake Options Reference + +| Flag | Default | Description | +|------|---------|-------------| +| `CMAKE_BUILD_TYPE` | Debug | **Use `Release` for performance!** | +| `GPU_TARGETS` | None | Target GPU: `"gfx942"`, `"gfx90a"`, etc. | +| `BUILD_DISPATCHER_EXAMPLES` | OFF | Build C++ examples and Python libs | +| `BUILD_DISPATCHER_TESTS` | OFF | Build unit tests | +| `CMAKE_PREFIX_PATH` | - | ROCm installation path | +| `CMAKE_CXX_COMPILER` | - | Path to hipcc compiler | + +⚠️ **Important:** Always use `-DCMAKE_BUILD_TYPE=Release` for benchmarking. Debug builds are slower. +⚠️ **Important:** Note that the current system provides single GPU target support for architecture-based kernel filtering, please do not use multiple GPU targets at a time (if necessary, please compile into different build directories). + +--- + +## Running Examples + +### C++ Examples + +After building, executables are in `build/examples/`: + +```bash +cd build/examples + +# GEMM Examples +./gemm_01_basic # Basic GEMM with autofill/autocorrect +./gemm_02_multi_size # Wildcard expansion +./gemm_03_benchmark_validation # Benchmarking + validation +./gemm_04_heuristics # Heuristic kernel selection +./gemm_05_json_export # Registry JSON export +./gemm_06_multi_registry # Multiple registries +``` + +### Python Examples + +Run from the `dispatcher` directory: + +```bash +cd /path/to/composable_kernel/dispatcher + +# GEMM Examples +python3 examples/gemm/python/01_basic_gemm.py # Basic multi-kernel GEMM +python3 examples/gemm/python/04_validation.py # CPU reference validation +python3 examples/gemm/python/07_stress_test.py # Stress test (48 kernels) +python3 examples/gemm/python/08_heuristics.py # Heuristic selection +``` + +### Example Output + +**Expected C++ output (`gemm_01_basic`):** +``` +====================================================================== +Example 01: Basic GEMM with Declarative Kernel Definition +====================================================================== + +Step 1: Declared Kernels +------------------------ +Kernel Set: fp16_gemm_kernels + Architecture: gfx942 + Configurations: 1 + - gemm_fp16_rcr_compv4_cshuffle_intrawave_128x128x32 + +Step 2: Create Registry and Dispatcher +-------------------------------------- + Registered 1 kernels + +Step 3: Define Problem +---------------------- + M=1024, N=1024, K=1024 + +Step 4: GPU Execution +--------------------- + *** GPU EXECUTION *** + Time: ms + TFLOPS: +``` + +> **Note:** Timing values vary by GPU model and system configuration. + +--- + +## Benchmark Parameters + +The dispatcher supports fine-grained control over benchmarking, matching CK Tile's `stream_config`: + +### Available Parameters + +| Parameter | Type | Default | Description | +|-----------|------|---------|-------------| +| `warmup` | int | 5 | Warmup iterations (discarded from timing) | +| `repeat` | int | 20 | Benchmark iterations (averaged) | +| `flush_cache` | bool | false | Flush GPU L2 cache between iterations | +| `rotating_count` | int | 1 | Rotating buffer count (for cache simulation) | +| `timer` | string | "gpu" | Timer type: "gpu" (HIP events) or "cpu" | +| `init` | string | "random" | Matrix initialization: "random", "linear", "constant" | +| `split_k` | int | 1 | Split-K parallelism factor | + +### Python Usage + +```python +from ctypes_utils import DispatcherLib + +# Basic usage (default benchmark settings) +lib = DispatcherLib.load() + +# Advanced benchmark settings via command line +python3 examples/gemm/python/10_advanced_benchmark.py \ + --warmup 10 \ + --repeat 100 \ + --flush-cache +``` + +### C++ Usage + +```cpp +// Basic timing +ck_tile::stream_config cfg{nullptr, true}; + +// Advanced benchmark settings +ck_tile::stream_config cfg{ + nullptr, // stream_id (nullptr = default stream) + true, // time_kernel + 1, // log_level + 10, // cold_niters (warmup) + 100, // nrepeat + true, // is_gpu_timer + true, // flush_cache + 4 // rotating_count +}; + +float avg_time = kernel.run(args, cfg); +``` + +### Command Line (Python Examples) + +```bash +# Basic run +python3 examples/gemm/python/10_advanced_benchmark.py + +# With benchmark parameters +python3 examples/gemm/python/10_advanced_benchmark.py \ + --warmup 10 \ + --repeat 100 \ + --flush-cache \ + --rotating-count 4 \ + --timer gpu +``` + +### When to Use Each Parameter + +| Use Case | Recommended Settings | +|----------|---------------------| +| Quick test | `warmup=1, repeat=3` | +| Stable benchmark | `warmup=10, repeat=100` | +| Memory-bound analysis | `flush_cache=True, rotating_count=4` | +| Compute-bound analysis | `flush_cache=False` (default) | +| Debug timing | `timer="cpu"` | +| Production | `timer="gpu"` (default) | + +--- + +## External Integration + +### Using Dispatcher in Your Own Project + +#### Option 1: CMake Integration (Recommended) + +Add to your `CMakeLists.txt`: + +```cmake +# Set path to composable_kernel +set(CK_ROOT "/path/to/composable_kernel") + +# Add dispatcher subdirectory +add_subdirectory(${CK_ROOT}/dispatcher dispatcher_build) + +# Link to your target +target_link_libraries(your_target PRIVATE ck_tile_dispatcher) +target_include_directories(your_target PRIVATE + ${CK_ROOT}/dispatcher/include + ${CK_ROOT}/include +) +``` + +#### Option 2: Include as Pre-built Library + +```cmake +# Find the pre-built library +find_library(CK_DISPATCHER ck_tile_dispatcher + PATHS /path/to/composable_kernel/dispatcher/build) + +# Include directories +set(CK_INCLUDE_DIRS + /path/to/composable_kernel/include + /path/to/composable_kernel/dispatcher/include +) + +target_link_libraries(your_target PRIVATE ${CK_DISPATCHER}) +target_include_directories(your_target PRIVATE ${CK_INCLUDE_DIRS}) +``` + +#### Option 3: Python Integration + +```python +import sys +sys.path.insert(0, "/path/to/composable_kernel/dispatcher/examples/gemm/python") + +# For GEMM +from ctypes_utils import DispatcherLib, Dispatcher, KernelConfig +``` + +### Required Include Paths + +When integrating, you need these include paths: + +``` +/path/to/composable_kernel/include # CK Tile core headers +/path/to/composable_kernel/dispatcher/include # Dispatcher headers +/path/to/composable_kernel/dispatcher/build/generated_kernels # Generated kernels +``` + +### Required Compile Flags + +```bash +# Minimum flags for hipcc +-std=c++17 +-D__HIP_PLATFORM_AMD__=1 +--offload-arch=gfx942 # Your target GPU + +# Recommended flags +-O3 +-mllvm -enable-noalias-to-md-conversion=0 +-Wno-undefined-func-template +-Wno-float-equal +-Wall +-Werror +``` + +### Python Path Setup + +For Python scripts outside the dispatcher directory: + +```bash +# Option 1: Environment variable +export PYTHONPATH="/path/to/composable_kernel/dispatcher/examples/gemm/python:$PYTHONPATH" + +# Option 2: In your Python script +import sys +sys.path.insert(0, "/path/to/composable_kernel/dispatcher/examples/gemm/python") +``` + +### Library Search Paths + +The Python utilities search for the shared library in these locations: + +```python +# For GEMM (ctypes_utils.py) +SEARCH_PATHS = [ + "build/examples/libdispatcher_gemm_lib.so", + "../build/examples/libdispatcher_gemm_lib.so", + "../../build/examples/libdispatcher_gemm_lib.so", +] +``` + +If using from a different location, set the library path explicitly: + +```python +# GEMM +from ctypes_utils import DispatcherLib +lib = DispatcherLib.load("/absolute/path/to/libdispatcher_gemm_lib.so") +``` + +--- + +## Core Concepts + +### Data Flow + +``` +KernelConfig → Registry → Dispatcher → GPU Execution +``` + +1. **KernelConfig**: Defines kernel parameters (tile sizes, data types, layouts) +2. **Registry**: Stores multiple kernel configurations +3. **Dispatcher**: Selects best kernel for a given problem and executes it + +### GEMM Layouts + +| Layout | A | B | C | Use Case | +|--------|---|---|---|----------| +| RCR | Row | Col | Row | Most common (PyTorch default) | +| RRR | Row | Row | Row | Both inputs row-major | +| CRR | Col | Row | Row | A transposed | +| CCR | Col | Col | Row | Both inputs column-major | + +### Split-K Support + +Split-K divides the K dimension across multiple thread blocks, useful for large K dimensions. + +**Usage (C++):** +```cpp +// GEMM with 4-way K split +auto problem = ProblemBuilder() + .m(1024).n(1024).k(8192) + .split_k(4) + .build(); +``` + +--- + +## Troubleshooting + +### Build Issues + +| Problem | Solution | +|---------|----------| +| `hipcc not found` | Set `-DCMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc` | +| `hip not found` | Set `-DCMAKE_PREFIX_PATH=/opt/rocm` | +| Very slow performance | Use `-DCMAKE_BUILD_TYPE=Release` | +| `gfx942 not supported` | Check ROCm version (need 6.0+) | +| Kernel generation fails | Ensure Python 3.8+ with NumPy installed in active venv | +| Build errors | First verify CK builds without dispatcher (see main CK README) | + +### Runtime Issues + +| Problem | Solution | +|---------|----------| +| `Library not found` | Build with `-DBUILD_DISPATCHER_EXAMPLES=ON` | +| `No kernel found` | Check GPU arch matches build target | +| Python `ModuleNotFoundError` | Add paths to `PYTHONPATH` (see above) | +| Wrong results | Verify layout matches your data | + +### Debug Commands + +```bash +# Check ROCm installation +rocminfo | head -20 + +# Check GPU architecture +rocminfo | grep "Name:" + +# Verify library exists +ls -la build/examples/libdispatcher_*.so + +# Run with verbose output +./build/examples/gemm_01_basic 2>&1 + +# Python: Check library loading +python3 -c " +import ctypes +lib = ctypes.CDLL('/path/to/libdispatcher_gemm_lib.so') +print('Library loaded successfully') +" +``` + +### Clean Rebuild + +If you encounter issues, try a clean rebuild: + +```bash +cd dispatcher +rm -rf build +mkdir build && cd build +cmake .. [your options] +make -j$(nproc) +``` + +--- + +## File Structure + +``` +dispatcher/ +├── README.md # This file +├── CMakeLists.txt # Build configuration +│ +├── include/ck_tile/dispatcher/ # C++ headers +│ ├── dispatcher.hpp # GEMM dispatcher +│ ├── registry.hpp # Kernel registry +│ └── kernel_key.hpp # Kernel configuration +│ +├── src/ # C++ implementation +│ +├── codegen/ # Kernel generation +│ ├── unified_gemm_codegen.py # GEMM kernel generator +│ └── arch_specs.json # GPU specifications +│ +├── bindings/ctypes/ # Python ctypes interface +│ └── gemm_ctypes_lib.cpp # GEMM Python library +│ +├── examples/ # Examples +│ └── gemm/ +│ ├── cpp/ # C++ GEMM examples (01-06) +│ └── python/ # Python GEMM examples (01-11) +│ +├── scripts/ # Build scripts +│ +└── tests/ # Unit tests +``` + +--- + +## Example Documentation + +| Directory | README | +|-----------|--------| +| GEMM C++ | [examples/gemm/cpp/README.md](examples/gemm/cpp/README.md) | +| GEMM Python | [examples/gemm/python/README.md](examples/gemm/python/README.md) | +| Codegen | [codegen/README.md](codegen/README.md) | + +--- + +## Archived Content + +Convolution examples and utilities have been archived to `ck-2/conv_archive/dispatcher/`: +- `examples/conv/cpp/` - 11 C++ convolution examples +- `examples/conv/python/` - 14 Python convolution examples +- `codegen/unified_conv_codegen.py` - Conv kernel generator +- `include/ck_tile/dispatcher/conv_*.hpp` - Conv headers +- `python/conv_utils.py` - Conv Python utilities + +--- + +## License + +MIT License - Copyright (c) 2025, Advanced Micro Devices, Inc. diff --git a/dispatcher/bindings/README.md b/dispatcher/bindings/README.md new file mode 100644 index 00000000000..7cda21f6ec2 --- /dev/null +++ b/dispatcher/bindings/README.md @@ -0,0 +1,109 @@ +# CK Tile Dispatcher - Language Bindings + +This directory contains language bindings for the CK Tile Dispatcher. + +## Structure + +``` +bindings/ +├── ctypes/ # Python ctypes bindings (C API) +│ ├── gemm_ctypes_lib.cpp # GEMM dispatcher C API +│ ├── conv_ctypes_lib.cpp # Convolution dispatcher C API (fwd + bwd_data) +│ ├── conv_bwdw_ctypes_lib.cpp # Convolution backward weight C API +│ ├── gpu_helper.cpp # CLI helper for Python +│ └── CMakeLists.txt +└── README.md +``` + +## ctypes Bindings + +The ctypes bindings provide a C API that Python can load via `ctypes.CDLL()`. + +### Building + +```bash +cd build +cmake .. -DCMAKE_PREFIX_PATH=/opt/rocm +make dispatcher_gemm_lib dispatcher_conv_lib gpu_helper +``` + +### Usage from Python + +```python +import ctypes + +# Load the library +lib = ctypes.CDLL("path/to/libdispatcher_gemm_lib.so") + +# Initialize +lib.dispatcher_init() + +# Check if problem is supported +is_supported = lib.dispatcher_is_supported(M, N, K) + +# Run GEMM +time_ms = ctypes.c_float() +result = lib.dispatcher_run_gemm( + A_ptr, B_ptr, C_ptr, + M, N, K, + ctypes.byref(time_ms) +) + +# Cleanup +lib.dispatcher_cleanup() +``` + +### GEMM API + +| Function | Description | +|----------|-------------| +| `dispatcher_init()` | Initialize the dispatcher | +| `dispatcher_is_supported(M, N, K)` | Check if problem size is supported | +| `dispatcher_select_kernel(M, N, K, name_buf, buf_size)` | Get kernel name for problem | +| `dispatcher_run_gemm(A, B, C, M, N, K, time_ms)` | Execute GEMM | +| `dispatcher_get_kernel_count()` | Get number of registered kernels | +| `dispatcher_export_registry_json()` | Export registry as JSON | +| `dispatcher_cleanup()` | Release resources | + +### Convolution API + +| Function | Description | +|----------|-------------| +| `conv_dispatcher_init()` | Initialize the dispatcher | +| `conv_dispatcher_is_supported(prob)` | Check if problem is supported | +| `conv_dispatcher_select_kernel(prob, name_buf, buf_size)` | Get kernel name | +| `conv_dispatcher_run(input, weight, output, prob, stream)` | Execute convolution | +| `conv_dispatcher_get_kernel_count()` | Get number of registered kernels | +| `conv_dispatcher_cleanup()` | Release resources | + +## GPU Helper + +The `gpu_helper` executable provides a CLI interface for Python: + +```bash +./gpu_helper 1024 1024 1024 --validate +``` + +Output is JSON for easy parsing: +```json +{ + "problem": {"M": 1024, "N": 1024, "K": 1024}, + "kernel": "gemm_fp16_rcr_...", + "execution": { + "time_ms": 0.5, + "tflops": 4.2 + }, + "validation": { + "accuracy": 100.0 + }, + "status": "success" +} +``` + +## Examples + +See the examples that use these bindings: + +- **GEMM**: `dispatcher/examples/gemm/python/` +- **Conv**: `dispatcher/examples/conv/python/` + diff --git a/dispatcher/bindings/ctypes/CMakeLists.txt b/dispatcher/bindings/ctypes/CMakeLists.txt new file mode 100644 index 00000000000..804e5e9bd70 --- /dev/null +++ b/dispatcher/bindings/ctypes/CMakeLists.txt @@ -0,0 +1,181 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +# ============================================================================= +# CK Tile Dispatcher - ctypes Bindings +# ============================================================================= +# +# Provides shared libraries with C API for Python ctypes integration. +# +# Targets: +# - dispatcher_gemm_lib : GEMM dispatcher library +# - dispatcher_conv_lib : Convolution dispatcher library (forward + bwd_data) +# - dispatcher_conv_bwdw_lib : Convolution backward weight library +# - gpu_helper : GPU helper executable for Python +# + +cmake_minimum_required(VERSION 3.16) + +# Helper function to add a ctypes library +function(add_ctypes_library TARGET_NAME SOURCE_FILE) + cmake_parse_arguments(ARG "CONV" "KERNEL_HEADER" "" ${ARGN}) + + add_library(${TARGET_NAME} SHARED ${SOURCE_FILE}) + + target_include_directories(${TARGET_NAME} PRIVATE + ${PROJECT_SOURCE_DIR}/include + ${PROJECT_SOURCE_DIR}/dispatcher/include + ) + + target_link_libraries(${TARGET_NAME} PRIVATE + hip::device + ) + + # Force-include kernel header if provided + if(ARG_KERNEL_HEADER AND EXISTS ${ARG_KERNEL_HEADER}) + target_compile_options(${TARGET_NAME} PRIVATE + -include ${ARG_KERNEL_HEADER} + ) + if(ARG_CONV) + target_compile_definitions(${TARGET_NAME} PRIVATE CONV_KERNEL_AVAILABLE) + endif() + endif() + + set_target_properties(${TARGET_NAME} PROPERTIES + POSITION_INDEPENDENT_CODE ON + CXX_STANDARD 17 + ) +endfunction() + +# ============================================================================= +# GEMM ctypes Library +# ============================================================================= + +# Find a generated GEMM kernel header for the library +file(GLOB GEMM_KERNEL_HEADERS "${CMAKE_BINARY_DIR}/generated_kernels/gemm_*.hpp") +if(GEMM_KERNEL_HEADERS) + list(GET GEMM_KERNEL_HEADERS 0 GEMM_KERNEL_HEADER) + message(STATUS "Found GEMM kernel for ctypes lib: ${GEMM_KERNEL_HEADER}") + + add_ctypes_library(dispatcher_gemm_lib + gemm_ctypes_lib.cpp + KERNEL_HEADER ${GEMM_KERNEL_HEADER} + ) +else() + message(STATUS "No GEMM kernel found for ctypes lib - building without kernel") + add_library(dispatcher_gemm_lib SHARED gemm_ctypes_lib.cpp) + target_include_directories(dispatcher_gemm_lib PRIVATE + ${PROJECT_SOURCE_DIR}/include + ${PROJECT_SOURCE_DIR}/dispatcher/include + ) + target_link_libraries(dispatcher_gemm_lib PRIVATE hip::device) +endif() + +# ============================================================================= +# Convolution ctypes Library (supports forward + bwd_data) +# ============================================================================= + +# Look for forward kernels +file(GLOB CONV_FWD_KERNEL_HEADERS "${CMAKE_BINARY_DIR}/generated_kernels/conv_fwd_*.hpp") +# Look for backward data kernels +file(GLOB CONV_BWDD_KERNEL_HEADERS "${CMAKE_BINARY_DIR}/generated_kernels/conv_bwdd_*.hpp") +# Fallback: any conv kernel (for backwards compatibility) +file(GLOB CONV_KERNEL_HEADERS "${CMAKE_BINARY_DIR}/generated_kernels/conv_*.hpp") + +add_library(dispatcher_conv_lib SHARED conv_ctypes_lib.cpp) +target_include_directories(dispatcher_conv_lib PRIVATE + ${PROJECT_SOURCE_DIR}/include + ${PROJECT_SOURCE_DIR}/dispatcher/include +) +target_link_libraries(dispatcher_conv_lib PRIVATE hip::device) +set_target_properties(dispatcher_conv_lib PROPERTIES + POSITION_INDEPENDENT_CODE ON + CXX_STANDARD 17 +) + +# Add forward kernel if available +if(CONV_FWD_KERNEL_HEADERS) + list(GET CONV_FWD_KERNEL_HEADERS 0 CONV_FWD_KERNEL_HEADER) + message(STATUS "Found Conv FWD kernel for ctypes lib: ${CONV_FWD_KERNEL_HEADER}") + target_compile_options(dispatcher_conv_lib PRIVATE -include ${CONV_FWD_KERNEL_HEADER}) + target_compile_definitions(dispatcher_conv_lib PRIVATE CONV_KERNEL_AVAILABLE) +elseif(CONV_KERNEL_HEADERS) + # Fallback to any conv kernel + list(GET CONV_KERNEL_HEADERS 0 CONV_KERNEL_HEADER) + message(STATUS "Found Conv kernel for ctypes lib: ${CONV_KERNEL_HEADER}") + target_compile_options(dispatcher_conv_lib PRIVATE -include ${CONV_KERNEL_HEADER}) + target_compile_definitions(dispatcher_conv_lib PRIVATE CONV_KERNEL_AVAILABLE) +else() + message(STATUS "No Conv FWD kernel found for ctypes lib - building without kernel") +endif() + +# Add backward data kernel if available +if(CONV_BWDD_KERNEL_HEADERS) + list(GET CONV_BWDD_KERNEL_HEADERS 0 CONV_BWDD_KERNEL_HEADER) + message(STATUS "Found Conv BWD_DATA kernel for ctypes lib: ${CONV_BWDD_KERNEL_HEADER}") + target_compile_options(dispatcher_conv_lib PRIVATE -include ${CONV_BWDD_KERNEL_HEADER}) + target_compile_definitions(dispatcher_conv_lib PRIVATE CONV_BWD_DATA_AVAILABLE) +endif() + +# ============================================================================= +# Convolution Backward Weight ctypes Library (separate lib for bwd_weight) +# ============================================================================= + +file(GLOB CONV_BWDW_KERNEL_HEADERS "${CMAKE_BINARY_DIR}/generated_kernels/conv_*bwd_weight*.hpp") +if(CONV_BWDW_KERNEL_HEADERS) + list(GET CONV_BWDW_KERNEL_HEADERS 0 CONV_BWDW_KERNEL_HEADER) + message(STATUS "Found Conv BwdWeight kernel for ctypes lib: ${CONV_BWDW_KERNEL_HEADER}") + + add_library(dispatcher_conv_bwdw_lib SHARED conv_bwdw_ctypes_lib.cpp) + target_include_directories(dispatcher_conv_bwdw_lib PRIVATE + ${PROJECT_SOURCE_DIR}/include + ${PROJECT_SOURCE_DIR}/dispatcher/include + ) + target_link_libraries(dispatcher_conv_bwdw_lib PRIVATE hip::device) + target_compile_options(dispatcher_conv_bwdw_lib PRIVATE + -include ${CONV_BWDW_KERNEL_HEADER} + ) + target_compile_definitions(dispatcher_conv_bwdw_lib PRIVATE CONV_BWD_WEIGHT_AVAILABLE) + set_target_properties(dispatcher_conv_bwdw_lib PROPERTIES + POSITION_INDEPENDENT_CODE ON + CXX_STANDARD 17 + ) +else() + message(STATUS "No Conv BwdWeight kernel found for ctypes lib - building without kernel") + add_library(dispatcher_conv_bwdw_lib SHARED conv_bwdw_ctypes_lib.cpp) + target_include_directories(dispatcher_conv_bwdw_lib PRIVATE + ${PROJECT_SOURCE_DIR}/include + ${PROJECT_SOURCE_DIR}/dispatcher/include + ) + target_link_libraries(dispatcher_conv_bwdw_lib PRIVATE hip::device) + set_target_properties(dispatcher_conv_bwdw_lib PROPERTIES + POSITION_INDEPENDENT_CODE ON + CXX_STANDARD 17 + ) +endif() + +# ============================================================================= +# GPU Helper Executable +# ============================================================================= + +if(GEMM_KERNEL_HEADERS) + add_executable(gpu_helper gpu_helper.cpp) + + target_include_directories(gpu_helper PRIVATE + ${PROJECT_SOURCE_DIR}/include + ${PROJECT_SOURCE_DIR}/dispatcher/include + ) + + target_link_libraries(gpu_helper PRIVATE + hip::device + ) + + target_compile_options(gpu_helper PRIVATE + -include ${GEMM_KERNEL_HEADER} + ) + + set_target_properties(gpu_helper PROPERTIES + CXX_STANDARD 17 + ) +endif() + diff --git a/dispatcher/bindings/ctypes/conv_bwdw_ctypes_lib.cpp b/dispatcher/bindings/ctypes/conv_bwdw_ctypes_lib.cpp new file mode 100644 index 00000000000..09e058f80fa --- /dev/null +++ b/dispatcher/bindings/ctypes/conv_bwdw_ctypes_lib.cpp @@ -0,0 +1,175 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +/** + * Convolution Backward Weight Dispatcher ctypes Library + * + * SEPARATE library for backward weight to avoid template conflicts with + * forward/backward_data kernels in the main conv_ctypes_lib. + * + * Usage from Python: + * lib = ctypes.CDLL("libdispatcher_conv_bwdw_lib.so") + * lib.conv_bwdw_init() + * lib.conv_bwdw_run(...) + */ + +#include +#include +#include + +// Minimal includes - matching the C++ example +#include "ck_tile/core.hpp" +#include "ck_tile/host.hpp" +#include "ck_tile/host/convolution_parameter.hpp" +#include "ck_tile/ops/gemm.hpp" // Must be before grouped_convolution for TileGemmTraits +#include "ck_tile/ops/grouped_convolution.hpp" + +// Global state - minimal, no registry needed for direct launch +static bool g_bwdw_initialized = false; + +extern "C" { + +// ============================================================================= +// Initialization (minimal - just sets flag) +// ============================================================================= + +int conv_bwdw_init() +{ + g_bwdw_initialized = true; + return 0; // Return 0 on success (consistent with other init functions) +} + +void conv_bwdw_cleanup() { g_bwdw_initialized = false; } + +// ============================================================================= +// Problem Structure (same as main library) +// ============================================================================= + +struct ConvBwdwProblemC +{ + int N, G, C, K; + int input_d, input_h, input_w; + int filter_z, filter_y, filter_x; + int stride_d, stride_h, stride_w; + int pad_d, pad_h, pad_w; + int dilation_d, dilation_h, dilation_w; +}; + +// ============================================================================= +// Backward Weight Execution +// ============================================================================= + +#ifdef CONV_BWD_WEIGHT_AVAILABLE +static ck_tile::conv::ConvParam build_conv_param(const ConvBwdwProblemC* prob) +{ + const bool is_3d = (prob->input_d > 1 || prob->filter_z > 1); + + if(is_3d) + { + return ck_tile::conv::ConvParam{3, + prob->G, + prob->N, + prob->K, + prob->C, + {prob->filter_z, prob->filter_y, prob->filter_x}, + {prob->input_d, prob->input_h, prob->input_w}, + {prob->stride_d, prob->stride_h, prob->stride_w}, + {prob->dilation_d, prob->dilation_h, prob->dilation_w}, + {prob->pad_d, prob->pad_h, prob->pad_w}, + {prob->pad_d, prob->pad_h, prob->pad_w}}; + } + else + { + return ck_tile::conv::ConvParam{2, + prob->G, + prob->N, + prob->K, + prob->C, + {prob->filter_y, prob->filter_x}, + {prob->input_h, prob->input_w}, + {prob->stride_h, prob->stride_w}, + {prob->dilation_h, prob->dilation_w}, + {prob->pad_h, prob->pad_w}, + {prob->pad_h, prob->pad_w}}; + } +} + +static float run_bwd_weight_impl(const void* input_ptr, + const void* grad_output_ptr, + void* grad_weight_ptr, + const ConvBwdwProblemC* prob, + void* stream) +{ + auto conv_param = build_conv_param(prob); + + // Backward weight: A=input, B=grad_output, C=grad_weight + ck_tile::GroupedConvBwdWeightHostArgs args(conv_param, + input_ptr, // in_ptr = input + grad_weight_ptr, // wei_ptr = grad_weight (output) + {}, // ds_ptr + grad_output_ptr, // out_ptr = grad_output + 1 // k_batch + ); + + ck_tile::stream_config stream_cfg{static_cast(stream), true, 1, 3, 10}; + + return SelectedConvBwdWeightLauncher::launch(args, stream_cfg); +} +#endif + +float conv_bwdw_run(const void* input_ptr, + const void* grad_output_ptr, + void* grad_weight_ptr, + const ConvBwdwProblemC* prob, + void* stream) +{ +#ifdef CONV_BWD_WEIGHT_AVAILABLE + // Validate all required pointers before kernel launch + if(!g_bwdw_initialized || !prob) + return -1.0f; + if(!input_ptr || !grad_output_ptr || !grad_weight_ptr) + return -1.0f; // Null data pointer would cause kernel crash + return run_bwd_weight_impl(input_ptr, grad_output_ptr, grad_weight_ptr, prob, stream); +#else + return -1.0f; +#endif +} + +// ============================================================================= +// Info +// ============================================================================= + +const char* conv_bwdw_version() { return "1.0.0"; } + +int conv_bwdw_has_kernels() +{ +#ifdef CONV_BWD_WEIGHT_AVAILABLE + return 1; +#else + return 0; +#endif +} + +int conv_bwdw_get_kernel_count() +{ +#ifdef CONV_BWD_WEIGHT_AVAILABLE + return 1; +#else + return 0; +#endif +} + +int conv_bwdw_get_kernel_name(int index, char* buffer, int buffer_size) +{ +#ifdef CONV_BWD_WEIGHT_AVAILABLE + if(index != 0 || !buffer || buffer_size <= 0) + return -1; + std::strncpy(buffer, CONV_BWD_WEIGHT_KERNEL_NAME, buffer_size - 1); + buffer[buffer_size - 1] = '\0'; + return 0; +#else + return -1; +#endif +} + +} // extern "C" diff --git a/dispatcher/bindings/ctypes/conv_ctypes_lib.cpp b/dispatcher/bindings/ctypes/conv_ctypes_lib.cpp new file mode 100644 index 00000000000..d3c64621a7b --- /dev/null +++ b/dispatcher/bindings/ctypes/conv_ctypes_lib.cpp @@ -0,0 +1,411 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +/** + * Convolution Dispatcher ctypes Library + * + * Provides C API for Python ctypes integration. + * Supports forward convolution. Backward operations require additional headers. + * + * REQUIRED: Forward kernel header must be force-included via -include flag. + * OPTIONAL: Backward kernels can be added with CONV_BWD_DATA_AVAILABLE/CONV_BWD_WEIGHT_AVAILABLE + * + * Usage from Python: + * lib = ctypes.CDLL("libdispatcher_conv.so") + * lib.conv_dispatcher_init() + * lib.conv_dispatcher_run(...) + */ + +#include +#include +#include +#include + +#include "ck_tile/dispatcher/conv_utils.hpp" +#include "ck_tile/core.hpp" +#include "ck_tile/host.hpp" + +using namespace ck_tile::dispatcher; + +// Global state (using shared_ptr for safe memory management) +static std::shared_ptr g_registry = nullptr; +static std::shared_ptr g_dispatcher = nullptr; +static std::vector g_kernels; + +extern "C" { + +// ============================================================================= +// Initialization +// ============================================================================= + +int conv_dispatcher_init() +{ + if(g_registry) + return 0; // Already initialized + + g_registry = std::make_shared(); + g_dispatcher = std::make_shared(g_registry.get()); + + // Register kernel configurations using simple ConvKernelSet + // (actual kernel launch uses the force-included SelectedConvKernelLauncher) + using namespace ck_tile::dispatcher::conv_decl; + + // Forward kernels (required - must be force-included) + // Must match: conv_fwd_fp16_nhwgc_2d_compv4_cshuffle_intrawave_128x128x64_2x2x1_32x32x16_dsb + ConvKernelSet fwd_set; + fwd_set.add(ConvSignature().dtype("fp16").layout("nhwgc").conv_type("forward").dims(2), + ConvAlgorithm() + .tile(128, 128, 64) // tile_m x tile_n x tile_k + .wave(2, 2, 1) + .warp(32, 32, 16) + .pipeline("compv4") + .scheduler("intrawave"), + "gfx942"); + g_registry->register_set(fwd_set, ConvRegistry::Priority::High); + +#ifdef CONV_BWD_DATA_AVAILABLE + // Backward data kernels + // Must match: conv_bwdd_fp16_nhwgc_2d_compv3_cshuffle_intrawave_128x128x64_2x2x1_32x32x16 + ConvKernelSet bwd_data_set; + bwd_data_set.add(ConvSignature().dtype("fp16").layout("nhwgc").conv_type("bwd_data").dims(2), + ConvAlgorithm() + .tile(128, 128, 64) // tile_m x tile_n x tile_k + .wave(2, 2, 1) + .warp(32, 32, 16) + .pipeline("compv3") + .scheduler("intrawave"), + "gfx942"); + g_registry->register_set(bwd_data_set, ConvRegistry::Priority::High); +#endif + + return 0; +} + +int conv_dispatcher_cleanup() +{ + // shared_ptr automatically handles cleanup when reset + g_dispatcher.reset(); + g_registry.reset(); + g_kernels.clear(); + return 0; +} + +// ============================================================================= +// Registry Management +// ============================================================================= + +int conv_dispatcher_get_kernel_count() +{ + if(!g_registry) + return 0; + return static_cast(g_registry->size()); +} + +int conv_dispatcher_get_kernel_name(int index, char* buffer, int buffer_size) +{ + if(index < 0 || !buffer || buffer_size <= 0) + return -1; + + if(!g_registry) + return -1; + + // Use registry to get kernel names (they are registered with full names) + const auto& kernels = g_registry->all_kernels(); + if(static_cast(index) >= kernels.size()) + return -1; + + const auto* kernel = kernels[index]; + std::strncpy(buffer, kernel->name().c_str(), buffer_size - 1); + buffer[buffer_size - 1] = '\0'; + return 0; +} + +// ============================================================================= +// Problem Definition +// ============================================================================= + +struct ConvProblemC +{ + int N, G, C, K; + int input_d, input_h, input_w; + int filter_z, filter_y, filter_x; + int stride_d, stride_h, stride_w; + int pad_d, pad_h, pad_w; + int dilation_d, dilation_h, dilation_w; + int direction; // 0=forward, 1=bwd_data, 2=bwd_weight +}; + +// ============================================================================= +// Kernel Selection +// ============================================================================= + +int conv_dispatcher_is_supported(const ConvProblemC* prob) +{ + if(!g_registry || !prob) + return 0; + + ConvProblem problem; + problem.N = prob->N; + problem.G = prob->G; + problem.C = prob->C; + problem.K = prob->K; + problem.input_spatial = {prob->input_d, prob->input_h, prob->input_w}; + problem.filter_spatial = {prob->filter_z, prob->filter_y, prob->filter_x}; + problem.stride = {prob->stride_d, prob->stride_h, prob->stride_w}; + problem.padding = {prob->pad_d, prob->pad_h, prob->pad_w}; + problem.dilation = {prob->dilation_d, prob->dilation_h, prob->dilation_w}; + problem.op = static_cast(prob->direction); + problem.compute_output_size(); + + const auto* kernel = g_dispatcher->select(problem); + return kernel ? 1 : 0; +} + +int conv_dispatcher_select_kernel(const ConvProblemC* prob, char* kernel_name, int buffer_size) +{ + if(!g_registry || !prob || !kernel_name || buffer_size <= 0) + return -1; + + ConvProblem problem; + problem.N = prob->N; + problem.G = prob->G; + problem.C = prob->C; + problem.K = prob->K; + problem.input_spatial = {prob->input_d, prob->input_h, prob->input_w}; + problem.filter_spatial = {prob->filter_z, prob->filter_y, prob->filter_x}; + problem.stride = {prob->stride_d, prob->stride_h, prob->stride_w}; + problem.padding = {prob->pad_d, prob->pad_h, prob->pad_w}; + problem.dilation = {prob->dilation_d, prob->dilation_h, prob->dilation_w}; + problem.op = static_cast(prob->direction); + problem.compute_output_size(); + + const auto* kernel = g_dispatcher->select(problem); + if(!kernel) + return -1; + + std::strncpy(kernel_name, kernel->name().c_str(), buffer_size - 1); + kernel_name[buffer_size - 1] = '\0'; + + return 0; +} + +// ============================================================================= +// Convolution Execution +// ============================================================================= + +// Helper to build ConvParam +static ck_tile::conv::ConvParam build_conv_param(const ConvProblemC* prob) +{ + // Determine if this is 2D or 3D convolution + const bool is_3d = (prob->input_d > 1 || prob->filter_z > 1); + + if(is_3d) + { + // 3D convolution: use all spatial dimensions + return ck_tile::conv::ConvParam{3, + prob->G, + prob->N, + prob->K, + prob->C, + {prob->filter_z, prob->filter_y, prob->filter_x}, + {prob->input_d, prob->input_h, prob->input_w}, + {prob->stride_d, prob->stride_h, prob->stride_w}, + {prob->dilation_d, prob->dilation_h, prob->dilation_w}, + {prob->pad_d, prob->pad_h, prob->pad_w}, + {prob->pad_d, prob->pad_h, prob->pad_w}}; + } + else + { + // 2D convolution: only use H, W dimensions + return ck_tile::conv::ConvParam{2, + prob->G, + prob->N, + prob->K, + prob->C, + {prob->filter_y, prob->filter_x}, + {prob->input_h, prob->input_w}, + {prob->stride_h, prob->stride_w}, + {prob->dilation_h, prob->dilation_w}, + {prob->pad_h, prob->pad_w}, + {prob->pad_h, prob->pad_w}}; + } +} + +// Forward convolution (required - kernel header must be force-included) +static float run_forward(const void* input_ptr, + const void* weight_ptr, + void* output_ptr, + const ConvProblemC* prob, + void* stream) +{ + auto conv_param = build_conv_param(prob); + + ck_tile::GroupedConvFwdHostArgs<> args(conv_param, input_ptr, weight_ptr, {}, output_ptr, 1); + + ck_tile::stream_config stream_cfg{static_cast(stream), true, 1, 3, 10}; + + // SelectedConvKernelLauncher is defined in the force-included forward kernel header + return SelectedConvKernelLauncher::launch(args, stream_cfg); +} + +#ifdef CONV_BWD_DATA_AVAILABLE +// Backward data convolution (optional) +// Computes: grad_input = conv_bwd_data(weight, grad_output) +// +// Parameters: +// grad_output_ptr: dY - gradient from next layer (const, read-only INPUT) +// weight_ptr: W - frozen weights (const, read-only INPUT) +// grad_input_ptr: dX - gradient for input (writable, OUTPUT) +static float run_bwd_data(const void* grad_output_ptr, + const void* weight_ptr, + void* grad_input_ptr, + const ConvProblemC* prob, + void* stream) +{ + auto conv_param = build_conv_param(prob); + + // CK Tile API uses tensor POSITION names (from forward pass), not data flow: + // in_ptr = input tensor position = grad_input_ptr (dX, OUTPUT of bwd_data) + // wei_ptr = weight tensor = weight_ptr (W, const) + // out_ptr = output tensor position = grad_output_ptr (dY, INPUT to bwd_data) + ck_tile::GroupedConvBwdDataHostArgs args( + conv_param, grad_input_ptr, weight_ptr, {}, grad_output_ptr, 1); + + ck_tile::stream_config stream_cfg{static_cast(stream), true, 1, 3, 10}; + + return SelectedConvBwdDataLauncher::launch(args, stream_cfg); +} +#endif + +#ifdef CONV_BWD_WEIGHT_AVAILABLE +// Backward weight convolution (optional) +// Parameters: +// input_ptr: original forward input X (const, read-only) +// grad_output_ptr: gradient from next layer dY (const, read-only) +// grad_weight_ptr: gradient of weights dW (writable, OUTPUT) +static float run_bwd_weight(const void* input_ptr, + const void* grad_output_ptr, + void* grad_weight_ptr, + const ConvProblemC* prob, + void* stream) +{ + auto conv_param = build_conv_param(prob); + + // GroupedConvBwdWeightHostArgs constructor order: + // (param, in=X, wei=dW (output), ds, out=dY (input), k_batch) + // Note: wei_ptr is the OUTPUT (grad_weight), out_ptr is the INPUT (grad_output) + ck_tile::GroupedConvBwdWeightHostArgs args( + conv_param, input_ptr, grad_weight_ptr, {}, grad_output_ptr, 1); + + ck_tile::stream_config stream_cfg{static_cast(stream), true, 1, 3, 10}; + + return SelectedConvBwdWeightLauncher::launch(args, stream_cfg); +} +#endif + +/** + * @brief Execute convolution based on direction specified in prob + * + * Parameter mapping varies by direction: + * Forward (direction=0): + * input_ptr = X (input tensor) + * weight_ptr = W (weight tensor) + * output_ptr = Y (output buffer) + * + * Backward Data (direction=1): + * input_ptr = dY (grad_output - gradient from next layer) + * weight_ptr = W (weight tensor, frozen) + * output_ptr = dX (grad_input buffer) + * + * Backward Weight (direction=2): + * input_ptr = X (forward input tensor) + * weight_ptr = dY (grad_output - gradient from next layer) + * output_ptr = dW (grad_weight buffer) + */ +float conv_dispatcher_run(const void* input_ptr, + const void* weight_ptr, + void* output_ptr, + const ConvProblemC* prob, + void* stream) +{ + // Validate all required pointers before kernel launch + if(!g_dispatcher || !prob) + return -1.0f; + if(!input_ptr || !weight_ptr || !output_ptr) + return -1.0f; // Null data pointer would cause kernel crash + + // Build problem for kernel selection + ConvProblem problem; + problem.N = prob->N; + problem.G = prob->G; + problem.C = prob->C; + problem.K = prob->K; + problem.input_spatial = {prob->input_d, prob->input_h, prob->input_w}; + problem.filter_spatial = {prob->filter_z, prob->filter_y, prob->filter_x}; + problem.stride = {prob->stride_d, prob->stride_h, prob->stride_w}; + problem.padding = {prob->pad_d, prob->pad_h, prob->pad_w}; + problem.dilation = {prob->dilation_d, prob->dilation_h, prob->dilation_w}; + problem.op = static_cast(prob->direction); + problem.compute_output_size(); + + // Select kernel + const auto* kernel = g_dispatcher->select(problem); + if(!kernel) + return -1.0f; + + // Dispatch based on direction + switch(prob->direction) + { + case 0: // Forward (always available) + return run_forward(input_ptr, weight_ptr, output_ptr, prob, stream); + +#ifdef CONV_BWD_DATA_AVAILABLE + case 1: // Backward data + // Convention: caller passes (grad_output, weight, grad_input_buffer) + // in the (input_ptr, weight_ptr, output_ptr) slots respectively. + // run_bwd_data expects: (grad_output, weight, grad_input) + return run_bwd_data(input_ptr, weight_ptr, output_ptr, prob, stream); +#endif + +#ifdef CONV_BWD_WEIGHT_AVAILABLE + case 2: // Backward weight + // Convention: caller passes (input, grad_output, grad_weight_buffer) + // in the (input_ptr, weight_ptr, output_ptr) slots respectively. + // run_bwd_weight expects: (input, grad_output, grad_weight) + return run_bwd_weight(input_ptr, weight_ptr, output_ptr, prob, stream); +#endif + + default: return -1.0f; + } +} + +// ============================================================================= +// Info +// ============================================================================= + +const char* conv_dispatcher_version() { return "1.0.0"; } + +int conv_dispatcher_has_kernels() +{ + return 1; // Forward kernel is required +} + +int conv_dispatcher_has_bwd_data() +{ +#ifdef CONV_BWD_DATA_AVAILABLE + return 1; +#else + return 0; +#endif +} + +int conv_dispatcher_has_bwd_weight() +{ +#ifdef CONV_BWD_WEIGHT_AVAILABLE + return 1; +#else + return 0; +#endif +} + +} // extern "C" diff --git a/dispatcher/bindings/ctypes/gemm_ctypes_lib.cpp b/dispatcher/bindings/ctypes/gemm_ctypes_lib.cpp new file mode 100644 index 00000000000..85c0c2f2c13 --- /dev/null +++ b/dispatcher/bindings/ctypes/gemm_ctypes_lib.cpp @@ -0,0 +1,401 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +/** + * GEMM Dispatcher ctypes Library + * + * Provides C API for Python ctypes integration. + * Kernel header included via -include at compile time. + * + * Usage from Python: + * lib = ctypes.CDLL("libdispatcher_gemm.so") + * lib.dispatcher_init() + * lib.dispatcher_run_gemm(...) + */ + +#include +#include +#include +#include +#include +#include +#include + +#include "ck_tile/dispatcher/dispatcher.hpp" +#include "ck_tile/dispatcher/registry.hpp" +#include "ck_tile/dispatcher/backends/generated_tile_backend.hpp" + +// Kernel header included via -include compiler flag +// Defines: ADataType, BDataType, CDataType, AccDataType, SelectedKernel, KERNEL_NAME + +// GPU architecture - can be overridden via -DGFX_ARCH="gfx90a" at compile time +#ifndef GFX_ARCH +#define GFX_ARCH "gfx942" +#endif + +using namespace ck_tile::dispatcher; +using namespace ck_tile::dispatcher::backends; +using Priority = ck_tile::dispatcher::Registry::Priority; + +// Global dispatcher (initialized once, managed via shared_ptr for safe cleanup) +static std::shared_ptr g_dispatcher = nullptr; +static bool g_initialized = false; + +#define HIP_CHECK(call) \ + { \ + hipError_t err = call; \ + if(err != hipSuccess) \ + { \ + return -1; \ + } \ + } + +extern "C" { + +/** + * Initialize dispatcher with a kernel + * Must be called before run_gemm + * + * Returns: 0 on success, -1 on error + */ +int dispatcher_initialize() +{ + if(g_initialized) + { + return 0; // Already initialized + } + + // Create kernel key from the force-included kernel header + KernelKey key; + key.signature.dtype_a = DataType::FP16; + key.signature.dtype_b = DataType::FP16; + key.signature.dtype_c = DataType::FP16; + key.signature.dtype_acc = DataType::FP32; + key.signature.layout_a = LayoutTag::RowMajor; + key.signature.layout_b = LayoutTag::ColMajor; + key.signature.layout_c = LayoutTag::RowMajor; + key.signature.transpose_a = false; + key.signature.transpose_b = false; + key.signature.grouped = false; + key.signature.split_k = 1; + key.signature.elementwise_op = "PassThrough"; + key.signature.num_d_tensors = 0; + key.signature.structured_sparsity = false; + + key.algorithm.tile_shape = {128, 128, 32}; + key.algorithm.wave_shape = {2, 2, 1}; + key.algorithm.warp_tile_shape = {32, 32, 16}; + key.algorithm.pipeline = Pipeline::CompV4; + key.algorithm.scheduler = Scheduler::Intrawave; + key.algorithm.epilogue = Epilogue::CShuffle; + key.algorithm.block_size = 256; + key.algorithm.double_buffer = true; + key.algorithm.persistent = false; + key.algorithm.preshuffle = false; + key.algorithm.transpose_c = false; + key.algorithm.num_wave_groups = 1; + key.gfx_arch = GFX_ARCH; + + // Register kernel using types from force-included header + auto kernel = + create_generated_tile_kernel( + key, KERNEL_NAME); + + Registry::instance().clear(); + Registry::instance().register_kernel(kernel, Priority::High); + + // Create dispatcher (using shared_ptr for safe memory management) + g_dispatcher = std::make_shared(); + g_initialized = true; + + return 0; +} + +/** + * Get kernel tile configuration + */ +int dispatcher_get_kernel_config(int* tile_m, + int* tile_n, + int* tile_k, + int* warp_tile_m, + int* warp_tile_n, + int* warp_tile_k, + int* warp_m, + int* warp_n, + int* warp_k) +{ + if(!g_initialized) + { + return -1; + } + + auto kernels = Registry::instance().get_all(); + if(kernels.empty()) + { + return -1; + } + + // Get configuration from first kernel + auto& key = kernels[0]->get_key(); + auto& algo = key.algorithm; + + if(tile_m) + *tile_m = algo.tile_shape.m; + if(tile_n) + *tile_n = algo.tile_shape.n; + if(tile_k) + *tile_k = algo.tile_shape.k; + if(warp_tile_m) + *warp_tile_m = algo.warp_tile_shape.m; + if(warp_tile_n) + *warp_tile_n = algo.warp_tile_shape.n; + if(warp_tile_k) + *warp_tile_k = algo.warp_tile_shape.k; + if(warp_m) + *warp_m = algo.wave_shape.m; + if(warp_n) + *warp_n = algo.wave_shape.n; + if(warp_k) + *warp_k = algo.wave_shape.k; + + return 0; +} + +/** + * Get the selected kernel name for a problem + */ +int dispatcher_select_kernel(int64_t M, int64_t N, int64_t K, char* name_buffer, int buffer_size) +{ + if(!g_initialized || !name_buffer || buffer_size <= 0) + { + return -1; + } + + Problem problem(M, N, K); + auto kernel = g_dispatcher->select_kernel(problem); + + if(!kernel) + { + return -1; + } + + std::string name = kernel->get_name(); + strncpy(name_buffer, name.c_str(), buffer_size - 1); + name_buffer[buffer_size - 1] = '\0'; + + return 0; +} + +/** + * Check if a problem size is supported by available kernels + */ +int dispatcher_is_supported(int64_t M, int64_t N, int64_t K) +{ + if(!g_initialized) + { + return 0; + } + + if(M <= 0 || N <= 0 || K <= 0) + { + return 0; + } + + Problem problem(M, N, K); + auto kernel = g_dispatcher->select_kernel(problem); + return kernel != nullptr ? 1 : 0; +} + +/** + * Run GEMM on GPU via dispatcher + */ +int dispatcher_run_gemm( + const void* A, const void* B, void* C, int64_t M, int64_t N, int64_t K, float* time_ms) +{ + if(!g_initialized || !A || !B || !C) + { + return -1; + } + + // First check if any kernel supports this problem + Problem problem(M, N, K); + auto kernel = g_dispatcher->select_kernel(problem); + if(!kernel) + { + if(time_ms) + { + *time_ms = -1.0f; + } + return -2; // No suitable kernel + } + + // Cast to correct types (from force-included header) + const ADataType* A_host = static_cast(A); + const BDataType* B_host = static_cast(B); + CDataType* C_host = static_cast(C); + + // Allocate GPU memory + ADataType* A_dev = nullptr; + BDataType* B_dev = nullptr; + CDataType* C_dev = nullptr; + + auto cleanup_gpu_mem = [&]() { + if(A_dev) + (void)hipFree(A_dev); + if(B_dev) + (void)hipFree(B_dev); + if(C_dev) + (void)hipFree(C_dev); + }; + + if(hipMalloc(&A_dev, M * K * sizeof(ADataType)) != hipSuccess) + { + cleanup_gpu_mem(); + return -1; + } + if(hipMalloc(&B_dev, K * N * sizeof(BDataType)) != hipSuccess) + { + cleanup_gpu_mem(); + return -1; + } + if(hipMalloc(&C_dev, M * N * sizeof(CDataType)) != hipSuccess) + { + cleanup_gpu_mem(); + return -1; + } + + // Copy input data to GPU + if(hipMemcpy(A_dev, A_host, M * K * sizeof(ADataType), hipMemcpyHostToDevice) != hipSuccess) + { + cleanup_gpu_mem(); + return -1; + } + if(hipMemcpy(B_dev, B_host, K * N * sizeof(BDataType), hipMemcpyHostToDevice) != hipSuccess) + { + cleanup_gpu_mem(); + return -1; + } + if(hipMemset(C_dev, 0, M * N * sizeof(CDataType)) != hipSuccess) + { + cleanup_gpu_mem(); + return -1; + } + + // Run GEMM via dispatcher + float exec_time; + try + { + exec_time = g_dispatcher->run(A_dev, B_dev, C_dev, problem); + } + catch(const std::exception& e) + { + cleanup_gpu_mem(); + return -1; + } + + // Copy result back to host + if(hipMemcpy(C_host, C_dev, M * N * sizeof(CDataType), hipMemcpyDeviceToHost) != hipSuccess) + { + cleanup_gpu_mem(); + return -1; + } + + if(time_ms) + { + *time_ms = exec_time; + } + + cleanup_gpu_mem(); + return 0; +} + +/** + * Get kernel information + */ +const char* dispatcher_get_kernel_name() { return KERNEL_NAME; } + +/** + * Initialize dispatcher (alias) + */ +int dispatcher_init() { return dispatcher_initialize(); } + +/** + * Get the number of registered kernels + */ +int dispatcher_get_kernel_count() { return static_cast(Registry::instance().size()); } + +/** + * Export registry to JSON string + */ +static std::string g_json_buffer; + +const char* dispatcher_export_registry_json() +{ + auto& registry = Registry::instance(); + + std::ostringstream json; + json << "{\n"; + json << " \"metadata\": {\n"; + json << " \"timestamp\": \"" << __DATE__ << " " << __TIME__ << "\",\n"; + json << " \"total_kernels\": " << registry.size() << ",\n"; + json << " \"export_version\": \"1.0\",\n"; + json << " \"dispatcher_version\": \"1.0.0\"\n"; + json << " },\n"; + json << " \"statistics\": {\n"; + json << " \"by_datatype\": {},\n"; + json << " \"by_pipeline\": {},\n"; + json << " \"by_scheduler\": {}\n"; + json << " },\n"; + json << " \"kernels\": [\n"; + + auto kernels = registry.get_all(); + for(size_t i = 0; i < kernels.size(); ++i) + { + auto& kernel = kernels[i]; + auto& key = kernel->get_key(); + auto& algo = key.algorithm; + std::string name = kernel->get_name(); + + json << " {\n"; + json << " \"identifier\": \"" << key.encode_identifier() << "\",\n"; + json << " \"name\": \"" << name << "\",\n"; + json << " \"algorithm\": {\n"; + json << " \"tile_shape\": {\"m\": " << algo.tile_shape.m + << ", \"n\": " << algo.tile_shape.n << ", \"k\": " << algo.tile_shape.k << "},\n"; + json << " \"wave_shape\": {\"m\": " << unsigned(algo.wave_shape.m) + << ", \"n\": " << unsigned(algo.wave_shape.n) + << ", \"k\": " << unsigned(algo.wave_shape.k) << "},\n"; + json << " \"warp_tile_shape\": {\"m\": " << unsigned(algo.warp_tile_shape.m) + << ", \"n\": " << unsigned(algo.warp_tile_shape.n) + << ", \"k\": " << unsigned(algo.warp_tile_shape.k) << "},\n"; + json << " \"block_size\": " << algo.block_size << ",\n"; + json << " \"persistent\": " << (algo.persistent ? "true" : "false") << ",\n"; + json << " \"double_buffer\": " << (algo.double_buffer ? "true" : "false") << ",\n"; + json << " \"preshuffle\": " << (algo.preshuffle ? "true" : "false") << ",\n"; + json << " \"transpose_c\": " << (algo.transpose_c ? "true" : "false") << "\n"; + json << " }\n"; + json << " }"; + if(i < kernels.size() - 1) + { + json << ","; + } + json << "\n"; + } + + json << " ]\n"; + json << "}\n"; + + g_json_buffer = json.str(); + return g_json_buffer.c_str(); +} + +/** + * Cleanup dispatcher resources + */ +void dispatcher_cleanup() +{ + g_dispatcher.reset(); + g_initialized = false; +} + +} // extern "C" diff --git a/dispatcher/bindings/ctypes/gpu_helper.cpp b/dispatcher/bindings/ctypes/gpu_helper.cpp new file mode 100644 index 00000000000..1c72c14e394 --- /dev/null +++ b/dispatcher/bindings/ctypes/gpu_helper.cpp @@ -0,0 +1,206 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +/** + * GPU Helper - C++ executable for GPU GEMM execution + * + * A CLI tool for Python to execute GPU GEMM with generated kernels. + * Usage: gpu_helper [--validate] + * + * Kernel header included via -include flag at compile time. + */ + +#include +#include +#include +#include +#include + +#include "ck_tile/dispatcher/dispatcher.hpp" +#include "ck_tile/dispatcher/registry.hpp" +#include "ck_tile/dispatcher/backends/generated_tile_backend.hpp" + +// Kernel header included via -include compiler flag + +using namespace ck_tile::dispatcher; +using namespace ck_tile::dispatcher::backends; +using Priority = ck_tile::dispatcher::Registry::Priority; + +#define HIP_CHECK(call) \ + { \ + hipError_t err = call; \ + if(err != hipSuccess) \ + { \ + std::cerr << "HIP_ERROR: " << hipGetErrorString(err) << "\n"; \ + exit(1); \ + } \ + } + +// CPU reference GEMM (for validation) +template +void cpu_gemm( + const std::vector& A, const std::vector& B, std::vector& C, int M, int N, int K) +{ + for(int m = 0; m < M; m++) + { + for(int n = 0; n < N; n++) + { + float acc = 0.0f; + for(int k = 0; k < K; k++) + { + // A: RowMajor, B: ColumnMajor + acc += float(A[m * K + k]) * float(B[k + n * K]); + } + C[m * N + n] = T(acc); + } + } +} + +int main(int argc, char** argv) +{ + // Parse arguments + if(argc < 4) + { + std::cerr << "Usage: " << argv[0] << " [--validate]\n"; + std::cerr << "\nOptions:\n"; + std::cerr << " M, N, K : Problem dimensions\n"; + std::cerr << " --validate : Compare GPU results with CPU reference\n"; + return 1; + } + + int M = std::atoi(argv[1]); + int N = std::atoi(argv[2]); + int K = std::atoi(argv[3]); + bool validate = (argc > 4 && std::string(argv[4]) == "--validate"); + + // Output in JSON-like format for easy Python parsing + std::cout << "{" << std::endl; + std::cout << " \"problem\": {\"M\": " << M << ", \"N\": " << N << ", \"K\": " << K << "}," + << std::endl; + std::cout << " \"kernel\": \"" << KERNEL_NAME << "\"," << std::endl; + + // Register kernel + KernelKey key; + key.signature.dtype_a = DataType::FP16; + key.signature.dtype_b = DataType::FP16; + key.signature.dtype_c = DataType::FP16; + key.signature.dtype_acc = DataType::FP32; + key.signature.layout_a = LayoutTag::RowMajor; + key.signature.layout_b = LayoutTag::ColMajor; + key.signature.layout_c = LayoutTag::RowMajor; + key.signature.transpose_a = false; + key.signature.transpose_b = false; + key.signature.grouped = false; + key.signature.split_k = 1; + key.signature.elementwise_op = "PassThrough"; + key.signature.num_d_tensors = 0; + key.signature.structured_sparsity = false; + + key.algorithm.tile_shape = {128, 128, 32}; + key.algorithm.wave_shape = {2, 2, 1}; + key.algorithm.warp_tile_shape = {32, 32, 16}; + key.algorithm.pipeline = Pipeline::CompV4; + key.algorithm.scheduler = Scheduler::Intrawave; + key.algorithm.epilogue = Epilogue::CShuffle; + key.algorithm.block_size = 256; + key.algorithm.double_buffer = true; + key.algorithm.persistent = false; + key.algorithm.preshuffle = false; + key.algorithm.transpose_c = false; + key.algorithm.num_wave_groups = 1; + key.gfx_arch = "gfx942"; + + auto kernel = + create_generated_tile_kernel( + key, KERNEL_NAME); + + Registry::instance().clear(); + Registry::instance().register_kernel(kernel, Priority::High); + + Dispatcher dispatcher; + Problem problem(M, N, K); + + auto selected = dispatcher.select_kernel(problem); + if(!selected) + { + std::cout << " \"error\": \"No kernel selected\"" << std::endl; + std::cout << "}" << std::endl; + return 1; + } + + std::cout << " \"selected_kernel\": \"" << selected->get_name() << "\"," << std::endl; + + // Prepare data: A=1, B=1, so C should be K + std::vector A_host(M * K, ADataType(1.0f)); + std::vector B_host(K * N, BDataType(1.0f)); + std::vector C_gpu(M * N); + + // GPU execution + ADataType *A_dev, *B_dev; + CDataType* C_dev; + + HIP_CHECK(hipMalloc(&A_dev, M * K * sizeof(ADataType))); + HIP_CHECK(hipMalloc(&B_dev, K * N * sizeof(BDataType))); + HIP_CHECK(hipMalloc(&C_dev, M * N * sizeof(CDataType))); + + HIP_CHECK(hipMemcpy(A_dev, A_host.data(), M * K * sizeof(ADataType), hipMemcpyHostToDevice)); + HIP_CHECK(hipMemcpy(B_dev, B_host.data(), K * N * sizeof(BDataType), hipMemcpyHostToDevice)); + HIP_CHECK(hipMemset(C_dev, 0, M * N * sizeof(CDataType))); + + float gpu_time = dispatcher.run(A_dev, B_dev, C_dev, problem); + + HIP_CHECK(hipMemcpy(C_gpu.data(), C_dev, M * N * sizeof(CDataType), hipMemcpyDeviceToHost)); + + // Calculate performance + double flops = 2.0 * M * N * K; + double tflops = (flops / (gpu_time * 1e-3)) / 1e12; + + std::cout << " \"execution\": {" << std::endl; + std::cout << " \"time_ms\": " << gpu_time << "," << std::endl; + std::cout << " \"tflops\": " << tflops << "," << std::endl; + std::cout << " \"flops\": " << (long long)flops << std::endl; + std::cout << " }," << std::endl; + + // Validation + if(validate) + { + std::vector C_cpu(M * N); + cpu_gemm(A_host, B_host, C_cpu, M, N, K); + + int correct = 0; + float max_error = 0.0f; + + for(int i = 0; i < M * N; i++) + { + float gpu_val = float(C_gpu[i]); + float cpu_val = float(C_cpu[i]); + float error = std::abs(gpu_val - cpu_val) / (std::abs(cpu_val) + 1e-5f); + + max_error = std::max(max_error, error); + + if(error < 0.02f) + { + correct++; + } + } + + float accuracy = 100.0f * correct / (M * N); + + std::cout << " \"validation\": {" << std::endl; + std::cout << " \"accuracy\": " << accuracy << "," << std::endl; + std::cout << " \"max_error\": " << max_error << "," << std::endl; + std::cout << " \"correct_elements\": " << correct << "," << std::endl; + std::cout << " \"total_elements\": " << M * N << std::endl; + std::cout << " }," << std::endl; + } + + std::cout << " \"status\": \"success\"" << std::endl; + std::cout << "}" << std::endl; + + // Cleanup + HIP_CHECK(hipFree(A_dev)); + HIP_CHECK(hipFree(B_dev)); + HIP_CHECK(hipFree(C_dev)); + + return 0; +} diff --git a/dispatcher/codegen/ADDING_NEW_GPU.md b/dispatcher/codegen/ADDING_NEW_GPU.md new file mode 100644 index 00000000000..0bd2966a857 --- /dev/null +++ b/dispatcher/codegen/ADDING_NEW_GPU.md @@ -0,0 +1,197 @@ +# Adding New GPU Architecture Support + +Guide for adding support for a new AMD GPU architecture to the CK Tile Dispatcher. + +> **See also:** [Main Dispatcher README](../README.md) | [Codegen README](README.md) + +## Overview + +The dispatcher uses `arch_specs.json` as the **single source of truth** for GPU specifications: + +``` +arch_specs.json → generate_arch_specs.py → arch_specs_generated.py (Python) + → arch_specs_generated.hpp (C++) +``` + +## Quick Start + +```bash +# 1. Edit arch_specs.json +# 2. Run generator +python generate_arch_specs.py +# 3. Rebuild +cd ../build && cmake --build . -j8 +# 4. Test +ctest +``` + +## Step-by-Step Guide + +### Step 1: Edit arch_specs.json + +Add new architecture under `"architectures"`: + +```json +{ + "architectures": { + "gfx1100": { + "family": "rdna3", + "description": "AMD Radeon RX 7000 series (RDNA3)", + "warp_size": 32, + "lds_capacity_kb": 64, + "warp_configs": [ + [2, 4, 1], + [4, 2, 1] + ], + "warp_tile_combos": { + "fp16_fp16_fp16": [[16, 16, 16], [32, 32, 16]], + "bf16_bf16_bf16": [[16, 16, 16], [32, 32, 16]] + } + } + } +} +``` + +### Step 2: Configuration Fields + +| Field | Description | Example | +|-------|-------------|---------| +| `family` | GPU family | `"cdna3"`, `"rdna4"` | +| `description` | Human-readable name | `"AMD Instinct MI300"` | +| `warp_size` | Wave/warp size | `64` (CDNA), `32` (RDNA) | +| `lds_capacity_kb` | LDS memory in KB | `64` | +| `warp_configs` | Valid `[warp_m, warp_n, warp_k]` | `[[2,2,1], [4,4,1]]` | +| `warp_tile_combos` | Warp tiles per dtype | See below | + +### Step 3: Warp Tile Combinations + +Map data type combinations to valid warp tile sizes: + +```json +"warp_tile_combos": { + "fp16_fp16_fp16": [[32, 32, 8], [16, 16, 16], [32, 32, 16]], + "bf16_bf16_bf16": [[32, 32, 8], [16, 16, 16]], + "fp8_fp8_fp16": [[32, 32, 16], [32, 32, 32]], + "int8_int8_int32": [[16, 16, 32], [32, 32, 16]] +} +``` + +Key format: `{A_dtype}_{B_dtype}_{C_dtype}` + +### Step 4: Run Generator + +```bash +cd dispatcher/codegen +python generate_arch_specs.py +``` + +This generates: +- `arch_specs_generated.py` (Python module) +- `../include/ck_tile/dispatcher/arch_specs_generated.hpp` (C++ header) + +### Step 5: Rebuild and Test + +```bash +cd ../build +cmake --build . -j8 +ctest --output-on-failure +``` + +### Step 6: Verify + +```python +from arch_filter import ArchFilter + +filter = ArchFilter("gfx1100") +is_valid = filter.is_kernel_valid( + datatype_a="fp16", datatype_b="fp16", datatype_c="fp16", + tile_m=128, tile_n=128, tile_k=32, + warp_m=2, warp_n=2, warp_k=1, + warp_tile_m=16, warp_tile_n=16, warp_tile_k=16 +) +print(f"Valid: {is_valid}") +``` + +## Reference + +### Supported Data Types + +| Key | Description | +|-----|-------------| +| `fp16` | Half precision (16-bit) | +| `bf16` | Brain float 16 | +| `fp32` | Single precision (32-bit) | +| `fp64` | Double precision (64-bit) | +| `fp8` | 8-bit float (E4M3) | +| `bf8` | 8-bit brain float (E5M2) | +| `int8` | 8-bit integer | +| `int4` | 4-bit integer | + +### GPU Families + +| Family | Description | +|--------|-------------| +| `cdna2` | MI200 series (gfx90a) | +| `cdna3` | MI300 series (gfx942) | +| `cdna4` | MI350 series (gfx950) | +| `rdna3` | RX 7000 series (gfx1100) | +| `rdna4` | RX 9000 series (gfx1201) | + +### Pipeline LDS Limits + +| Pipeline | LDS Limit | +|----------|-----------| +| `compv4` | 32 KB | +| `preshufflev2` | 32 KB | +| `default` | 64 KB | + +## Troubleshooting + +### "Unknown GPU architecture" + +1. Check architecture key matches exactly (e.g., `"gfx942"` not `"GFX942"`) +2. Verify you ran `generate_arch_specs.py` +3. Rebuild C++ code + +### Kernels being rejected + +```python +from arch_filter import ArchFilter, KernelConfig + +filter = ArchFilter("gfx942") +result = filter.validate_kernel(config) +print(f"Valid: {result.valid}") +for error in result.errors: + print(f" Error: {error}") +``` + +### Missing warp tile combination + +1. Check `warp_tile_combos` in `arch_specs.json` +2. Ensure `[warp_tile_m, warp_tile_n, warp_tile_k]` is in the list +3. Verify data type key format + +## File Structure + +``` +codegen/ +├── arch_specs.json # Single source of truth (EDIT THIS) +├── generate_arch_specs.py # Generator script +├── arch_specs_generated.py # Generated Python module +└── ADDING_NEW_GPU.md # This file + +include/ck_tile/dispatcher/ +├── arch_specs_generated.hpp # Generated C++ header +└── arch_filter.hpp # C++ filter +``` + +## Best Practices + +1. **Test thoroughly** - Run all tests after adding a new GPU +2. **Start minimal** - Add only validated configurations +3. **Document sources** - Note where warp tile combinations came from +4. **Keep in sync** - If using tile_engine, keep both updated + +--- + +> **More info:** See [../README.md](../README.md) for full documentation. diff --git a/dispatcher/codegen/CMakeLists.txt b/dispatcher/codegen/CMakeLists.txt new file mode 100644 index 00000000000..e63dcaab67f --- /dev/null +++ b/dispatcher/codegen/CMakeLists.txt @@ -0,0 +1,125 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +# CK Tile GEMM Unified Code Generator + +cmake_minimum_required(VERSION 3.16) + +# Find Python +find_package(Python3 COMPONENTS Interpreter REQUIRED) + +# Configuration +set(CODEGEN_SCRIPT "${CMAKE_CURRENT_SOURCE_DIR}/unified_gemm_codegen.py") +set(CODEGEN_CONFIG "${CMAKE_CURRENT_SOURCE_DIR}/default_config.json") +set(CODEGEN_OUTPUT_DIR "${CMAKE_BINARY_DIR}/generated/tile_gemm") + +# Configurable options +set(CK_TILE_GEMM_DATATYPE "fp16" CACHE STRING "GEMM data type (fp16, bf16, fp32, fp8, bf8, int8)") +set(CK_TILE_GEMM_LAYOUT "rcr" CACHE STRING "GEMM layout (rcr, rrr, crr, ccr)") +set(CK_TILE_GEMM_VARIANTS "standard" CACHE STRING "GEMM variants (standard, preshuffle, multi_d)") +set(CK_TILE_GEMM_GPU_TARGET "gfx942" CACHE STRING "Target GPU architecture") +set(CK_TILE_GEMM_PARALLEL ON CACHE BOOL "Enable parallel generation") + +# Custom target to run code generation +add_custom_target(generate_tile_gemm_kernels + COMMAND ${Python3_EXECUTABLE} ${CODEGEN_SCRIPT} + --output-dir ${CODEGEN_OUTPUT_DIR} + --datatype ${CK_TILE_GEMM_DATATYPE} + --layout ${CK_TILE_GEMM_LAYOUT} + --gpu-target ${CK_TILE_GEMM_GPU_TARGET} + --config ${CODEGEN_CONFIG} + --variants ${CK_TILE_GEMM_VARIANTS} + $<$>:--no-parallel> + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} + COMMENT "Generating CK Tile GEMM kernels and dispatcher wrappers..." + VERBATIM +) + +# Create output directory +file(MAKE_DIRECTORY ${CODEGEN_OUTPUT_DIR}) + +# Add generated headers to include path +include_directories(${CODEGEN_OUTPUT_DIR}) + +# Installation +install(FILES + ${CODEGEN_SCRIPT} + ${CODEGEN_CONFIG} + README.md + DESTINATION share/ck_tile/codegen +) + +# Helper function for projects to generate kernels +function(ck_tile_generate_gemm_kernels) + set(options PARALLEL) + set(oneValueArgs OUTPUT_DIR DATATYPE LAYOUT GPU_TARGET CONFIG) + set(multiValueArgs VARIANTS) + cmake_parse_arguments(ARG "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) + + # Set defaults + if(NOT ARG_OUTPUT_DIR) + set(ARG_OUTPUT_DIR "${CMAKE_BINARY_DIR}/generated/tile_gemm") + endif() + if(NOT ARG_DATATYPE) + set(ARG_DATATYPE "fp16") + endif() + if(NOT ARG_LAYOUT) + set(ARG_LAYOUT "rcr") + endif() + if(NOT ARG_GPU_TARGET) + set(ARG_GPU_TARGET "gfx942") + endif() + if(NOT ARG_CONFIG) + set(ARG_CONFIG "${CMAKE_CURRENT_SOURCE_DIR}/default_config.json") + endif() + if(NOT ARG_VARIANTS) + set(ARG_VARIANTS "standard") + endif() + + # Build command + set(CMD ${Python3_EXECUTABLE} ${CODEGEN_SCRIPT} + --output-dir ${ARG_OUTPUT_DIR} + --datatype ${ARG_DATATYPE} + --layout ${ARG_LAYOUT} + --gpu-target ${ARG_GPU_TARGET} + --config ${ARG_CONFIG} + --variants ${ARG_VARIANTS} + ) + + if(NOT ARG_PARALLEL) + list(APPEND CMD --no-parallel) + endif() + + # Execute + execute_process( + COMMAND ${CMD} + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} + RESULT_VARIABLE RESULT + OUTPUT_VARIABLE OUTPUT + ERROR_VARIABLE ERROR + ) + + if(NOT RESULT EQUAL 0) + message(FATAL_ERROR "Failed to generate GEMM kernels:\n${ERROR}") + else() + message(STATUS "Generated GEMM kernels: ${OUTPUT}") + endif() +endfunction() + +# Example usage documentation +message(STATUS "CK Tile GEMM Code Generator configured") +message(STATUS " Script: ${CODEGEN_SCRIPT}") +message(STATUS " Config: ${CODEGEN_CONFIG}") +message(STATUS " Output: ${CODEGEN_OUTPUT_DIR}") +message(STATUS "") +message(STATUS "To generate kernels:") +message(STATUS " cmake --build . --target generate_tile_gemm_kernels") +message(STATUS "") +message(STATUS "Or use CMake function:") +message(STATUS " ck_tile_generate_gemm_kernels(") +message(STATUS " OUTPUT_DIR ./generated") +message(STATUS " DATATYPE fp16") +message(STATUS " LAYOUT rcr") +message(STATUS " VARIANTS standard preshuffle multi_d") +message(STATUS " PARALLEL") +message(STATUS " )") diff --git a/dispatcher/codegen/README.md b/dispatcher/codegen/README.md new file mode 100644 index 00000000000..2d753924f58 --- /dev/null +++ b/dispatcher/codegen/README.md @@ -0,0 +1,123 @@ +# CK Tile GEMM Unified Code Generator + +Single source of truth for all GEMM kernel generation. + +> **See also:** [Main Dispatcher README](../README.md) for installation and core concepts. + +## Quick Start + +```bash +cd dispatcher/codegen + +# Generate standard FP16 kernels +python3 unified_gemm_codegen.py \ + --output-dir ../build/generated_kernels \ + --datatype fp16 \ + --layout rcr \ + --variants standard + +# Generate all variants +python3 unified_gemm_codegen.py \ + --output-dir ../build/generated_kernels \ + --variants standard preshuffle multi_d +``` + +## Using from Python + +```python +from ctypes_utils import CodegenRunner, KernelConfig + +# Generate from specific config +config = KernelConfig(tile_m=256, tile_n=256, tile_k=64) +codegen = CodegenRunner() +result = codegen.generate_from_config(config) + +# Generate variant +result = codegen.generate("preshuffle") + +# Generate all +results = codegen.generate_all() +``` + +## Command Line Options + +| Option | Values | Description | +|--------|--------|-------------| +| `--output-dir` | path | Output directory | +| `--datatype` | `fp16`, `bf16`, `fp32`, `int8` | Data type | +| `--layout` | `rcr`, `rrr`, `crr`, `ccr` | Matrix layouts | +| `--gpu-target` | `gfx942`, `gfx90a`, `gfx950` | Target GPU | +| `--variants` | `standard`, `preshuffle`, `multi_d` | Kernel variants | +| `--preselected` | `fp16_rcr_essential`, etc. | Predefined kernel set | + +### Layout Notation + +- `R` = Row-major, `C` = Column-major +- Order: A, B, C (e.g., `rcr` = A row, B col, C row) + +## Variants + +### Standard +Basic GEMM: `C = A × B` + +### PreShuffle +Optimized weight access with LDS pre-shuffling. Best for large matrices. + +### Multi-D +Element-wise fusion: `C = op(A × B + D0 + D1 + ...)` + +Supported ops: `PassThrough`, `MultiDAdd`, `Relu`, `Gelu`, `Sigmoid`, `Tanh` + +## Output Structure + +``` +generated_kernels/ +├── gemm_fp16_rcr_compv4_..._128x128x32_....hpp +├── gemm_fp16_rcr_compv4_..._preshuffle.hpp +├── gemm_fp16_rcr_compv4_..._multid_Relu_d1.hpp +└── ... +``` + +## Configuration Files + +### arch_specs.json + +GPU architecture specifications (single source of truth): + +```json +{ + "architectures": { + "gfx942": { + "family": "cdna3", + "warp_size": 64, + "warp_configs": [[2, 2, 1], [4, 4, 1]], + ... + } + } +} +``` + +### preselected_kernels.py + +Curated kernel sets for common use cases. + +## Adding New GPU Support + +See [ADDING_NEW_GPU.md](ADDING_NEW_GPU.md) for complete guide. + +Quick steps: +1. Edit `arch_specs.json` +2. Run `python generate_arch_specs.py` +3. Rebuild + +## Troubleshooting + +| Issue | Solution | +|-------|----------| +| "Arguments not supported" | Check tile config validity | +| Missing element-wise op | Check `elementwise_ops.hpp` | +| Compilation errors | Verify C++17, include paths | + +--- + +> **More info:** See [../README.md](../README.md) for full documentation. diff --git a/dispatcher/codegen/arch_filter.py b/dispatcher/codegen/arch_filter.py new file mode 100644 index 00000000000..67f146045b4 --- /dev/null +++ b/dispatcher/codegen/arch_filter.py @@ -0,0 +1,1012 @@ +#!/usr/bin/env python + +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +Architecture-Specific Kernel Filtering for CK Tile Dispatcher + +Unified filtering mechanism for validating kernel configurations against +GPU architecture capabilities. Uses arch_specs.json as single source of truth. + +Key Features: +- GPU architecture-specific warp tile and warp configuration validation +- Data type compatibility checking +- Trait combination validation (pipeline, epilogue, scheduler) +- LDS capacity validation +- Single source of truth (arch_specs.json) + +Usage: + from arch_filter import ArchFilter, get_supported_archs + + # Create filter for specific architecture + filter = ArchFilter("gfx942") + + # Validate a kernel configuration + is_valid = filter.is_kernel_valid( + datatype_a="fp16", datatype_b="fp16", datatype_c="fp16", + tile_m=256, tile_n=256, tile_k=64, + warp_m=2, warp_n=2, warp_k=1, + warp_tile_m=32, warp_tile_n=32, warp_tile_k=16, + pipeline="compv4", epilogue="cshuffle", scheduler="intrawave" + ) + + # Get detailed validation results + result = filter.validate_kernel_detailed(...) + print(result.valid, result.errors) +""" + +from dataclasses import dataclass, field +from typing import Dict, List, Optional, Tuple, Any +from enum import Enum +import logging + +logger = logging.getLogger(__name__) + + +class OperatorType(Enum): + """Supported operator types for kernel validation""" + + GEMM = "gemm" + GEMM_PRESHUFFLE = "gemm_preshuffle" + GEMM_MULTI_D = "gemm_multi_d" + CONV_FWD = "conv_fwd" + CONV_BWD_DATA = "conv_bwd_data" + CONV_BWD_WEIGHT = "conv_bwd_weight" + CONV3D_FWD = "conv3d_fwd" + CONV3D_BWD_DATA = "conv3d_bwd_data" + CONV3D_BWD_WEIGHT = "conv3d_bwd_weight" + + +# Operator-specific tile constraints +# Different operators may have different minimum tile sizes or alignment requirements +OPERATOR_TILE_CONSTRAINTS = { + OperatorType.GEMM: { + "min_tile_m": 16, + "min_tile_n": 16, + "min_tile_k": 8, + "tile_m_alignment": 16, + "tile_n_alignment": 16, + "tile_k_alignment": 8, + }, + OperatorType.GEMM_PRESHUFFLE: { + "min_tile_m": 64, + "min_tile_n": 64, + "min_tile_k": 32, + "tile_m_alignment": 32, + "tile_n_alignment": 32, + "tile_k_alignment": 16, + }, + OperatorType.GEMM_MULTI_D: { + "min_tile_m": 16, + "min_tile_n": 16, + "min_tile_k": 8, + "tile_m_alignment": 16, + "tile_n_alignment": 16, + "tile_k_alignment": 8, + }, + OperatorType.CONV_FWD: { + "min_tile_m": 1, # N dimension can be 1 + "min_tile_n": 16, # K (output channels) should be reasonable + "min_tile_k": 16, # C (input channels) should be reasonable + "tile_m_alignment": 1, + "tile_n_alignment": 16, + "tile_k_alignment": 16, + }, + OperatorType.CONV_BWD_DATA: { + "min_tile_m": 1, + "min_tile_n": 16, # C (input channels) + "min_tile_k": 16, # K (output channels) + "tile_m_alignment": 1, + "tile_n_alignment": 16, + "tile_k_alignment": 16, + }, + OperatorType.CONV_BWD_WEIGHT: { + "min_tile_m": 16, # K (output channels) + "min_tile_n": 16, # C (input channels) + "min_tile_k": 1, # Spatial reduction dimension + "tile_m_alignment": 16, + "tile_n_alignment": 16, + "tile_k_alignment": 1, + }, +} + +# Add 3D convolution constraints (same as 2D for now) +OPERATOR_TILE_CONSTRAINTS[OperatorType.CONV3D_FWD] = OPERATOR_TILE_CONSTRAINTS[ + OperatorType.CONV_FWD +] +OPERATOR_TILE_CONSTRAINTS[OperatorType.CONV3D_BWD_DATA] = OPERATOR_TILE_CONSTRAINTS[ + OperatorType.CONV_BWD_DATA +] +OPERATOR_TILE_CONSTRAINTS[OperatorType.CONV3D_BWD_WEIGHT] = OPERATOR_TILE_CONSTRAINTS[ + OperatorType.CONV_BWD_WEIGHT +] + +# ============================================================================= +# Import from Generated Module (Single Source of Truth) +# ============================================================================= + +# Try to import from the generated module (created from arch_specs.json) +try: + from arch_specs_generated import ( + ARCH_FAMILY_MAP, + ELEMENT_SIZE_MAP, + WARP_SUPPORTED_COMBINATIONS, + WARP_TILE_SUPPORTED_COMBINATIONS, + PRESHUFFLE_WARP_TILE_SUPPORTED_COMBINATIONS, + PRESHUFFLE_PIPELINES, + LDS_CAPACITY_LIMITS, + TRAIT_UNSUPPORTED_COMBINATIONS, + DTYPE_COMBINATIONS, + ) + + _USING_GENERATED = True +except ImportError: + # Fallback to hardcoded values if generated module not available + logger.warning( + "arch_specs_generated.py not found, using fallback values. " + "Run 'python generate_arch_specs.py' to generate." + ) + _USING_GENERATED = False + + # Fallback data (minimal subset for basic operation) + ARCH_FAMILY_MAP = { + "gfx90a": "cdna2", + "gfx942": "cdna3", + "gfx950": "cdna4", + "gfx1201": "rdna4", + } + + ELEMENT_SIZE_MAP = { + "fp16": 2, + "bf16": 2, + "fp32": 4, + "fp64": 8, + "fp8": 1, + "bf8": 1, + "int8": 1, + "int4": 0.5, + "int32": 4, + } + + WARP_SUPPORTED_COMBINATIONS = { + "gfx90a": [[1, 4, 1], [2, 2, 1], [4, 1, 1]], + "gfx942": [[1, 4, 1], [2, 2, 1], [4, 1, 1]], + "gfx950": [[1, 4, 1], [2, 2, 1], [4, 1, 1]], + "gfx1201": [[2, 4, 1], [1, 8, 1], [8, 1, 1], [4, 2, 1]], + } + + WARP_TILE_SUPPORTED_COMBINATIONS = { + "gfx942": { + # Key format: A_B_Acc (e.g., fp16_fp16_fp32 = A/B are fp16, accumulator is fp32) + # These match tile_engine's GEMM_WARP_TILE_SUPPORTED_COMBINATIONS + "fp16_fp16_fp32": [ + [32, 32, 8], + [16, 16, 16], + [32, 32, 16], + [16, 16, 32], + [4, 64, 16], + [64, 4, 16], + ], + "bf16_bf16_fp32": [ + [32, 32, 8], + [16, 16, 16], + [32, 32, 16], + [16, 16, 32], + [4, 64, 16], + [64, 4, 16], + ], + "fp8_fp8_fp32": [[32, 32, 16], [32, 32, 32], [16, 16, 32], [16, 16, 64]], + "bf8_bf8_fp32": [[32, 32, 16], [32, 32, 32], [16, 16, 64], [16, 16, 32]], + "int8_int8_int32": [[16, 16, 32], [32, 32, 16]], + }, + } + + # Preshuffle-specific warp tile combinations (no [4, 64, 16]) + PRESHUFFLE_WARP_TILE_SUPPORTED_COMBINATIONS = { + "gfx942": { + "fp16_fp16_fp32": [ + [32, 32, 8], + [16, 16, 16], + [32, 32, 16], + [16, 16, 32], + [64, 4, 16], + ], + }, + } + + PRESHUFFLE_PIPELINES = ["preshufflev2"] + + LDS_CAPACITY_LIMITS = {"compv4": 32768, "preshufflev2": 32768, "default": 65536} + + TRAIT_UNSUPPORTED_COMBINATIONS = { + ("compv3", "cshuffle", "interwave"), + ("compv3", "default", "interwave"), + ("compv4", "cshuffle", "interwave"), + ("compv4", "default", "interwave"), + } + + DTYPE_COMBINATIONS = { + "fp32_fp32": {"acc": "fp32", "notes": "Full precision"}, + "fp16_fp16": {"acc": "fp32", "notes": "Standard half precision"}, + "bf16_bf16": {"acc": "fp32", "notes": "Brain float 16"}, + "fp8_fp8": {"acc": "fp32", "notes": "FP8 E4M3"}, + "fp8_bf8": {"acc": "fp32", "notes": "Mixed FP8/BF8"}, + "bf8_fp8": {"acc": "fp32", "notes": "Mixed BF8/FP8"}, + "bf8_bf8": {"acc": "fp32", "notes": "BF8 E5M2"}, + "int8_int8": {"acc": "int32", "notes": "Integer GEMM"}, + "pk_fp4_pk_fp4": {"acc": "fp32", "notes": "Packed 4-bit float"}, + } + + +# ============================================================================= +# GPU Family Enum (for backwards compatibility) +# ============================================================================= + + +class GpuFamily(Enum): + """GPU architecture families""" + + CDNA2 = "cdna2" + CDNA3 = "cdna3" + CDNA4 = "cdna4" + RDNA4 = "rdna4" + + +# ============================================================================= +# Dtype Validation Helpers +# ============================================================================= + + +def is_dtype_combo_valid(dtype_a: str, dtype_b: str) -> bool: + """Check if a dtype combination is valid for GEMM.""" + key = f"{dtype_a.lower()}_{dtype_b.lower()}" + return key in DTYPE_COMBINATIONS + + +def get_dtype_acc(dtype_a: str, dtype_b: str) -> str: + """Get the accumulator type for a dtype combination.""" + key = f"{dtype_a.lower()}_{dtype_b.lower()}" + info = DTYPE_COMBINATIONS.get(key, {"acc": "fp32"}) + return info["acc"] + + +def get_valid_dtype_combos() -> List[str]: + """Get list of all valid dtype combinations.""" + return list(DTYPE_COMBINATIONS.keys()) + + +# ============================================================================= +# Validation Result Types +# ============================================================================= + + +@dataclass +class ValidationResult: + """Result of kernel configuration validation""" + + valid: bool + errors: List[str] = field(default_factory=list) + warnings: List[str] = field(default_factory=list) + + def __bool__(self) -> bool: + return self.valid + + def add_error(self, msg: str): + self.errors.append(msg) + self.valid = False + + def add_warning(self, msg: str): + self.warnings.append(msg) + + +@dataclass +class KernelConfig: + """Kernel configuration for validation""" + + # Data types + datatype_a: str + datatype_b: str + datatype_c: str + + # Tile dimensions + tile_m: int + tile_n: int + tile_k: int + + # Warp configuration + warp_m: int + warp_n: int + warp_k: int + + # Warp tile dimensions + warp_tile_m: int + warp_tile_n: int + warp_tile_k: int + + # Traits + pipeline: str = "compv4" + epilogue: str = "cshuffle" + scheduler: str = "intrawave" + + # Layout (for whole-workgroup cover validation) + layout: str = "rcr" + + # Operator type (affects validation rules) + operator: OperatorType = OperatorType.GEMM + + @property + def dtype_key(self) -> str: + """Generate data type combination key for warp tile lookup. + + Uses accumulator dtype (not output C type) to match the format + used in WARP_TILE_SUPPORTED_COMBINATIONS dictionaries which are + keyed as {datatype_a}_{datatype_b}_{accumulator_dtype}. + """ + acc_dtype = get_dtype_acc(self.datatype_a, self.datatype_b) + return f"{self.datatype_a}_{self.datatype_b}_{acc_dtype}" + + +# ============================================================================= +# Architecture Filter Class +# ============================================================================= + + +class ArchFilter: + """ + Architecture-specific kernel configuration filter. + + Validates kernel configurations against GPU architecture capabilities + to ensure only compatible kernels are registered. + + Example: + filter = ArchFilter("gfx942") + + # Quick validation + if filter.is_kernel_valid(config): + registry.register_kernel(kernel) + + # Detailed validation with error messages + result = filter.validate_kernel(config) + if not result.valid: + for error in result.errors: + print(f"Validation failed: {error}") + """ + + def __init__(self, gpu_arch: str, strict_mode: bool = True): + """ + Initialize architecture filter. + + Args: + gpu_arch: GPU architecture string (e.g., "gfx942", "gfx90a") + strict_mode: If True, unknown configurations are rejected. + If False, unknown configurations pass with warnings. + """ + self.gpu_arch = gpu_arch.lower() + self.strict_mode = strict_mode + self.family = ARCH_FAMILY_MAP.get(self.gpu_arch) + + if self.family is None and strict_mode: + raise ValueError( + f"Unknown GPU architecture: {gpu_arch}. " + f"Supported: {list(ARCH_FAMILY_MAP.keys())}" + ) + + def validate_kernel(self, config: KernelConfig) -> ValidationResult: + """ + Validate a kernel configuration against architecture constraints. + + Validation is performed based on the operator type, as different + operators (GEMM, Conv FWD, Conv BWD) have different constraints. + + Args: + config: Kernel configuration to validate + + Returns: + ValidationResult with valid flag and error/warning messages + """ + result = ValidationResult(valid=True) + + # Operator-specific tile constraint validation + self._validate_operator_constraints(config, result) + if not result.valid and self.strict_mode: + return result + + # Basic sanity checks + self._validate_dimensions(config, result) + if not result.valid and self.strict_mode: + return result + + # Warp configuration validation + self._validate_warp_config(config, result) + + # Warp tile combination validation + self._validate_warp_tile_combo(config, result) + + # Trait combination validation + self._validate_trait_combo(config, result) + + # LDS capacity validation + self._validate_lds_capacity(config, result) + + # Dimension alignment validation + self._validate_dimension_alignment(config, result) + + return result + + def _validate_operator_constraints( + self, config: KernelConfig, result: ValidationResult + ): + """Validate operator-specific tile constraints""" + constraints = OPERATOR_TILE_CONSTRAINTS.get(config.operator) + + if constraints is None: + # Unknown operator - add warning but don't fail + result.add_warning( + f"Unknown operator type: {config.operator}. " + f"Skipping operator-specific validation." + ) + return + + # Validate minimum tile sizes + min_tile_m = constraints.get("min_tile_m", 1) + min_tile_n = constraints.get("min_tile_n", 1) + min_tile_k = constraints.get("min_tile_k", 1) + + if config.tile_m < min_tile_m: + result.add_error( + f"Operator {config.operator.value}: tile_m ({config.tile_m}) " + f"< minimum ({min_tile_m})" + ) + if config.tile_n < min_tile_n: + result.add_error( + f"Operator {config.operator.value}: tile_n ({config.tile_n}) " + f"< minimum ({min_tile_n})" + ) + if config.tile_k < min_tile_k: + result.add_error( + f"Operator {config.operator.value}: tile_k ({config.tile_k}) " + f"< minimum ({min_tile_k})" + ) + + # Validate tile alignment + tile_m_align = constraints.get("tile_m_alignment", 1) + tile_n_align = constraints.get("tile_n_alignment", 1) + tile_k_align = constraints.get("tile_k_alignment", 1) + + if tile_m_align > 1 and config.tile_m % tile_m_align != 0: + result.add_error( + f"Operator {config.operator.value}: tile_m ({config.tile_m}) " + f"must be aligned to {tile_m_align}" + ) + if tile_n_align > 1 and config.tile_n % tile_n_align != 0: + result.add_error( + f"Operator {config.operator.value}: tile_n ({config.tile_n}) " + f"must be aligned to {tile_n_align}" + ) + if tile_k_align > 1 and config.tile_k % tile_k_align != 0: + result.add_error( + f"Operator {config.operator.value}: tile_k ({config.tile_k}) " + f"must be aligned to {tile_k_align}" + ) + + def is_kernel_valid( + self, + datatype_a: str = "fp16", + datatype_b: str = "fp16", + datatype_c: str = "fp16", + tile_m: int = 256, + tile_n: int = 256, + tile_k: int = 64, + warp_m: int = 2, + warp_n: int = 2, + warp_k: int = 1, + warp_tile_m: int = 32, + warp_tile_n: int = 32, + warp_tile_k: int = 16, + pipeline: str = "compv4", + epilogue: str = "cshuffle", + scheduler: str = "intrawave", + layout: str = "rcr", + operator: Optional[OperatorType] = None, + ) -> bool: + """ + Quick validation check for a kernel configuration. + + Args: + datatype_a, datatype_b, datatype_c: Data types for A, B, C matrices + tile_m, tile_n, tile_k: Block tile dimensions + warp_m, warp_n, warp_k: Warp/wave configuration + warp_tile_m, warp_tile_n, warp_tile_k: Warp tile dimensions + pipeline, epilogue, scheduler: Kernel traits + layout: Matrix layout (e.g., "rcr") + operator: Operator type (GEMM, CONV_FWD, CONV_BWD_DATA, etc.) + Affects validation rules for tile constraints. + Defaults to GEMM if not specified. + + Returns: + True if configuration is valid for this architecture + """ + config = KernelConfig( + datatype_a=datatype_a.lower(), + datatype_b=datatype_b.lower(), + datatype_c=datatype_c.lower(), + tile_m=tile_m, + tile_n=tile_n, + tile_k=tile_k, + warp_m=warp_m, + warp_n=warp_n, + warp_k=warp_k, + warp_tile_m=warp_tile_m, + warp_tile_n=warp_tile_n, + warp_tile_k=warp_tile_k, + pipeline=pipeline.lower(), + epilogue=epilogue.lower(), + scheduler=scheduler.lower(), + layout=layout.lower(), + operator=operator if operator is not None else OperatorType.GEMM, + ) + return self.validate_kernel(config).valid + + def _validate_dimensions(self, config: KernelConfig, result: ValidationResult): + """Validate basic dimension constraints""" + if config.tile_m <= 0 or config.tile_n <= 0 or config.tile_k <= 0: + result.add_error( + f"Tile dimensions must be positive: " + f"{config.tile_m}x{config.tile_n}x{config.tile_k}" + ) + + if config.warp_m <= 0 or config.warp_n <= 0 or config.warp_k <= 0: + result.add_error( + f"Warp dimensions must be positive: " + f"{config.warp_m}x{config.warp_n}x{config.warp_k}" + ) + + if ( + config.warp_tile_m <= 0 + or config.warp_tile_n <= 0 + or config.warp_tile_k <= 0 + ): + result.add_error( + f"Warp tile dimensions must be positive: " + f"{config.warp_tile_m}x{config.warp_tile_n}x{config.warp_tile_k}" + ) + + # Check warp tiles fit within block tiles + if config.warp_m * config.warp_tile_m > config.tile_m: + result.add_error( + f"warp_m * warp_tile_m ({config.warp_m}*{config.warp_tile_m}=" + f"{config.warp_m * config.warp_tile_m}) > tile_m ({config.tile_m})" + ) + if config.warp_n * config.warp_tile_n > config.tile_n: + result.add_error( + f"warp_n * warp_tile_n ({config.warp_n}*{config.warp_tile_n}=" + f"{config.warp_n * config.warp_tile_n}) > tile_n ({config.tile_n})" + ) + if config.warp_k * config.warp_tile_k > config.tile_k: + result.add_error( + f"warp_k * warp_tile_k ({config.warp_k}*{config.warp_tile_k}=" + f"{config.warp_k * config.warp_tile_k}) > tile_k ({config.tile_k})" + ) + + def _validate_warp_config(self, config: KernelConfig, result: ValidationResult): + """Validate warp configuration against architecture""" + allowed = WARP_SUPPORTED_COMBINATIONS.get(self.gpu_arch, []) + current = [config.warp_m, config.warp_n, config.warp_k] + + if not allowed: + msg = f"No warp configurations defined for {self.gpu_arch}" + if self.strict_mode: + result.add_error(msg) + else: + result.add_warning(msg) + return + + if current not in allowed: + result.add_error( + f"Invalid warp configuration {current} for {self.gpu_arch}. " + f"Allowed: {allowed}" + ) + + def _validate_warp_tile_combo(self, config: KernelConfig, result: ValidationResult): + """Validate warp tile combination against architecture and data types""" + # Use preshuffle-specific warp tiles for preshuffle operator + if config.operator == OperatorType.GEMM_PRESHUFFLE: + gpu_combos = PRESHUFFLE_WARP_TILE_SUPPORTED_COMBINATIONS.get( + self.gpu_arch, {} + ) + combo_source = "preshuffle" + else: + gpu_combos = WARP_TILE_SUPPORTED_COMBINATIONS.get(self.gpu_arch, {}) + combo_source = "standard" + + if not gpu_combos: + msg = ( + f"No {combo_source} warp tile combinations defined for {self.gpu_arch}" + ) + if self.strict_mode: + result.add_error(msg) + else: + result.add_warning(msg) + return + + dtype_combos = gpu_combos.get(config.dtype_key, []) + if not dtype_combos: + # Data type combo not explicitly listed - may still be valid + result.add_warning( + f"No {combo_source} warp tile combinations defined for {config.dtype_key} on {self.gpu_arch}" + ) + return + + current = [config.warp_tile_m, config.warp_tile_n, config.warp_tile_k] + if current not in dtype_combos: + result.add_error( + f"Invalid warp tile {current} for {config.dtype_key} on {self.gpu_arch} ({combo_source}). " + f"Allowed: {dtype_combos}" + ) + + def _validate_trait_combo(self, config: KernelConfig, result: ValidationResult): + """Validate trait (pipeline, epilogue, scheduler) combination""" + # Preshuffle requires specific pipelines + if config.operator == OperatorType.GEMM_PRESHUFFLE: + if config.pipeline not in PRESHUFFLE_PIPELINES: + result.add_error( + f"Preshuffle GEMM requires pipeline in {PRESHUFFLE_PIPELINES}, " + f"got {config.pipeline}" + ) + + # Conv backward operations only support compv3/mem pipelines + # (compv4/compv5 have template issues: transpose_tile2d for bwd_weight, + # get_length for bwd_data in ck_tile kernels) + conv_bwd_operators = { + OperatorType.CONV_BWD_DATA, + OperatorType.CONV_BWD_WEIGHT, + OperatorType.CONV3D_BWD_DATA, + OperatorType.CONV3D_BWD_WEIGHT, + } + conv_bwd_supported_pipelines = {"compv3", "mem"} + if config.operator in conv_bwd_operators: + if config.pipeline not in conv_bwd_supported_pipelines: + result.add_error( + f"Conv backward operations require pipeline in " + f"{conv_bwd_supported_pipelines}, got {config.pipeline}. " + f"(compv4/compv5 have ck_tile template compatibility issues)" + ) + + combo = (config.pipeline, config.epilogue, config.scheduler) + if combo in TRAIT_UNSUPPORTED_COMBINATIONS: + result.add_error( + f"Unsupported trait combination: pipeline={config.pipeline}, " + f"epilogue={config.epilogue}, scheduler={config.scheduler}" + ) + + def _validate_lds_capacity(self, config: KernelConfig, result: ValidationResult): + """Validate LDS (Local Data Share) memory capacity""" + elem_size_a = ELEMENT_SIZE_MAP.get(config.datatype_a, 2) + elem_size_b = ELEMENT_SIZE_MAP.get(config.datatype_b, 2) + + matrix_a_size = config.tile_m * config.tile_k * elem_size_a + matrix_b_size = config.tile_n * config.tile_k * elem_size_b + total_lds = matrix_a_size + matrix_b_size + + max_lds = LDS_CAPACITY_LIMITS.get( + config.pipeline, LDS_CAPACITY_LIMITS["default"] + ) + + if total_lds > max_lds: + result.add_error( + f"LDS capacity exceeded: {total_lds} bytes > {max_lds} bytes limit. " + f"Matrix A: {config.tile_m}x{config.tile_k}x{elem_size_a}={matrix_a_size}B, " + f"Matrix B: {config.tile_n}x{config.tile_k}x{elem_size_b}={matrix_b_size}B" + ) + + def _validate_dimension_alignment( + self, config: KernelConfig, result: ValidationResult + ): + """Validate tile dimensions are aligned with warp dimensions""" + if config.tile_m % (config.warp_m * config.warp_tile_m) != 0: + result.add_error( + f"tile_m ({config.tile_m}) must be divisible by " + f"warp_m*warp_tile_m ({config.warp_m}*{config.warp_tile_m}=" + f"{config.warp_m * config.warp_tile_m})" + ) + + if config.tile_n % (config.warp_n * config.warp_tile_n) != 0: + result.add_error( + f"tile_n ({config.tile_n}) must be divisible by " + f"warp_n*warp_tile_n ({config.warp_n}*{config.warp_tile_n}=" + f"{config.warp_n * config.warp_tile_n})" + ) + + if config.tile_k % (config.warp_k * config.warp_tile_k) != 0: + result.add_error( + f"tile_k ({config.tile_k}) must be divisible by " + f"warp_k*warp_tile_k ({config.warp_k}*{config.warp_tile_k}=" + f"{config.warp_k * config.warp_tile_k})" + ) + + def get_supported_warp_configs(self) -> List[List[int]]: + """Get list of supported warp configurations for this architecture""" + return WARP_SUPPORTED_COMBINATIONS.get(self.gpu_arch, []) + + def get_supported_warp_tiles(self, dtype_key: str) -> List[List[int]]: + """Get list of supported warp tile configurations for given data types""" + gpu_combos = WARP_TILE_SUPPORTED_COMBINATIONS.get(self.gpu_arch, {}) + return gpu_combos.get(dtype_key, []) + + def get_supported_datatypes(self) -> List[str]: + """Get list of data type combinations supported on this architecture""" + gpu_combos = WARP_TILE_SUPPORTED_COMBINATIONS.get(self.gpu_arch, {}) + return list(gpu_combos.keys()) + + +# ============================================================================= +# Registry Filter Integration +# ============================================================================= + + +class RegistryFilter: + """ + Filter wrapper for integrating with dispatcher Registry. + + Provides a callable interface that can be used with Registry.filter() + or during kernel registration. + + Example: + # Create filter for gfx942 + filter = RegistryFilter("gfx942") + + # Use with registry + registry = Registry() + registry.set_kernel_filter(filter) # Auto-filter on registration + + # Or filter existing kernels + valid_kernels = registry.filter(filter.accepts_kernel) + """ + + def __init__(self, gpu_arch: str, strict_mode: bool = False): + """ + Initialize registry filter. + + Args: + gpu_arch: Target GPU architecture + strict_mode: If True, reject unknown configurations + """ + self.arch_filter = ArchFilter(gpu_arch, strict_mode=strict_mode) + self.gpu_arch = gpu_arch + self._rejected_count = 0 + self._accepted_count = 0 + + def accepts_kernel(self, kernel_config: Dict[str, Any]) -> bool: + """ + Check if a kernel configuration should be accepted into the registry. + + Args: + kernel_config: Dictionary with kernel configuration values + + Returns: + True if kernel is valid for target architecture + """ + try: + is_valid = self.arch_filter.is_kernel_valid( + datatype_a=kernel_config.get("dtype_a", "fp16"), + datatype_b=kernel_config.get("dtype_b", "fp16"), + datatype_c=kernel_config.get("dtype_c", "fp16"), + tile_m=kernel_config.get("tile_m", 256), + tile_n=kernel_config.get("tile_n", 256), + tile_k=kernel_config.get("tile_k", 64), + warp_m=kernel_config.get("warp_m", 2), + warp_n=kernel_config.get("warp_n", 2), + warp_k=kernel_config.get("warp_k", 1), + warp_tile_m=kernel_config.get("warp_tile_m", 32), + warp_tile_n=kernel_config.get("warp_tile_n", 32), + warp_tile_k=kernel_config.get("warp_tile_k", 16), + pipeline=kernel_config.get("pipeline", "compv4"), + epilogue=kernel_config.get("epilogue", "cshuffle"), + scheduler=kernel_config.get("scheduler", "intrawave"), + layout=kernel_config.get("layout", "rcr"), + ) + + if is_valid: + self._accepted_count += 1 + else: + self._rejected_count += 1 + + return is_valid + + except Exception as e: + logger.warning(f"Error validating kernel config: {e}") + self._rejected_count += 1 + return False + + def get_stats(self) -> Dict[str, int]: + """Get filtering statistics""" + return { + "accepted": self._accepted_count, + "rejected": self._rejected_count, + "total": self._accepted_count + self._rejected_count, + } + + def reset_stats(self): + """Reset filtering statistics""" + self._accepted_count = 0 + self._rejected_count = 0 + + def __call__(self, kernel_config: Dict[str, Any]) -> bool: + """Callable interface for use with filter functions""" + return self.accepts_kernel(kernel_config) + + +# ============================================================================= +# Convenience Functions +# ============================================================================= + + +def get_supported_archs() -> List[str]: + """Get list of all supported GPU architectures""" + return list(ARCH_FAMILY_MAP.keys()) + + +def get_arch_family(gpu_arch: str) -> Optional[str]: + """Get the GPU family for an architecture""" + family = ARCH_FAMILY_MAP.get(gpu_arch.lower()) + return family if family else None # ARCH_FAMILY_MAP contains strings, not Enums + + +def create_filter_for_current_gpu() -> Optional[ArchFilter]: + """ + Create a filter for the current GPU (auto-detect). + + Returns: + ArchFilter for detected GPU, or None if detection fails + """ + try: + import subprocess + + result = subprocess.run(["rocminfo"], capture_output=True, text=True, timeout=5) + + for line in result.stdout.split("\n"): + if "gfx" in line.lower(): + for arch in ARCH_FAMILY_MAP.keys(): + if arch in line.lower(): + return ArchFilter(arch) + + return None + except Exception: + return None + + +def filter_kernel_list( + kernels: List[Dict[str, Any]], gpu_arch: str +) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]: + """ + Filter a list of kernel configurations for a specific architecture. + + Args: + kernels: List of kernel configuration dictionaries + gpu_arch: Target GPU architecture + + Returns: + Tuple of (valid_kernels, rejected_kernels) + """ + reg_filter = RegistryFilter(gpu_arch) + valid = [] + rejected = [] + + for kernel in kernels: + if reg_filter.accepts_kernel(kernel): + valid.append(kernel) + else: + rejected.append(kernel) + + return valid, rejected + + +# ============================================================================= +# Main (for testing) +# ============================================================================= + +if __name__ == "__main__": + # Test the filter + print("Testing ArchFilter for gfx942...\n") + + filter_942 = ArchFilter("gfx942") + + # Test valid configuration + print("Test 1: Valid FP16 GEMM kernel") + result = filter_942.validate_kernel( + KernelConfig( + datatype_a="fp16", + datatype_b="fp16", + datatype_c="fp16", + tile_m=256, + tile_n=256, + tile_k=64, + warp_m=2, + warp_n=2, + warp_k=1, + warp_tile_m=32, + warp_tile_n=32, + warp_tile_k=16, + pipeline="compv4", + epilogue="cshuffle", + scheduler="intrawave", + ) + ) + print(f" Valid: {result.valid}") + if result.errors: + print(f" Errors: {result.errors}") + print() + + # Test invalid warp configuration + print("Test 2: Invalid warp configuration") + result = filter_942.validate_kernel( + KernelConfig( + datatype_a="fp16", + datatype_b="fp16", + datatype_c="fp16", + tile_m=256, + tile_n=256, + tile_k=64, + warp_m=3, + warp_n=3, + warp_k=1, # Invalid! + warp_tile_m=32, + warp_tile_n=32, + warp_tile_k=16, + ) + ) + print(f" Valid: {result.valid}") + if result.errors: + print(f" Errors: {result.errors}") + print() + + # Test LDS overflow + print("Test 3: LDS capacity overflow") + result = filter_942.validate_kernel( + KernelConfig( + datatype_a="fp16", + datatype_b="fp16", + datatype_c="fp16", + tile_m=512, + tile_n=512, + tile_k=256, # Too large! + warp_m=2, + warp_n=2, + warp_k=1, + warp_tile_m=32, + warp_tile_n=32, + warp_tile_k=16, + pipeline="compv4", + ) + ) + print(f" Valid: {result.valid}") + if result.errors: + print(f" Errors: {result.errors}") + print() + + # Test quick validation + print("Test 4: Quick validation (is_kernel_valid)") + is_valid = filter_942.is_kernel_valid( + tile_m=128, + tile_n=128, + tile_k=32, + warp_m=2, + warp_n=2, + warp_k=1, + warp_tile_m=16, + warp_tile_n=16, + warp_tile_k=16, + ) + print(f" Valid: {is_valid}") + print() + + # Show supported configurations + print("Supported warp configurations for gfx942:") + for cfg in filter_942.get_supported_warp_configs(): + print(f" {cfg}") + print() + + print("Supported data types for gfx942:") + for dtype in filter_942.get_supported_datatypes(): + print(f" {dtype}") diff --git a/dispatcher/codegen/arch_specs.json b/dispatcher/codegen/arch_specs.json new file mode 100644 index 00000000000..7d8c83fbf75 --- /dev/null +++ b/dispatcher/codegen/arch_specs.json @@ -0,0 +1,270 @@ +{ + "_comment": "Single source of truth for GPU architecture specifications. Edit this file to add new GPU support.", + "_version": "1.2.0", + "_instructions": "See ADDING_NEW_GPU.md for instructions on adding new GPU support.", + "_supported_arch_note": "CK Tile supports: GFX9 (gfx908, gfx90a, gfx942, gfx950), GFX10.3 (gfx103x), GFX11 (gfx110x, gfx115x), GFX12 (gfx120x)", + + "architectures": { + "gfx908": { + "family": "cdna1", + "target_family": "gfx9", + "architecture": "cdna", + "description": "AMD Instinct MI100", + "warp_size": 64, + "lds_capacity_kb": 64, + "warp_configs": [ + [1, 4, 1], + [2, 2, 1], + [4, 1, 1] + ], + "warp_tile_combos": { + "fp32_fp32_fp32": [[16, 16, 4], [16, 16, 16]], + "fp16_fp16_fp32": [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32]], + "bf16_bf16_fp32": [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32]], + "int8_int8_int32": [[32, 32, 16], [16, 16, 32]] + } + }, + + "gfx90a": { + "family": "cdna2", + "target_family": "gfx9", + "architecture": "cdna", + "description": "AMD Instinct MI200 series", + "warp_size": 64, + "lds_capacity_kb": 64, + "warp_configs": [ + [1, 4, 1], + [2, 2, 1], + [4, 1, 1] + ], + "warp_tile_combos": { + "fp32_fp32_fp32": [[16, 16, 4], [16, 16, 16]], + "fp16_fp16_fp32": [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [4, 64, 16], [64, 4, 16]], + "bf16_bf16_fp32": [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [4, 64, 16], [64, 4, 16]], + "fp8_fp8_fp32": [[32, 32, 16], [32, 32, 32]], + "bf8_bf8_fp32": [[32, 32, 16], [32, 32, 32]], + "int8_int8_int32": [[32, 32, 16], [16, 16, 32]] + } + }, + + "gfx942": { + "family": "cdna3", + "target_family": "gfx9", + "architecture": "cdna", + "description": "AMD Instinct MI300 series", + "warp_size": 64, + "lds_capacity_kb": 64, + "warp_configs": [ + [1, 4, 1], + [2, 2, 1], + [4, 1, 1] + ], + "warp_tile_combos": { + "fp32_fp32_fp32": [[16, 16, 4], [16, 16, 16]], + "fp16_fp16_fp32": [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [4, 64, 16], [64, 4, 16]], + "bf16_bf16_fp32": [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [4, 64, 16], [64, 4, 16]], + "fp8_fp8_fp32": [[32, 32, 16], [32, 32, 32], [16, 16, 32], [16, 16, 64]], + "fp8_bf8_fp32": [[32, 32, 16], [16, 16, 32], [32, 32, 32]], + "bf8_fp8_fp32": [[32, 32, 16]], + "bf8_bf8_fp32": [[32, 32, 16], [32, 32, 32], [16, 16, 32], [16, 16, 64]], + "int8_int8_int32": [[32, 32, 16], [16, 16, 32]] + } + }, + + "gfx950": { + "family": "cdna4", + "target_family": "gfx9", + "architecture": "cdna", + "description": "AMD Instinct MI350 series", + "warp_size": 64, + "lds_capacity_kb": 160, + "warp_configs": [ + [1, 4, 1], + [2, 2, 1], + [4, 1, 1] + ], + "warp_tile_combos": { + "fp32_fp32_fp32": [[16, 16, 4], [16, 16, 16]], + "fp16_fp16_fp32": [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [4, 64, 16], [64, 4, 16]], + "bf16_bf16_fp32": [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [4, 64, 16], [64, 4, 16]], + "fp8_fp8_fp32": [[32, 32, 16], [32, 32, 32], [16, 16, 32], [16, 16, 64], [16, 16, 128], [32, 32, 64]], + "fp8_bf8_fp32": [[32, 32, 16], [32, 32, 32], [16, 16, 32], [16, 16, 128], [32, 32, 64]], + "bf8_fp8_fp32": [[32, 32, 16], [16, 16, 128], [32, 32, 64]], + "bf8_bf8_fp32": [[32, 32, 16], [32, 32, 32], [16, 16, 32], [16, 16, 64], [16, 16, 128], [32, 32, 64]], + "int8_int8_int32": [[32, 32, 16], [16, 16, 32]], + "pk_fp4_pk_fp4_fp32": [[16, 16, 128]] + } + }, + + "gfx1100": { + "family": "rdna3", + "target_family": "gfx11", + "architecture": "rdna", + "description": "AMD Radeon RX 7900 series (RDNA3)", + "warp_size": 32, + "lds_capacity_kb": 64, + "warp_configs": [ + [2, 4, 1], + [1, 8, 1], + [8, 1, 1], + [4, 2, 1] + ], + "warp_tile_combos": { + "fp16_fp16_fp32": [[16, 16, 16]], + "bf16_bf16_fp32": [[16, 16, 16]], + "int8_int8_int32": [[16, 16, 16]] + } + }, + + "gfx1200": { + "family": "rdna4", + "target_family": "gfx12", + "architecture": "rdna", + "description": "AMD Radeon RX 9000 series (RDNA4)", + "warp_size": 32, + "lds_capacity_kb": 64, + "warp_configs": [ + [2, 4, 1], + [1, 8, 1], + [8, 1, 1], + [4, 2, 1] + ], + "warp_tile_combos": { + "fp16_fp16_fp32": [[16, 16, 16]], + "bf16_bf16_fp32": [[16, 16, 16]], + "fp8_fp8_fp32": [[16, 16, 16]], + "bf8_bf8_fp32": [[16, 16, 16]], + "fp8_bf8_fp32": [[16, 16, 16]], + "bf8_fp8_fp32": [[16, 16, 16]], + "int8_int8_int32": [[16, 16, 16]] + } + }, + + "gfx1201": { + "family": "rdna4", + "target_family": "gfx12", + "architecture": "rdna", + "description": "AMD Radeon RX 9000 series (RDNA4)", + "warp_size": 32, + "lds_capacity_kb": 64, + "warp_configs": [ + [2, 4, 1], + [1, 8, 1], + [8, 1, 1], + [4, 2, 1] + ], + "warp_tile_combos": { + "fp16_fp16_fp32": [[16, 16, 16]], + "bf16_bf16_fp32": [[16, 16, 16]], + "fp8_fp8_fp32": [[16, 16, 16]], + "bf8_bf8_fp32": [[16, 16, 16]], + "fp8_bf8_fp32": [[16, 16, 16]], + "bf8_fp8_fp32": [[16, 16, 16]], + "int8_int8_int32": [[16, 16, 16]] + } + } + }, + + "element_sizes": { + "fp16": 2, + "bf16": 2, + "fp32": 4, + "fp64": 8, + "fp8": 1, + "bf8": 1, + "int8": 1, + "int4": 0.5, + "pk_fp4": 0.5, + "int32": 4 + }, + + "datatype_cpp_map": { + "_comment": "Maps dtype string to CK Tile C++ type for code generation", + "fp16": "ck_tile::half_t", + "bf16": "ck_tile::bf16_t", + "fp32": "float", + "fp64": "double", + "fp8": "ck_tile::fp8_t", + "bf8": "ck_tile::bf8_t", + "int8": "ck_tile::int8_t", + "int4": "ck_tile::pk_int4_t", + "pk_fp4": "ck_tile::pk_fp4_t", + "int32": "ck_tile::int32_t" + }, + + "dtype_combinations": { + "_comment": "All valid (A, B) -> Acc combinations for GEMM from warp_gemm_dispatcher.hpp", + "fp32_fp32": {"acc": "fp32", "notes": "Full precision"}, + "fp16_fp16": {"acc": "fp32", "notes": "Standard half precision"}, + "bf16_bf16": {"acc": "fp32", "notes": "Brain float 16"}, + "fp8_fp8": {"acc": "fp32", "notes": "FP8 E4M3"}, + "fp8_bf8": {"acc": "fp32", "notes": "Mixed FP8/BF8"}, + "bf8_fp8": {"acc": "fp32", "notes": "Mixed BF8/FP8"}, + "bf8_bf8": {"acc": "fp32", "notes": "BF8 E5M2"}, + "int8_int8": {"acc": "int32", "notes": "Integer GEMM"}, + "pk_fp4_pk_fp4": {"acc": "fp32", "notes": "Packed 4-bit float"} + }, + + "layout_cpp_map": { + "_comment": "Maps layout character to CK Tile C++ type", + "r": "ck_tile::tensor_layout::gemm::RowMajor", + "c": "ck_tile::tensor_layout::gemm::ColumnMajor" + }, + + "pipeline_lds_limits": { + "_comment": "LDS capacity limits in bytes for different pipeline types", + "mem": 65536, + "compv1": 65536, + "compv2": 65536, + "compv3": 65536, + "compv4": 32768, + "compv5": 65536, + "preshufflev1": 32768, + "preshufflev2": 32768, + "default": 65536 + }, + + "unsupported_trait_combos": { + "_comment": "Only 'mem' pipeline supports interwave scheduler. All compute pipelines only support intrawave.", + "combinations": [ + ["compv3", "cshuffle", "interwave"], + ["compv3", "default", "interwave"], + ["compv4", "cshuffle", "interwave"], + ["compv4", "default", "interwave"], + ["compv5", "cshuffle", "interwave"], + ["compv5", "default", "interwave"], + ["compv6", "cshuffle", "interwave"], + ["compv6", "default", "interwave"], + ["comp_async", "cshuffle", "interwave"], + ["comp_async", "default", "interwave"] + ] + }, + + "preshuffle_warp_tile_combos": { + "_comment": "Preshuffle-specific warp tile combinations (subset of standard GEMM, no [4, 64, 16])", + "gfx90a": { + "fp16_fp16_fp32": [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [64, 4, 16]], + "bf16_bf16_fp32": [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [64, 4, 16]], + "fp8_fp8_fp32": [[32, 32, 16], [32, 32, 32]], + "bf8_bf8_fp32": [[32, 32, 16], [32, 32, 32]] + }, + "gfx942": { + "fp16_fp16_fp32": [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [64, 4, 16]], + "bf16_bf16_fp32": [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [64, 4, 16]], + "fp8_fp8_fp32": [[32, 32, 16], [32, 32, 32], [16, 16, 32], [16, 16, 64]], + "bf8_bf8_fp32": [[32, 32, 16], [32, 32, 32], [16, 16, 64], [16, 16, 32]], + "int8_int8_int32": [[16, 16, 32], [32, 32, 16]] + }, + "gfx950": { + "fp16_fp16_fp32": [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [64, 4, 16]], + "bf16_bf16_fp32": [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [64, 4, 16]], + "fp8_fp8_fp32": [[32, 32, 16], [32, 32, 32], [16, 16, 32], [16, 16, 64], [16, 16, 128], [32, 32, 64]], + "bf8_bf8_fp32": [[32, 32, 16], [32, 32, 32], [16, 16, 64], [16, 16, 32], [16, 16, 128], [32, 32, 64]] + } + }, + + "preshuffle_pipelines": { + "_comment": "Pipelines supported for preshuffle GEMM variant", + "supported": ["preshufflev2"] + } +} diff --git a/dispatcher/codegen/arch_specs_generated.py b/dispatcher/codegen/arch_specs_generated.py new file mode 100644 index 00000000000..97f17e97241 --- /dev/null +++ b/dispatcher/codegen/arch_specs_generated.py @@ -0,0 +1,358 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +AUTO-GENERATED FILE - DO NOT EDIT DIRECTLY! + +Generated from: arch_specs.json +Generated at: 2026-01-05T19:34:01.224422 + +To update this file: +1. Edit arch_specs.json +2. Run: python generate_arch_specs.py + +This module provides architecture-specific configurations for kernel filtering. +""" + +from typing import Dict, List, Set, Tuple + +# ============================================================================= +# Architecture Data (Generated from arch_specs.json) +# ============================================================================= + +# GPU architecture to family mapping +ARCH_FAMILY_MAP: Dict[str, str] = { + "gfx908": "cdna1", + "gfx90a": "cdna2", + "gfx942": "cdna3", + "gfx950": "cdna4", + "gfx1100": "rdna3", + "gfx1200": "rdna4", + "gfx1201": "rdna4", +} + +# Element size in bytes for each data type +ELEMENT_SIZE_MAP: Dict[str, float] = { + "fp16": 2, + "bf16": 2, + "fp32": 4, + "fp64": 8, + "fp8": 1, + "bf8": 1, + "int8": 1, + "int4": 0.5, + "pk_fp4": 0.5, + "int32": 4, +} + +# Supported warp configurations per architecture [warp_m, warp_n, warp_k] +WARP_SUPPORTED_COMBINATIONS: Dict[str, List[List[int]]] = { + "gfx908": [[1, 4, 1], [2, 2, 1], [4, 1, 1]], + "gfx90a": [[1, 4, 1], [2, 2, 1], [4, 1, 1]], + "gfx942": [[1, 4, 1], [2, 2, 1], [4, 1, 1]], + "gfx950": [[1, 4, 1], [2, 2, 1], [4, 1, 1]], + "gfx1100": [[2, 4, 1], [1, 8, 1], [8, 1, 1], [4, 2, 1]], + "gfx1200": [[2, 4, 1], [1, 8, 1], [8, 1, 1], [4, 2, 1]], + "gfx1201": [[2, 4, 1], [1, 8, 1], [8, 1, 1], [4, 2, 1]], +} + +# Supported warp tile combinations: arch -> dtype_key -> [[warp_tile_m, n, k], ...] +WARP_TILE_SUPPORTED_COMBINATIONS: Dict[str, Dict[str, List[List[int]]]] = { + "gfx908": { + "fp32_fp32_fp32": [[16, 16, 4], [16, 16, 16]], + "fp16_fp16_fp32": [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32]], + "bf16_bf16_fp32": [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32]], + "int8_int8_int32": [[32, 32, 16], [16, 16, 32]], + }, + "gfx90a": { + "fp32_fp32_fp32": [[16, 16, 4], [16, 16, 16]], + "fp16_fp16_fp32": [ + [32, 32, 8], + [16, 16, 16], + [32, 32, 16], + [16, 16, 32], + [4, 64, 16], + [64, 4, 16], + ], + "bf16_bf16_fp32": [ + [32, 32, 8], + [16, 16, 16], + [32, 32, 16], + [16, 16, 32], + [4, 64, 16], + [64, 4, 16], + ], + "fp8_fp8_fp32": [[32, 32, 16], [32, 32, 32]], + "bf8_bf8_fp32": [[32, 32, 16], [32, 32, 32]], + "int8_int8_int32": [[32, 32, 16], [16, 16, 32]], + }, + "gfx942": { + "fp32_fp32_fp32": [[16, 16, 4], [16, 16, 16]], + "fp16_fp16_fp32": [ + [32, 32, 8], + [16, 16, 16], + [32, 32, 16], + [16, 16, 32], + [4, 64, 16], + [64, 4, 16], + ], + "bf16_bf16_fp32": [ + [32, 32, 8], + [16, 16, 16], + [32, 32, 16], + [16, 16, 32], + [4, 64, 16], + [64, 4, 16], + ], + "fp8_fp8_fp32": [[32, 32, 16], [32, 32, 32], [16, 16, 32], [16, 16, 64]], + "fp8_bf8_fp32": [[32, 32, 16], [16, 16, 32], [32, 32, 32]], + "bf8_fp8_fp32": [[32, 32, 16]], + "bf8_bf8_fp32": [[32, 32, 16], [32, 32, 32], [16, 16, 32], [16, 16, 64]], + "int8_int8_int32": [[32, 32, 16], [16, 16, 32]], + }, + "gfx950": { + "fp32_fp32_fp32": [[16, 16, 4], [16, 16, 16]], + "fp16_fp16_fp32": [ + [32, 32, 8], + [16, 16, 16], + [32, 32, 16], + [16, 16, 32], + [4, 64, 16], + [64, 4, 16], + ], + "bf16_bf16_fp32": [ + [32, 32, 8], + [16, 16, 16], + [32, 32, 16], + [16, 16, 32], + [4, 64, 16], + [64, 4, 16], + ], + "fp8_fp8_fp32": [ + [32, 32, 16], + [32, 32, 32], + [16, 16, 32], + [16, 16, 64], + [16, 16, 128], + [32, 32, 64], + ], + "fp8_bf8_fp32": [ + [32, 32, 16], + [32, 32, 32], + [16, 16, 32], + [16, 16, 128], + [32, 32, 64], + ], + "bf8_fp8_fp32": [[32, 32, 16], [16, 16, 128], [32, 32, 64]], + "bf8_bf8_fp32": [ + [32, 32, 16], + [32, 32, 32], + [16, 16, 32], + [16, 16, 64], + [16, 16, 128], + [32, 32, 64], + ], + "int8_int8_int32": [[32, 32, 16], [16, 16, 32]], + "pk_fp4_pk_fp4_fp32": [[16, 16, 128]], + }, + "gfx1100": { + "fp16_fp16_fp32": [[16, 16, 16]], + "bf16_bf16_fp32": [[16, 16, 16]], + "int8_int8_int32": [[16, 16, 16]], + }, + "gfx1200": { + "fp16_fp16_fp32": [[16, 16, 16]], + "bf16_bf16_fp32": [[16, 16, 16]], + "fp8_fp8_fp32": [[16, 16, 16]], + "bf8_bf8_fp32": [[16, 16, 16]], + "fp8_bf8_fp32": [[16, 16, 16]], + "bf8_fp8_fp32": [[16, 16, 16]], + "int8_int8_int32": [[16, 16, 16]], + }, + "gfx1201": { + "fp16_fp16_fp32": [[16, 16, 16]], + "bf16_bf16_fp32": [[16, 16, 16]], + "fp8_fp8_fp32": [[16, 16, 16]], + "bf8_bf8_fp32": [[16, 16, 16]], + "fp8_bf8_fp32": [[16, 16, 16]], + "bf8_fp8_fp32": [[16, 16, 16]], + "int8_int8_int32": [[16, 16, 16]], + }, +} + +# Preshuffle-specific warp tile combinations (subset of standard GEMM) +PRESHUFFLE_WARP_TILE_SUPPORTED_COMBINATIONS: Dict[str, Dict[str, List[List[int]]]] = { + "gfx90a": { + "fp16_fp16_fp32": [ + [32, 32, 8], + [16, 16, 16], + [32, 32, 16], + [16, 16, 32], + [64, 4, 16], + ], + "bf16_bf16_fp32": [ + [32, 32, 8], + [16, 16, 16], + [32, 32, 16], + [16, 16, 32], + [64, 4, 16], + ], + "fp8_fp8_fp32": [[32, 32, 16], [32, 32, 32]], + "bf8_bf8_fp32": [[32, 32, 16], [32, 32, 32]], + }, + "gfx942": { + "fp16_fp16_fp32": [ + [32, 32, 8], + [16, 16, 16], + [32, 32, 16], + [16, 16, 32], + [64, 4, 16], + ], + "bf16_bf16_fp32": [ + [32, 32, 8], + [16, 16, 16], + [32, 32, 16], + [16, 16, 32], + [64, 4, 16], + ], + "fp8_fp8_fp32": [[32, 32, 16], [32, 32, 32], [16, 16, 32], [16, 16, 64]], + "bf8_bf8_fp32": [[32, 32, 16], [32, 32, 32], [16, 16, 64], [16, 16, 32]], + "int8_int8_int32": [[16, 16, 32], [32, 32, 16]], + }, + "gfx950": { + "fp16_fp16_fp32": [ + [32, 32, 8], + [16, 16, 16], + [32, 32, 16], + [16, 16, 32], + [64, 4, 16], + ], + "bf16_bf16_fp32": [ + [32, 32, 8], + [16, 16, 16], + [32, 32, 16], + [16, 16, 32], + [64, 4, 16], + ], + "fp8_fp8_fp32": [ + [32, 32, 16], + [32, 32, 32], + [16, 16, 32], + [16, 16, 64], + [16, 16, 128], + [32, 32, 64], + ], + "bf8_bf8_fp32": [ + [32, 32, 16], + [32, 32, 32], + [16, 16, 64], + [16, 16, 32], + [16, 16, 128], + [32, 32, 64], + ], + }, +} + +# Preshuffle-supported pipelines +PRESHUFFLE_PIPELINES: List[str] = ["preshufflev2"] + +# LDS capacity limits per pipeline type (in bytes) +LDS_CAPACITY_LIMITS: Dict[str, int] = { + "mem": 65536, + "compv1": 65536, + "compv2": 65536, + "compv3": 65536, + "compv4": 32768, + "compv5": 65536, + "preshufflev1": 32768, + "preshufflev2": 32768, + "default": 65536, +} + +# Unsupported trait combinations: (pipeline, epilogue, scheduler) +TRAIT_UNSUPPORTED_COMBINATIONS: Set[Tuple[str, str, str]] = { + ("compv3", "cshuffle", "interwave"), + ("compv3", "default", "interwave"), + ("compv4", "cshuffle", "interwave"), + ("compv4", "default", "interwave"), + ("compv5", "cshuffle", "interwave"), + ("compv5", "default", "interwave"), + ("compv6", "cshuffle", "interwave"), + ("compv6", "default", "interwave"), + ("comp_async", "cshuffle", "interwave"), + ("comp_async", "default", "interwave"), +} + +# Valid dtype combinations: (A_dtype, B_dtype) -> acc_dtype and notes +DTYPE_COMBINATIONS: Dict[str, Dict[str, str]] = { + "fp32_fp32": {"acc": "fp32", "notes": "Full precision"}, + "fp16_fp16": {"acc": "fp32", "notes": "Standard half precision"}, + "bf16_bf16": {"acc": "fp32", "notes": "Brain float 16"}, + "fp8_fp8": {"acc": "fp32", "notes": "FP8 E4M3"}, + "fp8_bf8": {"acc": "fp32", "notes": "Mixed FP8/BF8"}, + "bf8_fp8": {"acc": "fp32", "notes": "Mixed BF8/FP8"}, + "bf8_bf8": {"acc": "fp32", "notes": "BF8 E5M2"}, + "int8_int8": {"acc": "int32", "notes": "Integer GEMM"}, + "pk_fp4_pk_fp4": {"acc": "fp32", "notes": "Packed 4-bit float"}, +} + +# ============================================================================= +# Helper Functions +# ============================================================================= + + +def get_supported_archs() -> List[str]: + """Get list of all supported GPU architectures.""" + return list(ARCH_FAMILY_MAP.keys()) + + +def get_arch_family(gpu_arch: str) -> str: + """Get the GPU family for an architecture.""" + return ARCH_FAMILY_MAP.get(gpu_arch.lower(), "unknown") + + +def get_element_size(dtype: str) -> float: + """Get element size in bytes for a data type.""" + return ELEMENT_SIZE_MAP.get(dtype.lower(), 2.0) + + +def get_warp_configs(gpu_arch: str) -> List[List[int]]: + """Get supported warp configurations for an architecture.""" + return WARP_SUPPORTED_COMBINATIONS.get(gpu_arch.lower(), []) + + +def get_warp_tile_combos(gpu_arch: str, dtype_key: str) -> List[List[int]]: + """Get supported warp tile combinations for arch and data types.""" + gpu_combos = WARP_TILE_SUPPORTED_COMBINATIONS.get(gpu_arch.lower(), {}) + return gpu_combos.get(dtype_key.lower(), []) + + +def get_lds_limit(pipeline: str) -> int: + """Get LDS capacity limit for a pipeline type.""" + return LDS_CAPACITY_LIMITS.get(pipeline.lower(), LDS_CAPACITY_LIMITS["default"]) + + +def is_trait_combo_unsupported(pipeline: str, epilogue: str, scheduler: str) -> bool: + """Check if a trait combination is unsupported.""" + return ( + pipeline.lower(), + epilogue.lower(), + scheduler.lower(), + ) in TRAIT_UNSUPPORTED_COMBINATIONS + + +def get_dtype_info(dtype_a: str, dtype_b: str) -> Dict[str, str]: + """Get accumulator type and notes for a dtype combination.""" + key = f"{dtype_a.lower()}_{dtype_b.lower()}" + return DTYPE_COMBINATIONS.get(key, {"acc": "fp32", "notes": "unknown"}) + + +def is_dtype_combo_valid(dtype_a: str, dtype_b: str) -> bool: + """Check if a dtype combination is valid.""" + key = f"{dtype_a.lower()}_{dtype_b.lower()}" + return key in DTYPE_COMBINATIONS + + +def get_valid_dtype_combos() -> List[str]: + """Get list of all valid dtype combinations.""" + return list(DTYPE_COMBINATIONS.keys()) diff --git a/dispatcher/codegen/default_config.json b/dispatcher/codegen/default_config.json new file mode 100644 index 00000000000..3ef823fcc2d --- /dev/null +++ b/dispatcher/codegen/default_config.json @@ -0,0 +1,27 @@ +{ + "tile_config": { + "tile_m": [128, 256], + "tile_n": [128, 256], + "tile_k": [32, 64], + "warp_m": [2, 4], + "warp_n": [2, 4], + "warp_k": [1], + "warp_tile_m": [16, 32], + "warp_tile_n": [16, 32], + "warp_tile_k": [16] + }, + "trait_config": { + "pipeline": ["compv4"], + "epilogue": ["cshuffle"], + "scheduler": ["intrawave"], + "pad_m": [false], + "pad_n": [false], + "pad_k": [false], + "persistent": [false, true] + }, + "multi_d_config": { + "elementwise_ops": ["MultiDAdd", "Relu", "Gelu"], + "num_d_tensors": [1, 2] + } +} + diff --git a/dispatcher/codegen/generate_arch_specs.py b/dispatcher/codegen/generate_arch_specs.py new file mode 100644 index 00000000000..5b6fc2971b2 --- /dev/null +++ b/dispatcher/codegen/generate_arch_specs.py @@ -0,0 +1,452 @@ +#!/usr/bin/env python3 + +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +Architecture Specs Generator + +Generates both Python and C++ code from a single JSON source of truth. +This ensures consistency between Python codegen and C++ runtime filtering. + +Usage: + python generate_arch_specs.py [--json arch_specs.json] [--output-dir .] + + # Regenerate after editing arch_specs.json: + python generate_arch_specs.py + +Output: + - arch_specs_generated.py (Python module with arch data) + - arch_specs_generated.hpp (C++ header with arch data) +""" + +import json +import argparse +from pathlib import Path +from datetime import datetime +from typing import Dict, Any + +SCRIPT_DIR = Path(__file__).parent + + +def load_arch_specs(json_path: Path) -> Dict[str, Any]: + """Load architecture specifications from JSON file.""" + with open(json_path) as f: + return json.load(f) + + +def generate_python_module(specs: Dict[str, Any], output_path: Path): + """Generate Python module from arch specs.""" + + timestamp = datetime.now().isoformat() + + # Extract data + archs = specs["architectures"] + element_sizes = specs["element_sizes"] + pipeline_limits = specs["pipeline_lds_limits"] + unsupported = specs["unsupported_trait_combos"]["combinations"] + + # Build warp configs dict + warp_configs_str = "{\n" + for arch, data in archs.items(): + warp_configs_str += f' "{arch}": {data["warp_configs"]},\n' + warp_configs_str += "}" + + # Build warp tile combos dict + warp_tile_str = "{\n" + for arch, data in archs.items(): + warp_tile_str += f' "{arch}": {{\n' + for dtype, combos in data["warp_tile_combos"].items(): + warp_tile_str += f' "{dtype}": {combos},\n' + warp_tile_str += " },\n" + warp_tile_str += "}" + + # Build arch family map + arch_family_str = "{\n" + for arch, data in archs.items(): + arch_family_str += f' "{arch}": "{data["family"]}",\n' + arch_family_str += "}" + + # Build unsupported combos set + unsupported_str = "{\n" + for combo in unsupported: + unsupported_str += f' ("{combo[0]}", "{combo[1]}", "{combo[2]}"),\n' + unsupported_str += "}" + + # Pipeline LDS limits + pipeline_limits_clean = { + k: v for k, v in pipeline_limits.items() if not k.startswith("_") + } + + # Build dtype combinations dict + dtype_combos = specs.get("dtype_combinations", {}) + dtype_combos_str = "{\n" + for key, info in dtype_combos.items(): + if not key.startswith("_"): + dtype_combos_str += f' "{key}": {{"acc": "{info["acc"]}", "notes": "{info["notes"]}"}},\n' + dtype_combos_str += "}" + + # Build preshuffle warp tile combos dict (operator-specific) + preshuffle_combos = specs.get("preshuffle_warp_tile_combos", {}) + preshuffle_warp_tile_str = "{\n" + for arch, dtype_combos_dict in preshuffle_combos.items(): + if not arch.startswith("_"): + preshuffle_warp_tile_str += f' "{arch}": {{\n' + for dtype, combos in dtype_combos_dict.items(): + preshuffle_warp_tile_str += f' "{dtype}": {combos},\n' + preshuffle_warp_tile_str += " },\n" + preshuffle_warp_tile_str += "}" + + # Build preshuffle pipelines list + preshuffle_pipelines = specs.get("preshuffle_pipelines", {}).get( + "supported", ["preshufflev2"] + ) + preshuffle_pipelines_str = str(preshuffle_pipelines) + + content = f'''# SPDX-License-Identifier: MIT + +""" +AUTO-GENERATED FILE - DO NOT EDIT DIRECTLY! + +Generated from: arch_specs.json +Generated at: {timestamp} + +To update this file: +1. Edit arch_specs.json +2. Run: python generate_arch_specs.py + +This module provides architecture-specific configurations for kernel filtering. +""" + +from typing import Dict, List, Set, Tuple + +# ============================================================================= +# Architecture Data (Generated from arch_specs.json) +# ============================================================================= + +# GPU architecture to family mapping +ARCH_FAMILY_MAP: Dict[str, str] = {arch_family_str} + +# Element size in bytes for each data type +ELEMENT_SIZE_MAP: Dict[str, float] = {element_sizes} + +# Supported warp configurations per architecture [warp_m, warp_n, warp_k] +WARP_SUPPORTED_COMBINATIONS: Dict[str, List[List[int]]] = {warp_configs_str} + +# Supported warp tile combinations: arch -> dtype_key -> [[warp_tile_m, n, k], ...] +WARP_TILE_SUPPORTED_COMBINATIONS: Dict[str, Dict[str, List[List[int]]]] = {warp_tile_str} + +# Preshuffle-specific warp tile combinations (subset of standard GEMM) +PRESHUFFLE_WARP_TILE_SUPPORTED_COMBINATIONS: Dict[str, Dict[str, List[List[int]]]] = {preshuffle_warp_tile_str} + +# Preshuffle-supported pipelines +PRESHUFFLE_PIPELINES: List[str] = {preshuffle_pipelines_str} + +# LDS capacity limits per pipeline type (in bytes) +LDS_CAPACITY_LIMITS: Dict[str, int] = {pipeline_limits_clean} + +# Unsupported trait combinations: (pipeline, epilogue, scheduler) +TRAIT_UNSUPPORTED_COMBINATIONS: Set[Tuple[str, str, str]] = {unsupported_str} + +# Valid dtype combinations: (A_dtype, B_dtype) -> acc_dtype and notes +DTYPE_COMBINATIONS: Dict[str, Dict[str, str]] = {dtype_combos_str} + +# ============================================================================= +# Helper Functions +# ============================================================================= + +def get_supported_archs() -> List[str]: + """Get list of all supported GPU architectures.""" + return list(ARCH_FAMILY_MAP.keys()) + + +def get_arch_family(gpu_arch: str) -> str: + """Get the GPU family for an architecture.""" + return ARCH_FAMILY_MAP.get(gpu_arch.lower(), "unknown") + + +def get_element_size(dtype: str) -> float: + """Get element size in bytes for a data type.""" + return ELEMENT_SIZE_MAP.get(dtype.lower(), 2.0) + + +def get_warp_configs(gpu_arch: str) -> List[List[int]]: + """Get supported warp configurations for an architecture.""" + return WARP_SUPPORTED_COMBINATIONS.get(gpu_arch.lower(), []) + + +def get_warp_tile_combos(gpu_arch: str, dtype_key: str) -> List[List[int]]: + """Get supported warp tile combinations for arch and data types.""" + gpu_combos = WARP_TILE_SUPPORTED_COMBINATIONS.get(gpu_arch.lower(), {{}}) + return gpu_combos.get(dtype_key.lower(), []) + + +def get_lds_limit(pipeline: str) -> int: + """Get LDS capacity limit for a pipeline type.""" + return LDS_CAPACITY_LIMITS.get(pipeline.lower(), LDS_CAPACITY_LIMITS["default"]) + + +def is_trait_combo_unsupported(pipeline: str, epilogue: str, scheduler: str) -> bool: + """Check if a trait combination is unsupported.""" + return (pipeline.lower(), epilogue.lower(), scheduler.lower()) in TRAIT_UNSUPPORTED_COMBINATIONS + + +def get_dtype_info(dtype_a: str, dtype_b: str) -> Dict[str, str]: + """Get accumulator type and notes for a dtype combination.""" + key = f"{{dtype_a.lower()}}_{{dtype_b.lower()}}" + return DTYPE_COMBINATIONS.get(key, {{"acc": "fp32", "notes": "unknown"}}) + + +def is_dtype_combo_valid(dtype_a: str, dtype_b: str) -> bool: + """Check if a dtype combination is valid.""" + key = f"{{dtype_a.lower()}}_{{dtype_b.lower()}}" + return key in DTYPE_COMBINATIONS + + +def get_valid_dtype_combos() -> List[str]: + """Get list of all valid dtype combinations.""" + return list(DTYPE_COMBINATIONS.keys()) +''' + + output_path.write_text(content) + print(f"Generated: {output_path}") + + +def generate_cpp_header(specs: Dict[str, Any], output_path: Path): + """Generate C++ header from arch specs.""" + + timestamp = datetime.now().isoformat() + + # Extract data + archs = specs["architectures"] + element_sizes = specs["element_sizes"] + pipeline_limits = specs["pipeline_lds_limits"] + specs["unsupported_trait_combos"]["combinations"] + + # Build arch enum and string functions + arch_enums = [] + arch_to_string_cases = [] + string_to_arch_cases = [] + + for arch, data in archs.items(): + enum_name = arch.upper().replace("GFX", "GFX_") + arch_enums.append(f" {enum_name}, // {data['description']}") + arch_to_string_cases.append( + f' case GpuArch::{enum_name}: return "{arch}";' + ) + string_to_arch_cases.append( + f' if (arch_str == "{arch}") return GpuArch::{enum_name};' + ) + + # Build warp configs switch + warp_config_cases = [] + for arch, data in archs.items(): + enum_name = arch.upper().replace("GFX", "GFX_") + configs = ", ".join( + [f"{{{c[0]}, {c[1]}, {c[2]}}}" for c in data["warp_configs"]] + ) + warp_config_cases.append( + f" case GpuArch::{enum_name}: return {{{configs}}};" + ) + + # Build element size switch + # Include all data types defined in kernel_key.hpp DataType enum + elem_size_cases = [] + dtype_enum_map = { + "fp16": "FP16", + "bf16": "BF16", + "fp32": "FP32", + "fp64": "FP64", + "fp8": "FP8", + "bf8": "BF8", + "int8": "INT8", + "int4": "INT4", + "int32": "INT32", + } + for dtype, size in element_sizes.items(): + if dtype in dtype_enum_map: + elem_size_cases.append( + f" case DataType::{dtype_enum_map[dtype]}: return {float(size)}f;" + ) + + # Build LDS limits + lds_limit_cases = [] + pipeline_enum_map = { + "mem": "Mem", + "compv1": "CompV1", + "compv2": "CompV2", + "compv3": "CompV3", + "compv4": "CompV4", + "compv5": "CompV5", + "preshufflev1": "PreShuffleV1", + "preshufflev2": "PreShuffleV2", + } + default_lds = pipeline_limits.get("default", 65536) + for pipeline, limit in pipeline_limits.items(): + if pipeline in pipeline_enum_map: + lds_limit_cases.append( + f" if (pipeline == Pipeline::{pipeline_enum_map[pipeline]}) return {limit};" + ) + + content = f"""// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +/** + * AUTO-GENERATED FILE - DO NOT EDIT DIRECTLY! + * + * Generated from: arch_specs.json + * Generated at: {timestamp} + * + * To update this file: + * 1. Edit arch_specs.json + * 2. Run: python generate_arch_specs.py + */ + +#pragma once + +#include "ck_tile/dispatcher/kernel_key.hpp" +#include +#include +#include +#include + +namespace ck_tile {{ +namespace dispatcher {{ +namespace arch_specs {{ + +// ============================================================================= +// GPU Architecture Enum (Generated) +// ============================================================================= + +enum class GpuArch : std::uint8_t {{ +{chr(10).join(arch_enums)} + UNKNOWN +}}; + +// ============================================================================= +// String Conversion Functions (Generated) +// ============================================================================= + +inline std::string arch_to_string(GpuArch arch) {{ + switch (arch) {{ +{chr(10).join(arch_to_string_cases)} + default: return "unknown"; + }} +}} + +inline GpuArch string_to_arch(const std::string& arch_str) {{ +{chr(10).join(string_to_arch_cases)} + return GpuArch::UNKNOWN; +}} + +// ============================================================================= +// Element Size (Generated) +// ============================================================================= + +inline float element_size(DataType dtype) {{ + switch (dtype) {{ +{chr(10).join(elem_size_cases)} + default: return 2.0f; + }} +}} + +// ============================================================================= +// Warp Configurations (Generated) +// ============================================================================= + +using WarpConfig = std::array; + +inline std::vector get_supported_warp_configs(GpuArch arch) {{ + switch (arch) {{ +{chr(10).join(warp_config_cases)} + default: return {{}}; + }} +}} + +// ============================================================================= +// LDS Capacity Limits (Generated) +// ============================================================================= + +inline std::size_t get_lds_capacity(Pipeline pipeline) {{ +{chr(10).join(lds_limit_cases)} + return {default_lds}; // Default +}} + +// ============================================================================= +// Unsupported Trait Combinations (Generated) +// ============================================================================= + +inline bool is_trait_unsupported(Pipeline pipeline, [[maybe_unused]] Epilogue epilogue, Scheduler scheduler) {{ + // Generated from unsupported_trait_combos in arch_specs.json + if (scheduler == Scheduler::Interwave) {{ + if (pipeline == Pipeline::CompV3 || pipeline == Pipeline::CompV4) {{ + return true; + }} + }} + return false; +}} + +}} // namespace arch_specs +}} // namespace dispatcher +}} // namespace ck_tile +""" + + output_path.write_text(content) + print(f"Generated: {output_path}") + + +def main(): + parser = argparse.ArgumentParser( + description="Generate Python and C++ code from arch_specs.json" + ) + parser.add_argument( + "--json", + type=Path, + default=SCRIPT_DIR / "arch_specs.json", + help="Path to arch_specs.json", + ) + parser.add_argument( + "--output-dir", + type=Path, + default=SCRIPT_DIR, + help="Output directory for generated files", + ) + parser.add_argument( + "--cpp-output-dir", + type=Path, + default=None, + help="Output directory for C++ header (defaults to dispatcher/include/...)", + ) + + args = parser.parse_args() + + # Load specs + print(f"Loading: {args.json}") + specs = load_arch_specs(args.json) + + # Generate Python module + py_output = args.output_dir / "arch_specs_generated.py" + generate_python_module(specs, py_output) + + # Generate C++ header + if args.cpp_output_dir: + cpp_output = args.cpp_output_dir / "arch_specs_generated.hpp" + else: + cpp_output = ( + SCRIPT_DIR.parent + / "include" + / "ck_tile" + / "dispatcher" + / "arch_specs_generated.hpp" + ) + + cpp_output.parent.mkdir(parents=True, exist_ok=True) + generate_cpp_header(specs, cpp_output) + + print("\nDone! To apply changes:") + print(" 1. Python code will automatically use arch_specs_generated.py") + print(" 2. C++ code includes arch_specs_generated.hpp") + + +if __name__ == "__main__": + main() diff --git a/dispatcher/codegen/generate_dispatcher_registration.py b/dispatcher/codegen/generate_dispatcher_registration.py new file mode 100644 index 00000000000..024ec4a7c8c --- /dev/null +++ b/dispatcher/codegen/generate_dispatcher_registration.py @@ -0,0 +1,429 @@ +#!/usr/bin/env python3 + +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +Generate dispatcher registration code for CK Tile kernels + +This script generates C++ registration code that instantiates TileKernelInstance +templates for each generated kernel, solving the "cannot instantiate from parsed headers" problem. +""" + +import json +import argparse +from pathlib import Path +from typing import List +from dataclasses import dataclass + + +@dataclass +class KernelConfig: + """Kernel configuration for registration""" + + name: str + header_file: str + tile_m: int + tile_n: int + tile_k: int + warp_m: int + warp_n: int + warp_k: int + warp_tile_m: int + warp_tile_n: int + warp_tile_k: int + block_size: int + pipeline: str + epilogue: str + scheduler: str + pad_m: bool + pad_n: bool + pad_k: bool + persistent: bool + double_buffer: bool + transpose_c: bool + dtype_a: str = "fp16" + dtype_b: str = "fp16" + dtype_c: str = "fp16" + dtype_acc: str = "fp32" + layout_a: str = "row" + layout_b: str = "col" + layout_c: str = "row" + + +def generate_registration_header(kernels: List[KernelConfig], output_file: Path): + """Generate registration header file""" + + content = """// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +// +// AUTO-GENERATED FILE - DO NOT EDIT +// Generated by generate_dispatcher_registration.py + +#pragma once + +#include "ck_tile/dispatcher/registry.hpp" +#include "ck_tile/dispatcher/backends/tile_backend.hpp" +#include "ck_tile/dispatcher/backends/kernel_registration.hpp" + +// Include all generated kernel headers +""" + + # Add includes for all kernel headers + for kernel in kernels: + content += f'#include "{kernel.header_file}"\n' + + content += """ + +namespace ck_tile { +namespace dispatcher { +namespace generated { + +/// Register all generated kernels with the dispatcher +inline void register_all_kernels(Registry& registry) +{ +""" + + # Add registration calls for each kernel + for kernel in kernels: + # Extract the SelectedKernel type name from the header file + # Assuming the header defines a type like: using SelectedKernel = ... + kernel_type = f"SelectedKernel_{kernel.name}" + + content += f""" // Register {kernel.name} + register_tile_kernel<{kernel_type}>(registry, "{kernel.name}"); +""" + + content += """} + +/// Register all generated kernels with the global registry +inline void register_all_kernels() +{ + auto& registry = Registry::instance(); + register_all_kernels(registry); +} + +} // namespace generated +} // namespace dispatcher +} // namespace ck_tile +""" + + output_file.write_text(content) + print(f"✓ Generated registration header: {output_file}") + + +def generate_registration_cpp(kernels: List[KernelConfig], output_file: Path): + """Generate registration implementation file""" + + content = """// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +// +// AUTO-GENERATED FILE - DO NOT EDIT +// Generated by generate_dispatcher_registration.py + +#include "dispatcher_registration.hpp" + +namespace ck_tile { +namespace dispatcher { +namespace generated { + +// Explicit instantiations to reduce compile time +// These ensure the templates are instantiated once + +""" + + for kernel in kernels: + kernel_type = f"SelectedKernel_{kernel.name}" + content += f"template class backends::TileKernelInstance<{kernel_type}>;\n" + + content += """ +} // namespace generated +} // namespace dispatcher +} // namespace ck_tile +""" + + output_file.write_text(content) + print(f"✓ Generated registration implementation: {output_file}") + + +def generate_kernel_wrapper_header(kernel: KernelConfig, output_dir: Path): + """Generate a wrapper header that defines SelectedKernel type""" + + wrapper_file = output_dir / f"{kernel.name}_wrapper.hpp" + + content = f"""// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +// +// AUTO-GENERATED FILE - DO NOT EDIT +// Generated by generate_dispatcher_registration.py + +#pragma once + +#include "{kernel.header_file}" + +namespace ck_tile {{ +namespace dispatcher {{ +namespace generated {{ + +// Type alias for dispatcher registration +// This allows the registration code to reference the kernel type +using SelectedKernel_{kernel.name} = /* Actual kernel type from generated header */; + +}} // namespace generated +}} // namespace dispatcher +}} // namespace ck_tile +""" + + wrapper_file.write_text(content) + + +def load_kernel_manifest(manifest_file: Path) -> List[KernelConfig]: + """Load kernel configurations from manifest file""" + + with open(manifest_file, "r") as f: + data = json.load(f) + + kernels = [] + for kernel_data in data.get("kernels", []): + kernel = KernelConfig( + name=kernel_data["name"], + header_file=kernel_data["header_file"], + tile_m=kernel_data["tile_m"], + tile_n=kernel_data["tile_n"], + tile_k=kernel_data["tile_k"], + warp_m=kernel_data.get("warp_m", 2), + warp_n=kernel_data.get("warp_n", 2), + warp_k=kernel_data.get("warp_k", 1), + warp_tile_m=kernel_data.get("warp_tile_m", 32), + warp_tile_n=kernel_data.get("warp_tile_n", 32), + warp_tile_k=kernel_data.get("warp_tile_k", 16), + block_size=kernel_data.get("block_size", 256), + pipeline=kernel_data.get("pipeline", "compv4"), + epilogue=kernel_data.get("epilogue", "cshuffle"), + scheduler=kernel_data.get("scheduler", "intrawave"), + pad_m=kernel_data.get("pad_m", False), + pad_n=kernel_data.get("pad_n", False), + pad_k=kernel_data.get("pad_k", False), + persistent=kernel_data.get("persistent", False), + double_buffer=kernel_data.get("double_buffer", True), + transpose_c=kernel_data.get("transpose_c", False), + dtype_a=kernel_data.get("dtype_a", "fp16"), + dtype_b=kernel_data.get("dtype_b", "fp16"), + dtype_c=kernel_data.get("dtype_c", "fp16"), + dtype_acc=kernel_data.get("dtype_acc", "fp32"), + ) + kernels.append(kernel) + + return kernels + + +def scan_generated_headers(generated_dir: Path) -> List[KernelConfig]: + """Scan generated headers and extract kernel configurations""" + + import re + + kernels = [] + + for header_file in generated_dir.glob("**/*.hpp"): + try: + content = header_file.read_text() + + # Extract kernel name + name_match = re.search( + r'constexpr const char\* KERNEL_NAME\s*=\s*"([^"]+)"', content + ) + if not name_match: + continue + + kernel_name = name_match.group(1) + + # Extract tile configuration (support ck_tile::index_t) + tile_m_match = re.search( + r"(?:static\s+)?constexpr\s+(?:int|std::size_t|ck_tile::index_t)\s+TileM\s*=\s*(\d+)", + content, + ) + tile_n_match = re.search( + r"(?:static\s+)?constexpr\s+(?:int|std::size_t|ck_tile::index_t)\s+TileN\s*=\s*(\d+)", + content, + ) + tile_k_match = re.search( + r"(?:static\s+)?constexpr\s+(?:int|std::size_t|ck_tile::index_t)\s+TileK\s*=\s*(\d+)", + content, + ) + + tile_m = int(tile_m_match.group(1)) if tile_m_match else 256 + tile_n = int(tile_n_match.group(1)) if tile_n_match else 256 + tile_k = int(tile_k_match.group(1)) if tile_k_match else 32 + + # Extract warp configuration + warp_m_match = re.search( + r"(?:static\s+)?constexpr\s+(?:int|std::size_t|ck_tile::index_t)\s+WarpPerBlock_M\s*=\s*(\d+)", + content, + ) + warp_n_match = re.search( + r"(?:static\s+)?constexpr\s+(?:int|std::size_t|ck_tile::index_t)\s+WarpPerBlock_N\s*=\s*(\d+)", + content, + ) + warp_k_match = re.search( + r"(?:static\s+)?constexpr\s+(?:int|std::size_t|ck_tile::index_t)\s+WarpPerBlock_K\s*=\s*(\d+)", + content, + ) + + warp_m = int(warp_m_match.group(1)) if warp_m_match else 2 + warp_n = int(warp_n_match.group(1)) if warp_n_match else 2 + warp_k = int(warp_k_match.group(1)) if warp_k_match else 1 + + # Extract warp tile configuration + warp_tile_m_match = re.search( + r"(?:static\s+)?constexpr\s+(?:int|std::size_t|ck_tile::index_t)\s+WarpTileM\s*=\s*(\d+)", + content, + ) + warp_tile_n_match = re.search( + r"(?:static\s+)?constexpr\s+(?:int|std::size_t|ck_tile::index_t)\s+WarpTileN\s*=\s*(\d+)", + content, + ) + warp_tile_k_match = re.search( + r"(?:static\s+)?constexpr\s+(?:int|std::size_t|ck_tile::index_t)\s+WarpTileK\s*=\s*(\d+)", + content, + ) + + warp_tile_m = int(warp_tile_m_match.group(1)) if warp_tile_m_match else 32 + warp_tile_n = int(warp_tile_n_match.group(1)) if warp_tile_n_match else 32 + warp_tile_k = int(warp_tile_k_match.group(1)) if warp_tile_k_match else 16 + + # Extract other parameters (with defaults) + block_size_match = re.search( + r"(?:static\s+)?constexpr\s+(?:int|std::size_t|ck_tile::index_t)\s+BlockSize\s*=\s*(\d+)", + content, + ) + block_size = int(block_size_match.group(1)) if block_size_match else 256 + + # Extract boolean flags + pad_m = re.search(r"kPadM\s*=\s*true", content) is not None + pad_n = re.search(r"kPadN\s*=\s*true", content) is not None + pad_k = re.search(r"kPadK\s*=\s*true", content) is not None + persistent = ( + re.search(r"UsePersistentKernel\s*=\s*true", content) is not None + ) + double_buffer = ( + re.search(r"DoubleSmemBuffer\s*=\s*true", content) is not None + ) + transpose_c = re.search(r"TransposeC\s*=\s*true", content) is not None + + kernel = KernelConfig( + name=kernel_name, + header_file=str(header_file.relative_to(generated_dir.parent)), + tile_m=tile_m, + tile_n=tile_n, + tile_k=tile_k, + warp_m=warp_m, + warp_n=warp_n, + warp_k=warp_k, + warp_tile_m=warp_tile_m, + warp_tile_n=warp_tile_n, + warp_tile_k=warp_tile_k, + block_size=block_size, + pipeline="compv4", + epilogue="cshuffle", + scheduler="intrawave", + pad_m=pad_m, + pad_n=pad_n, + pad_k=pad_k, + persistent=persistent, + double_buffer=double_buffer, + transpose_c=transpose_c, + ) + + kernels.append(kernel) + + except Exception as e: + print(f"Warning: Failed to parse {header_file}: {e}") + continue + + return kernels + + +def main(): + parser = argparse.ArgumentParser( + description="Generate dispatcher registration code" + ) + parser.add_argument( + "--generated-dir", + type=str, + required=True, + help="Directory containing generated kernel headers", + ) + parser.add_argument( + "--output-dir", + type=str, + required=True, + help="Output directory for registration code", + ) + parser.add_argument( + "--manifest", type=str, help="Optional manifest file with kernel configurations" + ) + parser.add_argument( + "--scan", + action="store_true", + help="Scan generated headers instead of using manifest", + ) + + args = parser.parse_args() + + generated_dir = Path(args.generated_dir) + output_dir = Path(args.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + # Load kernel configurations + if args.manifest: + print(f"Loading kernels from manifest: {args.manifest}") + kernels = load_kernel_manifest(Path(args.manifest)) + elif args.scan: + print(f"Scanning generated headers in: {generated_dir}") + kernels = scan_generated_headers(generated_dir) + else: + print("Error: Must specify either --manifest or --scan") + return 1 + + print(f"Found {len(kernels)} kernels") + + # Generate registration code + registration_header = output_dir / "dispatcher_registration.hpp" + registration_cpp = output_dir / "dispatcher_registration.cpp" + + generate_registration_header(kernels, registration_header) + generate_registration_cpp(kernels, registration_cpp) + + # Generate manifest for Python + manifest_output = output_dir / "kernels_manifest.json" + manifest_data = { + "kernels": [ + { + "name": k.name, + "header_file": k.header_file, + "tile_m": k.tile_m, + "tile_n": k.tile_n, + "tile_k": k.tile_k, + "block_size": k.block_size, + "persistent": k.persistent, + } + for k in kernels + ] + } + + with open(manifest_output, "w") as f: + json.dump(manifest_data, f, indent=2) + + print(f"✓ Generated manifest: {manifest_output}") + print("\n✓ Registration code generation complete!") + print(f" Total kernels: {len(kernels)}") + print(" Output files:") + print(f" - {registration_header}") + print(f" - {registration_cpp}") + print(f" - {manifest_output}") + + return 0 + + +if __name__ == "__main__": + exit(main()) diff --git a/dispatcher/codegen/generate_kernel_wrappers.py b/dispatcher/codegen/generate_kernel_wrappers.py new file mode 100644 index 00000000000..53a9bff3edc --- /dev/null +++ b/dispatcher/codegen/generate_kernel_wrappers.py @@ -0,0 +1,430 @@ +#!/usr/bin/env python3 + +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +Generate one .cpp wrapper file per kernel header for maximum parallel compilation. + +Each kernel becomes its own translation unit, enabling: + - Maximum parallelism with make -j$(nproc) + - Per-kernel build progress (e.g., [5/128] Building kernel: gemm_fp16_128x128) + - Incremental rebuilds (only changed kernels recompile) + - Fine-grained build time analysis + +Usage: + python3 generate_kernel_wrappers.py --kernel-dir build/generated_kernels --output-dir build/kernel_wrappers + +Output structure: + build/kernel_wrappers/ + ├── gemm_fp16_rcr_128x128x32.cpp + ├── gemm_fp16_rcr_256x256x64.cpp + ├── conv_fwd_fp16_2d_128x128.cpp + └── ... + +Each .cpp simply includes its corresponding .hpp and forces symbol emission. +""" + +import argparse +import sys +from pathlib import Path +from typing import List, Tuple +import concurrent.futures + + +WRAPPER_TEMPLATE_GEMM = """// SPDX-License-Identifier: MIT +// Auto-generated wrapper for: {kernel_name} +// This file enables per-kernel parallel compilation + +#include "{kernel_hpp}" + +// Force symbol emission for kernel registration +namespace ck_tile {{ +namespace dispatcher {{ +namespace generated {{ + +// Marker to prevent dead code elimination +volatile bool _{kernel_id}_registered = true; + +}} // namespace generated +}} // namespace dispatcher +}} // namespace ck_tile +""" + +WRAPPER_TEMPLATE_CONV = """// SPDX-License-Identifier: MIT +// Auto-generated wrapper for: {kernel_name} +// This file enables per-kernel parallel compilation + +#include "{kernel_hpp}" + +namespace ck_tile {{ +namespace dispatcher {{ +namespace generated {{ + +volatile bool _{kernel_id}_registered = true; + +}} // namespace generated +}} // namespace dispatcher +}} // namespace ck_tile +""" + + +def generate_wrapper( + kernel_hpp: Path, output_dir: Path, index: int, total: int +) -> Tuple[Path, bool]: + """Generate a .cpp wrapper for a single kernel header.""" + kernel_name = kernel_hpp.stem + kernel_id = kernel_name.replace("-", "_").replace(".", "_") + + # Select template based on kernel type + if kernel_name.startswith("gemm"): + template = WRAPPER_TEMPLATE_GEMM + else: + template = WRAPPER_TEMPLATE_CONV + + content = template.format( + kernel_name=kernel_name, + kernel_hpp=kernel_hpp.name, + kernel_id=kernel_id, + ) + + output_cpp = output_dir / f"{kernel_name}.cpp" + + # Only write if content changed (for incremental builds) + if output_cpp.exists(): + existing = output_cpp.read_text() + if existing == content: + return output_cpp, False # No change + + output_cpp.write_text(content) + return output_cpp, True # Written + + +def generate_cmake_list( + wrappers: List[Path], output_dir: Path, kernel_dir: Path +) -> Path: + """Generate CMakeLists.txt that compiles each wrapper as a separate object.""" + + num_kernels = len(wrappers) + + cmake_content = f'''# SPDX-License-Identifier: MIT +# Auto-generated CMakeLists.txt for per-kernel parallel compilation +# Generated {num_kernels} kernel translation units + +cmake_minimum_required(VERSION 3.16) + +# ============================================================================= +# Per-Kernel Object Targets ({num_kernels} kernels) +# ============================================================================= +# Each kernel is compiled as a separate OBJECT library for maximum parallelism. +# Build with: make -j$(nproc) all_kernels +# +# Progress output: +# [ 1/{num_kernels}] Building kernel: gemm_fp16_rcr_128x128x32 +# [ 2/{num_kernels}] Building kernel: gemm_fp16_rcr_256x256x64 +# ... + +set(KERNEL_INCLUDE_DIR "{kernel_dir}") +set(ALL_KERNEL_OBJECTS "") + +''' + + for idx, wrapper in enumerate(wrappers, 1): + kernel_name = wrapper.stem + obj_target = f"kobj_{kernel_name}" + + cmake_content += f""" +# [{idx}/{num_kernels}] {kernel_name} +add_library({obj_target} OBJECT {wrapper.name}) +target_include_directories({obj_target} PRIVATE ${{KERNEL_INCLUDE_DIR}} ${{CK_INCLUDE_DIR}}) +target_compile_options({obj_target} PRIVATE + -mllvm -enable-noalias-to-md-conversion=0 + -Wno-undefined-func-template + -Wno-float-equal + --offload-compress +) +set_target_properties({obj_target} PROPERTIES POSITION_INDEPENDENT_CODE ON) +if(hip_FOUND) + target_link_libraries({obj_target} PRIVATE hip::device hip::host) +endif() +list(APPEND ALL_KERNEL_OBJECTS $) +""" + + cmake_content += f""" + +# ============================================================================= +# Combined Kernel Library +# ============================================================================= +# Links all {num_kernels} kernel objects into a single shared library + +add_library(all_kernels SHARED ${{ALL_KERNEL_OBJECTS}}) +if(hip_FOUND) + target_link_libraries(all_kernels PRIVATE hip::device hip::host) +endif() +set_target_properties(all_kernels PROPERTIES + POSITION_INDEPENDENT_CODE ON + OUTPUT_NAME "dispatcher_kernels" +) + +message(STATUS "Configured {num_kernels} kernel objects for parallel compilation") +message(STATUS "Build with: make -j$(nproc) all_kernels") +""" + + cmake_file = output_dir / "CMakeLists.txt" + cmake_file.write_text(cmake_content) + return cmake_file + + +def generate_ninja_build( + wrappers: List[Path], output_dir: Path, kernel_dir: Path +) -> Path: + """Generate build.ninja for even faster parallel compilation.""" + + num_kernels = len(wrappers) + + ninja_content = f"""# SPDX-License-Identifier: MIT +# Auto-generated build.ninja for per-kernel parallel compilation +# {num_kernels} kernel translation units + +# Variables +cxx = hipcc +cxxflags = -fPIC -std=c++17 -O3 -mllvm -enable-noalias-to-md-conversion=0 -Wno-undefined-func-template -Wno-float-equal --offload-compress +includes = -I{kernel_dir} -I/opt/rocm/include + +# Rules +rule compile + command = $cxx $cxxflags $includes -c $in -o $out + description = [{num_kernels}] Building kernel: $kernel_name + +rule link + command = $cxx -shared $in -o $out -L/opt/rocm/lib -lamdhip64 + description = Linking: $out + +# Kernel objects +""" + + obj_files = [] + for idx, wrapper in enumerate(wrappers, 1): + kernel_name = wrapper.stem + obj_file = f"{kernel_name}.o" + obj_files.append(obj_file) + + ninja_content += f""" +build {obj_file}: compile {wrapper.name} + kernel_name = {kernel_name} +""" + + ninja_content += f""" + +# Shared library +build libdispatcher_kernels.so: link {" ".join(obj_files)} + +# Default target +default libdispatcher_kernels.so +""" + + ninja_file = output_dir / "build.ninja" + ninja_file.write_text(ninja_content) + return ninja_file + + +def generate_makefile(wrappers: List[Path], output_dir: Path, kernel_dir: Path) -> Path: + """Generate Makefile for per-kernel parallel compilation.""" + + num_kernels = len(wrappers) + kernel_names = [w.stem for w in wrappers] + obj_files = [f"{name}.o" for name in kernel_names] + + makefile_content = f"""# SPDX-License-Identifier: MIT +# Auto-generated Makefile for per-kernel parallel compilation +# {num_kernels} kernel translation units +# +# Usage: +# make -j$(nproc) # Build all kernels in parallel +# make -j$(nproc) VERBOSE=1 # With per-kernel progress +# make clean # Remove all objects + +CXX = hipcc +CXXFLAGS = -fPIC -std=c++17 -O3 -mllvm -enable-noalias-to-md-conversion=0 \\ + -Wno-undefined-func-template -Wno-float-equal --offload-compress +INCLUDES = -I{kernel_dir} -I/opt/rocm/include +LDFLAGS = -shared -L/opt/rocm/lib -lamdhip64 + +TARGET = libdispatcher_kernels.so +OBJECTS = {" ".join(obj_files)} + +# Progress counter (only works with make -j1, use ninja for parallel progress) +TOTAL_KERNELS = {num_kernels} +CURRENT = 0 + +.PHONY: all clean + +all: $(TARGET) +\t@echo "Built $(TARGET) with {num_kernels} kernels" + +$(TARGET): $(OBJECTS) +\t@echo "[LINK] Linking {num_kernels} kernel objects -> $@" +\t$(CXX) $(LDFLAGS) $^ -o $@ + +""" + + for idx, (wrapper, obj) in enumerate(zip(wrappers, obj_files), 1): + kernel_name = wrapper.stem + makefile_content += f""" +{obj}: {wrapper.name} +\t@echo "[{idx}/{num_kernels}] Building kernel: {kernel_name}" +\t$(CXX) $(CXXFLAGS) $(INCLUDES) -c $< -o $@ +""" + + makefile_content += f""" + +clean: +\trm -f $(OBJECTS) $(TARGET) +\t@echo "Cleaned {num_kernels} kernel objects" +""" + + makefile = output_dir / "Makefile" + makefile.write_text(makefile_content) + return makefile + + +def main(): + parser = argparse.ArgumentParser( + description="Generate per-kernel wrapper .cpp files for parallel compilation" + ) + parser.add_argument( + "--kernel-dir", + type=Path, + required=True, + help="Directory containing generated kernel .hpp files", + ) + parser.add_argument( + "--output-dir", + type=Path, + required=True, + help="Output directory for wrapper .cpp files", + ) + parser.add_argument( + "--pattern", + type=str, + default="*.hpp", + help="Glob pattern for kernel headers (default: *.hpp)", + ) + parser.add_argument( + "--generate-cmake", + action="store_true", + help="Generate CMakeLists.txt for the wrappers", + ) + parser.add_argument( + "--generate-ninja", + action="store_true", + help="Generate build.ninja for ninja builds", + ) + parser.add_argument( + "--generate-makefile", + action="store_true", + help="Generate Makefile for make builds", + ) + parser.add_argument( + "--parallel", + action="store_true", + default=True, + help="Generate wrappers in parallel (default: True)", + ) + parser.add_argument( + "-v", + "--verbose", + action="store_true", + help="Verbose output", + ) + + args = parser.parse_args() + + # Find kernel headers + kernel_dir = args.kernel_dir.resolve() + if not kernel_dir.exists(): + print(f"Error: Kernel directory not found: {kernel_dir}", file=sys.stderr) + return 1 + + kernel_headers = sorted(kernel_dir.glob(args.pattern)) + if not kernel_headers: + print( + f"Error: No kernel headers found matching {args.pattern} in {kernel_dir}", + file=sys.stderr, + ) + return 1 + + num_kernels = len(kernel_headers) + print(f"Found {num_kernels} kernel headers in {kernel_dir}") + + # Create output directory + output_dir = args.output_dir.resolve() + output_dir.mkdir(parents=True, exist_ok=True) + + # Generate wrappers + print(f"Generating {num_kernels} wrapper .cpp files...") + + wrappers = [] + written = 0 + + if args.parallel and num_kernels > 1: + with concurrent.futures.ThreadPoolExecutor() as executor: + futures = { + executor.submit( + generate_wrapper, hpp, output_dir, idx, num_kernels + ): hpp + for idx, hpp in enumerate(kernel_headers, 1) + } + for future in concurrent.futures.as_completed(futures): + wrapper_path, was_written = future.result() + wrappers.append(wrapper_path) + if was_written: + written += 1 + if args.verbose: + print(f" Generated: {wrapper_path.name}") + else: + for idx, hpp in enumerate(kernel_headers, 1): + wrapper_path, was_written = generate_wrapper( + hpp, output_dir, idx, num_kernels + ) + wrappers.append(wrapper_path) + if was_written: + written += 1 + if args.verbose: + print(f" [{idx}/{num_kernels}] Generated: {wrapper_path.name}") + + wrappers.sort(key=lambda p: p.name) + + print( + f" Total: {num_kernels} wrappers ({written} written, {num_kernels - written} unchanged)" + ) + + # Generate build files + if args.generate_cmake: + cmake_file = generate_cmake_list(wrappers, output_dir, kernel_dir) + print(f" Generated: {cmake_file}") + + if args.generate_ninja: + ninja_file = generate_ninja_build(wrappers, output_dir, kernel_dir) + print(f" Generated: {ninja_file}") + + if args.generate_makefile: + makefile = generate_makefile(wrappers, output_dir, kernel_dir) + print(f" Generated: {makefile}") + + print(f"\nOutput directory: {output_dir}") + print(f"Kernels ready for parallel compilation: {num_kernels}") + print("\nTo build:") + print(f" cd {output_dir}") + if args.generate_makefile: + print(" make -j$(nproc) # Parallel build with progress") + if args.generate_ninja: + print(" ninja # Fast parallel build") + if args.generate_cmake: + print(" cmake -B build && cmake --build build -j$(nproc)") + + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/dispatcher/codegen/kernel_config_loader.py b/dispatcher/codegen/kernel_config_loader.py new file mode 100644 index 00000000000..537fc40581e --- /dev/null +++ b/dispatcher/codegen/kernel_config_loader.py @@ -0,0 +1,798 @@ +#!/usr/bin/env python3 + +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +Kernel Configuration Loader + +Load kernel configurations from JSON files for generating specific kernel sets. +Compatible with tile_engine JSON format. + +Usage: + from kernel_config_loader import load_kernel_configs, KernelConfigSet + + # Load configs from JSON + config_set = load_kernel_configs("my_kernels.json") + + # Get all configurations (cartesian product of all parameter values) + for config in config_set.generate_configs(): + print(config) + + # Use with codegen + from unified_gemm_codegen import UnifiedGemmCodegen + codegen = UnifiedGemmCodegen(...) + codegen.generate_from_configs(config_set.generate_configs()) +""" + +import json +import itertools +from dataclasses import dataclass, field +from pathlib import Path +from typing import List, Dict, Any, Optional, Iterator + + +@dataclass +class TileConfig: + """Tile configuration for a kernel""" + + tile_m: int = 128 + tile_n: int = 128 + tile_k: int = 32 + warp_m: int = 2 + warp_n: int = 2 + warp_k: int = 1 + warp_tile_m: int = 32 + warp_tile_n: int = 32 + warp_tile_k: int = 16 + + +@dataclass +class TraitConfig: + """Trait configuration for a kernel (order matches GEMM/Conv TraitConfig)""" + + pipeline: str = "compv4" + epilogue: str = "cshuffle" + scheduler: str = "intrawave" + pad_m: bool = False + pad_n: bool = False + pad_k: bool = False + + +@dataclass +class KernelConfig: + """Complete kernel configuration""" + + tile: TileConfig = field(default_factory=TileConfig) + trait: TraitConfig = field(default_factory=TraitConfig) + dtype_a: str = "fp16" + dtype_b: str = "fp16" + dtype_c: str = "fp16" + dtype_acc: str = "fp32" + layout: str = "rcr" + gpu_target: str = "gfx942" + variant: str = "standard" + + def to_dict(self) -> Dict[str, Any]: + """Convert to dictionary for codegen""" + return { + "tile_m": self.tile.tile_m, + "tile_n": self.tile.tile_n, + "tile_k": self.tile.tile_k, + "warp_m": self.tile.warp_m, + "warp_n": self.tile.warp_n, + "warp_k": self.tile.warp_k, + "warp_tile_m": self.tile.warp_tile_m, + "warp_tile_n": self.tile.warp_tile_n, + "warp_tile_k": self.tile.warp_tile_k, + "pipeline": self.trait.pipeline, + "scheduler": self.trait.scheduler, + "epilogue": self.trait.epilogue, + "pad_m": self.trait.pad_m, + "pad_n": self.trait.pad_n, + "pad_k": self.trait.pad_k, + "dtype_a": self.dtype_a, + "dtype_b": self.dtype_b, + "dtype_c": self.dtype_c, + "dtype_acc": self.dtype_acc, + "layout": self.layout, + "gpu_target": self.gpu_target, + "variant": self.variant, + } + + def kernel_name(self) -> str: + """Generate kernel name from config""" + name = f"gemm_{self.dtype_a}_{self.layout}_{self.trait.pipeline}" + name += f"_{self.trait.epilogue}_{self.trait.scheduler}" + name += f"_{str(self.trait.pad_m).capitalize()}" + name += f"_{str(self.trait.pad_n).capitalize()}" + name += f"_{str(self.trait.pad_k).capitalize()}" + name += "_False" # preshuffle + name += f"_{self.tile.tile_m}x{self.tile.tile_n}x{self.tile.tile_k}" + name += f"_{self.tile.warp_m}x{self.tile.warp_n}x{self.tile.warp_k}" + name += ( + f"_{self.tile.warp_tile_m}x{self.tile.warp_tile_n}x{self.tile.warp_tile_k}" + ) + return name + + +@dataclass +class KernelConfigSet: + """A set of kernel configurations loaded from JSON""" + + name: str = "default" + configs: List[KernelConfig] = field(default_factory=list) + + # Parameter ranges for generation + tile_m_values: List[int] = field(default_factory=lambda: [128]) + tile_n_values: List[int] = field(default_factory=lambda: [128]) + tile_k_values: List[int] = field(default_factory=lambda: [32]) + warp_m_values: List[int] = field(default_factory=lambda: [2]) + warp_n_values: List[int] = field(default_factory=lambda: [2]) + warp_k_values: List[int] = field(default_factory=lambda: [1]) + warp_tile_m_values: List[int] = field(default_factory=lambda: [32]) + warp_tile_n_values: List[int] = field(default_factory=lambda: [32]) + warp_tile_k_values: List[int] = field(default_factory=lambda: [16]) + + pipeline_values: List[str] = field(default_factory=lambda: ["compv4"]) + scheduler_values: List[str] = field(default_factory=lambda: ["intrawave"]) + epilogue_values: List[str] = field(default_factory=lambda: ["cshuffle"]) + pad_m_values: List[bool] = field(default_factory=lambda: [False]) + pad_n_values: List[bool] = field(default_factory=lambda: [False]) + pad_k_values: List[bool] = field(default_factory=lambda: [False]) + + dtype_a: str = "fp16" + dtype_b: str = "fp16" + dtype_c: str = "fp16" + dtype_acc: str = "fp32" + layout: str = "rcr" + gpu_targets: List[str] = field(default_factory=lambda: ["gfx942"]) + variant: str = "standard" + + def generate_configs(self) -> Iterator[KernelConfig]: + """Generate all kernel configurations (cartesian product)""" + # Tile parameters + tile_params = itertools.product( + self.tile_m_values, + self.tile_n_values, + self.tile_k_values, + self.warp_m_values, + self.warp_n_values, + self.warp_k_values, + self.warp_tile_m_values, + self.warp_tile_n_values, + self.warp_tile_k_values, + ) + + # Trait parameters + trait_params = itertools.product( + self.pipeline_values, + self.scheduler_values, + self.epilogue_values, + self.pad_m_values, + self.pad_n_values, + self.pad_k_values, + ) + + # Convert to lists for reuse + tile_list = list(tile_params) + trait_list = list(trait_params) + + # Generate for each GPU target + for gpu_target in self.gpu_targets: + for tile in tile_list: + for trait in trait_list: + tile_cfg = TileConfig( + tile_m=tile[0], + tile_n=tile[1], + tile_k=tile[2], + warp_m=tile[3], + warp_n=tile[4], + warp_k=tile[5], + warp_tile_m=tile[6], + warp_tile_n=tile[7], + warp_tile_k=tile[8], + ) + trait_cfg = TraitConfig( + pipeline=trait[0], + scheduler=trait[1], + epilogue=trait[2], + pad_m=trait[3], + pad_n=trait[4], + pad_k=trait[5], + ) + yield KernelConfig( + tile=tile_cfg, + trait=trait_cfg, + dtype_a=self.dtype_a, + dtype_b=self.dtype_b, + dtype_c=self.dtype_c, + dtype_acc=self.dtype_acc, + layout=self.layout, + gpu_target=gpu_target, + variant=self.variant, + ) + + def config_count(self) -> int: + """Get total number of configurations""" + tile_count = ( + len(self.tile_m_values) + * len(self.tile_n_values) + * len(self.tile_k_values) + * len(self.warp_m_values) + * len(self.warp_n_values) + * len(self.warp_k_values) + * len(self.warp_tile_m_values) + * len(self.warp_tile_n_values) + * len(self.warp_tile_k_values) + ) + trait_count = ( + len(self.pipeline_values) + * len(self.scheduler_values) + * len(self.epilogue_values) + * len(self.pad_m_values) + * len(self.pad_n_values) + * len(self.pad_k_values) + ) + return tile_count * trait_count * len(self.gpu_targets) + + +def _get_values(config: Dict, key: str, default: List) -> List: + """Extract values from config dict, handling range specifications""" + if key not in config: + return default + + item = config[key] + + # Explicit values list + if "values" in item: + return item["values"] + + # Range specification (min, max, step) + if "min" in item and "max" in item: + min_val = item["min"] + max_val = item["max"] + step = item.get("step", 1) + return list(range(min_val, max_val + 1, step)) + + return default + + +def load_kernel_configs(json_path: str | Path) -> KernelConfigSet: + """ + Load kernel configurations from a JSON file. + + Supports both tile_engine format and dispatcher format. + + Args: + json_path: Path to JSON configuration file + + Returns: + KernelConfigSet with all parameter values loaded + """ + json_path = Path(json_path) + + with open(json_path) as f: + data = json.load(f) + + config_set = KernelConfigSet() + + # Name + config_set.name = data.get("kernel_set_name", json_path.stem) + + # Data types + if "datatype" in data: + dt = data["datatype"] + config_set.dtype_a = dt.get("a", "fp16") + config_set.dtype_b = dt.get("b", "fp16") + config_set.dtype_c = dt.get("c", "fp16") + config_set.dtype_acc = dt.get("acc", "fp32") + + # Layout + config_set.layout = data.get("layout", "rcr") + + # GPU targets + if "gpu_targets" in data: + config_set.gpu_targets = data["gpu_targets"] + elif "gpu_target" in data: + config_set.gpu_targets = [data["gpu_target"]] + + # Variant + config_set.variant = data.get("variant", "standard") + + # Tile config + tile_cfg = data.get("tile_config", {}) + config_set.tile_m_values = _get_values(tile_cfg, "tile_m", [128]) + config_set.tile_n_values = _get_values(tile_cfg, "tile_n", [128]) + config_set.tile_k_values = _get_values(tile_cfg, "tile_k", [32]) + config_set.warp_m_values = _get_values(tile_cfg, "warp_m", [2]) + config_set.warp_n_values = _get_values(tile_cfg, "warp_n", [2]) + config_set.warp_k_values = _get_values(tile_cfg, "warp_k", [1]) + config_set.warp_tile_m_values = _get_values(tile_cfg, "warp_tile_m", [32]) + config_set.warp_tile_n_values = _get_values(tile_cfg, "warp_tile_n", [32]) + config_set.warp_tile_k_values = _get_values(tile_cfg, "warp_tile_k", [16]) + + # Trait config + trait_cfg = data.get("trait_config", {}) + config_set.pipeline_values = _get_values(trait_cfg, "pipeline", ["compv4"]) + config_set.scheduler_values = _get_values(trait_cfg, "scheduler", ["intrawave"]) + config_set.epilogue_values = _get_values(trait_cfg, "epilogue", ["cshuffle"]) + config_set.pad_m_values = _get_values(trait_cfg, "pad_m", [False]) + config_set.pad_n_values = _get_values(trait_cfg, "pad_n", [False]) + config_set.pad_k_values = _get_values(trait_cfg, "pad_k", [False]) + + return config_set + + +# ============================================================================= +# Convolution Configuration Classes +# ============================================================================= + + +@dataclass +class ConvTileConfig: + """Tile configuration for a convolution kernel""" + + tile_m: int = 128 # M dimension (N * spatial_out for fwd) + tile_n: int = 128 # N dimension (K output channels for fwd) + tile_k: int = 32 # K dimension (C * filter for fwd) + warp_m: int = 2 + warp_n: int = 2 + warp_k: int = 1 + warp_tile_m: int = 32 + warp_tile_n: int = 32 + warp_tile_k: int = 16 + + +@dataclass +class ConvTraitConfig: + """Trait configuration for a convolution kernel""" + + pipeline: str = "compv3" + scheduler: str = "intrawave" + epilogue: str = "cshuffle" + pad_m: bool = True + pad_n: bool = True + pad_k: bool = True + double_smem_buffer: bool = False + num_groups_to_merge: int = 1 + + +@dataclass +class ConvKernelConfig: + """Complete convolution kernel configuration""" + + tile: ConvTileConfig = field(default_factory=ConvTileConfig) + trait: ConvTraitConfig = field(default_factory=ConvTraitConfig) + dtype_input: str = "fp16" + dtype_weight: str = "fp16" + dtype_output: str = "fp16" + dtype_acc: str = "fp32" + variant: str = "forward" # forward, bwd_data, bwd_weight + ndim: int = 2 # 1, 2, or 3 + layout: str = "nhwgc" + gpu_target: str = "gfx942" + + # Vector sizes + vector_size_a: int = 4 + vector_size_b: int = 8 + vector_size_c: int = 8 + + # Occupancy + block_per_cu: int = 1 + num_wave_groups: int = 1 + + def to_dict(self) -> Dict[str, Any]: + """Convert to dictionary for codegen""" + return { + "tile_m": self.tile.tile_m, + "tile_n": self.tile.tile_n, + "tile_k": self.tile.tile_k, + "warp_m": self.tile.warp_m, + "warp_n": self.tile.warp_n, + "warp_k": self.tile.warp_k, + "warp_tile_m": self.tile.warp_tile_m, + "warp_tile_n": self.tile.warp_tile_n, + "warp_tile_k": self.tile.warp_tile_k, + "pipeline": self.trait.pipeline, + "scheduler": self.trait.scheduler, + "epilogue": self.trait.epilogue, + "pad_m": self.trait.pad_m, + "pad_n": self.trait.pad_n, + "pad_k": self.trait.pad_k, + "double_smem_buffer": self.trait.double_smem_buffer, + "num_groups_to_merge": self.trait.num_groups_to_merge, + "dtype_input": self.dtype_input, + "dtype_weight": self.dtype_weight, + "dtype_output": self.dtype_output, + "dtype_acc": self.dtype_acc, + "variant": self.variant, + "ndim": self.ndim, + "layout": self.layout, + "gpu_target": self.gpu_target, + "vector_size_a": self.vector_size_a, + "vector_size_b": self.vector_size_b, + "vector_size_c": self.vector_size_c, + "block_per_cu": self.block_per_cu, + "num_wave_groups": self.num_wave_groups, + } + + def kernel_name(self) -> str: + """Generate kernel name from config""" + variant_map = {"forward": "fwd", "bwd_data": "bwdd", "bwd_weight": "bwdw"} + var_str = variant_map.get(self.variant, self.variant) + + name = f"conv_{var_str}_{self.dtype_input}_{self.ndim}d" + name += f"_{self.trait.pipeline}_{self.trait.epilogue}_{self.trait.scheduler}" + name += f"_{self.tile.tile_m}x{self.tile.tile_n}x{self.tile.tile_k}" + name += f"_{self.tile.warp_m}x{self.tile.warp_n}x{self.tile.warp_k}" + name += ( + f"_{self.tile.warp_tile_m}x{self.tile.warp_tile_n}x{self.tile.warp_tile_k}" + ) + return name + + +@dataclass +class ConvKernelConfigSet: + """A set of convolution kernel configurations loaded from JSON""" + + name: str = "default" + configs: List[ConvKernelConfig] = field(default_factory=list) + + # Tile parameter ranges + tile_m_values: List[int] = field(default_factory=lambda: [128]) + tile_n_values: List[int] = field(default_factory=lambda: [128]) + tile_k_values: List[int] = field(default_factory=lambda: [32]) + warp_m_values: List[int] = field(default_factory=lambda: [2]) + warp_n_values: List[int] = field(default_factory=lambda: [2]) + warp_k_values: List[int] = field(default_factory=lambda: [1]) + warp_tile_m_values: List[int] = field(default_factory=lambda: [32]) + warp_tile_n_values: List[int] = field(default_factory=lambda: [32]) + warp_tile_k_values: List[int] = field(default_factory=lambda: [16]) + + # Trait parameter ranges + pipeline_values: List[str] = field(default_factory=lambda: ["compv3"]) + scheduler_values: List[str] = field(default_factory=lambda: ["intrawave"]) + epilogue_values: List[str] = field(default_factory=lambda: ["cshuffle"]) + pad_m_values: List[bool] = field(default_factory=lambda: [True]) + pad_n_values: List[bool] = field(default_factory=lambda: [True]) + pad_k_values: List[bool] = field(default_factory=lambda: [True]) + double_smem_buffer_values: List[bool] = field(default_factory=lambda: [False]) + num_groups_to_merge_values: List[int] = field(default_factory=lambda: [1]) + + # Vector sizes + vector_size_a_values: List[int] = field(default_factory=lambda: [4]) + vector_size_b_values: List[int] = field(default_factory=lambda: [8]) + vector_size_c_values: List[int] = field(default_factory=lambda: [8]) + + # Occupancy + block_per_cu_values: List[int] = field(default_factory=lambda: [1]) + num_wave_groups_values: List[int] = field(default_factory=lambda: [1]) + + # Data types + dtype_input: str = "fp16" + dtype_weight: str = "fp16" + dtype_output: str = "fp16" + dtype_acc: str = "fp32" + + # Conv specific + variant: str = "forward" + ndim: int = 2 + layout: str = "nhwgc" + gpu_targets: List[str] = field(default_factory=lambda: ["gfx942"]) + + def generate_configs(self) -> Iterator[ConvKernelConfig]: + """Generate all kernel configurations (cartesian product)""" + # Tile parameters + tile_params = itertools.product( + self.tile_m_values, + self.tile_n_values, + self.tile_k_values, + self.warp_m_values, + self.warp_n_values, + self.warp_k_values, + self.warp_tile_m_values, + self.warp_tile_n_values, + self.warp_tile_k_values, + ) + + # Trait parameters + trait_params = itertools.product( + self.pipeline_values, + self.scheduler_values, + self.epilogue_values, + self.pad_m_values, + self.pad_n_values, + self.pad_k_values, + self.double_smem_buffer_values, + self.num_groups_to_merge_values, + ) + + # Vector/occupancy parameters + extra_params = itertools.product( + self.vector_size_a_values, + self.vector_size_b_values, + self.vector_size_c_values, + self.block_per_cu_values, + self.num_wave_groups_values, + ) + + # Convert to lists for reuse + tile_list = list(tile_params) + trait_list = list(trait_params) + extra_list = list(extra_params) + + # Generate for each GPU target + for gpu_target in self.gpu_targets: + for tile in tile_list: + for trait in trait_list: + for extra in extra_list: + tile_cfg = ConvTileConfig( + tile_m=tile[0], + tile_n=tile[1], + tile_k=tile[2], + warp_m=tile[3], + warp_n=tile[4], + warp_k=tile[5], + warp_tile_m=tile[6], + warp_tile_n=tile[7], + warp_tile_k=tile[8], + ) + trait_cfg = ConvTraitConfig( + pipeline=trait[0], + scheduler=trait[1], + epilogue=trait[2], + pad_m=trait[3], + pad_n=trait[4], + pad_k=trait[5], + double_smem_buffer=trait[6], + num_groups_to_merge=trait[7], + ) + yield ConvKernelConfig( + tile=tile_cfg, + trait=trait_cfg, + dtype_input=self.dtype_input, + dtype_weight=self.dtype_weight, + dtype_output=self.dtype_output, + dtype_acc=self.dtype_acc, + variant=self.variant, + ndim=self.ndim, + layout=self.layout, + gpu_target=gpu_target, + vector_size_a=extra[0], + vector_size_b=extra[1], + vector_size_c=extra[2], + block_per_cu=extra[3], + num_wave_groups=extra[4], + ) + + def config_count(self) -> int: + """Get total number of configurations""" + tile_count = ( + len(self.tile_m_values) + * len(self.tile_n_values) + * len(self.tile_k_values) + * len(self.warp_m_values) + * len(self.warp_n_values) + * len(self.warp_k_values) + * len(self.warp_tile_m_values) + * len(self.warp_tile_n_values) + * len(self.warp_tile_k_values) + ) + trait_count = ( + len(self.pipeline_values) + * len(self.scheduler_values) + * len(self.epilogue_values) + * len(self.pad_m_values) + * len(self.pad_n_values) + * len(self.pad_k_values) + * len(self.double_smem_buffer_values) + * len(self.num_groups_to_merge_values) + ) + extra_count = ( + len(self.vector_size_a_values) + * len(self.vector_size_b_values) + * len(self.vector_size_c_values) + * len(self.block_per_cu_values) + * len(self.num_wave_groups_values) + ) + return tile_count * trait_count * extra_count * len(self.gpu_targets) + + +def load_conv_kernel_configs(json_path: str | Path) -> ConvKernelConfigSet: + """ + Load convolution kernel configurations from a JSON file. + + Args: + json_path: Path to JSON configuration file + + Returns: + ConvKernelConfigSet with all parameter values loaded + """ + json_path = Path(json_path) + + with open(json_path) as f: + data = json.load(f) + + config_set = ConvKernelConfigSet() + + # Name + config_set.name = data.get("kernel_set_name", json_path.stem) + + # Data types + if "datatype" in data: + dt = data["datatype"] + config_set.dtype_input = dt.get("input", "fp16") + config_set.dtype_weight = dt.get("weight", "fp16") + config_set.dtype_output = dt.get("output", "fp16") + config_set.dtype_acc = dt.get("acc", "fp32") + + # Conv specific + config_set.variant = data.get("variant", "forward") + config_set.ndim = data.get("ndim", 2) + config_set.layout = data.get("layout", "nhwgc") + + # GPU targets + if "gpu_targets" in data: + config_set.gpu_targets = data["gpu_targets"] + elif "gpu_target" in data: + config_set.gpu_targets = [data["gpu_target"]] + + # Tile config + tile_cfg = data.get("tile_config", {}) + config_set.tile_m_values = _get_values(tile_cfg, "tile_m", [128]) + config_set.tile_n_values = _get_values(tile_cfg, "tile_n", [128]) + config_set.tile_k_values = _get_values(tile_cfg, "tile_k", [32]) + config_set.warp_m_values = _get_values(tile_cfg, "warp_m", [2]) + config_set.warp_n_values = _get_values(tile_cfg, "warp_n", [2]) + config_set.warp_k_values = _get_values(tile_cfg, "warp_k", [1]) + config_set.warp_tile_m_values = _get_values(tile_cfg, "warp_tile_m", [32]) + config_set.warp_tile_n_values = _get_values(tile_cfg, "warp_tile_n", [32]) + config_set.warp_tile_k_values = _get_values(tile_cfg, "warp_tile_k", [16]) + + # Trait config + trait_cfg = data.get("trait_config", {}) + config_set.pipeline_values = _get_values(trait_cfg, "pipeline", ["compv3"]) + config_set.scheduler_values = _get_values(trait_cfg, "scheduler", ["intrawave"]) + config_set.epilogue_values = _get_values(trait_cfg, "epilogue", ["cshuffle"]) + config_set.pad_m_values = _get_values(trait_cfg, "pad_m", [True]) + config_set.pad_n_values = _get_values(trait_cfg, "pad_n", [True]) + config_set.pad_k_values = _get_values(trait_cfg, "pad_k", [True]) + config_set.double_smem_buffer_values = _get_values( + trait_cfg, "double_smem_buffer", [False] + ) + config_set.num_groups_to_merge_values = _get_values( + trait_cfg, "num_groups_to_merge", [1] + ) + + # Vector config + vec_cfg = data.get("vector_config", {}) + config_set.vector_size_a_values = _get_values(vec_cfg, "vector_size_a", [4]) + config_set.vector_size_b_values = _get_values(vec_cfg, "vector_size_b", [8]) + config_set.vector_size_c_values = _get_values(vec_cfg, "vector_size_c", [8]) + + # Occupancy config + occ_cfg = data.get("occupancy_config", {}) + config_set.block_per_cu_values = _get_values(occ_cfg, "block_per_cu", [1]) + config_set.num_wave_groups_values = _get_values(occ_cfg, "num_wave_groups", [1]) + + return config_set + + +def generate_cpp_conv_kernel_set_declaration( + config_set: ConvKernelConfigSet, + set_name: Optional[str] = None, +) -> str: + """ + Generate C++ DECL_CONV_KERNEL_SET code from a ConvKernelConfigSet. + """ + name = set_name or config_set.name + + lines = [f"DECL_CONV_KERNEL_SET({name},"] + + for config in config_set.generate_configs(): + line = f' .add("{config.dtype_input}", "{config.variant}", {config.ndim}, ' + line += f"{config.tile.tile_m}, {config.tile.tile_n}, {config.tile.tile_k})" + lines.append(line) + + lines.append(");") + + return "\n".join(lines) + + +# ============================================================================= +# GEMM Configuration Export Functions +# ============================================================================= + + +def generate_cpp_kernel_set_declaration( + config_set: KernelConfigSet, + set_name: Optional[str] = None, +) -> str: + """ + Generate C++ DECL_KERNEL_SET code from a KernelConfigSet. + + Args: + config_set: The kernel configuration set + set_name: Optional name override for the kernel set + + Returns: + C++ code string with DECL_KERNEL_SET declaration + """ + name = set_name or config_set.name + + lines = [f"DECL_KERNEL_SET({name},"] + + for config in config_set.generate_configs(): + # Generate .add() call for each config + line = f' .add("{config.dtype_a}", "{config.layout}", ' + line += f"{config.tile.tile_m}, {config.tile.tile_n}, {config.tile.tile_k})" + lines.append(line) + + lines.append(");") + + return "\n".join(lines) + + +# CLI for testing +if __name__ == "__main__": + import sys + + if len(sys.argv) < 2: + print("Usage: python kernel_config_loader.py ") + print("\nLoads kernel configurations from JSON and prints summary.") + sys.exit(1) + + json_path = sys.argv[1] + + try: + config_set = load_kernel_configs(json_path) + + print(f"Kernel Set: {config_set.name}") + print( + f"Data Types: A={config_set.dtype_a}, B={config_set.dtype_b}, C={config_set.dtype_c}, Acc={config_set.dtype_acc}" + ) + print(f"Layout: {config_set.layout}") + print(f"GPU Targets: {config_set.gpu_targets}") + print(f"Variant: {config_set.variant}") + print() + print("Tile Configurations:") + print(f" tile_m: {config_set.tile_m_values}") + print(f" tile_n: {config_set.tile_n_values}") + print(f" tile_k: {config_set.tile_k_values}") + print(f" warp_m: {config_set.warp_m_values}") + print(f" warp_n: {config_set.warp_n_values}") + print(f" warp_k: {config_set.warp_k_values}") + print( + f" warp_tile: {config_set.warp_tile_m_values}x{config_set.warp_tile_n_values}x{config_set.warp_tile_k_values}" + ) + print() + print("Trait Configurations:") + print(f" pipeline: {config_set.pipeline_values}") + print(f" scheduler: {config_set.scheduler_values}") + print(f" epilogue: {config_set.epilogue_values}") + print( + f" padding: m={config_set.pad_m_values}, n={config_set.pad_n_values}, k={config_set.pad_k_values}" + ) + print() + print(f"Total configurations: {config_set.config_count()}") + print() + + # Print first few config names + print("Sample kernel names:") + for i, config in enumerate(config_set.generate_configs()): + if i >= 5: + print(f" ... and {config_set.config_count() - 5} more") + break + print(f" {config.kernel_name()}") + print() + + # Generate C++ code + if "--cpp" in sys.argv: + print("C++ Declaration:") + print("-" * 60) + print(generate_cpp_kernel_set_declaration(config_set)) + + except Exception as e: + print(f"Error: {e}") + sys.exit(1) diff --git a/dispatcher/codegen/preselected_kernels.py b/dispatcher/codegen/preselected_kernels.py new file mode 100644 index 00000000000..010d930639e --- /dev/null +++ b/dispatcher/codegen/preselected_kernels.py @@ -0,0 +1,518 @@ +#!/usr/bin/env python3 + +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +Preselected, Benchmarked Kernel Configurations + +Curated kernel sets optimized for different workload characteristics: +- Compute-friendly: Large tiles, high arithmetic intensity +- Memory-friendly: Smaller tiles, better memory access patterns +- Latency-friendly: Minimal tiles, low latency for small problems +""" + +from functools import partial, lru_cache +from typing import List +from unified_gemm_codegen import KernelConfig, TileConfig, TraitConfig, GemmVariant + + +# ============================================================================ +# Base Configurations +# ============================================================================ + + +def _base_fp16_rcr_compute() -> partial: + """Base configuration for compute-intensive FP16 RCR kernels""" + return partial( + KernelConfig, + tile=None, # Will be overridden + trait=TraitConfig( + pipeline="compv4", + epilogue="cshuffle", + scheduler="intrawave", + pad_m=True, + pad_n=True, + pad_k=True, + persistent=False, + ), + variant=GemmVariant.STANDARD, + block_size=256, + k_block_per_cu=1, + num_wave_groups=1, + ) + + +def _base_fp16_rcr_memory() -> partial: + """Base configuration for memory-intensive FP16 RCR kernels""" + # Note: Use 'mem' pipeline for interwave scheduler (compv3/compv4/compv5/compv6 only support intrawave) + return partial( + KernelConfig, + tile=None, # Will be overridden + trait=TraitConfig( + pipeline="mem", + epilogue="cshuffle", + scheduler="interwave", + pad_m=True, + pad_n=True, + pad_k=True, + persistent=False, + ), + variant=GemmVariant.STANDARD, + block_size=128, + k_block_per_cu=1, + num_wave_groups=1, + ) + + +def _base_fp16_rcr_latency() -> partial: + """Base configuration for latency-sensitive FP16 RCR kernels""" + return partial( + KernelConfig, + tile=None, # Will be overridden + trait=TraitConfig( + pipeline="mem", + epilogue="default", + scheduler="intrawave", + pad_m=True, + pad_n=True, + pad_k=True, + persistent=False, + ), + variant=GemmVariant.STANDARD, + block_size=128, + k_block_per_cu=1, + num_wave_groups=1, + ) + + +# ============================================================================ +# Preselected FP16 RCR Kernels +# ============================================================================ + + +@lru_cache(None) +def preselected_fp16_rcr_compute() -> List[KernelConfig]: + """ + Compute-friendly FP16 RCR kernels + + Optimized for: + - Large M, N dimensions (>= 128) + - High arithmetic intensity + - Good occupancy + - Maximum throughput + """ + base = _base_fp16_rcr_compute() + + return [ + # Large tiles for maximum compute + base(tile=TileConfig(256, 256, 32, 4, 4, 1, 32, 32, 16)), + base(tile=TileConfig(256, 256, 64, 4, 4, 1, 32, 32, 16)), + base(tile=TileConfig(256, 128, 32, 4, 2, 1, 32, 32, 16)), + base(tile=TileConfig(128, 256, 32, 2, 4, 1, 32, 32, 16)), + # Balanced tiles + base(tile=TileConfig(128, 128, 32, 2, 2, 1, 32, 32, 16)), + base(tile=TileConfig(128, 128, 64, 2, 2, 1, 32, 32, 16)), + # With persistent kernel for large batches + base( + tile=TileConfig(256, 256, 32, 4, 4, 1, 32, 32, 16), + trait=TraitConfig( + pipeline="compv4", + epilogue="cshuffle", + scheduler="intrawave", + pad_m=False, + pad_n=False, + pad_k=False, + persistent=True, + ), + ), + ] + + +@lru_cache(None) +def preselected_fp16_rcr_memory() -> List[KernelConfig]: + """ + Memory-friendly FP16 RCR kernels + + Optimized for: + - Small to medium M, N dimensions + - Memory-bound workloads + - Better cache utilization + - Lower register pressure + """ + base = _base_fp16_rcr_memory() + + return [ + # Small tiles for memory efficiency + base(tile=TileConfig(16, 32, 32, 1, 1, 1, 16, 16, 16)), + base(tile=TileConfig(32, 16, 32, 1, 1, 1, 16, 16, 16)), + base(tile=TileConfig(16, 64, 32, 1, 2, 1, 16, 16, 16)), + base(tile=TileConfig(64, 16, 32, 2, 1, 1, 16, 16, 16)), + # Medium tiles + base(tile=TileConfig(32, 64, 32, 1, 1, 1, 32, 32, 16)), + base(tile=TileConfig(64, 32, 32, 1, 1, 1, 32, 32, 16)), + base(tile=TileConfig(32, 128, 32, 1, 2, 1, 32, 32, 16)), + base(tile=TileConfig(128, 32, 32, 2, 1, 1, 32, 32, 16)), + ] + + +@lru_cache(None) +def preselected_fp16_rcr_latency() -> List[KernelConfig]: + """ + Latency-friendly FP16 RCR kernels + + Optimized for: + - Very small M, N dimensions (< 64) + - Minimal launch overhead + - Low latency + - Quick execution + """ + base = _base_fp16_rcr_latency() + + return [ + # Minimal tiles for low latency + base(tile=TileConfig(16, 32, 32, 1, 1, 1, 16, 16, 16)), + base(tile=TileConfig(32, 16, 32, 1, 1, 1, 16, 16, 16)), + ] + + +# ============================================================================ +# Preselected Multi-D Kernels +# ============================================================================ + + +@lru_cache(None) +def preselected_fp16_rcr_multi_d() -> List[KernelConfig]: + """ + Multi-D GEMM kernels with element-wise fusion + + Common fusions: + - MultiDAdd: E = C + D0 + D1 + - Relu: E = max(C, 0) + - Gelu: E = gelu(C) + """ + base = _base_fp16_rcr_compute() + + configs = [] + + # Best-performing tile for fused operations + tile = TileConfig(128, 128, 32, 2, 2, 1, 32, 32, 16) + + # Common element-wise operations + for ew_op in ["MultiDAdd", "Relu", "Gelu", "FastGelu"]: + for num_d in [1, 2]: + configs.append( + base( + tile=tile, + variant=GemmVariant.MULTI_D, + elementwise_op=ew_op, + num_d_tensors=num_d, + ) + ) + + return configs + + +@lru_cache(None) +def preselected_fp16_rcr_preshuffle() -> List[KernelConfig]: + """ + Preshuffle GEMM kernels for weight optimization + + Best for: + - Repeated use of same weights + - Inference workloads + - Batch size > 1 + """ + base = _base_fp16_rcr_compute() + + return [ + base( + tile=TileConfig(256, 256, 32, 4, 4, 1, 32, 32, 16), + variant=GemmVariant.PRESHUFFLE, + preshuffle=True, + ), + base( + tile=TileConfig(128, 128, 32, 2, 2, 1, 32, 32, 16), + variant=GemmVariant.PRESHUFFLE, + preshuffle=True, + ), + ] + + +# ============================================================================ +# Unified Preselected Sets +# ============================================================================ + + +@lru_cache(None) +def preselected_fp16_rcr_all() -> List[KernelConfig]: + """All preselected FP16 RCR kernels""" + return ( + preselected_fp16_rcr_compute() + + preselected_fp16_rcr_memory() + + preselected_fp16_rcr_latency() + + preselected_fp16_rcr_multi_d() + + preselected_fp16_rcr_preshuffle() + ) + + +@lru_cache(None) +def preselected_fp16_rcr_essential() -> List[KernelConfig]: + """ + Essential FP16 RCR kernels - minimal set for most workloads + + Covers: + - 90% of common GEMM sizes + - Key fusion operations + - Balanced performance + """ + base_compute = _base_fp16_rcr_compute() + base_memory = _base_fp16_rcr_memory() + + return [ + # Top compute kernels + base_compute(tile=TileConfig(256, 256, 32, 4, 4, 1, 32, 32, 16)), + base_compute(tile=TileConfig(128, 128, 32, 2, 2, 1, 32, 32, 16)), + # Top memory kernels + base_memory(tile=TileConfig(32, 64, 32, 1, 1, 1, 32, 32, 16)), + base_memory(tile=TileConfig(64, 32, 32, 1, 1, 1, 32, 32, 16)), + # Essential fusions + base_compute( + tile=TileConfig(128, 128, 32, 2, 2, 1, 32, 32, 16), + variant=GemmVariant.MULTI_D, + elementwise_op="Relu", + num_d_tensors=1, + ), + base_compute( + tile=TileConfig(128, 128, 32, 2, 2, 1, 32, 32, 16), + variant=GemmVariant.MULTI_D, + elementwise_op="Gelu", + num_d_tensors=1, + ), + ] + + +# ============================================================================ +# Default Fallback +# ============================================================================ + + +def default_kernel() -> KernelConfig: + """ + Default fallback kernel - guaranteed to work + + Known-good configuration tested on gfx942 + """ + return KernelConfig( + tile=TileConfig(128, 128, 32, 2, 2, 1, 32, 32, 16), + trait=TraitConfig( + pipeline="compv4", + epilogue="cshuffle", + scheduler="intrawave", + pad_m=True, + pad_n=True, + pad_k=True, + persistent=False, + ), + variant=GemmVariant.STANDARD, + block_size=256, + k_block_per_cu=1, + num_wave_groups=1, + ) + + +# ============================================================================ +# BF16 Preselected Sets +# ============================================================================ + + +@lru_cache(None) +def preselected_bf16_rcr_essential() -> List[KernelConfig]: + """Essential BF16 RCR kernels""" + base_compute = partial( + KernelConfig, + tile=None, + trait=TraitConfig( + pipeline="compv4", + epilogue="cshuffle", + scheduler="intrawave", + pad_m=True, + pad_n=True, + pad_k=True, + persistent=False, + ), + variant=GemmVariant.STANDARD, + block_size=256, + ) + + return [ + base_compute(tile=TileConfig(256, 256, 32, 4, 4, 1, 32, 32, 16)), + base_compute(tile=TileConfig(128, 128, 32, 2, 2, 1, 32, 32, 16)), + ] + + +# ============================================================================ +# INT8 Preselected Sets +# ============================================================================ + + +@lru_cache(None) +def preselected_int8_rcr_essential() -> List[KernelConfig]: + """Essential INT8 RCR kernels for quantized inference""" + base = partial( + KernelConfig, + tile=None, + trait=TraitConfig( + pipeline="compv4", + epilogue="cshuffle", + scheduler="intrawave", + pad_m=True, + pad_n=True, + pad_k=True, + persistent=False, + ), + variant=GemmVariant.STANDARD, + block_size=256, + ) + + return [ + base(tile=TileConfig(256, 256, 64, 4, 4, 1, 32, 32, 16)), + base(tile=TileConfig(128, 128, 64, 2, 2, 1, 32, 32, 16)), + ] + + +# ============================================================================ +# FP8 Preselected Sets +# ============================================================================ + + +@lru_cache(None) +def preselected_fp8_rcr_essential() -> List[KernelConfig]: + """Essential FP8 RCR kernels for AI training""" + base = partial( + KernelConfig, + tile=None, + trait=TraitConfig( + pipeline="compv4", + epilogue="cshuffle", + scheduler="intrawave", + pad_m=True, + pad_n=True, + pad_k=True, + persistent=False, + ), + variant=GemmVariant.STANDARD, + block_size=256, + ) + + return [ + base(tile=TileConfig(256, 256, 64, 4, 4, 1, 32, 32, 16)), + base(tile=TileConfig(128, 128, 64, 2, 2, 1, 32, 32, 16)), + ] + + +# ============================================================================ +# Mixed Precision Preselected Sets +# ============================================================================ + + +@lru_cache(None) +def preselected_mixed_precision() -> List[KernelConfig]: + """Mixed-precision kernels (FP16 inputs, FP32 output)""" + base = partial( + KernelConfig, + tile=None, + trait=TraitConfig( + pipeline="compv4", + epilogue="cshuffle", + scheduler="intrawave", + pad_m=True, + pad_n=True, + pad_k=True, + persistent=False, + ), + variant=GemmVariant.STANDARD, + block_size=256, + ) + + return [ + base(tile=TileConfig(256, 256, 32, 4, 4, 1, 32, 32, 16)), + base(tile=TileConfig(128, 128, 32, 2, 2, 1, 32, 32, 16)), + ] + + +# ============================================================================ +# Registry +# ============================================================================ + +PRESELECTED_SETS = { + # FP16 sets + "fp16_rcr_compute": preselected_fp16_rcr_compute, + "fp16_rcr_memory": preselected_fp16_rcr_memory, + "fp16_rcr_latency": preselected_fp16_rcr_latency, + "fp16_rcr_multi_d": preselected_fp16_rcr_multi_d, + "fp16_rcr_preshuffle": preselected_fp16_rcr_preshuffle, + "fp16_rcr_all": preselected_fp16_rcr_all, + "fp16_rcr_essential": preselected_fp16_rcr_essential, + # BF16 sets + "bf16_rcr_essential": preselected_bf16_rcr_essential, + # INT8 sets + "int8_rcr_essential": preselected_int8_rcr_essential, + # FP8 sets + "fp8_rcr_essential": preselected_fp8_rcr_essential, + # Mixed precision + "mixed_precision": preselected_mixed_precision, +} + + +def get_preselected_set(name: str) -> List[KernelConfig]: + """Get a preselected kernel set by name""" + if name not in PRESELECTED_SETS: + raise ValueError( + f"Unknown preselected set: {name}. Available: {list(PRESELECTED_SETS.keys())}" + ) + return PRESELECTED_SETS[name]() + + +def list_preselected_sets() -> List[str]: + """List all available preselected sets""" + return list(PRESELECTED_SETS.keys()) + + +# ============================================================================ +# CLI for testing +# ============================================================================ + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser( + description="List preselected kernel configurations" + ) + parser.add_argument( + "--set", + type=str, + default="fp16_rcr_essential", + choices=list_preselected_sets(), + help="Preselected set to display", + ) + parser.add_argument("--count-only", action="store_true", help="Only show count") + + args = parser.parse_args() + + configs = get_preselected_set(args.set) + + if args.count_only: + print(f"{args.set}: {len(configs)} kernels") + else: + print(f"Preselected set: {args.set}") + print(f"Total kernels: {len(configs)}\n") + for i, cfg in enumerate(configs, 1): + print(f"{i}. {cfg.variant.value}") + print(f" Tile: {cfg.tile.tile_m}x{cfg.tile.tile_n}x{cfg.tile.tile_k}") + print(f" Pipeline: {cfg.trait.pipeline}, Epilogue: {cfg.trait.epilogue}") + if cfg.variant == GemmVariant.MULTI_D: + print( + f" Element-wise: {cfg.elementwise_op}, D tensors: {cfg.num_d_tensors}" + ) + print() diff --git a/dispatcher/codegen/unified_gemm_codegen.py b/dispatcher/codegen/unified_gemm_codegen.py new file mode 100755 index 00000000000..b0dd961be7c --- /dev/null +++ b/dispatcher/codegen/unified_gemm_codegen.py @@ -0,0 +1,1713 @@ +#!/usr/bin/env python3 + +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +Unified GEMM Code Generator - Single Source of Truth + +This is THE unified code generator for all GEMM kernel variants: +- Standard GEMM (C = A × B) +- Preshuffle GEMM (optimized weight access) +- Multi-D GEMM (element-wise fusion) + +Generates both CK Tile kernels AND dispatcher wrappers in one pass. +Replaces all tile_engine GEMM codegen. +""" + +import json +import argparse +import itertools +import logging +from pathlib import Path +from typing import Dict, List, Tuple, Optional +from dataclasses import dataclass, asdict +from enum import Enum +import concurrent.futures + +# Import architecture filter for GPU-specific validation +try: + from arch_filter import ArchFilter, KernelConfig as ArchKernelConfig, OperatorType + + HAS_ARCH_FILTER = True +except ImportError: + HAS_ARCH_FILTER = False + ArchFilter = None + ArchKernelConfig = None + OperatorType = None + + +# ============================================================================= +# Preshuffle Validation (copied from tile_engine/ops/commons/gemm_validation_utils.py) +# ============================================================================= + +ELEMENT_SIZE_MAP = { + "fp16": 2, + "bf16": 2, + "fp32": 4, + "fp64": 8, + "fp8": 1, + "bf8": 1, + "int8": 1, +} + + +def _validate_preshuffle_vector_load( + warp_tile_m: int, + warp_tile_k: int, + datatype: str, + m_iter_per_warp: float, + wave_size: int = 64, + vector_load_size: int = 16, +) -> bool: + """ + Validate vector load alignment for preshuffle pipeline. + + Checks: (warp_tile_m * warp_tile_k * elem_size * m_iter_per_warp / wave_size) % vector_load_size == 0 + """ + elem_size = ELEMENT_SIZE_MAP.get(datatype, 2) + access_size = (warp_tile_m * warp_tile_k * elem_size * m_iter_per_warp) / wave_size + return access_size % vector_load_size == 0 + + +def _validate_preshuffle_m0_m1_m2( + tile_m: int, + tile_k: int, + warp_m: int, + warp_n: int, + warp_k: int, + datatype: str, + vector_load_size: int = 16, + warp_size: int = 64, +) -> bool: + """ + Validate M0, M1, M2 configuration for preshuffle matrix A row-major layout. + Ensures proper memory access pattern alignment. + """ + try: + elem_size = ELEMENT_SIZE_MAP.get(datatype, 2) + MPerBlock = tile_m + + # Calculate K1 + K1 = vector_load_size / elem_size + if K1 != int(K1): + return False + K1 = int(K1) + + # Calculate K0 + if tile_k % K1 != 0: + return False + K0 = tile_k // K1 + + # Calculate M2 + if warp_size % K0 != 0: + return False + M2 = warp_size // K0 + + # Calculate number of warps + NumWarps = warp_m * warp_n * warp_k + M0 = NumWarps + + # Calculate M1 + if (M2 * M0) == 0: + return False + if MPerBlock % (M2 * M0) != 0: + return False + M1 = MPerBlock // (M2 * M0) + + # Validate: M0 * M1 * M2 == MPerBlock + return (M0 * M1 * M2) == MPerBlock + + except (ZeroDivisionError, ValueError): + return False + + +def is_preshuffle_config_valid( + tile_m: int, + tile_n: int, + tile_k: int, + warp_m: int, + warp_n: int, + warp_k: int, + warp_tile_m: int, + warp_tile_n: int, + warp_tile_k: int, + datatype: str, +) -> bool: + """ + Comprehensive preshuffle configuration validation. + Copied from tile_engine/ops/commons/gemm_validation_utils.py + """ + # Basic divisibility checks + if tile_m % (warp_m * warp_tile_m) != 0: + return False + if tile_n % (warp_n * warp_tile_n) != 0: + return False + if tile_k % (warp_k * warp_tile_k) != 0: + return False + + # Calculate m_iter_per_warp + m_iter_per_warp = tile_m / (warp_m * warp_tile_m) + + # Validate vector load alignment + if not _validate_preshuffle_vector_load( + warp_tile_m, + warp_tile_k, + datatype, + m_iter_per_warp, + wave_size=64, + vector_load_size=16, + ): + return False + + # Validate M0/M1/M2 configuration + if not _validate_preshuffle_m0_m1_m2( + tile_m, + tile_k, + warp_m, + warp_n, + warp_k, + datatype, + vector_load_size=16, + warp_size=64, + ): + return False + + return True + + +logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s") + +log = logging.getLogger(__name__) + + +# ============================================================================ +# Configuration and Data Structures +# ============================================================================ + + +class GemmVariant(Enum): + """GEMM kernel variants""" + + STANDARD = "standard" + PRESHUFFLE = "preshuffle" + MULTI_D = "multi_d" + + +@dataclass +class TileConfig: + """Tile configuration parameters""" + + tile_m: int + tile_n: int + tile_k: int + warp_m: int + warp_n: int + warp_k: int + warp_tile_m: int + warp_tile_n: int + warp_tile_k: int + + def is_valid(self) -> bool: + """Validate tile configuration""" + return ( + self.tile_m % (self.warp_m * self.warp_tile_m) == 0 + and self.tile_n % (self.warp_n * self.warp_tile_n) == 0 + and self.tile_k % (self.warp_k * self.warp_tile_k) == 0 + and self.tile_m > 0 + and self.tile_n > 0 + and self.tile_k > 0 + ) + + +@dataclass +class TraitConfig: + """Kernel trait configuration""" + + pipeline: str # mem, compv3, compv4 + epilogue: str # default, cshuffle + scheduler: str # intrawave, interwave + pad_m: bool + pad_n: bool + pad_k: bool + persistent: bool + + def is_valid(self) -> bool: + """Check if trait combination is valid""" + # Unsupported combinations + # Only 'mem' pipeline supports interwave scheduler. + # All compute pipelines (compv3/v4/v5/v6/async) only support intrawave. + unsupported = { + ("compv3", "cshuffle", "interwave"), + ("compv3", "default", "interwave"), + ("compv4", "cshuffle", "interwave"), + ("compv4", "default", "interwave"), + ("compv5", "cshuffle", "interwave"), + ("compv5", "default", "interwave"), + ("compv6", "cshuffle", "interwave"), + ("compv6", "default", "interwave"), + ("comp_async", "cshuffle", "interwave"), + ("comp_async", "default", "interwave"), + } + return (self.pipeline, self.epilogue, self.scheduler) not in unsupported + + +@dataclass +class KernelConfig: + """Complete kernel configuration""" + + tile: TileConfig + trait: TraitConfig + variant: GemmVariant = GemmVariant.STANDARD + + # Variant-specific + preshuffle: bool = False + elementwise_op: str = "PassThrough" + num_d_tensors: int = 0 + d_layout: str = "r" # Layout for D tensors (r=row, c=col) - same for all D tensors + + # Fixed parameters + block_size: int = 256 + k_block_per_cu: int = 1 + num_wave_groups: int = 1 + + def name(self, datatype: str, layout: str) -> str: + """C++ alias for template instance""" + return f"ck_tile_gemm_{self.key_name(datatype, layout)}" + + def key_name(self, datatype: str, layout: str) -> str: + """ + Unique identifier for this kernel configuration. + + All parameters that affect kernel behavior MUST be included to ensure + unique names for unique configurations: + - Data type and layout (signature) + - Tile, warp, warp_tile dimensions (algorithm) + - Pipeline, epilogue, scheduler (traits) + - Padding flags (affects divisibility requirements) + - Persistent mode + - Preshuffle variant + - Multi-D: elementwise op, num D tensors, D layout + - Occupancy: wave groups, k_block_per_cu (if non-default) + """ + parts = [] + # Signature + parts.append(f"dt_{datatype}") + parts.append(f"ly_{layout}") + + # Tile configuration + parts.append(f"tile_{self.tile.tile_m}x{self.tile.tile_n}x{self.tile.tile_k}") + parts.append(f"warp_{self.tile.warp_m}x{self.tile.warp_n}x{self.tile.warp_k}") + parts.append( + f"wtile_{self.tile.warp_tile_m}x{self.tile.warp_tile_n}x{self.tile.warp_tile_k}" + ) + + # Traits + parts.append(f"pipe_{self.trait.pipeline}") + parts.append(f"epi_{self.trait.epilogue}") + parts.append(f"sched_{self.trait.scheduler}") + + # Padding flags (only if not all True - the common case) + if not (self.trait.pad_m and self.trait.pad_n and self.trait.pad_k): + parts.append( + f"pad{int(self.trait.pad_m)}{int(self.trait.pad_n)}{int(self.trait.pad_k)}" + ) + + # Persistent mode + if self.trait.persistent: + parts.append("persist") + + # Preshuffle variant + if self.preshuffle: + parts.append("preshuffle") + + # Multi-D variant: include elementwise op, num tensors, and D layout + if self.variant == GemmVariant.MULTI_D: + parts.append(f"ew_{self.elementwise_op}") + parts.append(f"nd{self.num_d_tensors}") + parts.append(f"dly_{self.d_layout}") + + # Occupancy parameters (only if non-default) + if self.num_wave_groups != 1: + parts.append(f"wg{self.num_wave_groups}") + if self.k_block_per_cu != 1: + parts.append(f"kbpc{self.k_block_per_cu}") + + return "_".join(parts) + + def dict_items(self): + """Iterator over (field, value) pairs""" + return asdict(self).items() + + +# ============================================================================ +# Type Mappings +# ============================================================================ + + +class TypeMappings: + """Centralized type mappings for code generation""" + + DTYPE_TO_CK = { + "fp16": "fp16_t", + "bf16": "bf16_t", + "fp32": "float", + "fp8": "fp8_t", + "bf8": "bf8_t", + "int8": "int8_t", + } + + # Fully-qualified types for use outside of 'using namespace ck_tile' scope + DTYPE_TO_CK_QUALIFIED = { + "fp16": "ck_tile::fp16_t", + "bf16": "ck_tile::bf16_t", + "fp32": "float", # Built-in type, no namespace + "fp8": "ck_tile::fp8_t", + "bf8": "ck_tile::bf8_t", + "int8": "int8_t", # Built-in type + } + + DTYPE_TO_DISPATCHER = { + "fp16": "DataType::FP16", + "bf16": "DataType::BF16", + "fp32": "DataType::FP32", + "fp8": "DataType::FP8", + "bf8": "DataType::BF8", + "int8": "DataType::INT8", + } + + LAYOUT_TO_CK = { + "r": "tensor_layout::gemm::RowMajor", + "c": "tensor_layout::gemm::ColumnMajor", + } + + LAYOUT_TO_DISPATCHER = { + "r": "LayoutTag::RowMajor", + "c": "LayoutTag::ColMajor", + } + + PIPELINE_TO_CK = { + "mem": "GemmPipelineAgBgCrMem", + "compv3": "GemmPipelineAgBgCrCompV3", + "compv4": "GemmPipelineAgBgCrCompV4", + "preshufflev2": "WeightPreshufflePipelineAGmemBGmemCRegV2", + } + + PIPELINE_TO_BASE = { + "mem": "BaseGemmPipelineAgBgCrMem", + "compv3": "BaseGemmPipelineAgBgCrCompV3", + "compv4": "BaseGemmPipelineAgBgCrCompV4", + "preshufflev2": "BaseWeightPreshufflePipelineAGmemBGmemCRegV2", + } + + PIPELINE_TO_DISPATCHER = { + "mem": "Pipeline::Mem", + "compv3": "Pipeline::CompV3", + "compv4": "Pipeline::CompV4", + "preshufflev2": "Pipeline::PreShuffleV2", + } + + SCHEDULER_TO_CK = { + "intrawave": "GemmPipelineScheduler::Intrawave", + "interwave": "GemmPipelineScheduler::Interwave", + "default": "GemmPipelineScheduler::Default", + } + + SCHEDULER_TO_DISPATCHER = { + "intrawave": "Scheduler::Intrawave", + "interwave": "Scheduler::Interwave", + "default": "Scheduler::Auto", + } + + EPILOGUE_TO_DISPATCHER = { + "cshuffle": "Epilogue::CShuffle", + "default": "Epilogue::Default", + } + + @staticmethod + def get_output_dtype(dtype: str) -> str: + """Get output datatype (fp8/bf8 -> fp16)""" + return "fp16" if dtype in ["fp8", "bf8"] else dtype + + +# ============================================================================ +# Kernel Name Generator +# ============================================================================ + + +class KernelNaming: + """Unified kernel naming""" + + @staticmethod + def generate(config: KernelConfig, datatype: str, layout: str) -> str: + """Generate kernel name following tile_engine convention""" + t = config.tile + tr = config.trait + + # For multi-d, use 4-char layout (abcd), otherwise use 3-char layout (abc) + if config.variant == GemmVariant.MULTI_D: + full_layout = layout + config.d_layout # e.g., "rcr" + "r" = "rcrr" + else: + full_layout = layout + + name = ( + f"gemm_{datatype}_{full_layout}_{tr.pipeline}_{tr.epilogue}_{tr.scheduler}" + ) + name += f"_{str(tr.pad_m).capitalize()}_{str(tr.pad_n).capitalize()}" + name += f"_{str(tr.pad_k).capitalize()}_{str(tr.persistent).capitalize()}" + name += f"_{t.tile_m}x{t.tile_n}x{t.tile_k}" + name += f"_{t.warp_m}x{t.warp_n}x{t.warp_k}" + name += f"_{t.warp_tile_m}x{t.warp_tile_n}x{t.warp_tile_k}" + + # Add variant suffix + if config.variant == GemmVariant.PRESHUFFLE: + name += "_preshuffle" + elif config.variant == GemmVariant.MULTI_D: + name += f"_multid_{config.elementwise_op}_d{config.num_d_tensors}" + + return name + + +# ============================================================================ +# CK Tile Kernel Generator +# ============================================================================ + + +class CKTileKernelGenerator: + """Generates CK Tile kernel instance code""" + + def __init__(self, datatype: str, layout: str): + self.datatype = datatype + self.layout = layout + self.tm = TypeMappings() + + def generate(self, config: KernelConfig) -> str: + """Generate complete CK Tile kernel""" + kernel_name = KernelNaming.generate(config, self.datatype, self.layout) + + return f"""{self._header(kernel_name, config)} +{self._types(config, kernel_name)} +{self._selected_kernel_struct(config, kernel_name)} +""" + + def _header(self, kernel_name: str, config: KernelConfig) -> str: + """Generate header includes""" + includes = """// SPDX-License-Identifier: MIT +// Auto-generated CK Tile GEMM kernel +#pragma once + +#include +#include "ck_tile/core.hpp" +#include "ck_tile/host/kernel_launch.hpp" +#include "ck_tile/ops/gemm.hpp" +#include "ck_tile/ops/gemm/kernel/gemm_kernel.hpp" +#include "ck_tile/ops/common/tensor_layout.hpp" +#include "ck_tile/ops/epilogue/default_2d_epilogue.hpp" +#include "ck_tile/ops/epilogue/cshuffle_epilogue.hpp" + +""" + + if config.variant == GemmVariant.MULTI_D: + includes += """ +#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp" +#include "ck_tile/ops/gemm/kernel/gemm_multi_d_kernel.hpp" +""" + + if config.preshuffle: + includes += """ +#include "ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v2.hpp" +#include "ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_base_policy.hpp" +""" + + return includes + + def _types(self, config: KernelConfig, kernel_name: str) -> str: + """Generate type definitions - just the namespace import, types are in kernel namespace""" + # Note: Data types and layouts are now defined inside each kernel's unique namespace + # to avoid type alias redefinition conflicts when mixing layouts (e.g., RCR + RRR) + types = """ +// Use ck_tile namespace for generated code +using namespace ck_tile; +""" + return types + + def _kernel_local_types(self, config: KernelConfig) -> str: + """Generate data type and layout definitions inside kernel namespace""" + output_dtype = self.tm.get_output_dtype(self.datatype) + + return f""" + // Data types (inside namespace to avoid conflicts across layouts) + using ADataType = {self.tm.DTYPE_TO_CK[self.datatype]}; + using BDataType = {self.tm.DTYPE_TO_CK[self.datatype]}; + using AccDataType = float; + using CDataType = {self.tm.DTYPE_TO_CK[output_dtype]}; + + // Layouts (inside namespace to avoid conflicts when mixing layouts) + using ALayout = {self.tm.LAYOUT_TO_CK[self.layout[0]]}; + using BLayout = {self.tm.LAYOUT_TO_CK[self.layout[1]]}; + using CLayout = {self.tm.LAYOUT_TO_CK[self.layout[2]]}; +""" + + def _multi_d_types(self, config: KernelConfig) -> str: + """Generate multi-d type definitions (inside namespace to avoid conflicts)""" + if config.variant != GemmVariant.MULTI_D: + return "" + + d_types = ", ".join(["CDataType"] * config.num_d_tensors) + d_layout_ck = self.tm.LAYOUT_TO_CK[config.d_layout] + d_layouts = ", ".join([d_layout_ck] * config.num_d_tensors) + + return f""" +// Multi-D types (defined in namespace to avoid conflicts) +using DsDataType = tuple<{d_types}>; +using DLayout = {d_layout_ck}; // D tensor layout (can differ from C) +using DsLayout = tuple<{d_layouts}>; +using ElementWiseFn = element_wise::{config.elementwise_op}; +static constexpr index_t NumDTensor = {config.num_d_tensors}; +using GemmMultiDArgs = GemmMultiDHostArgs; +""" + + def _selected_kernel_struct(self, config: KernelConfig, kernel_name: str) -> str: + """Generate SelectedKernel struct with unique name in unique namespace""" + t = config.tile + tr = config.trait + output_dtype = self.tm.get_output_dtype(self.datatype) + + # Generate unique struct name and namespace from kernel name + struct_name = f"Kernel_{kernel_name}" + # Create valid C++ namespace name (replace invalid chars) + ns_name = "ns_" + kernel_name.replace("-", "_") + + multi_d_types = self._multi_d_types(config) + + return f""" +namespace {ns_name} {{ +constexpr const char* KERNEL_NAME = "{kernel_name}"; + +// Data types (inside namespace to avoid conflicts across different kernels) +using ADataType = {self.tm.DTYPE_TO_CK[self.datatype]}; +using BDataType = {self.tm.DTYPE_TO_CK[self.datatype]}; +using AccDataType = float; +using CDataType = {self.tm.DTYPE_TO_CK[output_dtype]}; + +// Layouts (inside namespace to avoid conflicts when mixing layouts like RCR + RRR) +using ALayout = {self.tm.LAYOUT_TO_CK[self.layout[0]]}; +using BLayout = {self.tm.LAYOUT_TO_CK[self.layout[1]]}; +using CLayout = {self.tm.LAYOUT_TO_CK[self.layout[2]]}; +{multi_d_types} +struct {struct_name} {{ + // Data types (required by backend as member types) + using ADataType = {ns_name}::ADataType; + using BDataType = {ns_name}::BDataType; + using CDataType = {ns_name}::CDataType; + using AccDataType = {ns_name}::AccDataType; + + // Configuration + static constexpr index_t BlockSize = {config.block_size}; + static constexpr index_t TileM = {t.tile_m}; + static constexpr index_t TileN = {t.tile_n}; + static constexpr index_t TileK = {t.tile_k}; + static constexpr index_t WarpPerBlock_M = {t.warp_m}; + static constexpr index_t WarpPerBlock_N = {t.warp_n}; + static constexpr index_t WarpPerBlock_K = {t.warp_k}; + static constexpr index_t WarpTileM = {t.warp_tile_m}; + static constexpr index_t WarpTileN = {t.warp_tile_n}; + static constexpr index_t WarpTileK = {t.warp_tile_k}; + + // Traits + static constexpr bool kPadM = {str(tr.pad_m).lower()}; + static constexpr bool kPadN = {str(tr.pad_n).lower()}; + static constexpr bool kPadK = {str(tr.pad_k).lower()}; + static constexpr bool TransposeC = false; + static constexpr bool UsePersistentKernel = {str(tr.persistent).lower()}; + static constexpr bool DoubleSmemBuffer = {str(tr.pipeline == "compv4" or tr.pipeline == "preshufflev2").lower()}; + static constexpr bool UseStructuredSparsity = false; + static constexpr bool Preshuffle = {str(config.preshuffle).lower()}; + static constexpr index_t NumWaveGroups = {config.num_wave_groups}; + + {self._tile_types(config, ns_name)} + {self._launch_function(config)} +}}; + +// Alias for tile_engine style compatibility (when used with -include) +using SelectedKernel = {struct_name}; +using SelectedKernelLauncher = {struct_name}; +}} // namespace {ns_name} + +// Export to global namespace ONLY for single-kernel includes +// Define CK_TILE_SINGLE_KERNEL_INCLUDE before including this header to enable these aliases +#ifdef CK_TILE_SINGLE_KERNEL_INCLUDE +using {struct_name} = {ns_name}::{struct_name}; +using SelectedKernel = {ns_name}::{struct_name}; +constexpr const char* KERNEL_NAME = {ns_name}::KERNEL_NAME; +using ADataType = {self.tm.DTYPE_TO_CK_QUALIFIED[self.datatype]}; +using BDataType = {self.tm.DTYPE_TO_CK_QUALIFIED[self.datatype]}; +using CDataType = {self.tm.DTYPE_TO_CK_QUALIFIED[self.tm.get_output_dtype(self.datatype)]}; +using AccDataType = float; +#endif // CK_TILE_SINGLE_KERNEL_INCLUDE +""" + + def _tile_types(self, config: KernelConfig, ns_name: str) -> str: + """Generate tile type definitions - uses namespace-qualified types""" + return ( + f"""// Tile shape + using TileShape = TileGemmShape< + sequence, + sequence, + sequence, + false, false>; + + using TilePartitioner = GemmSpatiallyLocalTilePartitioner; + using Traits = TileGemmTraits; + using GemmPipelineProblem = GemmPipelineProblem; + using BaseGemmPipeline = """ + + self.tm.PIPELINE_TO_BASE[config.trait.pipeline] + + """;""" + ) + + def _launch_function(self, config: KernelConfig) -> str: + """Generate launch function""" + if config.variant == GemmVariant.MULTI_D: + return self._launch_function_multi_d(config) + if config.preshuffle: + return self._launch_function_preshuffle(config) + return self._launch_function_standard(config) + + def _launch_function_standard(self, config: KernelConfig) -> str: + """Generate launch function for standard GEMM""" + return f""" + static float launch(const GemmHostArgs& args, const stream_config& stream) {{ + const index_t k_grain = args.k_batch * TileK; + const index_t K_split = (args.K + k_grain - 1) / k_grain * TileK; + const index_t num_loop = TilePartitioner::GetLoopNum(K_split); + const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); + const TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); + + float ave_time{{0}}; + + constexpr auto scheduler = {self.tm.SCHEDULER_TO_CK[config.trait.scheduler]}; + + using UniversalGemmProblem = UniversalGemmPipelineProblem< + ADataType, BDataType, AccDataType, TileShape, + TileGemmUniversalTraits, + scheduler>; + + using GemmPipeline = {self.tm.PIPELINE_TO_CK[config.trait.pipeline]}; + {self._epilogue_code(config)} + + using GemmKernel = ck_tile::GemmKernel; + + const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) {{ + auto kargs = GemmKernel::MakeKernelArgs(args); + + if (!GemmKernel::IsSupportedArgument(kargs)) {{ + throw std::runtime_error("Arguments not supported!"); + }} + + const dim3 grids = {"GemmKernel::MaxOccupancyGridSize(stream)" if config.trait.persistent else "GemmKernel::GridSize(args.M, args.N, args.k_batch)"}; + const dim3 blocks = GemmKernel::BlockSize(); + + constexpr int kBlockPerCu = {config.k_block_per_cu}; + ave_time = launch_kernel(stream, + make_kernel(GemmKernel{{}}, grids, blocks, 0, kargs)); + + return ave_time; + }}; + + BaseGemmPipeline::TailHandler(Run, has_hot_loop, tail_num); + return ave_time; + }}""" + + def _launch_function_preshuffle(self, config: KernelConfig) -> str: + """Generate launch function for preshuffle GEMM (weight preshuffle variant) + + Preshuffle uses WeightPreshufflePipelineAGmemBGmemCRegV2 which has a different + API than standard pipelines. It's designed for weight-preshuffled GEMM operations. + """ + return f""" + static float launch(const GemmHostArgs& args, const stream_config& stream) {{ + const index_t k_grain = args.k_batch * TileK; + const index_t K_split = (args.K + k_grain - 1) / k_grain * TileK; + const index_t num_loop = TilePartitioner::GetLoopNum(K_split); + const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); + const TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); + + float ave_time{{0}}; + + constexpr auto scheduler = GemmPipelineScheduler::Default; // Preshuffle uses Default scheduler + + // Preshuffle uses TileFlatmmShape instead of TileGemmShape for the problem + using UniversalGemmProblem = UniversalGemmPipelineProblem< + ADataType, BDataType, AccDataType, TileShape, + TileGemmUniversalTraits, + scheduler>; + + using GemmPipeline = WeightPreshufflePipelineAGmemBGmemCRegV2; + {self._epilogue_code(config)} + + using GemmKernel = ck_tile::GemmKernel; + + const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) {{ + auto kargs = GemmKernel::MakeKernelArgs(args); + + if (!GemmKernel::IsSupportedArgument(kargs)) {{ + throw std::runtime_error("Arguments not supported for preshuffle kernel!"); + }} + + const dim3 grids = {"GemmKernel::MaxOccupancyGridSize(stream)" if config.trait.persistent else "GemmKernel::GridSize(args.M, args.N, args.k_batch)"}; + const dim3 blocks = GemmKernel::BlockSize(); + + constexpr int kBlockPerCu = {config.k_block_per_cu}; + ave_time = launch_kernel(stream, + make_kernel(GemmKernel{{}}, grids, blocks, 0, kargs)); + + return ave_time; + }}; + + BaseGemmPipeline::TailHandler(Run, has_hot_loop, tail_num); + return ave_time; + }}""" + + def _launch_function_multi_d(self, config: KernelConfig) -> str: + """Generate launch function for Multi-D GEMM""" + return f""" + // Multi-D launch function - takes GemmMultiDHostArgs with D tensor pointers + static float launch(const GemmMultiDArgs& args, const stream_config& stream) {{ + const index_t k_grain = args.k_batch * TileK; + const index_t K_split = (args.K + k_grain - 1) / k_grain * TileK; + const index_t num_loop = TilePartitioner::GetLoopNum(K_split); + const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); + const TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); + + float ave_time{{0}}; + + constexpr auto scheduler = {self.tm.SCHEDULER_TO_CK[config.trait.scheduler]}; + + using UniversalGemmProblem = UniversalGemmPipelineProblem< + ADataType, BDataType, AccDataType, TileShape, + TileGemmUniversalTraits, + scheduler>; + + using GemmPipeline = {self.tm.PIPELINE_TO_CK[config.trait.pipeline]}; + {self._epilogue_code(config)} + + // Use GemmKernelMultiD for Multi-D variant + using GemmKernel = ck_tile::GemmKernelMultiD; + + const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) {{ + auto kargs = GemmKernel::MakeKernelArgs(args); + + if (!GemmKernel::IsSupportedArgument(kargs)) {{ + throw std::runtime_error("Arguments not supported! Multi-D currently doesn't support k_batch > 1"); + }} + + const dim3 grids = GemmKernel::GridSize(args.M, args.N, args.k_batch); + const dim3 blocks = GemmKernel::BlockSize(); + + constexpr int kBlockPerCu = {config.k_block_per_cu}; + ave_time = launch_kernel(stream, + make_kernel(GemmKernel{{}}, grids, blocks, 0, kargs)); + + return ave_time; + }}; + + BaseGemmPipeline::TailHandler(Run, has_hot_loop, tail_num); + return ave_time; + }} + + // Overload for standard GemmHostArgs (converts to Multi-D args with empty D tensors) + static float launch(const GemmHostArgs& args, const stream_config& stream) {{ + std::array empty_ds{{}}; + std::array empty_strides{{}}; + for (index_t i = 0; i < NumDTensor; ++i) {{ + empty_ds[i] = nullptr; + empty_strides[i] = 0; + }} + GemmMultiDArgs multi_d_args{{ + args.a_ptr, + args.b_ptr, + empty_ds, + args.e_ptr, + args.k_batch, + args.M, + args.N, + args.K, + args.stride_A, + args.stride_B, + empty_strides, + args.stride_C + }}; + return launch(multi_d_args, stream); + }}""" + + def _epilogue_code(self, config: KernelConfig) -> str: + """Generate epilogue code""" + if config.variant == GemmVariant.MULTI_D: + return """ + using EpilogueProblem = CShuffleEpilogueProblem< + ADataType, BDataType, DsDataType, AccDataType, CDataType, + DsLayout, CLayout, ElementWiseFn, + TilePartitioner::MPerBlock, TilePartitioner::NPerBlock, + WarpPerBlock_M, WarpPerBlock_N, WarpTileM, WarpTileN, WarpTileK, + TransposeC, NumWaveGroups, false, 1, false, 1, DoubleSmemBuffer>; + using GemmEpilogue = CShuffleEpilogue;""" + elif config.trait.epilogue == "cshuffle": + return """ + using EpilogueProblem = CShuffleEpilogueProblem< + ADataType, BDataType, tuple<>, AccDataType, CDataType, + tuple<>, CLayout, element_wise::PassThrough, + TilePartitioner::MPerBlock, TilePartitioner::NPerBlock, + WarpPerBlock_M, WarpPerBlock_N, WarpTileM, WarpTileN, WarpTileK, + TransposeC, NumWaveGroups, false, 1, false, 1, DoubleSmemBuffer>; + using GemmEpilogue = CShuffleEpilogue;""" + else: + return """ + using EpilogueProblem = DefaultGemm2DEpilogueProblem< + ADataType, BDataType, tuple<>, AccDataType, CDataType, + tuple<>, CLayout, element_wise::PassThrough, + TilePartitioner::MPerBlock, TilePartitioner::NPerBlock, + kPadM, kPadN, WarpTileM, WarpTileN, WarpTileK, TransposeC>; + using GemmEpilogue = DefaultGemm2DEpilogue;""" + + +# ============================================================================ +# Dispatcher Wrapper Generator +# ============================================================================ + + +class DispatcherWrapperGenerator: + """Generates dispatcher wrapper code""" + + def __init__(self, datatype: str, layout: str): + self.datatype = datatype + self.layout = layout + self.tm = TypeMappings() + + def generate( + self, config: KernelConfig, kernel_path: Path, output_dir: Path + ) -> str: + """Generate dispatcher wrapper""" + kernel_name = KernelNaming.generate(config, self.datatype, self.layout) + output_dtype = self.tm.get_output_dtype(self.datatype) + rel_path = kernel_path.relative_to(output_dir) + + return f"""// SPDX-License-Identifier: MIT +// Auto-generated dispatcher wrapper +#pragma once + +#include "ck_tile/dispatcher.hpp" +#include "ck_tile/dispatcher/backends/generated_kernel_backend.hpp" +#include "{rel_path}" + +namespace ck_tile {{ +namespace dispatcher {{ +namespace generated {{ + +using ::ck_tile::dispatcher::KernelInstancePtr; +using ::ck_tile::dispatcher::KernelKey; +using ::ck_tile::dispatcher::DataType; +using ::ck_tile::dispatcher::LayoutTag; +using ::ck_tile::dispatcher::Pipeline; +using ::ck_tile::dispatcher::Scheduler; +using ::ck_tile::dispatcher::Epilogue; +using Priority = ::ck_tile::dispatcher::Registry::Priority; +namespace backends = ::ck_tile::dispatcher::backends; + +inline KernelInstancePtr make_{kernel_name}(const std::string& gfx_arch = "gfx942") {{ + // Use the unique kernel struct name + using KernelStruct = Kernel_{kernel_name}; + + KernelKey key; + + // Signature + key.signature.dtype_a = {self.tm.DTYPE_TO_DISPATCHER[self.datatype]}; + key.signature.dtype_b = {self.tm.DTYPE_TO_DISPATCHER[self.datatype]}; + key.signature.dtype_c = {self.tm.DTYPE_TO_DISPATCHER[output_dtype]}; + key.signature.dtype_acc = DataType::FP32; + key.signature.layout_a = {self.tm.LAYOUT_TO_DISPATCHER[self.layout[0]]}; + key.signature.layout_b = {self.tm.LAYOUT_TO_DISPATCHER[self.layout[1]]}; + key.signature.layout_c = {self.tm.LAYOUT_TO_DISPATCHER[self.layout[2]]}; + key.signature.transpose_a = false; + key.signature.transpose_b = false; + key.signature.grouped = false; + key.signature.split_k = 1; + key.signature.elementwise_op = "{config.elementwise_op}"; + key.signature.num_d_tensors = {config.num_d_tensors}; + key.signature.structured_sparsity = false; + + // Algorithm + key.algorithm.tile_shape = {{{config.tile.tile_m}, {config.tile.tile_n}, {config.tile.tile_k}}}; + key.algorithm.wave_shape = {{{config.tile.warp_m}, {config.tile.warp_n}, {config.tile.warp_k}}}; + key.algorithm.warp_tile_shape = {{{config.tile.warp_tile_m}, {config.tile.warp_tile_n}, {config.tile.warp_tile_k}}}; + key.algorithm.pipeline = {self.tm.PIPELINE_TO_DISPATCHER[config.trait.pipeline]}; + key.algorithm.scheduler = {self.tm.SCHEDULER_TO_DISPATCHER[config.trait.scheduler]}; + key.algorithm.epilogue = {self.tm.EPILOGUE_TO_DISPATCHER[config.trait.epilogue]}; + key.algorithm.block_size = {config.block_size}; + key.algorithm.double_buffer = {str(config.trait.pipeline == "compv4").lower()}; + key.algorithm.persistent = {str(config.trait.persistent).lower()}; + key.algorithm.preshuffle = {str(config.preshuffle).lower()}; + key.algorithm.transpose_c = false; + key.algorithm.num_wave_groups = {config.num_wave_groups}; + + key.gfx_arch = gfx_arch; + + return std::make_shared>(key, "{kernel_name}"); +}} + +}}}}}} +""" + + +# ============================================================================ +# Main Unified Generator +# ============================================================================ + + +class UnifiedGemmCodegen: + """Unified GEMM code generator - single entry point""" + + def __init__( + self, + output_dir: Path, + datatype: str, + layout: str, + gpu_target: str = "gfx942", + config_file: Optional[Path] = None, + variants: List[GemmVariant] = None, + use_preselected: Optional[str] = None, + enable_arch_filter: bool = True, + kernel_set_name: Optional[str] = None, + ): + self.output_dir = Path(output_dir) + self.datatype = datatype + # Support 3-char (rcr) or 4-char (rcrr) layout codes + # 4th char specifies D tensor layout for multi-d + self.layout = layout[:3] # A, B, C layouts + self.d_layout = ( + layout[3] if len(layout) >= 4 else layout[2] + ) # D layout (default = C layout) + self.gpu_target = gpu_target + self.variants = variants or [GemmVariant.STANDARD] + self.use_preselected = use_preselected + self.kernel_set_name = kernel_set_name + + # Create directories - optionally with kernel set subdirectory + if kernel_set_name: + self.kernel_dir = self.output_dir / kernel_set_name + else: + self.kernel_dir = self.output_dir + self.kernel_dir.mkdir(parents=True, exist_ok=True) + self.wrapper_dir = self.kernel_dir / "dispatcher_wrappers" + self.wrapper_dir.mkdir(parents=True, exist_ok=True) + + # Load configuration + self.config = self._load_config(config_file) + + # Initialize architecture filter for GPU-specific validation + self.arch_filter = None + if enable_arch_filter and HAS_ARCH_FILTER: + try: + self.arch_filter = ArchFilter(gpu_target, strict_mode=False) + log.info(f"Architecture filter enabled for {gpu_target}") + except ValueError as e: + log.warning(f"Could not create arch filter: {e}") + + # Initialize generators (use self.layout which is the 3-char A,B,C layout) + self.ck_gen = CKTileKernelGenerator(datatype, self.layout) + self.disp_gen = DispatcherWrapperGenerator(datatype, self.layout) + + def _load_config(self, config_file: Optional[Path]) -> Dict: + """Load or create default configuration""" + if config_file and config_file.exists(): + with open(config_file) as f: + return json.load(f) + + # Match tile_engine default configs for GEMM/Preshuffle/Multi-D + # See: tile_engine/ops/gemm/configs/default_config.json + # tile_engine/ops/gemm_preshuffle/configs/default_config.json + # tile_engine/ops/gemm_multi_d/configs/default_config.json + return { + "tile_config": { + # tile_m/n/k: 64-256 step 64 = [64, 128, 192, 256] + "tile_m": [64, 128, 192, 256], + "tile_n": [64, 128, 192, 256], + "tile_k": [64, 128, 192, 256], + # warp configs matching tile_engine + "warp_m": [1, 2, 4], + "warp_n": [1, 2, 4], + "warp_k": [1], + # warp_tile configs matching tile_engine + "warp_tile_m": [4, 16, 32], + "warp_tile_n": [16, 32, 64], + "warp_tile_k": [8, 16, 32, 64, 128], + }, + "trait_config": { + "pipeline": ["compv3", "compv4", "mem"], + "epilogue": ["cshuffle", "default"], + "scheduler": ["intrawave", "interwave"], + "pad_m": [False], + "pad_n": [False], + "pad_k": [False], + "persistent": [False, True], + }, + "multi_d_config": { + # Note: Only MultiDAdd and MultiDMultiply are compatible with multi-D GEMM. + # Relu/Gelu are unary ops with signature (y, x), not multi-D signature (e, c, ds...) + "elementwise_ops": ["MultiDAdd", "MultiDMultiply"], + "num_d_tensors": [1, 2], + }, + } + + def generate_all(self, parallel: bool = True) -> Dict: + """Generate all kernels""" + log.info("Generating GEMM kernels:") + log.info(f" Datatype: {self.datatype}") + log.info(f" Layout: {self.layout}") + log.info(f" Variants: {[v.value for v in self.variants]}") + if self.use_preselected: + log.info(f" Using preselected set: {self.use_preselected}") + + results = {"kernels": [], "wrappers": [], "failed": []} + + # Get configurations + if self.use_preselected: + configs = self._get_preselected_configs() + log.info(f" Total configurations: {len(configs)}") + else: + for variant in self.variants: + log.info(f"\nGenerating {variant.value} kernels...") + configs = self._get_configs_for_variant(variant) + log.info(f" Configurations: {len(configs)}") + + if parallel: + with concurrent.futures.ThreadPoolExecutor() as executor: + futures = [ + executor.submit(self._generate_one, cfg) for cfg in configs + ] + for future in concurrent.futures.as_completed(futures): + try: + k, w = future.result() + results["kernels"].append(k) + results["wrappers"].append(w) + except Exception as e: + results["failed"].append(str(e)) + log.error(f"Failed: {e}") + else: + for cfg in configs: + try: + k, w = self._generate_one(cfg) + results["kernels"].append(k) + results["wrappers"].append(w) + except Exception as e: + results["failed"].append(str(e)) + log.error(f"Failed: {e}") + + # Generate registration header + if results["wrappers"]: + self._generate_registration_header(results["wrappers"]) + + return results + + # Generate from preselected set + if parallel: + with concurrent.futures.ThreadPoolExecutor() as executor: + futures = [executor.submit(self._generate_one, cfg) for cfg in configs] + for future in concurrent.futures.as_completed(futures): + try: + k, w = future.result() + results["kernels"].append(k) + results["wrappers"].append(w) + except Exception as e: + results["failed"].append(str(e)) + log.error(f"Failed: {e}") + else: + for cfg in configs: + try: + k, w = self._generate_one(cfg) + results["kernels"].append(k) + results["wrappers"].append(w) + except Exception as e: + results["failed"].append(str(e)) + log.error(f"Failed: {e}") + + # Generate registration header + if results["wrappers"]: + self._generate_registration_header(results["wrappers"]) + + return results + + def _get_preselected_configs(self) -> List[KernelConfig]: + """Get preselected kernel configurations""" + try: + from preselected_kernels import get_preselected_set + + return get_preselected_set(self.use_preselected) + except ImportError: + log.warning( + "preselected_kernels module not found, falling back to config-based generation" + ) + return [] + except ValueError as e: + log.error(f"Invalid preselected set: {e}") + return [] + + def _get_configs_for_variant(self, variant: GemmVariant) -> List[KernelConfig]: + """Get all configurations for a variant + + Args: + variant: GEMM variant (STANDARD, PRESHUFFLE, MULTI_D) + + Returns: + List of valid kernel configurations for the variant + """ + configs = [] + + # Get base configs + tile_configs = self._get_tile_configs() + trait_configs = self._get_trait_configs() + + for tile, trait in itertools.product(tile_configs, trait_configs): + # Perform variant-specific architecture validation + if self.arch_filter and HAS_ARCH_FILTER: + if not self._is_tile_arch_valid(tile, variant): + continue + + if variant == GemmVariant.STANDARD: + configs.append(KernelConfig(tile=tile, trait=trait, variant=variant)) + + elif variant == GemmVariant.PRESHUFFLE: + # Preshuffle needs specific pipeline (preshufflev2) and scheduler (default) + # Skip configs that don't use preshuffle-compatible traits + preshuffle_trait = TraitConfig( + pipeline="preshufflev2", + epilogue="cshuffle", + scheduler="default", + pad_m=trait.pad_m, + pad_n=trait.pad_n, + pad_k=trait.pad_k, + persistent=trait.persistent, + ) + # Only generate one preshuffle config per tile (not per trait) + # since preshuffle has fixed pipeline/scheduler + if trait.pipeline == "compv3" and trait.scheduler == "intrawave": + configs.append( + KernelConfig( + tile=tile, + trait=preshuffle_trait, + variant=variant, + preshuffle=True, + ) + ) + + elif variant == GemmVariant.MULTI_D: + multi_d = self.config.get("multi_d_config", {}) + for ew_op, num_d in itertools.product( + multi_d.get("elementwise_ops", ["MultiDAdd"]), + multi_d.get("num_d_tensors", [1]), + ): + configs.append( + KernelConfig( + tile=tile, + trait=trait, + variant=variant, + elementwise_op=ew_op, + num_d_tensors=num_d, + d_layout=self.d_layout, # Use extracted D layout + ) + ) + + return configs + + def _get_tile_configs(self) -> List[TileConfig]: + """Get valid tile configurations, filtered by architecture constraints""" + tc = self.config["tile_config"] + configs = [] + rejected_count = 0 + + for params in itertools.product( + tc["tile_m"], + tc["tile_n"], + tc["tile_k"], + tc["warp_m"], + tc["warp_n"], + tc["warp_k"], + tc["warp_tile_m"], + tc["warp_tile_n"], + tc["warp_tile_k"], + ): + tile = TileConfig(*params) + + # Basic validation + if not tile.is_valid(): + rejected_count += 1 + continue + + # Architecture-specific validation + if self.arch_filter and HAS_ARCH_FILTER: + if not self._is_tile_arch_valid(tile): + rejected_count += 1 + continue + + configs.append(tile) + + if rejected_count > 0: + log.debug(f"Rejected {rejected_count} tile configs for {self.gpu_target}") + + return configs + + def _is_tile_arch_valid( + self, tile: TileConfig, variant: GemmVariant = None + ) -> bool: + """Check if tile configuration is valid for target architecture + + Args: + tile: Tile configuration to validate + variant: GEMM variant (affects operator-specific constraints) + """ + if not self.arch_filter or not HAS_ARCH_FILTER: + return True + + # Determine data types based on self.datatype + # Note: dtype_c is the ACCUMULATOR type, not output type (which may be fp16) + # WMMA instructions on gfx942 always use fp32 accumulator for fp16 inputs + dtype_map = { + "fp16": ("fp16", "fp16", "fp32"), # A=fp16, B=fp16, Acc=fp32 + "bf16": ("bf16", "bf16", "fp32"), # A=bf16, B=bf16, Acc=fp32 + "fp8": ("fp8", "fp8", "fp32"), # A=fp8, B=fp8, Acc=fp32 + "bf8": ("bf8", "bf8", "fp32"), # A=bf8, B=bf8, Acc=fp32 + "int8": ("int8", "int8", "int32"), # A=int8, B=int8, Acc=int32 + } + dtype_a, dtype_b, dtype_c = dtype_map.get( + self.datatype, ("fp16", "fp16", "fp32") + ) + + # Map GEMM variant to operator type for validation + operator = None + pipeline = "compv4" # Default + scheduler = "intrawave" # Default + + if OperatorType is not None and variant is not None: + variant_to_operator = { + GemmVariant.STANDARD: OperatorType.GEMM, + GemmVariant.PRESHUFFLE: OperatorType.GEMM_PRESHUFFLE, + GemmVariant.MULTI_D: OperatorType.GEMM_MULTI_D, + } + operator = variant_to_operator.get(variant, OperatorType.GEMM) + + # Preshuffle requires specific pipeline and scheduler + if variant == GemmVariant.PRESHUFFLE: + pipeline = "preshufflev2" + scheduler = "default" + + # Use preshuffle-specific validation (comprehensive CK-specific checks) + if variant == GemmVariant.PRESHUFFLE: + if not is_preshuffle_config_valid( + tile_m=tile.tile_m, + tile_n=tile.tile_n, + tile_k=tile.tile_k, + warp_m=tile.warp_m, + warp_n=tile.warp_n, + warp_k=tile.warp_k, + warp_tile_m=tile.warp_tile_m, + warp_tile_n=tile.warp_tile_n, + warp_tile_k=tile.warp_tile_k, + datatype=self.datatype, + ): + return False + + return self.arch_filter.is_kernel_valid( + datatype_a=dtype_a, + datatype_b=dtype_b, + datatype_c=dtype_c, + tile_m=tile.tile_m, + tile_n=tile.tile_n, + tile_k=tile.tile_k, + warp_m=tile.warp_m, + warp_n=tile.warp_n, + warp_k=tile.warp_k, + warp_tile_m=tile.warp_tile_m, + warp_tile_n=tile.warp_tile_n, + warp_tile_k=tile.warp_tile_k, + pipeline=pipeline, + scheduler=scheduler, + layout=self.layout, + operator=operator, + ) + + def _get_trait_configs(self) -> List[TraitConfig]: + """Get valid trait configurations, filtered by architecture constraints""" + tc = self.config["trait_config"] + configs = [] + rejected_count = 0 + + for params in itertools.product( + tc["pipeline"], + tc["epilogue"], + tc["scheduler"], + tc["pad_m"], + tc["pad_n"], + tc["pad_k"], + tc["persistent"], + ): + trait = TraitConfig(*params) + + # Basic trait validation (unsupported combinations) + if not trait.is_valid(): + rejected_count += 1 + continue + + configs.append(trait) + + if rejected_count > 0: + log.debug(f"Rejected {rejected_count} trait configs") + + return configs + + def _generate_one(self, config: KernelConfig) -> Tuple[str, str]: + """Generate one kernel and wrapper""" + kernel_name = KernelNaming.generate(config, self.datatype, self.layout) + + # Generate CK Tile kernel + kernel_code = self.ck_gen.generate(config) + kernel_path = self.kernel_dir / f"{kernel_name}.hpp" + kernel_path.write_text(kernel_code) + + # Generate dispatcher wrapper + wrapper_code = self.disp_gen.generate(config, kernel_path, self.kernel_dir) + wrapper_path = self.wrapper_dir / f"dispatcher_wrapper_{kernel_name}.hpp" + wrapper_path.write_text(wrapper_code) + + # Generate .cpp compilation unit for per-kernel parallel builds + cpp_path = self.kernel_dir / f"{kernel_name}.cpp" + cpp_code = f'''// SPDX-License-Identifier: MIT +// Auto-generated compilation unit for: {kernel_name} +// Enables per-kernel parallel compilation with make -j + +#include "{kernel_name}.hpp" + +namespace ck_tile {{ namespace generated {{ + volatile bool _{kernel_name.replace("-", "_")}_loaded = true; +}} }} +''' + cpp_path.write_text(cpp_code) + + return str(kernel_path), str(wrapper_path) + + def _generate_registration_header(self, wrapper_paths: List[str]): + """Generate master registration header""" + kernel_names = [ + Path(w).stem.replace("dispatcher_wrapper_", "") for w in wrapper_paths + ] + + includes = "\n".join( + [f'#include "dispatcher_wrapper_{n}.hpp"' for n in kernel_names] + ) + registrations = "\n ".join( + [ + f"registry.register_kernel(generated::make_{n}(gfx_arch), priority);" + for n in kernel_names + ] + ) + + content = f"""// SPDX-License-Identifier: MIT +// Auto-generated master registration +#pragma once + +#include "ck_tile/dispatcher.hpp" +{includes} + +namespace ck_tile {{ +namespace dispatcher {{ + +using ::ck_tile::dispatcher::Registry; +using Priority = ::ck_tile::dispatcher::Registry::Priority; + +inline void register_all_tile_gemm_kernels( + const std::string& gfx_arch = "gfx942", + Priority priority = Priority::Normal) +{{ + auto& registry = Registry::instance(); + {registrations} +}} + +inline std::size_t get_tile_gemm_kernel_count() {{ return {len(kernel_names)}; }} + +}}}} +""" + + reg_path = self.wrapper_dir / "register_all_kernels.hpp" + reg_path.write_text(content) + logging.info(f"Generated registration header: {reg_path}") + + +# ============================================================================ +# CLI +# ============================================================================ + + +def _show_arch_info(gpu_target: str, datatype: str): + """Display supported configurations for a GPU architecture""" + if not HAS_ARCH_FILTER: + print("Architecture filter module not available") + return + + try: + from arch_filter import ( + get_supported_archs, + WARP_SUPPORTED_COMBINATIONS, + WARP_TILE_SUPPORTED_COMBINATIONS, + LDS_CAPACITY_LIMITS, + TRAIT_UNSUPPORTED_COMBINATIONS, + ) + + print(f"\n=== Architecture Info for {gpu_target} ===\n") + + # Supported architectures + print(f"Supported GPUs: {get_supported_archs()}") + + # Warp configurations + warp_cfgs = WARP_SUPPORTED_COMBINATIONS.get(gpu_target, []) + print("\nWarp configurations [warp_m, warp_n, warp_k]:") + for cfg in warp_cfgs: + print(f" {cfg}") + + # Warp tile configurations for data type + dtype_map = { + "fp16": "fp16_fp16_fp16", + "bf16": "bf16_bf16_bf16", + "fp8": "fp8_fp8_fp16", + "bf8": "bf8_bf8_fp16", + "int8": "int8_int8_int32", + } + dtype_key = dtype_map.get(datatype, "fp16_fp16_fp16") + + gpu_combos = WARP_TILE_SUPPORTED_COMBINATIONS.get(gpu_target, {}) + warp_tiles = gpu_combos.get(dtype_key, []) + print( + f"\nWarp tile configurations for {dtype_key} [warp_tile_m, warp_tile_n, warp_tile_k]:" + ) + for cfg in warp_tiles: + print(f" {cfg}") + + # All supported data types + print(f"\nAll supported data types on {gpu_target}:") + for dtype in gpu_combos.keys(): + print(f" {dtype}") + + # LDS limits + print("\nLDS capacity limits:") + for pipeline, limit in LDS_CAPACITY_LIMITS.items(): + print(f" {pipeline}: {limit // 1024}KB") + + # Unsupported trait combinations + print("\nUnsupported trait combinations (pipeline, epilogue, scheduler):") + for combo in TRAIT_UNSUPPORTED_COMBINATIONS: + print(f" {combo}") + + print() + + except Exception as e: + print(f"Error showing arch info: {e}") + + +def main(): + parser = argparse.ArgumentParser( + description="Unified GEMM Code Generator - Single Source of Truth" + ) + parser.add_argument( + "--output-dir", type=Path, required=True, help="Output directory" + ) + parser.add_argument( + "--datatype", + type=str, + default="fp16", + choices=["fp16", "bf16", "fp32", "fp8", "bf8", "int8", "pk_fp4"], + help="Data type (fp16, bf16, fp32, fp8, bf8, int8, pk_fp4)", + ) + parser.add_argument( + "--layout", + type=str, + default="rcr", + help="Layout (e.g., rcr for A=row, B=col, C=row; or rcrr for multi-d with D=row)", + ) + parser.add_argument( + "--gpu-target", + type=str, + default="gfx942", + help="Target GPU (gfx90a, gfx942, gfx950, gfx1201)", + ) + parser.add_argument("--config", type=Path, help="Configuration JSON file") + parser.add_argument( + "--variants", + nargs="+", + choices=["standard", "preshuffle", "multi_d"], + default=["standard"], + help="Variants to generate", + ) + parser.add_argument( + "--preselected", + type=str, + help="Use preselected kernel set (e.g., fp16_rcr_essential)", + ) + parser.add_argument( + "--no-parallel", action="store_true", help="Disable parallel generation" + ) + parser.add_argument( + "--register", action="store_true", help="Generate dispatcher registration code" + ) + parser.add_argument( + "--no-arch-filter", + action="store_true", + help="Disable architecture-specific kernel filtering", + ) + parser.add_argument( + "--show-arch-info", + action="store_true", + help="Show supported configurations for target GPU and exit", + ) + parser.add_argument( + "--kernel-set", + type=str, + help="Kernel set name (creates subdirectory for organization)", + ) + parser.add_argument( + "--tile-config-json", + type=str, + help="JSON string specifying exact tile configuration (for minimal builds)", + ) + + args = parser.parse_args() + + # Handle inline tile config JSON for minimal/single-kernel builds + if args.tile_config_json: + try: + cfg = json.loads(args.tile_config_json) + + # Build proper config structure + full_config = {} + + # Extract tile config + tile_keys = [ + "tile_m", + "tile_n", + "tile_k", + "warp_m", + "warp_n", + "warp_k", + "warp_tile_m", + "warp_tile_n", + "warp_tile_k", + "block_size", + ] + tile_config = {k: cfg[k] for k in tile_keys if k in cfg} + if tile_config: + full_config["tile_config"] = tile_config + + # Extract trait config + trait_keys = ["pipeline", "epilogue", "scheduler"] + trait_config = {k: cfg[k] for k in trait_keys if k in cfg} + # Add default pad/persistent values + trait_config.setdefault("pad_m", [False]) + trait_config.setdefault("pad_n", [False]) + trait_config.setdefault("pad_k", [False]) + trait_config.setdefault("persistent", [False]) + if trait_config: + full_config["trait_config"] = trait_config + + # Extract multi_d config (for multi_d variant) + if "elementwise_ops" in cfg or "num_d_tensors" in cfg: + multi_d_config = {} + if "elementwise_ops" in cfg: + multi_d_config["elementwise_ops"] = cfg["elementwise_ops"] + if "num_d_tensors" in cfg: + multi_d_config["num_d_tensors"] = cfg["num_d_tensors"] + full_config["multi_d_config"] = multi_d_config + + # Use already structured config if provided + if "tile_config" in cfg: + full_config = cfg + + # Write to temp file and use as config + import tempfile + + with tempfile.NamedTemporaryFile( + mode="w", suffix=".json", delete=False + ) as f: + json.dump(full_config, f) + args.config = Path(f.name) + except json.JSONDecodeError as e: + logging.error(f"Invalid tile-config-json: {e}") + return 1 + except KeyError as e: + logging.error(f"Missing required key in tile-config-json: {e}") + return 1 + + # Show architecture info if requested + if args.show_arch_info: + _show_arch_info(args.gpu_target, args.datatype) + return 0 + + variants = [GemmVariant(v) for v in args.variants] if not args.preselected else None + + codegen = UnifiedGemmCodegen( + output_dir=args.output_dir, + datatype=args.datatype, + layout=args.layout, + gpu_target=args.gpu_target, + config_file=args.config, + variants=variants, + use_preselected=args.preselected, + enable_arch_filter=not args.no_arch_filter, + kernel_set_name=args.kernel_set, + ) + + results = codegen.generate_all(parallel=not args.no_parallel) + + logging.info("\n✅ Generation complete!") + logging.info(f" Kernels: {len(results['kernels'])}") + logging.info(f" Wrappers: {len(results['wrappers'])}") + logging.info(f" Failed: {len(results['failed'])}") + + if results["failed"]: + logging.error(f"\nFailed kernels: {len(results['failed'])}") + for err in results["failed"][:5]: + logging.error(f" {err}") + + # Generate dispatcher registration if requested + if args.register: + logging.info("\n📝 Generating dispatcher registration code...") + try: + from generate_dispatcher_registration import ( + scan_generated_headers, + generate_registration_header, + generate_registration_cpp, + ) + + kernels = scan_generated_headers(args.output_dir) + reg_dir = args.output_dir / "registration" + reg_dir.mkdir(exist_ok=True) + + generate_registration_header( + kernels, reg_dir / "dispatcher_registration.hpp" + ) + generate_registration_cpp(kernels, reg_dir / "dispatcher_registration.cpp") + + logging.info(f"✓ Generated registration code for {len(kernels)} kernels") + except Exception as e: + logging.error(f"Failed to generate registration code: {e}") + return 1 + + return 0 if not results["failed"] else 1 + + +if __name__ == "__main__": + exit(main()) diff --git a/dispatcher/examples/CMakeLists.txt b/dispatcher/examples/CMakeLists.txt new file mode 100644 index 00000000000..0359eb0d8d9 --- /dev/null +++ b/dispatcher/examples/CMakeLists.txt @@ -0,0 +1,448 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +cmake_minimum_required(VERSION 3.16) + +# Get processor count for parallel builds +include(ProcessorCount) +ProcessorCount(NPROC) +if(NPROC EQUAL 0) + set(NPROC 4) +endif() + +# GPU target architecture (passed from command line or default to gfx942) +if(NOT DEFINED GPU_TARGETS OR GPU_TARGETS STREQUAL "") + set(GPU_TARGETS "gfx942" CACHE STRING "GPU architecture target") +endif() +# Extract first target if multiple are provided (we only support single target builds) +string(REPLACE ";" " " GPU_TARGETS_SPACE "${GPU_TARGETS}") +string(REPLACE " " ";" GPU_TARGETS_LIST "${GPU_TARGETS_SPACE}") +list(GET GPU_TARGETS_LIST 0 GPU_TARGET) +message(STATUS "Building for GPU target: ${GPU_TARGET}") + +# NOTE: Per-kernel compilation is now automatic via declarative examples +# Each example generates only its declared kernels (from DECL_KERNEL_SET) + +# Link to dispatcher library +link_directories(${CMAKE_CURRENT_SOURCE_DIR}/../build) + +# ============================================================================= +# Kernel Output Directory +# ============================================================================= + +set(KERNEL_OUTPUT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/../build/generated_kernels") +file(MAKE_DIRECTORY ${KERNEL_OUTPUT_DIR}) + +# ============================================================================= +# Kernel Generation Targets (run during 'make', not 'cmake') +# ============================================================================= + +# Sentinel files to track generation +set(GEMM_SENTINEL "${KERNEL_OUTPUT_DIR}/.gemm_generated") + +# Generate GEMM kernels (standard + preshuffle + multi_d) - runs with internal parallelism +# Note: 4-char layout "rcrr" means A=row, B=col, C=row, D=row (for multi-d) +add_custom_command( + OUTPUT ${GEMM_SENTINEL} + COMMAND python3 ${CMAKE_CURRENT_SOURCE_DIR}/../codegen/unified_gemm_codegen.py + --datatype fp16 --layout rcrr --variants standard preshuffle multi_d + --output ${KERNEL_OUTPUT_DIR} + COMMAND ${CMAKE_COMMAND} -E touch ${GEMM_SENTINEL} + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/../codegen + COMMENT "Generating GEMM kernels (fp16, rcrr, standard + preshuffle + multi_d) with internal parallelism..." + VERBATIM +) + +add_custom_target(generate_gemm_kernels + DEPENDS ${GEMM_SENTINEL} + COMMENT "GEMM kernel generation target" +) + +# Alias for generate_all_kernels (GEMM only now) +add_custom_target(generate_all_kernels + DEPENDS generate_gemm_kernels +) + +# ============================================================================= +# Per-Kernel Compilation (Maximum Parallelism) +# ============================================================================= +# Enable with: cmake -DPER_KERNEL_COMPILATION=ON +# +# This creates ONE translation unit per kernel, enabling: +# 1. Maximum parallelism with make -j$(nproc) +# 2. Per-kernel build progress: "[1/128] Building kernel: gemm_fp16_128x128" +# 3. Incremental rebuilds (only changed kernels recompile) +# 4. Fine-grained build time analysis +# +# Build process: +# 1. Generate kernel headers (.hpp) +# 2. Generate wrapper files (.cpp) - one per kernel +# 3. Compile each wrapper in parallel +# 4. Link all objects into libdispatcher_kernels.so +# +# Example output: +# [ 1/128] Building kernel: gemm_fp16_rcr_128x128x32 +# [ 2/128] Building kernel: gemm_fp16_rcr_256x256x64 +# ... +# [128/128] Linking: libdispatcher_kernels.so +# ============================================================================= + +set(WRAPPER_DIR "${CMAKE_BINARY_DIR}/kernel_wrappers") +set(WRAPPER_SENTINEL "${WRAPPER_DIR}/.wrappers_generated") + +# Target: Generate wrapper .cpp files (one per kernel) +add_custom_command( + OUTPUT ${WRAPPER_SENTINEL} + COMMAND python3 ${CMAKE_CURRENT_SOURCE_DIR}/../codegen/generate_kernel_wrappers.py + --kernel-dir ${KERNEL_OUTPUT_DIR} + --output-dir ${WRAPPER_DIR} + --generate-makefile + --generate-cmake + COMMAND ${CMAKE_COMMAND} -E touch ${WRAPPER_SENTINEL} + DEPENDS ${GEMM_SENTINEL} + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/../codegen + COMMENT "Generating per-kernel wrapper .cpp files..." + VERBATIM +) + +add_custom_target(generate_kernel_wrappers + DEPENDS ${WRAPPER_SENTINEL} + COMMENT "Kernel wrapper generation target" +) + +# Target: Build kernels using generated Makefile (true per-kernel progress) +add_custom_target(build_kernels_parallel + COMMAND ${CMAKE_COMMAND} -E echo "Building kernels with per-kernel progress..." + COMMAND make -C ${WRAPPER_DIR} -j${NPROC} 2>&1 | grep -E "^\\[|Built|Linking|Error" + DEPENDS generate_kernel_wrappers + WORKING_DIRECTORY ${WRAPPER_DIR} + COMMENT "Compiling kernels in parallel (one translation unit per kernel)..." + VERBATIM +) + +# Global kernel build (optional - prefer per-example builds for minimal compilation) +# This builds ALL kernels into a shared library - use for Python bindings or full library +# For C++ examples, use declarative approach which builds only needed kernels +add_custom_target(dispatcher_kernels + COMMAND python3 ${CMAKE_CURRENT_SOURCE_DIR}/../scripts/parallel_kernel_builder.py + --kernel-dir ${KERNEL_OUTPUT_DIR} + --output-dir ${CMAKE_BINARY_DIR} + --include-dirs "${CMAKE_CURRENT_SOURCE_DIR}/../../include,${CMAKE_CURRENT_SOURCE_DIR}/../include" + --jobs ${NPROC} + DEPENDS generate_all_kernels + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/../scripts + COMMENT "Building ALL kernels in parallel (prefer per-example builds for minimal compilation)..." + VERBATIM +) + +# ============================================================================= +# Force regeneration targets (useful when you want to regenerate) +# ============================================================================= + +add_custom_target(regenerate_gemm_kernels + COMMAND ${CMAKE_COMMAND} -E remove -f ${GEMM_SENTINEL} + COMMAND python3 ${CMAKE_CURRENT_SOURCE_DIR}/../codegen/unified_gemm_codegen.py + --datatype fp16 --layout rcr --variants standard preshuffle multi_d + --output ${KERNEL_OUTPUT_DIR} + COMMAND ${CMAKE_COMMAND} -E touch ${GEMM_SENTINEL} + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/../codegen + COMMENT "Force regenerating GEMM kernels (standard + preshuffle + multi_d)..." + VERBATIM +) + +add_custom_target(regenerate_all_kernels + DEPENDS regenerate_gemm_kernels +) + +# Clean all per-example kernel directories +add_custom_target(clean_example_kernels + COMMAND ${CMAKE_COMMAND} -E echo "Removing per-example kernel directories..." + COMMAND find ${CMAKE_BINARY_DIR} -maxdepth 1 -type d -name "*_kernels" -exec rm -rf {} + + COMMENT "Cleaning all per-example kernel directories..." + VERBATIM +) + +# ============================================================================= +# Helper function to add a GPU example with force-included kernel +# ============================================================================= + +# Helper for GPU examples that use the dispatcher registry +# KERNEL_HEADER can be: +# - A registration header (register_all_kernels.hpp) - included directly in source +# - A specific kernel header - force-included via compiler flag +function(add_gpu_example NAME SOURCE KERNEL_HEADER) + add_executable(${NAME} ${SOURCE}) + + target_link_libraries(${NAME} PRIVATE ck_tile_dispatcher) + + target_include_directories(${NAME} PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/../../include # CK root include + ${CMAKE_CURRENT_SOURCE_DIR}/../include # Dispatcher include + ${CMAKE_CURRENT_SOURCE_DIR}/../build/generated_kernels # Generated kernels + ${CMAKE_CURRENT_SOURCE_DIR}/../build/generated_kernels/dispatcher_wrappers # Wrapper headers + ) + + # Check if using registration header (no force-include needed) + get_filename_component(HEADER_NAME ${KERNEL_HEADER} NAME) + if(HEADER_NAME STREQUAL "register_all_kernels.hpp") + # Registration header - examples include it directly + target_compile_options(${NAME} PRIVATE + -DGEMM_KERNEL_AVAILABLE=1 + -mllvm -enable-noalias-to-md-conversion=0 + -Wno-undefined-func-template + -Wno-float-equal + --offload-compress + ) + else() + # Specific kernel header - force-include it + target_compile_options(${NAME} PRIVATE + -include ${KERNEL_HEADER} + -mllvm -enable-noalias-to-md-conversion=0 + -Wno-undefined-func-template + -Wno-float-equal + --offload-compress + ) + endif() + + if(hip_FOUND) + target_link_libraries(${NAME} PRIVATE hip::device hip::host) + endif() +endfunction() + +# Helper for standalone GPU examples (instantiate kernel directly, no pre-generated header) +function(add_standalone_gpu_example NAME SOURCE) + add_executable(${NAME} ${SOURCE}) + + target_link_libraries(${NAME} PRIVATE ck_tile_dispatcher) + + target_include_directories(${NAME} PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/../../include # CK root include + ${CMAKE_CURRENT_SOURCE_DIR}/../include # Dispatcher include + ${CMAKE_CURRENT_SOURCE_DIR}/../build/generated_kernels # Generated kernels (optional) + ) + + target_compile_options(${NAME} PRIVATE + -mllvm -enable-noalias-to-md-conversion=0 + -Wno-undefined-func-template + -Wno-float-equal + --offload-compress + ) + + if(hip_FOUND) + target_link_libraries(${NAME} PRIVATE hip::device hip::host) + endif() +endfunction() + +# Helper for declarative examples (configuration demo, still needs HIP compiler for CK headers) +function(add_declarative_example NAME SOURCE) + add_executable(${NAME} ${SOURCE}) + + target_include_directories(${NAME} PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/../../include + ${CMAKE_CURRENT_SOURCE_DIR}/../include + ) + + target_compile_options(${NAME} PRIVATE + -Wno-float-equal + -Wno-unused-variable + -Wno-undefined-func-template + -mllvm -enable-noalias-to-md-conversion=0 + ) + + target_link_libraries(${NAME} PRIVATE ck_tile_dispatcher) + + if(hip_FOUND) + target_link_libraries(${NAME} PRIVATE hip::device hip::host) + endif() +endfunction() + +# ============================================================================= +# GEMM Examples +# ============================================================================= + +# Per-example kernel directories are created from DECL_KERNEL_SET declarations +# Each example gets its own: build/_kernels/ +# This prevents clashes during parallel compilation of multiple examples. + +# Helper function to add example with declarative kernel support +# Parses DECL_KERNEL_SET from source and generates ONLY the declared kernels +# This enables minimal builds: only kernels needed by this example are generated +# +# Key features: +# - Per-example kernel directories: build/_kernels/ (no clashes) +# - Automatic header inclusion: No hardcoded #include needed in source +# - Minimal builds: Only declared kernels are generated +# - Auto-regeneration: Kernels regenerated if directory missing +# - Parallel compilation: Each kernel is a separate translation unit +function(add_declarative_gpu_example NAME SOURCE) + set(EXAMPLE_SOURCE "${CMAKE_CURRENT_SOURCE_DIR}/${SOURCE}") + get_filename_component(EXAMPLE_STEM ${SOURCE} NAME_WE) + + # Per-example kernel directories + set(EXAMPLE_KERNEL_DIR "${CMAKE_BINARY_DIR}/${NAME}_kernels") + set(EXAMPLE_HEADER "${EXAMPLE_KERNEL_DIR}/${EXAMPLE_STEM}_kernels.hpp") + set(EXAMPLE_LIB "${EXAMPLE_KERNEL_DIR}/lib${NAME}_kernels.a") + set(EXAMPLE_SENTINEL "${EXAMPLE_KERNEL_DIR}/.generated") + + # Generate AND compile kernels in parallel at make time + # This avoids slow cmake and gets per-kernel progress + add_custom_command( + OUTPUT ${EXAMPLE_SENTINEL} ${EXAMPLE_LIB} + COMMAND python3 ${CMAKE_CURRENT_SOURCE_DIR}/../scripts/example_kernel_builder.py + ${EXAMPLE_SOURCE} + --output-dir ${EXAMPLE_KERNEL_DIR} + --include-dirs "${CMAKE_CURRENT_SOURCE_DIR}/../../include,${CMAKE_CURRENT_SOURCE_DIR}/../include" + --gpu-target ${GPU_TARGET} + --jobs ${NPROC} + --target-name ${NAME} + COMMAND ${CMAKE_COMMAND} -E touch ${EXAMPLE_SENTINEL} + DEPENDS ${EXAMPLE_SOURCE} + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/../scripts + COMMENT "[${NAME}] Generating and compiling kernels from DECL_KERNEL_SET..." + VERBATIM + ) + + add_custom_target(generate_${NAME}_kernels DEPENDS ${EXAMPLE_SENTINEL}) + + # Add the executable + add_executable(${NAME} ${SOURCE}) + + target_link_libraries(${NAME} PRIVATE ck_tile_dispatcher) + + # Link against the per-example kernel library + target_link_libraries(${NAME} PRIVATE ${EXAMPLE_LIB}) + + target_include_directories(${NAME} PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/../../include + ${CMAKE_CURRENT_SOURCE_DIR}/../include + ${EXAMPLE_KERNEL_DIR} + ${EXAMPLE_KERNEL_DIR}/dispatcher_wrappers + ) + + # Force-include the generated registration header + target_compile_options(${NAME} PRIVATE + -include ${EXAMPLE_HEADER} + -DGEMM_KERNEL_AVAILABLE=1 + -mllvm -enable-noalias-to-md-conversion=0 + -Wno-undefined-func-template + -Wno-float-equal + --offload-compress + ) + + if(hip_FOUND) + target_link_libraries(${NAME} PRIVATE hip::device hip::host) + endif() + + # Only depends on generating THIS example's kernels + add_dependencies(${NAME} generate_${NAME}_kernels) +endfunction() + +# GEMM C++ examples with declarative kernel support +# Each example's C++ code contains DECL_KERNEL_SET which declares needed kernels +add_declarative_gpu_example(gemm_01_basic gemm/cpp/01_basic_gemm.cpp) +add_declarative_gpu_example(gemm_02_multi_size gemm/cpp/02_multi_size.cpp) +add_declarative_gpu_example(gemm_03_benchmark_validation gemm/cpp/03_benchmark_validation.cpp) +add_declarative_gpu_example(gemm_04_heuristics gemm/cpp/04_heuristics.cpp) +add_declarative_gpu_example(gemm_05_json_export gemm/cpp/05_json_export.cpp) +add_declarative_gpu_example(gemm_06_multi_registry gemm/cpp/06_multi_registry.cpp) + +# ============================================================================= +# GEMM Python Library - Single Fallback Kernel +# ============================================================================= + +# Generate a single fallback kernel for the Python library (fp16, rcr, compv4) +set(GEMM_FALLBACK_KERNEL_DIR "${CMAKE_CURRENT_BINARY_DIR}/gemm_python_fallback") +set(GEMM_FALLBACK_KERNEL "${GEMM_FALLBACK_KERNEL_DIR}/gemm_fp16_rcr_compv4_cshuffle_intrawave_False_False_False_False_128x128x32_2x2x1_32x32x16.hpp") + +# Tile config JSON for single kernel generation +set(GEMM_FALLBACK_TILE_CONFIG "{\"tile_m\":[128],\"tile_n\":[128],\"tile_k\":[32],\"warp_m\":[2],\"warp_n\":[2],\"warp_k\":[1],\"warp_tile_m\":[32],\"warp_tile_n\":[32],\"warp_tile_k\":[16],\"pipeline\":[\"compv4\"],\"scheduler\":[\"intrawave\"],\"epilogue\":[\"cshuffle\"]}") + +# Generate single fallback kernel (not all 6000+ kernels) +add_custom_command( + OUTPUT ${GEMM_FALLBACK_KERNEL} + COMMAND ${CMAKE_COMMAND} -E make_directory ${GEMM_FALLBACK_KERNEL_DIR} + COMMAND python3 ${CMAKE_CURRENT_SOURCE_DIR}/../codegen/unified_gemm_codegen.py + --datatype fp16 --layout rcr --variants standard + --gpu-target ${GPU_TARGET} + --output-dir ${GEMM_FALLBACK_KERNEL_DIR} + --tile-config-json "${GEMM_FALLBACK_TILE_CONFIG}" + COMMENT "Generating single fallback GEMM kernel for Python library" + VERBATIM +) + +add_custom_target(generate_gemm_fallback_kernel DEPENDS ${GEMM_FALLBACK_KERNEL}) + +# GEMM dynamic library for Python +add_library(dispatcher_gemm_lib SHARED ${CMAKE_CURRENT_SOURCE_DIR}/../bindings/ctypes/gemm_ctypes_lib.cpp) +target_link_libraries(dispatcher_gemm_lib PRIVATE ck_tile_dispatcher) +target_include_directories(dispatcher_gemm_lib PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/../../include + ${CMAKE_CURRENT_SOURCE_DIR}/../include + ${GEMM_FALLBACK_KERNEL_DIR} +) +target_compile_options(dispatcher_gemm_lib PRIVATE + -DCK_TILE_SINGLE_KERNEL_INCLUDE + -include ${GEMM_FALLBACK_KERNEL} + -DGFX_ARCH="${GPU_TARGET}" + -mllvm -enable-noalias-to-md-conversion=0 + -Wno-undefined-func-template + -Wno-float-equal + --offload-compress +) +if(hip_FOUND) + target_link_libraries(dispatcher_gemm_lib PRIVATE hip::device hip::host) +endif() +add_dependencies(dispatcher_gemm_lib generate_gemm_fallback_kernel) + +message(STATUS "GEMM examples configured - kernels will be generated during 'make'") + +# Convenience target to build all Python ctypes libraries +add_custom_target(python_libs + DEPENDS dispatcher_gemm_lib + COMMENT "Building Python ctypes libraries (GEMM)" +) + +# ============================================================================= +# Per-Architecture Kernel Generation Targets +# ============================================================================= + +set(SUPPORTED_GPU_ARCHS gfx942 gfx90a gfx1100 gfx1030) + +foreach(ARCH ${SUPPORTED_GPU_ARCHS}) + # GEMM kernels for this arch + add_custom_target(generate_gemm_kernels_${ARCH} + COMMAND python3 ${CMAKE_CURRENT_SOURCE_DIR}/../codegen/unified_gemm_codegen.py + --datatype fp16 --layout rcr --gpu-target ${ARCH} + --output ${KERNEL_OUTPUT_DIR} + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/../codegen + COMMENT "Generating GEMM kernels for ${ARCH}..." + VERBATIM + ) + + # Alias for kernels (GEMM only now) + add_custom_target(generate_kernels_${ARCH} + DEPENDS generate_gemm_kernels_${ARCH} + COMMENT "Generating all kernels for ${ARCH}..." + ) +endforeach() + +# ============================================================================= +# Summary +# ============================================================================= + +message(STATUS "") +message(STATUS "=== Dispatcher Examples Configuration ===") +message(STATUS "") +message(STATUS "Kernels will be generated automatically during 'make'") +message(STATUS " Generated to: ${KERNEL_OUTPUT_DIR}") +message(STATUS "") +message(STATUS "Build targets:") +message(STATUS " make - Build all examples (generates kernels first)") +message(STATUS " make python_libs - Build Python ctypes libraries") +message(STATUS " make generate_all_kernels - Generate all kernels only") +message(STATUS " make regenerate_all_kernels - Force regenerate all kernels") +message(STATUS "") +message(STATUS "Per-architecture targets:") +message(STATUS " make generate_kernels_ - Generate for specific arch") +message(STATUS " Supported archs: ${SUPPORTED_GPU_ARCHS}") +message(STATUS "") diff --git a/dispatcher/examples/README.md b/dispatcher/examples/README.md new file mode 100644 index 00000000000..fdee9c35839 --- /dev/null +++ b/dispatcher/examples/README.md @@ -0,0 +1,210 @@ +# CK Tile Dispatcher Examples + +Comprehensive examples for GEMM operations with GPU execution. + +> **Note**: Convolution examples have been moved to `ck-2/conv_archive/` for reference. + +--- + +## Quick Start + +### Step 1: Build + +```bash +cd /path/to/composable_kernel/dispatcher +mkdir -p build && cd build + +cmake .. \ + -DCMAKE_PREFIX_PATH=/opt/rocm \ + -DCMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \ + -DCMAKE_BUILD_TYPE=Release \ + -DGPU_TARGETS="gfx942" \ + -DBUILD_DISPATCHER_EXAMPLES=ON + +# Build everything (C++ examples + Python libraries) +make -j$(nproc) + +# Or build ONLY Python libraries (faster) +make python_libs -j$(nproc) +``` + +### Step 2: Run C++ Examples + +```bash +cd build/examples + +# GEMM +./gemm_01_basic +./gemm_02_multi_size +./gemm_03_benchmark_validation +./gemm_04_heuristics +./gemm_05_json_export +./gemm_06_multi_registry +``` + +### Step 3: Run Python Examples + +```bash +cd /path/to/composable_kernel/dispatcher + +# GEMM +python3 examples/gemm/python/01_basic_gemm.py +python3 examples/gemm/python/04_validation.py +python3 examples/gemm/python/07_stress_test.py +python3 examples/gemm/python/08_heuristics.py +``` + +--- + +## Directory Structure + +``` +examples/ +├── gemm/ +│ ├── cpp/ # 6 C++ GEMM examples +│ └── python/ # 11 Python GEMM examples +│ +└── README.md +``` + +--- + +## GEMM Examples + +### C++ Examples + +| # | Example | Description | +|---|---------|-------------| +| 01 | `gemm_01_basic` | Basic GEMM with declarative API, autofill, autocorrect | +| 02 | `gemm_02_multi_size` | Wildcard expansion for multiple configurations | +| 03 | `gemm_03_benchmark_validation` | Performance benchmarking with CPU/GPU validation | +| 04 | `gemm_04_heuristics` | Heuristic-based kernel selection | +| 05 | `gemm_05_json_export` | Registry JSON export for external tools | +| 06 | `gemm_06_multi_registry` | Multiple registries with named kernel sets | + +**Details:** [gemm/cpp/README.md](gemm/cpp/README.md) + +--- + +### Python Examples + +| # | Example | Description | +|---|---------|-------------| +| 01 | `01_basic_gemm.py` | Basic GEMM with multi-kernel support | +| 02 | `02_batch_gemm.py` | Batched GEMM operations | +| 03 | `03_benchmark.py` | Performance benchmarking | +| 04 | `04_validation.py` | CPU reference validation | +| 05 | `05_numpy_integration.py` | NumPy array integration | +| 06 | `06_json_export.py` | Registry JSON export | +| 07 | `07_stress_test.py` | Multi-kernel stress testing (48 configs) | +| 08 | `08_heuristics.py` | Heuristic-based kernel selection (24 configs) | +| 09 | `09_multi_registry.py` | Multiple registries | +| 10 | `10_advanced_benchmark.py` | Advanced benchmark with full control | +| 11 | `11_json_import.py` | Import kernels from JSON | + +**Details:** [gemm/python/README.md](gemm/python/README.md) + +--- + +## Key Features + +### Declarative Kernel API + +Both C++ and Python examples use a declarative approach: + +**C++ (DECL_KERNEL_SET macro):** +```cpp +DECL_KERNEL_SET(my_kernels, + .add( + Signature().dtype("fp16").layout("rcr"), + Algorithm().tile(256, 256, 32).wave(2, 2, 1).warp(32, 32, 16) + .pipeline("compv4").scheduler("intrawave"), + "gfx942" + ) +); +``` + +**Python (KernelConfig):** +```python +config = KernelConfig( + tile_m=256, tile_n=256, tile_k=32, + wave_m=2, wave_n=2, wave_k=1, + warp_tile_m=32, warp_tile_n=32, warp_tile_k=16, + pipeline="compv4", scheduler="intrawave" +) +``` + +### Autofill and Autocorrect + +The build system automatically: +- **Autofills** missing parameters with sensible defaults +- **Autocorrects** invalid parameters based on architecture constraints +- **Expands** wildcards (`*`, `-1`, `ANY_INT`) to all valid configurations + +### Architecture Filtering + +Kernel configurations are validated against GPU architecture constraints: +- Tile divisibility requirements +- Warp tile constraints +- Pipeline compatibility + +Invalid configurations are automatically pruned during code generation. + +--- + +## Validation Examples + +### C++ Validation + +```bash +./gemm_03_benchmark_validation --verify 1 # GEMM with CPU reference +./gemm_03_benchmark_validation --verify 2 # GEMM with GPU reference +``` + +### Python Validation + +```bash +python3 examples/gemm/python/04_validation.py +python3 examples/gemm/python/07_stress_test.py # Multi-kernel validation +``` + +--- + +## Troubleshooting + +### Python: Library not found + +```bash +# Run from dispatcher directory +cd /path/to/composable_kernel/dispatcher +python3 examples/gemm/python/01_basic_gemm.py +``` + +### C++: Executables not found + +```bash +# Build with examples enabled +cmake .. -DBUILD_DISPATCHER_EXAMPLES=ON +make -j$(nproc) + +# Run from build/examples +cd build/examples +./gemm_01_basic +``` + +### GPU not detected + +```bash +rocminfo | grep "Name:" +# Should show: gfx942, gfx90a, etc. +``` + +--- + +## Archived Examples + +Convolution examples have been archived to `ck-2/conv_archive/dispatcher/`: +- `examples/conv/cpp/` - 11 C++ convolution examples +- `examples/conv/python/` - 14 Python convolution examples + +See the archive for convolution functionality reference. diff --git a/dispatcher/examples/gemm/cpp/01_basic_gemm.cpp b/dispatcher/examples/gemm/cpp/01_basic_gemm.cpp new file mode 100644 index 00000000000..80b584a8425 --- /dev/null +++ b/dispatcher/examples/gemm/cpp/01_basic_gemm.cpp @@ -0,0 +1,243 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +/** + * Example 01: Basic GEMM - Autofill, Autocorrect, and Full Declaration + * + * Demonstrates THREE declaration patterns: + * + * 1. AUTOFILL: Minimal declaration - missing params filled with defaults + * .add(Signature().dtype("fp16").layout("rcr"), + * Algorithm().tile(128,128,64).pipeline("compv3").scheduler("intrawave"), + * "gfx942") + * -> wave(2,2,1), warp(32,32,16), epilogue("cshuffle") added automatically + * + * 2. AUTOCORRECT: Invalid params corrected to valid values + * .add(..., Algorithm().wave(1,1,1)...) + * -> wave(1,1,1) is invalid for gfx942, corrected to wave(2,2,1) + * + * 3. FULL: All parameters explicitly specified + * .add(..., Algorithm().tile().wave().warp().pipeline().scheduler().epilogue()...) + * + * Build: cd dispatcher/build && cmake .. && make gemm_01_basic + */ + +#include +#include +#include +#include + +#include "ck_tile/dispatcher.hpp" +#include "ck_tile/dispatcher/kernel_decl.hpp" +#include "ck_tile/dispatcher/example_args.hpp" + +using namespace ck_tile::dispatcher; +using namespace ck_tile::dispatcher::backends; +using namespace ck_tile::dispatcher::utils; +using Signature = decl::Signature; +using Algorithm = decl::Algorithm; + +// ============================================================================= +// THREE KERNEL DECLARATION PATTERNS +// ============================================================================= + +DECL_KERNEL_SET( + basic_gemm_kernels, + // ------------------------------------------------------------------------- + // Pattern 1: AUTOFILL - Minimal declaration + // Only specify: dtype, layout, tile, pipeline, scheduler + // Auto-filled: wave(2,2,1), warp(32,32,16), epilogue("cshuffle"), pad(false,false,false) + // ------------------------------------------------------------------------- + .add(Signature().dtype("fp16").layout("rcr"), + Algorithm() + .tile(128, 128, 64) // Required + .pipeline("compv3") // Required + .scheduler("intrawave"), // Required + "gfx942") + + // ------------------------------------------------------------------------- + // Pattern 2: AUTOCORRECT - Invalid wave config + // wave(1,1,1) is invalid for gfx942 WMMA, corrected to wave(2,2,1) + // ------------------------------------------------------------------------- + .add(Signature().dtype("fp16").layout("rcr"), + Algorithm() + .tile(128, 128, 32) // Different tile_k to make unique kernel + .wave(1, 1, 1) // INVALID: autocorrected to (2,2,1) + .warp(32, 32, 16) // Valid warp for 128x128 tile + .pipeline("compv3") + .scheduler("intrawave") + .epilogue("cshuffle"), + "gfx942") + + // ------------------------------------------------------------------------- + // Pattern 3: FULL - All parameters explicitly specified + // No autofill or autocorrect needed + // ------------------------------------------------------------------------- + .add(Signature().dtype("fp16").layout("rcr"), + Algorithm() + .tile(64, 64, 32) // Explicit tile + .wave(2, 2, 1) // Explicit wave (valid) + .warp(16, 16, 32) // Explicit warp tile + .pipeline("compv3") // Explicit pipeline + .scheduler("intrawave") // Explicit scheduler + .epilogue("cshuffle") // Explicit epilogue + .pad(false, false, false), // Explicit padding + "gfx942")); + +// ============================================================================= +// MAIN +// ============================================================================= + +int main(int argc, char* argv[]) +{ + ExampleArgs args("Example 01: GEMM Autofill/Autocorrect/Full", + "Three kernel declaration patterns"); + args.add_flag("--list", "List registered kernels"); + args.add_flag("--list-verbose", "List registered kernels with full details"); + args.add_option("--size", "1024", "Problem size MxNxK"); + args.add_option("--arch", "gfx942", "GPU architecture"); + + if(!args.parse(argc, argv)) + return 0; + + print_header("Example 01: GEMM Declaration Patterns"); + + // ========================================================================= + // Show the Three Patterns + // ========================================================================= + std::cout << "\nTHREE DECLARATION PATTERNS:\n"; + std::cout << "============================\n\n"; + + std::cout << "1. AUTOFILL (minimal declaration):\n"; + std::cout << " .add(Signature().dtype(\"fp16\").layout(\"rcr\"),\n"; + std::cout + << " Algorithm().tile(128,128,64).pipeline(\"compv3\").scheduler(\"intrawave\"),\n"; + std::cout << " \"gfx942\")\n"; + std::cout << " -> Auto-filled: wave(2,2,1), warp(32,32,16), epilogue(\"cshuffle\")\n\n"; + + std::cout << "2. AUTOCORRECT (invalid params fixed):\n"; + std::cout << " .add(..., Algorithm().wave(1,1,1)...)\n"; + std::cout << " -> wave(1,1,1) invalid for gfx942, corrected to wave(2,2,1)\n\n"; + + std::cout << "3. FULL (all params explicit):\n"; + std::cout << " .add(..., " + "Algorithm().tile().wave().warp().pipeline().scheduler().epilogue().pad()...)\n"; + std::cout << " -> No changes needed\n\n"; + + std::string gfx_arch = args.get("--arch", "gfx942"); + + // ========================================================================= + // Step 1: Show Declared Kernel Sets + // ========================================================================= + std::cout << "Step 1: Declared Kernel Sets\n"; + KernelSetRegistry::instance().print(); + + const auto& decl_set = KernelSetRegistry::instance().get("basic_gemm_kernels"); + std::cout << " 'basic_gemm_kernels': " << decl_set.size() << " declaration(s)\n"; + + // ========================================================================= + // Step 2: Create Registry and Register Kernels + // ========================================================================= + std::cout << "\nStep 2: Register Kernels\n"; + + Registry registry; + // Use generic macro + REGISTER_GENERATED_KERNELS(registry, gfx_arch); + + std::cout << " Registered " << registry.size() << " kernel(s)\n"; + + // List kernels if requested + if(args.has("--list") || args.has("--list-verbose")) + { + std::cout << "\n"; + print_registered_kernels(registry, std::cout, args.has("--list-verbose")); + return 0; + } + + // ========================================================================= + // Step 3: Create Dispatcher + // ========================================================================= + std::cout << "\nStep 3: Create Dispatcher\n"; + Dispatcher dispatcher(®istry); + + // ========================================================================= + // Step 4: Setup Problem + // ========================================================================= + int size = args.get_int("--size", 1024); + const int M = size, N = size, K = size; + + std::cout << "\nStep 4: Setup Problem (" << M << "x" << N << "x" << K << ")\n"; + + Problem problem(M, N, K); + + using DataType = ck_tile::fp16_t; + GpuBuffer a_dev(M * K); + GpuBuffer b_dev(K * N); + GpuBuffer c_dev(M * N); + + std::vector a_host(M * K, DataType(1.0f)); + std::vector b_host(K * N, DataType(1.0f)); + a_dev.copy_from_host(a_host.data()); + b_dev.copy_from_host(b_host.data()); + c_dev.zero(); + + // ========================================================================= + // Step 5: Select and Run + // ========================================================================= + std::cout << "\nStep 5: Select and Run\n"; + + auto selected = dispatcher.select_kernel(problem); + if(!selected) + { + std::cerr << "ERROR: No kernel found!\n"; + return 1; + } + std::cout << " Selected: " << selected->get_name() << "\n"; + + float time_ms = dispatcher.run(a_dev.get(), b_dev.get(), c_dev.get(), problem, nullptr); + std::cout << " Time: " << std::fixed << std::setprecision(4) << time_ms << " ms\n"; + std::cout << " TFLOPS: " << std::setprecision(2) << calculate_tflops(M, N, K, time_ms) << "\n"; + + // ========================================================================= + // Step 6: Verify + // ========================================================================= + std::cout << "\nStep 6: Verify\n"; + std::vector c_host(M * N); + c_dev.copy_to_host(c_host.data()); + + const float expected = static_cast(K); + int errors = 0; + for(int i = 0; i < M * N; ++i) + { + if(std::abs(static_cast(c_host[i]) - expected) > 0.01f * expected + 1.0f) + ++errors; + } + + bool passed = (errors == 0); + std::cout << " Expected: " << expected << ", Errors: " << errors << "\n"; + std::cout << " Status: " << (passed ? "PASS" : "FAIL") << "\n"; + + // ========================================================================= + // Summary + // ========================================================================= + print_separator(); + std::cout << "DECLARATION PATTERNS SUMMARY:\n"; + print_separator(); + std::cout << R"( + 1. AUTOFILL: Specify only required params, system fills defaults + - Useful for quick prototyping + - Guarantees valid configuration + + 2. AUTOCORRECT: System validates and fixes invalid params + - wave(1,1,1) -> wave(2,2,1) on gfx942 + - Invalid pipeline/scheduler combos fixed + - Logs corrections for debugging + + 3. FULL: All params explicit - no changes made + - Full control over configuration + - Best for production/tuning +)"; + print_separator(); + + return passed ? 0 : 1; +} diff --git a/dispatcher/examples/gemm/cpp/02_multi_size.cpp b/dispatcher/examples/gemm/cpp/02_multi_size.cpp new file mode 100644 index 00000000000..5e620209f4c --- /dev/null +++ b/dispatcher/examples/gemm/cpp/02_multi_size.cpp @@ -0,0 +1,215 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +/** + * Example 02: Multi-Size GEMM with Wildcard Expansion + * + * Demonstrates the WILDCARD feature where specifying wildcards causes + * the build system to expand to ALL valid configurations for the architecture. + * + * WILDCARD SYNTAX: + * - Integer params: ANY_INT or -1 (both are equivalent, ANY_INT is just a #define for -1) + * - String params: "*" (for pipeline, scheduler) + * + * The kernel declaration: + * .add(..., Algorithm().tile(64,64,64).wave(ANY_INT,ANY_INT,1).warp(-1,-1,-1) + * .pipeline("*").scheduler("*"), ...) + * + * Expands to multiple kernels: + * - wave: (1,4,1), (2,2,1), (4,1,1) -> 3 options + * - warp: (16,16,32), (32,32,16) -> 2 options + * - pipeline: "compv3" -> 1 option (compv4 requires special handling) + * - scheduler: "intrawave" -> 1 option + * + * Raw expansion: 3 × 2 = 6 configs, but arch filter validates each: + * - tile_m must be divisible by (warp_m × warp_tile_m) + * - tile_n must be divisible by (warp_n × warp_tile_n) + * - Some wave/warp combos invalid: (4,1,1)+(32,32,16), (1,4,1)+(32,32,16) + * Result: 4 valid wildcard kernels + 1 explicit = 5 total + * + * Build: cd dispatcher/build && cmake .. && make gemm_02_multi_size + * Usage: ./gemm_02_multi_size [--max-size N] [--help] + */ + +#include +#include +#include +#include + +#include "ck_tile/dispatcher.hpp" +#include "ck_tile/dispatcher/kernel_decl.hpp" +#include "ck_tile/dispatcher/example_args.hpp" + +using namespace ck_tile::dispatcher; +using namespace ck_tile::dispatcher::utils; +using Signature = decl::Signature; +using Algorithm = decl::Algorithm; + +// ============================================================================= +// KERNEL SET: Demonstrates Wildcard Expansion +// ============================================================================= + +DECL_KERNEL_SET(multi_size_kernels, + // ------------------------------------------------------------------------- + // Kernel 1: Explicit - all parameters specified (no expansion) + // ------------------------------------------------------------------------- + .add(Signature().dtype("fp16").layout("rcr"), + Algorithm() + .tile(64, 64, 32) + .wave(2, 2, 1) + .warp(16, 16, 32) + .pipeline("compv3") + .scheduler("intrawave") + .epilogue("cshuffle"), + "gfx942") + + // ------------------------------------------------------------------------- + // Kernel 2: WILDCARD - expands to multiple valid configurations + // Wildcards: ANY_INT == -1 (for integers), "*" (for strings) + // ------------------------------------------------------------------------- + .add(Signature().dtype("fp16").layout("rcr"), + Algorithm() + .tile(64, 64, 64) + .wave(ANY_INT, ANY_INT, 1) // ANY_INT → (1,4,1), (2,2,1), (4,1,1) + .warp(-1, -1, -1) // -1 same as ANY_INT → (16,16,32), (32,32,16) + .pipeline("*") // "*" → valid pipelines + .scheduler("*") // "*" → valid schedulers + .epilogue("cshuffle"), + "gfx942")); +// Raw: 3×2=6, arch filter removes 2 invalid → 4 valid kernels + +// ============================================================================= +// MAIN +// ============================================================================= + +int main(int argc, char* argv[]) +{ + ExampleArgs args("Example 02: Multi-Size GEMM with Wildcards", + "Demonstrates wildcard expansion for kernel generation"); + args.add_option("--max-size", "4096", "Maximum problem size to test"); + args.add_option("--arch", "gfx942", "GPU architecture"); + args.add_flag("--list", "List all registered kernels"); + args.add_flag("--list-verbose", "List kernels with full configuration details"); + + if(!args.parse(argc, argv)) + return 0; + + int max_size = args.get_int("--max-size", 4096); + std::string gfx_arch = args.get("--arch", "gfx942"); + + print_header("Example 02: Multi-Size GEMM with Wildcards"); + + // ========================================================================= + // Show Wildcard Expansion Concept + // ========================================================================= + std::cout << "\nWILDCARD EXPANSION:\n"; + std::cout << "===================\n"; + std::cout << R"( + Wildcard syntax: + ANY_INT or -1 -> expands integer params to all valid values + "*" -> expands string params (pipeline/scheduler) to valid values + + Declaration with wildcards: + .tile(64, 64, 64) -> fixed tile size (no wildcard) + .wave(ANY_INT, ANY_INT, 1) -> expands to (1,4,1), (2,2,1), (4,1,1) = 3 + .warp(-1, -1, -1) -> expands to (16,16,32), (32,32,16) = 2 + .pipeline("*") -> expands to valid pipelines = 1 + .scheduler("*") -> expands to valid schedulers = 1 + + Expanded: 3 × 2 = 6 configs, but arch filter validates each: + - wave×warp must divide tile: (4,1,1)×(32,32,16) invalid for 64x64 + - Result: 4 valid kernels from wildcard + 1 explicit = 5 total +)"; + + // ========================================================================= + // Setup Registry and Dispatcher + // ========================================================================= + std::cout << "\nStep 1: Register Kernels\n"; + std::cout << "------------------------\n"; + + Registry registry; + registry.set_name("multi_size_registry"); + + // Register kernels from generated header (includes expanded wildcards) + // Use generic macro - no need to hardcode example name + REGISTER_GENERATED_KERNELS(registry, gfx_arch); + std::cout << " Registered " << registry.size() << " kernel(s) from wildcard expansion\n"; + + if(args.has("--list") || args.has("--list-verbose")) + { + std::cout << "\n"; + print_registered_kernels(registry, std::cout, args.has("--list-verbose")); + return 0; + } + + Dispatcher dispatcher(®istry); + std::cout << " Max size: " << max_size << "\n"; + + // ========================================================================= + // Run Multiple Problem Sizes + // ========================================================================= + std::cout << "\nStep 2: Run Multiple Sizes\n"; + print_separator(); + std::cout << std::setw(12) << "M" << std::setw(12) << "N" << std::setw(12) << "K" + << std::setw(12) << "Time(ms)" << std::setw(12) << "TFLOPS" << "\n"; + print_separator(); + + std::vector> all_sizes = { + {256, 256, 256}, + {512, 512, 512}, + {1024, 1024, 1024}, + {2048, 2048, 2048}, + {4096, 4096, 4096}, + }; + + std::vector> sizes; + for(const auto& [M, N, K] : all_sizes) + { + if(std::max({M, N, K}) <= max_size) + sizes.push_back({M, N, K}); + } + + using DataType = ck_tile::fp16_t; + bool all_passed = true; + + for(const auto& [M, N, K] : sizes) + { + Problem problem(M, N, K); + + GpuBuffer a_dev(M * K); + GpuBuffer b_dev(K * N); + GpuBuffer c_dev(M * N); + + std::vector a_host(M * K, DataType(1.0f)); + std::vector b_host(K * N, DataType(1.0f)); + a_dev.copy_from_host(a_host.data()); + b_dev.copy_from_host(b_host.data()); + c_dev.zero(); + + float time_ms = dispatcher.run(a_dev.get(), b_dev.get(), c_dev.get(), problem, nullptr); + double tflops = calculate_tflops(M, N, K, time_ms); + + std::cout << std::setw(12) << M << std::setw(12) << N << std::setw(12) << K << std::setw(12) + << std::fixed << std::setprecision(4) << time_ms << std::setw(12) + << std::setprecision(2) << tflops << "\n"; + + // Verify + std::vector c_host(M * N); + c_dev.copy_to_host(c_host.data()); + float expected = static_cast(K); + int errors = 0; + for(int i = 0; i < M * N; ++i) + { + if(std::abs(static_cast(c_host[i]) - expected) > 0.01f * expected + 1.0f) + ++errors; + } + if(errors > 0) + all_passed = false; + } + + print_separator(); + std::cout << "Status: " << (all_passed ? "ALL PASSED" : "SOME FAILED") << "\n"; + print_separator(); + + return all_passed ? 0 : 1; +} diff --git a/dispatcher/examples/gemm/cpp/03_benchmark_validation.cpp b/dispatcher/examples/gemm/cpp/03_benchmark_validation.cpp new file mode 100644 index 00000000000..61608c79149 --- /dev/null +++ b/dispatcher/examples/gemm/cpp/03_benchmark_validation.cpp @@ -0,0 +1,344 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +/** + * Example 03: GEMM Benchmark & Validation + * + * Combined example demonstrating: + * 1. Benchmarking with statistics (warmup, iterations, min/max/mean/median) + * 2. Validation against CK Tile reference (CPU or GPU) + * + * Build: cd dispatcher/build && cmake .. && make gemm_03_benchmark_validation + * Usage: ./gemm_03_benchmark_validation [--size N] [--verify MODE] [--benchmark] + * + * Options: + * --size N Problem size MxNxK (default: 512) + * --verify MODE 0=none, 1=CPU ref, 2=GPU ref (default: 1) + * --benchmark Run full benchmark with statistics + * --warmup N Warmup iterations (default: 5) + * --iterations N Benchmark iterations (default: 20) + */ + +#include +#include +#include +#include +#include +#include +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/host.hpp" +#include "ck_tile/host/reference/reference_gemm.hpp" + +#include "ck_tile/dispatcher.hpp" +#include "ck_tile/dispatcher/kernel_decl.hpp" +#include "ck_tile/dispatcher/example_args.hpp" + +using namespace ck_tile::dispatcher; +using namespace ck_tile::dispatcher::utils; +using namespace ck_tile::literals; +using Signature = decl::Signature; +using Algorithm = decl::Algorithm; + +// ============================================================================= +// KERNEL SET: High-performance kernels for benchmarking/validation +// ============================================================================= + +DECL_KERNEL_SET(benchmark_validation_kernels, + .add(Signature().dtype("fp16").layout("rcr"), + Algorithm() + .tile(128, 128, 32) + .wave(2, 2, 1) + .warp(32, 32, 16) + .pipeline("compv3") + .scheduler("intrawave") + .epilogue("cshuffle"), + "gfx942")); + +// ============================================================================= +// Helper: Layout detection +// ============================================================================= + +template +constexpr auto is_row_major(Layout) +{ + return ck_tile::bool_constant>{}; +} + +// ============================================================================= +// MAIN +// ============================================================================= + +int main(int argc, char* argv[]) +{ + ExampleArgs args("Example 03: GEMM Benchmark & Validation", + "Benchmark and/or validate GEMM output against reference"); + args.add_option("--size", "512", "Problem size MxNxK"); + args.add_option("--verify", "1", "Verification: 0=none, 1=CPU ref, 2=GPU ref"); + args.add_flag("--benchmark", "Run benchmark with statistics"); + args.add_option("--warmup", "5", "Warmup iterations"); + args.add_option("--iterations", "20", "Benchmark iterations"); + args.add_option("--rtol", "0.01", "Relative tolerance"); + args.add_option("--atol", "0.01", "Absolute tolerance"); + args.add_option("--arch", "gfx942", "GPU architecture"); + + if(!args.parse(argc, argv)) + return 0; + + int M = args.get_int("--size", 512); + int N = M; + int K = M; + int verify = args.get_int("--verify", 1); + bool do_benchmark = args.has("--benchmark"); + int warmup = args.get_int("--warmup", 5); + int iterations = args.get_int("--iterations", 20); + float rtol = args.get_float("--rtol", 0.01f); + float atol = args.get_float("--atol", 0.01f); + std::string gfx_arch = args.get("--arch", "gfx942"); + + print_header("Example 03: GEMM Benchmark & Validation"); + + std::cout << "\nConfiguration:\n"; + std::cout << " Problem: " << M << " x " << N << " x " << K << "\n"; + std::cout << " Layout: RCR (A=row, B=col, C=row)\n"; + std::cout << " Verify: " << verify; + if(verify == 0) + std::cout << " (disabled)"; + else if(verify == 1) + std::cout << " (CPU reference)"; + else if(verify == 2) + std::cout << " (GPU reference)"; + std::cout << "\n"; + std::cout << " Benchmark: " << (do_benchmark ? "yes" : "no") << "\n"; + if(do_benchmark) + { + std::cout << " Warmup: " << warmup << " iterations\n"; + std::cout << " Measure: " << iterations << " iterations\n"; + } + + // ========================================================================= + // Setup Registry and Dispatcher + // ========================================================================= + Registry registry; + REGISTER_GENERATED_KERNELS(registry, gfx_arch); + Dispatcher dispatcher(®istry); + + std::cout << " Kernels: " << registry.size() << " registered\n"; + print_registered_kernels(registry); + + // ========================================================================= + // Initialize data with proper tensor descriptors + // ========================================================================= + using ALayout = ck_tile::tensor_layout::gemm::RowMajor; + using BLayout = ck_tile::tensor_layout::gemm::ColumnMajor; + using CLayout = ck_tile::tensor_layout::gemm::RowMajor; + + using ADataType = ck_tile::fp16_t; + using BDataType = ck_tile::fp16_t; + using CDataType = ck_tile::fp16_t; + using AccDataType = float; + + auto stride_a = ck_tile::get_default_stride(M, K, 0_uz, is_row_major(ALayout{})); + auto stride_b = ck_tile::get_default_stride(K, N, 0_uz, is_row_major(BLayout{})); + auto stride_c = ck_tile::get_default_stride(M, N, 0_uz, is_row_major(CLayout{})); + + ck_tile::HostTensor a_m_k( + ck_tile::host_tensor_descriptor(M, K, stride_a, is_row_major(ALayout{}))); + ck_tile::HostTensor b_k_n( + ck_tile::host_tensor_descriptor(K, N, stride_b, is_row_major(BLayout{}))); + ck_tile::HostTensor c_m_n_dev( + ck_tile::host_tensor_descriptor(M, N, stride_c, is_row_major(CLayout{}))); + ck_tile::HostTensor c_m_n_ref( + ck_tile::host_tensor_descriptor(M, N, stride_c, is_row_major(CLayout{}))); + + ck_tile::FillUniformDistribution{-0.5f, 0.5f}(a_m_k); + ck_tile::FillUniformDistribution{-0.5f, 0.5f}(b_k_n); + + std::cout << "\nData:\n"; + std::cout << " A: " << M << " x " << K << " (fp16, row-major)\n"; + std::cout << " B: " << K << " x " << N << " (fp16, col-major)\n"; + std::cout << " C: " << M << " x " << N << " (fp16, row-major)\n"; + + // GPU memory + ck_tile::DeviceMem a_dev(a_m_k.get_element_space_size_in_bytes()); + ck_tile::DeviceMem b_dev(b_k_n.get_element_space_size_in_bytes()); + ck_tile::DeviceMem c_dev(c_m_n_dev.get_element_space_size_in_bytes()); + + a_dev.ToDevice(a_m_k.data()); + b_dev.ToDevice(b_k_n.data()); + + // ========================================================================= + // Compute Reference (if needed) + // ========================================================================= + if(verify > 0) + { + std::cout << "\nComputing reference...\n"; + c_m_n_ref.SetZero(); + + if(verify == 1) + { + std::cout << " Using CPU reference (ck_tile::reference_gemm)\n"; + ck_tile::reference_gemm( + a_m_k, b_k_n, c_m_n_ref); + } + else if(verify == 2) + { + std::cout << " Using GPU reference (ck_tile::reference_gemm_gpu)\n"; + ck_tile::DeviceMem c_ref_dev(c_m_n_ref.get_element_space_size_in_bytes()); + c_ref_dev.SetZero(); + + ck_tile::reference_gemm_gpu( + static_cast(a_dev.GetDeviceBuffer()), + static_cast(b_dev.GetDeviceBuffer()), + static_cast(c_ref_dev.GetDeviceBuffer()), + M, + N, + K, + stride_a, + stride_b, + stride_c); + + (void)hipDeviceSynchronize(); + c_ref_dev.FromDevice(c_m_n_ref.data()); + } + std::cout << " Reference complete.\n"; + } + + // ========================================================================= + // Run Kernel + // ========================================================================= + Problem problem(M, N, K); + auto selected = dispatcher.select_kernel(problem); + + std::cout << "\nRunning kernel:\n"; + if(selected) + std::cout << " Selected: " << selected->get_name() << "\n"; + + c_dev.SetZero(); + float time_ms = 0.0f; + std::vector times; + + if(do_benchmark) + { + // Warmup + std::cout << " Warming up (" << warmup << " iterations)...\n"; + for(int i = 0; i < warmup; ++i) + { + c_dev.SetZero(); + (void)dispatcher.run(static_cast(a_dev.GetDeviceBuffer()), + static_cast(b_dev.GetDeviceBuffer()), + static_cast(c_dev.GetDeviceBuffer()), + problem, + nullptr); + } + + // Benchmark + std::cout << " Benchmarking (" << iterations << " iterations)...\n"; + times.reserve(iterations); + for(int i = 0; i < iterations; ++i) + { + c_dev.SetZero(); + float t = dispatcher.run(static_cast(a_dev.GetDeviceBuffer()), + static_cast(b_dev.GetDeviceBuffer()), + static_cast(c_dev.GetDeviceBuffer()), + problem, + nullptr); + times.push_back(t); + } + time_ms = *std::min_element(times.begin(), times.end()); + } + else + { + // Single run + time_ms = dispatcher.run(static_cast(a_dev.GetDeviceBuffer()), + static_cast(b_dev.GetDeviceBuffer()), + static_cast(c_dev.GetDeviceBuffer()), + problem, + nullptr); + } + + c_dev.FromDevice(c_m_n_dev.data()); + + // ========================================================================= + // Results + // ========================================================================= + double flops = 2.0 * M * N * K; + double tflops = flops / (time_ms * 1e9); + + print_separator(); + std::cout << "Performance:\n"; + print_separator(); + + if(do_benchmark && !times.empty()) + { + std::sort(times.begin(), times.end()); + float min_t = times.front(); + float max_t = times.back(); + float median_t = times[times.size() / 2]; + float mean_t = std::accumulate(times.begin(), times.end(), 0.0f) / times.size(); + + std::cout << std::fixed << std::setprecision(4); + std::cout << " Min: " << min_t << " ms (" << std::setprecision(2) + << (flops / (min_t * 1e9)) << " TFLOPS)\n"; + std::cout << std::setprecision(4); + std::cout << " Max: " << max_t << " ms\n"; + std::cout << " Mean: " << mean_t << " ms (" << std::setprecision(2) + << (flops / (mean_t * 1e9)) << " TFLOPS)\n"; + std::cout << std::setprecision(4); + std::cout << " Median: " << median_t << " ms (" << std::setprecision(2) + << (flops / (median_t * 1e9)) << " TFLOPS)\n"; + } + else + { + std::cout << std::fixed << std::setprecision(4); + std::cout << " Time: " << time_ms << " ms\n"; + std::cout << std::setprecision(2); + std::cout << " TFLOPS: " << tflops << "\n"; + } + + // ========================================================================= + // Validation + // ========================================================================= + bool pass = true; + + if(verify > 0) + { + print_separator(); + std::cout << "Validation:\n"; + print_separator(); + std::cout << " Tolerance: rtol=" << rtol << ", atol=" << atol << "\n"; + + pass = ck_tile::check_err(c_m_n_dev, c_m_n_ref, "Validation Error!", rtol, atol); + + float max_abs_diff = 0.0f; + float max_rel_diff = 0.0f; + for(size_t i = 0; i < c_m_n_dev.get_element_space_size(); ++i) + { + float dev_val = static_cast(c_m_n_dev.mData[i]); + float ref_val = static_cast(c_m_n_ref.mData[i]); + float abs_diff = std::abs(dev_val - ref_val); + float rel_diff = (ref_val != 0.0f) ? abs_diff / std::abs(ref_val) : abs_diff; + max_abs_diff = std::max(max_abs_diff, abs_diff); + max_rel_diff = std::max(max_rel_diff, rel_diff); + } + + std::cout << " Max abs diff: " << max_abs_diff << "\n"; + std::cout << " Max rel diff: " << max_rel_diff << "\n"; + } + + // ========================================================================= + // Summary + // ========================================================================= + print_separator(); + std::cout << "Result: " << (pass ? "PASS" : "FAIL") << "\n"; + print_separator(); + + return pass ? 0 : 1; +} diff --git a/dispatcher/examples/gemm/cpp/04_heuristics.cpp b/dispatcher/examples/gemm/cpp/04_heuristics.cpp new file mode 100644 index 00000000000..2a8753cdffb --- /dev/null +++ b/dispatcher/examples/gemm/cpp/04_heuristics.cpp @@ -0,0 +1,168 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +/** + * Example 04: Custom Heuristics + * + * Demonstrates custom kernel selection heuristics for different workloads. + * + * Build: cd dispatcher/build && cmake .. && make gemm_04_heuristics + */ + +#include +#include +#include +#include +#include + +#include "ck_tile/dispatcher.hpp" +#include "ck_tile/dispatcher/kernel_decl.hpp" +#include "ck_tile/dispatcher/example_args.hpp" + +using namespace ck_tile::dispatcher; +using namespace ck_tile::dispatcher::utils; +using Signature = decl::Signature; +using Algorithm = decl::Algorithm; + +// ============================================================================= +// KERNEL SET: Multiple tile sizes for heuristic-based selection +// ============================================================================= + +DECL_KERNEL_SET(heuristics_kernels, + // Small tile - low latency + .add(Signature().dtype("fp16").layout("rcr"), + Algorithm() + .tile(64, 64, 32) + .wave(2, 2, 1) + .warp(32, 32, 16) + .pipeline("compv3") + .scheduler("intrawave") + .epilogue("cshuffle"), + "gfx942") + // Medium tile - balanced + .add(Signature().dtype("fp16").layout("rcr"), + Algorithm() + .tile(128, 128, 64) + .wave(2, 2, 1) + .warp(32, 32, 16) + .pipeline("compv3") + .scheduler("intrawave") + .epilogue("cshuffle"), + "gfx942")); + +// ============================================================================= +// Custom Heuristic +// ============================================================================= + +std::vector size_based_heuristic(const Problem& problem) +{ + std::vector ranked_kernels; + int64_t total_elements = problem.M * problem.N; + + if(total_elements < 100000) + { + ranked_kernels = {"gemm_64x64", "gemm_128x128"}; + } + else + { + ranked_kernels = {"gemm_128x128", "gemm_64x64"}; + } + return ranked_kernels; +} + +// ============================================================================= +// MAIN +// ============================================================================= + +int main(int argc, char* argv[]) +{ + ExampleArgs args("Example 04: Custom Heuristics", + "Demonstrates custom kernel selection heuristics"); + args.add_option("--arch", "gfx942", "GPU architecture"); + + if(!args.parse(argc, argv)) + return 0; + + print_header("Example 04: Custom Heuristics"); + + std::string gfx_arch = args.get("--arch", "gfx942"); + + // ========================================================================= + // Setup Registry and Dispatcher + // ========================================================================= + Registry registry; + REGISTER_GENERATED_KERNELS(registry, gfx_arch); + + Dispatcher dispatcher(®istry); + dispatcher.set_strategy(Dispatcher::SelectionStrategy::Heuristic); + dispatcher.set_heuristic(size_based_heuristic); + + std::cout << "\nSetup:\n"; + std::cout << " Registry: " << registry.size() << " kernel(s)\n"; + std::cout << " Strategy: Heuristic (size-based)\n"; + + // ========================================================================= + // Test Different Problem Sizes + // ========================================================================= + std::cout << "\nTesting heuristic selection:\n"; + print_separator(); + + using DataType = ck_tile::fp16_t; + + std::vector> sizes = { + {128, 128, 64}, + {512, 512, 256}, + {2048, 2048, 1024}, + }; + + bool all_passed = true; + + for(const auto& [M, N, K] : sizes) + { + Problem problem(M, N, K); + auto selected = dispatcher.select_kernel(problem); + + std::cout << "Problem " << M << "x" << N << "x" << K << ":\n"; + if(selected) + { + std::cout << " Selected: " << selected->get_name() << "\n"; + } + + GpuBuffer a_dev(M * K); + GpuBuffer b_dev(K * N); + GpuBuffer c_dev(M * N); + + std::vector a_host(M * K, DataType(1.0f)); + std::vector b_host(K * N, DataType(1.0f)); + a_dev.copy_from_host(a_host.data()); + b_dev.copy_from_host(b_host.data()); + c_dev.zero(); + + float time_ms = dispatcher.run(a_dev.get(), b_dev.get(), c_dev.get(), problem, nullptr); + double tflops = calculate_tflops(M, N, K, time_ms); + + std::cout << " Time: " << std::fixed << std::setprecision(4) << time_ms << " ms\n"; + std::cout << " TFLOPS: " << std::setprecision(2) << tflops << "\n"; + + // Verify + std::vector c_host(M * N); + c_dev.copy_to_host(c_host.data()); + float expected = static_cast(K); + int errors = 0; + for(int i = 0; i < M * N; ++i) + { + float actual = static_cast(c_host[i]); + if(std::abs(actual - expected) > 0.01f * expected + 1.0f) + ++errors; + } + bool pass = (errors == 0); + std::cout << " Verify: " << (pass ? "PASS" : "FAIL") << "\n"; + if(!pass) + all_passed = false; + print_separator(); + } + + std::cout << "Overall: " << (all_passed ? "ALL PASSED" : "SOME FAILED") << "\n"; + + return all_passed ? 0 : 1; +} diff --git a/dispatcher/examples/gemm/cpp/05_json_export.cpp b/dispatcher/examples/gemm/cpp/05_json_export.cpp new file mode 100644 index 00000000000..75ed7308af9 --- /dev/null +++ b/dispatcher/examples/gemm/cpp/05_json_export.cpp @@ -0,0 +1,127 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +/** + * Example 05: JSON Export + * + * Demonstrates exporting registry information to JSON format. + * + * Build: cd dispatcher/build && cmake .. && make gemm_05_json_export + */ + +#include +#include +#include + +#include "ck_tile/dispatcher.hpp" +#include "ck_tile/dispatcher/kernel_decl.hpp" +#include "ck_tile/dispatcher/example_args.hpp" + +using namespace ck_tile::dispatcher; +using namespace ck_tile::dispatcher::utils; +using Signature = decl::Signature; +using Algorithm = decl::Algorithm; + +// ============================================================================= +// KERNEL SET: Multiple kernels for JSON export demo +// ============================================================================= + +DECL_KERNEL_SET(json_export_kernels, + .add(Signature().dtype("fp16").layout("rcr"), + Algorithm() + .tile(64, 64, 32) + .wave(2, 2, 1) + .warp(32, 32, 16) + .pipeline("compv3") + .scheduler("intrawave") + .epilogue("cshuffle"), + "gfx942") + .add(Signature().dtype("fp16").layout("rcr"), + Algorithm() + .tile(128, 128, 64) + .wave(2, 2, 1) + .warp(32, 32, 16) + .pipeline("compv3") + .scheduler("intrawave") + .epilogue("cshuffle"), + "gfx942")); + +// ============================================================================= +// MAIN +// ============================================================================= + +int main(int argc, char* argv[]) +{ + ExampleArgs args("Example 05: JSON Export", "Export registry information to JSON format"); + args.add_option("--output", "registry.json", "Output JSON file path"); + args.add_option("--arch", "gfx942", "GPU architecture"); + args.add_flag("--list", "List all kernel sets"); + + if(!args.parse(argc, argv)) + return 0; + + print_header("Example 05: JSON Export"); + + std::string gfx_arch = args.get("--arch", "gfx942"); + + if(args.has("--list")) + { + std::cout << "\nDeclared Kernel Sets:\n"; + KernelSetRegistry::instance().print(); + return 0; + } + + std::string output_file = args.get("--output", "registry.json"); + + // ========================================================================= + // Setup Registry + // ========================================================================= + std::cout << "\nSetting up registry...\n"; + Registry registry; + registry.set_name("json_export_registry"); + + REGISTER_GENERATED_KERNELS(registry, gfx_arch); + + std::cout << " Registry: " << registry.get_name() << "\n"; + std::cout << " Kernels: " << registry.size() << "\n"; + + // ========================================================================= + // Export to JSON + // ========================================================================= + std::cout << "\nExporting to JSON...\n"; + + std::string json = registry.export_json(true); + + std::cout << "\nJSON Preview (first 500 chars):\n"; + print_separator(); + std::cout << json.substr(0, std::min(size_t(500), json.size())); + if(json.size() > 500) + std::cout << "\n..."; + std::cout << "\n"; + print_separator(); + + // Write to file + std::ofstream file(output_file); + if(file.is_open()) + { + file << json; + file.close(); + std::cout << "\nExported to: " << output_file << "\n"; + std::cout << "File size: " << json.size() << " bytes\n"; + } + else + { + std::cerr << "Failed to write to: " << output_file << "\n"; + return 1; + } + + // ========================================================================= + // Also show kernel set declarations + // ========================================================================= + std::cout << "\nKernel Set Declarations:\n"; + print_separator(); + KernelSetRegistry::instance().print(); + print_separator(); + + return 0; +} diff --git a/dispatcher/examples/gemm/cpp/06_multi_registry.cpp b/dispatcher/examples/gemm/cpp/06_multi_registry.cpp new file mode 100644 index 00000000000..3077f2d754f --- /dev/null +++ b/dispatcher/examples/gemm/cpp/06_multi_registry.cpp @@ -0,0 +1,294 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +/** + * Example 06: Multiple Registries and Multiple Kernel Sets + * + * Demonstrates: + * - Multiple DECL_KERNEL_SET declarations (each with multiple kernels) + * - Separate Registry instances for different workload types + * - Independent Dispatchers that select from their respective registries + * + * Registration patterns: + * - REGISTER_GENERATED_KERNELS(registry, arch) -> all kernels to one registry + * - REGISTER_KERNEL_SET("set_name", registry, arch) -> specific set by name + * - generated::get_kernel_set_names() -> list available set names + * + * Build: cd dispatcher/build && cmake .. && make gemm_06_multi_registry + * Usage: ./gemm_06_multi_registry [--list] [--help] + */ + +#include +#include +#include +#include + +#include "ck_tile/dispatcher.hpp" +#include "ck_tile/dispatcher/kernel_decl.hpp" +#include "ck_tile/dispatcher/example_args.hpp" + +using namespace ck_tile::dispatcher; +using namespace ck_tile::dispatcher::utils; +using Signature = decl::Signature; +using Algorithm = decl::Algorithm; + +// ============================================================================= +// KERNEL SETS: Multiple sets with multiple kernels each +// ============================================================================= + +// Compute-bound kernel set: Large tiles for high arithmetic intensity +// Max tile with 32x32 warp is 128x128 (16 warps = 1024 threads) +DECL_KERNEL_SET(compute_bound_set, + .add(Signature().dtype("fp16").layout("rcr"), + Algorithm() + .tile(128, 128, 64) // Large tile, max for 32x32 warp + .wave(2, 2, 1) + .warp(32, 32, 16) + .pipeline("compv3") + .scheduler("intrawave") + .epilogue("cshuffle"), + "gfx942") + .add(Signature().dtype("fp16").layout("rcr"), + Algorithm() + .tile(128, 128, 32) // Same tile, different K for variety + .wave(2, 2, 1) + .warp(32, 32, 16) + .pipeline("compv3") + .scheduler("intrawave") + .epilogue("cshuffle"), + "gfx942")); + +// Memory-bound kernel set: Smaller tiles for better cache efficiency +DECL_KERNEL_SET(memory_bound_set, + .add(Signature().dtype("fp16").layout("rcr"), + Algorithm() + .tile(64, 64, 32) + .wave(2, 2, 1) + .warp(32, 32, 16) + .pipeline("compv3") + .scheduler("intrawave") + .epilogue("cshuffle"), + "gfx942") + .add(Signature().dtype("fp16").layout("rcr"), + Algorithm() + .tile(128, 64, 32) + .wave(2, 2, 1) + .warp(32, 32, 16) + .pipeline("compv3") + .scheduler("intrawave") + .epilogue("cshuffle"), + "gfx942")); + +// Latency-optimized: Minimal overhead tiles +DECL_KERNEL_SET(latency_set, + .add(Signature().dtype("fp16").layout("rcr"), + Algorithm() + .tile(64, 64, 64) + .wave(2, 2, 1) + .warp(32, 32, 16) + .pipeline("compv3") + .scheduler("intrawave") + .epilogue("cshuffle"), + "gfx942")); + +// ============================================================================= +// MAIN +// ============================================================================= + +int main(int argc, char* argv[]) +{ + ExampleArgs args("Example 06: Multiple Registries", + "Separate registries for different workload types"); + args.add_flag("--list", "List all declared kernel sets"); + args.add_option("--arch", "gfx942", "GPU architecture"); + + if(!args.parse(argc, argv)) + return 0; + + print_header("Example 06: Multiple Registries & Kernel Sets"); + + std::string gfx_arch = args.get("--arch", "gfx942"); + + // ========================================================================= + // Step 1: Show declared kernel sets (from DECL_KERNEL_SET macros) + // ========================================================================= + std::cout << "\nStep 1: Declared Kernel Sets\n"; + std::cout << "-----------------------------\n"; + KernelSetRegistry::instance().print(); + + if(args.has("--list")) + { + // Print detailed info + for(const auto& name : KernelSetRegistry::instance().names()) + { + const auto& set = KernelSetRegistry::instance().get(name); + std::cout << "\n " << name << ":\n"; + for(const auto& decl : set.declarations()) + { + std::cout << " - " << decl.name() << " (tile=" << decl.algorithm.tile_m_ << "x" + << decl.algorithm.tile_n_ << "x" << decl.algorithm.tile_k_ << ")\n"; + } + } + return 0; + } + + // ========================================================================= + // Step 2: Create registries and demonstrate MERGING + // ========================================================================= + std::cout << "\nStep 2: Create and Merge Registries\n"; + std::cout << "------------------------------------\n"; + + // Create individual registries first + Registry compute_registry; + Registry latency_registry; + Registry memory_registry; + + compute_registry.set_name("compute_bound"); + latency_registry.set_name("latency_optimized"); + memory_registry.set_name("memory_bound"); + + // Register kernels to individual registries using set names (no hardcoding) + REGISTER_KERNEL_SET("compute_bound_set", compute_registry, gfx_arch); + REGISTER_KERNEL_SET("latency_set", latency_registry, gfx_arch); + REGISTER_KERNEL_SET("memory_bound_set", memory_registry, gfx_arch); + + std::cout << " Individual registries:\n"; + std::cout << " compute_bound: " << compute_registry.size() << " kernel(s)\n"; + std::cout << " latency_optimized: " << latency_registry.size() << " kernel(s)\n"; + std::cout << " memory_bound: " << memory_registry.size() << " kernel(s)\n"; + + // MERGE compute + latency into a combined registry + Registry combined_registry; + combined_registry.set_name("compute_latency_combined"); + + // Register both sets into combined registry + REGISTER_KERNEL_SET("compute_bound_set", combined_registry, gfx_arch); + REGISTER_KERNEL_SET("latency_set", combined_registry, gfx_arch); + + std::cout << "\n After merging compute + latency:\n"; + std::cout << " combined: " << combined_registry.size() << " kernel(s)\n"; + std::cout << " memory (separate): " << memory_registry.size() << " kernel(s)\n"; + + // ========================================================================= + // Step 3: Create dispatchers - one merged, one separate + // ========================================================================= + std::cout << "\nStep 3: Create Dispatchers\n"; + std::cout << "--------------------------\n"; + + Dispatcher combined_dispatcher(&combined_registry); // compute + latency merged + Dispatcher memory_dispatcher(&memory_registry); // memory separate + + std::cout << " combined_dispatcher: compute + latency kernels (" << combined_registry.size() + << " kernels)\n"; + std::cout << " memory_dispatcher: memory-bound kernels (" << memory_registry.size() + << " kernels)\n"; + + // ========================================================================= + // Step 4: Run with different dispatchers + // ========================================================================= + std::cout << "\nStep 4: Run Workloads\n"; + print_separator(); + + using DataType = ck_tile::fp16_t; + + struct WorkloadTest + { + const char* name; + Dispatcher* dispatcher; + int M, N, K; + }; + + std::vector tests = { + {"Compute-bound (combined)", &combined_dispatcher, 4096, 4096, 4096}, + {"Memory-bound (separate)", &memory_dispatcher, 1024, 1024, 1024}, + {"Latency-opt (combined)", &combined_dispatcher, 512, 512, 512}, + }; + + bool all_passed = true; + + for(const auto& test : tests) + { + Problem problem(test.M, test.N, test.K); + + // Allocate and initialize + GpuBuffer a_dev(test.M * test.K); + GpuBuffer b_dev(test.K * test.N); + GpuBuffer c_dev(test.M * test.N); + + std::vector a_host(test.M * test.K, DataType(1.0f)); + std::vector b_host(test.K * test.N, DataType(1.0f)); + a_dev.copy_from_host(a_host.data()); + b_dev.copy_from_host(b_host.data()); + c_dev.zero(); + + // Select kernel and run + auto selected = test.dispatcher->select_kernel(problem); + float time_ms = + test.dispatcher->run(a_dev.get(), b_dev.get(), c_dev.get(), problem, nullptr); + double tflops = calculate_tflops(test.M, test.N, test.K, time_ms); + + std::cout << test.name << " (" << test.M << "x" << test.N << "x" << test.K << "):\n"; + if(selected) + std::cout << " Selected: " << selected->get_name() << "\n"; + std::cout << " Time: " << std::fixed << std::setprecision(4) << time_ms << " ms\n"; + std::cout << " TFLOPS: " << std::setprecision(2) << tflops << "\n"; + + // Verify ALL elements + std::vector c_host(test.M * test.N); + c_dev.copy_to_host(c_host.data()); + const float expected = static_cast(test.K); + + int num_errors = 0; + float max_error = 0.0f; + for(int i = 0; i < test.M * test.N; ++i) + { + float actual = static_cast(c_host[i]); + float error = std::abs(actual - expected); + max_error = std::max(max_error, error); + // Allow 1% relative tolerance for FP16 accumulation + if(error > 0.01f * expected + 1.0f) + ++num_errors; + } + + bool test_passed = (num_errors == 0); + std::cout << " Verify: " << (test.M * test.N) << " elements, errors=" << num_errors + << "\n"; + std::cout << " Status: " << (test_passed ? "PASS" : "FAIL") << "\n\n"; + + if(!test_passed) + all_passed = false; + } + + // ========================================================================= + // Summary + // ========================================================================= + print_separator(); + std::cout << "Multi-Registry Pattern Summary:\n"; + print_separator(); + std::cout << R"( +// 1. Declare multiple kernel sets +DECL_KERNEL_SET(compute_bound_set, .add(...)); +DECL_KERNEL_SET(memory_bound_set, .add(...)); +DECL_KERNEL_SET(latency_set, .add(...)); + +// 2. Create registries and register by set NAME (no hardcoding!) +Registry combined_reg, memory_reg; +REGISTER_KERNEL_SET("compute_bound_set", combined_reg, arch); // Add compute +REGISTER_KERNEL_SET("latency_set", combined_reg, arch); // Merge latency +REGISTER_KERNEL_SET("memory_bound_set", memory_reg, arch); // Separate + +// 3. Create dispatchers from merged/separate registries +Dispatcher combined_disp(&combined_reg); // Has both compute + latency +Dispatcher memory_disp(&memory_reg); // Has only memory-bound + +// 4. Choose dispatcher based on workload +if (problem.is_memory_bound()) + memory_disp.run(...); +else + combined_disp.run(...); // Handles both compute & latency workloads +)"; + print_separator(); + std::cout << "Overall Status: " << (all_passed ? "ALL PASSED" : "SOME FAILED") << "\n"; + + return all_passed ? 0 : 1; +} diff --git a/dispatcher/examples/gemm/cpp/README.md b/dispatcher/examples/gemm/cpp/README.md new file mode 100644 index 00000000000..1d81a90a0e8 --- /dev/null +++ b/dispatcher/examples/gemm/cpp/README.md @@ -0,0 +1,229 @@ +# GEMM C++ Examples + +CK Tile Dispatcher C++ examples for GEMM (General Matrix Multiplication) operations. + +> **Main Documentation**: [Dispatcher README](../../../README.md) | [Examples Overview](../../README.md) + +## Quick Start + +### Build and Run + +```bash +cd /path/to/composable_kernel/dispatcher +mkdir -p build && cd build + +cmake .. \ + -DCMAKE_PREFIX_PATH=/opt/rocm \ + -DCMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \ + -DBUILD_DISPATCHER_EXAMPLES=ON + +# Build (kernels generated automatically by CMake) +make -j$(nproc) + +# Run examples +cd examples +./gemm_01_basic +./gemm_03_benchmark_validation +./gemm_04_heuristics +``` + +## Examples + +| Example | Description | Complexity | +|---------|-------------|------------| +| [01_basic_gemm.cpp](01_basic_gemm.cpp) | Basic GEMM with declarative API, autofill, autocorrect | ★☆☆☆☆ | +| [02_multi_size.cpp](02_multi_size.cpp) | Wildcard expansion for multiple configurations | ★★☆☆☆ | +| [03_benchmark_validation.cpp](03_benchmark_validation.cpp) | Performance benchmarking with CPU reference validation | ★★☆☆☆ | +| [04_heuristics.cpp](04_heuristics.cpp) | Heuristic-based kernel selection | ★★★☆☆ | +| [05_json_export.cpp](05_json_export.cpp) | Registry JSON export for external tools | ★★☆☆☆ | +| [06_multi_registry.cpp](06_multi_registry.cpp) | Multiple registries with named kernel sets | ★★★☆☆ | + +## Example Details + +### 01_basic_gemm.cpp - Basic GEMM +Demonstrates the declarative kernel API with three patterns: + +1. **Autofill Pattern** - Minimal specification, defaults filled automatically +2. **Autocorrect Pattern** - Invalid parameters corrected at build time +3. **Full Specification Pattern** - Complete kernel configuration + +```cpp +DECL_KERNEL_SET(basic_kernels, + // Pattern 1: Autofill - minimal specification + .add( + Signature().dtype("fp16").layout("rcr"), + Algorithm(), // Defaults filled by autofill + "gfx942" + ) + // Pattern 2: Full specification + .add( + Signature().dtype("fp16").layout("rcr"), + Algorithm().tile(256, 256, 32).wave(2, 2, 1).warp(32, 32, 16) + .pipeline("compv4").scheduler("intrawave"), + "gfx942" + ) +); +``` + +**Features:** +- Uses generic `REGISTER_GENERATED_KERNELS` macro +- `print_registered_kernels()` utility for debugging +- Demonstrates autofill messages during build + +### 02_multi_size.cpp - Wildcard Expansion +Demonstrates automatic generation of multiple kernel configurations: + +```cpp +DECL_KERNEL_SET(multi_kernels, + .add( + Signature().dtype("fp16").layout("rcr"), + Algorithm().tile(*, *, 32) // Wildcard tile M and N + .wave(2, 2, 1) + .warp(32, 32, 16) + .pipeline("compv4") + .scheduler("intrawave"), + "gfx942" + ) +); +``` + +**Wildcard Values:** +- `*`, `-1`, or `ANY_INT` expand to all valid configurations +- Architecture filter prunes invalid combinations automatically +- Example generates 5 valid kernels after arch filtering (from 7 expansions) + +### 03_benchmark_validation.cpp - Benchmark + Validation +Consolidated example combining performance benchmarking with correctness validation: + +```bash +# Benchmark only +./gemm_03_benchmark_validation --warmup 10 --iterations 100 + +# With CPU validation +./gemm_03_benchmark_validation --verify 1 --rtol 1e-3 --atol 1e-3 + +# With GPU reference validation (faster for large matrices) +./gemm_03_benchmark_validation --verify 2 +``` + +**Features:** +- Warmup iterations (discarded from timing) +- Benchmark iterations with statistics (min/max/mean/median) +- CPU reference validation using `ck_tile::reference_gemm` +- GPU reference validation using `ck_tile::reference_gemm_gpu` +- Configurable tolerances + +### 04_heuristics.cpp - Heuristic Selection +Demonstrates custom kernel selection based on problem characteristics: + +```cpp +// Problem size analysis +auto heuristic = [](const Problem& p) -> std::optional { + if (p.M() * p.N() < 256 * 256) { + return small_kernel_key; // Memory-bound heuristic + } else { + return large_kernel_key; // Compute-bound heuristic + } +}; + +dispatcher.set_heuristic(heuristic); +``` + +**Features:** +- Problem size analysis (small vs large matrices) +- Compute-bound vs memory-bound selection +- Custom heuristic function registration + +### 05_json_export.cpp - JSON Export +Exports registry information to JSON for external tool integration: + +```cpp +auto json = registry.to_json(); +std::ofstream file("kernels.json"); +file << json; +``` + +**Use Cases:** +- Kernel metadata serialization +- External analysis tools +- Configuration management + +### 06_multi_registry.cpp - Multiple Registries +Demonstrates using multiple registries with named kernel sets: + +```cpp +// Define separate kernel sets +DECL_KERNEL_SET(compute_optimized, ...); +DECL_KERNEL_SET(latency_optimized, ...); + +// Register to specific registries +Registry compute_registry, latency_registry; +REGISTER_KERNEL_SET(compute_optimized, compute_registry); +REGISTER_KERNEL_SET(latency_optimized, latency_registry); + +// Use appropriate registry based on workload +Dispatcher compute_dispatcher(compute_registry); +Dispatcher latency_dispatcher(latency_registry); +``` + +**Features:** +- Named kernel set registration with `REGISTER_KERNEL_SET` macro +- Separate registries for different optimization goals +- Dynamic kernel set selection by name + +## Benchmark Parameters (stream_config) + +CK Tile uses `stream_config` for benchmark control: + +```cpp +ck_tile::stream_config cfg{ + nullptr, // stream_id - HIP stream (nullptr = default) + true, // time_kernel - Enable timing + 1, // log_level - Verbosity (0=quiet, 1=normal) + 5, // cold_niters - Warmup iterations + 20, // nrepeat - Benchmark iterations + true, // is_gpu_timer - Use GPU events vs CPU chrono + false, // flush_cache - Flush L2 cache between iterations + 1 // rotating_count - Rotating buffers for cache simulation +}; +``` + +| Parameter | CLI Option | Default | Description | +|-----------|------------|---------|-------------| +| `cold_niters_` | `--warmup` | 5 | Warmup iterations | +| `nrepeat_` | `--iterations` | 100 | Benchmark iterations | +| `flush_cache_` | - | false | Flush L2 cache | +| `rotating_count_` | - | 1 | Rotating buffers | +| `is_gpu_timer_` | - | true | GPU timer vs CPU | + +## Declarative Kernel Pattern + +All examples use the declarative `DECL_KERNEL_SET` macro: + +```cpp +DECL_KERNEL_SET(my_kernels, + .add( + Signature() // WHAT: operation signature + .dtype("fp16") // Data type + .layout("rcr"), // Matrix layouts (A=row, B=col, C=row) + Algorithm() // HOW: implementation details + .tile(256, 256, 32) // Tile sizes (M, N, K) + .wave(2, 2, 1) // Wave configuration + .warp(32, 32, 16) // Warp tile sizes + .pipeline("compv4") // Pipeline type + .scheduler("intrawave"), // Scheduler type + "gfx942" // WHERE: target architecture + ) +); +``` + +**Key Macros:** +- `DECL_KERNEL_SET(name, ...)` - Declare a kernel set +- `REGISTER_GENERATED_KERNELS` - Register all kernels from this example +- `REGISTER_KERNEL_SET(name, registry)` - Register specific kernel set to a registry + +## Related Documentation + +- [Python GEMM Examples](../python/README.md) +- [Convolution Examples](../../conv/cpp/README.md) +- [Main Dispatcher README](../../../README.md) diff --git a/dispatcher/examples/gemm/python/01_basic_gemm.py b/dispatcher/examples/gemm/python/01_basic_gemm.py new file mode 100644 index 00000000000..93a78d24d1e --- /dev/null +++ b/dispatcher/examples/gemm/python/01_basic_gemm.py @@ -0,0 +1,331 @@ +#!/usr/bin/env python3 + +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +Example 01: Basic GEMM with Multiple Kernels + +Demonstrates: +1. Declaring multiple kernel configurations +2. Printing all registered kernels +3. Running each kernel and validating output +4. Comparing performance across kernels + +Complexity: ★★☆☆☆ + +Usage: + python3 01_basic_gemm.py + python3 01_basic_gemm.py --help + python3 01_basic_gemm.py --dtype bf16 + python3 01_basic_gemm.py --size 2048 +""" + +import sys +import argparse +from pathlib import Path +from dataclasses import dataclass +from typing import List + +sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python")) +import numpy as np + +from ctypes_utils import ( + KernelConfig, + setup_gemm_dispatcher, + cleanup_gemm, + reset_for_example, +) + + +@dataclass +class KernelSpec: + """Specification for a kernel configuration""" + + name: str + tile_m: int + tile_n: int + tile_k: int + pipeline: str = "compv3" + scheduler: str = "intrawave" + + +# Define multiple kernel configurations to test (50+ kernels) +KERNEL_SPECS = [ + # Small tiles - compv3 + KernelSpec("small_64x64_k32", 64, 64, 32, "compv3"), + KernelSpec("small_64x64_k64", 64, 64, 64, "compv3"), + # Small tiles - compv4 + KernelSpec("small_64x64_v4_k32", 64, 64, 32, "compv4"), + KernelSpec("small_64x64_v4_k64", 64, 64, 64, "compv4"), + # Medium tiles - compv3 + KernelSpec("med_128x128_k32", 128, 128, 32, "compv3"), + KernelSpec("med_128x128_k64", 128, 128, 64, "compv3"), + KernelSpec("med_128x128_k128", 128, 128, 128, "compv3"), + # Medium tiles - compv4 + KernelSpec("med_128x128_v4_k32", 128, 128, 32, "compv4"), + KernelSpec("med_128x128_v4_k64", 128, 128, 64, "compv4"), + KernelSpec("med_128x128_v4_k128", 128, 128, 128, "compv4"), + # Rectangular tiles - compv3 + KernelSpec("rect_64x128_k32", 64, 128, 32, "compv3"), + KernelSpec("rect_64x128_k64", 64, 128, 64, "compv3"), + KernelSpec("rect_128x64_k32", 128, 64, 32, "compv3"), + KernelSpec("rect_128x64_k64", 128, 64, 64, "compv3"), + # Rectangular tiles - compv4 + KernelSpec("rect_64x128_v4_k32", 64, 128, 32, "compv4"), + KernelSpec("rect_64x128_v4_k64", 64, 128, 64, "compv4"), + KernelSpec("rect_128x64_v4_k32", 128, 64, 32, "compv4"), + KernelSpec("rect_128x64_v4_k64", 128, 64, 64, "compv4"), + # Large tiles - compv3 + KernelSpec("large_256x128_k32", 256, 128, 32, "compv3"), + KernelSpec("large_256x128_k64", 256, 128, 64, "compv3"), + KernelSpec("large_128x256_k32", 128, 256, 32, "compv3"), + KernelSpec("large_128x256_k64", 128, 256, 64, "compv3"), + KernelSpec("large_256x256_k32", 256, 256, 32, "compv3"), + KernelSpec("large_256x256_k64", 256, 256, 64, "compv3"), + # Large tiles - compv4 + KernelSpec("large_256x128_v4_k32", 256, 128, 32, "compv4"), + KernelSpec("large_256x128_v4_k64", 256, 128, 64, "compv4"), + KernelSpec("large_128x256_v4_k32", 128, 256, 32, "compv4"), + KernelSpec("large_128x256_v4_k64", 128, 256, 64, "compv4"), + KernelSpec("large_256x256_v4_k32", 256, 256, 32, "compv4"), + KernelSpec("large_256x256_v4_k64", 256, 256, 64, "compv4"), + # Interwave scheduler variants + KernelSpec("int_64x64_k32", 64, 64, 32, "compv3", "interwave"), + KernelSpec("int_128x128_k32", 128, 128, 32, "compv3", "interwave"), + KernelSpec("int_128x128_k64", 128, 128, 64, "compv3", "interwave"), + KernelSpec("int_256x128_k32", 256, 128, 32, "compv3", "interwave"), + # More tile_k variations - compv3 + KernelSpec("med_128x128_k16", 128, 128, 16, "compv3"), + KernelSpec("rect_64x128_k16", 64, 128, 16, "compv3"), + KernelSpec("rect_128x64_k16", 128, 64, 16, "compv3"), + # More tile_k variations - compv4 + KernelSpec("med_128x128_v4_k16", 128, 128, 16, "compv4"), + KernelSpec("rect_64x128_v4_k16", 64, 128, 16, "compv4"), + KernelSpec("rect_128x64_v4_k16", 128, 64, 16, "compv4"), + # Additional rectangular + KernelSpec("rect_32x64_k32", 32, 64, 32, "compv3"), + KernelSpec("rect_64x32_k32", 64, 32, 32, "compv3"), + KernelSpec("rect_32x128_k32", 32, 128, 32, "compv3"), + KernelSpec("rect_128x32_k32", 128, 32, 32, "compv3"), + # Additional compv4 variants + KernelSpec("rect_32x64_v4_k32", 32, 64, 32, "compv4"), + KernelSpec("rect_64x32_v4_k32", 64, 32, 32, "compv4"), + KernelSpec("rect_32x128_v4_k32", 32, 128, 32, "compv4"), + KernelSpec("rect_128x32_v4_k32", 128, 32, 32, "compv4"), +] + + +def create_kernel_config(spec: KernelSpec, dtype: str, arch: str) -> KernelConfig: + """Create a KernelConfig from a spec""" + # Adjust warp tiles based on tile size + if spec.tile_m <= 64: + warp_m, warp_n = 16, 16 + else: + warp_m, warp_n = 32, 32 + + return KernelConfig( + dtype_a=dtype, + dtype_b=dtype, + dtype_c=dtype, + dtype_acc="fp32", + layout_a="row", + layout_b="col", + layout_c="row", + tile_m=spec.tile_m, + tile_n=spec.tile_n, + tile_k=spec.tile_k, + wave_m=2, + wave_n=2, + wave_k=1, + warp_m=warp_m, + warp_n=warp_n, + warp_k=16, + pipeline=spec.pipeline, + scheduler=spec.scheduler, + epilogue="cshuffle", + gfx_arch=arch, + ) + + +def print_kernel_table(specs: List[KernelSpec], dtype: str): + """Print a formatted table of kernel configurations""" + print("\n" + "=" * 70) + print(f" DECLARED KERNEL CONFIGURATIONS ({len(specs)} kernels)") + print("=" * 70) + print(f"\n {'#':<3} {'Name':<18} {'Tile':<14} {'Pipeline':<10} {'Scheduler':<12}") + print(" " + "-" * 68) + + for i, spec in enumerate(specs, 1): + tile = f"{spec.tile_m}x{spec.tile_n}x{spec.tile_k}" + print( + f" {i:<3} {spec.name:<18} {tile:<14} {spec.pipeline:<10} {spec.scheduler:<12}" + ) + + print(" " + "-" * 68) + print(f" Data type: {dtype}") + + +def main(): + parser = argparse.ArgumentParser( + description="Basic GEMM Example with Multiple Kernels", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + python3 01_basic_gemm.py # Default FP16 with 4 kernels + python3 01_basic_gemm.py --dtype bf16 # BF16 mode + python3 01_basic_gemm.py --size 2048 # Larger problem size + python3 01_basic_gemm.py --num-kernels 2 # Test only 2 kernels + """, + ) + parser.add_argument( + "--dtype", + default="fp16", + choices=["fp16", "bf16", "fp32"], + help="Data type (default: fp16)", + ) + parser.add_argument( + "--arch", + default="gfx942", + help="Target architecture (default: gfx942)", + ) + parser.add_argument( + "--size", + type=int, + default=512, + help="Problem size MxNxK (default: 512)", + ) + parser.add_argument( + "--num-kernels", + type=int, + default=0, + help="Number of kernels to test (0 = all)", + ) + args = parser.parse_args() + + reset_for_example() + + print("=" * 70) + print("Example 01: Basic GEMM with Multiple Kernels") + print("=" * 70) + + # Select kernels to test + specs = KERNEL_SPECS[: args.num_kernels] if args.num_kernels > 0 else KERNEL_SPECS + + # ========================================================================= + # Step 1: Print all kernel configurations + # ========================================================================= + print_kernel_table(specs, args.dtype) + + # ========================================================================= + # Step 2: Setup and test each kernel + # ========================================================================= + print("\n" + "=" * 70) + print(" RUNNING KERNELS") + print("=" * 70) + + np_dtype = np.float16 if args.dtype in ["fp16", "bf16"] else np.float32 + M, N, K = args.size, args.size, args.size + + results = [] + + print(f"\n Problem size: {M}x{N}x{K}\n") + print( + f" {'#':<3} {'Name':<18} {'Tile':<14} {'Time (ms)':>10} {'TFLOPS':>10} {'Max Err':>10} {'Status':<8}" + ) + print(" " + "-" * 78) + + for i, spec in enumerate(specs, 1): + # Create unique test data per kernel + np.random.seed(42 + i * 1000) + A = (np.random.randn(M, K) * 0.1).astype(np_dtype) + B = (np.random.randn(K, N) * 0.1).astype(np_dtype) + + # Create config and setup dispatcher + config = create_kernel_config(spec, args.dtype, args.arch) + + setup = setup_gemm_dispatcher( + config=config, + registry_name=f"kernel_{spec.name}", + verbose=False, + auto_rebuild=True, + ) + + tile = f"{spec.tile_m}x{spec.tile_n}x{spec.tile_k}" + + if not setup.success: + print( + f" {i:<3} {spec.name:<18} {tile:<14} {'N/A':>10} {'N/A':>10} {'N/A':>10} {'FAIL':<8}" + ) + results.append((spec.name, False, 0, 0, 0)) + cleanup_gemm() + continue + + dispatcher = setup.dispatcher + + # Check if size is supported + if not dispatcher.is_supported(M, N, K): + print( + f" {i:<3} {spec.name:<18} {tile:<14} {'N/A':>10} {'N/A':>10} {'N/A':>10} {'SKIP':<8}" + ) + results.append((spec.name, False, 0, 0, 0)) + cleanup_gemm() + continue + + # Run GEMM + result = dispatcher.run(A, B, M, N, K) + + if not result.success: + print( + f" {i:<3} {spec.name:<18} {tile:<14} {'N/A':>10} {'N/A':>10} {'N/A':>10} {'FAIL':<8}" + ) + results.append((spec.name, False, 0, 0, 0)) + cleanup_gemm() + continue + + # Validate against NumPy reference + C_ref = np.matmul(A.astype(np.float32), B.astype(np.float32)).astype(np_dtype) + max_err = np.max(np.abs(result.output - C_ref)) + + # Check if within tolerance + passed = max_err < 1e-2 + status = "PASS" if passed else "FAIL" + + print( + f" {i:<3} {spec.name:<18} {tile:<14} {result.time_ms:>10.4f} {result.tflops:>10.2f} {max_err:>10.2e} {status:<8}" + ) + results.append((spec.name, passed, result.time_ms, result.tflops, max_err)) + + cleanup_gemm() + + # ========================================================================= + # Step 3: Summary + # ========================================================================= + print("\n" + "=" * 70) + print(" SUMMARY") + print("=" * 70) + + passed = sum(1 for r in results if r[1]) + failed = len(results) - passed + + print(f"\n Results: {passed}/{len(results)} kernels passed") + print(f" Problem: {M}x{N}x{K}, dtype={args.dtype}") + + if results: + valid_results = [r for r in results if r[1]] + if valid_results: + best = max(valid_results, key=lambda x: x[3]) + print(f"\n Best kernel: {best[0]} ({best[3]:.2f} TFLOPS)") + + if failed == 0: + print("\n *** ALL KERNELS PASSED ***") + else: + print(f"\n *** {failed} KERNELS FAILED ***") + + print("=" * 70) + + return 0 if failed == 0 else 1 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/dispatcher/examples/gemm/python/02_batch_gemm.py b/dispatcher/examples/gemm/python/02_batch_gemm.py new file mode 100644 index 00000000000..039aba2790f --- /dev/null +++ b/dispatcher/examples/gemm/python/02_batch_gemm.py @@ -0,0 +1,149 @@ +#!/usr/bin/env python3 + +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +Example 02: Batch GEMM + +Runs multiple GEMM operations with different sizes. + +Complexity: ★★☆☆☆ + +Usage: + python3 02_batch_gemm.py + python3 02_batch_gemm.py --help + python3 02_batch_gemm.py --dtype bf16 +""" + +import sys +import argparse +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python")) +import numpy as np + +from ctypes_utils import ( + KernelConfig, + setup_gemm_dispatcher, + cleanup_gemm, + reset_for_example, +) + + +def main(): + parser = argparse.ArgumentParser( + description="Batch GEMM Example - runs multiple sizes", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + python3 02_batch_gemm.py # Default FP16 + python3 02_batch_gemm.py --dtype bf16 # BF16 GEMM + python3 02_batch_gemm.py --max-size 2048 # Limit max size + """, + ) + parser.add_argument( + "--dtype", + default="fp16", + choices=["fp16", "bf16", "fp32"], + help="Data type (default: fp16)", + ) + parser.add_argument( + "--max-size", + type=int, + default=4096, + help="Maximum problem size (default: 4096)", + ) + parser.add_argument( + "--arch", default="gfx942", help="Target architecture (default: gfx942)" + ) + args = parser.parse_args() + + reset_for_example() + + print("=" * 60) + print("Example 02: Batch GEMM") + print("=" * 60) + + # ========================================================================= + # Step 1: Setup dispatcher + # ========================================================================= + print("\nStep 1: Setup Dispatcher") + + config = KernelConfig( + dtype_a=args.dtype, + dtype_b=args.dtype, + dtype_c=args.dtype, + tile_m=128, + tile_n=128, + tile_k=32, + gfx_arch=args.arch, + ) + + setup = setup_gemm_dispatcher(config, registry_name="batch_gemm", verbose=True) + if not setup.success: + print(f" ERROR: {setup.error}") + return 1 + + dispatcher = setup.dispatcher + + # ========================================================================= + # Step 2: Run batch of different sizes + # ========================================================================= + print("\nStep 2: Run Batch") + + # Generate sizes up to max_size + all_sizes = [ + (256, 256, 256), + (512, 512, 512), + (1024, 1024, 1024), + (2048, 2048, 2048), + (4096, 4096, 4096), + ] + sizes = [(m, n, k) for m, n, k in all_sizes if max(m, n, k) <= args.max_size] + + np_dtype = np.float16 if args.dtype in ["fp16", "bf16"] else np.float32 + + print(f"\n {'Size':<20} | {'Time (ms)':>12} | {'TFLOPS':>10} | {'Status':>8}") + print(" " + "-" * 60) + + total_ops = 0 + total_time = 0 + + for M, N, K in sizes: + if not dispatcher.is_supported(M, N, K): + print(f" {M:>4}x{N:>4}x{K:<4} | {'N/A':>12} | {'N/A':>10} | Skipped") + continue + + A = np.random.randn(M, K).astype(np_dtype) * 0.1 + B = np.random.randn(K, N).astype(np_dtype) * 0.1 + + result = dispatcher.run(A, B, M, N, K) + + if result.success: + total_ops += 2 * M * N * K + total_time += result.time_ms + print( + f" {M:>4}x{N:>4}x{K:<4} | {result.time_ms:>12.4f} | {result.tflops:>10.2f} | OK" + ) + else: + print(f" {M:>4}x{N:>4}x{K:<4} | {'N/A':>12} | {'N/A':>10} | Error") + + print(" " + "-" * 60) + + if total_time > 0: + avg_tflops = (total_ops / 1e12) / (total_time / 1000) + print(f"\n Total: {total_time:.2f} ms, Average: {avg_tflops:.2f} TFLOPS") + + # Cleanup + cleanup_gemm() + + print("\n" + "=" * 60) + print("Batch GEMM complete!") + print("=" * 60) + + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/dispatcher/examples/gemm/python/03_benchmark.py b/dispatcher/examples/gemm/python/03_benchmark.py new file mode 100644 index 00000000000..bec1b7e2fb4 --- /dev/null +++ b/dispatcher/examples/gemm/python/03_benchmark.py @@ -0,0 +1,171 @@ +#!/usr/bin/env python3 + +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +Example 03: Benchmark + +Performance benchmarking with compute-optimized kernel configuration. + +Complexity: ★★★☆☆ + +Usage: + python3 03_benchmark.py + python3 03_benchmark.py --help + python3 03_benchmark.py --size 4096 + python3 03_benchmark.py --dtype bf16 --iterations 20 +""" + +import sys +import argparse +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python")) +import numpy as np + +from ctypes_utils import ( + KernelConfig, + setup_gemm_dispatcher, + cleanup_gemm, + reset_for_example, +) + + +def main(): + parser = argparse.ArgumentParser( + description="GEMM Benchmark Example - performance testing", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + python3 03_benchmark.py # Default benchmark suite + python3 03_benchmark.py --size 4096 # Single size benchmark + python3 03_benchmark.py --dtype bf16 # BF16 benchmark + python3 03_benchmark.py --iterations 20 # More iterations + """, + ) + parser.add_argument( + "--dtype", + default="bf16", + choices=["fp16", "bf16", "fp32"], + help="Data type (default: bf16)", + ) + parser.add_argument( + "--size", + type=int, + default=0, + help="Single problem size MxNxK (default: run all sizes)", + ) + parser.add_argument( + "--warmup", type=int, default=3, help="Warmup iterations (default: 3)" + ) + parser.add_argument( + "--iterations", type=int, default=10, help="Benchmark iterations (default: 10)" + ) + parser.add_argument( + "--arch", default="gfx942", help="Target architecture (default: gfx942)" + ) + args = parser.parse_args() + + reset_for_example() + + print("=" * 60) + print("Example 03: Benchmark") + print("=" * 60) + + # ========================================================================= + # Step 1: Setup dispatcher with compute-optimized config + # ========================================================================= + print("\nStep 1: Setup Dispatcher") + + config = KernelConfig( + dtype_a=args.dtype, + dtype_b=args.dtype, + dtype_c=args.dtype, + tile_m=128, + tile_n=128, + tile_k=32, + pipeline="compv4", + scheduler="intrawave", + gfx_arch=args.arch, + ) + + setup = setup_gemm_dispatcher(config, registry_name="benchmark", verbose=True) + if not setup.success: + print(f" ERROR: {setup.error}") + return 1 + + dispatcher = setup.dispatcher + + # ========================================================================= + # Step 2: Benchmark + # ========================================================================= + print("\nStep 2: Benchmark") + + if args.size > 0: + sizes = [(args.size, args.size, args.size)] + else: + sizes = [ + (512, 512, 512), + (1024, 1024, 1024), + (2048, 2048, 2048), + (4096, 4096, 4096), + (1024, 2048, 512), + (2048, 1024, 2048), + ] + + np_dtype = np.float16 if args.dtype in ["fp16", "bf16"] else np.float32 + + print(f" Warmup: {args.warmup}, Iterations: {args.iterations}\n") + + print(f" {'Size':<20} | {'Min (ms)':>10} | {'Avg (ms)':>10} | {'TFLOPS':>10}") + print(" " + "-" * 60) + + all_tflops = [] + + for M, N, K in sizes: + if not dispatcher.is_supported(M, N, K): + continue + + A = np.random.randn(M, K).astype(np_dtype) * 0.1 + B = np.random.randn(K, N).astype(np_dtype) * 0.1 + + # Warmup + for _ in range(args.warmup): + dispatcher.run(A, B, M, N, K) + + # Benchmark + times = [] + for _ in range(args.iterations): + result = dispatcher.run(A, B, M, N, K) + if result.success: + times.append(result.time_ms) + + if times: + min_time = min(times) + avg_time = sum(times) / len(times) + tflops = (2.0 * M * N * K / (avg_time * 1e-3)) / 1e12 + all_tflops.append(tflops) + print( + f" {M:>4}x{N:>4}x{K:<4} | {min_time:>10.4f} | {avg_time:>10.4f} | {tflops:>10.2f}" + ) + + # Cleanup + cleanup_gemm() + + # Summary + print("\n" + "=" * 60) + print("Summary") + print("=" * 60) + + if all_tflops: + print(f" Average: {sum(all_tflops) / len(all_tflops):.2f} TFLOPS") + print(f" Peak: {max(all_tflops):.2f} TFLOPS") + + print("=" * 60) + + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/dispatcher/examples/gemm/python/04_validation.py b/dispatcher/examples/gemm/python/04_validation.py new file mode 100644 index 00000000000..2fe54c53f75 --- /dev/null +++ b/dispatcher/examples/gemm/python/04_validation.py @@ -0,0 +1,156 @@ +#!/usr/bin/env python3 + +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +Example 04: Validation + +Validates GPU GEMM against NumPy reference. + +Complexity: ★★★☆☆ + +Usage: + python3 04_validation.py + python3 04_validation.py --help + python3 04_validation.py --dtype bf16 +""" + +import sys +import argparse +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python")) +import numpy as np + +from ctypes_utils import ( + KernelConfig, + Validator, + setup_gemm_dispatcher, + cleanup_gemm, + reset_for_example, +) + + +def main(): + parser = argparse.ArgumentParser( + description="GEMM Validation Example - validates GPU results against NumPy", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + python3 04_validation.py # Default FP16 validation + python3 04_validation.py --dtype bf16 # BF16 validation + python3 04_validation.py --rtol 1e-2 # Relaxed tolerance + """, + ) + parser.add_argument( + "--dtype", + default="fp16", + choices=["fp16", "bf16", "fp32"], + help="Data type (default: fp16)", + ) + parser.add_argument( + "--rtol", type=float, default=1e-3, help="Relative tolerance (default: 1e-3)" + ) + parser.add_argument( + "--atol", type=float, default=1e-2, help="Absolute tolerance (default: 1e-2)" + ) + parser.add_argument( + "--arch", default="gfx942", help="Target architecture (default: gfx942)" + ) + args = parser.parse_args() + + reset_for_example() + + print("=" * 60) + print("Example 04: Validation") + print("=" * 60) + + # ========================================================================= + # Step 1: Setup dispatcher + # ========================================================================= + print("\nStep 1: Setup Dispatcher") + + config = KernelConfig( + dtype_a=args.dtype, + dtype_b=args.dtype, + dtype_c=args.dtype, + tile_m=128, + tile_n=128, + tile_k=32, + gfx_arch=args.arch, + ) + + setup = setup_gemm_dispatcher(config, registry_name="validation", verbose=True) + if not setup.success: + print(f" ERROR: {setup.error}") + return 1 + + dispatcher = setup.dispatcher + + # ========================================================================= + # Step 2: Run validation tests + # ========================================================================= + print("\nStep 2: Validation Tests") + + validator = Validator(rtol=args.rtol, atol=args.atol) + np_dtype = np.float16 if args.dtype in ["fp16", "bf16"] else np.float32 + + test_cases = [ + ("Identity", 128, 128, 128, "identity"), + ("Small", 256, 256, 256, "random"), + ("Medium", 512, 512, 512, "random"), + ("Large", 1024, 1024, 1024, "random"), + ("Non-square", 512, 1024, 256, "random"), + ] + + passed = 0 + failed = 0 + + print(f"\n {'Test':<15} | {'Size':<15} | {'Max Err':>10} | {'Status':>8}") + print(" " + "-" * 55) + + for name, M, N, K, pattern in test_cases: + if not dispatcher.is_supported(M, N, K): + print(f" {name:<15} | {M}x{N}x{K:<5} | {'N/A':>10} | Skipped") + continue + + np.random.seed(42) + if pattern == "identity": + A = np.eye(M, K, dtype=np_dtype) + B = np.eye(K, N, dtype=np_dtype) + else: + A = (np.random.randn(M, K) * 0.1).astype(np_dtype) + B = (np.random.randn(K, N) * 0.1).astype(np_dtype) + + result = dispatcher.run(A, B, M, N, K) + if not result.success: + print(f" {name:<15} | {M}x{N}x{K:<5} | {'GPU Err':>10} | FAILED") + failed += 1 + continue + + C_ref = np.matmul(A.astype(np.float32), B.astype(np.float32)).astype(np_dtype) + is_valid, max_err, _ = validator.check(result.output, C_ref) + + if is_valid: + print(f" {name:<15} | {M}x{N}x{K:<5} | {max_err:>10.2e} | PASSED") + passed += 1 + else: + print(f" {name:<15} | {M}x{N}x{K:<5} | {max_err:>10.2e} | FAILED") + failed += 1 + + # Cleanup + cleanup_gemm() + + # Summary + print("\n" + "=" * 60) + total = passed + failed + print(f"Results: {passed}/{total} passed") + print(f"Settings: dtype={args.dtype}, rtol={args.rtol}, atol={args.atol}") + print("=" * 60) + + return 0 if failed == 0 else 1 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/dispatcher/examples/gemm/python/05_numpy_integration.py b/dispatcher/examples/gemm/python/05_numpy_integration.py new file mode 100644 index 00000000000..493ce46d223 --- /dev/null +++ b/dispatcher/examples/gemm/python/05_numpy_integration.py @@ -0,0 +1,166 @@ +#!/usr/bin/env python3 + +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +Example 05: NumPy Integration + +Shows how to create a GPU-accelerated matmul wrapper. + +Complexity: ★★☆☆☆ + +Usage: + python3 05_numpy_integration.py + python3 05_numpy_integration.py --help + python3 05_numpy_integration.py --dtype bf16 +""" + +import sys +import argparse +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python")) +import numpy as np + +from ctypes_utils import ( + KernelConfig, + Dispatcher, + setup_gemm_dispatcher, + cleanup_gemm, + reset_for_example, +) + + +class GPUMatmul: + """GPU-accelerated matrix multiplication wrapper.""" + + def __init__(self, dispatcher: Dispatcher): + self.dispatcher = dispatcher + + def __call__(self, A: np.ndarray, B: np.ndarray) -> np.ndarray: + """Compute C = A @ B on GPU with CPU fallback.""" + M, K = A.shape + K2, N = B.shape + + if K != K2: + raise ValueError(f"Dimension mismatch: {A.shape} @ {B.shape}") + + if not self.dispatcher.is_supported(M, N, K): + return np.matmul(A, B) + + result = self.dispatcher.run(A, B, M, N, K) + return result.output if result.success else np.matmul(A, B) + + +def main(): + parser = argparse.ArgumentParser( + description="NumPy Integration Example - GPU-accelerated matmul wrapper", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + python3 05_numpy_integration.py # Default FP16 + python3 05_numpy_integration.py --dtype bf16 # BF16 mode + """, + ) + parser.add_argument( + "--dtype", + default="fp16", + choices=["fp16", "bf16", "fp32"], + help="Data type (default: fp16)", + ) + parser.add_argument( + "--arch", default="gfx942", help="Target architecture (default: gfx942)" + ) + args = parser.parse_args() + + reset_for_example() + + print("=" * 60) + print("Example 05: NumPy Integration") + print("=" * 60) + + # ========================================================================= + # Step 1: Setup dispatcher + # ========================================================================= + print("\nStep 1: Setup Dispatcher") + + config = KernelConfig( + dtype_a=args.dtype, + dtype_b=args.dtype, + dtype_c=args.dtype, + tile_m=128, + tile_n=128, + tile_k=32, + gfx_arch=args.arch, + ) + + setup = setup_gemm_dispatcher(config, registry_name="numpy", verbose=True) + if not setup.success: + print(f" ERROR: {setup.error}") + return 1 + + dispatcher = setup.dispatcher + np_dtype = np.float16 if args.dtype in ["fp16", "bf16"] else np.float32 + + # ========================================================================= + # Step 2: Create GPU matmul wrapper + # ========================================================================= + print("\nStep 2: Create GPUMatmul") + + gpu_matmul = GPUMatmul(dispatcher=dispatcher) + print(" gpu_matmul ready") + + # ========================================================================= + # Step 3: Demo - Simple multiplication using gpu_matmul + # ========================================================================= + print("\nStep 3: Demo - Simple Multiplication") + + A = np.random.randn(1024, 512).astype(np_dtype) * 0.1 + B = np.random.randn(512, 256).astype(np_dtype) * 0.1 + + # Use the gpu_matmul wrapper + C = gpu_matmul(A, B) + print(f" gpu_matmul result: {C.shape}, sum={C.sum():.4f}") + + M, K = A.shape + _, N = B.shape + result = dispatcher.run(A, B, M, N, K) + + print(f" A: {A.shape}, B: {B.shape} -> C: {result.output.shape}") + print(f" GPU: {result.time_ms:.4f} ms, {result.tflops:.2f} TFLOPS") + + # ========================================================================= + # Step 4: Demo - FFN block + # ========================================================================= + print("\nStep 4: Demo - FFN Block") + + batch, hidden, ffn = 128, 768, 3072 + X = np.random.randn(batch, hidden).astype(np_dtype) * 0.02 + W1 = np.random.randn(hidden, ffn).astype(np_dtype) * 0.02 + W2 = np.random.randn(ffn, hidden).astype(np_dtype) * 0.02 + + result1 = dispatcher.run(X, W1, batch, ffn, hidden) + H = result1.output + result2 = dispatcher.run(H, W2, batch, hidden, ffn) + + print(f" X: {X.shape} -> H: {H.shape} -> Y: {result2.output.shape}") + print(f" Total: {result1.time_ms + result2.time_ms:.4f} ms") + + # Cleanup + cleanup_gemm() + + # Summary + print("\n" + "=" * 60) + print("NumPy Integration Pattern:") + print("=" * 60) + print(" 1. setup_gemm_dispatcher(config)") + print(" 2. GPUMatmul(dispatcher)") + print(" 3. C = gpu_matmul(A, B)") + print("=" * 60) + + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/dispatcher/examples/gemm/python/06_json_export.py b/dispatcher/examples/gemm/python/06_json_export.py new file mode 100644 index 00000000000..9e062e507b3 --- /dev/null +++ b/dispatcher/examples/gemm/python/06_json_export.py @@ -0,0 +1,169 @@ +#!/usr/bin/env python3 + +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +Example 06: JSON Export + +Exports registry configuration to JSON. + +Complexity: ★★☆☆☆ + +Usage: + python3 06_json_export.py + python3 06_json_export.py --help + python3 06_json_export.py --output my_kernels.json +""" + +import sys +import json +import argparse +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python")) + +from ctypes_utils import ( + KernelConfig, + setup_gemm_dispatcher, + cleanup_gemm, + reset_for_example, +) + + +def main(): + parser = argparse.ArgumentParser( + description="JSON Export Example - exports registry to JSON", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + python3 06_json_export.py # Default output to kernels.json + python3 06_json_export.py --output my.json # Custom output file + """, + ) + parser.add_argument( + "--output", + "-o", + default="kernels.json", + help="Output JSON file (default: kernels.json)", + ) + parser.add_argument( + "--dtype", + default="fp16", + choices=["fp16", "bf16", "fp32"], + help="Data type (default: fp16)", + ) + parser.add_argument( + "--arch", default="gfx942", help="Target architecture (default: gfx942)" + ) + args = parser.parse_args() + + reset_for_example() + + print("=" * 60) + print("Example 06: JSON Export") + print("=" * 60) + + # ========================================================================= + # Step 1: Setup dispatcher + # ========================================================================= + print("\nStep 1: Setup Dispatcher") + + config = KernelConfig( + dtype_a=args.dtype, + dtype_b=args.dtype, + dtype_c=args.dtype, + tile_m=128, + tile_n=128, + tile_k=32, + gfx_arch=args.arch, + ) + + setup = setup_gemm_dispatcher(config, registry_name="export_demo", verbose=True) + if not setup.success: + print(f" ERROR: {setup.error}") + return 1 + + # ========================================================================= + # Step 2: Define additional configs for export + # ========================================================================= + print("\nStep 2: Define Additional Configs") + + configs = [ + config, + KernelConfig( + dtype_a=args.dtype, + dtype_b=args.dtype, + dtype_c=args.dtype, + tile_m=256, + tile_n=256, + tile_k=64, + gfx_arch=args.arch, + ), + KernelConfig( + dtype_a=args.dtype, + dtype_b=args.dtype, + dtype_c=args.dtype, + tile_m=64, + tile_n=64, + tile_k=32, + gfx_arch=args.arch, + ), + ] + + for cfg in configs: + print(f" - {cfg.tile_str}") + + # ========================================================================= + # Step 3: Export to JSON + # ========================================================================= + print("\nStep 3: Export to JSON") + + export_data = { + "registry": setup.registry.name, + "kernel_count": len(configs), + "kernels": [], + } + + for cfg in configs: + kernel_info = { + "tile": cfg.tile_str, + "dtypes": {"A": cfg.dtype_a, "B": cfg.dtype_b, "C": cfg.dtype_c}, + "layout": cfg.layout, + "pipeline": cfg.pipeline, + "target": cfg.gfx_arch, + } + export_data["kernels"].append(kernel_info) + + # Include C++ library info + if setup.lib: + cpp_json = setup.lib.export_registry_json() + try: + export_data["cpp_registry"] = json.loads(cpp_json) + except json.JSONDecodeError: + pass + + json_str = json.dumps(export_data, indent=2) + + with open(args.output, "w") as f: + f.write(json_str) + print(f" Saved to: {args.output}") + + # Preview + print("\nStep 4: Preview") + print("-" * 60) + print(json_str[:500] + ("..." if len(json_str) > 500 else "")) + print("-" * 60) + + # Cleanup + cleanup_gemm() + + print("\n" + "=" * 60) + print("JSON Export complete!") + print("=" * 60) + + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/dispatcher/examples/gemm/python/07_stress_test.py b/dispatcher/examples/gemm/python/07_stress_test.py new file mode 100644 index 00000000000..81600306319 --- /dev/null +++ b/dispatcher/examples/gemm/python/07_stress_test.py @@ -0,0 +1,513 @@ +#!/usr/bin/env python3 + +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +Example 07: Stress Test - Multiple Kernels with Validation + +Consolidated stress test that: +1. Declares multiple kernel configurations (various tiles, pipelines, layouts) +2. Prints all registered kernels with details +3. Validates each kernel against NumPy reference +4. Optional benchmarking mode + +This tests: +- Multiple tile sizes (64x64, 128x128, 256x256) +- Multiple pipelines (compv3, compv4) +- Multiple data types (fp16, bf16) +- Different schedulers (intrawave, interwave) + +Complexity: ★★★★☆ + +Usage: + python3 07_stress_test.py + python3 07_stress_test.py --help + python3 07_stress_test.py --num-kernels 10 + python3 07_stress_test.py --benchmark + python3 07_stress_test.py --dtype bf16 +""" + +import sys +import argparse +from pathlib import Path +from dataclasses import dataclass +from typing import List, Tuple + +sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python")) +import numpy as np + +from ctypes_utils import ( + KernelConfig, + setup_gemm_dispatcher, + cleanup_gemm, + reset_for_example, + Validator, +) + + +@dataclass +class KernelSpec: + """A kernel specification for testing""" + + name: str + tile_m: int + tile_n: int + tile_k: int + wave_m: int = 2 + wave_n: int = 2 + wave_k: int = 1 + warp_m: int = 32 + warp_n: int = 32 + warp_k: int = 16 + pipeline: str = "compv3" + scheduler: str = "intrawave" + layout: str = "rcr" + + def to_config(self, dtype: str, arch: str) -> KernelConfig: + """Convert to KernelConfig""" + # Adjust warp tiles for smaller tiles + warp_m = min(self.warp_m, self.tile_m // self.wave_m) + warp_n = min(self.warp_n, self.tile_n // self.wave_n) + warp_k = self.warp_k + + return KernelConfig( + dtype_a=dtype, + dtype_b=dtype, + dtype_c=dtype, + dtype_acc="fp32", + layout_a={"r": "row", "c": "col"}[self.layout[0]], + layout_b={"r": "row", "c": "col"}[self.layout[1]], + layout_c={"r": "row", "c": "col"}[self.layout[2]], + tile_m=self.tile_m, + tile_n=self.tile_n, + tile_k=self.tile_k, + wave_m=self.wave_m, + wave_n=self.wave_n, + wave_k=self.wave_k, + warp_m=warp_m, + warp_n=warp_n, + warp_k=warp_k, + pipeline=self.pipeline, + scheduler=self.scheduler, + epilogue="cshuffle", + gfx_arch=arch, + ) + + +# Define stress test kernel configurations +KERNEL_SPECS = [ + # Small tiles - compv3 + KernelSpec( + "small_compv3", + 64, + 64, + 32, + wave_m=2, + wave_n=2, + warp_m=16, + warp_n=16, + warp_k=32, + pipeline="compv3", + ), + KernelSpec( + "small_compv4", + 64, + 64, + 32, + wave_m=2, + wave_n=2, + warp_m=16, + warp_n=16, + warp_k=32, + pipeline="compv4", + ), + # Medium tiles + KernelSpec( + "medium_compv3", + 128, + 128, + 32, + wave_m=2, + wave_n=2, + warp_m=32, + warp_n=32, + warp_k=16, + pipeline="compv3", + ), + KernelSpec( + "medium_compv4", + 128, + 128, + 32, + wave_m=2, + wave_n=2, + warp_m=32, + warp_n=32, + warp_k=16, + pipeline="compv4", + ), + KernelSpec( + "medium_k64", + 128, + 128, + 64, + wave_m=2, + wave_n=2, + warp_m=32, + warp_n=32, + warp_k=16, + pipeline="compv3", + ), + # Rectangular tiles + KernelSpec( + "rect_64x128", + 64, + 128, + 32, + wave_m=2, + wave_n=2, + warp_m=32, + warp_n=32, + warp_k=16, + pipeline="compv3", + ), + KernelSpec( + "rect_128x64", + 128, + 64, + 32, + wave_m=2, + wave_n=2, + warp_m=32, + warp_n=32, + warp_k=16, + pipeline="compv3", + ), + # Different schedulers + KernelSpec( + "interwave", + 128, + 128, + 32, + wave_m=2, + wave_n=2, + warp_m=32, + warp_n=32, + warp_k=16, + pipeline="compv3", + scheduler="interwave", + ), + # Large tiles + KernelSpec( + "large_compv3", + 256, + 128, + 32, + wave_m=2, + wave_n=2, + warp_m=32, + warp_n=32, + warp_k=16, + pipeline="compv3", + ), + KernelSpec( + "large_compv4", + 256, + 128, + 64, + wave_m=2, + wave_n=2, + warp_m=32, + warp_n=32, + warp_k=16, + pipeline="compv4", + ), +] + + +def print_kernel_summary(specs: List[KernelSpec], dtype: str): + """Print a summary table of all kernel specs""" + print("\n" + "=" * 80) + print(f" DECLARED KERNEL CONFIGURATIONS ({len(specs)} kernels)") + print("=" * 80) + print( + f"\n {'#':<3} {'Name':<18} {'Tile':<12} {'Wave':<10} {'Warp':<12} {'Pipeline':<10} {'Sched':<10}" + ) + print(" " + "-" * 78) + + for i, spec in enumerate(specs, 1): + tile = f"{spec.tile_m}x{spec.tile_n}x{spec.tile_k}" + wave = f"{spec.wave_m}x{spec.wave_n}x{spec.wave_k}" + warp = f"{spec.warp_m}x{spec.warp_n}x{spec.warp_k}" + print( + f" {i:<3} {spec.name:<18} {tile:<12} {wave:<10} {warp:<12} {spec.pipeline:<10} {spec.scheduler:<10}" + ) + + print(" " + "-" * 78) + print(f" Data type: {dtype}\n") + + +def validate_kernel( + spec: KernelSpec, + dtype: str, + arch: str, + size: int, + validator: Validator, + kernel_index: int = 0, + verbose: bool = False, +) -> Tuple[bool, float, str]: + """ + Validate a single kernel configuration. + Returns: (passed, max_error, message) + """ + np_dtype = np.float16 if dtype in ["fp16", "bf16"] else np.float32 + + # Create config + config = spec.to_config(dtype, arch) + + # Setup dispatcher + setup = setup_gemm_dispatcher( + config=config, + registry_name=f"stress_{spec.name}", + verbose=False, + auto_rebuild=True, + ) + + if not setup.success: + return False, 0.0, f"Setup failed: {setup.error}" + + dispatcher = setup.dispatcher + M, N, K = size, size, size + + if not dispatcher.is_supported(M, N, K): + cleanup_gemm() + return False, 0.0, f"Size {M}x{N}x{K} not supported" + + # Use different seed per kernel to get unique test data + # This ensures each kernel is tested with different matrices + np.random.seed(42 + kernel_index * 1000) + A = (np.random.randn(M, K) * 0.1).astype(np_dtype) + B = (np.random.randn(K, N) * 0.1).astype(np_dtype) + + # Run GPU GEMM + result = dispatcher.run(A, B, M, N, K) + + if not result.success: + cleanup_gemm() + return False, 0.0, "GPU execution failed" + + # Validate against NumPy + C_ref = np.matmul(A.astype(np.float32), B.astype(np.float32)).astype(np_dtype) + is_valid, max_err, _ = validator.check(result.output, C_ref) + + cleanup_gemm() + + return is_valid, max_err, f"{result.time_ms:.2f}ms, {result.tflops:.1f} TFLOPS" + + +def benchmark_kernel( + spec: KernelSpec, + dtype: str, + arch: str, + size: int, + warmup: int = 3, + iterations: int = 10, +) -> Tuple[bool, float, float]: + """ + Benchmark a kernel configuration. + Returns: (success, avg_time_ms, tflops) + """ + np_dtype = np.float16 if dtype in ["fp16", "bf16"] else np.float32 + + config = spec.to_config(dtype, arch) + setup = setup_gemm_dispatcher( + config=config, + registry_name=f"bench_{spec.name}", + verbose=False, + auto_rebuild=True, + ) + + if not setup.success: + return False, 0.0, 0.0 + + dispatcher = setup.dispatcher + M, N, K = size, size, size + + if not dispatcher.is_supported(M, N, K): + cleanup_gemm() + return False, 0.0, 0.0 + + A = (np.random.randn(M, K) * 0.1).astype(np_dtype) + B = (np.random.randn(K, N) * 0.1).astype(np_dtype) + + # Warmup + for _ in range(warmup): + dispatcher.run(A, B, M, N, K) + + # Benchmark + times = [] + for _ in range(iterations): + result = dispatcher.run(A, B, M, N, K) + if result.success: + times.append(result.time_ms) + + cleanup_gemm() + + if not times: + return False, 0.0, 0.0 + + avg_time = sum(times) / len(times) + tflops = (2.0 * M * N * K / (avg_time * 1e-3)) / 1e12 + + return True, avg_time, tflops + + +def main(): + parser = argparse.ArgumentParser( + description="GEMM Stress Test - Multiple kernels with validation", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + python3 07_stress_test.py # Test all kernels + python3 07_stress_test.py --num-kernels 5 # Test first 5 kernels + python3 07_stress_test.py --benchmark # Include benchmarks + python3 07_stress_test.py --dtype bf16 # Test BF16 + python3 07_stress_test.py --size 2048 # Use 2048x2048 matrices + """, + ) + parser.add_argument( + "--dtype", + default="fp16", + choices=["fp16", "bf16", "fp32"], + help="Data type (default: fp16)", + ) + parser.add_argument( + "--num-kernels", + type=int, + default=0, + help="Number of kernels to test (0 = all)", + ) + parser.add_argument( + "--size", + type=int, + default=512, + help="Problem size MxNxK (default: 512)", + ) + parser.add_argument( + "--benchmark", + action="store_true", + help="Include benchmark timing", + ) + parser.add_argument( + "--rtol", + type=float, + default=1e-2, + help="Relative tolerance (default: 1e-2)", + ) + parser.add_argument( + "--atol", + type=float, + default=1e-2, + help="Absolute tolerance (default: 1e-2)", + ) + parser.add_argument( + "--arch", + default="gfx942", + help="Target architecture (default: gfx942)", + ) + args = parser.parse_args() + + reset_for_example() + + print("=" * 80) + print("Example 07: GEMM Stress Test - Multiple Kernels") + print("=" * 80) + + # Select kernels to test + specs = KERNEL_SPECS[: args.num_kernels] if args.num_kernels > 0 else KERNEL_SPECS + + # Print kernel summary + print_kernel_summary(specs, args.dtype) + + # Run validation + print("\n" + "=" * 80) + print(" VALIDATION RESULTS") + print("=" * 80) + + validator = Validator(rtol=args.rtol, atol=args.atol) + + if args.benchmark: + print( + f"\n {'#':<3} {'Name':<18} {'Tile':<12} {'Max Err':>10} {'Time':>10} {'TFLOPS':>8} {'Status':<8}" + ) + else: + print( + f"\n {'#':<3} {'Name':<18} {'Tile':<12} {'Max Err':>10} {'Info':<25} {'Status':<8}" + ) + print(" " + "-" * 78) + + passed = 0 + failed = 0 + skipped = 0 + + for i, spec in enumerate(specs, 1): + tile = f"{spec.tile_m}x{spec.tile_n}x{spec.tile_k}" + + try: + is_valid, max_err, info = validate_kernel( + spec, args.dtype, args.arch, args.size, validator, kernel_index=i + ) + + if is_valid: + status = "PASS" + passed += 1 + else: + status = "FAIL" + failed += 1 + + if args.benchmark: + success, avg_time, tflops = benchmark_kernel( + spec, args.dtype, args.arch, args.size + ) + if success: + print( + f" {i:<3} {spec.name:<18} {tile:<12} {max_err:>10.2e} {avg_time:>9.2f}ms {tflops:>7.1f} {status:<8}" + ) + else: + print( + f" {i:<3} {spec.name:<18} {tile:<12} {max_err:>10.2e} {'N/A':>10} {'N/A':>8} {status:<8}" + ) + else: + print( + f" {i:<3} {spec.name:<18} {tile:<12} {max_err:>10.2e} {info:<25} {status:<8}" + ) + + except Exception as e: + skipped += 1 + print( + f" {i:<3} {spec.name:<18} {tile:<12} {'N/A':>10} {str(e)[:25]:<25} {'SKIP':<8}" + ) + + # Summary + print("\n" + "=" * 80) + print(" SUMMARY") + print("=" * 80) + total = passed + failed + skipped + print(f"\n Results: {passed}/{total} passed, {failed} failed, {skipped} skipped") + print(f" Settings: dtype={args.dtype}, size={args.size}x{args.size}x{args.size}") + print(f" Tolerance: rtol={args.rtol}, atol={args.atol}") + print(f" Architecture: {args.arch}") + + if failed == 0 and skipped == 0: + print("\n *** ALL KERNELS PASSED ***") + elif failed > 0: + print(f"\n *** {failed} KERNELS FAILED ***") + + print("=" * 80) + + return 0 if failed == 0 else 1 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/dispatcher/examples/gemm/python/08_heuristics.py b/dispatcher/examples/gemm/python/08_heuristics.py new file mode 100644 index 00000000000..e2763c05135 --- /dev/null +++ b/dispatcher/examples/gemm/python/08_heuristics.py @@ -0,0 +1,718 @@ +#!/usr/bin/env python3 + +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +Example 08: Custom Heuristics + +Demonstrates custom kernel selection heuristics based on problem characteristics. + +This example shows how to: +1. Define multiple kernel configurations for different workloads +2. Implement custom heuristics to select the best kernel +3. Test heuristic selection across different problem sizes + +Heuristic strategies: +- Size-based: Small tiles for small problems, large tiles for large problems +- Compute-bound: Maximize compute utilization for large matrices +- Memory-bound: Optimize memory access for bandwidth-limited cases +- Latency-focused: Minimize kernel launch overhead for small problems + +Complexity: ★★★★☆ + +Usage: + python3 08_heuristics.py + python3 08_heuristics.py --help + python3 08_heuristics.py --strategy compute + python3 08_heuristics.py --dtype bf16 +""" + +import sys +import argparse +from pathlib import Path +from dataclasses import dataclass +from typing import List +from enum import Enum + +sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python")) +import numpy as np + +from ctypes_utils import ( + KernelConfig, + setup_gemm_dispatcher, + cleanup_gemm, + reset_for_example, +) + + +# ============================================================================= +# Kernel Specifications +# ============================================================================= + + +@dataclass +class KernelSpec: + """Kernel specification with metadata for heuristic selection""" + + name: str + tile_m: int + tile_n: int + tile_k: int + pipeline: str = "compv3" + scheduler: str = "intrawave" + # Metadata for heuristics + category: str = "balanced" # small, balanced, large, compute, memory + min_problem_size: int = 0 + max_problem_size: int = float("inf") + + +# Define kernel pool for heuristic selection (20+ kernels) +KERNEL_POOL = [ + # ========================================================================== + # SMALL TILES - Low latency, good for small problems + # ========================================================================== + KernelSpec( + "small_64x64_k32", + 64, + 64, + 32, + "compv3", + "intrawave", + category="small", + max_problem_size=256 * 256, + ), + KernelSpec( + "small_64x64_k64", + 64, + 64, + 64, + "compv3", + "intrawave", + category="small", + max_problem_size=256 * 256, + ), + KernelSpec( + "small_64x64_v4", + 64, + 64, + 32, + "compv4", + "intrawave", + category="small", + max_problem_size=256 * 256, + ), + # ========================================================================== + # MEDIUM TILES - Balanced performance + # ========================================================================== + KernelSpec( + "medium_128x128_k32", + 128, + 128, + 32, + "compv3", + "intrawave", + category="balanced", + min_problem_size=128 * 128, + max_problem_size=2048 * 2048, + ), + KernelSpec( + "medium_128x128_k64", + 128, + 128, + 64, + "compv3", + "intrawave", + category="balanced", + min_problem_size=256 * 256, + ), + KernelSpec( + "medium_128x128_k128", + 128, + 128, + 128, + "compv3", + "intrawave", + category="balanced", + min_problem_size=256 * 256, + ), + KernelSpec( + "medium_128x128_v4_k32", + 128, + 128, + 32, + "compv4", + "intrawave", + category="balanced", + min_problem_size=256 * 256, + ), + KernelSpec( + "medium_128x128_v4_k64", + 128, + 128, + 64, + "compv4", + "intrawave", + category="balanced", + min_problem_size=256 * 256, + ), + # Rectangular medium tiles + KernelSpec( + "rect_64x128_k32", + 64, + 128, + 32, + "compv3", + "intrawave", + category="balanced", + min_problem_size=128 * 128, + ), + KernelSpec( + "rect_128x64_k32", + 128, + 64, + 32, + "compv3", + "intrawave", + category="balanced", + min_problem_size=128 * 128, + ), + KernelSpec( + "rect_64x128_k64", + 64, + 128, + 64, + "compv3", + "intrawave", + category="balanced", + min_problem_size=256 * 256, + ), + KernelSpec( + "rect_128x64_k64", + 128, + 64, + 64, + "compv3", + "intrawave", + category="balanced", + min_problem_size=256 * 256, + ), + # ========================================================================== + # LARGE TILES - High throughput for large problems + # ========================================================================== + KernelSpec( + "large_256x128_k32", + 256, + 128, + 32, + "compv3", + "intrawave", + category="large", + min_problem_size=512 * 512, + ), + KernelSpec( + "large_256x128_k64", + 256, + 128, + 64, + "compv3", + "intrawave", + category="large", + min_problem_size=512 * 512, + ), + KernelSpec( + "large_128x256_k32", + 128, + 256, + 32, + "compv3", + "intrawave", + category="large", + min_problem_size=512 * 512, + ), + KernelSpec( + "large_128x256_k64", + 128, + 256, + 64, + "compv3", + "intrawave", + category="large", + min_problem_size=512 * 512, + ), + KernelSpec( + "large_256x256_k32", + 256, + 256, + 32, + "compv3", + "intrawave", + category="large", + min_problem_size=1024 * 1024, + ), + KernelSpec( + "large_256x256_k64", + 256, + 256, + 64, + "compv3", + "intrawave", + category="large", + min_problem_size=1024 * 1024, + ), + # ========================================================================== + # COMPUTE-OPTIMIZED - compv4 pipeline for compute-bound workloads + # ========================================================================== + KernelSpec( + "compute_128x128_v4_k32", + 128, + 128, + 32, + "compv4", + "intrawave", + category="compute", + min_problem_size=256 * 256, + ), + KernelSpec( + "compute_128x128_v4_k64", + 128, + 128, + 64, + "compv4", + "intrawave", + category="compute", + min_problem_size=256 * 256, + ), + KernelSpec( + "compute_256x128_v4", + 256, + 128, + 64, + "compv4", + "intrawave", + category="compute", + min_problem_size=512 * 512, + ), + KernelSpec( + "compute_256x256_v4", + 256, + 256, + 64, + "compv4", + "intrawave", + category="compute", + min_problem_size=1024 * 1024, + ), + # ========================================================================== + # MEMORY-OPTIMIZED - Good cache utilization for memory-bound workloads + # ========================================================================== + KernelSpec( + "memory_128x128_k16", + 128, + 128, + 16, + "compv3", + "intrawave", + category="memory", + min_problem_size=256 * 256, + ), + KernelSpec( + "memory_64x128_k16", + 64, + 128, + 16, + "compv3", + "intrawave", + category="memory", + min_problem_size=128 * 128, + ), +] + + +def create_kernel_config(spec: KernelSpec, dtype: str, arch: str) -> KernelConfig: + """Create KernelConfig from spec""" + warp_m = 16 if spec.tile_m <= 64 else 32 + warp_n = 16 if spec.tile_n <= 64 else 32 + + return KernelConfig( + dtype_a=dtype, + dtype_b=dtype, + dtype_c=dtype, + dtype_acc="fp32", + layout_a="row", + layout_b="col", + layout_c="row", + tile_m=spec.tile_m, + tile_n=spec.tile_n, + tile_k=spec.tile_k, + wave_m=2, + wave_n=2, + wave_k=1, + warp_m=warp_m, + warp_n=warp_n, + warp_k=16, + pipeline=spec.pipeline, + scheduler=spec.scheduler, + epilogue="cshuffle", + gfx_arch=arch, + ) + + +# ============================================================================= +# Heuristic Strategies +# ============================================================================= + + +class HeuristicStrategy(Enum): + SIZE_BASED = "size" + COMPUTE_BOUND = "compute" + MEMORY_BOUND = "memory" + LATENCY_FOCUSED = "latency" + + +def size_based_heuristic( + M: int, N: int, K: int, kernels: List[KernelSpec] +) -> KernelSpec: + """ + Select kernel based on problem size. + - Small problems: Use small tiles for low latency + - Medium problems: Use balanced tiles + - Large problems: Use large tiles for high throughput + + Also considers K dimension for tile_k selection. + """ + total_elements = M * N + + # Filter by problem size constraints + candidates = [ + k for k in kernels if k.min_problem_size <= total_elements <= k.max_problem_size + ] + + if not candidates: + candidates = kernels # Fall back to all kernels + + # Determine target category based on problem size + if total_elements < 256 * 256: + target_category = "small" + elif total_elements < 1024 * 1024: + target_category = "balanced" + else: + target_category = "large" + + # Filter by category if possible + category_candidates = [k for k in candidates if k.category == target_category] + if category_candidates: + candidates = category_candidates + + # Select best tile_k based on K dimension + # Prefer tile_k that divides K well + def tile_k_score(k): + if K % k.tile_k == 0: + return 0 # Perfect division + return K % k.tile_k # Remainder (lower is better) + + # Sort by tile_k fit, then by tile size + candidates.sort(key=lambda k: (tile_k_score(k), -k.tile_m * k.tile_n)) + + return candidates[0] + + +def compute_bound_heuristic( + M: int, N: int, K: int, kernels: List[KernelSpec] +) -> KernelSpec: + """ + Select kernel optimized for compute-bound workloads. + Prefers compv4 pipeline and larger tiles. + Selects based on problem size to maximize compute utilization. + """ + total_elements = M * N + + # Prefer compute category kernels + compute_kernels = [k for k in kernels if k.category == "compute"] + + if not compute_kernels: + # Fall back to compv4 kernels + compute_kernels = [k for k in kernels if k.pipeline == "compv4"] + + if not compute_kernels: + compute_kernels = kernels + + # Filter by problem size + valid = [k for k in compute_kernels if k.min_problem_size <= total_elements] + if valid: + compute_kernels = valid + + # For large problems, prefer larger tiles + if total_elements >= 1024 * 1024: + return max(compute_kernels, key=lambda k: k.tile_m * k.tile_n * k.tile_k) + else: + # For smaller problems, prefer medium tiles + return min( + compute_kernels, key=lambda k: abs(k.tile_m - 128) + abs(k.tile_n - 128) + ) + + +def memory_bound_heuristic( + M: int, N: int, K: int, kernels: List[KernelSpec] +) -> KernelSpec: + """ + Select kernel optimized for memory-bound workloads. + Prefers smaller tile_k for better memory access patterns. + """ + # Prefer memory category kernels first + memory_kernels = [k for k in kernels if k.category == "memory"] + if memory_kernels: + # Select based on problem size + total = M * N + if total < 512 * 512: + return min(memory_kernels, key=lambda k: k.tile_m * k.tile_n) + return max(memory_kernels, key=lambda k: k.tile_m * k.tile_n) + + # Fall back to balanced with smaller tile_k + balanced = [k for k in kernels if k.category == "balanced"] + if balanced: + # Prefer smaller tile_k for memory-bound + return min(balanced, key=lambda k: k.tile_k) + + # Fall back to medium-sized tile with small tile_k + return min( + kernels, key=lambda k: (k.tile_k, abs(k.tile_m - 128) + abs(k.tile_n - 128)) + ) + + +def latency_focused_heuristic( + M: int, N: int, K: int, kernels: List[KernelSpec] +) -> KernelSpec: + """ + Select kernel optimized for low latency. + Prefers smaller tiles and compv4 for faster execution. + """ + # Prefer small category + small_kernels = [k for k in kernels if k.category == "small"] + + if small_kernels: + # Among small kernels, prefer compv4 for lower latency + v4_small = [k for k in small_kernels if k.pipeline == "compv4"] + if v4_small: + return v4_small[0] + return small_kernels[0] + + # Fall back to smallest tile with compv4 if available + all_v4 = [k for k in kernels if k.pipeline == "compv4"] + if all_v4: + return min(all_v4, key=lambda k: k.tile_m * k.tile_n) + + # Fall back to smallest tile + return min(kernels, key=lambda k: k.tile_m * k.tile_n) + + +HEURISTICS = { + HeuristicStrategy.SIZE_BASED: size_based_heuristic, + HeuristicStrategy.COMPUTE_BOUND: compute_bound_heuristic, + HeuristicStrategy.MEMORY_BOUND: memory_bound_heuristic, + HeuristicStrategy.LATENCY_FOCUSED: latency_focused_heuristic, +} + + +# ============================================================================= +# Main +# ============================================================================= + + +def print_kernel_pool(kernels: List[KernelSpec]): + """Print available kernels""" + print("\n" + "=" * 75) + print(" KERNEL POOL") + print("=" * 75) + print(f"\n {'#':<3} {'Name':<22} {'Tile':<14} {'Pipeline':<10} {'Category':<12}") + print(" " + "-" * 73) + + for i, k in enumerate(kernels, 1): + tile = f"{k.tile_m}x{k.tile_n}x{k.tile_k}" + print(f" {i:<3} {k.name:<22} {tile:<14} {k.pipeline:<10} {k.category:<12}") + + print(" " + "-" * 73) + + +def main(): + parser = argparse.ArgumentParser( + description="Custom Heuristics Example - intelligent kernel selection", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + python3 08_heuristics.py # Default size-based heuristic + python3 08_heuristics.py --strategy compute # Compute-bound heuristic + python3 08_heuristics.py --strategy memory # Memory-bound heuristic + python3 08_heuristics.py --strategy latency # Latency-focused heuristic + python3 08_heuristics.py --dtype bf16 # BF16 mode + """, + ) + parser.add_argument( + "--dtype", + default="fp16", + choices=["fp16", "bf16", "fp32"], + help="Data type (default: fp16)", + ) + parser.add_argument( + "--strategy", + default="size", + choices=["size", "compute", "memory", "latency"], + help="Heuristic strategy (default: size)", + ) + parser.add_argument( + "--arch", + default="gfx942", + help="Target architecture (default: gfx942)", + ) + args = parser.parse_args() + + reset_for_example() + + print("=" * 75) + print("Example 08: Custom Heuristics") + print("=" * 75) + + # Map strategy string to enum + strategy_map = { + "size": HeuristicStrategy.SIZE_BASED, + "compute": HeuristicStrategy.COMPUTE_BOUND, + "memory": HeuristicStrategy.MEMORY_BOUND, + "latency": HeuristicStrategy.LATENCY_FOCUSED, + } + strategy = strategy_map[args.strategy] + heuristic_fn = HEURISTICS[strategy] + + print(f"\n Strategy: {strategy.value}") + print(f" Data type: {args.dtype}") + + # Print kernel pool + print_kernel_pool(KERNEL_POOL) + + # ========================================================================= + # Test heuristic selection across different problem sizes + # ========================================================================= + print("\n" + "=" * 75) + print(" HEURISTIC SELECTION TEST") + print("=" * 75) + + np_dtype = np.float16 if args.dtype in ["fp16", "bf16"] else np.float32 + + test_sizes = [ + (128, 128, 64), # Small + (256, 256, 128), # Small-medium + (512, 512, 256), # Medium + (1024, 1024, 512), # Medium-large + (2048, 2048, 1024), # Large + ] + + print( + f"\n {'Size':<20} {'Selected Kernel':<25} {'Time (ms)':>10} {'TFLOPS':>10} {'Status':<8}" + ) + print(" " + "-" * 78) + + results = [] + + for M, N, K in test_sizes: + # Use heuristic to select kernel + selected_spec = heuristic_fn(M, N, K, KERNEL_POOL) + + # Create config and setup + config = create_kernel_config(selected_spec, args.dtype, args.arch) + + setup = setup_gemm_dispatcher( + config=config, + registry_name=f"heuristic_{selected_spec.name}", + verbose=False, + auto_rebuild=True, + ) + + size_str = f"{M}x{N}x{K}" + + if not setup.success: + print( + f" {size_str:<20} {selected_spec.name:<25} {'N/A':>10} {'N/A':>10} {'FAIL':<8}" + ) + results.append((size_str, selected_spec.name, False, 0, 0)) + cleanup_gemm() + continue + + dispatcher = setup.dispatcher + + if not dispatcher.is_supported(M, N, K): + print( + f" {size_str:<20} {selected_spec.name:<25} {'N/A':>10} {'N/A':>10} {'SKIP':<8}" + ) + results.append((size_str, selected_spec.name, False, 0, 0)) + cleanup_gemm() + continue + + # Run GEMM + np.random.seed(42) + A = (np.random.randn(M, K) * 0.1).astype(np_dtype) + B = (np.random.randn(K, N) * 0.1).astype(np_dtype) + + result = dispatcher.run(A, B, M, N, K) + + if not result.success: + print( + f" {size_str:<20} {selected_spec.name:<25} {'N/A':>10} {'N/A':>10} {'FAIL':<8}" + ) + results.append((size_str, selected_spec.name, False, 0, 0)) + cleanup_gemm() + continue + + # Validate + C_ref = np.matmul(A.astype(np.float32), B.astype(np.float32)).astype(np_dtype) + max_err = np.max(np.abs(result.output - C_ref)) + passed = max_err < 1e-2 + + status = "PASS" if passed else "FAIL" + print( + f" {size_str:<20} {selected_spec.name:<25} {result.time_ms:>10.4f} {result.tflops:>10.2f} {status:<8}" + ) + results.append( + (size_str, selected_spec.name, passed, result.time_ms, result.tflops) + ) + + cleanup_gemm() + + # ========================================================================= + # Summary + # ========================================================================= + print("\n" + "=" * 75) + print(" SUMMARY") + print("=" * 75) + + passed = sum(1 for r in results if r[2]) + failed = len(results) - passed + + print(f"\n Strategy: {strategy.value}") + print(f" Results: {passed}/{len(results)} tests passed") + + # Show kernel selection distribution + kernel_usage = {} + for r in results: + kernel_usage[r[1]] = kernel_usage.get(r[1], 0) + 1 + + print("\n Kernel Selection Distribution:") + for kernel, count in sorted(kernel_usage.items(), key=lambda x: -x[1]): + print(f" {kernel}: {count} times") + + if results: + valid_results = [r for r in results if r[2]] + if valid_results: + avg_tflops = sum(r[4] for r in valid_results) / len(valid_results) + print(f"\n Average TFLOPS: {avg_tflops:.2f}") + + if failed == 0: + print("\n *** ALL TESTS PASSED ***") + else: + print(f"\n *** {failed} TESTS FAILED ***") + + print("=" * 75) + + return 0 if failed == 0 else 1 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/dispatcher/examples/gemm/python/09_multi_registry.py b/dispatcher/examples/gemm/python/09_multi_registry.py new file mode 100644 index 00000000000..97cbce34974 --- /dev/null +++ b/dispatcher/examples/gemm/python/09_multi_registry.py @@ -0,0 +1,220 @@ +#!/usr/bin/env python3 + +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +Example 09: Multiple Registries + +Demonstrates multiple registries for different optimization targets. + +Complexity: ★★★★★ + +Usage: + python3 09_multi_registry.py + python3 09_multi_registry.py --help + python3 09_multi_registry.py --dtype bf16 +""" + +import sys +import argparse +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python")) +import numpy as np + +from ctypes_utils import ( + KernelConfig, + Registry, + Dispatcher, + setup_gemm_dispatcher, + cleanup_gemm, + reset_for_example, +) + + +def main(): + parser = argparse.ArgumentParser( + description="Multiple Registries Example - optimization-specific registries", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + python3 09_multi_registry.py # Default FP16 + python3 09_multi_registry.py --dtype bf16 # BF16 mode + """, + ) + parser.add_argument( + "--dtype", + default="fp16", + choices=["fp16", "bf16", "fp32"], + help="Data type (default: fp16)", + ) + parser.add_argument( + "--arch", default="gfx942", help="Target architecture (default: gfx942)" + ) + args = parser.parse_args() + + reset_for_example() + + print("=" * 60) + print("Example 09: Multiple Registries") + print("=" * 60) + + # ========================================================================= + # Step 1: Setup base dispatcher + # ========================================================================= + print("\nStep 1: Setup Base Dispatcher") + + base_config = KernelConfig( + dtype_a=args.dtype, + dtype_b=args.dtype, + dtype_c=args.dtype, + tile_m=128, + tile_n=128, + tile_k=32, + gfx_arch=args.arch, + ) + + setup = setup_gemm_dispatcher(base_config, registry_name="base", verbose=True) + if not setup.success: + print(f" ERROR: {setup.error}") + return 1 + + lib = setup.lib + np_dtype = np.float16 if args.dtype in ["fp16", "bf16"] else np.float32 + + # ========================================================================= + # Step 2: Define configs for different optimization targets + # ========================================================================= + print("\nStep 2: Define Optimization Targets") + + compute_config = KernelConfig( + dtype_a=args.dtype, + dtype_b=args.dtype, + dtype_c=args.dtype, + tile_m=256, + tile_n=256, + tile_k=64, + wave_m=4, + wave_n=4, + pipeline="compv4", + gfx_arch=args.arch, + ) + memory_config = KernelConfig( + dtype_a=args.dtype, + dtype_b=args.dtype, + dtype_c=args.dtype, + tile_m=128, + tile_n=128, + tile_k=32, + wave_m=2, + wave_n=2, + pipeline="compv4", + gfx_arch=args.arch, + ) + latency_config = KernelConfig( + dtype_a=args.dtype, + dtype_b=args.dtype, + dtype_c=args.dtype, + tile_m=64, + tile_n=64, + tile_k=32, + wave_m=1, + wave_n=1, + pipeline="compv3", + gfx_arch=args.arch, + ) + + print(f" Compute: {compute_config.tile_str} (large matrices)") + print(f" Memory: {memory_config.tile_str} (medium matrices)") + print(f" Latency: {latency_config.tile_str} (small matrices)") + + # ========================================================================= + # Step 3: Create registries + # ========================================================================= + print("\nStep 3: Create Registries") + + compute_registry = Registry(name="compute", lib=lib) + compute_registry.register_kernel(compute_config) + + memory_registry = Registry(name="memory", lib=lib) + memory_registry.register_kernel(memory_config) + + latency_registry = Registry(name="latency", lib=lib) + latency_registry.register_kernel(latency_config) + + # ========================================================================= + # Step 4: Create dispatchers + # ========================================================================= + print("\nStep 4: Create Dispatchers") + + compute_dispatcher = Dispatcher(registry=compute_registry, lib=lib) + memory_dispatcher = Dispatcher(registry=memory_registry, lib=lib) + latency_dispatcher = Dispatcher(registry=latency_registry, lib=lib) + + print(f" {compute_dispatcher}") + print(f" {memory_dispatcher}") + print(f" {latency_dispatcher}") + + # ========================================================================= + # Step 5: Smart dispatcher selection + # ========================================================================= + print("\nStep 5: Smart Dispatcher Selection") + + def select_dispatcher(M: int, N: int, K: int) -> Dispatcher: + elements = M * N + if elements >= 4096 * 4096: + return compute_dispatcher + elif elements >= 1024 * 1024: + return memory_dispatcher + else: + return latency_dispatcher + + test_sizes = [ + (256, 256, 256), + (512, 512, 512), + (1024, 1024, 1024), + (2048, 2048, 2048), + (4096, 4096, 4096), + ] + + print(f"\n {'Size':<20} {'Registry':>10} {'Time (ms)':>12} {'TFLOPS':>10}") + print(" " + "-" * 55) + + for M, N, K in test_sizes: + dispatcher = select_dispatcher(M, N, K) + + if not dispatcher.is_supported(M, N, K): + continue + + A = np.random.randn(M, K).astype(np_dtype) * 0.1 + B = np.random.randn(K, N).astype(np_dtype) * 0.1 + + result = dispatcher.run(A, B, M, N, K) + + if result.success: + print( + f" {M}x{N}x{K:<10} {dispatcher.registry.name:>10} " + f"{result.time_ms:>12.4f} {result.tflops:>10.2f}" + ) + + # Cleanup + cleanup_gemm() + + # Summary + print("\n" + "=" * 60) + print("Multi-Registry Pattern:") + print("=" * 60) + print(" 1. Define KernelConfig for each optimization target") + print(" 2. Create Registry for each target") + print(" 3. Register configs to appropriate registries") + print(" 4. Create Dispatcher for each registry") + print(" 5. Select dispatcher based on problem characteristics") + print(" 6. Run GEMM with selected dispatcher") + print("=" * 60) + + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/dispatcher/examples/gemm/python/10_advanced_benchmark.py b/dispatcher/examples/gemm/python/10_advanced_benchmark.py new file mode 100644 index 00000000000..e16e4e271f0 --- /dev/null +++ b/dispatcher/examples/gemm/python/10_advanced_benchmark.py @@ -0,0 +1,260 @@ +#!/usr/bin/env python3 + +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +Example 10: Advanced Benchmarking with Full Control + +This example demonstrates all available benchmark parameters: + - warmup: Number of warmup iterations (default: 5) + - repeat: Number of benchmark iterations (default: 20) + - flush_cache: Flush GPU cache between iterations (default: False) + - timer: Timer type - "gpu" (default) or "cpu" + - init: Initialization method - "random", "linear", "constant" + +Usage: + python3 10_advanced_benchmark.py + python3 10_advanced_benchmark.py --warmup 10 --repeat 100 + python3 10_advanced_benchmark.py --init linear +""" + +import argparse +import sys +from pathlib import Path + +# Add paths for imports +sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python")) + +import numpy as np + +from ctypes_utils import ( + KernelConfig, + setup_gemm_dispatcher, + cleanup_gemm, + reset_for_example, +) + + +def parse_args(): + parser = argparse.ArgumentParser( + description="Advanced GEMM benchmarking with full parameter control" + ) + + # Problem size + parser.add_argument("-m", type=int, default=2048, help="M dimension") + parser.add_argument("-n", type=int, default=2048, help="N dimension") + parser.add_argument("-k", type=int, default=2048, help="K dimension") + + # Benchmark parameters + parser.add_argument( + "--warmup", type=int, default=5, help="Number of warmup iterations" + ) + parser.add_argument( + "--repeat", type=int, default=20, help="Number of benchmark iterations" + ) + parser.add_argument( + "--flush-cache", action="store_true", help="Flush GPU cache between iterations" + ) + parser.add_argument( + "--timer", choices=["gpu", "cpu"], default="gpu", help="Timer type (gpu or cpu)" + ) + parser.add_argument( + "--init", + choices=["random", "linear", "constant"], + default="random", + help="Initialization method", + ) + + # Kernel configuration + parser.add_argument("--dtype", default="fp16", help="Data type") + parser.add_argument("--pipeline", default="compv4", help="Pipeline type") + parser.add_argument("--arch", default="gfx942", help="GPU architecture") + + return parser.parse_args() + + +def initialize_matrix(shape, method, dtype): + """Initialize matrix with specified method""" + if method == "random": + return np.random.randn(*shape).astype(dtype) * 0.5 + elif method == "linear": + total = np.prod(shape) + return np.arange(total).reshape(shape).astype(dtype) / total + elif method == "constant": + return np.ones(shape, dtype=dtype) + else: + return np.random.randn(*shape).astype(dtype) + + +def main(): + args = parse_args() + + reset_for_example() + + print("=" * 70) + print("Example 10: Advanced GEMM Benchmarking") + print("=" * 70) + + # Show benchmark configuration + print("\nBenchmark Configuration:") + print(f" Problem Size: {args.m} x {args.n} x {args.k}") + print(f" Warmup: {args.warmup} iterations") + print(f" Repeat: {args.repeat} iterations") + print(f" Flush Cache: {args.flush_cache}") + print(f" Timer: {args.timer}") + print(f" Init Method: {args.init}") + print(f" Data Type: {args.dtype}") + print(f" Pipeline: {args.pipeline}") + print(f" Architecture: {args.arch}") + print() + + # Map dtype + np_dtype = np.float16 if args.dtype in ["fp16", "bf16"] else np.float32 + + # Initialize matrices + print("Step 1: Initialize matrices...") + A = initialize_matrix((args.m, args.k), args.init, np_dtype) + B = initialize_matrix((args.k, args.n), args.init, np_dtype) + print(f" A: {A.shape} ({args.init})") + print(f" B: {B.shape} ({args.init})") + + # Create kernel config (does not include M/N/K - those are problem size) + print("\nStep 2: Create kernel configuration...") + kernel_config = KernelConfig( + dtype_a=args.dtype, + dtype_b=args.dtype, + dtype_c=args.dtype, + dtype_acc="fp32", + layout_a="row", + layout_b="col", # B is column-major for optimal performance + layout_c="row", + tile_m=128, + tile_n=128, + tile_k=32, + wave_m=2, + wave_n=2, + wave_k=1, + warp_m=32, + warp_n=32, + warp_k=16, + pipeline=args.pipeline, + scheduler="intrawave", + epilogue="cshuffle", + gfx_arch=args.arch, + ) + print(f" Config: {args.dtype}, tile=128x128x32, {args.pipeline}") + + # Setup dispatcher + print("\nStep 3: Setup dispatcher...") + setup = setup_gemm_dispatcher( + config=kernel_config, + registry_name="benchmark_gemm", + verbose=False, + auto_rebuild=True, + ) + + if not setup.success: + print(f" ERROR: {setup.error}") + return 1 + + dispatcher = setup.dispatcher + print(f" Library: {setup.lib.path if setup.lib else 'N/A'}") + print(f" Kernel: {setup.lib.get_kernel_name() if setup.lib else 'N/A'}") + + # Run benchmark with multiple iterations + print("\nStep 4: Run benchmark...") + print(f" Running {args.warmup} warmup + {args.repeat} benchmark iterations...") + + # Warmup + for _ in range(args.warmup): + _ = dispatcher.run(A, B, args.m, args.n, args.k) + + # Benchmark + times = [] + for _ in range(args.repeat): + result = dispatcher.run(A, B, args.m, args.n, args.k) + if result.success: + times.append(result.time_ms) + + if times: + avg_time = sum(times) / len(times) + min_time = min(times) + max_time = max(times) + + # Calculate TFLOPS + flops = 2 * args.m * args.n * args.k + avg_tflops = (flops / 1e12) / (avg_time / 1000) if avg_time > 0 else 0 + max_tflops = (flops / 1e12) / (min_time / 1000) if min_time > 0 else 0 + + # Calculate bandwidth (C has same dtype as A and B) + C_bytes = args.m * args.n * np.dtype(np_dtype).itemsize + bandwidth_gb = ( + (A.nbytes + B.nbytes + C_bytes) / 1e9 / (avg_time / 1000) + if avg_time > 0 + else 0 + ) + + print(f"\n *** BENCHMARK RESULTS ({args.repeat} iterations) ***") + print(f" Average Time: {avg_time:.4f} ms") + print(f" Min Time: {min_time:.4f} ms") + print(f" Max Time: {max_time:.4f} ms") + print(f" Avg TFLOPS: {avg_tflops:.2f}") + print(f" Peak TFLOPS: {max_tflops:.2f}") + print(f" Bandwidth: {bandwidth_gb:.2f} GB/s") + else: + print(" FAILED: No successful runs") + return 1 + + # Summary + print("\n" + "=" * 70) + print("BENCHMARK PARAMETERS REFERENCE") + print("=" * 70) + print(""" +Available parameters for GEMM benchmarking: + + --warmup N Number of warmup iterations (discard results) + Higher = more stable results, longer run time + Default: 5 + + --repeat N Number of benchmark iterations + Higher = more accurate average, longer run time + Default: 20 + + --flush-cache Flush GPU L2 cache between iterations + Use for memory-bound benchmarks + Default: off + + --timer {gpu,cpu} Timer type + gpu = HIP events (more accurate for GPU) + cpu = std::chrono (includes kernel launch overhead) + Default: gpu + + --init METHOD Matrix initialization + random = uniform random [-0.5, 0.5] + linear = sequential values + constant = all ones + Default: random + +Note: For C++ examples, these parameters are passed to stream_config: + + ck_tile::stream_config cfg{ + nullptr, // stream_id + true, // time_kernel + 1, // log_level + 5, // cold_niters (warmup) + 20, // nrepeat + true, // is_gpu_timer + false, // flush_cache + 1 // rotating_count + }; +""") + + # Cleanup + cleanup_gemm() + + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/dispatcher/examples/gemm/python/11_json_import.py b/dispatcher/examples/gemm/python/11_json_import.py new file mode 100644 index 00000000000..06743af4064 --- /dev/null +++ b/dispatcher/examples/gemm/python/11_json_import.py @@ -0,0 +1,310 @@ +#!/usr/bin/env python3 + +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +Example 11: JSON-based Kernel Configuration Import + +Demonstrates loading kernel configurations from JSON files, similar to tile_engine. +This enables easy customization of kernel sets without modifying code. + +Key Features: + - Load tile configs from JSON (compatible with tile_engine format) + - Generate kernel sets from configuration + - Use arch_filter validation on loaded configs + - Export to C++ DECL_KERNEL_SET format + +Complexity: ★★★☆☆ + +Usage: + python3 11_json_import.py + python3 11_json_import.py --config my_kernels.json + python3 11_json_import.py --export-cpp +""" + +import sys +import argparse +import json +from pathlib import Path + +# Add codegen to path for kernel_config_loader +script_dir = Path(__file__).parent.resolve() +sys.path.insert(0, str(script_dir.parent.parent.parent / "codegen")) +sys.path.insert(0, str(script_dir.parent.parent.parent / "python")) + +from kernel_config_loader import ( # noqa: E402 + load_kernel_configs, + KernelConfig, + generate_cpp_kernel_set_declaration, +) + +from ctypes_utils import ( # noqa: E402 + KernelConfig as DispatcherKernelConfig, + setup_gemm_dispatcher, + cleanup_gemm, + reset_for_example, + validate_kernel_config, +) + +# Sample JSON configuration (embedded for demonstration) +SAMPLE_JSON_CONFIG = { + "_comment": "Sample kernel configuration for GEMM", + "kernel_set_name": "inference_kernels", + "datatype": {"a": "fp16", "b": "fp16", "c": "fp16", "acc": "fp32"}, + "layout": "rcr", + "tile_config": { + "tile_m": {"values": [128, 256]}, + "tile_n": {"values": [128, 256]}, + "tile_k": {"values": [32]}, + "warp_m": {"values": [2]}, + "warp_n": {"values": [2]}, + "warp_k": {"values": [1]}, + "warp_tile_m": {"values": [32]}, + "warp_tile_n": {"values": [32]}, + "warp_tile_k": {"values": [16]}, + }, + "trait_config": { + "pipeline": {"values": ["compv4"]}, + "scheduler": {"values": ["intrawave"]}, + "epilogue": {"values": ["cshuffle"]}, + "pad_m": {"values": [False]}, + "pad_n": {"values": [False]}, + "pad_k": {"values": [False]}, + }, + "gpu_targets": ["gfx942"], +} + + +def print_section(title: str): + """Print a section header""" + print(f"\n{'=' * 70}") + print(f" {title}") + print(f"{'=' * 70}\n") + + +def convert_to_dispatcher_config( + config: KernelConfig, arch: str = "gfx942" +) -> DispatcherKernelConfig: + """Convert kernel_config_loader.KernelConfig to dispatcher KernelConfig""" + return DispatcherKernelConfig( + dtype_a=config.dtype_a, + dtype_b=config.dtype_b, + dtype_c=config.dtype_c, + dtype_acc=config.dtype_acc, + tile_m=config.tile.tile_m, + tile_n=config.tile.tile_n, + tile_k=config.tile.tile_k, + wave_m=config.tile.warp_m, + wave_n=config.tile.warp_n, + wave_k=config.tile.warp_k, + warp_m=config.tile.warp_tile_m, + warp_n=config.tile.warp_tile_n, + warp_k=config.tile.warp_tile_k, + pipeline=config.trait.pipeline, + scheduler=config.trait.scheduler, + epilogue=config.trait.epilogue, + pad_m=config.trait.pad_m, + pad_n=config.trait.pad_n, + pad_k=config.trait.pad_k, + gfx_arch=arch, + variant=config.variant, + ) + + +def main(): + parser = argparse.ArgumentParser( + description="JSON Kernel Configuration Import Example", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + python3 11_json_import.py # Use embedded sample config + python3 11_json_import.py --config my.json # Load from file + python3 11_json_import.py --export-cpp # Generate C++ declarations + python3 11_json_import.py --validate # Validate configs against arch + """, + ) + parser.add_argument( + "--config", + type=str, + help="Path to JSON configuration file (uses embedded sample if not provided)", + ) + parser.add_argument( + "--export-cpp", + action="store_true", + help="Export kernel set as C++ DECL_KERNEL_SET", + ) + parser.add_argument( + "--validate", + action="store_true", + help="Validate all configurations against arch filter", + ) + parser.add_argument( + "--arch", + default="gfx942", + help="Target GPU architecture (default: gfx942)", + ) + args = parser.parse_args() + + reset_for_example() + + print_section("Example 11: JSON Kernel Configuration Import") + + # ========================================================================= + # Step 1: Load configuration from JSON + # ========================================================================= + print("Step 1: Load Kernel Configuration from JSON") + print("-" * 50) + + if args.config: + config_path = Path(args.config) + if not config_path.exists(): + print(f" ERROR: Config file not found: {config_path}") + return 1 + print(f" Loading from: {config_path}") + config_set = load_kernel_configs(config_path) + else: + # Use embedded sample config + print(" Using embedded sample configuration") + # Write to temp file and load + temp_path = Path("/tmp/sample_gemm_config.json") + with open(temp_path, "w") as f: + json.dump(SAMPLE_JSON_CONFIG, f, indent=2) + config_set = load_kernel_configs(temp_path) + + print(f"\n Kernel Set Name: {config_set.name}") + print( + f" Data Types: A={config_set.dtype_a}, B={config_set.dtype_b}, C={config_set.dtype_c}" + ) + print(f" Layout: {config_set.layout}") + print(f" GPU Targets: {config_set.gpu_targets}") + print(f" Total Configurations: {config_set.config_count()}") + + # ========================================================================= + # Step 2: Display configuration details + # ========================================================================= + print("\nStep 2: Configuration Details") + print("-" * 50) + + print("\n Tile Configurations:") + print(f" tile_m: {config_set.tile_m_values}") + print(f" tile_n: {config_set.tile_n_values}") + print(f" tile_k: {config_set.tile_k_values}") + print( + f" warp (wave): {config_set.warp_m_values}x{config_set.warp_n_values}x{config_set.warp_k_values}" + ) + print( + f" warp_tile: {config_set.warp_tile_m_values}x{config_set.warp_tile_n_values}x{config_set.warp_tile_k_values}" + ) + + print("\n Trait Configurations:") + print(f" pipeline: {config_set.pipeline_values}") + print(f" scheduler: {config_set.scheduler_values}") + print(f" epilogue: {config_set.epilogue_values}") + print( + f" padding: m={config_set.pad_m_values}, n={config_set.pad_n_values}, k={config_set.pad_k_values}" + ) + + # ========================================================================= + # Step 3: Generate and display kernel names + # ========================================================================= + print("\nStep 3: Generated Kernel Names") + print("-" * 50) + + configs = list(config_set.generate_configs()) + for i, config in enumerate(configs[:5]): + print(f" {i + 1}. {config.kernel_name()}") + if len(configs) > 5: + print(f" ... and {len(configs) - 5} more configurations") + + # ========================================================================= + # Step 4: Validate against arch filter (optional) + # ========================================================================= + if args.validate: + print("\nStep 4: Architecture Validation") + print("-" * 50) + + valid_count = 0 + invalid_count = 0 + + for config in configs: + disp_config = convert_to_dispatcher_config(config, args.arch) + result = validate_kernel_config(disp_config) + + if result.is_valid: + valid_count += 1 + else: + invalid_count += 1 + if invalid_count <= 3: # Show first 3 invalid + print(f"\n ✗ Invalid: {config.kernel_name()}") + for error in result.errors: + print(f" Error: {error}") + + print("\n Validation Summary:") + print(f" ✓ Valid: {valid_count}") + print(f" ✗ Invalid: {invalid_count}") + print(f" Total: {len(configs)}") + + # ========================================================================= + # Step 5: Export to C++ (optional) + # ========================================================================= + if args.export_cpp: + print("\nStep 5: C++ Export") + print("-" * 50) + print("\n // Generated DECL_KERNEL_SET from JSON config:") + print(" // " + "=" * 56) + cpp_code = generate_cpp_kernel_set_declaration(config_set) + for line in cpp_code.split("\n"): + print(f" {line}") + + # ========================================================================= + # Step 6: Use first config with dispatcher (demo) + # ========================================================================= + print("\nStep 6: Dispatcher Integration Demo") + print("-" * 50) + + if configs: + first_config = configs[0] + disp_config = convert_to_dispatcher_config(first_config, args.arch) + + print( + f"\n Using first config: {first_config.tile.tile_m}x{first_config.tile.tile_n}x{first_config.tile.tile_k}" + ) + + setup = setup_gemm_dispatcher( + disp_config, registry_name="json_import", verbose=False + ) + if setup.success: + print(" ✓ Dispatcher setup successful") + print( + f" Kernel header: {setup.kernel_header.name if setup.kernel_header else 'N/A'}" + ) + else: + print(f" ⚠ Dispatcher setup: {setup.error}") + print(" (This is expected if kernels aren't generated)") + + # ========================================================================= + # Summary + # ========================================================================= + print_section("Summary") + print(" JSON configuration allows easy kernel set customization:") + print(" - Define tile sizes and ranges") + print(" - Specify trait combinations (pipeline, scheduler, etc.)") + print(" - Target multiple GPU architectures") + print(" - Export to C++ DECL_KERNEL_SET for static compilation") + print() + print(" JSON Format (tile_engine compatible):") + print(' {"tile_config": {"tile_m": {"values": [128, 256]}, ...},') + print(' "trait_config": {"pipeline": {"values": ["compv4"]}, ...}}') + print() + print(" Usage:") + print(" config_set = load_kernel_configs('my_kernels.json')") + print(" for config in config_set.generate_configs():") + print(" # Use config for codegen or dispatcher setup") + + cleanup_gemm() + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/dispatcher/examples/gemm/python/README.md b/dispatcher/examples/gemm/python/README.md new file mode 100644 index 00000000000..0a83f3533fc --- /dev/null +++ b/dispatcher/examples/gemm/python/README.md @@ -0,0 +1,299 @@ +# GEMM Python Examples + +CK Tile Dispatcher Python examples for GEMM (General Matrix Multiplication) operations. + +> **Main Documentation**: [Dispatcher README](../../../README.md) | [Examples Overview](../../README.md) + +## Quick Start + +### Build Library + +```bash +cd /path/to/composable_kernel/dispatcher +mkdir -p build && cd build + +cmake .. \ + -DCMAKE_PREFIX_PATH=/opt/rocm \ + -DCMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \ + -DBUILD_DISPATCHER_EXAMPLES=ON + +# Build Python library (kernels generated automatically) +make dispatcher_gemm_lib -j$(nproc) +``` + +### Run Examples + +```bash +cd /path/to/composable_kernel/dispatcher + +python3 examples/gemm/python/01_basic_gemm.py +python3 examples/gemm/python/04_validation.py +python3 examples/gemm/python/07_stress_test.py +python3 examples/gemm/python/08_heuristics.py +``` + +## Examples + +| Example | Description | +|---------|-------------| +| [01_basic_gemm.py](01_basic_gemm.py) | Basic GEMM with multi-kernel support | +| [02_batch_gemm.py](02_batch_gemm.py) | Batched GEMM operations | +| [03_benchmark.py](03_benchmark.py) | Performance benchmarking | +| [04_validation.py](04_validation.py) | CPU reference validation | +| [05_numpy_integration.py](05_numpy_integration.py) | NumPy array integration | +| [06_json_export.py](06_json_export.py) | Registry JSON export | +| [07_stress_test.py](07_stress_test.py) | Multi-kernel stress testing | +| [08_heuristics.py](08_heuristics.py) | Heuristic-based kernel selection | +| [09_multi_registry.py](09_multi_registry.py) | Multiple registries | +| [10_advanced_benchmark.py](10_advanced_benchmark.py) | Advanced benchmark with full control | +| [11_json_import.py](11_json_import.py) | Import kernels from JSON | + +## Example Details + +### 01_basic_gemm.py - Basic GEMM +Demonstrates the Python API with multi-kernel support: + +```python +from ctypes_utils import KernelConfig, setup_gemm_dispatcher, print_kernel_config_table + +# Define multiple kernel configurations +kernels = [ + KernelConfig( + tile_m=128, tile_n=128, tile_k=32, + wave_m=2, wave_n=2, wave_k=1, + warp_tile_m=32, warp_tile_n=32, warp_tile_k=16, + pipeline="compv3", scheduler="intrawave" + ), + KernelConfig( + tile_m=256, tile_n=256, tile_k=32, + wave_m=2, wave_n=2, wave_k=1, + warp_tile_m=32, warp_tile_n=32, warp_tile_k=16, + pipeline="compv4", scheduler="intrawave" + ), +] + +# Display configurations +print_kernel_config_table(kernels) + +# Set up dispatcher with all kernels +lib, dispatcher, registry = setup_gemm_dispatcher(kernels) + +# Run GEMM +elapsed_ms = run_gemm(lib, M, N, K, ...) +``` + +### 02_batch_gemm.py - Batch GEMM +Batched matrix multiplication: +- Multiple independent GEMM operations +- Batch dimension handling + +### 03_benchmark.py - Benchmarking +Performance measurement: +- GPU timing +- TFLOPS calculation +- Multiple iterations + +### 04_validation.py - Validation +Correctness verification: +- NumPy reference implementation +- Tolerance-based validation +- Error reporting + +### 05_numpy_integration.py - NumPy Integration +Seamless NumPy integration: +- NumPy arrays to GPU buffers +- Results back to NumPy +- Automatic type conversion + +### 06_json_export.py - JSON Export +Registry serialization for tool integration: +- Export kernel configurations +- Machine-readable format + +### 07_stress_test.py - Stress Testing +Comprehensive multi-kernel stress testing: + +```python +from ctypes_utils import KernelConfig, setup_gemm_dispatcher, print_kernel_config_table + +# Define 48 unique kernel configurations +kernels = [ + KernelConfig(tile_m=128, tile_n=128, tile_k=32, pipeline="compv3", ...), + KernelConfig(tile_m=256, tile_n=256, tile_k=32, pipeline="compv4", ...), + KernelConfig(tile_m=128, tile_n=256, tile_k=64, pipeline="compv3", ...), + # ... many more configurations +] + +# Test each kernel +for i, kernel in enumerate(kernels): + lib, dispatcher, registry = setup_gemm_dispatcher([kernel]) + result = run_and_validate(lib, M, N, K, seed=42 + i) # Different seed per kernel + print(f"Kernel {i}: {result.max_err:.6e} {'PASS' if result.passed else 'FAIL'}") +``` + +**Features:** +- 48 unique kernel configurations +- Various tile sizes, pipelines, and schedulers +- Per-kernel validation with unique random seeds +- Performance reporting + +### 08_heuristics.py - Heuristic Selection +Custom kernel selection based on problem characteristics: + +```python +# Define kernel pools for different strategies +SMALL_KERNELS = [KernelConfig(tile_m=64, tile_n=64, ...), ...] +LARGE_KERNELS = [KernelConfig(tile_m=256, tile_n=256, ...), ...] +COMPUTE_KERNELS = [KernelConfig(pipeline="compv4", ...), ...] +MEMORY_KERNELS = [KernelConfig(pipeline="compv3", ...), ...] + +# Size-based heuristic +def size_based_heuristic(M, N, K): + if M * N < 512 * 512: + return SMALL_KERNELS + else: + return LARGE_KERNELS + +# Strategy-based selection +def compute_strategy(): + return COMPUTE_KERNELS # Optimized for compute-bound problems + +def memory_strategy(): + return MEMORY_KERNELS # Optimized for memory-bound problems + +# Test different strategies +for strategy in [size_based_heuristic, compute_strategy, memory_strategy]: + kernels = strategy(M, N, K) + lib, dispatcher, registry = setup_gemm_dispatcher(kernels) + elapsed_ms = run_gemm(lib, M, N, K, ...) +``` + +**Features:** +- 24 kernel configurations across 6 categories +- Size-based heuristic (small vs large) +- Optimization strategies (compute, memory, latency) +- Performance comparison across strategies + +### 09_multi_registry.py - Multiple Registries +Separate registries for different workloads: +- Compute-optimized registry +- Latency-optimized registry +- Dynamic registry selection + +### 10_advanced_benchmark.py - Advanced Benchmark +Full control over benchmark parameters: +- Warmup iterations +- Benchmark iterations +- Statistical analysis + +### 11_json_import.py - JSON Import +Import kernel configurations from JSON: +- External configuration files +- Dynamic kernel loading + +## Utility Module: ctypes_utils.py + +```python +from ctypes_utils import ( + KernelConfig, # Single kernel configuration + setup_gemm_dispatcher, # Set up dispatcher with kernels + print_kernel_config_table, # Display kernel configurations + Dispatcher, # High-level dispatcher + Registry, # Kernel registry + Validator, # Validation utilities +) +``` + +### KernelConfig + +```python +config = KernelConfig( + # Tile sizes + tile_m=256, tile_n=256, tile_k=32, + # Wave configuration + wave_m=2, wave_n=2, wave_k=1, + # Warp tile sizes + warp_tile_m=32, warp_tile_n=32, warp_tile_k=16, + # Pipeline and scheduler + pipeline="compv4", # "compv3" or "compv4" + scheduler="intrawave", # "intrawave" or "interwave" + # Optional + epilogue="default", + padding=True, + double_buffer=True, +) +``` + +### setup_gemm_dispatcher + +```python +# Single kernel +lib, dispatcher, registry = setup_gemm_dispatcher(config) + +# Multiple kernels +lib, dispatcher, registry = setup_gemm_dispatcher([config1, config2, ...]) + +# With auto-rebuild +lib, dispatcher, registry = setup_gemm_dispatcher(config, auto_rebuild=True) +``` + +### print_kernel_config_table + +```python +kernels = [config1, config2, config3] +print_kernel_config_table(kernels) +# Output: +# +----+-------+-------+-------+--------+-----------+ +# | # | Tile | Wave | Warp | Pipe | Scheduler | +# +----+-------+-------+-------+--------+-----------+ +# | 1 | 128x128x32 | 2x2x1 | 32x32x16 | compv3 | intrawave | +# | 2 | 256x256x32 | 2x2x1 | 32x32x16 | compv4 | intrawave | +# | 3 | 128x256x64 | 2x2x1 | 32x32x16 | compv3 | interwave | +# +----+-------+-------+-------+--------+-----------+ +``` + +### GPU Memory Management + +```python +import ctypes +import numpy as np + +# Load HIP library +hip = ctypes.CDLL("libamdhip64.so") + +# Allocate GPU memory +gpu_ptr = ctypes.c_void_p() +hip.hipMalloc(ctypes.byref(gpu_ptr), size_in_bytes) + +# Copy to GPU (1 = hipMemcpyHostToDevice) +hip.hipMemcpy(gpu_ptr, host_array.ctypes.data, size, 1) + +# Copy back (2 = hipMemcpyDeviceToHost) +hip.hipMemcpy(host_array.ctypes.data, gpu_ptr, size, 2) + +# Free +hip.hipFree(gpu_ptr) +``` + +## Performance Testing + +Test compilation performance with different kernel counts: + +```bash +# Test with 10 kernels (~15s compile time) +python3 01_basic_gemm.py --num-kernels 10 + +# Test with 20 kernels (~25s compile time) +python3 01_basic_gemm.py --num-kernels 20 + +# Test with 48 kernels (~50s compile time) +python3 01_basic_gemm.py --num-kernels 48 +``` + +Compilation time scales roughly linearly with kernel count. + +## Related Documentation + +- [C++ GEMM Examples](../cpp/README.md) +- [Python Conv Examples](../../conv/python/README.md) +- [Main Dispatcher README](../../../README.md) diff --git a/dispatcher/examples/gemm/python/kernels.json b/dispatcher/examples/gemm/python/kernels.json new file mode 100644 index 00000000000..214b1cc42c0 --- /dev/null +++ b/dispatcher/examples/gemm/python/kernels.json @@ -0,0 +1,80 @@ +{ + "registry": "export_demo", + "kernel_count": 3, + "kernels": [ + { + "tile": "128x128x32", + "dtypes": { + "A": "fp16", + "B": "fp16", + "C": "fp16" + }, + "layout": "rcr", + "pipeline": "compv4", + "target": "gfx942" + }, + { + "tile": "256x256x64", + "dtypes": { + "A": "fp16", + "B": "fp16", + "C": "fp16" + }, + "layout": "rcr", + "pipeline": "compv4", + "target": "gfx942" + }, + { + "tile": "64x64x32", + "dtypes": { + "A": "fp16", + "B": "fp16", + "C": "fp16" + }, + "layout": "rcr", + "pipeline": "compv4", + "target": "gfx942" + } + ], + "cpp_registry": { + "metadata": { + "timestamp": "Dec 4 2025 06:23:15", + "total_kernels": 1, + "export_version": "1.0", + "dispatcher_version": "1.0.0" + }, + "statistics": { + "by_datatype": {}, + "by_pipeline": {}, + "by_scheduler": {} + }, + "kernels": [ + { + "identifier": "128x128x32_2x2x1_32x32x16_nopers", + "name": "gemm_fp16_rcrr_compv4_cshuffle_intrawave_False_False_False_False_128x128x32_2x2x1_32x32x16", + "algorithm": { + "tile_shape": { + "m": 128, + "n": 128, + "k": 32 + }, + "wave_shape": { + "m": 2, + "n": 2, + "k": 1 + }, + "warp_tile_shape": { + "m": 32, + "n": 32, + "k": 16 + }, + "block_size": 256, + "persistent": false, + "double_buffer": true, + "preshuffle": false, + "transpose_c": false + } + } + ] + } +} \ No newline at end of file diff --git a/dispatcher/include/ck_tile/dispatcher.hpp b/dispatcher/include/ck_tile/dispatcher.hpp new file mode 100644 index 00000000000..98d8bb93332 --- /dev/null +++ b/dispatcher/include/ck_tile/dispatcher.hpp @@ -0,0 +1,19 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +/// Main dispatcher header - includes all core components +/// Use this for convenient access to the full dispatcher API + +#include "ck_tile/dispatcher/kernel_key.hpp" +#include "ck_tile/dispatcher/kernel_config.hpp" +#include "ck_tile/dispatcher/kernel_decl.hpp" +#include "ck_tile/dispatcher/problem.hpp" +#include "ck_tile/dispatcher/kernel_instance.hpp" +#include "ck_tile/dispatcher/registry.hpp" +#include "ck_tile/dispatcher/dispatcher.hpp" +#include "ck_tile/dispatcher/arch_filter.hpp" +#include "ck_tile/dispatcher/backends/tile_backend.hpp" +#include "ck_tile/dispatcher/backends/generated_tile_backend.hpp" +#include "ck_tile/dispatcher/utils.hpp" diff --git a/dispatcher/include/ck_tile/dispatcher/README.md b/dispatcher/include/ck_tile/dispatcher/README.md new file mode 100644 index 00000000000..db3ce996a92 --- /dev/null +++ b/dispatcher/include/ck_tile/dispatcher/README.md @@ -0,0 +1,161 @@ +# CK Tile Dispatcher - C++ Headers + +C++ API for the CK Tile dispatcher. + +> **See also:** [Main Dispatcher README](../../../../README.md) for installation and core concepts. + +## File Organization + +``` +dispatcher/ +├── dispatcher.hpp # Main dispatcher (kernel selection) +├── registry.hpp # Kernel registry (storage & lookup) +├── problem.hpp # Problem specification +├── kernel_key.hpp # Kernel configuration key +├── kernel_instance.hpp # Kernel instance interface +├── utils.hpp # Utilities (timers, GPU buffers) +│ +└── backends/ # Backend implementations + ├── generated_tile_backend.hpp # CK Tile kernels (production) + └── tile_backend.hpp # Tile backend base +``` + +## Quick Start + +```cpp +#include "ck_tile/dispatcher.hpp" + +using namespace ck_tile::dispatcher; +using namespace ck_tile::dispatcher::utils; + +int main() { + // 1. Build kernel key + KernelKeyBuilder builder = KernelKeyBuilder::fp16_rcr(); + builder.tile_m = 128; + builder.tile_n = 128; + builder.tile_k = 32; + KernelKey key = builder.build(); + + // 2. Register kernel + auto kernel = create_generated_tile_kernel<...>(key, "my_kernel"); + Registry::instance().register_kernel(kernel, Priority::High); + + // 3. Run GEMM + Dispatcher dispatcher; + Problem problem(1024, 1024, 1024); + float time_ms = dispatcher.run(a_ptr, b_ptr, c_ptr, problem, nullptr); +} +``` + +## Core Classes + +### KernelKey (`kernel_key.hpp`) + +Uniquely identifies a kernel configuration: + +```cpp +KernelKeyBuilder builder; +builder.dtype_a = DataType::FP16; +builder.layout_a = LayoutTag::Row; +builder.tile_m = 256; +builder.pipeline = Pipeline::CompV4; +KernelKey key = builder.build(); +``` + +### Registry (`registry.hpp`) + +Thread-safe kernel storage: + +```cpp +auto& registry = Registry::instance(); +registry.register_kernel(kernel, Priority::High); +registry.get_kernel_count(); +registry.export_json(); +``` + +### Dispatcher (`dispatcher.hpp`) + +Kernel selection and execution: + +```cpp +Dispatcher dispatcher; + +// Strategies +dispatcher.set_strategy(SelectionStrategy::FirstFit); +dispatcher.set_strategy(SelectionStrategy::Heuristic); + +// Run +float time = dispatcher.run(a, b, c, problem, stream); +``` + +### Problem (`problem.hpp`) + +GEMM problem specification: + +```cpp +Problem problem(M, N, K); +problem.batch_size = 4; +problem.alpha = 1.0f; +problem.beta = 0.0f; + +// Auto-inference +auto p = Problem::from_ab(a_rows, a_cols, b_rows, b_cols, trans_a, trans_b); +``` + +## Utilities (`utils.hpp`) + +### GPU Memory + +```cpp +GpuBuffer buffer(size); +buffer.copy_from_host(host_ptr); +buffer.copy_to_host(host_ptr); +buffer.zero(); +``` + +### Timing + +```cpp +GpuTimer timer; +timer.start(); +// kernel... +timer.stop(); +float ms = timer.elapsed_ms(); +``` + +### Quick Helpers + +```cpp +// Create FP16 RCR key +auto key = create_fp16_rcr_key(tile_m, tile_n, tile_k, ...); + +// Performance +double tflops = calculate_tflops(M, N, K, time_ms); + +// Validation +auto result = validate_result(gpu_ptr, cpu_ptr, size); +``` + +## Backend + +### Generated Tile Backend + +```cpp +#include "ck_tile/dispatcher/backends/generated_tile_backend.hpp" + +auto kernel = create_generated_tile_kernel< + SelectedKernel, ADataType, BDataType, CDataType, AccDataType +>(key, name); +``` + +## Best Practices + +1. Use `Release` build for performance +2. Register kernels at startup +3. Use `Priority::High` for hand-tuned kernels +4. Reuse dispatcher instances +5. Clear registry between test runs + +--- + +> **More info:** See [../../../../README.md](../../../../README.md) for full documentation. diff --git a/dispatcher/include/ck_tile/dispatcher/arch_filter.hpp b/dispatcher/include/ck_tile/dispatcher/arch_filter.hpp new file mode 100644 index 00000000000..33a864a6490 --- /dev/null +++ b/dispatcher/include/ck_tile/dispatcher/arch_filter.hpp @@ -0,0 +1,393 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +/** + * Architecture-Specific Kernel Filtering for CK Tile Dispatcher + * + * Provides GPU architecture-aware validation of kernel configurations. + * Uses arch_specs_generated.hpp as single source of truth (generated from arch_specs.json). + * + * Usage: + * ArchFilter filter("gfx942"); + * + * // Check if a kernel configuration is valid + * if (filter.is_valid(kernel_key)) { + * registry.register_kernel(kernel); + * } + * + * // Get validation result with error details + * auto result = filter.validate(kernel_key); + * if (!result.valid) { + * for (const auto& error : result.errors) { + * std::cerr << error << "\n"; + * } + * } + * + * Adding New GPU Support: + * 1. Edit dispatcher/codegen/arch_specs.json + * 2. Run: python dispatcher/codegen/generate_arch_specs.py + * 3. Rebuild the dispatcher + */ + +#pragma once + +#include "ck_tile/dispatcher/kernel_key.hpp" +#include "ck_tile/dispatcher/arch_specs_generated.hpp" +#include +#include +#include +#include + +namespace ck_tile { +namespace dispatcher { + +// ============================================================================= +// Re-export from generated header for convenience +// ============================================================================= + +// Use the generated types and functions from arch_specs namespace +using GpuArch = arch_specs::GpuArch; +using WarpConfig = arch_specs::WarpConfig; +using WarpTileConfig = std::array; + +// Re-export string conversion functions +using arch_specs::arch_to_string; +using arch_specs::element_size; +using arch_specs::get_lds_capacity; +using arch_specs::get_supported_warp_configs; +using arch_specs::is_trait_unsupported; +using arch_specs::string_to_arch; + +// ============================================================================= +// Additional Helper Functions +// ============================================================================= + +/// Get supported warp tile configurations for arch and data types +/// This function wraps the generated data with runtime logic +inline std::vector get_supported_warp_tiles(GpuArch arch, + DataType dtype_a, + DataType dtype_b, + [[maybe_unused]] DataType dtype_c) +{ + // Common FP16 configurations (from arch_specs.json) + std::vector fp16_configs = { + {32, 32, 8}, {16, 16, 16}, {32, 32, 16}, {16, 16, 32}, {4, 64, 16}, {64, 4, 16}}; + + // FP8 configurations + std::vector fp8_gfx942 = { + {32, 32, 16}, {32, 32, 32}, {16, 16, 32}, {16, 16, 64}}; + std::vector fp8_gfx950 = { + {32, 32, 16}, {32, 32, 32}, {16, 16, 32}, {16, 16, 64}, {16, 16, 128}, {32, 32, 64}}; + + // INT8 configurations + std::vector int8_configs = {{16, 16, 32}, {32, 32, 16}}; + + // GFX1201 only supports limited FP16 + std::vector rdna4_fp16 = {{16, 16, 16}}; + + // Match based on architecture and data types + if(dtype_a == DataType::FP16 && dtype_b == DataType::FP16) + { + if(arch == GpuArch::GFX_1201) + return rdna4_fp16; + return fp16_configs; + } + if(dtype_a == DataType::BF16 && dtype_b == DataType::BF16) + { + if(arch == GpuArch::GFX_1201) + return {}; // Not supported on RDNA4 + return fp16_configs; // Same as FP16 + } + if(dtype_a == DataType::FP8 || dtype_a == DataType::BF8) + { + if(arch == GpuArch::GFX_950) + return fp8_gfx950; + if(arch == GpuArch::GFX_942) + return fp8_gfx942; + if(arch == GpuArch::GFX_90A) + return {{32, 32, 16}, {32, 32, 32}}; + } + if(dtype_a == DataType::INT8 && dtype_b == DataType::INT8) + { + if(arch == GpuArch::GFX_942) + return int8_configs; + } + + return {}; // Unknown combination +} + +// ============================================================================= +// Validation Result +// ============================================================================= + +/// Result of kernel validation +struct ValidationResult +{ + bool valid = true; + std::vector errors; + std::vector warnings; + + explicit operator bool() const { return valid; } + + void add_error(const std::string& msg) + { + errors.push_back(msg); + valid = false; + } + + void add_warning(const std::string& msg) { warnings.push_back(msg); } +}; + +// ============================================================================= +// Architecture Filter +// ============================================================================= + +/** + * Architecture-specific kernel filter. + * + * Validates kernel configurations against GPU architecture constraints + * including warp configurations, warp tiles, LDS capacity, and traits. + */ +class ArchFilter +{ + public: + /** + * Create architecture filter. + * @param arch Target GPU architecture + * @param strict_mode If true, unknown configurations are rejected + */ + explicit ArchFilter(GpuArch arch, bool strict_mode = false) + : arch_(arch), strict_mode_(strict_mode) + { + } + + /** + * Create architecture filter from string. + * @param arch_str GPU architecture string (e.g., "gfx942") + * @param strict_mode If true, unknown configurations are rejected + */ + explicit ArchFilter(const std::string& arch_str, bool strict_mode = false) + : arch_(string_to_arch(arch_str)), strict_mode_(strict_mode) + { + } + + /** + * Quick validation check. + * @param key Kernel configuration key + * @return true if configuration is valid for this architecture + */ + [[nodiscard]] bool is_valid(const KernelKey& key) const { return validate(key).valid; } + + /** + * Detailed validation with error messages. + * @param key Kernel configuration key + * @return ValidationResult with valid flag and error/warning messages + */ + [[nodiscard]] ValidationResult validate(const KernelKey& key) const + { + ValidationResult result; + + // Check architecture match + if(!key.gfx_arch.empty() && string_to_arch(key.gfx_arch) != arch_) + { + result.add_warning("Kernel compiled for different architecture: " + key.gfx_arch); + } + + // Validate dimensions + validate_dimensions(key, result); + + // Validate warp configuration + validate_warp_config(key, result); + + // Validate warp tile configuration + validate_warp_tiles(key, result); + + // Validate trait combination + validate_traits(key, result); + + // Validate LDS capacity + validate_lds(key, result); + + return result; + } + + /// Get target architecture + [[nodiscard]] GpuArch get_arch() const { return arch_; } + + /// Get target architecture as string + [[nodiscard]] std::string get_arch_string() const { return arch_to_string(arch_); } + + private: + void validate_dimensions(const KernelKey& key, ValidationResult& result) const + { + const auto& alg = key.algorithm; + + // Check positive dimensions + if(alg.tile_shape.m <= 0 || alg.tile_shape.n <= 0 || alg.tile_shape.k <= 0) + { + result.add_error("Tile dimensions must be positive"); + return; + } + + // Check warp tiles fit in block tiles + int warp_m_coverage = alg.wave_shape.m * alg.warp_tile_shape.m; + int warp_n_coverage = alg.wave_shape.n * alg.warp_tile_shape.n; + int warp_k_coverage = alg.wave_shape.k * alg.warp_tile_shape.k; + + if(warp_m_coverage > alg.tile_shape.m) + { + result.add_error("warp_m * warp_tile_m > tile_m: " + std::to_string(warp_m_coverage) + + " > " + std::to_string(alg.tile_shape.m)); + } + if(warp_n_coverage > alg.tile_shape.n) + { + result.add_error("warp_n * warp_tile_n > tile_n: " + std::to_string(warp_n_coverage) + + " > " + std::to_string(alg.tile_shape.n)); + } + if(warp_k_coverage > alg.tile_shape.k) + { + result.add_error("warp_k * warp_tile_k > tile_k: " + std::to_string(warp_k_coverage) + + " > " + std::to_string(alg.tile_shape.k)); + } + + // Check alignment + if(alg.tile_shape.m % warp_m_coverage != 0) + { + result.add_error("tile_m must be divisible by warp_m * warp_tile_m"); + } + if(alg.tile_shape.n % warp_n_coverage != 0) + { + result.add_error("tile_n must be divisible by warp_n * warp_tile_n"); + } + if(alg.tile_shape.k % warp_k_coverage != 0) + { + result.add_error("tile_k must be divisible by warp_k * warp_tile_k"); + } + } + + void validate_warp_config(const KernelKey& key, ValidationResult& result) const + { + auto supported = get_supported_warp_configs(arch_); + if(supported.empty()) + { + if(strict_mode_) + { + result.add_error("No warp configurations defined for " + get_arch_string()); + } + else + { + result.add_warning("No warp configurations defined for " + get_arch_string()); + } + return; + } + + WarpConfig current = { + key.algorithm.wave_shape.m, key.algorithm.wave_shape.n, key.algorithm.wave_shape.k}; + + bool found = false; + for(const auto& cfg : supported) + { + if(cfg == current) + { + found = true; + break; + } + } + + if(!found) + { + result.add_error("Invalid warp configuration [" + std::to_string(current[0]) + ", " + + std::to_string(current[1]) + ", " + std::to_string(current[2]) + + "] for " + get_arch_string()); + } + } + + void validate_warp_tiles(const KernelKey& key, ValidationResult& result) const + { + auto supported = get_supported_warp_tiles( + arch_, key.signature.dtype_a, key.signature.dtype_b, key.signature.dtype_c); + + if(supported.empty()) + { + // Unknown data type combination - allow with warning + result.add_warning("No warp tile combinations defined for data types"); + return; + } + + WarpTileConfig current = {key.algorithm.warp_tile_shape.m, + key.algorithm.warp_tile_shape.n, + key.algorithm.warp_tile_shape.k}; + + bool found = false; + for(const auto& cfg : supported) + { + if(cfg == current) + { + found = true; + break; + } + } + + if(!found) + { + result.add_error("Invalid warp tile [" + std::to_string(current[0]) + ", " + + std::to_string(current[1]) + ", " + std::to_string(current[2]) + + "] for " + get_arch_string()); + } + } + + void validate_traits(const KernelKey& key, ValidationResult& result) const + { + if(is_trait_unsupported( + key.algorithm.pipeline, key.algorithm.epilogue, key.algorithm.scheduler)) + { + result.add_error("Unsupported trait combination"); + } + } + + void validate_lds(const KernelKey& key, ValidationResult& result) const + { + const auto& sig = key.signature; + const auto& alg = key.algorithm; + + float elem_a = element_size(sig.dtype_a); + float elem_b = element_size(sig.dtype_b); + + std::size_t matrix_a_size = alg.tile_shape.m * alg.tile_shape.k * elem_a; + std::size_t matrix_b_size = alg.tile_shape.n * alg.tile_shape.k * elem_b; + std::size_t total_lds = matrix_a_size + matrix_b_size; + + std::size_t max_lds = get_lds_capacity(alg.pipeline); + + if(total_lds > max_lds) + { + result.add_error("LDS capacity exceeded: " + std::to_string(total_lds) + " bytes > " + + std::to_string(max_lds) + " bytes limit"); + } + } + + GpuArch arch_; + bool strict_mode_; +}; + +// ============================================================================= +// Registry Integration Helper +// ============================================================================= + +/** + * Create a filter function for use with Registry::filter() + * + * @tparam KernelT Kernel instance type with get_key() method + * @param arch Target GPU architecture + * @return Predicate function that returns true for valid kernels + */ +template +inline auto make_arch_filter_predicate(const std::string& arch) +{ + return [filter = ArchFilter(arch)](const KernelT& kernel) { + return filter.is_valid(kernel.get_key()); + }; +} + +} // namespace dispatcher +} // namespace ck_tile diff --git a/dispatcher/include/ck_tile/dispatcher/arch_specs_generated.hpp b/dispatcher/include/ck_tile/dispatcher/arch_specs_generated.hpp new file mode 100644 index 00000000000..af52c8eb1d0 --- /dev/null +++ b/dispatcher/include/ck_tile/dispatcher/arch_specs_generated.hpp @@ -0,0 +1,168 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +/** + * AUTO-GENERATED FILE - DO NOT EDIT DIRECTLY! + * + * Generated from: arch_specs.json + * Generated at: 2026-01-05T19:34:01.229811 + * + * To update this file: + * 1. Edit arch_specs.json + * 2. Run: python generate_arch_specs.py + */ + +#pragma once + +#include "ck_tile/dispatcher/kernel_key.hpp" +#include +#include +#include +#include + +namespace ck_tile { +namespace dispatcher { +namespace arch_specs { + +// ============================================================================= +// GPU Architecture Enum (Generated) +// ============================================================================= + +enum class GpuArch : std::uint8_t +{ + GFX_908, // AMD Instinct MI100 + GFX_90A, // AMD Instinct MI200 series + GFX_942, // AMD Instinct MI300 series + GFX_950, // AMD Instinct MI350 series + GFX_1100, // AMD Radeon RX 7900 series (RDNA3) + GFX_1200, // AMD Radeon RX 9000 series (RDNA4) + GFX_1201, // AMD Radeon RX 9000 series (RDNA4) + UNKNOWN +}; + +// ============================================================================= +// String Conversion Functions (Generated) +// ============================================================================= + +inline std::string arch_to_string(GpuArch arch) +{ + switch(arch) + { + case GpuArch::GFX_908: return "gfx908"; + case GpuArch::GFX_90A: return "gfx90a"; + case GpuArch::GFX_942: return "gfx942"; + case GpuArch::GFX_950: return "gfx950"; + case GpuArch::GFX_1100: return "gfx1100"; + case GpuArch::GFX_1200: return "gfx1200"; + case GpuArch::GFX_1201: return "gfx1201"; + default: return "unknown"; + } +} + +inline GpuArch string_to_arch(const std::string& arch_str) +{ + if(arch_str == "gfx908") + return GpuArch::GFX_908; + if(arch_str == "gfx90a") + return GpuArch::GFX_90A; + if(arch_str == "gfx942") + return GpuArch::GFX_942; + if(arch_str == "gfx950") + return GpuArch::GFX_950; + if(arch_str == "gfx1100") + return GpuArch::GFX_1100; + if(arch_str == "gfx1200") + return GpuArch::GFX_1200; + if(arch_str == "gfx1201") + return GpuArch::GFX_1201; + return GpuArch::UNKNOWN; +} + +// ============================================================================= +// Element Size (Generated) +// ============================================================================= + +inline float element_size(DataType dtype) +{ + switch(dtype) + { + case DataType::FP16: return 2.0f; + case DataType::BF16: return 2.0f; + case DataType::FP32: return 4.0f; + case DataType::FP64: return 8.0f; + case DataType::FP8: return 1.0f; + case DataType::BF8: return 1.0f; + case DataType::INT8: return 1.0f; + case DataType::INT4: return 0.5f; + case DataType::INT32: return 4.0f; + default: return 2.0f; + } +} + +// ============================================================================= +// Warp Configurations (Generated) +// ============================================================================= + +using WarpConfig = std::array; + +inline std::vector get_supported_warp_configs(GpuArch arch) +{ + switch(arch) + { + case GpuArch::GFX_908: return {{1, 4, 1}, {2, 2, 1}, {4, 1, 1}}; + case GpuArch::GFX_90A: return {{1, 4, 1}, {2, 2, 1}, {4, 1, 1}}; + case GpuArch::GFX_942: return {{1, 4, 1}, {2, 2, 1}, {4, 1, 1}}; + case GpuArch::GFX_950: return {{1, 4, 1}, {2, 2, 1}, {4, 1, 1}}; + case GpuArch::GFX_1100: return {{2, 4, 1}, {1, 8, 1}, {8, 1, 1}, {4, 2, 1}}; + case GpuArch::GFX_1200: return {{2, 4, 1}, {1, 8, 1}, {8, 1, 1}, {4, 2, 1}}; + case GpuArch::GFX_1201: return {{2, 4, 1}, {1, 8, 1}, {8, 1, 1}, {4, 2, 1}}; + default: return {}; + } +} + +// ============================================================================= +// LDS Capacity Limits (Generated) +// ============================================================================= + +inline std::size_t get_lds_capacity(Pipeline pipeline) +{ + if(pipeline == Pipeline::Mem) + return 65536; + if(pipeline == Pipeline::CompV1) + return 65536; + if(pipeline == Pipeline::CompV2) + return 65536; + if(pipeline == Pipeline::CompV3) + return 65536; + if(pipeline == Pipeline::CompV4) + return 32768; + if(pipeline == Pipeline::CompV5) + return 65536; + if(pipeline == Pipeline::PreShuffleV1) + return 32768; + if(pipeline == Pipeline::PreShuffleV2) + return 32768; + return 65536; // Default +} + +// ============================================================================= +// Unsupported Trait Combinations (Generated) +// ============================================================================= + +inline bool +is_trait_unsupported(Pipeline pipeline, [[maybe_unused]] Epilogue epilogue, Scheduler scheduler) +{ + // Generated from unsupported_trait_combos in arch_specs.json + if(scheduler == Scheduler::Interwave) + { + if(pipeline == Pipeline::CompV3 || pipeline == Pipeline::CompV4) + { + return true; + } + } + return false; +} + +} // namespace arch_specs +} // namespace dispatcher +} // namespace ck_tile diff --git a/dispatcher/include/ck_tile/dispatcher/backends/generated_kernel_backend.hpp b/dispatcher/include/ck_tile/dispatcher/backends/generated_kernel_backend.hpp new file mode 100644 index 00000000000..79f8f30a9b3 --- /dev/null +++ b/dispatcher/include/ck_tile/dispatcher/backends/generated_kernel_backend.hpp @@ -0,0 +1,143 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +/** + * Generated Kernel Backend + * + * Backend for kernels generated by unified_gemm_codegen.py + * with unique namespace wrapping (Kernel_{name}). + * + * Status: Work in progress - use generated_tile_backend.hpp for now + * + * This backend handles the new codegen format with unique kernel structs. + */ + +#pragma once + +#include "ck_tile/dispatcher/kernel_instance.hpp" +#include "ck_tile/core.hpp" +#include "ck_tile/host.hpp" +#include + +namespace ck_tile { +namespace dispatcher { +namespace backends { + +/** + * Kernel instance wrapper for unified_gemm_codegen.py generated kernels + * + * These kernels have: + * - namespace {kernel_name}_ns { ... } (NEW format) + * - struct Kernel_{name} with static launch() method + * - struct SelectedKernel alias for compatibility + * - Type aliases: ADataType, BDataType, CDataType, AccDataType + * + * Note: Currently use generated_tile_backend.hpp for production + */ +template +class GeneratedKernelInstance : public KernelInstance +{ + public: + using SelectedKernel = SelectedKernelType; + using ADataType = typename SelectedKernel::ADataType; + using BDataType = typename SelectedKernel::BDataType; + using CDataType = typename SelectedKernel::CDataType; + using AccDataType = typename SelectedKernel::AccDataType; + + GeneratedKernelInstance(const KernelKey& key, const std::string& name) : key_(key), name_(name) + { + } + + const KernelKey& get_key() const override { return key_; } + + bool supports(const Problem& problem) const override + { + // Check dimension divisibility based on padding flags + constexpr bool pad_m = SelectedKernel::kPadM; + constexpr bool pad_n = SelectedKernel::kPadN; + constexpr bool pad_k = SelectedKernel::kPadK; + + if(pad_m && pad_n && pad_k) + { + return true; // Padding enabled - supports any size + } + + // Check divisibility for dimensions without padding + constexpr int tile_m = SelectedKernel::TileM; + constexpr int tile_n = SelectedKernel::TileN; + constexpr int tile_k = SelectedKernel::TileK; + + if(!pad_m && problem.M % tile_m != 0) + return false; + if(!pad_n && problem.N % tile_n != 0) + return false; + if(!pad_k && problem.K % tile_k != 0) + return false; + + return true; + } + + std::string get_name() const override { return name_; } + + float run(const void* a_ptr, + const void* b_ptr, + void* c_ptr, + const void** d_ptrs, + const Problem& problem, + void* stream) const override + { + (void)d_ptrs; // Not used in basic GEMM + + // Create arguments using constructor + ck_tile::GemmHostArgs args(a_ptr, // a_ptr + b_ptr, // b_ptr + c_ptr, // e_ptr/c_ptr + problem.k_batch, // k_batch + problem.M, // M + problem.N, // N + problem.K, // K + problem.K, // stride_A (row-major A: stride = K) + problem.K, // stride_B (column-major B: stride = K) + problem.N // stride_E/C (row-major C: stride = N) + ); + + // Create stream config for timing + ck_tile::stream_config stream_cfg; + stream_cfg.stream_id_ = reinterpret_cast(stream); + stream_cfg.time_kernel_ = true; + stream_cfg.log_level_ = 0; + stream_cfg.cold_niters_ = 5; // Warmup iterations + stream_cfg.nrepeat_ = 10; // Measurement iterations + stream_cfg.is_gpu_timer_ = true; + stream_cfg.flush_cache_ = false; + stream_cfg.rotating_count_ = 1; + + // Call the generated kernel's launch method + return SelectedKernel::launch(args, stream_cfg); + } + + bool validate(const void* a_ptr, + const void* b_ptr, + const void* c_ptr, + const void** d_ptrs, + const Problem& problem, + float tolerance) const override + { + (void)a_ptr; + (void)b_ptr; + (void)c_ptr; + (void)d_ptrs; + (void)problem; + (void)tolerance; + // Validation would require reference implementation + return true; + } + + private: + KernelKey key_; + std::string name_; +}; + +} // namespace backends +} // namespace dispatcher +} // namespace ck_tile diff --git a/dispatcher/include/ck_tile/dispatcher/backends/generated_tile_backend.hpp b/dispatcher/include/ck_tile/dispatcher/backends/generated_tile_backend.hpp new file mode 100644 index 00000000000..76565045cfc --- /dev/null +++ b/dispatcher/include/ck_tile/dispatcher/backends/generated_tile_backend.hpp @@ -0,0 +1,157 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck_tile/dispatcher/kernel_instance.hpp" +#include "ck_tile/dispatcher/validation/reference_kernels.hpp" +#include "ck_tile/core.hpp" +#include "ck_tile/host.hpp" +#include "ck_tile/ops/gemm/kernel/gemm_kernel.hpp" +#include +#include +#include +#include + +namespace ck_tile { +namespace dispatcher { +namespace backends { + +/** + * Kernel instance wrapper for unified_gemm_codegen.py generated kernels + * + * These kernels have structure: + * - Types defined outside: using ADataType = ...; using BDataType = ...; + * - struct SelectedKernel with static constexpr config and launch() method + * - constexpr const char* KERNEL_NAME = "..."; + * + * This is different from tile_engine style where everything is in SelectedKernel. + */ +template +class GeneratedTileKernelInstance : public KernelInstance +{ + public: + using ADataType = ADataType_; + using BDataType = BDataType_; + using CDataType = CDataType_; + using AccDataType = AccDataType_; + using SelectedKernel = SelectedKernelType; + + GeneratedTileKernelInstance(const KernelKey& key, const std::string& name) + : key_(key), name_(name) + { + } + + const KernelKey& get_key() const override { return key_; } + + bool supports(const Problem& problem) const override + { + // Check dimension divisibility if padding not enabled + constexpr bool pad_m = SelectedKernel::kPadM; + constexpr bool pad_n = SelectedKernel::kPadN; + constexpr bool pad_k = SelectedKernel::kPadK; + + if(pad_m && pad_n && pad_k) + { + return true; // Padding enabled - supports any size + } + + // Check divisibility + constexpr int tile_m = SelectedKernel::TileM; + constexpr int tile_n = SelectedKernel::TileN; + constexpr int tile_k = SelectedKernel::TileK; + + if(!pad_m && problem.M % tile_m != 0) + return false; + if(!pad_n && problem.N % tile_n != 0) + return false; + if(!pad_k && problem.K % tile_k != 0) + return false; + + return true; + } + + std::string get_name() const override { return name_; } + + float run(const void* a_ptr, + const void* b_ptr, + void* c_ptr, + const void** d_ptrs, + const Problem& problem, + void* stream) const override + { + (void)d_ptrs; // Not used in basic GEMM + + // Create arguments using constructor (correct order!) + // Order from GemmHostArgs constructor: a_ptr, b_ptr, e_ptr, k_batch, M, N, K, stride_A, + // stride_B, stride_E + ck_tile::GemmHostArgs args(a_ptr, // a_ptr + b_ptr, // b_ptr + c_ptr, // e_ptr/c_ptr + problem.k_batch, // k_batch (4th argument!) + problem.M, // M + problem.N, // N + problem.K, // K + problem.K, // stride_A (row-major A: stride = K) + problem.K, // stride_B (column-major B: stride = K) + problem.N // stride_E/C (row-major C: stride = N) + ); + + // Create stream config for timing + ck_tile::stream_config stream_cfg; + stream_cfg.stream_id_ = reinterpret_cast(stream); + stream_cfg.time_kernel_ = true; + stream_cfg.log_level_ = 0; // No logging for performance + stream_cfg.cold_niters_ = 5; // Warmup iterations + stream_cfg.nrepeat_ = 10; // Measurement iterations + stream_cfg.is_gpu_timer_ = true; + stream_cfg.flush_cache_ = false; + stream_cfg.rotating_count_ = 1; + + // Call the generated kernel's launch method + return SelectedKernel::launch(args, stream_cfg); + } + + bool validate(const void* a_ptr, + const void* b_ptr, + const void* c_ptr, + const void** d_ptrs, + const Problem& problem, + float tolerance) const override + { + (void)a_ptr; + (void)b_ptr; + (void)c_ptr; + (void)d_ptrs; + (void)problem; + (void)tolerance; + // Validation would require reference implementation + return true; + } + + private: + KernelKey key_; + std::string name_; +}; + +/// Helper function to create a generated tile kernel instance wrapper +template +std::shared_ptr create_generated_tile_kernel(const KernelKey& key, + const std::string& name) +{ + return std::make_shared< + GeneratedTileKernelInstance>( + key, name); +} + +} // namespace backends +} // namespace dispatcher +} // namespace ck_tile diff --git a/dispatcher/include/ck_tile/dispatcher/backends/kernel_registration.hpp b/dispatcher/include/ck_tile/dispatcher/backends/kernel_registration.hpp new file mode 100644 index 00000000000..01ab1f5e521 --- /dev/null +++ b/dispatcher/include/ck_tile/dispatcher/backends/kernel_registration.hpp @@ -0,0 +1,109 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck_tile/dispatcher/backends/tile_backend.hpp" +#include "ck_tile/dispatcher/registry.hpp" +#include + +namespace ck_tile { +namespace dispatcher { +namespace backends { + +/// Helper to register a CK Tile generated kernel +/// This should be called from generated code for each kernel +template +void register_tile_kernel(Registry& registry, const std::string& kernel_name) +{ + // Extract metadata from SelectedKernel static members + KernelKey key; + + // Signature + key.signature.dtype_a = static_cast(SelectedKernel::ADataType); + key.signature.dtype_b = static_cast(SelectedKernel::BDataType); + key.signature.dtype_c = static_cast(SelectedKernel::CDataType); + key.signature.dtype_acc = static_cast(SelectedKernel::AccDataType); + + key.signature.layout_a = static_cast(SelectedKernel::ALayout); + key.signature.layout_b = static_cast(SelectedKernel::BLayout); + key.signature.layout_c = static_cast(SelectedKernel::CLayout); + + key.signature.transpose_a = false; // Extract from kernel if available + key.signature.transpose_b = false; + key.signature.grouped = false; + key.signature.split_k = 1; + + key.signature.elementwise_op = "PassThrough"; // Extract if available + key.signature.num_d_tensors = 0; + key.signature.structured_sparsity = SelectedKernel::UseStructuredSparsity; + + // Algorithm + key.algorithm.tile_shape.m = SelectedKernel::TileM; + key.algorithm.tile_shape.n = SelectedKernel::TileN; + key.algorithm.tile_shape.k = SelectedKernel::TileK; + + key.algorithm.wave_shape.m = SelectedKernel::WarpPerBlock_M; + key.algorithm.wave_shape.n = SelectedKernel::WarpPerBlock_N; + key.algorithm.wave_shape.k = SelectedKernel::WarpPerBlock_K; + + key.algorithm.warp_tile_shape.m = SelectedKernel::WarpTileM; + key.algorithm.warp_tile_shape.n = SelectedKernel::WarpTileN; + key.algorithm.warp_tile_shape.k = SelectedKernel::WarpTileK; + + // Extract pipeline, epilogue, scheduler from traits + key.algorithm.pipeline = Pipeline::CompV4; // Extract from kernel + key.algorithm.epilogue = Epilogue::Default; // Extract from kernel + key.algorithm.scheduler = Scheduler::Auto; // Extract from kernel + + key.algorithm.block_size = SelectedKernel::BlockSize; + key.algorithm.double_buffer = SelectedKernel::DoubleSmemBuffer; + key.algorithm.persistent = SelectedKernel::UsePersistentKernel; + key.algorithm.preshuffle = false; // Extract if available + key.algorithm.transpose_c = SelectedKernel::TransposeC; + key.algorithm.num_wave_groups = 1; // Extract if available + + key.gfx_arch = 942; // Extract from build configuration + + // Create kernel instance + auto kernel_instance = std::make_shared>(key, kernel_name); + + // Register with high priority (Tile kernels preferred) + registry.register_kernel(kernel_instance, Registry::Priority::High); +} + +/// Macro to simplify kernel registration in generated code +#define CK_TILE_REGISTER_KERNEL(SelectedKernel, KernelName, Registry) \ + ::ck_tile::dispatcher::backends::register_tile_kernel(Registry, KernelName) + +/// Helper to register multiple kernels from a list +template +struct KernelRegistrar +{ + static void register_all(Registry& registry) + { + // This would be specialized for each kernel set + // For now, empty implementation + } +}; + +/// Auto-registration helper +/// Place this in generated files to automatically register kernels +template +struct AutoRegister +{ + AutoRegister(const std::string& kernel_name) + { + auto& registry = Registry::instance(); + register_tile_kernel(registry, kernel_name); + } +}; + +/// Macro for auto-registration +#define CK_TILE_AUTO_REGISTER(SelectedKernel, KernelName) \ + static ::ck_tile::dispatcher::backends::AutoRegister \ + auto_register_##SelectedKernel{KernelName}; + +} // namespace backends +} // namespace dispatcher +} // namespace ck_tile diff --git a/dispatcher/include/ck_tile/dispatcher/backends/tile_backend.hpp b/dispatcher/include/ck_tile/dispatcher/backends/tile_backend.hpp new file mode 100644 index 00000000000..a3a0b046856 --- /dev/null +++ b/dispatcher/include/ck_tile/dispatcher/backends/tile_backend.hpp @@ -0,0 +1,173 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck_tile/dispatcher/kernel_instance.hpp" +#include "ck_tile/dispatcher/validation/reference_kernels.hpp" +#include "ck_tile/core.hpp" +#include "ck_tile/host.hpp" +#include +#include +#include +#include +#include +#include + +namespace ck_tile { +namespace dispatcher { +namespace backends { + +/// Kernel instance for CK Tile generated kernels +template +class TileKernelInstance : public KernelInstance +{ + public: + TileKernelInstance(const KernelKey& key, const std::string& name) : key_(key), name_(name) {} + + const KernelKey& get_key() const override { return key_; } + + bool supports(const Problem& problem) const override + { + // Check dimension divisibility if padding not enabled + constexpr bool pad_m = SelectedKernel::kPadM; + constexpr bool pad_n = SelectedKernel::kPadN; + constexpr bool pad_k = SelectedKernel::kPadK; + + if(pad_m && pad_n && pad_k) + { + // Padding enabled - supports any size + return true; + } + + // Check divisibility + constexpr int tile_m = SelectedKernel::TileM; + constexpr int tile_n = SelectedKernel::TileN; + constexpr int tile_k = SelectedKernel::TileK; + + if(!pad_m && problem.M % tile_m != 0) + return false; + if(!pad_n && problem.N % tile_n != 0) + return false; + if(!pad_k && problem.K % tile_k != 0) + return false; + + // Check shared memory budget if specified + if(problem.smem_budget > 0) + { + int64_t estimated_smem = estimate_smem_usage(); + if(estimated_smem > problem.smem_budget) + return false; + } + + return true; + } + + std::string get_name() const override { return name_; } + + float run(const void* a_ptr, + const void* b_ptr, + void* c_ptr, + const void** d_ptrs, + const Problem& problem, + void* stream) const override + { + // Convert void* stream to hipStream_t + hipStream_t hip_stream = reinterpret_cast(stream); + + // Construct kernel arguments + using ADataType = typename SelectedKernel::ADataType; + using BDataType = typename SelectedKernel::BDataType; + using CDataType = typename SelectedKernel::CDataType; + + // Note: d_ptrs not yet supported in basic CK Tile kernels + (void)d_ptrs; // Suppress unused parameter warning + + auto kargs = SelectedKernel::MakeKernelArgs(static_cast(a_ptr), + static_cast(b_ptr), + static_cast(c_ptr), + problem.M, + problem.N, + problem.K, + problem.k_batch); + + // Validate arguments + if(!SelectedKernel::IsSupportedArgument(kargs)) + { + throw std::runtime_error("Kernel does not support the given arguments"); + } + + // Calculate grid and block dimensions + dim3 grids = SelectedKernel::GridSize(problem.M, problem.N, problem.K); + dim3 blocks = SelectedKernel::BlockSize(); + size_t lds_bytes = SelectedKernel::GetSmemSize(); + + // Time kernel execution + hipEvent_t start, stop; + (void)hipEventCreate(&start); + (void)hipEventCreate(&stop); + + (void)hipEventRecord(start, hip_stream); + + // Launch kernel + ck_tile::launch_kernel(SelectedKernel::Kernel, grids, blocks, lds_bytes, hip_stream, kargs); + + (void)hipEventRecord(stop, hip_stream); + (void)hipEventSynchronize(stop); + + float elapsed_ms = 0.0f; + (void)hipEventElapsedTime(&elapsed_ms, start, stop); + + (void)hipEventDestroy(start); + (void)hipEventDestroy(stop); + + return elapsed_ms; + } + + bool validate(const void* a_ptr, + const void* b_ptr, + const void* c_ptr, + const void** d_ptrs, + const Problem& problem, + float tolerance) const override + { + // Use validation helper + using ADataType = typename SelectedKernel::ADataType; + using BDataType = typename SelectedKernel::BDataType; + using CDataType = typename SelectedKernel::CDataType; + using AccDataType = typename SelectedKernel::AccDataType; + + // d_ptrs not yet supported + (void)d_ptrs; + + // Convert tolerance to rtol and atol + float rtol = tolerance; + float atol = tolerance * 1e-2f; // atol is typically smaller + + return validation::validate_gemm_kernel( + a_ptr, b_ptr, c_ptr, problem, rtol, atol); + } + + private: + int64_t estimate_smem_usage() const + { + // Use kernel's reported shared memory size + return SelectedKernel::GetSmemSize(); + } + + KernelKey key_; + std::string name_; +}; + +/// Helper function to create a tile kernel instance wrapper +/// This should be called from generated code that knows the SelectedKernel type +template +std::shared_ptr create_tile_kernel_instance(const KernelKey& key, + const std::string& name) +{ + return std::make_shared>(key, name); +} + +} // namespace backends +} // namespace dispatcher +} // namespace ck_tile diff --git a/dispatcher/include/ck_tile/dispatcher/dispatcher.hpp b/dispatcher/include/ck_tile/dispatcher/dispatcher.hpp new file mode 100644 index 00000000000..6d3f5481382 --- /dev/null +++ b/dispatcher/include/ck_tile/dispatcher/dispatcher.hpp @@ -0,0 +1,146 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +/** + * Dispatcher - Main Kernel Selection and Execution Engine + * + * The Dispatcher provides unified interface for selecting and executing + * CK Tile GEMM kernels based on problem specifications. + * + * Features: + * - Multiple selection strategies (FirstFit, Heuristic) + * - Custom heuristic functions + * - Thread-safe registry integration + * - Real GPU execution with timing + * + * Usage: + * Dispatcher dispatcher; + * Problem problem(M, N, K); + * float time = dispatcher.run(a_dev, b_dev, c_dev, problem); + * + * Status: Production ready - 319 TFLOPS validated + */ + +#pragma once + +#include "ck_tile/dispatcher/kernel_instance.hpp" +#include "ck_tile/dispatcher/problem.hpp" +#include "ck_tile/dispatcher/registry.hpp" +#include +#include +#include +#include + +namespace ck_tile { +namespace dispatcher { + +/// Heuristic function type: maps Problem to ordered list of kernel identifiers +/// Returns kernel identifiers ranked by expected performance (best first) +using HeuristicFunction = std::function(const Problem&)>; + +/// Dispatcher: Top-level orchestration for kernel selection and execution +/// Provides unified interface for kernel dispatch across different backends +class Dispatcher +{ + public: + /// Selection strategy for kernel choice + enum class SelectionStrategy + { + FirstFit, // Use first kernel that supports the problem + Heuristic // Use heuristic function to guide selection + }; + + /// Constructor + /// @param registry Registry instance to use (default: global singleton) + explicit Dispatcher(Registry* registry = nullptr); + + /// Register a heuristic function for kernel selection + /// @param heuristic Function that maps problems to ranked kernel identifiers + void set_heuristic(HeuristicFunction heuristic); + + /// Set selection strategy + /// @param strategy Strategy to use for kernel selection + void set_strategy(SelectionStrategy strategy); + + /// Select a kernel for the given problem + /// @param problem Problem configuration + /// @return Selected kernel instance, or nullptr if no suitable kernel found + [[nodiscard]] KernelInstancePtr select_kernel(const Problem& problem) const; + + /// Execute GEMM operation with automatic kernel selection + /// @param a_ptr Pointer to matrix A (device memory) + /// @param b_ptr Pointer to matrix B (device memory) + /// @param c_ptr Pointer to matrix C (device memory, input/output) + /// @param problem Problem configuration + /// @param stream HIP stream for kernel launch (nullptr = default stream) + /// @return Kernel execution time in milliseconds + /// @throws std::runtime_error if no suitable kernel found + [[nodiscard]] float run(const void* a_ptr, + const void* b_ptr, + void* c_ptr, + const Problem& problem, + void* stream = nullptr) const; + + /// Execute GEMM operation with fusion (multi-D) + /// @param a_ptr Pointer to matrix A (device memory) + /// @param b_ptr Pointer to matrix B (device memory) + /// @param c_ptr Pointer to matrix C (device memory, input/output) + /// @param d_ptrs Array of pointers to additional D tensors (device memory) + /// @param problem Problem configuration + /// @param stream HIP stream for kernel launch (nullptr = default stream) + /// @return Kernel execution time in milliseconds + /// @throws std::runtime_error if no suitable kernel found + [[nodiscard]] float run_fused(const void* a_ptr, + const void* b_ptr, + void* c_ptr, + const void** d_ptrs, + const Problem& problem, + void* stream = nullptr) const; + + /// Execute with explicit kernel selection + /// @param kernel_id Kernel identifier string + /// @param a_ptr Pointer to matrix A (device memory) + /// @param b_ptr Pointer to matrix B (device memory) + /// @param c_ptr Pointer to matrix C (device memory, input/output) + /// @param d_ptrs Array of pointers to additional D tensors (device memory) + /// @param problem Problem configuration + /// @param stream HIP stream for kernel launch (nullptr = default stream) + /// @return Kernel execution time in milliseconds + /// @throws std::runtime_error if kernel not found or doesn't support problem + [[nodiscard]] float run_explicit(const std::string& kernel_id, + const void* a_ptr, + const void* b_ptr, + void* c_ptr, + const void** d_ptrs, + const Problem& problem, + void* stream = nullptr) const; + + /// Validate kernel output + /// @param a_ptr Pointer to matrix A (device memory) + /// @param b_ptr Pointer to matrix B (device memory) + /// @param c_ptr Pointer to matrix C (device memory, kernel output) + /// @param d_ptrs Array of pointers to additional D tensors (device memory) + /// @param problem Problem configuration + /// @param tolerance Relative error tolerance + /// @return true if validation passes, false otherwise + [[nodiscard]] bool validate(const void* a_ptr, + const void* b_ptr, + const void* c_ptr, + const void** d_ptrs, + const Problem& problem, + float tolerance = 1e-3f) const; + + private: + Registry* registry_; + HeuristicFunction heuristic_; + SelectionStrategy strategy_; + + /// Select kernel using first-fit strategy + [[nodiscard]] KernelInstancePtr select_first_fit(const Problem& problem) const; + + /// Select kernel using heuristic strategy + [[nodiscard]] KernelInstancePtr select_heuristic(const Problem& problem) const; +}; + +} // namespace dispatcher +} // namespace ck_tile diff --git a/dispatcher/include/ck_tile/dispatcher/example_args.hpp b/dispatcher/include/ck_tile/dispatcher/example_args.hpp new file mode 100644 index 00000000000..f93a4d61f6b --- /dev/null +++ b/dispatcher/include/ck_tile/dispatcher/example_args.hpp @@ -0,0 +1,230 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include +#include +#include +#include +#include +#include + +namespace ck_tile { +namespace dispatcher { +namespace utils { + +/** + * Simple command-line argument parser for examples. + * + * Usage: + * ExampleArgs args("Example 01: Basic GEMM", "Demonstrates basic GEMM usage"); + * args.add_flag("--list", "List all kernel sets"); + * args.add_option("--dtype", "fp16", "Data type (fp16, bf16, fp32)"); + * args.add_option("--size", "1024", "Problem size MxNxK"); + * + * if (!args.parse(argc, argv)) return 0; // --help was printed + * + * bool do_list = args.has("--list"); + * std::string dtype = args.get("--dtype"); + * int size = args.get_int("--size"); + */ +class ExampleArgs +{ + public: + ExampleArgs(const std::string& name, const std::string& description = "") + : name_(name), description_(description) + { + // Always add --help + add_flag("--help", "Show this help message"); + add_flag("-h", "Show this help message"); + } + + // Add a boolean flag (no value) + void add_flag(const std::string& name, const std::string& help) + { + flags_[name] = false; + help_[name] = help; + order_.push_back(name); + } + + // Add an option with a default value + void + add_option(const std::string& name, const std::string& default_val, const std::string& help) + { + options_[name] = default_val; + defaults_[name] = default_val; + help_[name] = help; + order_.push_back(name); + } + + // Parse arguments. Returns false if --help was requested. + bool parse(int argc, char* argv[]) + { + for(int i = 1; i < argc; ++i) + { + std::string arg = argv[i]; + + // Check for --help + if(arg == "--help" || arg == "-h") + { + print_help(); + return false; + } + + // Check for flags + if(flags_.find(arg) != flags_.end()) + { + flags_[arg] = true; + continue; + } + + // Check for options (--name=value or --name value) + std::string name, value; + size_t eq_pos = arg.find('='); + if(eq_pos != std::string::npos) + { + name = arg.substr(0, eq_pos); + value = arg.substr(eq_pos + 1); + } + else if(options_.find(arg) != options_.end() && i + 1 < argc) + { + name = arg; + value = argv[++i]; + } + else + { + // Positional argument - store as _pos_N + std::string pos_name = "_pos_" + std::to_string(positional_.size()); + positional_.push_back(arg); + continue; + } + + if(options_.find(name) != options_.end()) + { + options_[name] = value; + } + } + return true; + } + + // Check if a flag is set + bool has(const std::string& name) const + { + auto it = flags_.find(name); + return it != flags_.end() && it->second; + } + + // Get an option value as string + std::string get(const std::string& name) const + { + auto it = options_.find(name); + return it != options_.end() ? it->second : ""; + } + + // Get an option value as string with default + std::string get(const std::string& name, const std::string& default_val) const + { + auto it = options_.find(name); + return it != options_.end() ? it->second : default_val; + } + + // Get an option value as int + int get_int(const std::string& name, int default_val = 0) const + { + std::string val = get(name); + if(val.empty()) + return default_val; + try + { + return std::stoi(val); + } + catch(...) + { + return default_val; + } + } + + // Get an option value as float + float get_float(const std::string& name, float default_val = 0.0f) const + { + std::string val = get(name); + if(val.empty()) + return default_val; + try + { + return std::stof(val); + } + catch(...) + { + return default_val; + } + } + + // Get positional arguments + const std::vector& positional() const { return positional_; } + + // Print help message + void print_help() const + { + std::cout << "\n"; + std::cout << " " << name_ << "\n"; + if(!description_.empty()) + { + std::cout << " " << description_ << "\n"; + } + std::cout << "\n"; + std::cout << "Usage:\n"; + std::cout << " ./example [OPTIONS]\n"; + std::cout << "\n"; + std::cout << "Options:\n"; + + // Find max option name length for alignment + size_t max_len = 0; + for(const auto& name : order_) + { + if(name == "-h") + continue; // Skip -h, show --help only + max_len = std::max(max_len, name.length()); + } + + // Print options in order + for(const auto& name : order_) + { + if(name == "-h") + continue; + + std::cout << " " << std::left << std::setw(max_len + 2) << name; + + auto help_it = help_.find(name); + if(help_it != help_.end()) + { + std::cout << help_it->second; + } + + // Show default value for options + auto def_it = defaults_.find(name); + if(def_it != defaults_.end() && !def_it->second.empty()) + { + std::cout << " (default: " << def_it->second << ")"; + } + + std::cout << "\n"; + } + std::cout << "\n"; + } + + private: + std::string name_; + std::string description_; + std::map flags_; + std::map options_; + std::map defaults_; + std::map help_; + std::vector order_; + std::vector positional_; +}; + +} // namespace utils +} // namespace dispatcher +} // namespace ck_tile diff --git a/dispatcher/include/ck_tile/dispatcher/json_export.hpp b/dispatcher/include/ck_tile/dispatcher/json_export.hpp new file mode 100644 index 00000000000..ab1c45412ff --- /dev/null +++ b/dispatcher/include/ck_tile/dispatcher/json_export.hpp @@ -0,0 +1,370 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +/** + * JSON Export Utilities for Dispatcher Registry + * + * Provides functionality to export kernel registry metadata to JSON format, + * similar to the tile engine benchmarking JSON export. + * + * Features: + * - Export all registered kernels with full metadata + * - Include kernel configuration (tile shapes, pipeline, scheduler, etc.) + * - Group kernels by various properties (data type, layout, pipeline, etc.) + * - Export to string or file + * + * Usage: + * auto& registry = Registry::instance(); + * std::string json = export_registry_json(registry); + * // or + * export_registry_json_to_file(registry, "kernels.json"); + */ + +#pragma once + +#include "ck_tile/dispatcher/registry.hpp" +#include "ck_tile/dispatcher/kernel_key.hpp" +#include +#include +#include +#include +#include +#include +#include +#include + +namespace ck_tile { +namespace dispatcher { + +/// Convert DataType enum to string +inline std::string datatype_to_string(DataType dtype) +{ + switch(dtype) + { + case DataType::FP16: return "fp16"; + case DataType::BF16: return "bf16"; + case DataType::FP32: return "fp32"; + case DataType::FP8: return "fp8"; + case DataType::BF8: return "bf8"; + case DataType::INT8: return "int8"; + case DataType::INT32: return "int32"; + default: return "unknown"; + } +} + +/// Convert LayoutTag enum to string +inline std::string layout_to_string(LayoutTag layout) +{ + switch(layout) + { + case LayoutTag::RowMajor: return "row_major"; + case LayoutTag::ColMajor: return "col_major"; + case LayoutTag::PackedExternal: return "packed_external"; + default: return "unknown"; + } +} + +/// Convert Pipeline enum to string +inline std::string pipeline_to_string(Pipeline pipeline) +{ + switch(pipeline) + { + case Pipeline::Mem: return "mem"; + case Pipeline::CompV1: return "compv1"; + case Pipeline::CompV2: return "compv2"; + case Pipeline::CompV3: return "compv3"; + case Pipeline::CompV4: return "compv4"; + case Pipeline::CompV5: return "compv5"; + default: return "unknown"; + } +} + +/// Convert Epilogue enum to string +inline std::string epilogue_to_string(Epilogue epilogue) +{ + switch(epilogue) + { + case Epilogue::None: return "none"; + case Epilogue::Bias: return "bias"; + case Epilogue::Activation: return "activation"; + case Epilogue::CShuffle: return "cshuffle"; + case Epilogue::Default: return "default"; + default: return "unknown"; + } +} + +/// Convert Scheduler enum to string +inline std::string scheduler_to_string(Scheduler scheduler) +{ + switch(scheduler) + { + case Scheduler::Auto: return "auto"; + case Scheduler::Intrawave: return "intrawave"; + case Scheduler::Interwave: return "interwave"; + default: return "unknown"; + } +} + +/// Escape string for JSON +inline std::string json_escape(const std::string& str) +{ + std::ostringstream oss; + for(char c : str) + { + switch(c) + { + case '"': oss << "\\\""; break; + case '\\': oss << "\\\\"; break; + case '\b': oss << "\\b"; break; + case '\f': oss << "\\f"; break; + case '\n': oss << "\\n"; break; + case '\r': oss << "\\r"; break; + case '\t': oss << "\\t"; break; + default: + if(c < 0x20) + { + oss << "\\u" << std::hex << std::setw(4) << std::setfill('0') << (int)c; + } + else + { + oss << c; + } + } + } + return oss.str(); +} + +/// Get current timestamp in ISO 8601 format +inline std::string get_iso_timestamp() +{ + auto now = std::chrono::system_clock::now(); + auto time_t = std::chrono::system_clock::to_time_t(now); + std::tm tm_buf; + localtime_r(&time_t, &tm_buf); + + std::ostringstream oss; + oss << std::put_time(&tm_buf, "%Y-%m-%dT%H:%M:%S"); + return oss.str(); +} + +/// Export a single kernel's metadata to JSON +inline std::string export_kernel_json(const KernelInstance& kernel) +{ + std::ostringstream json; + const auto& key = kernel.get_key(); + + json << " {\n"; + json << " \"name\": \"" << json_escape(kernel.get_name()) << "\",\n"; + json << " \"identifier\": \"" << json_escape(key.encode_identifier()) << "\",\n"; + + // Signature (what operation is computed) + json << " \"signature\": {\n"; + json << " \"dtype_a\": \"" << datatype_to_string(key.signature.dtype_a) << "\",\n"; + json << " \"dtype_b\": \"" << datatype_to_string(key.signature.dtype_b) << "\",\n"; + json << " \"dtype_c\": \"" << datatype_to_string(key.signature.dtype_c) << "\",\n"; + json << " \"dtype_acc\": \"" << datatype_to_string(key.signature.dtype_acc) << "\",\n"; + json << " \"layout_a\": \"" << layout_to_string(key.signature.layout_a) << "\",\n"; + json << " \"layout_b\": \"" << layout_to_string(key.signature.layout_b) << "\",\n"; + json << " \"layout_c\": \"" << layout_to_string(key.signature.layout_c) << "\",\n"; + json << " \"transpose_a\": " << (key.signature.transpose_a ? "true" : "false") << ",\n"; + json << " \"transpose_b\": " << (key.signature.transpose_b ? "true" : "false") << ",\n"; + json << " \"grouped\": " << (key.signature.grouped ? "true" : "false") << ",\n"; + json << " \"split_k\": " << (int)key.signature.split_k << ",\n"; + json << " \"elementwise_op\": \"" << json_escape(key.signature.elementwise_op) + << "\",\n"; + json << " \"num_d_tensors\": " << (int)key.signature.num_d_tensors << ",\n"; + json << " \"structured_sparsity\": " + << (key.signature.structured_sparsity ? "true" : "false") << "\n"; + json << " },\n"; + + // Algorithm (how it's implemented) + json << " \"algorithm\": {\n"; + json << " \"tile_shape\": {\n"; + json << " \"m\": " << key.algorithm.tile_shape.m << ",\n"; + json << " \"n\": " << key.algorithm.tile_shape.n << ",\n"; + json << " \"k\": " << key.algorithm.tile_shape.k << "\n"; + json << " },\n"; + json << " \"wave_shape\": {\n"; + json << " \"m\": " << (int)key.algorithm.wave_shape.m << ",\n"; + json << " \"n\": " << (int)key.algorithm.wave_shape.n << ",\n"; + json << " \"k\": " << (int)key.algorithm.wave_shape.k << "\n"; + json << " },\n"; + json << " \"warp_tile_shape\": {\n"; + json << " \"m\": " << (int)key.algorithm.warp_tile_shape.m << ",\n"; + json << " \"n\": " << (int)key.algorithm.warp_tile_shape.n << ",\n"; + json << " \"k\": " << (int)key.algorithm.warp_tile_shape.k << "\n"; + json << " },\n"; + json << " \"pipeline\": \"" << pipeline_to_string(key.algorithm.pipeline) << "\",\n"; + json << " \"scheduler\": \"" << scheduler_to_string(key.algorithm.scheduler) << "\",\n"; + json << " \"epilogue\": \"" << epilogue_to_string(key.algorithm.epilogue) << "\",\n"; + json << " \"block_size\": " << key.algorithm.block_size << ",\n"; + json << " \"double_buffer\": " << (key.algorithm.double_buffer ? "true" : "false") + << ",\n"; + json << " \"persistent\": " << (key.algorithm.persistent ? "true" : "false") << ",\n"; + json << " \"preshuffle\": " << (key.algorithm.preshuffle ? "true" : "false") << ",\n"; + json << " \"transpose_c\": " << (key.algorithm.transpose_c ? "true" : "false") << ",\n"; + json << " \"num_wave_groups\": " << (int)key.algorithm.num_wave_groups << "\n"; + json << " },\n"; + + json << " \"gfx_arch\": \"" << json_escape(key.gfx_arch) << "\"\n"; + json << " }"; + + return json.str(); +} + +/// Export registry metadata and statistics to JSON +inline std::string export_registry_json(const Registry& registry, bool include_statistics = true) +{ + std::ostringstream json; + + auto all_kernels = registry.get_all(); + + json << "{\n"; + + // Metadata + json << " \"metadata\": {\n"; + json << " \"timestamp\": \"" << get_iso_timestamp() << "\",\n"; + json << " \"registry_name\": \"" << json_escape(registry.get_name()) << "\",\n"; + json << " \"total_kernels\": " << all_kernels.size() << ",\n"; + json << " \"export_version\": \"1.0.0\"\n"; + json << " },\n"; + + // Statistics (if enabled) + if(include_statistics && !all_kernels.empty()) + { + std::map by_datatype; + std::map by_pipeline; + std::map by_scheduler; + std::map by_layout; + std::map by_gfx_arch; + + for(const auto& kernel : all_kernels) + { + const auto& key = kernel->get_key(); + + // Count by data type + std::string dtype_key = datatype_to_string(key.signature.dtype_a) + "_" + + datatype_to_string(key.signature.dtype_b) + "_" + + datatype_to_string(key.signature.dtype_c); + by_datatype[dtype_key]++; + + // Count by pipeline + by_pipeline[pipeline_to_string(key.algorithm.pipeline)]++; + + // Count by scheduler + by_scheduler[scheduler_to_string(key.algorithm.scheduler)]++; + + // Count by layout + std::string layout_key = layout_to_string(key.signature.layout_a) + "_" + + layout_to_string(key.signature.layout_b) + "_" + + layout_to_string(key.signature.layout_c); + by_layout[layout_key]++; + + // Count by GFX architecture + by_gfx_arch[key.gfx_arch]++; + } + + json << " \"statistics\": {\n"; + + // Data type breakdown + json << " \"by_datatype\": {\n"; + bool first = true; + for(const auto& [dtype, count] : by_datatype) + { + if(!first) + json << ",\n"; + json << " \"" << dtype << "\": " << count; + first = false; + } + json << "\n },\n"; + + // Pipeline breakdown + json << " \"by_pipeline\": {\n"; + first = true; + for(const auto& [pipeline, count] : by_pipeline) + { + if(!first) + json << ",\n"; + json << " \"" << pipeline << "\": " << count; + first = false; + } + json << "\n },\n"; + + // Scheduler breakdown + json << " \"by_scheduler\": {\n"; + first = true; + for(const auto& [scheduler, count] : by_scheduler) + { + if(!first) + json << ",\n"; + json << " \"" << scheduler << "\": " << count; + first = false; + } + json << "\n },\n"; + + // Layout breakdown + json << " \"by_layout\": {\n"; + first = true; + for(const auto& [layout, count] : by_layout) + { + if(!first) + json << ",\n"; + json << " \"" << layout << "\": " << count; + first = false; + } + json << "\n },\n"; + + // GFX architecture breakdown + json << " \"by_gfx_arch\": {\n"; + first = true; + for(const auto& [arch, count] : by_gfx_arch) + { + if(!first) + json << ",\n"; + json << " \"" << arch << "\": " << count; + first = false; + } + json << "\n }\n"; + + json << " },\n"; + } + + // Kernels list + json << " \"kernels\": [\n"; + for(size_t i = 0; i < all_kernels.size(); ++i) + { + json << export_kernel_json(*all_kernels[i]); + if(i < all_kernels.size() - 1) + { + json << ","; + } + json << "\n"; + } + json << " ]\n"; + + json << "}\n"; + + return json.str(); +} + +/// Export registry to a JSON file +inline bool export_registry_json_to_file(const Registry& registry, + const std::string& filename, + bool include_statistics = true) +{ + std::string json = export_registry_json(registry, include_statistics); + + std::ofstream file(filename); + if(!file.is_open()) + { + return false; + } + + file << json; + file.close(); + + return true; +} + +} // namespace dispatcher +} // namespace ck_tile diff --git a/dispatcher/include/ck_tile/dispatcher/kernel_config.hpp b/dispatcher/include/ck_tile/dispatcher/kernel_config.hpp new file mode 100644 index 00000000000..05011d2c2d9 --- /dev/null +++ b/dispatcher/include/ck_tile/dispatcher/kernel_config.hpp @@ -0,0 +1,370 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +/** + * @file kernel_config.hpp + * @brief Explicit kernel configuration for CK Tile Dispatcher + * + * This header provides a KernelConfig struct that mirrors the Python API, + * allowing explicit, self-contained kernel configuration without relying + * on force-included generated headers. + * + * Usage: + * #include "ck_tile/dispatcher/kernel_config.hpp" + * using namespace ck_tile::dispatcher; + * + * // Step 1: Define explicit config + * auto config = KernelConfig::fp16_rcr() + * .tile(128, 128, 32) + * .wave(2, 2, 1) + * .warp_tile(32, 32, 16) + * .pipeline(Pipeline::CompV4) + * .scheduler(Scheduler::Intrawave); + * + * // Step 2: Create registry and register + * Registry registry; + * registry.register_kernel(config.build_key(), config.get_name()); + * + * // Step 3: Create dispatcher + * Dispatcher dispatcher(®istry); + * + * // Step 4: Run GEMM + * dispatcher.run(a, b, c, Problem(M, N, K)); + */ + +#pragma once + +#include "ck_tile/dispatcher/kernel_key.hpp" +#include +#include +#include + +namespace ck_tile { +namespace dispatcher { + +/** + * @brief Explicit kernel configuration matching Python's KernelConfig + * + * This provides a fluent builder API for creating kernel configurations + * with all parameters visible and explicit. + */ +class KernelConfig +{ + public: + // ========================================================================= + // Data types + // ========================================================================= + DataType dtype_a = DataType::FP16; + DataType dtype_b = DataType::FP16; + DataType dtype_c = DataType::FP16; + DataType dtype_acc = DataType::FP32; + + // ========================================================================= + // Layouts + // ========================================================================= + LayoutTag layout_a = LayoutTag::RowMajor; + LayoutTag layout_b = LayoutTag::ColMajor; + LayoutTag layout_c = LayoutTag::RowMajor; + + // ========================================================================= + // Tile shape + // ========================================================================= + int tile_m = 128; + int tile_n = 128; + int tile_k = 32; + + // ========================================================================= + // Wave shape (warps per block) + // ========================================================================= + int wave_m = 2; + int wave_n = 2; + int wave_k = 1; + + // ========================================================================= + // Warp tile shape + // ========================================================================= + int warp_m = 32; + int warp_n = 32; + int warp_k = 16; + + // ========================================================================= + // Block and pipeline + // ========================================================================= + int block_size = 256; + Pipeline pipeline_type = Pipeline::CompV4; + Scheduler scheduler_type = Scheduler::Intrawave; + Epilogue epilogue_type = Epilogue::CShuffle; + + // ========================================================================= + // Padding and features + // ========================================================================= + bool pad_m = true; + bool pad_n = true; + bool pad_k = true; + bool preshuffle = false; + + // ========================================================================= + // Target architecture + // ========================================================================= + std::string gfx_arch = "gfx942"; + + // ========================================================================= + // Fluent builder methods + // ========================================================================= + + /// Set tile dimensions (M x N x K) + KernelConfig& tile(int m, int n, int k) + { + tile_m = m; + tile_n = n; + tile_k = k; + return *this; + } + + /// Set wave dimensions (warps per block M x N x K) + KernelConfig& wave(int m, int n, int k) + { + wave_m = m; + wave_n = n; + wave_k = k; + return *this; + } + + /// Set warp tile dimensions (M x N x K) + KernelConfig& warp_tile(int m, int n, int k) + { + warp_m = m; + warp_n = n; + warp_k = k; + return *this; + } + + /// Set block size + KernelConfig& block(int size) + { + block_size = size; + return *this; + } + + /// Set pipeline type + KernelConfig& pipeline(Pipeline p) + { + pipeline_type = p; + return *this; + } + + /// Set scheduler type + KernelConfig& scheduler(Scheduler s) + { + scheduler_type = s; + return *this; + } + + /// Set epilogue type + KernelConfig& epilogue(Epilogue e) + { + epilogue_type = e; + return *this; + } + + /// Set data types for A, B, C + KernelConfig& dtypes(DataType a, DataType b, DataType c, DataType acc = DataType::FP32) + { + dtype_a = a; + dtype_b = b; + dtype_c = c; + dtype_acc = acc; + return *this; + } + + /// Set layouts for A, B, C + KernelConfig& layouts(LayoutTag a, LayoutTag b, LayoutTag c) + { + layout_a = a; + layout_b = b; + layout_c = c; + return *this; + } + + /// Set padding flags + KernelConfig& padding(bool m, bool n, bool k) + { + pad_m = m; + pad_n = n; + pad_k = k; + return *this; + } + + /// Set target GPU architecture + KernelConfig& arch(const std::string& gpu) + { + gfx_arch = gpu; + return *this; + } + + // ========================================================================= + // Preset configurations + // ========================================================================= + + /// FP16 Row-Column-Row layout (most common) + static KernelConfig fp16_rcr() { return KernelConfig{}; } + + /// FP16 Row-Row-Row layout + static KernelConfig fp16_rrr() + { + KernelConfig cfg; + cfg.layout_b = LayoutTag::RowMajor; + return cfg; + } + + /// BF16 Row-Column-Row layout + static KernelConfig bf16_rcr() + { + KernelConfig cfg; + cfg.dtype_a = DataType::BF16; + cfg.dtype_b = DataType::BF16; + cfg.dtype_c = DataType::BF16; + return cfg; + } + + /// FP32 Row-Column-Row layout + static KernelConfig fp32_rcr() + { + KernelConfig cfg; + cfg.dtype_a = DataType::FP32; + cfg.dtype_b = DataType::FP32; + cfg.dtype_c = DataType::FP32; + cfg.dtype_acc = DataType::FP32; + return cfg; + } + + // ========================================================================= + // Build KernelKey + // ========================================================================= + + /// Build a KernelKey from this configuration + [[nodiscard]] KernelKey build_key() const + { + KernelKey key; + + // Signature + key.signature.dtype_a = dtype_a; + key.signature.dtype_b = dtype_b; + key.signature.dtype_c = dtype_c; + key.signature.dtype_acc = dtype_acc; + key.signature.layout_a = layout_a; + key.signature.layout_b = layout_b; + key.signature.layout_c = layout_c; + key.signature.transpose_a = false; + key.signature.transpose_b = false; + key.signature.grouped = false; + key.signature.split_k = 1; + key.signature.elementwise_op = "PassThrough"; + key.signature.num_d_tensors = 0; + key.signature.structured_sparsity = false; + + // Algorithm + key.algorithm.tile_shape = {static_cast(tile_m), + static_cast(tile_n), + static_cast(tile_k)}; + key.algorithm.wave_shape = {static_cast(wave_m), + static_cast(wave_n), + static_cast(wave_k)}; + key.algorithm.warp_tile_shape = {static_cast(warp_m), + static_cast(warp_n), + static_cast(warp_k)}; + key.algorithm.pipeline = pipeline_type; + key.algorithm.scheduler = scheduler_type; + key.algorithm.epilogue = epilogue_type; + key.algorithm.block_size = block_size; + key.algorithm.double_buffer = true; + key.algorithm.persistent = false; + key.algorithm.preshuffle = preshuffle; + key.algorithm.transpose_c = false; + key.algorithm.num_wave_groups = 1; + + key.gfx_arch = gfx_arch; + + return key; + } + + // ========================================================================= + // String representations + // ========================================================================= + + /// Get tile string (e.g., "128x128x32") + [[nodiscard]] std::string tile_str() const + { + std::ostringstream oss; + oss << tile_m << "x" << tile_n << "x" << tile_k; + return oss.str(); + } + + /// Get wave string (e.g., "2x2x1") + [[nodiscard]] std::string wave_str() const + { + std::ostringstream oss; + oss << wave_m << "x" << wave_n << "x" << wave_k; + return oss.str(); + } + + /// Get warp tile string (e.g., "32x32x16") + [[nodiscard]] std::string warp_tile_str() const + { + std::ostringstream oss; + oss << warp_m << "x" << warp_n << "x" << warp_k; + return oss.str(); + } + + /// Get layout string (e.g., "rcr") + [[nodiscard]] std::string layout_str() const + { + std::ostringstream oss; + oss << to_string(layout_a) << to_string(layout_b) << to_string(layout_c); + return oss.str(); + } + + /// Get kernel name for generated code lookup + [[nodiscard]] std::string get_name() const + { + std::ostringstream oss; + oss << "gemm_" << to_string(dtype_a) << "_" << layout_str() << "_" + << to_string(pipeline_type) << "_" << to_string(epilogue_type) << "_" + << to_string(scheduler_type) << "_" << (pad_m ? "True" : "False") << "_" + << (pad_n ? "True" : "False") << "_" << (pad_k ? "True" : "False") << "_" + << "False" // preshuffle + << "_" << tile_str() << "_" << wave_str() << "_" << warp_tile_str(); + return oss.str(); + } + + /// Print configuration to stdout + void print_config(std::ostream& os = std::cout) const + { + os << " Data types:\n"; + os << " dtype_a = " << to_string(dtype_a) << "\n"; + os << " dtype_b = " << to_string(dtype_b) << "\n"; + os << " dtype_c = " << to_string(dtype_c) << "\n"; + os << " dtype_acc = " << to_string(dtype_acc) << "\n"; + os << " Layouts:\n"; + os << " layout_a = " << to_string(layout_a) << "\n"; + os << " layout_b = " << to_string(layout_b) << "\n"; + os << " layout_c = " << to_string(layout_c) << "\n"; + os << " Tile shape:\n"; + os << " tile = " << tile_str() << "\n"; + os << " wave = " << wave_str() << "\n"; + os << " warp_tile = " << warp_tile_str() << "\n"; + os << " Pipeline:\n"; + os << " pipeline = " << to_string(pipeline_type) << "\n"; + os << " scheduler = " << to_string(scheduler_type) << "\n"; + os << " epilogue = " << to_string(epilogue_type) << "\n"; + os << " Padding:\n"; + os << " pad_m = " << (pad_m ? "true" : "false") << "\n"; + os << " pad_n = " << (pad_n ? "true" : "false") << "\n"; + os << " pad_k = " << (pad_k ? "true" : "false") << "\n"; + os << " Target:\n"; + os << " gfx_arch = " << gfx_arch << "\n"; + } +}; + +} // namespace dispatcher +} // namespace ck_tile diff --git a/dispatcher/include/ck_tile/dispatcher/kernel_decl.hpp b/dispatcher/include/ck_tile/dispatcher/kernel_decl.hpp new file mode 100644 index 00000000000..095de52e06b --- /dev/null +++ b/dispatcher/include/ck_tile/dispatcher/kernel_decl.hpp @@ -0,0 +1,509 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +/** + * @file kernel_decl.hpp + * @brief Declarative kernel specification with KernelSet + * + * USAGE: + * ====== + * + * // Named kernel sets + * DECL_KERNEL_SET(compute_bound, + * .add("fp16", "rcr", 256, 256, 64) + * .add("fp16", "rcr", 128, 128, 32) + * ); + * + * // Access at runtime + * auto& set = KernelSetRegistry::instance().get("compute_bound"); + */ + +#pragma once + +#include +#include +#include +#include +#include +#include + +namespace ck_tile { +namespace dispatcher { +namespace decl { + +// ============================================================================= +// Wildcard constants +// ============================================================================= + +constexpr const char* ANY = "*"; +constexpr int ANY_INT = -1; + +// ============================================================================= +// Signature Builder +// ============================================================================= + +class Signature +{ + public: + std::string dtype_a_ = "fp16"; + std::string dtype_b_ = "fp16"; + std::string dtype_c_ = "fp16"; + std::string dtype_acc_ = "fp32"; + std::string layout_a_ = "row"; + std::string layout_b_ = "col"; + std::string layout_c_ = "row"; + std::string elementwise_op_ = "PassThrough"; + int num_d_tensors_ = 0; + bool structured_sparsity_ = false; + + Signature& dtype(const std::string& a, + const std::string& b, + const std::string& c, + const std::string& acc = "fp32") + { + dtype_a_ = a; + dtype_b_ = b; + dtype_c_ = c; + dtype_acc_ = acc; + return *this; + } + + Signature& dtype(const std::string& all) + { + dtype_a_ = dtype_b_ = dtype_c_ = all; + dtype_acc_ = "fp32"; + return *this; + } + + Signature& layout(const std::string& a, const std::string& b, const std::string& c) + { + layout_a_ = a; + layout_b_ = b; + layout_c_ = c; + return *this; + } + + Signature& layout(const std::string& combined) + { + if(combined.size() >= 3) + { + layout_a_ = (combined[0] == 'r') ? "row" : "col"; + layout_b_ = (combined[1] == 'r') ? "row" : "col"; + layout_c_ = (combined[2] == 'r') ? "row" : "col"; + } + return *this; + } + + Signature& elementwise(const std::string& op, int num_d = 0) + { + elementwise_op_ = op; + num_d_tensors_ = num_d; + return *this; + } + + std::string layout_str() const + { + std::string r; + r += (layout_a_ == "col") ? 'c' : 'r'; + r += (layout_b_ == "col") ? 'c' : 'r'; + r += (layout_c_ == "col") ? 'c' : 'r'; + return r; + } +}; + +// ============================================================================= +// Algorithm Builder +// ============================================================================= + +class Algorithm +{ + public: + int tile_m_ = 128, tile_n_ = 128, tile_k_ = 32; + int wave_m_ = ANY_INT, wave_n_ = ANY_INT, wave_k_ = 1; + int warp_m_ = ANY_INT, warp_n_ = ANY_INT, warp_k_ = 16; + std::string pipeline_ = "compv4"; + std::string scheduler_ = "intrawave"; + std::string epilogue_ = "cshuffle"; + int block_size_ = 256; + int pad_m_ = 1, pad_n_ = 1, pad_k_ = 1; + bool preshuffle_ = false; + + Algorithm& tile(int m, int n, int k) + { + tile_m_ = m; + tile_n_ = n; + tile_k_ = k; + return *this; + } + + Algorithm& wave(int m, int n, int k = 1) + { + wave_m_ = m; + wave_n_ = n; + wave_k_ = k; + return *this; + } + + Algorithm& warp(int m, int n, int k = 16) + { + warp_m_ = m; + warp_n_ = n; + warp_k_ = k; + return *this; + } + + Algorithm& pipeline(const std::string& p) + { + pipeline_ = p; + return *this; + } + Algorithm& scheduler(const std::string& s) + { + scheduler_ = s; + return *this; + } + Algorithm& epilogue(const std::string& e) + { + epilogue_ = e; + return *this; + } + + Algorithm& pad(bool m, bool n, bool k) + { + pad_m_ = m ? 1 : 0; + pad_n_ = n ? 1 : 0; + pad_k_ = k ? 1 : 0; + return *this; + } + + Algorithm& preshuffle(bool v) + { + preshuffle_ = v; + return *this; + } + + bool needs_expansion() const + { + return wave_m_ == ANY_INT || warp_m_ == ANY_INT || pipeline_ == "*" || pad_m_ == ANY_INT; + } + + void auto_fill() + { + if(wave_m_ == ANY_INT) + wave_m_ = 2; + if(wave_n_ == ANY_INT) + wave_n_ = 2; + if(wave_k_ == ANY_INT) + wave_k_ = 1; + if(warp_m_ == ANY_INT) + warp_m_ = 32; + if(warp_n_ == ANY_INT) + warp_n_ = 32; + if(warp_k_ == ANY_INT) + warp_k_ = 16; + } +}; + +// ============================================================================= +// Kernel Declaration +// ============================================================================= + +struct KernelDecl +{ + Signature signature; + Algorithm algorithm; + std::string arch = "gfx942"; + + KernelDecl() = default; + + KernelDecl(const Signature& sig, const Algorithm& algo, const std::string& a = "gfx942") + : signature(sig), algorithm(algo), arch(a) + { + } + + std::string name() const + { + std::ostringstream oss; + oss << signature.dtype_a_ << "_" << signature.layout_str(); + if(algorithm.tile_m_ > 0) + { + oss << "_" << algorithm.tile_m_ << "x" << algorithm.tile_n_ << "x" << algorithm.tile_k_; + } + return oss.str(); + } + + bool has_wildcards() const { return algorithm.needs_expansion() || arch == "*"; } +}; + +// ============================================================================= +// KernelSet - Collection of declarations +// ============================================================================= + +class KernelSet +{ + public: + KernelSet() = default; + + KernelSet& add(const Signature& sig, const Algorithm& algo, const std::string& arch = "gfx942") + { + decls_.emplace_back(sig, algo, arch); + return *this; + } + + KernelSet& add(const std::string& dtype, + const std::string& layout, + int tm, + int tn, + int tk, + const std::string& arch = "gfx942") + { + Signature sig; + sig.dtype(dtype).layout(layout); + Algorithm algo; + algo.tile(tm, tn, tk); + decls_.emplace_back(sig, algo, arch); + return *this; + } + + KernelSet& add(const KernelDecl& decl) + { + decls_.push_back(decl); + return *this; + } + + KernelSet& merge(const KernelSet& other) + { + decls_.insert(decls_.end(), other.decls_.begin(), other.decls_.end()); + return *this; + } + + const std::vector& declarations() const { return decls_; } + size_t size() const { return decls_.size(); } + + bool needs_expansion() const + { + for(const auto& d : decls_) + { + if(d.algorithm.needs_expansion()) + return true; + } + return false; + } + + void print(std::ostream& os = std::cout) const + { + os << "KernelSet (" << size() << " declarations):\n"; + for(const auto& d : decls_) + { + os << " - " << d.name(); + if(d.algorithm.needs_expansion()) + os << " [expands]"; + os << "\n"; + } + } + + KernelSet& tag(const std::string& t) + { + tag_ = t; + return *this; + } + std::string tag() const { return tag_; } + + private: + std::vector decls_; + std::string tag_; +}; + +// ============================================================================= +// KernelSet Registry +// ============================================================================= + +class KernelSetRegistry +{ + public: + static KernelSetRegistry& instance() + { + static KernelSetRegistry reg; + return reg; + } + + void add(const std::string& name, const KernelSet& set) + { + sets_[name] = set; + order_.push_back(name); + } + + const KernelSet& get(const std::string& name) const + { + static KernelSet empty; + auto it = sets_.find(name); + return it != sets_.end() ? it->second : empty; + } + + bool has(const std::string& name) const { return sets_.find(name) != sets_.end(); } + + // Return const reference to avoid deep copy + const std::vector& names() const { return order_; } + size_t size() const { return sets_.size(); } + + void print() const + { + std::cout << "Named Kernel Sets (" << size() << "):\n"; + for(const auto& name : order_) + { + const auto& set = sets_.at(name); + std::cout << " " << name << ": " << set.size() << " declarations\n"; + } + } + + private: + KernelSetRegistry() = default; + std::unordered_map sets_; + std::vector order_; +}; + +// ============================================================================= +// Declaration Registry (for DECL_KERNEL) +// ============================================================================= + +class Registry +{ + public: + static Registry& instance() + { + static Registry reg; + return reg; + } + + void add(const KernelDecl& decl) + { + std::string key = decl.has_wildcards() + ? ("wildcard_" + std::to_string(declarations_.size())) + : decl.name(); + declarations_[key] = decl; + order_.push_back(key); + } + + std::vector all() const + { + std::vector result; + for(const auto& key : order_) + { + result.push_back(declarations_.at(key)); + } + return result; + } + + size_t size() const { return declarations_.size(); } + + void print() const + { + std::cout << "Declared kernels (" << size() << "):\n"; + for(const auto& key : order_) + { + const auto& d = declarations_.at(key); + std::cout << " " << d.name(); + if(d.has_wildcards()) + std::cout << " [wildcards]"; + std::cout << "\n"; + } + } + + private: + Registry() = default; + std::unordered_map declarations_; + std::vector order_; +}; + +// ============================================================================= +// Static Registrars +// ============================================================================= + +struct Declarator +{ + Declarator(const Signature& sig, const Algorithm& algo, const std::string& arch = "gfx942") + { + Registry::instance().add(KernelDecl(sig, algo, arch)); + } + + Declarator(const std::string& dtype, + const std::string& layout, + int tm, + int tn, + int tk, + const std::string& arch = "gfx942") + { + Signature sig; + sig.dtype(dtype).layout(layout); + Algorithm algo; + algo.tile(tm, tn, tk); + Registry::instance().add(KernelDecl(sig, algo, arch)); + } + + Declarator(const std::string& dtype, const std::string& layout, const std::string& arch) + { + Signature sig; + sig.dtype(dtype).layout(layout); + Algorithm algo; + algo.tile(ANY_INT, ANY_INT, ANY_INT); + Registry::instance().add(KernelDecl(sig, algo, arch)); + } +}; + +struct KernelSetRegistrar +{ + KernelSetRegistrar(const std::string& name, const KernelSet& set) + { + KernelSetRegistry::instance().add(name, set); + } +}; + +} // namespace decl + +// ============================================================================= +// Convenience Aliases +// ============================================================================= + +using KernelSignature = decl::Signature; +using KernelAlgorithm = decl::Algorithm; +using KernelDecl = decl::KernelDecl; +using KernelDeclRegistry = decl::Registry; +using KernelSet = decl::KernelSet; +using KernelSetRegistry = decl::KernelSetRegistry; + +constexpr const char* ANY = decl::ANY; +constexpr int ANY_INT = decl::ANY_INT; + +} // namespace dispatcher +} // namespace ck_tile + +// ============================================================================= +// Declaration Macros +// ============================================================================= + +#define CK_DECL_CAT_(a, b) CK_DECL_CAT_IMPL_(a, b) +#define CK_DECL_CAT_IMPL_(a, b) a##b + +// Note: __extension__ suppresses warnings about __COUNTER__ being a GCC/Clang extension +#define DECL_KERNEL(sig, algo, ...) \ + __extension__ static ::ck_tile::dispatcher::decl::Declarator CK_DECL_CAT_( \ + _kdecl_, __COUNTER__)(sig, algo, ##__VA_ARGS__) + +#define DECL_KERNEL_SIMPLE(dtype, layout, tm, tn, tk) \ + __extension__ static ::ck_tile::dispatcher::decl::Declarator CK_DECL_CAT_( \ + _kdecl_, __COUNTER__)(#dtype, #layout, tm, tn, tk) + +#define DECL_KERNEL_ALL(dtype, layout) \ + __extension__ static ::ck_tile::dispatcher::decl::Declarator CK_DECL_CAT_( \ + _kdecl_, __COUNTER__)(#dtype, #layout, "*") + +#define DECL_KERNEL_SET(name, ...) \ + __extension__ static ::ck_tile::dispatcher::decl::KernelSetRegistrar CK_DECL_CAT_( \ + _kset_reg_, __COUNTER__)(#name, \ + ::ck_tile::dispatcher::decl::KernelSet() __VA_ARGS__.tag(#name)) + +#define KERNEL_SET(name) ::ck_tile::dispatcher::decl::KernelSet name +#define BEGIN_KERNEL_SET() ::ck_tile::dispatcher::decl::KernelSet() + +// Legacy compatibility +// Legacy aliases removed - use DECL_KERNEL_SET instead diff --git a/dispatcher/include/ck_tile/dispatcher/kernel_instance.hpp b/dispatcher/include/ck_tile/dispatcher/kernel_instance.hpp new file mode 100644 index 00000000000..4a734f4c3fd --- /dev/null +++ b/dispatcher/include/ck_tile/dispatcher/kernel_instance.hpp @@ -0,0 +1,68 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck_tile/dispatcher/kernel_key.hpp" +#include "ck_tile/dispatcher/problem.hpp" +#include +#include + +namespace ck_tile { +namespace dispatcher { + +/// KernelInstance: Uniform interface for kernel execution +/// Abstracts away implementation details (CK Library vs CK Tile vs future JIT) +/// Enables type-erased storage in registry while backends perform type-safe casts +class KernelInstance +{ + public: + virtual ~KernelInstance() = default; + + /// Get the kernel's configuration metadata + [[nodiscard]] virtual const KernelKey& get_key() const = 0; + + /// Check if this kernel supports the given problem + /// Returns false if problem dimensions don't meet kernel requirements + /// (e.g., divisibility constraints, resource limits) + [[nodiscard]] virtual bool supports(const Problem& problem) const = 0; + + /// Get human-readable kernel name for logging and debugging + [[nodiscard]] virtual std::string get_name() const = 0; + + /// Execute the kernel with given problem and data pointers + /// @param a_ptr Pointer to matrix A (device memory) + /// @param b_ptr Pointer to matrix B (device memory) + /// @param c_ptr Pointer to matrix C (device memory, input/output) + /// @param d_ptrs Array of pointers to additional D tensors for fusion (device memory) + /// @param problem Problem configuration + /// @param stream HIP stream for kernel launch (nullptr = default stream) + /// @return Kernel execution time in milliseconds (0 if timing not available) + [[nodiscard]] virtual float run(const void* a_ptr, + const void* b_ptr, + void* c_ptr, + const void** d_ptrs, + const Problem& problem, + void* stream = nullptr) const = 0; + + /// Validate kernel output against reference implementation + /// @param a_ptr Pointer to matrix A (device memory) + /// @param b_ptr Pointer to matrix B (device memory) + /// @param c_ptr Pointer to matrix C (device memory, kernel output) + /// @param d_ptrs Array of pointers to additional D tensors (device memory) + /// @param problem Problem configuration + /// @param tolerance Relative error tolerance for validation + /// @return true if validation passes, false otherwise + [[nodiscard]] virtual bool validate(const void* a_ptr, + const void* b_ptr, + const void* c_ptr, + const void** d_ptrs, + const Problem& problem, + float tolerance = 1e-3f) const = 0; +}; + +/// Shared pointer type for kernel instances +using KernelInstancePtr = std::shared_ptr; + +} // namespace dispatcher +} // namespace ck_tile diff --git a/dispatcher/include/ck_tile/dispatcher/kernel_key.hpp b/dispatcher/include/ck_tile/dispatcher/kernel_key.hpp new file mode 100644 index 00000000000..f49b3a0d746 --- /dev/null +++ b/dispatcher/include/ck_tile/dispatcher/kernel_key.hpp @@ -0,0 +1,428 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include +#include +#include +#include +#include + +namespace ck_tile { +namespace dispatcher { + +/// Data types supported by CK Tile GEMM kernels +/// Matches tile_engine DATA_TYPE_MAP for full compatibility +enum class DataType : std::uint8_t +{ + FP16, // ck_tile::half_t + BF16, // ck_tile::bf16_t + FP32, // float + FP64, // double + FP8, // ck_tile::fp8_t (E4M3) + BF8, // ck_tile::bf8_t (E5M2) + INT8, // ck_tile::int8_t + INT4, // ck_tile::pk_int4_t (packed int4) + INT32, // ck_tile::int32_t + UNKNOWN +}; + +/// Memory layout tags for tensors +enum class LayoutTag : std::uint8_t +{ + RowMajor, + ColMajor, + PackedExternal +}; + +/// Pipeline variants for memory/compute optimization +/// Matches tile_engine PIPELINE_MAP for full compatibility +enum class Pipeline : std::uint8_t +{ + Mem, // Memory-bound pipeline + CompV1, // Compute pipeline v1 + CompV2, // Compute pipeline v2 + CompV3, // Compute pipeline v3 + CompV4, // Compute pipeline v4 (double buffering) + CompV5, // Compute pipeline v5 + PreShuffleV1, // Weight preshuffle pipeline v1 + PreShuffleV2 // Weight preshuffle pipeline v2 (optimized) +}; + +/// Epilogue strategies for output processing +/// Matches tile_engine epilogue options for full compatibility +enum class Epilogue : std::uint8_t +{ + None, + Default, // DefaultGemm2DEpilogue + CShuffle, // CShuffleEpilogue (cross-shuffle) + Bias, // Bias addition + Activation, // Fused activation + BiasActivation // Fused bias + activation +}; + +/// Scheduler types for wave coordination +enum class Scheduler : std::uint8_t +{ + Auto, + Intrawave, + Interwave +}; + +/// KernelKey: Compile-time kernel configuration metadata +/// Organized into Signature (what operation) and Algorithm (how it's implemented) +struct KernelKey +{ + /// Signature: Describes WHAT operation is computed (mathematical semantics) + /// Two kernels with different signatures compute different mathematical operations + struct Signature + { + DataType dtype_a; + DataType dtype_b; + DataType dtype_c; + DataType dtype_acc; + LayoutTag layout_a; + LayoutTag layout_b; + LayoutTag layout_c; + bool transpose_a; + bool transpose_b; + bool grouped; + std::uint8_t split_k; + + // Element-wise fusion: Describes mathematical operation applied to GEMM output + // Examples: PassThrough (C = A*B), MultiDAdd (E = C + D0 + D1), + // MultiDMultiply (E = C * D0 * D1), Clamp, Relu, Gelu, etc. + // This affects the mathematical result, so it belongs in Signature + std::string elementwise_op; // e.g., "PassThrough", "MultiDAdd", "Relu" + std::uint8_t + num_d_tensors; // Number of additional input tensors for fusion (0 for basic GEMM) + + bool structured_sparsity; // 2:4 sparsity affects mathematical correctness + } signature; + + /// Algorithm: Describes HOW it's implemented (performance tuning parameters) + /// Two kernels with same signature but different algorithms compute the same result + /// with different performance characteristics + struct Algorithm + { + // Hierarchical tiling configuration (primary tuning knobs) + struct TileShape + { + std::uint16_t m; + std::uint16_t n; + std::uint16_t k; + } tile_shape; + + struct WaveShape + { + std::uint8_t m; // WarpPerBlock_M in generated kernels + std::uint8_t n; // WarpPerBlock_N + std::uint8_t k; // WarpPerBlock_K + } wave_shape; + + struct WarpTileShape + { + std::uint8_t m; // WarpTileM in generated kernels + std::uint8_t n; // WarpTileN + std::uint8_t k; // WarpTileK + } warp_tile_shape; + + // Pipeline and scheduling strategy + Pipeline pipeline; + Scheduler scheduler; + Epilogue epilogue; + + // Block and memory configuration + std::uint16_t block_size; // BlockSize in generated kernels (typically 256) + bool double_buffer; // DoubleSmemBuffer (true for compv4) + bool persistent; // UsePersistentKernel + bool preshuffle; // Preshuffle (for weight preshuffle variants) + bool transpose_c; // TransposeC + std::uint8_t num_wave_groups; // NumWaveGroups + } algorithm; + + std::string gfx_arch; // e.g. "gfx942", "gfx90a", "gfx908" + + /// Generate a unique string identifier for this kernel configuration + /// Format matches tile_engine naming convention for registry lookup + /// Note: Defined after to_string() functions to use them + [[nodiscard]] std::string encode_identifier() const; + + /// Create a tuple of all fields for comparison operators + auto tie() const + { + return std::tie(signature.dtype_a, + signature.dtype_b, + signature.dtype_c, + signature.dtype_acc, + signature.layout_a, + signature.layout_b, + signature.layout_c, + signature.transpose_a, + signature.transpose_b, + signature.grouped, + signature.split_k, + signature.elementwise_op, + signature.num_d_tensors, + signature.structured_sparsity, + algorithm.tile_shape.m, + algorithm.tile_shape.n, + algorithm.tile_shape.k, + algorithm.wave_shape.m, + algorithm.wave_shape.n, + algorithm.wave_shape.k, + algorithm.warp_tile_shape.m, + algorithm.warp_tile_shape.n, + algorithm.warp_tile_shape.k, + algorithm.pipeline, + algorithm.epilogue, + algorithm.scheduler, + algorithm.block_size, + gfx_arch, + signature.structured_sparsity, + algorithm.persistent, + algorithm.double_buffer, + algorithm.preshuffle, + algorithm.transpose_c, + algorithm.num_wave_groups); + } + + /// Equality comparison + friend bool operator==(const KernelKey& lhs, const KernelKey& rhs) + { + return lhs.tie() == rhs.tie(); + } + + /// Inequality comparison + friend bool operator!=(const KernelKey& lhs, const KernelKey& rhs) { return !(lhs == rhs); } +}; + +// ============================================================================= +// String Conversion Helpers (for serialization and debugging) +// ============================================================================= + +/// Convert DataType to string +inline std::string to_string(DataType dtype) +{ + switch(dtype) + { + case DataType::FP16: return "fp16"; + case DataType::BF16: return "bf16"; + case DataType::FP32: return "fp32"; + case DataType::FP64: return "fp64"; + case DataType::FP8: return "fp8"; + case DataType::BF8: return "bf8"; + case DataType::INT8: return "int8"; + case DataType::INT4: return "int4"; + case DataType::INT32: return "int32"; + default: return "unknown"; + } +} + +/// Convert string to DataType +inline DataType string_to_dtype(const std::string& str) +{ + if(str == "fp16") + return DataType::FP16; + if(str == "bf16") + return DataType::BF16; + if(str == "fp32") + return DataType::FP32; + if(str == "fp64") + return DataType::FP64; + if(str == "fp8") + return DataType::FP8; + if(str == "bf8") + return DataType::BF8; + if(str == "int8") + return DataType::INT8; + if(str == "int4") + return DataType::INT4; + if(str == "int32") + return DataType::INT32; + return DataType::UNKNOWN; +} + +/// Convert LayoutTag to string +inline std::string to_string(LayoutTag layout) +{ + switch(layout) + { + case LayoutTag::RowMajor: return "r"; + case LayoutTag::ColMajor: return "c"; + case LayoutTag::PackedExternal: return "p"; + default: return "?"; + } +} + +/// Convert string to LayoutTag +inline LayoutTag string_to_layout(const std::string& str) +{ + if(str == "r" || str == "row" || str == "RowMajor") + return LayoutTag::RowMajor; + if(str == "c" || str == "col" || str == "ColMajor") + return LayoutTag::ColMajor; + if(str == "p" || str == "packed") + return LayoutTag::PackedExternal; + return LayoutTag::RowMajor; // Default +} + +/// Convert Pipeline to string +inline std::string to_string(Pipeline pipeline) +{ + switch(pipeline) + { + case Pipeline::Mem: return "mem"; + case Pipeline::CompV1: return "compv1"; + case Pipeline::CompV2: return "compv2"; + case Pipeline::CompV3: return "compv3"; + case Pipeline::CompV4: return "compv4"; + case Pipeline::CompV5: return "compv5"; + case Pipeline::PreShuffleV1: return "preshufflev1"; + case Pipeline::PreShuffleV2: return "preshufflev2"; + default: return "unknown"; + } +} + +/// Convert string to Pipeline +inline Pipeline string_to_pipeline(const std::string& str) +{ + if(str == "mem") + return Pipeline::Mem; + if(str == "compv1") + return Pipeline::CompV1; + if(str == "compv2") + return Pipeline::CompV2; + if(str == "compv3") + return Pipeline::CompV3; + if(str == "compv4") + return Pipeline::CompV4; + if(str == "compv5") + return Pipeline::CompV5; + if(str == "preshufflev1") + return Pipeline::PreShuffleV1; + if(str == "preshufflev2") + return Pipeline::PreShuffleV2; + return Pipeline::Mem; // Default +} + +/// Convert Epilogue to string +inline std::string to_string(Epilogue epilogue) +{ + switch(epilogue) + { + case Epilogue::None: return "none"; + case Epilogue::Default: return "default"; + case Epilogue::CShuffle: return "cshuffle"; + case Epilogue::Bias: return "bias"; + case Epilogue::Activation: return "activation"; + case Epilogue::BiasActivation: return "bias_activation"; + default: return "unknown"; + } +} + +/// Convert string to Epilogue +inline Epilogue string_to_epilogue(const std::string& str) +{ + if(str == "none") + return Epilogue::None; + if(str == "default") + return Epilogue::Default; + if(str == "cshuffle") + return Epilogue::CShuffle; + if(str == "bias") + return Epilogue::Bias; + if(str == "activation") + return Epilogue::Activation; + if(str == "bias_activation") + return Epilogue::BiasActivation; + return Epilogue::Default; // Default +} + +/// Convert Scheduler to string +inline std::string to_string(Scheduler scheduler) +{ + switch(scheduler) + { + case Scheduler::Auto: return "auto"; + case Scheduler::Intrawave: return "intrawave"; + case Scheduler::Interwave: return "interwave"; + default: return "unknown"; + } +} + +/// Convert string to Scheduler +inline Scheduler string_to_scheduler(const std::string& str) +{ + if(str == "auto") + return Scheduler::Auto; + if(str == "intrawave") + return Scheduler::Intrawave; + if(str == "interwave") + return Scheduler::Interwave; + return Scheduler::Intrawave; // Default +} + +/// Common elementwise operations (for reference in elementwise_op field) +/// These match CK Tile's ck_tile::element_wise namespace +namespace ElementwiseOps { +constexpr const char* PassThrough = "PassThrough"; +constexpr const char* Add = "Add"; +constexpr const char* Multiply = "Multiply"; +constexpr const char* MultiDAdd = "MultiDAdd"; +constexpr const char* MultiDMultiply = "MultiDMultiply"; +constexpr const char* Relu = "Relu"; +constexpr const char* Gelu = "Gelu"; +constexpr const char* Clamp = "Clamp"; +constexpr const char* Sigmoid = "Sigmoid"; +constexpr const char* Tanh = "Tanh"; +constexpr const char* Swish = "Swish"; +constexpr const char* HardSwish = "HardSwish"; +} // namespace ElementwiseOps + +// ============================================================================= +// KernelKey::encode_identifier() implementation +// Defined after to_string() functions to use them +// ============================================================================= + +inline std::string KernelKey::encode_identifier() const +{ + std::ostringstream oss; + + // Include data types and layout for uniqueness across different signatures + oss << to_string(signature.dtype_a) << "_"; + oss << to_string(signature.layout_a) << to_string(signature.layout_b) + << to_string(signature.layout_c) << "_"; + + // Include pipeline, scheduler, epilogue for uniqueness + oss << to_string(algorithm.pipeline) << "_"; + oss << to_string(algorithm.scheduler) << "_"; + oss << to_string(algorithm.epilogue) << "_"; + + // Match tile_engine naming: tile_m x tile_n x tile_k _ warp_m x warp_n x warp_k _ + // warp_tile_m x warp_tile_n x warp_tile_k + oss << algorithm.tile_shape.m << "x" << algorithm.tile_shape.n << "x" << algorithm.tile_shape.k + << "_" << unsigned(algorithm.wave_shape.m) << "x" << unsigned(algorithm.wave_shape.n) << "x" + << unsigned(algorithm.wave_shape.k) << "_" << unsigned(algorithm.warp_tile_shape.m) << "x" + << unsigned(algorithm.warp_tile_shape.n) << "x" << unsigned(algorithm.warp_tile_shape.k); + + // Add trait flags + oss << "_" << (algorithm.persistent ? "persist" : "nopers"); + + if(signature.split_k > 1) + oss << "_splitk" << unsigned(signature.split_k); + if(!signature.elementwise_op.empty() && signature.elementwise_op != "PassThrough") + oss << "_" << signature.elementwise_op; + if(signature.num_d_tensors > 0) + oss << "_d" << unsigned(signature.num_d_tensors); + if(signature.structured_sparsity) + oss << "_sparse"; + if(algorithm.preshuffle) + oss << "_preshuffle"; + + return oss.str(); +} + +} // namespace dispatcher +} // namespace ck_tile diff --git a/dispatcher/include/ck_tile/dispatcher/problem.hpp b/dispatcher/include/ck_tile/dispatcher/problem.hpp new file mode 100644 index 00000000000..437511d1ba3 --- /dev/null +++ b/dispatcher/include/ck_tile/dispatcher/problem.hpp @@ -0,0 +1,311 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include +#include +#include + +namespace ck_tile { +namespace dispatcher { + +// ============================================================================= +// Tensor Information for Automatic MNK Inference +// ============================================================================= + +/// TensorShape: Describes tensor dimensions for automatic MNK inference +struct TensorShape +{ + std::int64_t rows; // First dimension + std::int64_t cols; // Second dimension + bool is_transposed; // Whether the tensor is transposed (column-major) + + TensorShape() : rows(0), cols(0), is_transposed(false) {} + TensorShape(std::int64_t r, std::int64_t c, bool trans = false) + : rows(r), cols(c), is_transposed(trans) + { + } + + /// Get logical M (rows when not transposed) + [[nodiscard]] std::int64_t logical_rows() const { return is_transposed ? cols : rows; } + + /// Get logical N (cols when not transposed) + [[nodiscard]] std::int64_t logical_cols() const { return is_transposed ? rows : cols; } +}; + +// ============================================================================= +// Problem: Runtime Parameters +// ============================================================================= + +/// Problem: Runtime parameters for kernel invocation +/// Captures problem dimensions and resource constraints that vary between invocations +/// even when using the same kernel +struct Problem +{ + // Problem dimensions + std::int64_t M; // Number of rows in A and C + std::int64_t N; // Number of columns in B and C + std::int64_t K; // Shared dimension (columns of A, rows of B) + + // Batch configuration + std::int32_t k_batch; // Number of K-dimension splits for split-K GEMM + + // Resource preferences + std::int32_t smem_budget; // Shared memory budget in bytes (0 = no constraint) + bool prefer_persistent; // Prefer persistent kernel variants + + // Validation control + bool enable_validation; // Enable output validation against reference + + /// Default constructor with sensible defaults + Problem() + : M(0), + N(0), + K(0), + k_batch(1), + smem_budget(0), + prefer_persistent(false), + enable_validation(false) + { + } + + /// Constructor with problem dimensions + Problem(std::int64_t m, std::int64_t n, std::int64_t k) + : M(m), + N(n), + K(k), + k_batch(1), + smem_budget(0), + prefer_persistent(false), + enable_validation(false) + { + } + + /// Check if problem dimensions are valid + [[nodiscard]] bool is_valid() const { return M > 0 && N > 0 && K > 0 && k_batch > 0; } + + /// Get total number of operations (for performance metrics) + [[nodiscard]] std::int64_t num_ops() const + { + return 2 * M * N * K; // Multiply-add counts as 2 ops + } + + // ========================================================================= + // Factory Methods for Automatic MNK Inference + // ========================================================================= + + /** + * Create Problem by inferring MNK from tensor shapes. + * + * For GEMM: C[M,N] = A[M,K] × B[K,N] + * + * @param a_shape Shape of matrix A (M x K, or K x M if transposed) + * @param b_shape Shape of matrix B (K x N, or N x K if transposed) + * @param c_shape Shape of matrix C (M x N) - used for validation + * @throws std::invalid_argument if dimensions are inconsistent + * + * Example: + * // A is 512x256, B is 256x1024, C is 512x1024 + * auto problem = Problem::from_shapes({512, 256}, {256, 1024}, {512, 1024}); + * // Infers: M=512, N=1024, K=256 + */ + [[nodiscard]] static Problem + from_shapes(TensorShape a_shape, TensorShape b_shape, TensorShape c_shape) + { + // For C = A × B: + // A: [M, K] (or [K, M] if transposed) + // B: [K, N] (or [N, K] if transposed) + // C: [M, N] + + std::int64_t M_from_A = a_shape.logical_rows(); + std::int64_t K_from_A = a_shape.logical_cols(); + std::int64_t K_from_B = b_shape.logical_rows(); + std::int64_t N_from_B = b_shape.logical_cols(); + std::int64_t M_from_C = c_shape.logical_rows(); + std::int64_t N_from_C = c_shape.logical_cols(); + + // Validate K dimension matches between A and B + if(K_from_A != K_from_B) + { + throw std::invalid_argument( + "K dimension mismatch: A has K=" + std::to_string(K_from_A) + + ", B has K=" + std::to_string(K_from_B)); + } + + // Validate M dimension matches between A and C + if(M_from_A != M_from_C) + { + throw std::invalid_argument( + "M dimension mismatch: A has M=" + std::to_string(M_from_A) + + ", C has M=" + std::to_string(M_from_C)); + } + + // Validate N dimension matches between B and C + if(N_from_B != N_from_C) + { + throw std::invalid_argument( + "N dimension mismatch: B has N=" + std::to_string(N_from_B) + + ", C has N=" + std::to_string(N_from_C)); + } + + return Problem(M_from_A, N_from_B, K_from_A); + } + + /** + * Create Problem from tensor dimensions (simple version without transpose). + * + * @param a_rows Rows of matrix A (= M) + * @param a_cols Columns of matrix A (= K) + * @param b_rows Rows of matrix B (= K) + * @param b_cols Columns of matrix B (= N) + * @param c_rows Rows of matrix C (= M) - for validation + * @param c_cols Columns of matrix C (= N) - for validation + * @throws std::invalid_argument if dimensions are inconsistent + * + * Example: + * // A[512,256] × B[256,1024] = C[512,1024] + * auto problem = Problem::from_dimensions(512, 256, 256, 1024, 512, 1024); + */ + [[nodiscard]] static Problem from_dimensions(std::int64_t a_rows, + std::int64_t a_cols, + std::int64_t b_rows, + std::int64_t b_cols, + std::int64_t c_rows, + std::int64_t c_cols) + { + return from_shapes( + TensorShape(a_rows, a_cols), TensorShape(b_rows, b_cols), TensorShape(c_rows, c_cols)); + } + + /** + * Create Problem from A and B dimensions only (C is inferred). + * + * @param a_rows Rows of matrix A (= M) + * @param a_cols Columns of matrix A (= K) + * @param b_rows Rows of matrix B (= K) - validated + * @param b_cols Columns of matrix B (= N) + * @throws std::invalid_argument if K dimensions don't match + * + * Example: + * // A[512,256] × B[256,1024] = C[512,1024] + * auto problem = Problem::from_ab(512, 256, 256, 1024); + */ + [[nodiscard]] static Problem + from_ab(std::int64_t a_rows, std::int64_t a_cols, std::int64_t b_rows, std::int64_t b_cols) + { + if(a_cols != b_rows) + { + throw std::invalid_argument("K dimension mismatch: A.cols=" + std::to_string(a_cols) + + ", B.rows=" + std::to_string(b_rows)); + } + return Problem(a_rows, b_cols, a_cols); + } + + /** + * Validate that tensor pointers have consistent sizes. + * Call this before kernel execution to catch dimension errors early. + * + * @param a_size Total elements in A tensor + * @param b_size Total elements in B tensor + * @param c_size Total elements in C tensor + * @throws std::invalid_argument if sizes don't match expected dimensions + */ + void validate_sizes(std::int64_t a_size, std::int64_t b_size, std::int64_t c_size) const + { + std::int64_t expected_a = M * K; + std::int64_t expected_b = K * N; + std::int64_t expected_c = M * N; + + if(a_size != expected_a) + { + throw std::invalid_argument("A tensor size mismatch: got " + std::to_string(a_size) + + ", expected " + std::to_string(expected_a) + " (M*K = " + + std::to_string(M) + "*" + std::to_string(K) + ")"); + } + if(b_size != expected_b) + { + throw std::invalid_argument("B tensor size mismatch: got " + std::to_string(b_size) + + ", expected " + std::to_string(expected_b) + " (K*N = " + + std::to_string(K) + "*" + std::to_string(N) + ")"); + } + if(c_size != expected_c) + { + throw std::invalid_argument("C tensor size mismatch: got " + std::to_string(c_size) + + ", expected " + std::to_string(expected_c) + " (M*N = " + + std::to_string(M) + "*" + std::to_string(N) + ")"); + } + } +}; + +// ============================================================================= +// Convenience Builders +// ============================================================================= + +/// Builder pattern for Problem configuration +class ProblemBuilder +{ + public: + ProblemBuilder() = default; + + /// Set dimensions from A and B shapes + ProblemBuilder& + from_ab(std::int64_t a_rows, std::int64_t a_cols, std::int64_t b_rows, std::int64_t b_cols) + { + problem_ = Problem::from_ab(a_rows, a_cols, b_rows, b_cols); + return *this; + } + + /// Set MNK directly + ProblemBuilder& dimensions(std::int64_t m, std::int64_t n, std::int64_t k) + { + problem_.M = m; + problem_.N = n; + problem_.K = k; + return *this; + } + + /// Set split-K batch count + ProblemBuilder& split_k(std::int32_t k_batch) + { + problem_.k_batch = k_batch; + return *this; + } + + /// Set shared memory budget + ProblemBuilder& smem_budget(std::int32_t budget) + { + problem_.smem_budget = budget; + return *this; + } + + /// Prefer persistent kernels + ProblemBuilder& persistent(bool prefer = true) + { + problem_.prefer_persistent = prefer; + return *this; + } + + /// Enable validation + ProblemBuilder& validate(bool enable = true) + { + problem_.enable_validation = enable; + return *this; + } + + /// Build the Problem + [[nodiscard]] Problem build() const + { + if(!problem_.is_valid()) + { + throw std::invalid_argument("Invalid problem dimensions"); + } + return problem_; + } + + private: + Problem problem_; +}; + +} // namespace dispatcher +} // namespace ck_tile diff --git a/dispatcher/include/ck_tile/dispatcher/registry.hpp b/dispatcher/include/ck_tile/dispatcher/registry.hpp new file mode 100644 index 00000000000..93d1eb9f648 --- /dev/null +++ b/dispatcher/include/ck_tile/dispatcher/registry.hpp @@ -0,0 +1,197 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +/** + * Registry - Thread-Safe Kernel Storage + * + * Central registry for all available kernel instances with priority-based + * ordering and efficient lookup. + * + * Features: + * - Thread-safe registration and lookup + * - Priority-based ordering (High, Normal, Low) + * - Lookup by name or KernelKey + * - Filter by problem compatibility + * - Supports both singleton and multiple instance patterns + * + * Usage (Singleton - backward compatible): + * auto& registry = Registry::instance(); + * registry.register_kernel(kernel, Priority::High); + * auto kernel = registry.lookup("kernel_name"); + * + * Usage (Multiple registries): + * Registry fp16_registry; + * Registry bf16_registry; + * fp16_registry.register_kernel(fp16_kernel, Priority::High); + * bf16_registry.register_kernel(bf16_kernel, Priority::High); + * + * Dispatcher fp16_dispatcher(&fp16_registry); + * Dispatcher bf16_dispatcher(&bf16_registry); + * + * Status: Production ready, thread-safe + */ + +#pragma once + +#include "ck_tile/dispatcher/kernel_instance.hpp" +#include "ck_tile/dispatcher/kernel_key.hpp" +#include +#include +#include +#include +#include +#include + +namespace ck_tile { +namespace dispatcher { + +/// Registry: Central mapping from kernel configurations to executable instances +/// Thread-safe kernel registration and lookup +/// Supports both singleton pattern and multiple independent instances +class Registry +{ + public: + /// Priority levels for conflict resolution when multiple kernels have same key + enum class Priority + { + Low = 0, + Normal = 1, + High = 2 + }; + + /// Default constructor - creates an empty registry instance + /// Use this to create independent registries for different kernel sets + Registry(); + + /// Destructor - triggers auto-export if enabled + ~Registry(); + + /// Move constructor + Registry(Registry&& other) noexcept; + + /// Move assignment + Registry& operator=(Registry&& other) noexcept; + + // Prevent copying (registries contain shared_ptrs that shouldn't be duplicated) + Registry(const Registry&) = delete; + Registry& operator=(const Registry&) = delete; + + /// Register a kernel instance with the registry + /// @param instance Kernel instance to register + /// @param priority Priority level for conflict resolution (default: Normal) + /// @return true if registered successfully, false if duplicate with higher priority exists + bool register_kernel(KernelInstancePtr instance, Priority priority = Priority::Normal); + + /// Lookup a kernel by its string identifier + /// @param identifier Kernel identifier string + /// @return Kernel instance if found, nullptr otherwise + [[nodiscard]] KernelInstancePtr lookup(const std::string& identifier) const; + + /// Lookup a kernel by its KernelKey + /// @param key Kernel configuration key + /// @return Kernel instance if found, nullptr otherwise + [[nodiscard]] KernelInstancePtr lookup(const KernelKey& key) const; + + /// Get all registered kernels + /// @return Vector of all kernel instances + [[nodiscard]] std::vector get_all() const; + + /// Get all kernels matching a predicate + /// @param predicate Function to filter kernels + /// @return Vector of matching kernel instances + [[nodiscard]] std::vector + filter(std::function predicate) const; + + /// Get number of registered kernels + [[nodiscard]] std::size_t size() const; + + /// Check if registry is empty + [[nodiscard]] bool empty() const; + + /// Clear all registered kernels + void clear(); + + /// Get registry name (for logging/debugging) + [[nodiscard]] const std::string& get_name() const; + + /// Set registry name (for logging/debugging) + void set_name(const std::string& name); + + /// Export registry to JSON string + /// @param include_statistics Whether to include kernel statistics breakdown + /// @return JSON string with all kernel metadata + [[nodiscard]] std::string export_json(bool include_statistics = true) const; + + /// Export registry to JSON file + /// @param filename Output filename + /// @param include_statistics Whether to include kernel statistics breakdown + /// @return true if export succeeded, false otherwise + bool export_json_to_file(const std::string& filename, bool include_statistics = true) const; + + /// Enable automatic JSON export on kernel registration + /// @param filename Output filename for auto-export + /// @param include_statistics Whether to include statistics in auto-export + /// @param export_on_every_registration If true, exports after every registration (default). + /// If false, only exports on destruction. + void enable_auto_export(const std::string& filename, + bool include_statistics = true, + bool export_on_every_registration = true); + + /// Disable automatic JSON export + void disable_auto_export(); + + /// Check if auto-export is enabled + [[nodiscard]] bool is_auto_export_enabled() const; + + /// Merge kernels from another registry into this one + /// @param other Registry to merge from + /// @param priority Priority for merged kernels (default: Normal) + /// @return Number of kernels successfully merged + std::size_t merge_from(const Registry& other, Priority priority = Priority::Normal); + + /// Filter kernels in-place by architecture + /// @param gpu_arch Target GPU architecture string (e.g., "gfx942") + /// @return Number of kernels removed + std::size_t filter_by_arch(const std::string& gpu_arch); + + /// Get singleton instance of the global registry (backward compatible) + /// This is the default registry used when no specific registry is provided + static Registry& instance(); + + private: + struct RegistryEntry + { + KernelInstancePtr instance; + Priority priority; + }; + + /// Perform auto-export if enabled + void perform_auto_export(); + + mutable std::mutex mutex_; + std::unordered_map kernels_; + std::string name_; + + // Auto-export configuration + bool auto_export_enabled_ = false; + std::string auto_export_filename_; + bool auto_export_include_statistics_ = true; + bool auto_export_on_every_registration_ = true; +}; + +/// Shared pointer type for registries (useful for managing lifetime) +using RegistryPtr = std::shared_ptr; + +/// Create a new registry instance (factory function) +inline RegistryPtr make_registry(const std::string& name = "") +{ + auto reg = std::make_shared(); + if(!name.empty()) + { + reg->set_name(name); + } + return reg; +} + +} // namespace dispatcher +} // namespace ck_tile diff --git a/dispatcher/include/ck_tile/dispatcher/utils.hpp b/dispatcher/include/ck_tile/dispatcher/utils.hpp new file mode 100644 index 00000000000..0f9990c45ea --- /dev/null +++ b/dispatcher/include/ck_tile/dispatcher/utils.hpp @@ -0,0 +1,724 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +/** + * @file utils.hpp + * @brief Common utilities for CK Tile Dispatcher + * + * This header provides reusable utilities for: + * - GPU memory management (GpuBuffer) + * - Performance measurement (Timer, GpuTimer, BenchmarkStats) + * - Validation (ValidationResult, validate_result) + * - Kernel registration helpers + * - Data generation (fill_random, etc.) + * + * Usage: + * #include "ck_tile/dispatcher/utils.hpp" + * using namespace ck_tile::dispatcher::utils; + * + * // GPU memory + * GpuBuffer buffer(1024); + * + * // Timing + * GpuTimer timer; + * timer.start(); + * // ... kernel ... + * timer.stop(); + * float ms = timer.elapsed_ms(); + * + * // Validation + * auto result = validate_result(gpu_data, ref_data, size); + */ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "ck_tile/dispatcher/dispatcher.hpp" +#include "ck_tile/dispatcher/registry.hpp" +#include "ck_tile/dispatcher/backends/generated_tile_backend.hpp" + +namespace ck_tile { +namespace dispatcher { +namespace utils { + +// ============================================================================= +// HIP Error Handling +// ============================================================================= + +#define CK_HIP_CHECK(call) \ + do \ + { \ + hipError_t err = call; \ + if(err != hipSuccess) \ + { \ + std::cerr << "HIP error at " << __FILE__ << ":" << __LINE__ << ": " \ + << hipGetErrorString(err) << std::endl; \ + return false; \ + } \ + } while(0) + +#define CK_HIP_CHECK_THROW(call) \ + do \ + { \ + hipError_t err = call; \ + if(err != hipSuccess) \ + { \ + throw std::runtime_error(std::string("HIP error: ") + hipGetErrorString(err)); \ + } \ + } while(0) + +// ============================================================================= +// Timing Utilities +// ============================================================================= + +/** + * @brief High-resolution timer for CPU timing + */ +class Timer +{ + public: + void start() { start_ = std::chrono::high_resolution_clock::now(); } + + double elapsed_ms() const + { + auto end = std::chrono::high_resolution_clock::now(); + return std::chrono::duration(end - start_).count(); + } + + private: + std::chrono::high_resolution_clock::time_point start_; +}; + +/** + * @brief GPU timing using HIP events + * + * Times kernel execution on a specific HIP stream. Events are recorded + * on the provided stream to accurately measure kernel execution time. + * + * Usage: + * hipStream_t stream; + * hipStreamCreate(&stream); + * GpuTimer timer(stream); // or timer.set_stream(stream) + * timer.start(); + * kernel<<>>(...); + * timer.stop(); + * float ms = timer.elapsed_ms(); + */ +class GpuTimer +{ + public: + /** + * @brief Construct timer with optional stream + * @param stream HIP stream to record events on (default: null stream) + */ + explicit GpuTimer(hipStream_t stream = nullptr) : stream_(stream) + { + (void)hipEventCreate(&start_); + (void)hipEventCreate(&stop_); + } + + ~GpuTimer() + { + (void)hipEventDestroy(start_); + (void)hipEventDestroy(stop_); + } + + // Non-copyable + GpuTimer(const GpuTimer&) = delete; + GpuTimer& operator=(const GpuTimer&) = delete; + + // Movable + GpuTimer(GpuTimer&& other) noexcept + : start_(other.start_), stop_(other.stop_), stream_(other.stream_) + { + other.start_ = nullptr; + other.stop_ = nullptr; + other.stream_ = nullptr; + } + + GpuTimer& operator=(GpuTimer&& other) noexcept + { + if(this != &other) + { + if(start_) + (void)hipEventDestroy(start_); + if(stop_) + (void)hipEventDestroy(stop_); + start_ = other.start_; + stop_ = other.stop_; + stream_ = other.stream_; + other.start_ = nullptr; + other.stop_ = nullptr; + other.stream_ = nullptr; + } + return *this; + } + + /** + * @brief Set the stream to record events on + * @param stream HIP stream (pass nullptr for default stream) + */ + void set_stream(hipStream_t stream) { stream_ = stream; } + + /** + * @brief Get the current stream + */ + hipStream_t get_stream() const { return stream_; } + + /** + * @brief Record start event on the stream + */ + void start() { (void)hipEventRecord(start_, stream_); } + + /** + * @brief Record stop event on the stream + */ + void stop() { (void)hipEventRecord(stop_, stream_); } + + /** + * @brief Get elapsed time in milliseconds + * + * Synchronizes on the stop event before calculating time. + * @return Elapsed time between start and stop in milliseconds + */ + float elapsed_ms() + { + (void)hipEventSynchronize(stop_); + float ms = 0; + (void)hipEventElapsedTime(&ms, start_, stop_); + return ms; + } + + private: + hipEvent_t start_ = nullptr; + hipEvent_t stop_ = nullptr; + hipStream_t stream_ = nullptr; +}; + +// ============================================================================= +// Performance Metrics +// ============================================================================= + +/** + * @brief Calculate TFLOPS for GEMM + */ +inline double calculate_tflops(int64_t M, int64_t N, int64_t K, double time_ms) +{ + double flops = 2.0 * M * N * K; + return (flops / (time_ms * 1e-3)) / 1e12; +} + +/** + * @brief Calculate memory bandwidth in GB/s + */ +template +inline double calculate_bandwidth_gbs(int64_t M, int64_t N, int64_t K, double time_ms) +{ + double bytes = M * K * sizeof(AType) + K * N * sizeof(BType) + M * N * sizeof(CType); + return (bytes / (time_ms * 1e-3)) / 1e9; +} + +/** + * @brief Benchmark statistics + */ +struct BenchmarkStats +{ + double min_ms = 0; + double avg_ms = 0; + double max_ms = 0; + double median_ms = 0; + double tflops = 0; + double bandwidth_gbs = 0; + int iterations = 0; + + void print(std::ostream& os = std::cout) const + { + os << std::fixed << std::setprecision(4); + os << " Min: " << min_ms << " ms\n"; + os << " Avg: " << avg_ms << " ms\n"; + os << " Max: " << max_ms << " ms\n"; + os << " Median: " << median_ms << " ms\n"; + os << " TFLOPS: " << std::setprecision(2) << tflops << "\n"; + os << " Bandwidth: " << bandwidth_gbs << " GB/s\n"; + } +}; + +/** + * @brief Run benchmark and compute statistics + */ +template +BenchmarkStats run_benchmark(Func&& func, int warmup = 2, int iterations = 10) +{ + std::vector times; + times.reserve(iterations); + + for(int i = 0; i < warmup; ++i) + func(); + + for(int i = 0; i < iterations; ++i) + times.push_back(func()); + + std::sort(times.begin(), times.end()); + + BenchmarkStats stats; + stats.iterations = iterations; + stats.min_ms = times.front(); + stats.max_ms = times.back(); + stats.median_ms = times[iterations / 2]; + + double sum = 0; + for(double t : times) + sum += t; + stats.avg_ms = sum / iterations; + + return stats; +} + +// ============================================================================= +// Validation Utilities +// ============================================================================= + +/** + * @brief Validation result + */ +struct ValidationResult +{ + bool correct = false; + double max_diff = 0; + double mean_diff = 0; + double accuracy = 0; + int64_t matches = 0; + int64_t total = 0; + + void print(std::ostream& os = std::cout) const + { + os << " Correct: " << (correct ? "YES" : "NO") << "\n"; + os << " Max diff: " << max_diff << "\n"; + os << " Mean diff: " << mean_diff << "\n"; + os << " Accuracy: " << accuracy << "%\n"; + os << " Matches: " << matches << "/" << total << "\n"; + } +}; + +/** + * @brief Validate GEMM result against reference + */ +template +ValidationResult validate_result( + const T* result, const T* reference, int64_t size, double rtol = 1e-3, double atol = 1e-2) +{ + ValidationResult v; + v.total = size; + v.max_diff = 0; + v.matches = 0; + + double sum_diff = 0; + + for(int64_t i = 0; i < size; ++i) + { + double r = static_cast(result[i]); + double ref = static_cast(reference[i]); + double diff = std::abs(r - ref); + + v.max_diff = std::max(v.max_diff, diff); + sum_diff += diff; + + double threshold = atol + rtol * std::abs(ref); + if(diff <= threshold) + ++v.matches; + } + + v.mean_diff = sum_diff / size; + v.accuracy = 100.0 * v.matches / v.total; + v.correct = (v.matches == v.total) || (v.accuracy >= 99.9); + + return v; +} + +/** + * @brief Compute reference GEMM on CPU + */ +template +void compute_reference_gemm( + const AType* A, const BType* B, CType* C, int64_t M, int64_t N, int64_t K) +{ + for(int64_t m = 0; m < M; ++m) + { + for(int64_t n = 0; n < N; ++n) + { + double acc = 0; + for(int64_t k = 0; k < K; ++k) + acc += static_cast(A[m * K + k]) * static_cast(B[k * N + n]); + C[m * N + n] = static_cast(acc); + } + } +} + +// ============================================================================= +// Data Generation +// ============================================================================= + +template +void fill_random(T* data, int64_t size, T min_val = T(-1), T max_val = T(1)) +{ + std::random_device rd; + std::mt19937 gen(rd()); + std::uniform_real_distribution dist(static_cast(min_val), + static_cast(max_val)); + for(int64_t i = 0; i < size; ++i) + data[i] = static_cast(dist(gen)); +} + +template +void fill_zeros(T* data, int64_t size) +{ + std::fill(data, data + size, T(0)); +} + +template +void fill_ones(T* data, int64_t size) +{ + std::fill(data, data + size, T(1)); +} + +template +void fill_identity(T* data, int64_t rows, int64_t cols) +{ + fill_zeros(data, rows * cols); + int64_t min_dim = std::min(rows, cols); + for(int64_t i = 0; i < min_dim; ++i) + data[i * cols + i] = T(1); +} + +// ============================================================================= +// GPU Memory Management +// ============================================================================= + +/** + * @brief RAII wrapper for GPU memory + */ +template +class GpuBuffer +{ + public: + GpuBuffer() : data_(nullptr), size_(0) {} + + explicit GpuBuffer(int64_t count) : size_(count * sizeof(T)) + { + CK_HIP_CHECK_THROW(hipMalloc(&data_, size_)); + } + + ~GpuBuffer() + { + if(data_) + (void)hipFree(data_); + } + + // Non-copyable + GpuBuffer(const GpuBuffer&) = delete; + GpuBuffer& operator=(const GpuBuffer&) = delete; + + // Movable + GpuBuffer(GpuBuffer&& other) noexcept : data_(other.data_), size_(other.size_) + { + other.data_ = nullptr; + other.size_ = 0; + } + + GpuBuffer& operator=(GpuBuffer&& other) noexcept + { + if(this != &other) + { + if(data_) + (void)hipFree(data_); + data_ = other.data_; + size_ = other.size_; + other.data_ = nullptr; + other.size_ = 0; + } + return *this; + } + + T* get() { return data_; } + const T* get() const { return data_; } + int64_t size_bytes() const { return size_; } + int64_t count() const { return size_ / sizeof(T); } + + void copy_from_host(const T* host_data) + { + CK_HIP_CHECK_THROW(hipMemcpy(data_, host_data, size_, hipMemcpyHostToDevice)); + } + + void copy_to_host(T* host_data) const + { + CK_HIP_CHECK_THROW(hipMemcpy(host_data, data_, size_, hipMemcpyDeviceToHost)); + } + + void zero() { CK_HIP_CHECK_THROW(hipMemset(data_, 0, size_)); } + + private: + T* data_; + int64_t size_; +}; + +// ============================================================================= +// Printing Utilities +// ============================================================================= + +inline void print_separator(char c = '=', int width = 70) +{ + std::cout << std::string(width, c) << "\n"; +} + +inline void print_header(const std::string& title) +{ + print_separator(); + std::cout << title << "\n"; + print_separator(); +} + +inline std::string format_size(int64_t M, int64_t N, int64_t K) +{ + std::ostringstream oss; + oss << M << "x" << N << "x" << K; + return oss.str(); +} + +inline std::string format_number(int64_t n) +{ + std::string s = std::to_string(n); + int pos = static_cast(s.length()) - 3; + while(pos > 0) + { + s.insert(pos, ","); + pos -= 3; + } + return s; +} + +/** + * @brief Print all registered kernels in a registry + * + * @param registry The registry to list kernels from + * @param os Output stream (default: std::cout) + * @param verbose If true, show full kernel config details + */ +inline void print_registered_kernels(const Registry& registry, + std::ostream& os = std::cout, + bool verbose = false) +{ + const auto& kernels = registry.get_all(); + os << "Registered Kernels (" << kernels.size() << "):\n"; + os << std::string(70, '-') << "\n"; + + int idx = 1; + for(const auto& kernel : kernels) + { + const auto& key = kernel->get_key(); + + os << " " << idx++ << ". " << kernel->get_name() << "\n"; + + if(verbose) + { + os << " Tile: " << key.algorithm.tile_shape.m << "x" + << key.algorithm.tile_shape.n << "x" << key.algorithm.tile_shape.k << "\n"; + os << " Wave: " << static_cast(key.algorithm.wave_shape.m) << "x" + << static_cast(key.algorithm.wave_shape.n) << "x" + << static_cast(key.algorithm.wave_shape.k) << "\n"; + os << " WarpTile: " << static_cast(key.algorithm.warp_tile_shape.m) << "x" + << static_cast(key.algorithm.warp_tile_shape.n) << "x" + << static_cast(key.algorithm.warp_tile_shape.k) << "\n"; + os << " Pipeline: " << to_string(key.algorithm.pipeline) << "\n"; + os << " Scheduler: " << to_string(key.algorithm.scheduler) << "\n"; + os << " Arch: " << key.gfx_arch << "\n"; + os << "\n"; + } + } + + if(!verbose && !kernels.empty()) + { + os << "\n Use --list-verbose for full details\n"; + } + os << std::string(70, '-') << "\n"; +} + +/** + * @brief Print a single kernel's configuration + */ +inline void print_kernel_info(const KernelInstance& kernel, std::ostream& os = std::cout) +{ + const auto& key = kernel.get_key(); + + os << "Kernel: " << kernel.get_name() << "\n"; + os << " Signature:\n"; + os << " dtype: " << to_string(key.signature.dtype_a) << "/" + << to_string(key.signature.dtype_b) << "/" << to_string(key.signature.dtype_c) << "\n"; + os << " layout: " << to_string(key.signature.layout_a) << to_string(key.signature.layout_b) + << to_string(key.signature.layout_c) << "\n"; + + os << " Algorithm:\n"; + os << " tile: " << key.algorithm.tile_shape.m << "x" << key.algorithm.tile_shape.n + << "x" << key.algorithm.tile_shape.k << "\n"; + os << " wave: " << static_cast(key.algorithm.wave_shape.m) << "x" + << static_cast(key.algorithm.wave_shape.n) << "x" + << static_cast(key.algorithm.wave_shape.k) << "\n"; + os << " warp_tile: " << static_cast(key.algorithm.warp_tile_shape.m) << "x" + << static_cast(key.algorithm.warp_tile_shape.n) << "x" + << static_cast(key.algorithm.warp_tile_shape.k) << "\n"; + os << " pipeline: " << to_string(key.algorithm.pipeline) << "\n"; + os << " scheduler: " << to_string(key.algorithm.scheduler) << "\n"; + os << " epilogue: " << to_string(key.algorithm.epilogue) << "\n"; + + os << " Target: " << key.gfx_arch << "\n"; +} + +// ============================================================================= +// Kernel Key Builders +// ============================================================================= + +/** + * @brief Build a KernelKey for FP16 Row-Col-Row layout GEMM + * + * This is the most common configuration. Customize parameters as needed. + */ +struct KernelKeyBuilder +{ + // Tile shape + int tile_m = 128; + int tile_n = 128; + int tile_k = 32; + + // Wave shape (warps per block) + int wave_m = 2; + int wave_n = 2; + int wave_k = 1; + + // Warp tile shape + int warp_m = 32; + int warp_n = 32; + int warp_k = 16; + + // Block size + int block_size = 256; + + // Data types + DataType dtype_a = DataType::FP16; + DataType dtype_b = DataType::FP16; + DataType dtype_c = DataType::FP16; + DataType dtype_acc = DataType::FP32; + + // Layouts + LayoutTag layout_a = LayoutTag::RowMajor; + LayoutTag layout_b = LayoutTag::ColMajor; + LayoutTag layout_c = LayoutTag::RowMajor; + + // Pipeline/scheduler + Pipeline pipeline = Pipeline::CompV4; + Scheduler scheduler = Scheduler::Intrawave; + Epilogue epilogue = Epilogue::CShuffle; + + // Features + bool preshuffle = false; + int num_d_tensors = 0; // Multi-D: number of additional input tensors + std::string elementwise_op = "PassThrough"; + + // Target GPU + std::string gfx_arch = "gfx942"; + + /** + * @brief Build the KernelKey + */ + KernelKey build() const + { + KernelKey key; + + // Signature + key.signature.dtype_a = dtype_a; + key.signature.dtype_b = dtype_b; + key.signature.dtype_c = dtype_c; + key.signature.dtype_acc = dtype_acc; + key.signature.layout_a = layout_a; + key.signature.layout_b = layout_b; + key.signature.layout_c = layout_c; + key.signature.transpose_a = false; + key.signature.transpose_b = false; + key.signature.grouped = false; + key.signature.split_k = 1; + key.signature.elementwise_op = elementwise_op; + key.signature.num_d_tensors = num_d_tensors; + key.signature.structured_sparsity = false; + + // Algorithm + key.algorithm.tile_shape = {static_cast(tile_m), + static_cast(tile_n), + static_cast(tile_k)}; + key.algorithm.wave_shape = {static_cast(wave_m), + static_cast(wave_n), + static_cast(wave_k)}; + key.algorithm.warp_tile_shape = {static_cast(warp_m), + static_cast(warp_n), + static_cast(warp_k)}; + key.algorithm.pipeline = pipeline; + key.algorithm.scheduler = scheduler; + key.algorithm.epilogue = epilogue; + key.algorithm.block_size = block_size; + key.algorithm.double_buffer = true; + key.algorithm.persistent = false; + key.algorithm.preshuffle = preshuffle; + key.algorithm.transpose_c = false; + key.algorithm.num_wave_groups = 1; + + key.gfx_arch = gfx_arch; + + return key; + } + + // Convenience preset methods + static KernelKeyBuilder fp16_rcr() { return KernelKeyBuilder{}; } + + static KernelKeyBuilder fp16_rrr() + { + auto b = KernelKeyBuilder{}; + b.layout_b = LayoutTag::RowMajor; + return b; + } + + static KernelKeyBuilder preshuffle_v1() + { + auto b = KernelKeyBuilder{}; + b.pipeline = Pipeline::PreShuffleV1; + b.preshuffle = true; + return b; + } + + static KernelKeyBuilder preshuffle_v2() + { + auto b = KernelKeyBuilder{}; + b.pipeline = Pipeline::PreShuffleV2; + b.preshuffle = true; + return b; + } + + static KernelKeyBuilder multi_d(int num_d, const std::string& op = "MultiDAdd") + { + auto b = KernelKeyBuilder{}; + b.num_d_tensors = num_d; + b.elementwise_op = op; + return b; + } +}; + +} // namespace utils +} // namespace dispatcher +} // namespace ck_tile diff --git a/dispatcher/include/ck_tile/dispatcher/validation/reference_kernels.hpp b/dispatcher/include/ck_tile/dispatcher/validation/reference_kernels.hpp new file mode 100644 index 00000000000..a7e063c3cc6 --- /dev/null +++ b/dispatcher/include/ck_tile/dispatcher/validation/reference_kernels.hpp @@ -0,0 +1,228 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck_tile/dispatcher/problem.hpp" +#include +#include +#include + +namespace ck_tile { +namespace dispatcher { +namespace validation { + +/// Reference CPU GEMM implementation for validation +template +void reference_gemm_cpu(const ADataType* a, + const BDataType* b, + CDataType* c, + int M, + int N, + int K, + int stride_a, + int stride_b, + int stride_c, + bool transpose_a = false, + bool transpose_b = false) +{ + for(int m = 0; m < M; ++m) + { + for(int n = 0; n < N; ++n) + { + AccDataType acc = 0; + + for(int k = 0; k < K; ++k) + { + // Get A element + int a_idx = transpose_a ? (k * stride_a + m) : (m * stride_a + k); + AccDataType a_val = static_cast(a[a_idx]); + + // Get B element + int b_idx = transpose_b ? (n * stride_b + k) : (k * stride_b + n); + AccDataType b_val = static_cast(b[b_idx]); + + acc += a_val * b_val; + } + + // Write C element + int c_idx = m * stride_c + n; + c[c_idx] = static_cast(acc); + } + } +} + +/// Validate kernel output against reference +template +bool validate_output(const CDataType* result, + const CDataType* reference, + int size, + float rtol = 1e-3f, + float atol = 1e-5f) +{ + int errors = 0; + const int max_errors_to_print = 10; + + for(int i = 0; i < size; ++i) + { + float res_val = static_cast(result[i]); + float ref_val = static_cast(reference[i]); + + float abs_diff = std::abs(res_val - ref_val); + float abs_ref = std::abs(ref_val); + + bool is_valid = (abs_diff <= atol) || (abs_diff <= rtol * abs_ref); + + if(!is_valid) + { + if(errors < max_errors_to_print) + { + printf("Mismatch at index %d: result=%.6f, reference=%.6f, diff=%.6e\n", + i, + res_val, + ref_val, + abs_diff); + } + errors++; + } + } + + if(errors > 0) + { + printf("Validation failed: %d/%d elements mismatched (%.2f%%)\n", + errors, + size, + 100.0f * errors / size); + return false; + } + + return true; +} + +/// Validate kernel with reference implementation +template +bool validate_gemm_kernel(const void* a_dev_ptr, + const void* b_dev_ptr, + const void* c_dev_ptr, + const Problem& problem, + float rtol = 1e-3f, + float atol = 1e-5f) +{ + const int M = problem.M; + const int N = problem.N; + const int K = problem.K; + + // Allocate host memory + std::vector a_host(M * K); + std::vector b_host(K * N); + std::vector c_host(M * N); + std::vector c_ref(M * N); + + // Copy from device + hipMemcpy(a_host.data(), a_dev_ptr, M * K * sizeof(ADataType), hipMemcpyDeviceToHost); + hipMemcpy(b_host.data(), b_dev_ptr, K * N * sizeof(BDataType), hipMemcpyDeviceToHost); + hipMemcpy(c_host.data(), c_dev_ptr, M * N * sizeof(CDataType), hipMemcpyDeviceToHost); + + // Compute reference + reference_gemm_cpu(a_host.data(), + b_host.data(), + c_ref.data(), + M, + N, + K, + K, // stride_a (row-major) + N, // stride_b (row-major) + N, // stride_c (row-major) + false, + false); + + // Validate + return validate_output(c_host.data(), c_ref.data(), M * N, rtol, atol); +} + +/// Validator class for kernel instances +class KernelValidator +{ + public: + KernelValidator(float rtol = 1e-3f, float atol = 1e-5f) : rtol_(rtol), atol_(atol) {} + + /// Validate a kernel instance + template + bool validate(KernelInstance& kernel, + const void* a_ptr, + const void* b_ptr, + const void* c_ptr, + const Problem& problem) + { + // Use kernel's validate method if available + return kernel.validate(a_ptr, b_ptr, c_ptr, problem, rtol_, atol_); + } + + /// Set tolerances + void set_tolerances(float rtol, float atol) + { + rtol_ = rtol; + atol_ = atol; + } + + /// Get tolerances + std::pair get_tolerances() const { return {rtol_, atol_}; } + + private: + float rtol_; + float atol_; +}; + +/// Helper to generate random test data +template +void generate_random_data(T* data, int size, float min_val = -1.0f, float max_val = 1.0f) +{ + for(int i = 0; i < size; ++i) + { + float rand_val = min_val + (max_val - min_val) * (rand() / (float)RAND_MAX); + data[i] = static_cast(rand_val); + } +} + +/// Helper to allocate and initialize test tensors +template +struct TestTensor +{ + T* host_ptr; + T* device_ptr; + int size; + + TestTensor(int size_) : size(size_) + { + host_ptr = new T[size]; + hipMalloc(&device_ptr, size * sizeof(T)); + } + + ~TestTensor() + { + delete[] host_ptr; + hipFree(device_ptr); + } + + void randomize(float min_val = -1.0f, float max_val = 1.0f) + { + generate_random_data(host_ptr, size, min_val, max_val); + hipMemcpy(device_ptr, host_ptr, size * sizeof(T), hipMemcpyHostToDevice); + } + + void copy_to_device() + { + hipMemcpy(device_ptr, host_ptr, size * sizeof(T), hipMemcpyHostToDevice); + } + + void copy_from_device() + { + hipMemcpy(host_ptr, device_ptr, size * sizeof(T), hipMemcpyDeviceToHost); + } + + void zero() { hipMemset(device_ptr, 0, size * sizeof(T)); } +}; + +} // namespace validation +} // namespace dispatcher +} // namespace ck_tile diff --git a/dispatcher/python/CMakeLists.txt b/dispatcher/python/CMakeLists.txt new file mode 100644 index 00000000000..e57678952ec --- /dev/null +++ b/dispatcher/python/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +# This directory contains Python utilities for the dispatcher examples. +# The main utility file is ctypes_utils.py which is used by GEMM Python examples. +# Conv Python examples use their own conv_utils.py in the examples directory. + +# No build targets needed - these are pure Python utilities. +message(STATUS "Python utilities directory configured (no build targets)") diff --git a/dispatcher/python/README.md b/dispatcher/python/README.md new file mode 100644 index 00000000000..9286acbf72d --- /dev/null +++ b/dispatcher/python/README.md @@ -0,0 +1,60 @@ +# CK Tile Dispatcher Python Utilities + +This directory contains Python utilities used by the dispatcher examples. + +## Contents + +- `ctypes_utils.py` - Core ctypes utilities for GEMM Python examples + - `KernelConfig` - Kernel configuration dataclass + - `setup_gemm_dispatcher()` - Setup dispatcher with auto-correction + - `cleanup_gemm()` - Cleanup dispatcher resources + - `GemmRunner` - GPU execution helper + - Auto-correction and validation utilities + +- `conv_utils.py` - Core utilities for Conv Python examples + - `ConvSignature`, `ConvAlgorithm` - Convolution configuration + - `ConvProblem` - Problem definition + - `GpuConvRunner` - GPU execution helper + - `EnhancedConvCodegenRunner` - Kernel codegen utilities + +## Usage + +### GEMM Examples + +The GEMM Python examples in `dispatcher/examples/gemm/python/` import: + +```python +import sys +from pathlib import Path +sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python")) + +from ctypes_utils import ( + KernelConfig, + setup_gemm_dispatcher, + cleanup_gemm, + GemmRunner, +) +``` + +### Conv Examples + +The Conv Python examples in `dispatcher/examples/conv/python/` import: + +```python +import sys +from pathlib import Path +sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python")) + +from conv_utils import ( + ConvSignature, + ConvAlgorithm, + ConvProblem, + GpuConvRunner, +) +``` + +## Requirements + +- Python 3.8+ +- NumPy +- HIP runtime (for GPU execution) diff --git a/dispatcher/python/ctypes_utils.py b/dispatcher/python/ctypes_utils.py new file mode 100644 index 00000000000..821fc2b08dc --- /dev/null +++ b/dispatcher/python/ctypes_utils.py @@ -0,0 +1,2347 @@ +#!/usr/bin/env python3 + +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +CK Tile Dispatcher Utilities + +Common utilities for loading, compiling, and using the CK Tile dispatcher. + +Usage: + from ck_tile_dispatcher.utils import DispatcherLib, GemmRunner, Validator + + # Option 1: Auto-compile and load + lib = DispatcherLib.auto() + + # Option 2: Load existing library + lib = DispatcherLib.load("/path/to/libdispatcher_gemm.so") + + # Run GEMM + runner = GemmRunner(lib) + result = runner.run(A, B) + + # Validate + validator = Validator() + check = validator.check(result.C, C_reference) +""" + +import ctypes +import subprocess +import numpy as np +from pathlib import Path +from typing import Optional, Tuple, List, Dict, Any +from dataclasses import dataclass, field +from concurrent.futures import ProcessPoolExecutor, as_completed +import multiprocessing +import time + + +# ============================================================================= +# Path Configuration +# ============================================================================= + + +def get_dispatcher_root() -> Path: + """Get the dispatcher root directory""" + # This file is in dispatcher/python/ + return Path(__file__).parent.parent + + +def get_ck_root() -> Path: + """Get the CK root directory""" + return get_dispatcher_root().parent + + +def get_build_dir() -> Path: + """Get the build directory""" + return get_dispatcher_root() / "build" + + +# ============================================================================= +# Supported Data Types +# ============================================================================= + +# All supported GEMM dtype combinations from warp_gemm_dispatcher.hpp +SUPPORTED_DTYPES = { + # dtype_a, dtype_b -> acc_dtype, warp_tiles + ("fp32", "fp32"): {"acc": "fp32", "warp_tiles": [(16, 16, 4), (16, 16, 16)]}, + ("fp16", "fp16"): { + "acc": "fp32", + "warp_tiles": [(32, 32, 8), (32, 32, 16), (16, 16, 16), (16, 16, 32)], + }, + ("bf16", "bf16"): { + "acc": "fp32", + "warp_tiles": [(32, 32, 8), (32, 32, 16), (16, 16, 16), (16, 16, 32)], + }, + ("fp8", "fp8"): { + "acc": "fp32", + "warp_tiles": [(32, 32, 16), (32, 32, 32), (16, 16, 32), (16, 16, 64)], + }, + ("fp8", "bf8"): {"acc": "fp32", "warp_tiles": [(32, 32, 16), (16, 16, 32)]}, + ("bf8", "fp8"): {"acc": "fp32", "warp_tiles": [(32, 32, 16), (16, 16, 128)]}, + ("bf8", "bf8"): { + "acc": "fp32", + "warp_tiles": [(32, 32, 16), (32, 32, 32), (16, 16, 32)], + }, + ("int8", "int8"): { + "acc": "int32", + "warp_tiles": [(32, 32, 16), (16, 16, 32), (16, 16, 16)], + }, + ("pk_fp4", "pk_fp4"): {"acc": "fp32", "warp_tiles": [(16, 16, 128)]}, +} + +# All valid individual dtypes +VALID_DTYPES = ["fp16", "bf16", "fp32", "fp8", "bf8", "int8", "pk_fp4"] + + +def get_generated_kernels_dir() -> Path: + """Get the generated kernels directory""" + return get_build_dir() / "generated_kernels" + + +# ============================================================================= +# Arch Filter and Validation +# ============================================================================= + + +def get_arch_filter_data() -> Dict[str, Any]: + """Load arch filter data from arch_specs_generated if available.""" + codegen_dir = get_dispatcher_root() / "codegen" + import sys + + sys.path.insert(0, str(codegen_dir)) + + try: + from arch_specs_generated import ( + TRAIT_UNSUPPORTED_COMBINATIONS, + WARP_SUPPORTED_COMBINATIONS, + WARP_TILE_SUPPORTED_COMBINATIONS, + get_supported_archs, + ) + + return { + "trait_unsupported": TRAIT_UNSUPPORTED_COMBINATIONS, + "warp_combos": WARP_SUPPORTED_COMBINATIONS, + "warp_tile_combos": WARP_TILE_SUPPORTED_COMBINATIONS, + "supported_archs": get_supported_archs(), + } + except ImportError: + # Fallback defaults + return { + "trait_unsupported": { + ("compv3", "cshuffle", "interwave"), + ("compv3", "default", "interwave"), + ("compv4", "cshuffle", "interwave"), + ("compv4", "default", "interwave"), + }, + "warp_combos": { + "gfx942": [[1, 4, 1], [2, 2, 1], [4, 1, 1]], + "gfx90a": [[1, 4, 1], [2, 2, 1], [4, 1, 1]], + }, + "warp_tile_combos": { + "gfx942": {"fp16_fp16_fp16": [[16, 16, 16], [32, 32, 16]]}, + "gfx90a": {"fp16_fp16_fp16": [[16, 16, 16], [32, 32, 16]]}, + }, + "supported_archs": ["gfx90a", "gfx942", "gfx950"], + } + + +@dataclass +class ValidationResult: + """Result of kernel config validation.""" + + is_valid: bool + errors: List[str] = field(default_factory=list) + warnings: List[str] = field(default_factory=list) + suggested_fixes: Dict[str, Any] = field(default_factory=dict) + + def print_result(self, indent: str = " "): + """Print validation result.""" + if self.is_valid: + print(f"{indent}✓ Configuration valid") + else: + print(f"{indent}⚠ Configuration has issues:") + for err in self.errors: + print(f"{indent} - {err}") + + if self.warnings: + for warn in self.warnings: + print(f"{indent} Warning: {warn}") + + if self.suggested_fixes: + print(f"{indent} Suggested fixes:") + for key, val in self.suggested_fixes.items(): + print(f"{indent} {key}: {val}") + + +def validate_kernel_config(config: "KernelConfig") -> ValidationResult: + """ + Validate a KernelConfig against arch filter rules. + + Validation considers the GEMM variant (standard, preshuffle, multi_d) + for operator-specific constraints like minimum tile sizes. + + Returns ValidationResult with is_valid, errors, and suggested fixes. + """ + arch_data = get_arch_filter_data() + + errors = [] + warnings = [] + suggested_fixes = {} + + pipeline = config.pipeline + epilogue = config.epilogue + scheduler = config.scheduler + dtype = config.dtype_a + arch = config.gfx_arch + variant = getattr(config, "variant", "standard") + + wave_m = config.wave_m + wave_n = config.wave_n + wave_k = config.wave_k + + warp_m = config.warp_m + warp_n = config.warp_n + warp_k = config.warp_k + + # Variant-specific tile constraints + if variant == "preshuffle": + # Preshuffle requires larger minimum tiles for efficiency + if config.tile_m < 64: + errors.append(f"Preshuffle requires tile_m >= 64, got {config.tile_m}") + suggested_fixes["tile_m"] = 64 + if config.tile_n < 64: + errors.append(f"Preshuffle requires tile_n >= 64, got {config.tile_n}") + suggested_fixes["tile_n"] = 64 + if config.tile_k < 32: + errors.append(f"Preshuffle requires tile_k >= 32, got {config.tile_k}") + suggested_fixes["tile_k"] = 32 + + elif variant == "multi_d": + # Multi-D has standard GEMM constraints + # Could add specific constraints here if needed + pass + + # Check trait combination (pipeline, epilogue, scheduler) + combo = (pipeline, epilogue, scheduler) + if combo in arch_data["trait_unsupported"]: + errors.append( + f"Unsupported trait combination: pipeline={pipeline}, epilogue={epilogue}, scheduler={scheduler}" + ) + suggested_fixes["scheduler"] = "intrawave" + + # Check wave configuration for this arch + warp_combos = arch_data["warp_combos"].get(arch, [[2, 2, 1]]) + wave_cfg = [wave_m, wave_n, wave_k] + if wave_cfg not in warp_combos: + valid_str = ", ".join(f"[{c[0]},{c[1]},{c[2]}]" for c in warp_combos) + errors.append( + f"Unsupported wave configuration [{wave_m},{wave_n},{wave_k}] for {arch}. Valid: {valid_str}" + ) + if warp_combos: + suggested_fixes["wave_m"] = warp_combos[0][0] + suggested_fixes["wave_n"] = warp_combos[0][1] + suggested_fixes["wave_k"] = warp_combos[0][2] + + # Check warp tile configuration for this arch and dtype + dtype_key = f"{dtype}_{dtype}_{dtype}" + warp_tile_combos = ( + arch_data["warp_tile_combos"] + .get(arch, {}) + .get(dtype_key, [[32, 32, 16], [16, 16, 16]]) + ) + warp_cfg = [warp_m, warp_n, warp_k] + if warp_cfg not in warp_tile_combos: + valid_str = ", ".join(f"[{c[0]},{c[1]},{c[2]}]" for c in warp_tile_combos[:5]) + errors.append( + f"Unsupported warp tile [{warp_m},{warp_n},{warp_k}] for {arch}/{dtype}. Valid: {valid_str}" + ) + if warp_tile_combos: + suggested_fixes["warp_m"] = warp_tile_combos[0][0] + suggested_fixes["warp_n"] = warp_tile_combos[0][1] + suggested_fixes["warp_k"] = warp_tile_combos[0][2] + + # Check arch is supported + if arch not in arch_data["supported_archs"]: + errors.append( + f"Unsupported architecture: {arch}. Supported: {', '.join(arch_data['supported_archs'])}" + ) + + return ValidationResult( + is_valid=len(errors) == 0, + errors=errors, + warnings=warnings, + suggested_fixes=suggested_fixes, + ) + + +def auto_correct_kernel_config( + config: "KernelConfig", verbose: bool = False +) -> Tuple["KernelConfig", bool, List[str]]: + """ + Validate and auto-correct a KernelConfig. + + Returns (corrected_config, was_modified, corrections_list). + If the config was valid, returns (original_config, False, []). + If corrections were made, returns (new_config, True, [list of correction descriptions]). + """ + validation = validate_kernel_config(config) + + if validation.is_valid: + return config, False, [] + + # Apply suggested fixes and track what changed + from dataclasses import replace + + fixes = validation.suggested_fixes + corrections = [] + + # Check each fix and describe what changed + if "scheduler" in fixes and fixes["scheduler"] != config.scheduler: + corrections.append( + f"Scheduler: {config.scheduler} → {fixes['scheduler']} " + f"('{config.scheduler}' not supported with pipeline={config.pipeline}, epilogue={config.epilogue})" + ) + + if "wave_m" in fixes or "wave_n" in fixes or "wave_k" in fixes: + old_wave = f"[{config.wave_m}, {config.wave_n}, {config.wave_k}]" + new_wave = f"[{fixes.get('wave_m', config.wave_m)}, {fixes.get('wave_n', config.wave_n)}, {fixes.get('wave_k', config.wave_k)}]" + if old_wave != new_wave: + corrections.append( + f"Wave config: {old_wave} → {new_wave} " + f"(original not supported on {config.gfx_arch})" + ) + + if "warp_m" in fixes or "warp_n" in fixes or "warp_k" in fixes: + old_warp = f"[{config.warp_m}, {config.warp_n}, {config.warp_k}]" + new_warp = f"[{fixes.get('warp_m', config.warp_m)}, {fixes.get('warp_n', config.warp_n)}, {fixes.get('warp_k', config.warp_k)}]" + if old_warp != new_warp: + corrections.append( + f"Warp tile: {old_warp} → {new_warp} " + f"(original not supported for {config.dtype_a} on {config.gfx_arch})" + ) + + new_config = replace( + config, + scheduler=fixes.get("scheduler", config.scheduler), + wave_m=fixes.get("wave_m", config.wave_m), + wave_n=fixes.get("wave_n", config.wave_n), + wave_k=fixes.get("wave_k", config.wave_k), + warp_m=fixes.get("warp_m", config.warp_m), + warp_n=fixes.get("warp_n", config.warp_n), + warp_k=fixes.get("warp_k", config.warp_k), + ) + + return new_config, True, corrections + + +def print_kernel_config(config: "KernelConfig", title: str = "KERNEL CONFIGURATION"): + """ + Print a formatted kernel configuration for GEMM. + + Args: + config: The KernelConfig to print + title: Title to display (e.g., "REQUESTED KERNEL CONFIGURATION") + """ + print() + print("=" * 70) + print(f" {title}") + print("=" * 70) + print(f" Data Type A: {config.dtype_a}") + print(f" Data Type B: {config.dtype_b}") + print(f" Data Type C: {config.dtype_c}") + print(f" Accumulator: {config.dtype_acc}") + print() + print( + f" Layout: {config.layout} (A={config.layout_a}, B={config.layout_b}, C={config.layout_c})" + ) + print() + print(f" Tile M x N x K: {config.tile_m} x {config.tile_n} x {config.tile_k}") + print(f" Wave Config: {config.wave_m} x {config.wave_n} x {config.wave_k}") + print(f" Warp Tile: {config.warp_m} x {config.warp_n} x {config.warp_k}") + print() + print(f" Pipeline: {config.pipeline}") + print(f" Scheduler: {config.scheduler}") + print(f" Epilogue: {config.epilogue}") + print() + print(f" Target Arch: {config.gfx_arch}") + print("=" * 70) + print() + + +def print_auto_correction( + original: "KernelConfig", + corrected: "KernelConfig", + corrections: List[str], + indent: str = " ", +): + """ + Print what was auto-corrected and why. + + Args: + original: Original configuration before correction + corrected: Configuration after correction + corrections: List of correction descriptions + indent: Indentation for output + """ + if not corrections: + print(f"{indent}✓ Configuration valid - no corrections needed") + return + + print(f"\n{indent}⚠ AUTO-CORRECTION APPLIED:") + print(f"{indent}" + "-" * 50) + for correction in corrections: + print(f"{indent} • {correction}") + print(f"{indent}" + "-" * 50) + print() + + +def find_matching_kernel_header(config: "KernelConfig") -> Optional[Path]: + """ + Find a kernel header that EXACTLY matches the config. + + Uses progressively relaxed matching strategies. + """ + kernel_dir = get_generated_kernels_dir() + + dtype = config.dtype_a + layout = config.layout + pipeline = config.pipeline + scheduler = config.scheduler + tile_str = config.tile_str + wave_str = f"{config.wave_m}x{config.wave_n}x{config.wave_k}" + warp_str = f"{config.warp_m}x{config.warp_n}x{config.warp_k}" + + # Strategy 1: Exact match with ALL parameters including warp tile + pattern = f"gemm_{dtype}_{layout}_{pipeline}_*_{scheduler}_*_{tile_str}_{wave_str}_{warp_str}.hpp" + matches = list(kernel_dir.glob(pattern)) + if matches: + return matches[0] + + # Strategy 2: Match with tile and wave, any warp + pattern = ( + f"gemm_{dtype}_{layout}_{pipeline}_*_{scheduler}_*_{tile_str}_{wave_str}_*.hpp" + ) + matches = list(kernel_dir.glob(pattern)) + if matches: + return matches[0] + + # Strategy 3: Match with just tile (ignore wave/warp) + pattern = f"gemm_{dtype}_{layout}_{pipeline}_*_{scheduler}_*_{tile_str}_*.hpp" + matches = list(kernel_dir.glob(pattern)) + if matches: + return matches[0] + + # Strategy 4: Match with intrawave (known to work) + pattern = f"gemm_{dtype}_{layout}_*_intrawave_*_{tile_str}_*.hpp" + matches = list(kernel_dir.glob(pattern)) + if matches: + return matches[0] + + # Strategy 5: Any kernel with matching dtype/layout/tile + pattern = f"gemm_{dtype}_{layout}_*_{tile_str}_*.hpp" + matches = list(kernel_dir.glob(pattern)) + if matches: + return matches[0] + + return None + + +# ============================================================================= +# Library Loading +# ============================================================================= + + +class DispatcherLib: + """Wrapper for the dispatcher dynamic library""" + + # Default library search paths (relative to dispatcher root) + SEARCH_PATHS = [ + "build/examples/libdispatcher_gemm_lib.so", + "build/libdispatcher_gemm_lib.so", + "build/examples/libdispatcher_gemm.so", + "build/lib/libdispatcher_gemm.so", + ] + + # Track loaded libraries globally for cleanup + _loaded_libs: List[Path] = [] + + def __init__(self, lib: ctypes.CDLL, path: Path): + self._lib = lib + self._path = path + self._closed = False + DispatcherLib._loaded_libs.append(path) + self._setup_functions() + + def _setup_functions(self): + """Setup ctypes function signatures""" + # Initialize + self._lib.dispatcher_initialize.argtypes = [] + self._lib.dispatcher_initialize.restype = ctypes.c_int + + # Alias for init + self._lib.dispatcher_init.argtypes = [] + self._lib.dispatcher_init.restype = ctypes.c_int + + # Get kernel count + self._lib.dispatcher_get_kernel_count.argtypes = [] + self._lib.dispatcher_get_kernel_count.restype = ctypes.c_int + + # Check if supported + self._lib.dispatcher_is_supported.argtypes = [ + ctypes.c_int64, + ctypes.c_int64, + ctypes.c_int64, + ] + self._lib.dispatcher_is_supported.restype = ctypes.c_int + + # Run GEMM + self._lib.dispatcher_run_gemm.argtypes = [ + ctypes.c_void_p, # A + ctypes.c_void_p, # B + ctypes.c_void_p, # C + ctypes.c_int64, # M + ctypes.c_int64, # N + ctypes.c_int64, # K + ctypes.POINTER(ctypes.c_float), # time_ms + ] + self._lib.dispatcher_run_gemm.restype = ctypes.c_int + + # Get kernel name + self._lib.dispatcher_get_kernel_name.argtypes = [] + self._lib.dispatcher_get_kernel_name.restype = ctypes.c_char_p + + # Select kernel + self._lib.dispatcher_select_kernel.argtypes = [ + ctypes.c_int64, + ctypes.c_int64, + ctypes.c_int64, + ctypes.c_char_p, + ctypes.c_int, + ] + self._lib.dispatcher_select_kernel.restype = ctypes.c_int + + # Export JSON + self._lib.dispatcher_export_registry_json.argtypes = [] + self._lib.dispatcher_export_registry_json.restype = ctypes.c_char_p + + # Cleanup + self._lib.dispatcher_cleanup.argtypes = [] + self._lib.dispatcher_cleanup.restype = None + + @property + def path(self) -> Path: + return self._path + + def initialize(self) -> bool: + """Initialize the dispatcher""" + return self._lib.dispatcher_initialize() == 0 + + def get_kernel_count(self) -> int: + """Get number of registered kernels""" + return self._lib.dispatcher_get_kernel_count() + + def is_supported(self, M: int, N: int, K: int) -> bool: + """Check if a problem size is supported""" + return self._lib.dispatcher_is_supported(M, N, K) == 1 + + def get_kernel_name(self) -> str: + """Get the kernel name""" + name = self._lib.dispatcher_get_kernel_name() + return name.decode("utf-8") if name else "unknown" + + def select_kernel(self, M: int, N: int, K: int) -> Optional[str]: + """Select kernel for problem and return its name""" + buffer = ctypes.create_string_buffer(256) + result = self._lib.dispatcher_select_kernel(M, N, K, buffer, 256) + if result == 0: + return buffer.value.decode("utf-8") + return None + + def run_gemm( + self, A: np.ndarray, B: np.ndarray, C: np.ndarray, M: int, N: int, K: int + ) -> Tuple[int, float]: + """ + Run GEMM operation + + Returns: (status, time_ms) + status: 0 = success, -1 = error, -2 = no suitable kernel + """ + time_ms = ctypes.c_float(0.0) + + status = self._lib.dispatcher_run_gemm( + A.ctypes.data_as(ctypes.c_void_p), + B.ctypes.data_as(ctypes.c_void_p), + C.ctypes.data_as(ctypes.c_void_p), + M, + N, + K, + ctypes.byref(time_ms), + ) + + return status, time_ms.value + + def export_json(self) -> Optional[str]: + """Export registry to JSON string""" + json_ptr = self._lib.dispatcher_export_registry_json() + if json_ptr: + return json_ptr.decode("utf-8") + return None + + def export_registry_json(self) -> str: + """Alias for export_json for compatibility""" + return self.export_json() or "{}" + + def cleanup(self): + """Cleanup dispatcher resources""" + self._lib.dispatcher_cleanup() + + @classmethod + def find(cls) -> Optional[Path]: + """Find the dispatcher library""" + root = get_dispatcher_root() + + for rel_path in cls.SEARCH_PATHS: + path = root / rel_path + if path.exists(): + return path + + return None + + @classmethod + def load(cls, path: Optional[Path] = None) -> Optional["DispatcherLib"]: + """Load the dispatcher library from path or auto-find""" + if path is None: + path = cls.find() + + if path is None or not path.exists(): + return None + + try: + lib = ctypes.CDLL(str(path)) + return cls(lib, path) + except OSError as e: + print(f"Failed to load library: {e}") + return None + + @classmethod + def compile(cls, output_path: Optional[Path] = None) -> Optional[Path]: + """Compile the dispatcher library""" + root = get_dispatcher_root() + ck_root = get_ck_root() + + if output_path is None: + output_path = get_build_dir() / "examples" / "libdispatcher_gemm.so" + + output_path.parent.mkdir(parents=True, exist_ok=True) + + # Find a kernel header to include + kernel_dir = get_generated_kernels_dir() + kernel_headers = list(kernel_dir.glob("gemm_fp16_rcr_compv4*128x128x32*.hpp")) + + if not kernel_headers: + print("No kernel headers found. Generate kernels first.") + return None + + kernel_header = kernel_headers[0] + + # Use the ctypes binding source file + ctypes_source = root / "bindings/ctypes/gemm_ctypes_lib.cpp" + if not ctypes_source.exists(): + print(f"Source file not found: {ctypes_source}") + print( + "Please build with CMake: cd build && cmake .. && make dispatcher_gemm_lib" + ) + return None + + # CK_TILE_SINGLE_KERNEL_INCLUDE exports types to global namespace for ctypes binding + compile_cmd = [ + "/opt/rocm/bin/hipcc", + "-shared", + "-fPIC", + "-O3", + f"-I{root / 'include'}", + f"-I{ck_root / 'include'}", + f"-I{ck_root}", + f"-I{root / 'build/generated_kernels'}", + "-DCK_TILE_SINGLE_KERNEL_INCLUDE", # Enable global namespace exports + f"-include{kernel_header}", + "-D__HIP_PLATFORM_AMD__", + "--offload-arch=gfx942", + "-DAMDGPU_ARCH=gfx942", + "-mllvm", + "-enable-noalias-to-md-conversion=0", + "-Wno-undefined-func-template", + "-Wno-float-equal", + str(ctypes_source), + "-o", + str(output_path), + ] + + try: + result = subprocess.run( + compile_cmd, capture_output=True, text=True, timeout=120 + ) + if result.returncode == 0: + return output_path + else: + print(f"Compilation failed:\n{result.stderr}") + return None + except subprocess.TimeoutExpired: + print("Compilation timed out") + return None + + @classmethod + def auto(cls, recompile: bool = False) -> Optional["DispatcherLib"]: + """Auto-find or compile the library. + + Note: The library is built by CMake with a specific kernel configuration. + If you need a different dtype/layout, rebuild with: + cd build && cmake .. && make dispatcher_gemm_lib + """ + lib = cls.load() + if lib is not None: + if lib.initialize(): + return lib + else: + print(" Library found but failed to initialize") + print( + " Rebuild with: cd build && cmake .. && make dispatcher_gemm_lib" + ) + + # Don't fall back to old compile method - use CMake instead + print(" Library not found. Build with:") + print(" cd dispatcher/build && cmake .. && make dispatcher_gemm_lib") + return None + + +# ============================================================================= +# GEMM Runner +# ============================================================================= + + +@dataclass +class GemmResult: + """Result of a GEMM operation""" + + output: np.ndarray # The output C matrix + time_ms: float + status: int + tflops: float + kernel_name: str + + @property + def success(self) -> bool: + return self.status == 0 + + # Alias for backward compatibility + @property + def C(self) -> np.ndarray: + return self.output + + +class GemmRunner: + """High-level GEMM runner using the dispatcher""" + + def __init__(self, lib: DispatcherLib): + self.lib = lib + + def run(self, A: np.ndarray, B: np.ndarray, dtype=np.float16) -> GemmResult: + """ + Run GEMM: C = A @ B + + Args: + A: Input matrix (M x K) + B: Input matrix (K x N) + dtype: Output data type (default: float16) + + Returns: + GemmResult with output matrix and timing + """ + M, K = A.shape + K2, N = B.shape + + assert K == K2, f"Dimension mismatch: A is {M}x{K}, B is {K2}x{N}" + + # Ensure contiguous float16 arrays + A_gpu = np.ascontiguousarray(A, dtype=np.float16) + B_gpu = np.ascontiguousarray(B.T, dtype=np.float16) # Column-major + C_gpu = np.zeros((M, N), dtype=np.float16) + + # Run + status, time_ms = self.lib.run_gemm(A_gpu, B_gpu, C_gpu, M, N, K) + + # Calculate TFLOPS + flops = 2.0 * M * N * K + tflops = (flops / (time_ms * 1e-3)) / 1e12 if time_ms > 0 else 0 + + return GemmResult( + output=C_gpu, + time_ms=time_ms, + status=status, + tflops=tflops, + kernel_name=self.lib.get_kernel_name(), + ) + + def benchmark( + self, M: int, N: int, K: int, warmup: int = 2, iterations: int = 10 + ) -> dict: + """Benchmark GEMM for given dimensions""" + A = np.random.randn(M, K).astype(np.float16) + B = np.random.randn(K, N).astype(np.float16) + + times = [] + + # Warmup + for _ in range(warmup): + self.run(A, B) + + # Benchmark + for _ in range(iterations): + result = self.run(A, B) + if result.success: + times.append(result.time_ms) + + if not times: + return {"error": "All iterations failed"} + + flops = 2.0 * M * N * K + avg_time = sum(times) / len(times) + + return { + "M": M, + "N": N, + "K": K, + "min_ms": min(times), + "avg_ms": avg_time, + "max_ms": max(times), + "tflops": (flops / (avg_time * 1e-3)) / 1e12, + "iterations": len(times), + } + + +# ============================================================================= +# Validation Utilities +# ============================================================================= + + +class Validator: + """Utilities for validating GEMM results""" + + def __init__(self, rtol: float = 1e-3, atol: float = 1e-2): + self.rtol = rtol + self.atol = atol + + def check( + self, result: np.ndarray, reference: np.ndarray + ) -> Tuple[bool, float, float]: + """ + Check if result matches reference + + Returns: (is_correct, max_diff, mean_diff) + """ + result = result.astype(np.float32) + reference = reference.astype(np.float32) + + diff = np.abs(result - reference) + max_diff = float(np.max(diff)) + mean_diff = float(np.mean(diff)) + + close = np.allclose(result, reference, rtol=self.rtol, atol=self.atol) + + return close, max_diff, mean_diff + + def compute_reference(self, A: np.ndarray, B: np.ndarray) -> np.ndarray: + """Compute reference GEMM result using NumPy""" + return np.matmul(A.astype(np.float32), B.astype(np.float32)) + + +# ============================================================================= +# Code Generation Utilities +# ============================================================================= + + +def get_codegen_path() -> Path: + """Get path to unified_gemm_codegen.py""" + return get_dispatcher_root() / "codegen" / "unified_gemm_codegen.py" + + +@dataclass +class CodegenResult: + """Result of kernel code generation""" + + success: bool + output_dir: Path + variant: str + stdout: str = "" + stderr: str = "" + kernel_count: int = 0 + elapsed_seconds: float = 0.0 + instance_names: List[str] = field(default_factory=list) + + def get_generated_kernels(self) -> List[Path]: + """Get list of generated kernel headers""" + if self.output_dir.exists(): + return list(self.output_dir.glob("*.hpp")) + return [] + + def print_instances(self, prefix: str = " "): + """Print all generated instance names.""" + for name in self.instance_names: + print(f"{prefix}{name}") + + +def _run_codegen_subprocess(args: Dict[str, Any]) -> CodegenResult: + """ + Worker function for parallel codegen execution. + + This is a module-level function to allow pickling for ProcessPoolExecutor. + """ + import sys + import subprocess + from pathlib import Path + + codegen_path = Path(args["codegen_path"]) + out_dir = Path(args["output_dir"]) + variant = args["variant"] + datatype = args["datatype"] + layout = args["layout"] + gpu_target = args["gpu_target"] + extra_args = args.get("extra_args", []) + timeout = args.get("timeout", 300) + + out_dir.mkdir(parents=True, exist_ok=True) + + start = time.time() + + # Get existing kernels before generation + existing_kernels = set(out_dir.glob("*.hpp")) if out_dir.exists() else set() + + cmd = [ + sys.executable, + str(codegen_path), + "--output-dir", + str(out_dir), + "--datatype", + datatype, + "--layout", + layout, + "--gpu-target", + gpu_target, + "--variants", + variant, + ] + + if extra_args: + cmd.extend(extra_args) + + try: + result = subprocess.run(cmd, capture_output=True, text=True, timeout=timeout) + + # Get new kernels after generation + all_kernels = set(out_dir.glob("*.hpp")) + new_kernels = all_kernels - existing_kernels + kernel_count = len(all_kernels) + elapsed = time.time() - start + + # Build instance names list for verbose output + instance_names = sorted([k.stem for k in new_kernels]) + + return CodegenResult( + success=result.returncode == 0, + output_dir=out_dir, + variant=variant, + stdout=result.stdout, + stderr=result.stderr, + kernel_count=kernel_count, + elapsed_seconds=elapsed, + instance_names=instance_names, + ) + except subprocess.TimeoutExpired: + return CodegenResult( + success=False, + output_dir=out_dir, + variant=variant, + stderr=f"Code generation timed out ({timeout}s)", + elapsed_seconds=time.time() - start, + ) + except Exception as e: + return CodegenResult( + success=False, + output_dir=out_dir, + variant=variant, + stderr=str(e), + elapsed_seconds=time.time() - start, + ) + + +# ============================================================================= +# Preshuffle Utilities +# ============================================================================= + + +def preshuffle_weight_matrix( + B: np.ndarray, + warp_tile_n: int, + warp_tile_k: int, + arch: str = "gfx942", +) -> np.ndarray: + """ + Preshuffle the B (weight) matrix for optimized GEMM inference. + + This transforms the B matrix layout to match the expected memory access + pattern for preshuffle-enabled kernels. The transformation reorders data + so that warp-level loads are coalesced. + + Args: + B: Weight matrix of shape (K, N) in column-major / (K, N) layout + warp_tile_n: Warp tile size in N dimension (e.g., 32) + warp_tile_k: Warp tile size in K dimension (e.g., 16) + arch: Target GPU architecture (gfx9xx, gfx11xx, gfx12xx) + + Returns: + Shuffled B matrix with same data but reordered layout + + Example: + >>> B = np.random.randn(1024, 2048).astype(np.float16) + >>> B_shuffled = preshuffle_weight_matrix(B, warp_tile_n=32, warp_tile_k=16) + >>> # Use B_shuffled with preshuffle-enabled kernel + """ + K, N = B.shape + + # Validate dimensions are divisible by warp tiles + if N % warp_tile_n != 0: + raise ValueError(f"N ({N}) must be divisible by warp_tile_n ({warp_tile_n})") + if K % warp_tile_k != 0: + raise ValueError(f"K ({K}) must be divisible by warp_tile_k ({warp_tile_k})") + + # Architecture-specific shuffle patterns + # Based on ck_tile/host/tensor_shuffle_utils.hpp + if arch.startswith("gfx12"): + # GFX12 (RDNA4) pattern + divisor = 2 + k_abk1_per_lane = 8 + k_abk0_per_lane = warp_tile_k // divisor // k_abk1_per_lane + + if k_abk0_per_lane <= 0: + raise ValueError( + f"warp_tile_k ({warp_tile_k}) too small for GFX12 preshuffle" + ) + + # Reshape: (K, N) -> (N/warp_n, warp_n, K/warp_k, k0, div, k1) + B_view = B.T.reshape( + N // warp_tile_n, + warp_tile_n, + K // warp_tile_k, + k_abk0_per_lane, + divisor, + k_abk1_per_lane, + ) + # Permute: {0, 2, 4, 1, 3, 5} + B_shuffled = np.transpose(B_view, (0, 2, 4, 1, 3, 5)) + + elif arch.startswith("gfx11"): + # GFX11 (RDNA3) pattern - divisor = 1 + divisor = 1 + + # Reshape: (K, N) -> (N/warp_n, warp_n, K/warp_k, div, warp_k/div) + B_view = B.T.reshape( + N // warp_tile_n, + warp_tile_n, + K // warp_tile_k, + divisor, + warp_tile_k // divisor, + ) + # Permute: {0, 2, 3, 1, 4} + B_shuffled = np.transpose(B_view, (0, 2, 3, 1, 4)) + + else: + # GFX9 (CDNA) pattern - wave64 + divisor = 2 if warp_tile_n == 32 else 4 + + # Reshape: (K, N) -> (N/warp_n, warp_n, K/warp_k, div, warp_k/div) + B_view = B.T.reshape( + N // warp_tile_n, + warp_tile_n, + K // warp_tile_k, + divisor, + warp_tile_k // divisor, + ) + # Permute: {0, 2, 3, 1, 4} + B_shuffled = np.transpose(B_view, (0, 2, 3, 1, 4)) + + # Return contiguous array with same dtype + return np.ascontiguousarray(B_shuffled.reshape(-1)).reshape(B.shape) + + +def is_preshuffle_supported(arch: str) -> bool: + """Check if preshuffle is supported for the given architecture.""" + # Preshuffle is supported on CDNA (gfx9xx) and RDNA (gfx11xx, gfx12xx) + return arch.startswith(("gfx9", "gfx11", "gfx12")) + + +@dataclass +class KernelConfig: + """ + Complete kernel configuration for GEMM. + + This defines all parameters needed to generate and run a specific kernel. + """ + + # Data types + dtype_a: str = "fp16" + dtype_b: str = "fp16" + dtype_c: str = "fp16" + dtype_acc: str = "fp32" + + # Layouts (row/col) + layout_a: str = "row" + layout_b: str = "col" + layout_c: str = "row" + + # Tile shape (work per thread block) + tile_m: int = 128 + tile_n: int = 128 + tile_k: int = 32 + + # Wave shape (warps per block) + wave_m: int = 2 + wave_n: int = 2 + wave_k: int = 1 + + # Warp tile (elements per warp) + warp_m: int = 32 + warp_n: int = 32 + warp_k: int = 16 + + # Block configuration + block_size: int = 256 + + # Pipeline configuration + pipeline: str = "compv4" + scheduler: str = "intrawave" + epilogue: str = "cshuffle" + + # Padding (enables arbitrary problem sizes) + pad_m: bool = True + pad_n: bool = True + pad_k: bool = True + + # GPU target + gfx_arch: str = "gfx942" + + # GEMM variant (affects arch filter validation) + # "standard", "preshuffle", or "multi_d" + variant: str = "standard" + + @property + def layout(self) -> str: + """Get layout string (e.g., 'rcr' for row-col-row)""" + mapping = {"row": "r", "col": "c"} + return mapping[self.layout_a] + mapping[self.layout_b] + mapping[self.layout_c] + + @property + def tile_str(self) -> str: + """Get tile size string""" + return f"{self.tile_m}x{self.tile_n}x{self.tile_k}" + + def print_config(self, indent: str = " "): + """Pretty print the configuration.""" + print(f"{indent}KernelConfig:") + print( + f"{indent} Data types: A={self.dtype_a}, B={self.dtype_b}, C={self.dtype_c}, Acc={self.dtype_acc}" + ) + print( + f"{indent} Layouts: A={self.layout_a}, B={self.layout_b}, C={self.layout_c} ({self.layout})" + ) + print(f"{indent} Tile: {self.tile_m}x{self.tile_n}x{self.tile_k}") + print(f"{indent} Waves: {self.wave_m}x{self.wave_n}x{self.wave_k}") + print(f"{indent} Warp tile: {self.warp_m}x{self.warp_n}x{self.warp_k}") + print(f"{indent} Block size: {self.block_size}") + print(f"{indent} Pipeline: {self.pipeline}/{self.scheduler}/{self.epilogue}") + print(f"{indent} Padding: M={self.pad_m}, N={self.pad_n}, K={self.pad_k}") + print(f"{indent} Target: {self.gfx_arch}") + + +class CodegenRunner: + """ + Runner for the unified GEMM code generator with parallel execution support. + + Usage: + codegen = CodegenRunner() + + # Generate standard kernels + result = codegen.generate("standard") + + # Generate preshuffle kernels + result = codegen.generate("preshuffle") + + # Generate multi-D kernels + result = codegen.generate("multi_d") + + # Generate all variants IN PARALLEL + results = codegen.generate_all_parallel() + + # Generate multiple configs IN PARALLEL + configs = [KernelConfig(...), KernelConfig(...)] + results = codegen.generate_configs_parallel(configs) + + # Generate with custom output directory + result = codegen.generate("standard", output_dir=Path("/custom/path")) + + # Generate from specific config + config = KernelConfig(tile_m=256, tile_n=256, tile_k=64) + result = codegen.generate_from_config(config) + """ + + VARIANTS = ["standard", "preshuffle", "multi_d"] + + def __init__( + self, + codegen_path: Optional[Path] = None, + output_dir: Optional[Path] = None, + datatype: str = "fp16", + layout: str = "rcr", + gpu_target: str = "gfx942", + max_workers: Optional[int] = None, + ): + self.codegen_path = codegen_path or get_codegen_path() + self.output_dir = output_dir or get_generated_kernels_dir() + self.datatype = datatype + self.layout = layout + self.gpu_target = gpu_target + # Default to CPU count, but cap at reasonable value + self.max_workers = max_workers or min(multiprocessing.cpu_count(), 8) + + def _make_args( + self, + variant: str, + output_dir: Optional[Path] = None, + extra_args: Optional[List[str]] = None, + timeout: int = 300, + show_instances: bool = False, + ) -> Dict[str, Any]: + """Build args dict for parallel worker.""" + return { + "codegen_path": str(self.codegen_path), + "output_dir": str(output_dir or self.output_dir), + "variant": variant, + "datatype": self.datatype, + "layout": self.layout, + "gpu_target": self.gpu_target, + "extra_args": extra_args or [], + "timeout": timeout, + "show_instances": show_instances, + } + + def generate( + self, + variant: str = "standard", + output_dir: Optional[Path] = None, + extra_args: Optional[List[str]] = None, + show_instances: bool = False, + ) -> CodegenResult: + """ + Generate kernels for a specific variant (single-threaded). + + Args: + variant: One of "standard", "preshuffle", "multi_d" + output_dir: Override output directory + extra_args: Additional arguments to pass to codegen + show_instances: Print "Adding Instance" and "Building Instance" for each kernel + + Returns: + CodegenResult with generation status and info + """ + args = self._make_args( + variant, output_dir, extra_args, show_instances=show_instances + ) + result = _run_codegen_subprocess(args) + + if show_instances and result.instance_names: + for name in result.instance_names: + print(f" Adding Instance: {name}") + print(f" Building Instance: {name}") + + return result + + def generate_all(self, output_dir: Optional[Path] = None) -> List[CodegenResult]: + """Generate all variants sequentially (use generate_all_parallel for speed).""" + results = [] + for variant in self.VARIANTS: + result = self.generate(variant, output_dir) + results.append(result) + return results + + def generate_all_parallel( + self, + output_dir: Optional[Path] = None, + variants: Optional[List[str]] = None, + verbose: bool = True, + show_instances: bool = False, + ) -> List[CodegenResult]: + """ + Generate all variants IN PARALLEL. + + Args: + output_dir: Override output directory + variants: List of variants to generate (default: all) + verbose: Print progress + show_instances: Print "Adding Instance" and "Building Instance" for each kernel + + Returns: + List of CodegenResult for each variant + """ + variants = variants or self.VARIANTS + start_total = time.time() + + if verbose: + print( + f"Generating {len(variants)} variants in parallel (workers={self.max_workers})..." + ) + + # Build args for each variant + args_list = [self._make_args(v, output_dir) for v in variants] + for args in args_list: + args["show_instances"] = show_instances + + results = [] + with ProcessPoolExecutor(max_workers=self.max_workers) as executor: + futures = { + executor.submit(_run_codegen_subprocess, args): args["variant"] + for args in args_list + } + + for future in as_completed(futures): + variant = futures[future] + try: + result = future.result() + results.append(result) + if verbose: + status = "✓" if result.success else "✗" + print( + f" {status} {variant}: {result.kernel_count} kernels in {result.elapsed_seconds:.2f}s" + ) + if show_instances and result.instance_names: + for name in result.instance_names: + print(f" Adding Instance: {name}") + print(f" Building Instance: {name}") + except Exception as e: + results.append( + CodegenResult( + success=False, + output_dir=output_dir or self.output_dir, + variant=variant, + stderr=str(e), + ) + ) + if verbose: + print(f" ✗ {variant}: FAILED - {e}") + + total_time = time.time() - start_total + if verbose: + total_kernels = sum(r.kernel_count for r in results) + print(f"Total: {total_kernels} kernels in {total_time:.2f}s") + + return results + + def generate_configs_parallel( + self, + configs: List["KernelConfig"], + output_dir: Optional[Path] = None, + verbose: bool = True, + show_instances: bool = False, + ) -> List[CodegenResult]: + """ + Generate kernels from multiple configs IN PARALLEL. + + Each config generates independently, allowing maximum parallelism. + + Args: + configs: List of KernelConfig objects + output_dir: Override output directory + verbose: Print progress + show_instances: Print "Adding Instance" and "Building Instance" for each kernel + + Returns: + List of CodegenResult for each config + """ + start_total = time.time() + out_dir = output_dir or self.output_dir + + if verbose: + print( + f"Generating {len(configs)} configs in parallel (workers={self.max_workers})..." + ) + + results = [] + with ProcessPoolExecutor(max_workers=self.max_workers) as executor: + futures = {} + for config in configs: + args = { + "codegen_path": str(self.codegen_path), + "output_dir": str(out_dir), + "variant": "standard", + "datatype": config.dtype_a, + "layout": config.layout, + "gpu_target": config.gfx_arch, + "extra_args": [], + "timeout": 300, + "show_instances": show_instances, + } + future = executor.submit(_run_codegen_subprocess, args) + futures[future] = config.tile_str + + for future in as_completed(futures): + tile_str = futures[future] + try: + result = future.result() + results.append(result) + if verbose: + status = "✓" if result.success else "✗" + print( + f" {status} {tile_str}: {result.kernel_count} kernels in {result.elapsed_seconds:.2f}s" + ) + if show_instances and result.instance_names: + for name in result.instance_names: + print(f" Adding Instance: {name}") + print(f" Building Instance: {name}") + except Exception as e: + results.append( + CodegenResult( + success=False, + output_dir=out_dir, + variant=f"config:{tile_str}", + stderr=str(e), + ) + ) + if verbose: + print(f" ✗ {tile_str}: FAILED - {e}") + + total_time = time.time() - start_total + if verbose: + total_kernels = sum(r.kernel_count for r in results) + print(f"Total: {total_kernels} kernels in {total_time:.2f}s") + + return results + + def generate_batch_parallel( + self, + batch: List[Dict[str, Any]], + verbose: bool = True, + show_instances: bool = False, + ) -> List[CodegenResult]: + """ + Generate a batch of kernel specs IN PARALLEL. + + This is the most flexible parallel generation method. + + Args: + batch: List of dicts with keys: variant, datatype, layout, gpu_target, output_dir + verbose: Print progress + show_instances: Print "Adding Instance" and "Building Instance" for each kernel + + Returns: + List of CodegenResult + """ + start_total = time.time() + + if verbose: + print( + f"Generating {len(batch)} kernel specs in parallel (workers={self.max_workers})..." + ) + + # Build args for each spec + args_list = [] + for spec in batch: + args = { + "codegen_path": str(self.codegen_path), + "output_dir": str(spec.get("output_dir", self.output_dir)), + "variant": spec.get("variant", "standard"), + "datatype": spec.get("datatype", self.datatype), + "layout": spec.get("layout", self.layout), + "gpu_target": spec.get("gpu_target", self.gpu_target), + "extra_args": spec.get("extra_args", []), + "timeout": spec.get("timeout", 300), + "show_instances": show_instances, + } + args_list.append(args) + + results = [] + with ProcessPoolExecutor(max_workers=self.max_workers) as executor: + futures = { + executor.submit(_run_codegen_subprocess, args): args["variant"] + for args in args_list + } + + for future in as_completed(futures): + variant = futures[future] + try: + result = future.result() + results.append(result) + if verbose: + status = "✓" if result.success else "✗" + print( + f" {status} {variant}: {result.kernel_count} kernels in {result.elapsed_seconds:.2f}s" + ) + if show_instances and result.instance_names: + for name in result.instance_names: + print(f" Adding Instance: {name}") + print(f" Building Instance: {name}") + except Exception as e: + results.append( + CodegenResult( + success=False, + output_dir=self.output_dir, + variant=variant, + stderr=str(e), + ) + ) + if verbose: + print(f" ✗ {variant}: FAILED - {e}") + + total_time = time.time() - start_total + if verbose: + total_kernels = sum(r.kernel_count for r in results) + print(f"Total: {total_kernels} kernels in {total_time:.2f}s") + + return results + + def generate_from_config( + self, + config: KernelConfig, + output_dir: Optional[Path] = None, + force: bool = False, + show_instances: bool = False, + ) -> CodegenResult: + """ + Generate kernel from a specific KernelConfig. + + This generates ONLY the specific kernel header needed (not all kernels). + Note: This does NOT rebuild the library - use build_library_for_configs() + for that. + + Args: + config: KernelConfig with all kernel parameters + output_dir: Override output directory + force: Force regeneration even if kernel exists + show_instances: Print instance names when generating + + Returns: + CodegenResult with the specific kernel + """ + import sys + import json + import tempfile + + out_dir = output_dir or self.output_dir + out_dir.mkdir(parents=True, exist_ok=True) + + # Build kernel filename pattern for this config + # Note: padding flags may differ from config (arch filter may enable padding) + tile_str = config.tile_str # e.g., "128x128x32" + wave_str = f"{config.wave_m}x{config.wave_n}x{config.wave_k}" + warp_str = f"{config.warp_m}x{config.warp_n}x{config.warp_k}" + + # Build pattern - use * for padding flags since arch filter may change them + precise_pattern = f"gemm_{config.dtype_a}_{config.layout}_{config.pipeline}_{config.epilogue}_{config.scheduler}_*_*_*_*_{tile_str}_{wave_str}_{warp_str}.hpp" + + # Check if exact kernel already exists + existing = list(out_dir.glob(precise_pattern)) + if existing and not force: + instance_names = sorted([k.stem for k in existing]) + if show_instances: + for name in instance_names: + print(f" Kernel exists: {name}") + + return CodegenResult( + success=True, + output_dir=out_dir, + variant=f"config:{tile_str}", + kernel_count=len(existing), + instance_names=instance_names, + stdout=f"Kernel exists, using: {existing[0].name}", + ) + + if not self.codegen_path.exists(): + return CodegenResult( + success=False, + output_dir=out_dir, + variant=f"config:{tile_str}", + stderr=f"Codegen not found at {self.codegen_path}", + ) + + start = time.time() + + # Create a temporary config file for single-kernel generation + # Format must match what unified_gemm_codegen.py expects + single_config = { + "tile_config": { + "tile_m": [config.tile_m], + "tile_n": [config.tile_n], + "tile_k": [config.tile_k], + "warp_m": [config.wave_m], + "warp_n": [config.wave_n], + "warp_k": [config.wave_k], + "warp_tile_m": [config.warp_m], + "warp_tile_n": [config.warp_n], + "warp_tile_k": [config.warp_k], + }, + "trait_config": { + "pipeline": [config.pipeline], + "epilogue": [config.epilogue], + "scheduler": [config.scheduler], + "pad_m": [config.pad_m], + "pad_n": [config.pad_n], + "pad_k": [config.pad_k], + "persistent": [False], + }, + } + + # Write temp config file + with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f: + json.dump(single_config, f) + config_file = f.name + + try: + # Generate ONLY this specific kernel using config file + cmd = [ + sys.executable, + str(self.codegen_path), + "--output-dir", + str(out_dir), + "--datatype", + config.dtype_a, + "--layout", + config.layout, + "--gpu-target", + config.gfx_arch, + "--config", + config_file, + "--variants", + "standard", + ] + + result = subprocess.run(cmd, capture_output=True, text=True, timeout=300) + + # Find the generated kernel + matching = list(out_dir.glob(precise_pattern)) + kernel_count = len(matching) + elapsed = time.time() - start + + instance_names = sorted([k.stem for k in matching]) + if show_instances and instance_names: + for name in instance_names: + print(f" Generated: {name}") + + return CodegenResult( + success=result.returncode == 0 and kernel_count > 0, + output_dir=out_dir, + variant=f"config:{tile_str}", + stdout=result.stdout, + stderr=result.stderr, + kernel_count=kernel_count, + elapsed_seconds=elapsed, + instance_names=instance_names, + ) + except Exception as e: + return CodegenResult( + success=False, + output_dir=out_dir, + variant=f"config:{tile_str}", + stderr=str(e), + ) + finally: + # Clean up temp file + import os + + try: + os.unlink(config_file) + except Exception: + pass + + def _rebuild_library_for_config( + self, config: KernelConfig, kernel_header: Path + ) -> Optional[Path]: + """ + Rebuild the library with the specified kernel header using hipcc directly. + + This compiles a new library with exactly the kernel specified. + Builds to a UNIQUE filename to avoid conflicts with loaded libraries. + + Architecture Note - C++ vs Python Paths: + ----------------------------------------- + C++ Multi-Kernel Path: + - Each kernel is in its own namespace (ns_gemm_...) + - Multiple kernel headers can be included together + - Uses namespace-qualified types: ns_...:SelectedKernel + - Does NOT define CK_TILE_SINGLE_KERNEL_INCLUDE + - Registration code uses block-scoped type aliases + + Python Single-Kernel JIT Path (this function): + - Each library contains exactly ONE kernel + - Uses -DCK_TILE_SINGLE_KERNEL_INCLUDE to export types to global namespace + - gemm_ctypes_lib.cpp expects: SelectedKernel, KERNEL_NAME, ADataType, etc. + - Different configs get different library files (by dtype/layout) + - This enables Python to use any kernel config without pre-building all + + Returns: Path to new library, or None on failure + """ + build_dir = get_build_dir() + # Use unique filename based on dtype/layout to avoid overwriting loaded library + lib_name = f"libdispatcher_gemm_{config.dtype_a}_{config.layout}_lib.so" + lib_path = build_dir / "examples" / lib_name + + print(f" Rebuilding library: {lib_name}") + print(f" With kernel: {kernel_header.name}") + + root = get_dispatcher_root() + ck_root = root.parent + + ctypes_source = root / "bindings/ctypes/gemm_ctypes_lib.cpp" + if not ctypes_source.exists(): + print(f" Source not found: {ctypes_source}") + return None + + # Link against the static dispatcher library (contains Registry, Dispatcher) + static_lib = build_dir / "libck_tile_dispatcher.a" + if not static_lib.exists(): + print(f" Static library not found: {static_lib}") + print(" Build with: cd build && cmake .. && make ck_tile_dispatcher") + return None + + # Compile source to object first, then link + obj_file = lib_path.with_suffix(".o") + + # Step 1: Compile source to object + # CK_TILE_SINGLE_KERNEL_INCLUDE enables global namespace exports in the kernel header + # This exports: SelectedKernel, KERNEL_NAME, ADataType, BDataType, CDataType, AccDataType + compile_cmd = [ + "/opt/rocm/bin/hipcc", + "-c", # Compile only + "-fPIC", + "-O3", + f"-I{root / 'include'}", + f"-I{ck_root / 'include'}", + f"-I{ck_root}", + f"-I{root / 'build/generated_kernels'}", + "-DCK_TILE_SINGLE_KERNEL_INCLUDE", # Enable global namespace exports + f"-include{kernel_header}", + "-D__HIP_PLATFORM_AMD__", + f"--offload-arch={config.gfx_arch}", + f'-DGFX_ARCH="{config.gfx_arch}"', # Pass arch as string for gemm_ctypes_lib.cpp + "-mllvm", + "-enable-noalias-to-md-conversion=0", + "-Wno-undefined-func-template", + "-Wno-float-equal", + str(ctypes_source), + "-o", + str(obj_file), + ] + + try: + print(" Compiling source...") + result = subprocess.run( + compile_cmd, capture_output=True, text=True, timeout=300 + ) + if result.returncode != 0: + print(f" Compilation failed: {result.stderr[:300]}") + return None + + # Step 2: Link object with static library into shared library + link_cmd = [ + "/opt/rocm/bin/hipcc", + "-shared", + "-fPIC", + f"--offload-arch={config.gfx_arch}", + "--hip-link", + str(obj_file), + str(static_lib), + "-o", + str(lib_path), + ] + + print(" Linking...") + result = subprocess.run( + link_cmd, capture_output=True, text=True, timeout=300 + ) + if result.returncode == 0: + print(f" ✓ Library rebuilt: {lib_path.name}") + # Clean up object file + obj_file.unlink(missing_ok=True) + return lib_path + else: + print(f" Linking failed: {result.stderr[:300]}") + return None + except subprocess.TimeoutExpired: + print(" Build timed out") + return None + except Exception as e: + print(f" Build error: {e}") + return None + + def generate_preselected( + self, preset: str = "fp16_rcr_essential", output_dir: Optional[Path] = None + ) -> CodegenResult: + """ + Generate kernels from a preselected set. + + Args: + preset: Preselected kernel set name (e.g., "fp16_rcr_essential") + output_dir: Override output directory + + Returns: + CodegenResult + """ + import sys + + out_dir = output_dir or self.output_dir + out_dir.mkdir(parents=True, exist_ok=True) + + cmd = [ + sys.executable, + str(self.codegen_path), + "--output-dir", + str(out_dir), + "--preselected", + preset, + ] + + try: + result = subprocess.run(cmd, capture_output=True, text=True, timeout=300) + kernel_count = len(list(out_dir.glob("*.hpp"))) + + return CodegenResult( + success=result.returncode == 0, + output_dir=out_dir, + variant=f"preselected:{preset}", + stdout=result.stdout, + stderr=result.stderr, + kernel_count=kernel_count, + ) + except Exception as e: + return CodegenResult( + success=False, + output_dir=out_dir, + variant=f"preselected:{preset}", + stderr=str(e), + ) + + def ensure_kernels_exist(self) -> bool: + """ + Ensure kernel headers exist, generating if necessary. + + Returns: + True if kernels exist or were successfully generated + """ + if self.output_dir.exists(): + kernels = list(self.output_dir.glob("*.hpp")) + if kernels: + return True + + # Generate standard kernels + result = self.generate("standard") + return result.success + + def list_kernels(self) -> List[Path]: + """List all generated kernel headers""" + if self.output_dir.exists(): + return sorted(self.output_dir.glob("*.hpp")) + return [] + + def categorize_kernels(self) -> dict: + """ + Categorize kernels by tile size and variant. + + Returns: + Dict with categories by tile size and variant type + """ + kernels = self.list_kernels() + + # Separate by variant first + preshuffle = [k for k in kernels if "_preshuffle" in k.name] + multi_d = [k for k in kernels if "_multid_" in k.name] + standard = [ + k + for k in kernels + if "_preshuffle" not in k.name and "_multid_" not in k.name + ] + + # Categorize standard kernels by tile size + compute = [k for k in standard if "_256x" in k.name] + memory = [k for k in standard if "_128x" in k.name] + latency = [k for k in standard if "_64x" in k.name or "_32x" in k.name] + + return { + "total": len(kernels), + "standard": len(standard), + "compute": compute, + "memory": memory, + "latency": latency, + "preshuffle": preshuffle, + "multi_d": multi_d, + } + + +# ============================================================================= +# Registry and Dispatcher (Explicit API) +# ============================================================================= + + +class Registry: + """ + Kernel registry - stores and manages kernel instances. + + This provides an explicit registry API that mirrors the C++ Registry class. + + Usage: + registry = Registry() + registry.register_kernel(kernel_config) + dispatcher = Dispatcher(registry) + """ + + def __init__(self, lib: Optional[DispatcherLib] = None, name: str = "default"): + self._lib = lib + self._name = name + self._kernels: List[KernelConfig] = [] + + @property + def name(self) -> str: + return self._name + + @property + def kernel_count(self) -> int: + if self._lib: + return self._lib.get_kernel_count() + return len(self._kernels) + + def register_kernel(self, config: KernelConfig) -> bool: + """Register a kernel configuration.""" + self._kernels.append(config) + return True + + def get_kernels(self) -> List[KernelConfig]: + """Get all registered kernel configs.""" + return self._kernels.copy() + + def clear(self): + """Clear all kernels.""" + self._kernels.clear() + + def bind_library(self, lib: DispatcherLib): + """Bind to a loaded dispatcher library.""" + self._lib = lib + + def __repr__(self) -> str: + return f"Registry(name='{self._name}', kernels={self.kernel_count})" + + +class Dispatcher: + """ + Kernel dispatcher - selects and runs kernels for problems. + + This provides an explicit dispatcher API that mirrors the C++ Dispatcher class. + + Usage: + registry = Registry() + registry.register_kernel(config) + + dispatcher = Dispatcher(registry) + result = dispatcher.run(A, B, M, N, K) + """ + + def __init__(self, registry: Registry, lib: Optional[DispatcherLib] = None): + self._registry = registry + self._lib = lib or registry._lib + + @property + def registry(self) -> Registry: + return self._registry + + def select_kernel(self, M: int, N: int, K: int) -> Optional[str]: + """Select best kernel for problem dimensions.""" + if self._lib: + return self._lib.select_kernel(M, N, K) + # Fallback: return first matching kernel + for config in self._registry.get_kernels(): + return f"kernel_{config.tile_str}" + return None + + def is_supported(self, M: int, N: int, K: int) -> bool: + """Check if problem size is supported.""" + if self._lib: + return self._lib.is_supported(M, N, K) + return len(self._registry.get_kernels()) > 0 + + def run(self, A: np.ndarray, B: np.ndarray, M: int, N: int, K: int) -> GemmResult: + """ + Run GEMM: C = A @ B + + Args: + A: Input matrix (M x K) + B: Input matrix (K x N) + M, N, K: Problem dimensions + + Returns: + GemmResult with output and timing + """ + if self._lib is None: + raise RuntimeError("Dispatcher not bound to library") + + # Ensure contiguous float16 arrays + A_gpu = np.ascontiguousarray(A, dtype=np.float16) + B_gpu = np.ascontiguousarray(B.T, dtype=np.float16) # Column-major + C_gpu = np.zeros((M, N), dtype=np.float16) + + # Run via library + status, time_ms = self._lib.run_gemm(A_gpu, B_gpu, C_gpu, M, N, K) + + # Calculate TFLOPS + flops = 2.0 * M * N * K + tflops = (flops / (time_ms * 1e-3)) / 1e12 if time_ms > 0 else 0 + + return GemmResult( + output=C_gpu, + time_ms=time_ms, + status=status, + tflops=tflops, + kernel_name=self._lib.get_kernel_name() if self._lib else "unknown", + ) + + def __repr__(self) -> str: + return f"Dispatcher(registry={self._registry.name}, kernels={self._registry.kernel_count})" + + +# ============================================================================= +# Main (self-test) +# ============================================================================= + +if __name__ == "__main__": + print("CK Tile Dispatcher Utils Self-Test") + print("=" * 60) + + # Test library loading + print("\n1. Loading library...") + lib = DispatcherLib.auto() + if lib is None: + print(" FAILED: Could not load library") + exit(1) + print(f" OK: Loaded from {lib.path}") + print(f" Kernel: {lib.get_kernel_name()}") + print(f" Registered kernels: {lib.get_kernel_count()}") + + # Test GEMM + print("\n2. Running GEMM 256x256x256...") + runner = GemmRunner(lib) + A = np.random.randn(256, 256).astype(np.float16) + B = np.random.randn(256, 256).astype(np.float16) + + result = runner.run(A, B) + print(f" Status: {'OK' if result.success else 'FAILED'}") + print(f" Time: {result.time_ms:.4f} ms") + print(f" TFLOPS: {result.tflops:.2f}") + + # Test validation + print("\n3. Validating result...") + validator = Validator() + reference = validator.compute_reference(A, B) + correct, max_diff, mean_diff = validator.check(result.output, reference) + print(f" Correct: {correct}") + print(f" Max diff: {max_diff:.6f}") + + print("\n" + "=" * 60) + print("All tests passed!") + + +# ============================================================================= +# High-Level Helper Functions +# ============================================================================= + + +@dataclass +class GemmSetupResult: + """Result of setup_gemm_dispatcher""" + + success: bool + dispatcher: Optional[Dispatcher] = None + lib: Optional[DispatcherLib] = None + registry: Optional[Registry] = None + codegen: Optional[CodegenRunner] = None + config: Optional[KernelConfig] = None + kernel_header: Optional[Path] = None + error: str = "" + corrections: List[str] = field(default_factory=list) + + +def setup_gemm_dispatcher( + config: KernelConfig, + registry_name: str = "gemm_registry", + verbose: bool = True, + auto_rebuild: bool = True, +) -> GemmSetupResult: + """ + High-level helper to setup a GEMM dispatcher from a kernel config. + + This handles: + 1. Validate config against arch filter (auto-correct if needed) + 2. Generate kernel code if needed + 3. Find matching kernel header + 4. Load or rebuild library (if dtype mismatch) + 5. Create registry and dispatcher + + Args: + config: KernelConfig with all parameters + registry_name: Name for the registry + verbose: Print progress messages + auto_rebuild: Rebuild library if dtype doesn't match + + Returns: + GemmSetupResult with dispatcher, lib, registry, etc. + """ + result = GemmSetupResult(success=False, config=config) + + def log(msg): + if verbose: + print(msg) + + # Step 1: Validate config + log(" Validating config...") + validation = validate_kernel_config(config) + if not validation.is_valid: + log(" ⚠ Auto-correcting configuration...") + config, was_modified, corrections = auto_correct_kernel_config( + config, verbose=verbose + ) + result.config = config + result.corrections = corrections + # Note: corrections will be displayed by the caller via print_auto_correction + + # Step 2: Setup codegen and generate kernel + log(f" Generating kernel (tile={config.tile_str})...") + codegen = CodegenRunner( + datatype=config.dtype_a, + layout=config.layout, + gpu_target=config.gfx_arch, + ) + result.codegen = codegen + + codegen_result = codegen.generate_from_config(config) + if not codegen_result.success: + log(" ⚠ Kernel generation: using existing") + + # Step 3: Find matching kernel header + kernel_header = find_matching_kernel_header(config) + result.kernel_header = kernel_header + if not kernel_header: + log(" ⚠ No matching kernel header found") + + # Step 4: Load library + log(" Loading library...") + lib = DispatcherLib.auto() + if lib is None: + result.error = "Could not load dispatcher library" + return result + result.lib = lib + + # Check if library kernel matches config - rebuild if ANY parameter differs + lib_kernel = lib.get_kernel_name() + needs_rebuild = False + mismatches = [] + + if lib_kernel: + # Build expected kernel signature components from config + expected_parts = { + "dtype": config.dtype_a, + "layout": config.layout, + "pipeline": config.pipeline, + "epilogue": config.epilogue, + "scheduler": config.scheduler, + "tile": f"{config.tile_m}x{config.tile_n}x{config.tile_k}", + "wave": f"{config.wave_m}x{config.wave_n}x{config.wave_k}", + "warp": f"{config.warp_m}x{config.warp_n}x{config.warp_k}", + } + + # Check each component against the library kernel name + for name, expected in expected_parts.items(): + if expected not in lib_kernel: + needs_rebuild = True + mismatches.append(f"{name}={expected}") + + if needs_rebuild and auto_rebuild: + log(f" Library kernel doesn't match config: {', '.join(mismatches)}") + log(" Rebuilding library for exact config match...") + + # First ensure we have a kernel header for this exact config + if not kernel_header: + # Generate kernel for the exact config + log(" Generating kernel for config...") + codegen_result = codegen.generate_from_config(config, force=True) + kernel_header = find_matching_kernel_header(config) + result.kernel_header = kernel_header + + if kernel_header: + new_lib_path = codegen._rebuild_library_for_config(config, kernel_header) + if new_lib_path: + lib = DispatcherLib.load(new_lib_path) + if lib is None or not lib.initialize(): + result.error = "Failed to load rebuilt library" + return result + result.lib = lib + log(f" ✓ Rebuilt library: {lib.get_kernel_name()}") + else: + log(" ⚠ Rebuild failed, using existing library") + else: + log(" ⚠ No kernel header found for config, using existing library") + + # Step 5: Create registry and dispatcher + log(" Creating registry and dispatcher...") + registry = Registry(name=registry_name, lib=lib) + registry.register_kernel(config) + result.registry = registry + + dispatcher = Dispatcher(registry=registry, lib=lib) + result.dispatcher = dispatcher + + log(f" ✓ Ready: {lib.get_kernel_name()}") + + result.success = True + return result + + +def cleanup_gemm(): + """ + Cleanup function to call after running GEMM examples. + + This helps ensure clean state between examples by: + 1. Clearing any global state + 2. Suggesting garbage collection + """ + import gc + + # Clear loaded libraries list + DispatcherLib._loaded_libs.clear() + + # Suggest garbage collection + gc.collect() + + +def cleanup_generated_kernels( + keep_default: bool = True, + verbose: bool = False, +) -> int: + """ + Clean up generated kernel files. + + Call this at the start of examples to ensure fresh state. + + Args: + keep_default: Keep the default fp16 kernel (True) or delete all (False) + verbose: Print what's being deleted + + Returns: + Number of files deleted + """ + + kernel_dir = get_generated_kernels_dir() + if not kernel_dir.exists(): + return 0 + + deleted = 0 + + # Default kernel pattern to keep + default_pattern = ( + "gemm_fp16_rcr_compv4_cshuffle_intrawave_*_128x128x32_2x2x1_16x16x16.hpp" + ) + + for f in kernel_dir.glob("*.hpp"): + # Skip dispatcher_wrappers directory + if f.is_dir(): + continue + + # Optionally keep default kernel + if keep_default and f.match(default_pattern): + continue + + if verbose: + print(f" Deleting: {f.name}") + f.unlink() + deleted += 1 + + # Also clean up any temp libs + build_dir = get_build_dir() + examples_dir = build_dir / "examples" + if examples_dir.exists(): + for f in examples_dir.glob("libdispatcher_gemm_*_lib.so"): + if f.name != "libdispatcher_gemm_lib.so": + if verbose: + print(f" Deleting: {f.name}") + f.unlink() + deleted += 1 + + return deleted + + +def reset_for_example(verbose: bool = False): + """ + Reset state for a fresh example run. + + Call this at the START of each example to ensure clean state. + Cleans up generated kernels (except default) and resets globals. + """ + # Cleanup any previously generated kernels + deleted = cleanup_generated_kernels(keep_default=True, verbose=verbose) + if verbose and deleted > 0: + print(f" Cleaned up {deleted} generated files") + + # Clear any cached state + cleanup_gemm() + + +# Main (self-test) +# ============================================================================= + +if __name__ == "__main__": + print("CK Tile Dispatcher Utils Self-Test") + print("=" * 60) + + # Test library loading + print("\n1. Loading library...") + lib = DispatcherLib.auto() + if lib is None: + print(" FAILED: Could not load library") + exit(1) + print(f" OK: Loaded from {lib.path}") + print(f" Kernel: {lib.get_kernel_name()}") + print(f" Registered kernels: {lib.get_kernel_count()}") + + # Test GEMM + print("\n2. Running GEMM 256x256x256...") + runner = GemmRunner(lib) + A = np.random.randn(256, 256).astype(np.float16) + B = np.random.randn(256, 256).astype(np.float16) + + result = runner.run(A, B) + print(f" Status: {'OK' if result.success else 'FAILED'}") + print(f" Time: {result.time_ms:.4f} ms") + print(f" TFLOPS: {result.tflops:.2f}") + + # Test validation + print("\n3. Validating result...") + validator = Validator() + reference = validator.compute_reference(A, B) + correct, max_diff, mean_diff = validator.check(result.output, reference) + print(f" Correct: {correct}") + print(f" Max diff: {max_diff:.6f}") + + # Test high-level helper + print("\n4. Testing setup_gemm_dispatcher...") + config = KernelConfig(tile_m=128, tile_n=128, tile_k=32) + setup = setup_gemm_dispatcher(config, verbose=True) + print(f" Success: {setup.success}") + + # Cleanup + cleanup_gemm() + + print("\n" + "=" * 60) + print("All tests passed!") diff --git a/dispatcher/python/pytest.ini b/dispatcher/python/pytest.ini new file mode 100644 index 00000000000..08cd235fdae --- /dev/null +++ b/dispatcher/python/pytest.ini @@ -0,0 +1,43 @@ +[pytest] +# Pytest configuration for CK Tile Dispatcher Python tests + +# Test discovery +python_files = test_*.py +python_classes = Test* +python_functions = test_* + +# Test paths +testpaths = tests + +# Options +addopts = + -v + --strict-markers + --tb=short + --color=yes + --durations=10 + +# Markers +markers = + slow: marks tests as slow (deselect with '-m "not slow"') + cuda: marks tests requiring CUDA/ROCm + torch: marks tests requiring PyTorch + integration: marks integration tests + unit: marks unit tests + +# Coverage +[coverage:run] +source = . +omit = + */tests/* + */examples/* + setup.py + +[coverage:report] +precision = 2 +show_missing = True +skip_covered = False + +[coverage:html] +directory = htmlcov + diff --git a/dispatcher/python/requirements.txt b/dispatcher/python/requirements.txt new file mode 100644 index 00000000000..9d429235f77 --- /dev/null +++ b/dispatcher/python/requirements.txt @@ -0,0 +1,22 @@ +# Core dependencies +numpy>=1.19.0 + +# Optional dependencies (install with pip install -e ".[torch]") +# torch>=2.0.0 + +# Development dependencies (install with pip install -e ".[dev]") +# pytest>=6.0.0 +# pytest-cov>=2.0.0 +# black>=21.0 +# flake8>=3.9.0 +# mypy>=0.910 +# isort>=5.0.0 + +# Visualization dependencies (install with pip install -e ".[viz]") +# matplotlib>=3.3.0 +# seaborn>=0.11.0 + +# Documentation dependencies +# sphinx>=4.0.0 +# sphinx-rtd-theme>=1.0.0 + diff --git a/dispatcher/scripts/compile_gemm_examples.py b/dispatcher/scripts/compile_gemm_examples.py new file mode 100644 index 00000000000..b19c18a13a4 --- /dev/null +++ b/dispatcher/scripts/compile_gemm_examples.py @@ -0,0 +1,2253 @@ +#!/usr/bin/env python3 + +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +Cross-platform build script for declarative kernel workflow. + +Uses existing ctypes_utils.py for path management and codegen. + +Usage: + python3 compile_gemm_examples.py [output_name] + +Example: + python3 compile_gemm_examples.py examples/cpp/01_basic_gemm.cpp my_app +""" + +import argparse +import os +import re +import subprocess +import sys +from pathlib import Path +import shutil + +# Add dispatcher/python to path to reuse existing utilities +SCRIPT_DIR = Path(__file__).parent.resolve() +DISPATCHER_DIR = SCRIPT_DIR.parent +sys.path.insert(0, str(DISPATCHER_DIR / "python")) + +# Import existing utilities (after sys.path modification) +from ctypes_utils import ( # noqa: E402 + get_dispatcher_root, + get_ck_root, + get_build_dir, + get_generated_kernels_dir, + CodegenRunner, +) + + +# ============================================================================= +# Terminal Colors (cross-platform) +# ============================================================================= + + +class Colors: + if sys.platform != "win32" and sys.stdout.isatty(): + GREEN = "\033[0;32m" + YELLOW = "\033[1;33m" + RED = "\033[0;31m" + NC = "\033[0m" + else: + GREEN = YELLOW = RED = NC = "" + + +def print_phase(msg: str): + print(f"{Colors.YELLOW}{msg}{Colors.NC}") + + +def print_success(msg: str): + print(f"{Colors.GREEN}{msg}{Colors.NC}") + + +def print_error(msg: str): + print(f"{Colors.RED}{msg}{Colors.NC}", file=sys.stderr) + + +# ============================================================================= +# Compiler Detection +# ============================================================================= + + +def find_hipcc() -> str: + """Find hipcc compiler.""" + candidates = [ + os.environ.get("HIPCC"), + "/opt/rocm/bin/hipcc", + "/opt/rocm/hip/bin/hipcc", + shutil.which("hipcc"), + ] + + for path in candidates: + if path and os.path.isfile(path): + return path + + raise RuntimeError( + "hipcc not found. Please install ROCm or set HIPCC environment variable." + ) + + +# ============================================================================= +# Declaration Extraction +# ============================================================================= + + +def extract_conv_kernel_declarations(source_file: Path) -> list: + """Extract CONVOLUTION kernel declarations from C++ source file. + + Supports DECL_CONV_KERNEL_SET macro with ConvSig/ConvAlgo pattern. + Extracts all parameters: dtype, layout, conv_type, dims, tile, wave, warp, pipeline, scheduler. + """ + content = source_file.read_text() + declarations = [] + seen = set() + + # Pattern: DECL_CONV_KERNEL_SET(name, .add(...).add(...)) + set_pattern = r"DECL_CONV_KERNEL_SET\s*\(\s*(\w+)\s*,([^;]+)\)" + + for match in re.finditer(set_pattern, content, re.DOTALL): + set_name = match.group(1) + set_body = match.group(2) + + # Pattern 1: Simple add("dtype", "layout", "conv_type", tile_k, tile_c) + simple_add = ( + r'\.add\s*\(\s*"(\w+)"\s*,\s*"(\w+)"\s*,\s*"(\w+)"\s*,\s*(\d+)\s*,\s*(\d+)' + ) + for add_match in re.finditer(simple_add, set_body): + dtype = add_match.group(1) + layout = add_match.group(2) + conv_type = add_match.group(3) + tile_k = int(add_match.group(4)) + tile_c = int(add_match.group(5)) + + name = f"{set_name}:{dtype}_{layout}_{conv_type}_{tile_k}x{tile_c}" + if name not in seen: + seen.add(name) + declarations.append( + { + "type": "conv", + "dtype": dtype, + "layout": layout, + "conv_type": conv_type, + "num_dims": 2, + "groups": 1, + "tile_n": 1, + "tile_k": tile_k, + "tile_c": tile_c, + "wave_m": -1, # Wildcard - will expand + "wave_n": -1, + "wave_k": 1, + "warp_m": -1, + "warp_n": -1, + "warp_k": 16, + "pipeline": "compv3", + "scheduler": "intrawave", + "epilogue": "cshuffle", + "name": name, + "set": set_name, + "arch": "gfx942", + } + ) + + # Pattern 2: Full specification with ConvSig() and ConvAlgo() + # Match .add( ConvSig()..., ConvAlgo()..., "arch" ) + # Use robust parsing that handles multi-line and comments + + # Find all .add( blocks containing ConvSig + add_blocks = re.findall( + r"\.add\s*\(\s*ConvSig\(\)([\s\S]*?)(?=\.add\s*\(|$)", set_body + ) + + for add_block in add_blocks: + # Find ConvAlgo and arch in this block + algo_match = re.search(r'ConvAlgo\(\)([\s\S]*?),\s*"(\w+)"\s*\)', add_block) + if not algo_match: + continue + + sig_str = add_block[: add_block.find("ConvAlgo()")] + algo_str = algo_match.group(1) + arch = algo_match.group(2) + + # Parse ConvSig + dtype = "fp16" + dtype_match = re.search(r'\.dtype\s*\(\s*"([^"]+)"', sig_str) + if dtype_match: + dtype = dtype_match.group(1) + + layout = "nhwgc" + layout_match = re.search(r'\.layout\s*\(\s*"([^"]+)"', sig_str) + if layout_match: + layout = layout_match.group(1) + + conv_type = "forward" + conv_type_match = re.search(r'\.conv_type\s*\(\s*"([^"]+)"', sig_str) + if conv_type_match: + conv_type = conv_type_match.group(1) + + num_dims = 2 + dims_match = re.search(r"\.dims\s*\(\s*(\d+)", sig_str) + if dims_match: + num_dims = int(dims_match.group(1)) + + groups = 1 + groups_match = re.search(r"\.groups\s*\(\s*(\d+)", sig_str) + if groups_match: + groups = int(groups_match.group(1)) + + # Parse ConvAlgo + tile_n, tile_k, tile_c = 1, 128, 128 + tile_match = re.search( + r"\.tile\s*\(\s*(\d+)\s*,\s*(\d+)\s*,\s*(\d+)\s*\)", algo_str + ) + if tile_match: + tile_n = int(tile_match.group(1)) + tile_k = int(tile_match.group(2)) + tile_c = int(tile_match.group(3)) + + wave_m, wave_n, wave_k = 2, 2, 1 + wave_match = re.search( + r"\.wave\s*\(\s*(\d+)\s*,\s*(\d+)(?:\s*,\s*(\d+))?\s*\)", algo_str + ) + if wave_match: + wave_m = int(wave_match.group(1)) + wave_n = int(wave_match.group(2)) + wave_k = int(wave_match.group(3) or 1) + + warp_m, warp_n, warp_k = 32, 32, 16 + warp_match = re.search( + r"\.warp\s*\(\s*(\d+)\s*,\s*(\d+)(?:\s*,\s*(\d+))?\s*\)", algo_str + ) + if warp_match: + warp_m = int(warp_match.group(1)) + warp_n = int(warp_match.group(2)) + warp_k = int(warp_match.group(3) or 16) + + pipeline = "compv3" + pipeline_match = re.search(r'\.pipeline\s*\(\s*"([^"]+)"', algo_str) + if pipeline_match: + pipeline = pipeline_match.group(1) + + scheduler = "intrawave" + scheduler_match = re.search(r'\.scheduler\s*\(\s*"([^"]+)"', algo_str) + if scheduler_match: + scheduler = scheduler_match.group(1) + + epilogue = "cshuffle" + epilogue_match = re.search(r'\.epilogue\s*\(\s*"([^"]+)"', algo_str) + if epilogue_match: + epilogue = epilogue_match.group(1) + + # Build unique name with full config + name = f"{set_name}:{dtype}_{conv_type}_{num_dims}d_{pipeline}_{scheduler}_{tile_k}x{tile_c}_{wave_m}x{wave_n}x{wave_k}" + if name not in seen: + seen.add(name) + declarations.append( + { + "type": "conv", + "dtype": dtype, + "layout": layout, + "conv_type": conv_type, + "num_dims": num_dims, + "groups": groups, + "tile_n": tile_n, + "tile_k": tile_k, + "tile_c": tile_c, + "wave_m": wave_m, + "wave_n": wave_n, + "wave_k": wave_k, + "warp_m": warp_m, + "warp_n": warp_n, + "warp_k": warp_k, + "pipeline": pipeline, + "scheduler": scheduler, + "epilogue": epilogue, + "name": name, + "set": set_name, + "arch": arch, + } + ) + + return declarations + + +def expand_conv_declaration_with_arch_filter(decl: dict, arch: str = "gfx942") -> list: + """Expand a convolution declaration to all valid combinations. + + Like GEMM, convolution supports wildcard expansion for: + - wave/warp: If -1, generates all valid combinations + - pipeline/scheduler: If "*", generates all valid trait combinations + """ + # Import arch filter + codegen_dir = get_dispatcher_root() / "codegen" + sys.path.insert(0, str(codegen_dir)) + + try: + from arch_specs_generated import ( + WARP_SUPPORTED_COMBINATIONS, + WARP_TILE_SUPPORTED_COMBINATIONS, + TRAIT_UNSUPPORTED_COMBINATIONS, + ) + except ImportError: + # Fallback + WARP_SUPPORTED_COMBINATIONS = { + "gfx942": [[1, 4, 1], [2, 2, 1], [4, 1, 1]], + } + WARP_TILE_SUPPORTED_COMBINATIONS = { + "gfx942": {"fp16_fp16_fp16": [[16, 16, 16], [32, 32, 16]]}, + } + TRAIT_UNSUPPORTED_COMBINATIONS = set() + + d = decl.copy() + tile_k = d.get("tile_k", 128) + tile_c = d.get("tile_c", 128) + dtype = d.get("dtype", "fp16") + + # Check what needs expansion + needs_wave_expansion = d.get("wave_m", -1) < 0 or d.get("wave_n", -1) < 0 + needs_warp_expansion = d.get("warp_m", -1) < 0 or d.get("warp_n", -1) < 0 + needs_pipeline_expansion = d.get("pipeline", "compv4") == "*" + needs_scheduler_expansion = d.get("scheduler", "intrawave") == "*" + + if ( + not needs_wave_expansion + and not needs_warp_expansion + and not needs_pipeline_expansion + and not needs_scheduler_expansion + ): + return [d] + + # Build valid combinations + if needs_wave_expansion or needs_warp_expansion: + wave_configs = WARP_SUPPORTED_COMBINATIONS.get(arch, [[2, 2, 1]]) + dtype_key = f"{dtype}_{dtype}_{dtype}" + warp_tile_configs = WARP_TILE_SUPPORTED_COMBINATIONS.get(arch, {}).get( + dtype_key, [[32, 32, 16], [16, 16, 16]] + ) + else: + wave_configs = [[d.get("wave_m", 2), d.get("wave_n", 2), d.get("wave_k", 1)]] + warp_tile_configs = [ + [d.get("warp_m", 32), d.get("warp_n", 32), d.get("warp_k", 16)] + ] + + # Pipeline/scheduler combinations + ALL_PIPELINES = ["compv3", "compv4"] + ALL_SCHEDULERS = ["intrawave", "interwave"] + + pipelines = ( + ALL_PIPELINES if needs_pipeline_expansion else [d.get("pipeline", "compv4")] + ) + schedulers = ( + ALL_SCHEDULERS + if needs_scheduler_expansion + else [d.get("scheduler", "intrawave")] + ) + + expanded = [] + + for wm, wn, wk in wave_configs: + for wtm, wtn, wtk in warp_tile_configs: + # Check divisibility for conv (M=output spatial, N=K channels, K=C channels) + # Simplified check for now + if tile_k % (wn * wtn) != 0: + continue + if tile_c % (wk * wtk) != 0: + continue + + for pipeline in pipelines: + for scheduler in schedulers: + # Check trait combination + if ( + pipeline, + "cshuffle", + scheduler, + ) in TRAIT_UNSUPPORTED_COMBINATIONS: + continue + + expanded_d = d.copy() + expanded_d["wave_m"] = wm + expanded_d["wave_n"] = wn + expanded_d["wave_k"] = wk + expanded_d["warp_m"] = wtm + expanded_d["warp_n"] = wtn + expanded_d["warp_k"] = wtk + expanded_d["pipeline"] = pipeline + expanded_d["scheduler"] = scheduler + + expanded_d["name"] = ( + f"conv_{d['conv_type']}_{dtype}_{d['num_dims']}d_{pipeline}_" + f"{scheduler}_{tile_k}x{tile_c}_{wm}x{wn}x{wk}" + ) + expanded.append(expanded_d) + + if not expanded: + # Fallback to defaults + d["wave_m"] = 2 + d["wave_n"] = 2 + d["wave_k"] = 1 + d["warp_m"] = 32 + d["warp_n"] = 32 + d["warp_k"] = 16 + d["pipeline"] = "compv4" + d["scheduler"] = "intrawave" + return [d] + + return expanded + + +def generate_conv_kernels(declarations: list, gpu_target: str = "gfx942") -> int: + """Generate convolution kernels using unified_conv_codegen.""" + kernel_dir = get_generated_kernels_dir() + kernel_dir.mkdir(parents=True, exist_ok=True) + + # Import conv codegen + codegen_dir = get_dispatcher_root() / "codegen" + sys.path.insert(0, str(codegen_dir)) + + try: + from unified_conv_codegen import ( + UnifiedConvCodegen, + ConvKernelConfig, + ConvVariant, + TileConfig, + TraitConfig, + ) + except ImportError as e: + print_error(f" Failed to import conv codegen: {e}") + return 0 + + codegen = UnifiedConvCodegen(kernel_dir) + total_generated = 0 + + # Group by dtype and variant for efficient generation + groups = {} + for decl in declarations: + dtype = decl.get("dtype", "fp16") + conv_type = decl.get("conv_type", "forward") + num_dims = decl.get("num_dims", 2) + key = (dtype, conv_type, num_dims) + if key not in groups: + groups[key] = [] + groups[key].append(decl) + + for (dtype, conv_type, num_dims), decls in groups.items(): + print(f" Generating {dtype} {conv_type} {num_dims}D kernels...") + + # Map to ConvVariant + variant = ConvVariant.FORWARD + if conv_type == "bwd_data": + variant = ConvVariant.BACKWARD_DATA + elif conv_type == "bwd_weight": + variant = ConvVariant.BACKWARD_WEIGHT + + for decl in decls: + pipeline = decl.get("pipeline", "compv3") + scheduler = decl.get("scheduler", "intrawave") + epilogue = decl.get("epilogue", "cshuffle") + + tile_k = decl.get("tile_k", 128) + tile_c = decl.get("tile_c", 128) + wave_m = decl.get("wave_m", 2) + wave_n = decl.get("wave_n", 2) + warp_m = decl.get("warp_m", 32) + warp_n = decl.get("warp_n", 32) + warp_k = decl.get("warp_k", 16) + + # Adjust tile_k for compv4 + adj_tile_k = 64 * 2 if pipeline == "compv4" else 64 + + # Create TileConfig + tile_config = TileConfig( + tile_m=tile_k, # K is M in conv GEMM view + tile_n=tile_c, # C is N in conv GEMM view + tile_k=adj_tile_k, + warp_m=wave_m, + warp_n=wave_n, + warp_k=1, + warp_tile_m=warp_m, + warp_tile_n=warp_n, + warp_tile_k=warp_k, + ) + + # Create TraitConfig + trait_config = TraitConfig( + pipeline=pipeline, + scheduler=scheduler, + epilogue=epilogue, + double_smem_buffer=(pipeline == "compv4"), + pad_m=True, + pad_n=True, + pad_k=True, + ) + + # Create ConvKernelConfig + config = ConvKernelConfig( + tile=tile_config, + trait=trait_config, + variant=variant, + ndim_spatial=num_dims, + arch=gpu_target, + ) + + try: + filepath = codegen.generate_kernel(config, dtype) + total_generated += 1 + print(f" Generated: {filepath.name}") + except Exception as e: + print_error(f" Failed to generate {decl['name']}: {e}") + + return total_generated + + +# Original GEMM extraction continues here +def extract_kernel_declarations(source_file: Path) -> list: + """Extract GEMM kernel declarations from C++ source file.""" + content = source_file.read_text() + declarations = [] + seen = set() + + # ------------------------------------------------------------------------- + # Pattern 1: Simple DECL_KERNEL_SIMPLE(dtype, layout, tile_m, tile_n, tile_k) + # ------------------------------------------------------------------------- + legacy_pattern = r"DECL_KERNEL_SIMPLE\s*\(\s*(\w+)\s*,\s*(\w+)\s*,\s*(\d+)\s*,\s*(\d+)\s*,\s*(\d+)\s*\)" + for match in re.findall(legacy_pattern, content): + dtype, layout, tm, tn, tk = match + name = f"{dtype}_{layout}_{tm}x{tn}x{tk}" + if name not in seen: + seen.add(name) + declarations.append( + { + "dtype_a": dtype, + "dtype_b": dtype, + "dtype_c": dtype, + "layout": layout, + "tile_m": int(tm), + "tile_n": int(tn), + "tile_k": int(tk), + "wave_m": -1, + "wave_n": -1, + "wave_k": 1, + "warp_m": -1, + "warp_n": -1, + "warp_k": 16, + "pipeline": "compv4", + "scheduler": "intrawave", + "epilogue": "cshuffle", + "name": name, + "wildcard": False, + } + ) + + # ------------------------------------------------------------------------- + # Pattern 2: Fluent API: DECL_KERNEL(Signature()..., Algorithm()..., arch) + # ------------------------------------------------------------------------- + # Match DECL_KERNEL( ... ); blocks + fluent_pattern = r'DECL_KERNEL\s*\(\s*(Signature\(\)[^,]+),\s*(Algorithm\(\)[^,]+)(?:,\s*"([^"]+)")?\s*\)' + + for match in re.finditer(fluent_pattern, content, re.DOTALL): + sig_str = match.group(1) + algo_str = match.group(2) + arch = match.group(3) or "gfx942" + + # Parse Signature + sig = {"dtype_a": "fp16", "dtype_b": "fp16", "dtype_c": "fp16", "layout": "rcr"} + + # .dtype("fp16", "fp16", "fp16", "fp32") or .dtype("fp16") + dtype_match = re.search( + r'\.dtype\("([^"]+)"(?:,\s*"([^"]+)")?(?:,\s*"([^"]+)")?', sig_str + ) + if dtype_match: + sig["dtype_a"] = dtype_match.group(1) + sig["dtype_b"] = dtype_match.group(2) or dtype_match.group(1) + sig["dtype_c"] = dtype_match.group(3) or dtype_match.group(1) + + # .layout("rcr") or .layout("row", "col", "row") + layout_match = re.search( + r'\.layout\("([^"]+)"(?:,\s*"([^"]+)")?(?:,\s*"([^"]+)")?', sig_str + ) + if layout_match: + if layout_match.group(2): # Three-arg form + la = layout_match.group(1) + lb = layout_match.group(2) + lc = layout_match.group(3) or "row" + sig["layout"] = ( + ("r" if la == "row" else "c") + + ("r" if lb == "row" else "c") + + ("r" if lc == "row" else "c") + ) + else: # Single arg "rcr" + sig["layout"] = layout_match.group(1) + + # Parse Algorithm + algo = {} + + # .tile(128, 128, 32) + tile_match = re.search(r"\.tile\((\d+),\s*(\d+),\s*(\d+)\)", algo_str) + if tile_match: + algo["tile_m"] = int(tile_match.group(1)) + algo["tile_n"] = int(tile_match.group(2)) + algo["tile_k"] = int(tile_match.group(3)) + + # .wave(2, 2, 1) + wave_match = re.search(r"\.wave\((\d+),\s*(\d+)(?:,\s*(\d+))?\)", algo_str) + if wave_match: + algo["wave_m"] = int(wave_match.group(1)) + algo["wave_n"] = int(wave_match.group(2)) + algo["wave_k"] = int(wave_match.group(3) or 1) + + # .warp(32, 32, 16) + warp_match = re.search(r"\.warp\((\d+),\s*(\d+)(?:,\s*(\d+))?\)", algo_str) + if warp_match: + algo["warp_m"] = int(warp_match.group(1)) + algo["warp_n"] = int(warp_match.group(2)) + algo["warp_k"] = int(warp_match.group(3) or 16) + + # .pipeline("compv4"), .scheduler("intrawave"), .epilogue("cshuffle") + for field in ["pipeline", "scheduler", "epilogue"]: + fmatch = re.search(rf'\.{field}\("([^"]+)"\)', algo_str) + if fmatch: + algo[field] = fmatch.group(1) + + # Build declaration + tm = algo.get("tile_m", 128) + tn = algo.get("tile_n", 128) + tk = algo.get("tile_k", 32) + + name = f"{sig['dtype_a']}_{sig['layout']}_{tm}x{tn}x{tk}" + + if name not in seen: + seen.add(name) + declarations.append( + { + "dtype_a": sig["dtype_a"], + "dtype_b": sig["dtype_b"], + "dtype_c": sig["dtype_c"], + "layout": sig["layout"], + "tile_m": tm, + "tile_n": tn, + "tile_k": tk, + "wave_m": algo.get("wave_m", -1), + "wave_n": algo.get("wave_n", -1), + "wave_k": algo.get("wave_k", 1), + "warp_m": algo.get("warp_m", -1), + "warp_n": algo.get("warp_n", -1), + "warp_k": algo.get("warp_k", 16), + "pipeline": algo.get("pipeline", "compv4"), + "scheduler": algo.get("scheduler", "intrawave"), + "epilogue": algo.get("epilogue", "cshuffle"), + "arch": arch, + "name": name, + "wildcard": False, + } + ) + + # ------------------------------------------------------------------------- + # Pattern 3: DECL_KERNEL_ALL(dtype, layout) - wildcard + # ------------------------------------------------------------------------- + all_pattern = r"DECL_KERNEL(?:S)?_ALL\s*\(\s*(\w+)\s*,\s*(\w+)\s*\)" + for match in re.findall(all_pattern, content): + dtype, layout = match + name = f"wildcard_{dtype}_{layout}" + if name not in seen: + seen.add(name) + declarations.append( + { + "dtype_a": dtype, + "dtype_b": dtype, + "dtype_c": dtype, + "layout": layout, + "tile_m": -1, + "tile_n": -1, + "tile_k": -1, + "wave_m": -1, + "wave_n": -1, + "wave_k": 1, + "warp_m": -1, + "warp_n": -1, + "warp_k": 16, + "pipeline": "compv4", + "scheduler": "intrawave", + "epilogue": "cshuffle", + "name": name, + "wildcard": True, + } + ) + + # ------------------------------------------------------------------------- + # Pattern 4: DECL_KERNEL_SIMPLE(dtype, layout, tm, tn, tk) + # ------------------------------------------------------------------------- + simple_pattern = r"DECL_KERNEL_SIMPLE\s*\(\s*(\w+)\s*,\s*(\w+)\s*,\s*(\d+)\s*,\s*(\d+)\s*,\s*(\d+)\s*\)" + for match in re.findall(simple_pattern, content): + dtype, layout, tm, tn, tk = match + name = f"{dtype}_{layout}_{tm}x{tn}x{tk}" + if name not in seen: + seen.add(name) + declarations.append( + { + "dtype_a": dtype, + "dtype_b": dtype, + "dtype_c": dtype, + "layout": layout, + "tile_m": int(tm), + "tile_n": int(tn), + "tile_k": int(tk), + "wave_m": -1, + "wave_n": -1, + "wave_k": 1, + "warp_m": -1, + "warp_n": -1, + "warp_k": 16, + "pipeline": "compv4", + "scheduler": "intrawave", + "epilogue": "cshuffle", + "name": name, + "wildcard": False, + "set": None, + } + ) + + # ------------------------------------------------------------------------- + # Pattern 5: DECL_KERNEL_SET(name, .add(...).add(...)) + # Named kernel sets for multiple registries + # Match only DECL_KERNEL_SET at start of line (not in comments) + # ------------------------------------------------------------------------- + set_pattern = r"^DECL_KERNEL_SET\s*\(\s*(\w+)\s*,([\s\S]*?)\)\s*;" + for match in re.finditer(set_pattern, content, re.MULTILINE): + set_name = match.group(1) + set_body = match.group(2) + + # Parse .add("dtype", "layout", tm, tn, tk) calls - simple form + add_simple = r'\.add\s*\(\s*"(\w+)"\s*,\s*"(\w+)"\s*,\s*(\d+)\s*,\s*(\d+)\s*,\s*(\d+)\s*\)' + for add_match in re.findall(add_simple, set_body): + dtype, layout, tm, tn, tk = add_match + name = f"{set_name}:{dtype}_{layout}_{tm}x{tn}x{tk}" + if name not in seen: + seen.add(name) + declarations.append( + { + "dtype_a": dtype, + "dtype_b": dtype, + "dtype_c": dtype, + "layout": layout, + "tile_m": int(tm), + "tile_n": int(tn), + "tile_k": int(tk), + "wave_m": -1, + "wave_n": -1, + "wave_k": 1, + "warp_m": -1, + "warp_n": -1, + "warp_k": 16, + "pipeline": "compv4", + "scheduler": "intrawave", + "epilogue": "cshuffle", + "name": name, + "wildcard": False, + "set": set_name, + } + ) + + # Parse .add(Signature()..., Algorithm()..., "arch") fluent calls + # Robust approach: find each .add( block and parse methods individually + # This handles any method order and optional methods + + # Split set_body into .add() blocks + add_blocks = [] + add_starts = [m.start() for m in re.finditer(r"\.add\s*\(", set_body)] + + for i, start in enumerate(add_starts): + # Find the matching closing paren by counting parens + depth = 0 + end = start + in_string = False + escape_next = False + + for j, ch in enumerate(set_body[start:], start): + if escape_next: + escape_next = False + continue + if ch == "\\": + escape_next = True + continue + if ch == '"' and not escape_next: + in_string = not in_string + continue + if in_string: + continue + if ch == "(": + depth += 1 + elif ch == ")": + depth -= 1 + if depth == 0: + end = j + 1 + break + + if end > start: + add_blocks.append(set_body[start:end]) + + for add_block in add_blocks: + # Skip if doesn't have both Signature() and Algorithm() + if "Signature()" not in add_block or "Algorithm()" not in add_block: + continue + + # Split on Algorithm() to separate Signature and Algorithm parts + algo_idx = add_block.find("Algorithm()") + if algo_idx == -1: + continue + + sig_str = add_block[:algo_idx] + algo_str = add_block[algo_idx:] # Include Algorithm() and everything after + + # Parse dtype from Signature - handles .dtype("fp16", "fp16", "fp16", "fp32") + dtype = "fp16" + dtype_m = re.search(r'\.dtype\s*\(\s*"([^"]+)"', sig_str) + if dtype_m: + dtype = dtype_m.group(1) + + # Parse layout from Signature - handles .layout("row", "col", "row") + layout = "rcr" + layout_m = re.search( + r'\.layout\s*\(\s*"([^"]+)"\s*,\s*"([^"]+)"\s*,\s*"([^"]+)"', sig_str + ) + if layout_m: + la, lb, lc = layout_m.group(1), layout_m.group(2), layout_m.group(3) + layout = ( + ("r" if la == "row" else "c") + + ("r" if lb == "row" else "c") + + ("r" if lc == "row" else "c") + ) + else: + # Single arg form: .layout("rcr") + layout_m = re.search(r'\.layout\s*\(\s*"([^"]+)"', sig_str) + if layout_m: + layout = layout_m.group(1) + + # Parse tile from Algorithm + tm, tn, tk = 128, 128, 32 + tile_m = re.search( + r"\.tile\s*\(\s*(\d+)\s*,\s*(\d+)\s*,\s*(\d+)\s*\)", algo_str + ) + if tile_m: + tm, tn, tk = ( + int(tile_m.group(1)), + int(tile_m.group(2)), + int(tile_m.group(3)), + ) + + # Parse wave + wave_m, wave_n, wave_k = 2, 2, 1 + wave_match = re.search( + r"\.wave\s*\(\s*(\d+)\s*,\s*(\d+)(?:\s*,\s*(\d+))?\s*\)", algo_str + ) + if wave_match: + wave_m, wave_n = int(wave_match.group(1)), int(wave_match.group(2)) + wave_k = int(wave_match.group(3) or 1) + + # Parse warp + warp_m, warp_n, warp_k = 32, 32, 16 + warp_match = re.search( + r"\.warp\s*\(\s*(\d+)\s*,\s*(\d+)(?:\s*,\s*(\d+))?\s*\)", algo_str + ) + if warp_match: + warp_m, warp_n = int(warp_match.group(1)), int(warp_match.group(2)) + warp_k = int(warp_match.group(3) or 16) + + # Parse pipeline - NEW: extract from declaration + pipeline = "compv4" + pipeline_m = re.search(r'\.pipeline\s*\(\s*"([^"]+)"', algo_str) + if pipeline_m: + pipeline = pipeline_m.group(1) + + # Parse scheduler - NEW: extract from declaration + scheduler = "intrawave" + scheduler_m = re.search(r'\.scheduler\s*\(\s*"([^"]+)"', algo_str) + if scheduler_m: + scheduler = scheduler_m.group(1) + + # Parse epilogue - NEW: extract from declaration + epilogue = "cshuffle" + epilogue_m = re.search(r'\.epilogue\s*\(\s*"([^"]+)"', algo_str) + if epilogue_m: + epilogue = epilogue_m.group(1) + + # Parse padding - NEW: extract from declaration + pad_m, pad_n, pad_k = False, False, False + pad_match = re.search( + r"\.pad\s*\(\s*(true|false)\s*,\s*(true|false)\s*,\s*(true|false)\s*\)", + algo_str, + re.IGNORECASE, + ) + if pad_match: + pad_m = pad_match.group(1).lower() == "true" + pad_n = pad_match.group(2).lower() == "true" + pad_k = pad_match.group(3).lower() == "true" + + # Parse elementwise from Signature - for Multi-D kernels + elementwise_op = "PassThrough" + num_d_tensors = 0 + elem_match = re.search( + r'\.elementwise\s*\(\s*"([^"]+)"\s*,\s*(\d+)\s*\)', + sig_str, + ) + if elem_match: + elementwise_op = elem_match.group(1) + num_d_tensors = int(elem_match.group(2)) + + name = f"{set_name}:{dtype}_{layout}_{pipeline}_{scheduler}_{tm}x{tn}x{tk}_{wave_m}x{wave_n}x{wave_k}" + if elementwise_op != "PassThrough": + name += f"_{elementwise_op}_d{num_d_tensors}" + if name not in seen: + seen.add(name) + declarations.append( + { + "dtype_a": dtype, + "dtype_b": dtype, + "dtype_c": dtype, + "layout": layout, + "tile_m": tm, + "tile_n": tn, + "tile_k": tk, + "wave_m": wave_m, + "wave_n": wave_n, + "wave_k": wave_k, + "warp_m": warp_m, + "warp_n": warp_n, + "warp_k": warp_k, + "pipeline": pipeline, + "scheduler": scheduler, + "epilogue": epilogue, + "pad_m": pad_m, + "pad_n": pad_n, + "pad_k": pad_k, + "elementwise_op": elementwise_op, + "num_d_tensors": num_d_tensors, + "name": name, + "wildcard": False, + "set": set_name, + } + ) + + return declarations + + +def expand_declaration_with_arch_filter(decl: dict, arch: str = "gfx942") -> list: + """Expand a declaration to all valid combinations using arch filter. + + Expands wildcards for: + - wave/warp: If -1, generates all valid wave/warp_tile combinations + - pipeline/scheduler/epilogue: If "*", generates all valid trait combinations + + Uses the arch_filter module for architecture-specific validation. + """ + # Import arch filter + codegen_dir = get_dispatcher_root() / "codegen" + sys.path.insert(0, str(codegen_dir)) + + try: + from arch_specs_generated import ( + WARP_SUPPORTED_COMBINATIONS, + WARP_TILE_SUPPORTED_COMBINATIONS, + TRAIT_UNSUPPORTED_COMBINATIONS, + ) + except ImportError: + # Fallback to hardcoded valid combinations + WARP_SUPPORTED_COMBINATIONS = { + "gfx90a": [[1, 4, 1], [2, 2, 1], [4, 1, 1]], + "gfx942": [[1, 4, 1], [2, 2, 1], [4, 1, 1]], + "gfx950": [[1, 4, 1], [2, 2, 1], [4, 1, 1]], + } + WARP_TILE_SUPPORTED_COMBINATIONS = { + "gfx942": {"fp16_fp16_fp16": [[16, 16, 16], [32, 32, 16]]}, + } + TRAIT_UNSUPPORTED_COMBINATIONS = { + ("compv3", "cshuffle", "interwave"), + ("compv3", "default", "interwave"), + ("compv4", "cshuffle", "interwave"), + ("compv4", "default", "interwave"), + } + + d = decl.copy() + tm = d.get("tile_m", 128) + tn = d.get("tile_n", 128) + tk = d.get("tile_k", 32) + dtype = d.get("dtype_a", "fp16") + + # Check what needs expansion + needs_wave_expansion = d.get("wave_m", -1) < 0 or d.get("wave_n", -1) < 0 + needs_warp_expansion = d.get("warp_m", -1) < 0 or d.get("warp_n", -1) < 0 + needs_pipeline_expansion = d.get("pipeline", "compv4") == "*" + needs_scheduler_expansion = d.get("scheduler", "intrawave") == "*" + needs_epilogue_expansion = d.get("epilogue", "cshuffle") == "*" + needs_pad_m_expansion = d.get("pad_m", 1) == -1 + needs_pad_n_expansion = d.get("pad_n", 1) == -1 + needs_pad_k_expansion = d.get("pad_k", 1) == -1 + needs_trait_expansion = ( + needs_pipeline_expansion + or needs_scheduler_expansion + or needs_epilogue_expansion + ) + needs_pad_expansion = ( + needs_pad_m_expansion or needs_pad_n_expansion or needs_pad_k_expansion + ) + + if ( + not needs_wave_expansion + and not needs_warp_expansion + and not needs_trait_expansion + and not needs_pad_expansion + ): + # Already fully specified + return [d] + + # === Build valid combinations === + + # Wave configurations + if needs_wave_expansion: + wave_configs = WARP_SUPPORTED_COMBINATIONS.get(arch, [[2, 2, 1]]) + else: + wave_configs = [[d.get("wave_m", 2), d.get("wave_n", 2), d.get("wave_k", 1)]] + + # Warp tile configurations + if needs_warp_expansion: + arch_warp_tiles = WARP_TILE_SUPPORTED_COMBINATIONS.get(arch, {}) + + # Try to find warp tile configs for this dtype + # Keys are like: fp16_fp16_fp32, int8_int8_int32, etc. + warp_tile_configs = None + dtype_key_variants = [ + f"{dtype}_{dtype}_{dtype}", # e.g., fp32_fp32_fp32 + f"{dtype}_{dtype}_fp32", # e.g., fp16_fp16_fp32 + f"{dtype}_{dtype}_int32", # e.g., int8_int8_int32 + ] + for dtype_key in dtype_key_variants: + warp_tile_configs = arch_warp_tiles.get(dtype_key, None) + if warp_tile_configs is not None: + break + + # If dtype is not supported on this arch, return empty list + if warp_tile_configs is None: + return [] + else: + warp_tile_configs = [ + [d.get("warp_m", 32), d.get("warp_n", 32), d.get("warp_k", 16)] + ] + + # Pipeline/scheduler/epilogue combinations + # Valid options per category + ALL_PIPELINES = ["compv3", "compv4"] # Most common; add more if needed + ALL_SCHEDULERS = ["intrawave", "interwave"] + ALL_EPILOGUES = ["cshuffle", "default"] + ALL_PAD_OPTIONS = [False, True] # 0 and 1 + + pipelines = ( + ALL_PIPELINES if needs_pipeline_expansion else [d.get("pipeline", "compv4")] + ) + schedulers = ( + ALL_SCHEDULERS + if needs_scheduler_expansion + else [d.get("scheduler", "intrawave")] + ) + epilogues = ( + ALL_EPILOGUES if needs_epilogue_expansion else [d.get("epilogue", "cshuffle")] + ) + pad_m_opts = ALL_PAD_OPTIONS if needs_pad_m_expansion else [bool(d.get("pad_m", 1))] + pad_n_opts = ALL_PAD_OPTIONS if needs_pad_n_expansion else [bool(d.get("pad_n", 1))] + pad_k_opts = ALL_PAD_OPTIONS if needs_pad_k_expansion else [bool(d.get("pad_k", 1))] + + expanded = [] + + # Generate all valid combinations + for wm, wn, wk in wave_configs: + for wtm, wtn, wtk in warp_tile_configs: + # Check divisibility constraints + if tm % (wm * wtm) != 0: + continue + if tn % (wn * wtn) != 0: + continue + if tk % (wk * wtk) != 0: + continue + + for pipeline in pipelines: + for scheduler in schedulers: + for epilogue in epilogues: + # Check trait combination is valid + if ( + pipeline, + epilogue, + scheduler, + ) in TRAIT_UNSUPPORTED_COMBINATIONS: + continue + + for pad_m in pad_m_opts: + for pad_n in pad_n_opts: + for pad_k in pad_k_opts: + # Create expanded declaration + expanded_d = d.copy() + expanded_d["wave_m"] = wm + expanded_d["wave_n"] = wn + expanded_d["wave_k"] = wk + expanded_d["warp_m"] = wtm + expanded_d["warp_n"] = wtn + expanded_d["warp_k"] = wtk + expanded_d["pipeline"] = pipeline + expanded_d["scheduler"] = scheduler + expanded_d["epilogue"] = epilogue + expanded_d["pad_m"] = int(pad_m) + expanded_d["pad_n"] = int(pad_n) + expanded_d["pad_k"] = int(pad_k) + + pad_str = f"{'T' if pad_m else 'F'}{'T' if pad_n else 'F'}{'T' if pad_k else 'F'}" + expanded_d["name"] = ( + f"{dtype}_{d.get('layout', 'rcr')}_{pipeline}_{scheduler}_" + f"pad{pad_str}_{tm}x{tn}x{tk}_{wm}x{wn}x{wk}" + ) + expanded_d["wildcard"] = False + expanded.append(expanded_d) + + if not expanded: + # No valid combinations found, return single default + print(f" Warning: No valid combinations for {tm}x{tn}x{tk} on {arch}") + d["wave_m"] = 2 + d["wave_n"] = 2 + d["wave_k"] = 1 + d["warp_m"] = 32 + d["warp_n"] = 32 + d["warp_k"] = 16 + d["pipeline"] = "compv4" + d["scheduler"] = "intrawave" + d["epilogue"] = "cshuffle" + return [d] + + return expanded + + +def auto_fill_declaration(decl: dict) -> dict: + """Auto-fill with single default (for backward compat).""" + expanded = expand_declaration_with_arch_filter(decl, decl.get("arch", "gfx942")) + return expanded[0] if expanded else decl + + +# ============================================================================= +# Build Functions +# ============================================================================= + + +def generate_kernels(declarations: list, gpu_target: str = "gfx942") -> int: + """Generate kernels using CodegenRunner from ctypes_utils.""" + kernel_dir = get_generated_kernels_dir() + kernel_dir.mkdir(parents=True, exist_ok=True) + + # Group by dtype+layout for efficient generation + groups = {} + for decl in declarations: + dtype = decl.get("dtype_a", decl.get("dtype", "fp16")) + layout = decl.get("layout", "rcr") + key = (dtype, layout) + if key not in groups: + groups[key] = [] + groups[key].append(auto_fill_declaration(decl)) + + total_generated = 0 + + for (dtype, layout), decls in groups.items(): + print(f" Generating {dtype} {layout} kernels...") + + # Check for wildcards - if any decl is wildcard, generate all + has_wildcard = any(d.get("wildcard", False) for d in decls) + + # Use CodegenRunner from ctypes_utils + runner = CodegenRunner( + datatype=dtype, + layout=layout, + gpu_target=gpu_target, + ) + + result = runner.generate("standard") + + if result.success: + total_generated += result.kernel_count + if has_wildcard: + print(f" [wildcard] Generated all {result.kernel_count} variants") + else: + print_error(f" Failed: {result.stderr[:200]}") + + return total_generated + + +def get_arch_filter_data(): + """Load arch filter data from arch_specs_generated if available.""" + codegen_dir = get_dispatcher_root() / "codegen" + sys.path.insert(0, str(codegen_dir)) + + try: + from arch_specs_generated import ( + TRAIT_UNSUPPORTED_COMBINATIONS, + WARP_SUPPORTED_COMBINATIONS, + WARP_TILE_SUPPORTED_COMBINATIONS, + get_supported_archs, + ) + + return { + "trait_unsupported": TRAIT_UNSUPPORTED_COMBINATIONS, + "warp_combos": WARP_SUPPORTED_COMBINATIONS, + "warp_tile_combos": WARP_TILE_SUPPORTED_COMBINATIONS, + "supported_archs": get_supported_archs(), + } + except ImportError: + # Fallback defaults + return { + "trait_unsupported": { + ("compv3", "cshuffle", "interwave"), + ("compv3", "default", "interwave"), + ("compv4", "cshuffle", "interwave"), + ("compv4", "default", "interwave"), + }, + "warp_combos": { + "gfx942": [[1, 4, 1], [2, 2, 1], [4, 1, 1]], + }, + "warp_tile_combos": { + "gfx942": {"fp16_fp16_fp16": [[16, 16, 16], [32, 32, 16]]}, + }, + "supported_archs": ["gfx90a", "gfx942", "gfx950"], + } + + +def is_wildcard_declaration(decl: dict) -> bool: + """Check if declaration has wildcards that need expansion.""" + # Wave/warp wildcards + if decl.get("wave_m", 2) < 0 or decl.get("wave_n", 2) < 0: + return True + if decl.get("warp_m", 32) < 0 or decl.get("warp_n", 32) < 0: + return True + # Pipeline/scheduler wildcards + if decl.get("pipeline", "compv4") == "*": + return True + if decl.get("scheduler", "intrawave") == "*": + return True + if decl.get("epilogue", "cshuffle") == "*": + return True + return False + + +def validate_kernel_config(decl: dict, arch: str = "gfx942") -> tuple: + """Validate a kernel configuration against known supported combinations. + + Uses arch_specs_generated for architecture-specific validation. + + For wildcard declarations (-1 values or "*" strings), validation is skipped + because the expansion phase will generate only valid combinations. + + Returns: (is_valid, error_message) + """ + # Skip validation for wildcards - expansion will filter invalid combos + if is_wildcard_declaration(decl): + return (True, None) + + arch_data = get_arch_filter_data() + + pipeline = decl.get("pipeline", "compv4") + epilogue = decl.get("epilogue", "cshuffle") + scheduler = decl.get("scheduler", "intrawave") + dtype = decl.get("dtype_a", "fp16") + + wave_m = decl.get("wave_m", 2) + wave_n = decl.get("wave_n", 2) + wave_k = decl.get("wave_k", 1) + + warp_m = decl.get("warp_m", 32) + warp_n = decl.get("warp_n", 32) + warp_k = decl.get("warp_k", 16) + + errors = [] + + # Check trait combination (pipeline, epilogue, scheduler) + combo = (pipeline, epilogue, scheduler) + if combo in arch_data["trait_unsupported"]: + errors.append( + f"Unsupported trait combination: pipeline={pipeline}, epilogue={epilogue}, scheduler={scheduler}\n" + f" Valid schedulers for {pipeline}+{epilogue}: intrawave" + ) + + # Check wave configuration for this arch + warp_combos = arch_data["warp_combos"].get(arch, [[2, 2, 1]]) + wave_cfg = [wave_m, wave_n, wave_k] + if wave_cfg not in warp_combos: + valid_str = ", ".join(f"[{c[0]},{c[1]},{c[2]}]" for c in warp_combos) + errors.append( + f"Unsupported wave configuration [{wave_m},{wave_n},{wave_k}] for {arch}\n" + f" Valid wave configs: {valid_str}" + ) + + # Check warp tile configuration for this arch and dtype + dtype_key = f"{dtype}_{dtype}_{dtype}" + warp_tile_combos = ( + arch_data["warp_tile_combos"] + .get(arch, {}) + .get(dtype_key, [[32, 32, 16], [16, 16, 16]]) + ) + warp_cfg = [warp_m, warp_n, warp_k] + if warp_cfg not in warp_tile_combos: + valid_str = ", ".join(f"[{c[0]},{c[1]},{c[2]}]" for c in warp_tile_combos[:5]) + errors.append( + f"Unsupported warp tile [{warp_m},{warp_n},{warp_k}] for {arch}/{dtype}\n" + f" Valid warp tiles: {valid_str}" + ) + + # Check arch is supported + if arch not in arch_data["supported_archs"]: + errors.append( + f"Unsupported architecture: {arch}\n" + f" Supported: {', '.join(arch_data['supported_archs'])}" + ) + + if errors: + return (False, "\n".join(errors)) + + return (True, None) + + +def build_exact_kernel_filename(decl: dict) -> str: + """Build the exact kernel filename from a fully-specified declaration. + + Standard format: + gemm_{dtype}_{layout}_{pipeline}_{epilogue}_{scheduler}_{pad_m}_{pad_n}_{pad_k}_{preshuffle}_{tile}_{wave}_{warp}.hpp + + Multi-D format: + gemm_{dtype}_{layout}_{pipeline}_{epilogue}_{scheduler}_{pad_m}_{pad_n}_{pad_k}_{preshuffle}_{tile}_{wave}_{warp}_multid_{op}_d{num}.hpp + """ + dtype = decl.get("dtype_a", decl.get("dtype", "fp16")) + layout = decl.get("layout", "rcr") + pipeline = decl.get("pipeline", "compv4") + epilogue = decl.get("epilogue", "cshuffle") + scheduler = decl.get("scheduler", "intrawave") + + pad_m = "True" if decl.get("pad_m", False) else "False" + pad_n = "True" if decl.get("pad_n", False) else "False" + pad_k = "True" if decl.get("pad_k", False) else "False" + preshuffle = "True" if decl.get("preshuffle", False) else "False" + + tile_m = decl.get("tile_m", 128) + tile_n = decl.get("tile_n", 128) + tile_k = decl.get("tile_k", 32) + + wave_m = decl.get("wave_m", 2) + wave_n = decl.get("wave_n", 2) + wave_k = decl.get("wave_k", 1) + + warp_m = decl.get("warp_m", 32) + warp_n = decl.get("warp_n", 32) + warp_k = decl.get("warp_k", 16) + + tile_str = f"{tile_m}x{tile_n}x{tile_k}" + wave_str = f"{wave_m}x{wave_n}x{wave_k}" + warp_str = f"{warp_m}x{warp_n}x{warp_k}" + + base = f"gemm_{dtype}_{layout}_{pipeline}_{epilogue}_{scheduler}_{pad_m}_{pad_n}_{pad_k}_{preshuffle}_{tile_str}_{wave_str}_{warp_str}" + + # Handle Multi-D kernels + elementwise_op = decl.get("elementwise_op", "PassThrough") + num_d_tensors = decl.get("num_d_tensors", 0) + if elementwise_op != "PassThrough" and num_d_tensors > 0: + base += f"_multid_{elementwise_op}_d{num_d_tensors}" + + return f"{base}.hpp" + + +def generate_specific_kernel(decl: dict, gpu_target: str = "gfx942") -> bool: + """Generate a specific kernel based on declaration.""" + dtype = decl.get("dtype_a", decl.get("dtype", "fp16")) + layout = decl.get("layout", "rcr") + + print(f" Generating kernel for {dtype}/{layout}...") + + # Use CodegenRunner to generate + runner = CodegenRunner( + datatype=dtype, + layout=layout, + gpu_target=gpu_target, + ) + + result = runner.generate("standard") + return result.success + + +def find_kernel_header(decl: dict, gpu_target: str = "gfx942") -> Path: + """Find a matching kernel header file for a declaration. + + Tries multiple matching strategies: + 1. Exact filename match + 2. Match with key parameters (dtype, layout, pipeline, scheduler, tile) + 3. Match with just dtype, layout, and tile (more flexible) + 4. Any kernel with matching dtype and layout + + If no kernel exists, attempts to generate it. + Returns None only if all strategies fail. + """ + kernel_dir = get_generated_kernels_dir() + + dtype = decl.get("dtype_a", decl.get("dtype", "fp16")) + layout = decl.get("layout", "rcr") + pipeline = decl.get("pipeline", "compv4") + scheduler = decl.get("scheduler", "intrawave") + tile_m = decl.get("tile_m", 128) + tile_n = decl.get("tile_n", 128) + tile_k = decl.get("tile_k", 32) + wave_m = decl.get("wave_m", 2) + wave_n = decl.get("wave_n", 2) + wave_k = decl.get("wave_k", 1) + + tile_str = f"{tile_m}x{tile_n}x{tile_k}" + wave_str = f"{wave_m}x{wave_n}x{wave_k}" + + # Build exact filename + exact_filename = build_exact_kernel_filename(decl) + exact_path = kernel_dir / exact_filename + + # Strategy 1: Exact filename match + if exact_path.exists(): + print(f" Found exact kernel: {exact_filename}") + return exact_path + + # Strategy 2: Match with key parameters + pattern = ( + f"gemm_{dtype}_{layout}_{pipeline}_*_{scheduler}_*_{tile_str}_{wave_str}_*.hpp" + ) + matches = list(kernel_dir.glob(pattern)) + if matches: + print(f" Found matching kernel: {matches[0].name}") + return matches[0] + + # Strategy 3: Match with just dtype, layout, tile (ignore wave/warp) + pattern = f"gemm_{dtype}_{layout}_{pipeline}_*_{scheduler}_*_{tile_str}_*.hpp" + matches = list(kernel_dir.glob(pattern)) + if matches: + print(f" Found kernel with matching tile: {matches[0].name}") + return matches[0] + + # Strategy 4: Match with just dtype, layout (most flexible, for wildcards) + # Prefer kernels with intrawave scheduler (known to work) + pattern = f"gemm_{dtype}_{layout}_*_intrawave_*_{tile_str}_*.hpp" + matches = list(kernel_dir.glob(pattern)) + if matches: + print(f" Found kernel with intrawave: {matches[0].name}") + return matches[0] + + # Strategy 5: Any kernel with matching dtype and layout + pattern = f"gemm_{dtype}_{layout}_*_{tile_str}_*.hpp" + matches = list(kernel_dir.glob(pattern)) + if matches: + print(f" Found kernel with matching dtype/layout/tile: {matches[0].name}") + return matches[0] + + # Strategy 6: Try to generate the kernel + print(" No matching kernel found, attempting to generate...") + if generate_specific_kernel(decl, gpu_target): + # Check strategies again after generation + for pattern in [ + f"gemm_{dtype}_{layout}_{pipeline}_*_{scheduler}_*_{tile_str}_*.hpp", + f"gemm_{dtype}_{layout}_*_intrawave_*_{tile_str}_*.hpp", + f"gemm_{dtype}_{layout}_*_{tile_str}_*.hpp", + ]: + matches = list(kernel_dir.glob(pattern)) + if matches: + print(f" Generated: {matches[0].name}") + return matches[0] + + # All strategies failed - return None (caller will try next expanded decl) + return None + + +def is_conv_wildcard_declaration(decl: dict) -> bool: + """Check if conv declaration has wildcards that need expansion.""" + if decl.get("wave_m", 2) < 0 or decl.get("wave_n", 2) < 0: + return True + if decl.get("warp_m", 32) < 0 or decl.get("warp_n", 32) < 0: + return True + if decl.get("pipeline", "compv3") == "*": + return True + if decl.get("scheduler", "intrawave") == "*": + return True + return False + + +def validate_conv_kernel_config(decl: dict, arch: str = "gfx942") -> tuple: + """Validate a conv kernel configuration against arch filter. + + For wildcard declarations, validation is skipped (expansion handles it). + + Returns: (is_valid, error_message) + """ + # Skip validation for wildcards + if is_conv_wildcard_declaration(decl): + return (True, None) + + arch_data = get_arch_filter_data() + + pipeline = decl.get("pipeline", "compv3") + epilogue = decl.get("epilogue", "cshuffle") + scheduler = decl.get("scheduler", "intrawave") + dtype = decl.get("dtype", "fp16") + + wave_m = decl.get("wave_m", 2) + wave_n = decl.get("wave_n", 2) + wave_k = decl.get("wave_k", 1) + + warp_m = decl.get("warp_m", 32) + warp_n = decl.get("warp_n", 32) + warp_k = decl.get("warp_k", 16) + + errors = [] + + # Check trait combination + combo = (pipeline, epilogue, scheduler) + if combo in arch_data["trait_unsupported"]: + errors.append( + f"Unsupported trait combination: pipeline={pipeline}, epilogue={epilogue}, scheduler={scheduler}\n" + f" Valid schedulers for {pipeline}+{epilogue}: intrawave" + ) + + # Check wave configuration + warp_combos = arch_data["warp_combos"].get(arch, [[2, 2, 1]]) + wave_cfg = [wave_m, wave_n, wave_k] + if wave_cfg not in warp_combos: + valid_str = ", ".join(f"[{c[0]},{c[1]},{c[2]}]" for c in warp_combos) + errors.append( + f"Unsupported wave configuration [{wave_m},{wave_n},{wave_k}] for {arch}\n" + f" Valid wave configs: {valid_str}" + ) + + # Check warp tile configuration + dtype_key = f"{dtype}_{dtype}_{dtype}" + warp_tile_combos = ( + arch_data["warp_tile_combos"] + .get(arch, {}) + .get(dtype_key, [[32, 32, 16], [16, 16, 16]]) + ) + warp_cfg = [warp_m, warp_n, warp_k] + if warp_cfg not in warp_tile_combos: + valid_str = ", ".join(f"[{c[0]},{c[1]},{c[2]}]" for c in warp_tile_combos[:5]) + errors.append( + f"Unsupported warp tile [{warp_m},{warp_n},{warp_k}] for {arch}/{dtype}\n" + f" Valid warp tiles: {valid_str}" + ) + + # Check arch is supported + if arch not in arch_data["supported_archs"]: + errors.append( + f"Unsupported architecture: {arch}\n" + f" Supported: {', '.join(arch_data['supported_archs'])}" + ) + + if errors: + return (False, "\n".join(errors)) + + return (True, None) + + +def build_exact_conv_kernel_filename(decl: dict) -> str: + """Build the exact conv kernel filename from a fully-specified declaration. + + Conv filename format: + conv_{type}_{dtype}_{ndim}d_{pipeline}_{epilogue}_{scheduler}_{tile}_{wave}.hpp + + Example: + conv_fwd_fp16_2d_compv3_cshuffle_intrawave_128x128x32_2x2x1.hpp + """ + dtype = decl.get("dtype", "fp16") + conv_type = decl.get("conv_type", "forward") + num_dims = decl.get("num_dims", 2) + pipeline = decl.get("pipeline", "compv3") + epilogue = decl.get("epilogue", "cshuffle") + scheduler = decl.get("scheduler", "intrawave") + + # Map conv_type to filename prefix + if conv_type == "forward": + type_prefix = "fwd" + elif conv_type == "bwd_data": + type_prefix = "bwdd" + elif conv_type == "bwd_weight": + type_prefix = "bwdw" + else: + type_prefix = conv_type + + tile_k = decl.get("tile_k", 128) + tile_c = decl.get("tile_c", 128) + + wave_m = decl.get("wave_m", 2) + wave_n = decl.get("wave_n", 2) + wave_k = decl.get("wave_k", 1) + + tile_str = f"{tile_k}x{tile_c}x32" # Conv uses tile_k x tile_c x 32 format + wave_str = f"{wave_m}x{wave_n}x{wave_k}" + + return f"conv_{type_prefix}_{dtype}_{num_dims}d_{pipeline}_{epilogue}_{scheduler}_{tile_str}_{wave_str}.hpp" + + +def generate_specific_conv_kernel(decl: dict, gpu_target: str = "gfx942") -> bool: + """Generate a specific conv kernel based on declaration.""" + dtype = decl.get("dtype", "fp16") + conv_type = decl.get("conv_type", "forward") + num_dims = decl.get("num_dims", 2) + + print(f" Generating conv kernel for {dtype}/{conv_type}/{num_dims}d...") + + # Map to variant name + if conv_type == "forward": + variant = "forward" + elif conv_type == "bwd_data": + variant = "bwd_data" + elif conv_type == "bwd_weight": + variant = "bwd_weight" + else: + variant = "forward" + + # Use unified_conv_codegen + codegen_dir = get_dispatcher_root() / "codegen" + codegen_script = codegen_dir / "unified_conv_codegen.py" + output_dir = get_generated_kernels_dir() + + cmd = [ + "python3", + str(codegen_script), + "--datatype", + dtype, + "--variant", + variant, + "--ndim", + str(num_dims), + "--arch", + gpu_target, + "--output", + str(output_dir), + ] + + try: + result = subprocess.run(cmd, capture_output=True, text=True, timeout=300) + return result.returncode == 0 + except subprocess.TimeoutExpired: + return False + + +def find_conv_kernel_header(decl: dict, gpu_target: str = "gfx942") -> Path: + """Find the EXACT matching conv kernel header file for a declaration. + + If the kernel doesn't exist, attempts to generate it. + Returns None only if generation also fails. + """ + kernel_dir = get_generated_kernels_dir() + + # Build exact filename + exact_filename = build_exact_conv_kernel_filename(decl) + exact_path = kernel_dir / exact_filename + + # Check if exact kernel exists + if exact_path.exists(): + print(f" Found exact conv kernel: {exact_filename}") + return exact_path + + # Try to find with glob (in case of minor variations) + dtype = decl.get("dtype", "fp16") + conv_type = decl.get("conv_type", "forward") + num_dims = decl.get("num_dims", 2) + pipeline = decl.get("pipeline", "compv3") + scheduler = decl.get("scheduler", "intrawave") + tile_k = decl.get("tile_k", 128) + tile_c = decl.get("tile_c", 128) + wave_m = decl.get("wave_m", 2) + wave_n = decl.get("wave_n", 2) + wave_k = decl.get("wave_k", 1) + + # Map conv_type to prefix + if conv_type == "forward": + type_prefix = "fwd" + elif conv_type == "bwd_data": + type_prefix = "bwdd" + elif conv_type == "bwd_weight": + type_prefix = "bwdw" + else: + type_prefix = conv_type + + tile_str = f"{tile_k}x{tile_c}" + wave_str = f"{wave_m}x{wave_n}x{wave_k}" + + # Search pattern with key parameters + pattern = f"conv_{type_prefix}_{dtype}_{num_dims}d_{pipeline}_*_{scheduler}_*{tile_str}*_{wave_str}.hpp" + matches = list(kernel_dir.glob(pattern)) + + if matches: + print(f" Found matching conv kernel: {matches[0].name}") + return matches[0] + + # Kernel doesn't exist - try to generate it + print(f" Conv kernel not found: {exact_filename}") + print(" Attempting to generate...") + + if generate_specific_conv_kernel(decl, gpu_target): + # Check again after generation + matches = list(kernel_dir.glob(pattern)) + if matches: + print(f" Generated: {matches[0].name}") + return matches[0] + + # Check for exact match + if exact_path.exists(): + print(f" Generated: {exact_filename}") + return exact_path + + # Still not found - print helpful error + print_error( + " ERROR: Could not find or generate conv kernel matching declaration:" + ) + print_error(f" dtype={dtype}, conv_type={conv_type}, num_dims={num_dims}") + print_error(f" pipeline={pipeline}, scheduler={scheduler}") + print_error(f" tile={tile_k}x{tile_c}, wave={wave_str}") + print_error(f" Expected: {exact_filename}") + print_error(f" Available conv kernels in {kernel_dir}:") + + available = list(kernel_dir.glob(f"conv_{type_prefix}_{dtype}_{num_dims}d_*.hpp"))[ + :5 + ] + for k in available: + print_error(f" - {k.name}") + if len(list(kernel_dir.glob(f"conv_{type_prefix}_{dtype}_{num_dims}d_*.hpp"))) > 5: + print_error(" ... and more") + + return None + + +def build_dispatcher_library(hipcc: str) -> bool: + """Build the dispatcher library if needed.""" + build_dir = get_build_dir() + lib_path = build_dir / "libck_tile_dispatcher.a" + + if lib_path.exists(): + return True + + print(" Building dispatcher library...") + build_dir.mkdir(parents=True, exist_ok=True) + + dispatcher_dir = get_dispatcher_root() + + # Run cmake + cmake_cmd = ["cmake", str(dispatcher_dir), f"-DCMAKE_CXX_COMPILER={hipcc}"] + result = subprocess.run( + cmake_cmd, cwd=str(build_dir), capture_output=True, text=True + ) + if result.returncode != 0: + print_error(f"CMake failed: {result.stderr}") + return False + + # Run make + make_cmd = ["make", "ck_tile_dispatcher", f"-j{os.cpu_count() or 4}"] + result = subprocess.run( + make_cmd, cwd=str(build_dir), capture_output=True, text=True + ) + if result.returncode != 0: + print_error(f"Make failed: {result.stderr}") + return False + + return True + + +def compile_application( + source_file: Path, + output_bin: Path, + kernel_header: Path, + hipcc: str, + gpu_target: str = "gfx942", +) -> bool: + """Compile the application with hipcc.""" + ck_root = get_ck_root() + dispatcher_dir = get_dispatcher_root() + build_dir = get_build_dir() + kernel_dir = get_generated_kernels_dir() + + includes = [ + f"-I{ck_root / 'include'}", + f"-I{dispatcher_dir / 'include'}", + f"-I{kernel_dir}", + ] + + cmd = [ + hipcc, + "-std=c++17", + "-O3", + f"--offload-arch={gpu_target}", + *includes, + "-include", + str(kernel_header), + f"-L{build_dir}", + "-lck_tile_dispatcher", + "-o", + str(output_bin), + str(source_file), + ] + + result = subprocess.run(cmd, capture_output=True, text=True) + + # Filter out nodiscard warnings + if result.stderr: + lines = result.stderr.split("\n") + errors = [line for line in lines if "error:" in line.lower()] + if errors: + for err_line in errors[:5]: + print_error(f" {err_line}") + + return result.returncode == 0 + + +# ============================================================================= +# Main +# ============================================================================= + + +def main(): + parser = argparse.ArgumentParser( + description="Build CK Tile application with declarative kernels", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Example: + python3 compile_gemm_examples.py examples/cpp/01_basic_gemm_declarative.cpp my_app + +In your C++ code, declare kernels like: + DECL_KERNEL_SET(my_kernels, + .add(Signature().dtype("fp16").layout("rcr"), + Algorithm().tile(128, 128, 32).wave(2, 2, 1).warp(32, 32, 16) + .pipeline("compv4").scheduler("intrawave")) + ); +""", + ) + parser.add_argument("source", help="Source file (.cpp)") + parser.add_argument( + "output", nargs="?", help="Output name (default: source basename)" + ) + parser.add_argument( + "--gpu-target", default="gfx942", help="GPU target architecture" + ) + parser.add_argument("--verbose", "-v", action="store_true", help="Verbose output") + args = parser.parse_args() + + # Resolve paths using utilities from ctypes_utils + dispatcher_dir = get_dispatcher_root() + build_dir = get_build_dir() + + source_file = Path(args.source) + if not source_file.is_absolute(): + # Try relative to dispatcher dir first, then CWD + candidates = [ + dispatcher_dir / args.source, + dispatcher_dir / "examples" / args.source, # examples/gemm/cpp/... + Path.cwd() / args.source, + ] + for candidate in candidates: + if candidate.exists(): + source_file = candidate + break + + if not source_file.exists(): + print_error(f"Source file not found: {source_file}") + return 1 + + output_name = args.output or source_file.stem + output_bin = build_dir / output_name + + # Ensure build directory exists + build_dir.mkdir(parents=True, exist_ok=True) + + print_success("=== CK Tile Declarative Kernel Build ===") + print() + + # Phase 1: Extract declarations (both GEMM and Conv) + print_phase("Phase 1: Scanning for kernel declarations...") + + gemm_declarations = extract_kernel_declarations(source_file) + conv_declarations = extract_conv_kernel_declarations(source_file) + + if not gemm_declarations and not conv_declarations: + print_error(" No kernel declarations found!") + print(" Add DECL_KERNEL_SET for GEMM or DECL_CONV_KERNEL_SET for Conv") + return 1 + + # Handle GEMM declarations + if gemm_declarations: + print(f"\n GEMM: Found {len(gemm_declarations)} declaration(s)") + + # Group by kernel set + sets = {} + for decl in gemm_declarations: + set_name = decl.get("set") or "(global)" + if set_name not in sets: + sets[set_name] = [] + sets[set_name].append(decl) + + for set_name, set_decls in sets.items(): + print(f" [{set_name}] ({len(set_decls)} kernels):") + for decl in set_decls[:5]: + needs_expansion = ( + decl.get("wave_m", -1) < 0 or decl.get("warp_m", -1) < 0 + ) + suffix = " [expands]" if needs_expansion else "" + display_name = ( + decl["name"].split(":")[-1] if ":" in decl["name"] else decl["name"] + ) + print(f" - {display_name}{suffix}") + if len(set_decls) > 5: + print(f" ... and {len(set_decls) - 5} more") + + # Validate declarations against arch filter + print(f"\n Validating against {args.gpu_target} arch filter...") + wildcard_count = 0 + invalid_count = 0 + auto_corrections = [] + + for decl in gemm_declarations: + arch = decl.get("arch", args.gpu_target) + decl_name = ( + decl["name"].split(":")[-1] if ":" in decl["name"] else decl["name"] + ) + + # Check for wildcards + if is_wildcard_declaration(decl): + wildcard_count += 1 + continue # Wildcards validated during expansion + + is_valid, error_msg = validate_kernel_config(decl, arch) + if not is_valid: + print(f"\n ⚠ Invalid configuration: {decl_name}") + + # Parse the error and show specific auto-corrections + corrections = [] + original_values = {} + + if "wave configuration" in error_msg.lower(): + original_values["wave"] = ( + f"[{decl.get('wave_m', 2)}, {decl.get('wave_n', 2)}, {decl.get('wave_k', 1)}]" + ) + decl["wave_m"] = -1 + decl["wave_n"] = -1 + corrections.append( + f"wave: {original_values['wave']} → [wildcard expansion]" + ) + + if "warp tile" in error_msg.lower(): + original_values["warp"] = ( + f"[{decl.get('warp_m', 32)}, {decl.get('warp_n', 32)}, {decl.get('warp_k', 16)}]" + ) + decl["warp_m"] = -1 + decl["warp_n"] = -1 + corrections.append( + f"warp_tile: {original_values['warp']} → [wildcard expansion]" + ) + + if "trait combination" in error_msg.lower(): + original_values["pipeline"] = decl.get("pipeline", "compv4") + original_values["scheduler"] = decl.get("scheduler", "intrawave") + decl["pipeline"] = "*" + decl["scheduler"] = "*" + corrections.append( + f"pipeline: {original_values['pipeline']} → [wildcard expansion]" + ) + corrections.append( + f"scheduler: {original_values['scheduler']} → [wildcard expansion]" + ) + + # Print the auto-corrections + print(" AUTO-CORRECTION:") + for corr in corrections: + print(f" • {corr}") + auto_corrections.append((decl_name, corrections)) + + invalid_count += 1 + wildcard_count += 1 + + if invalid_count > 0: + print( + f"\n ⚠ {invalid_count} invalid config(s) auto-corrected via wildcard expansion" + ) + + if wildcard_count > 0: + print( + f" ✓ {len(gemm_declarations) - wildcard_count} explicit + {wildcard_count} wildcard (will expand)" + ) + else: + print(f" ✓ All {len(gemm_declarations)} configurations valid") + + # Expand GEMM declarations (for wildcards) + print("\n Expanding wildcards to valid configurations...") + expanded_gemm = [] + for decl in gemm_declarations: + arch = decl.get("arch", args.gpu_target) + decl_name = ( + decl["name"].split(":")[-1] if ":" in decl["name"] else decl["name"] + ) + + expanded = expand_declaration_with_arch_filter(decl, arch) + expanded_gemm.extend(expanded) + + # Show what the wildcard expanded to + if len(expanded) > 1: + print( + f" {decl_name}: expanded to {len(expanded)} valid configurations" + ) + # Show first few expanded configs + for exp in expanded[:3]: + wave_str = f"[{exp['wave_m']}, {exp['wave_n']}, {exp['wave_k']}]" + warp_str = f"[{exp['warp_m']}, {exp['warp_n']}, {exp['warp_k']}]" + print( + f" → wave={wave_str}, warp={warp_str}, pipeline={exp['pipeline']}, scheduler={exp['scheduler']}" + ) + if len(expanded) > 3: + print(f" ... and {len(expanded) - 3} more") + elif len(expanded) == 1 and is_wildcard_declaration(decl): + exp = expanded[0] + wave_str = f"[{exp['wave_m']}, {exp['wave_n']}, {exp['wave_k']}]" + warp_str = f"[{exp['warp_m']}, {exp['warp_n']}, {exp['warp_k']}]" + print(f" {decl_name}: → wave={wave_str}, warp={warp_str}") + + if len(expanded_gemm) > len(gemm_declarations): + print( + f"\n Total: {len(gemm_declarations)} declarations → {len(expanded_gemm)} configurations" + ) + + gemm_declarations = expanded_gemm + + # Handle Conv declarations + if conv_declarations: + print(f"\n CONV: Found {len(conv_declarations)} declaration(s)") + + # Group by kernel set + sets = {} + for decl in conv_declarations: + set_name = decl.get("set") or "(global)" + if set_name not in sets: + sets[set_name] = [] + sets[set_name].append(decl) + + for set_name, set_decls in sets.items(): + print(f" [{set_name}] ({len(set_decls)} kernels):") + for decl in set_decls[:5]: + needs_expansion = is_conv_wildcard_declaration(decl) + suffix = " [expands]" if needs_expansion else "" + display_name = ( + decl["name"].split(":")[-1] if ":" in decl["name"] else decl["name"] + ) + print(f" - {display_name}{suffix}") + if len(set_decls) > 5: + print(f" ... and {len(set_decls) - 5} more") + + # Validate Conv declarations against arch filter + print(f"\n Validating against {args.gpu_target} arch filter...") + wildcard_count = 0 + invalid_count = 0 + auto_corrections = [] + + for decl in conv_declarations: + arch = decl.get("arch", args.gpu_target) + decl_name = ( + decl["name"].split(":")[-1] if ":" in decl["name"] else decl["name"] + ) + + # Check for wildcards + if is_conv_wildcard_declaration(decl): + wildcard_count += 1 + continue # Wildcards validated during expansion + + is_valid, error_msg = validate_conv_kernel_config(decl, arch) + if not is_valid: + print(f"\n ⚠ Invalid conv configuration: {decl_name}") + + # Parse the error and show specific auto-corrections + corrections = [] + original_values = {} + + if "wave configuration" in error_msg.lower(): + original_values["wave"] = ( + f"[{decl.get('wave_m', 2)}, {decl.get('wave_n', 2)}, {decl.get('wave_k', 1)}]" + ) + decl["wave_m"] = -1 + decl["wave_n"] = -1 + corrections.append( + f"wave: {original_values['wave']} → [wildcard expansion]" + ) + + if "warp tile" in error_msg.lower(): + original_values["warp"] = ( + f"[{decl.get('warp_m', 32)}, {decl.get('warp_n', 32)}, {decl.get('warp_k', 16)}]" + ) + decl["warp_m"] = -1 + decl["warp_n"] = -1 + corrections.append( + f"warp_tile: {original_values['warp']} → [wildcard expansion]" + ) + + if "trait combination" in error_msg.lower(): + original_values["pipeline"] = decl.get("pipeline", "compv3") + original_values["scheduler"] = decl.get("scheduler", "intrawave") + decl["pipeline"] = "*" + decl["scheduler"] = "*" + corrections.append( + f"pipeline: {original_values['pipeline']} → [wildcard expansion]" + ) + corrections.append( + f"scheduler: {original_values['scheduler']} → [wildcard expansion]" + ) + + # Print the auto-corrections + print(" AUTO-CORRECTION:") + for corr in corrections: + print(f" • {corr}") + auto_corrections.append((decl_name, corrections)) + + invalid_count += 1 + wildcard_count += 1 + + if invalid_count > 0: + print( + f"\n ⚠ {invalid_count} invalid config(s) auto-corrected via wildcard expansion" + ) + + if wildcard_count > 0: + print( + f" ✓ {len(conv_declarations) - wildcard_count} explicit + {wildcard_count} wildcard (will expand)" + ) + else: + print(f" ✓ All {len(conv_declarations)} configurations valid") + + # Expand Conv declarations (for wildcards) + print("\n Expanding wildcards to valid configurations...") + expanded_conv = [] + for decl in conv_declarations: + arch = decl.get("arch", args.gpu_target) + decl_name = ( + decl["name"].split(":")[-1] if ":" in decl["name"] else decl["name"] + ) + + expanded = expand_conv_declaration_with_arch_filter(decl, arch) + expanded_conv.extend(expanded) + + # Show what the wildcard expanded to + if len(expanded) > 1: + print( + f" {decl_name}: expanded to {len(expanded)} valid configurations" + ) + for exp in expanded[:3]: + wave_str = f"[{exp['wave_m']}, {exp['wave_n']}, {exp['wave_k']}]" + warp_str = f"[{exp['warp_m']}, {exp['warp_n']}, {exp['warp_k']}]" + print( + f" → wave={wave_str}, warp={warp_str}, pipeline={exp['pipeline']}, scheduler={exp['scheduler']}" + ) + if len(expanded) > 3: + print(f" ... and {len(expanded) - 3} more") + elif len(expanded) == 1 and is_conv_wildcard_declaration(decl): + exp = expanded[0] + wave_str = f"[{exp['wave_m']}, {exp['wave_n']}, {exp['wave_k']}]" + warp_str = f"[{exp['warp_m']}, {exp['warp_n']}, {exp['warp_k']}]" + print(f" {decl_name}: → wave={wave_str}, warp={warp_str}") + + if len(expanded_conv) > len(conv_declarations): + print( + f"\n Total: {len(conv_declarations)} declarations → {len(expanded_conv)} configurations" + ) + + conv_declarations = expanded_conv + + print() + + # Phase 2: Generate kernels + print_phase("Phase 2: Generating kernels...") + + total_generated = 0 + + # Generate GEMM kernels + if gemm_declarations: + print(" GEMM kernels:") + num_gemm = generate_kernels(gemm_declarations, args.gpu_target) + total_generated += num_gemm + print(f" Generated: {num_gemm}") + + # Generate Conv kernels + if conv_declarations: + print(" CONV kernels:") + num_conv = generate_conv_kernels(conv_declarations, args.gpu_target) + total_generated += num_conv + print(f" Generated: {num_conv}") + + print(f" Total kernel files: {total_generated}") + print() + + # Phase 3: Find kernel header + print_phase("Phase 3: Selecting kernel for compilation...") + + kernel_headers = [] + + # Find GEMM kernel header (try each expanded declaration until one matches) + if gemm_declarations: + gemm_header = None + for decl in gemm_declarations: + header = find_kernel_header(decl, args.gpu_target) + if header: + gemm_header = header + break + + if gemm_header: + kernel_headers.append(gemm_header) + print(f" GEMM: {gemm_header.name}") + else: + print_error(" GEMM: No kernel found matching any declaration!") + print_error( + " The kernels declared in DECL_KERNEL_SET must exist or be generatable." + ) + return 1 + + # Find Conv kernel header + if conv_declarations: + first_conv = conv_declarations[0] + conv_header = find_conv_kernel_header(first_conv) + if conv_header: + kernel_headers.append(conv_header) + print(f" CONV: {conv_header.name}") + + if not kernel_headers: + print_error(" No kernel headers found!") + return 1 + + # Use first available header (can be extended to use multiple) + kernel_header = kernel_headers[0] + print() + + # Phase 4: Build dispatcher library + print_phase("Phase 4: Building dispatcher library...") + hipcc = find_hipcc() + + if not build_dispatcher_library(hipcc): + print_error(" Failed to build dispatcher library!") + return 1 + print(" Done") + print() + + # Phase 5: Compile application + print_phase("Phase 5: Compiling application...") + + if not compile_application( + source_file, output_bin, kernel_header, hipcc, args.gpu_target + ): + print_error(" Compilation failed!") + return 1 + + print(f" Output: {output_bin}") + print() + + # Done + print_success("=== Build Complete ===") + print() + print("Run with:") + print(f" {output_bin}") + print() + print("List declared kernels:") + print(f" {output_bin} --list-kernels") + + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/dispatcher/scripts/example_kernel_builder.py b/dispatcher/scripts/example_kernel_builder.py new file mode 100755 index 00000000000..d3bb6191744 --- /dev/null +++ b/dispatcher/scripts/example_kernel_builder.py @@ -0,0 +1,1447 @@ +#!/usr/bin/env python3 + +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +Build example kernels - generates and compiles kernels for a single example. + +Detects if example is GEMM or Conv based on macro presence, extracts all +configuration parameters, and generates appropriate kernels. +""" + +import argparse +import os +import re +import shutil +import subprocess +import sys +from pathlib import Path +from concurrent.futures import ProcessPoolExecutor, as_completed +from typing import Dict, List, Tuple + + +def find_hipcc() -> str: + for path in [os.environ.get("HIPCC"), "/opt/rocm/bin/hipcc", shutil.which("hipcc")]: + if path and os.path.isfile(path): + return path + return "hipcc" + + +def find_ar() -> str: + for path in [ + "/opt/rocm/llvm/bin/llvm-ar", + shutil.which("llvm-ar"), + shutil.which("ar"), + ]: + if path and os.path.isfile(path): + return path + return "ar" + + +def extract_balanced_parens(text: str, start_pos: int) -> str: + """Extract content between balanced parentheses.""" + if start_pos >= len(text) or text[start_pos] != "(": + return "" + depth = 0 + for i, c in enumerate(text[start_pos:], start_pos): + if c == "(": + depth += 1 + elif c == ")": + depth -= 1 + if depth == 0: + return text[start_pos + 1 : i] + return "" + + +def parse_conv_declarations(content: str) -> List[Dict]: + """Parse DECL_CONV_KERNEL_SET declarations with all parameters.""" + kernels = [] + + for match in re.finditer(r"DECL_CONV_KERNEL_SET\s*\(", content): + body = extract_balanced_parens(content, match.end() - 1) + if not body: + continue + + # Parse each .add() call + for add_match in re.finditer(r"\.add\s*\(", body): + add_body = extract_balanced_parens(body, add_match.end() - 1) + + kernel = {} + + # ConvSig parameters - handle both single dtype and multi-dtype + # Multi-dtype: .dtype("fp16", "fp16", "fp16", "fp32") or .dtype("fp16", "bf16", "fp16") + if m := re.search( + r'\.dtype\s*\(\s*"([^"]+)"\s*,\s*"([^"]+)"\s*,\s*"([^"]+)"(?:\s*,\s*"([^"]+)")?\s*\)', + add_body, + ): + kernel["dtype_in"] = m.group(1) + kernel["dtype_wei"] = m.group(2) + kernel["dtype_out"] = m.group(3) + kernel["dtype_acc"] = m.group(4) if m.group(4) else "fp32" + kernel["dtype"] = m.group(1) # Default for codegen + # Single dtype: .dtype("fp16") + elif m := re.search(r'\.dtype\s*\(\s*"([^"]+)"\s*\)', add_body): + kernel["dtype"] = m.group(1) + kernel["dtype_in"] = m.group(1) + kernel["dtype_wei"] = m.group(1) + kernel["dtype_out"] = m.group(1) + kernel["dtype_acc"] = "fp32" + if m := re.search(r'\.layout\s*\(\s*"([^"]+)"', add_body): + kernel["layout"] = m.group(1) + if m := re.search(r'\.conv_type\s*\(\s*"([^"]+)"', add_body): + kernel["conv_type"] = m.group(1) + if m := re.search(r"\.dims\s*\(\s*(\d+)\s*\)", add_body): + kernel["ndim"] = int(m.group(1)) + + # ConvAlgo parameters - tile(G, M, N) where G=batch, M=output, N=reduction + if m := re.search( + r"\.tile\s*\(\s*(\d+)\s*,\s*(\d+)\s*,\s*(\d+)\s*\)", add_body + ): + kernel["tile_g"] = int(m.group(1)) # batch tile (usually 1) + kernel["tile_m"] = int(m.group(2)) # output channel tile + kernel["tile_n"] = int(m.group(3)) # input channel tile (reduction) + + # wave(M_Warp, N_Warp, K_Warp) - warp distribution + if m := re.search( + r"\.wave\s*\(\s*(\d+)\s*,\s*(\d+)\s*,\s*(\d+)\s*\)", add_body + ): + kernel["warp_m"] = int(m.group(1)) + kernel["warp_n"] = int(m.group(2)) + kernel["warp_k"] = int(m.group(3)) + + # warp(M_Warp_Tile, N_Warp_Tile, K_Warp_Tile) - warp tile sizes + if m := re.search( + r"\.warp\s*\(\s*(\d+)\s*,\s*(\d+)\s*,\s*(\d+)\s*\)", add_body + ): + kernel["warp_tile_m"] = int(m.group(1)) + kernel["warp_tile_n"] = int(m.group(2)) + kernel["warp_tile_k"] = int(m.group(3)) + + # vector_sizes(A, B, C) + if m := re.search( + r"\.vector_sizes\s*\(\s*(\d+)\s*,\s*(\d+)\s*,\s*(\d+)\s*\)", add_body + ): + kernel["vector_a"] = int(m.group(1)) + kernel["vector_b"] = int(m.group(2)) + kernel["vector_c"] = int(m.group(3)) + + # Single-value parameters + if m := re.search(r'\.pipeline\s*\(\s*"([^"]+)"', add_body): + kernel["pipeline"] = m.group(1) + if m := re.search(r'\.scheduler\s*\(\s*"([^"]+)"', add_body): + kernel["scheduler"] = m.group(1) + if m := re.search(r'\.epilogue\s*\(\s*"([^"]+)"', add_body): + kernel["epilogue"] = m.group(1) + if m := re.search(r"\.block_per_cu\s*\(\s*(\d+)\s*\)", add_body): + kernel["block_per_cu"] = int(m.group(1)) + if m := re.search(r"\.num_wave_groups\s*\(\s*(\d+)\s*\)", add_body): + kernel["num_wave_groups"] = int(m.group(1)) + if m := re.search(r"\.num_groups_to_merge\s*\(\s*(\d+)\s*\)", add_body): + kernel["num_groups_to_merge"] = int(m.group(1)) + if m := re.search( + r"\.double_smem_buffer\s*\(\s*(true|false)\s*\)", add_body, re.I + ): + kernel["double_smem_buffer"] = m.group(1).lower() == "true" + + # Architecture + if m := re.search(r'"(gfx\d+)"', add_body): + kernel["arch"] = m.group(1) + + if kernel.get("dtype"): + # Auto-fill missing parameters with defaults (autocorrect) + kernel = auto_fill_conv_defaults(kernel) + kernels.append(kernel) + + return kernels + + +def auto_fill_conv_defaults(kernel: Dict) -> Dict: + """Auto-fill missing conv parameters with sensible defaults (autofill + autocorrect). + + This implements: + 1. AUTOFILL: Missing parameters are filled with valid defaults (ConvConfigComputeV3) + 2. AUTOCORRECT: Invalid values are corrected to valid ones + """ + # Default tile configuration matching ConvConfigComputeV3 + defaults = { + "tile_g": 1, + "tile_m": 16, + "tile_n": 64, + "warp_m": 1, + "warp_n": 4, + "warp_k": 1, + "warp_tile_m": 16, + "warp_tile_n": 16, + "warp_tile_k": 32, + "pipeline": "compv3", + "scheduler": "intrawave", + "epilogue": "cshuffle", + "vector_a": 4, + "vector_b": 8, + "vector_c": 8, + "block_per_cu": 1, + "num_wave_groups": 1, + "num_groups_to_merge": 1, + "ndim": 2, + "layout": "nhwgc", + "conv_type": "forward", + "arch": "gfx942", + } + + # AUTOFILL: Fill missing parameters with defaults + autofilled = [] + for key, value in defaults.items(): + if key not in kernel or kernel[key] is None or kernel[key] == -1: + kernel[key] = value + autofilled.append(f"{key}={value}") + + if autofilled: + print(f" [AUTOFILL] {', '.join(autofilled)}") + + # AUTOCORRECT: Fix invalid wave configurations for gfx942 + valid_wave_configs = [(1, 4, 1), (2, 2, 1), (4, 1, 1)] + current_wave = ( + kernel.get("warp_m", 1), + kernel.get("warp_n", 4), + kernel.get("warp_k", 1), + ) + + if current_wave not in valid_wave_configs: + old = current_wave + kernel["warp_m"] = 1 + kernel["warp_n"] = 4 + kernel["warp_k"] = 1 + print(f" [AUTOCORRECT] wave{old} -> wave(1,4,1) (invalid for gfx942)") + + # AUTOCORRECT: Fix invalid pipeline for backward ops + conv_type = kernel.get("conv_type", "forward") + pipeline = kernel.get("pipeline", "compv3") + + if conv_type in ["bwd_data", "bwd_weight"] and pipeline in ["compv4", "compv5"]: + old_pipeline = pipeline + kernel["pipeline"] = "compv3" + print( + f" [AUTOCORRECT] pipeline {old_pipeline} -> compv3 (invalid for {conv_type})" + ) + + return kernel + + +def expand_conv_wildcards(kernel: Dict, arch: str = "gfx942") -> List[Dict]: + """Expand wildcard parameters to multiple valid configurations. + + When users specify wildcards (-1 or *), this expands them to all + valid configurations for the target architecture. + """ + expanded = [] + + # Valid wave configurations for gfx942 + valid_wave_configs = [(1, 4, 1), (2, 2, 1), (4, 1, 1)] + + # Valid warp tile configurations for gfx942 fp16 + valid_warp_configs = [(16, 16, 32), (32, 32, 16)] + + # Check if expansion is needed + needs_wave = kernel.get("warp_m") is None or kernel.get("warp_m") == -1 + needs_warp = kernel.get("warp_tile_m") is None or kernel.get("warp_tile_m") == -1 + + if not needs_wave and not needs_warp: + return [kernel] + + # Expand wave configurations + wave_configs = ( + valid_wave_configs + if needs_wave + else [ + (kernel.get("warp_m", 2), kernel.get("warp_n", 2), kernel.get("warp_k", 1)) + ] + ) + + # Expand warp tile configurations + warp_configs = ( + valid_warp_configs + if needs_warp + else [ + ( + kernel.get("warp_tile_m", 32), + kernel.get("warp_tile_n", 32), + kernel.get("warp_tile_k", 16), + ) + ] + ) + + for wm, wn, wk in wave_configs: + for wtm, wtn, wtk in warp_configs: + new_kernel = kernel.copy() + new_kernel["warp_m"] = wm + new_kernel["warp_n"] = wn + new_kernel["warp_k"] = wk + new_kernel["warp_tile_m"] = wtm + new_kernel["warp_tile_n"] = wtn + new_kernel["warp_tile_k"] = wtk + expanded.append(new_kernel) + + return expanded + + +def parse_int_or_wildcard(val: str) -> int: + """Parse integer or return -1 for wildcards. + + Supported wildcard formats: + - ANY_INT: Macro defined as -1 + - -1: Direct numeric wildcard + - "*": String wildcard (also maps to -1 for integer params) + """ + val = val.strip() + if val == "ANY_INT" or val == "-1" or val == "*": + return -1 + return int(val) + + +def parse_gemm_declarations(content: str) -> List[Dict]: + """Parse DECL_KERNEL_SET declarations for GEMM. + + Supports wildcards: + - ANY_INT for numeric params (wave, warp) -> expands to all valid combos + - "*" for string params (pipeline, scheduler) -> expands to valid options + + Each kernel is tagged with its kernel_set name for separate registration. + """ + kernels = [] + + for match in re.finditer(r"DECL_KERNEL_SET\s*\(\s*(\w+)\s*,", content): + kernel_set_name = match.group(1) + body = extract_balanced_parens( + content, match.start() + content[match.start() :].find("(") + ) + if not body: + continue + + for add_match in re.finditer(r"\.add\s*\(", body): + add_body = extract_balanced_parens(body, add_match.end() - 1) + + kernel = {} + + # Signature parameters + if m := re.search(r'\.dtype\s*\(\s*"([^"]+)"', add_body): + kernel["dtype"] = m.group(1) + if m := re.search(r'\.layout\s*\(\s*"([^"]+)"', add_body): + kernel["layout"] = m.group(1) + if m := re.search(r'\.elementwise\s*\(\s*"([^"]+)"\s*,\s*(\d+)', add_body): + kernel["elementwise_op"] = m.group(1) + kernel["num_d_tensors"] = int(m.group(2)) + + # Algorithm parameters - support ANY_INT wildcard + if m := re.search( + r"\.tile\s*\(\s*(\d+)\s*,\s*(\d+)\s*,\s*(\d+)\s*\)", add_body + ): + kernel["tile_m"] = int(m.group(1)) + kernel["tile_n"] = int(m.group(2)) + kernel["tile_k"] = int(m.group(3)) + + # Wave: support ANY_INT, -1, and "*" as wildcards + if m := re.search( + r"\.wave\s*\(\s*([\w*-]+)\s*,\s*([\w*-]+)\s*,\s*([\w*-]+)\s*\)", + add_body, + ): + kernel["warp_m"] = parse_int_or_wildcard(m.group(1)) + kernel["warp_n"] = parse_int_or_wildcard(m.group(2)) + kernel["warp_k"] = parse_int_or_wildcard(m.group(3)) + + # Warp: support ANY_INT, -1, and "*" as wildcards + if m := re.search( + r"\.warp\s*\(\s*([\w*-]+)\s*,\s*([\w*-]+)\s*,\s*([\w*-]+)\s*\)", + add_body, + ): + kernel["warp_tile_m"] = parse_int_or_wildcard(m.group(1)) + kernel["warp_tile_n"] = parse_int_or_wildcard(m.group(2)) + kernel["warp_tile_k"] = parse_int_or_wildcard(m.group(3)) + + # Pipeline/Scheduler: support "*" wildcard + if m := re.search(r'\.pipeline\s*\(\s*"([^"]+)"', add_body): + kernel["pipeline"] = m.group(1) + if m := re.search(r'\.scheduler\s*\(\s*"([^"]+)"', add_body): + kernel["scheduler"] = m.group(1) + if m := re.search(r'\.epilogue\s*\(\s*"([^"]+)"', add_body): + kernel["epilogue"] = m.group(1) + if m := re.search( + r"\.pad\s*\(\s*(true|false)\s*,\s*(true|false)\s*,\s*(true|false)", + add_body, + re.I, + ): + kernel["pad_m"] = m.group(1).lower() == "true" + kernel["pad_n"] = m.group(2).lower() == "true" + kernel["pad_k"] = m.group(3).lower() == "true" + + # Shorthand format: .add("dtype", "layout", M, N, K) + if not kernel.get("dtype"): + if m := re.match( + r'\s*"(\w+)"\s*,\s*"(\w+)"\s*,\s*(\d+)\s*,\s*(\d+)\s*,\s*(\d+)', + add_body, + ): + kernel["dtype"] = m.group(1) + kernel["layout"] = m.group(2) + kernel["tile_m"] = int(m.group(3)) + kernel["tile_n"] = int(m.group(4)) + kernel["tile_k"] = int(m.group(5)) + + if kernel.get("dtype"): + kernel["kernel_set"] = kernel_set_name + kernels.append(kernel) + + # Expand wildcards to multiple kernels + expanded = [] + for kernel in kernels: + expanded.extend(expand_gemm_wildcards(kernel)) + + # Apply autocorrect to each expanded kernel + return [auto_fill_gemm_defaults(k) for k in expanded] + + +def expand_gemm_wildcards(kernel: Dict, arch: str = "gfx942") -> List[Dict]: + """Expand wildcard parameters to multiple valid configurations. + + When users specify ANY_INT (-1) or "*", this expands them to all + valid configurations for the target architecture. + + Note: Block size constraint filters invalid combos: + - (tile_m/warp_tile_m) * (tile_n/warp_tile_n) * 64 <= 1024 + - For 128x128 tile: only (32,32,k) works (16 warps * 64 = 1024) + - For 64x64 tile: both (16,16,k) and (32,32,k) work + """ + # Valid wave configurations for gfx942 + valid_wave_configs = [(1, 4, 1), (2, 2, 1), (4, 1, 1)] + + # Valid warp tile configurations for gfx942 fp16 + valid_warp_configs = [(16, 16, 32), (32, 32, 16)] + + # Valid pipelines and schedulers + valid_pipelines = ["compv3"] # compv4 requires special handling + valid_schedulers = ["intrawave"] + + # Check what needs expansion + needs_wave = kernel.get("warp_m") == -1 + needs_warp = kernel.get("warp_tile_m") == -1 + needs_pipeline = kernel.get("pipeline") == "*" + needs_scheduler = kernel.get("scheduler") == "*" + + if not any([needs_wave, needs_warp, needs_pipeline, needs_scheduler]): + return [kernel] + + # Determine configs to iterate + wave_configs = ( + valid_wave_configs + if needs_wave + else [ + (kernel.get("warp_m", 2), kernel.get("warp_n", 2), kernel.get("warp_k", 1)) + ] + ) + warp_configs = ( + valid_warp_configs + if needs_warp + else [ + ( + kernel.get("warp_tile_m", 32), + kernel.get("warp_tile_n", 32), + kernel.get("warp_tile_k", 16), + ) + ] + ) + pipelines = ( + valid_pipelines if needs_pipeline else [kernel.get("pipeline", "compv3")] + ) + schedulers = ( + valid_schedulers if needs_scheduler else [kernel.get("scheduler", "intrawave")] + ) + + expanded = [] + for wm, wn, wk in wave_configs: + for wtm, wtn, wtk in warp_configs: + # Check block size constraint: (tile_m/warp_tile_m) * (tile_n/warp_tile_n) * 64 <= 1024 + tile_m = kernel.get("tile_m", 128) + tile_n = kernel.get("tile_n", 128) + num_warps = (tile_m // wtm) * (tile_n // wtn) + if num_warps * 64 > 1024: + continue # Skip invalid config + + for pipe in pipelines: + for sched in schedulers: + new_kernel = kernel.copy() + new_kernel["warp_m"] = wm + new_kernel["warp_n"] = wn + new_kernel["warp_k"] = wk + new_kernel["warp_tile_m"] = wtm + new_kernel["warp_tile_n"] = wtn + new_kernel["warp_tile_k"] = wtk + new_kernel["pipeline"] = pipe + new_kernel["scheduler"] = sched + expanded.append(new_kernel) + + if expanded: + print(f" [WILDCARD] Expanded 1 declaration -> {len(expanded)} kernel(s)") + + return expanded if expanded else [kernel] + + +def auto_fill_gemm_defaults(kernel: Dict) -> Dict: + """Auto-fill missing GEMM parameters with sensible defaults (autofill + autocorrect). + + This implements: + 1. AUTOFILL: Missing parameters are filled with valid defaults + 2. AUTOCORRECT: Invalid values are corrected to valid ones (e.g., wave(1,1,1) -> wave(2,2,1)) + """ + defaults = { + "tile_m": 128, + "tile_n": 128, + "tile_k": 64, + "warp_m": 2, + "warp_n": 2, + "warp_k": 1, + "warp_tile_m": 32, + "warp_tile_n": 32, + "warp_tile_k": 16, + "pipeline": "compv3", + "scheduler": "intrawave", + "epilogue": "cshuffle", + "pad_m": False, + "pad_n": False, + "pad_k": False, + "layout": "rcr", + } + + # AUTOFILL: Fill missing parameters with defaults + autofilled = [] + for key, value in defaults.items(): + if key not in kernel or kernel[key] is None or kernel[key] == -1: + kernel[key] = value + autofilled.append(f"{key}={value}") + + if autofilled: + print(f" [AUTOFILL] {', '.join(autofilled)}") + + # AUTOCORRECT: Fix invalid wave configurations for gfx942 + # Valid wave configs: (1,4,1), (2,2,1), (4,1,1) + valid_wave_configs = [(1, 4, 1), (2, 2, 1), (4, 1, 1)] + current_wave = ( + kernel.get("warp_m", 2), + kernel.get("warp_n", 2), + kernel.get("warp_k", 1), + ) + + if current_wave not in valid_wave_configs: + # Correct to (2,2,1) which is a balanced default + old = current_wave + kernel["warp_m"] = 2 + kernel["warp_n"] = 2 + kernel["warp_k"] = 1 + print(f" [AUTOCORRECT] wave{old} -> wave(2,2,1) (invalid for gfx942)") + + # AUTOCORRECT: Fix invalid pipeline/scheduler combinations + invalid_combos = [ + ("compv3", "interwave"), + ("compv4", "interwave"), + ] + current_combo = ( + kernel.get("pipeline", "compv3"), + kernel.get("scheduler", "intrawave"), + ) + if current_combo in invalid_combos: + old = current_combo + kernel["scheduler"] = "intrawave" + print( + f" [AUTOCORRECT] {old[0]}/{old[1]} -> {old[0]}/intrawave (invalid combo)" + ) + + # AUTOCORRECT: Fix warp tile to avoid exceeding max block size (1024 threads) + # Block size = (tile_m / warp_tile_m) * (tile_n / warp_tile_n) * 64 + tile_m = kernel.get("tile_m", 128) + tile_n = kernel.get("tile_n", 128) + warp_tile_m = kernel.get("warp_tile_m", 32) + warp_tile_n = kernel.get("warp_tile_n", 32) + + num_warps = (tile_m // warp_tile_m) * (tile_n // warp_tile_n) + block_size = num_warps * 64 # 64 threads per warp + + if block_size > 1024: + # Find valid warp tile that fits + old_warp = (warp_tile_m, warp_tile_n, kernel.get("warp_tile_k", 16)) + + # For large tiles, use larger warp tiles + if tile_m >= 256: + kernel["warp_tile_m"] = 64 + if tile_n >= 256: + kernel["warp_tile_n"] = 64 + + # Recalculate + num_warps = (tile_m // kernel["warp_tile_m"]) * ( + tile_n // kernel["warp_tile_n"] + ) + block_size = num_warps * 64 + + if block_size <= 1024: + new_warp = ( + kernel["warp_tile_m"], + kernel["warp_tile_n"], + kernel["warp_tile_k"], + ) + print( + f" [AUTOCORRECT] warp{old_warp} -> warp{new_warp} (block_size={block_size})" + ) + else: + # Still too large, try even larger warp tiles + kernel["warp_tile_m"] = tile_m // 4 + kernel["warp_tile_n"] = tile_n // 4 + new_warp = ( + kernel["warp_tile_m"], + kernel["warp_tile_n"], + kernel["warp_tile_k"], + ) + print( + f" [AUTOCORRECT] warp{old_warp} -> warp{new_warp} (block_size adjusted)" + ) + + return kernel + + +def strip_cpp_strings_and_comments(content: str) -> str: + """Strip C++ string literals and comments that could cause false positives. + + Only strips: + - Comments (// and /* */) - always stripped + - Raw string literals (R"...") - always stripped (can contain anything) + - Regular strings ONLY if they contain problematic patterns like DECL_KERNEL_SET + + Preserves normal string literals like "fp16", "rcr" which are needed for parsing. + """ + result = [] + i = 0 + n = len(content) + + # Patterns that indicate a string is problematic and should be stripped + problematic_patterns = ["DECL_KERNEL_SET", "DECL_CONV_KERNEL_SET", ".add("] + + while i < n: + # Check for raw string literal: R"delimiter(...)delimiter" + # Always strip these as they can contain arbitrary content + if i < n - 1 and content[i] == "R" and content[i + 1] == '"': + # Find the delimiter (between R" and () + j = i + 2 + delimiter_start = j + while j < n and content[j] != "(": + j += 1 + delimiter = content[delimiter_start:j] + # Find the closing )delimiter" + end_marker = ")" + delimiter + '"' + end_pos = content.find(end_marker, j + 1) + if end_pos != -1: + # Replace with spaces to preserve line numbers + span = content[i : end_pos + len(end_marker)] + result.append("".join("\n" if c == "\n" else " " for c in span)) + i = end_pos + len(end_marker) + continue + + # Check for regular string literal - only strip if it contains problematic patterns + if content[i] == '"': + j = i + 1 + while j < n: + if content[j] == "\\" and j + 1 < n: + j += 2 # Skip escaped character + elif content[j] == '"': + j += 1 + break + else: + j += 1 + string_content = content[i:j] + + # Only strip if this string contains problematic patterns + should_strip = any(pat in string_content for pat in problematic_patterns) + if should_strip: + result.append(" " * len(string_content)) + else: + result.append(string_content) + i = j + continue + + # Check for single-line comment - always strip + if i < n - 1 and content[i : i + 2] == "//": + j = i + while j < n and content[j] != "\n": + j += 1 + result.append(" " * (j - i)) + i = j + continue + + # Check for multi-line comment - always strip + if i < n - 1 and content[i : i + 2] == "/*": + end_pos = content.find("*/", i + 2) + if end_pos != -1: + span = content[i : end_pos + 2] + # Preserve newlines in multi-line comments + result.append("".join("\n" if c == "\n" else " " for c in span)) + i = end_pos + 2 + continue + + result.append(content[i]) + i += 1 + + return "".join(result) + + +def detect_and_parse(source_path: Path) -> Tuple[str, List[Dict]]: + """Detect example type and parse kernel declarations. + + Properly strips string literals and comments before parsing to avoid + picking up declarations inside strings or commented-out code. + """ + content = source_path.read_text() + content = strip_cpp_strings_and_comments(content) + + if "DECL_CONV_KERNEL_SET" in content: + return "conv", parse_conv_declarations(content) + elif "DECL_KERNEL_SET" in content: + return "gemm", parse_gemm_declarations(content) + return "unknown", [] + + +def generate_gemm_registration( + kernel_headers: List[Path], example_name: str, kernels: List[Dict] = None +) -> str: + """Generate GEMM kernel registration code for the dispatcher registry. + + Uses GeneratedKernelInstance to wrap the generated kernels + and provide the KernelInstance interface for the Dispatcher. + + If kernels list is provided with kernel_set info, generates separate + registration functions per kernel set. + """ + if not kernel_headers: + return " // No kernels to register" + + # Build mapping from kernel config pattern to kernel set + kernel_to_set = {} + kernel_sets = set() + if kernels: + for k in kernels: + tile_m = k.get("tile_m", 128) + tile_n = k.get("tile_n", 128) + tile_k = k.get("tile_k", 64) + warp_m = k.get("warp_m", 2) + warp_n = k.get("warp_n", 2) + warp_k = k.get("warp_k", 1) + warp_tile_m = k.get("warp_tile_m", 32) + warp_tile_n = k.get("warp_tile_n", 32) + warp_tile_k = k.get("warp_tile_k", 16) + + # Pattern that appears in kernel filename + key_pattern = f"{tile_m}x{tile_n}x{tile_k}_{warp_m}x{warp_n}x{warp_k}_{warp_tile_m}x{warp_tile_n}x{warp_tile_k}" + kernel_set = k.get("kernel_set", "default") + kernel_to_set[key_pattern] = kernel_set + kernel_sets.add(kernel_set) + + def generate_registration_block(h: Path) -> str: + """Generate registration code for a single kernel.""" + kernel_name = h.stem + ns = f"ns_{kernel_name}" + + # Parse pipeline, scheduler, and layout from kernel name + # Format: gemm_fp16_rcr_compv3_cshuffle_intrawave_... + parts = kernel_name.split("_") + pipeline = "CompV3" + scheduler = "Intrawave" + epilogue = "CShuffle" + datatype = "FP16" + layout_a = "RowMajor" + layout_b = "ColMajor" + layout_c = "RowMajor" + + # Parse datatype (e.g., fp16, bf16, fp32) + dtype_map = { + "fp16": "FP16", + "bf16": "BF16", + "fp32": "FP32", + "fp64": "FP64", + "int8": "INT8", + } + + # Parse layout from 3-char codes (e.g., rcr, rrr, rrc, ccc) + # r = RowMajor, c = ColMajor + layout_map = {"r": "RowMajor", "c": "ColMajor"} + + # Find pipeline, epilogue, scheduler in the name parts + pipeline_map = { + "mem": "Mem", + "compv1": "CompV1", + "compv2": "CompV2", + "compv3": "CompV3", + "compv4": "CompV4", + "compv5": "CompV5", + "preshufflev1": "PreShuffleV1", + "preshufflev2": "PreShuffleV2", + } + scheduler_map = { + "intrawave": "Intrawave", + "interwave": "Interwave", + "auto": "Auto", + } + epilogue_map = {"default": "Default", "cshuffle": "CShuffle", "none": "None"} + + for part in parts: + if part in pipeline_map: + pipeline = pipeline_map[part] + if part in scheduler_map: + scheduler = scheduler_map[part] + if part in epilogue_map: + epilogue = epilogue_map[part] + if part in dtype_map: + datatype = dtype_map[part] + # Parse 3-char layout codes (e.g., rcr, rrr) + if len(part) == 3 and all(c in "rc" for c in part): + layout_a = layout_map[part[0]] + layout_b = layout_map[part[1]] + layout_c = layout_map[part[2]] + + block = [] + block.append(f" // Register kernel: {kernel_name}") + block.append(" {") + block.append(f" using SelectedKernel = {ns}::SelectedKernel;") + block.append(" ck_tile::dispatcher::KernelKey key;") + block.append( + f" key.signature.dtype_a = ck_tile::dispatcher::DataType::{datatype};" + ) + block.append( + f" key.signature.dtype_b = ck_tile::dispatcher::DataType::{datatype};" + ) + block.append( + f" key.signature.dtype_c = ck_tile::dispatcher::DataType::{datatype};" + ) + block.append( + " key.signature.dtype_acc = ck_tile::dispatcher::DataType::FP32;" + ) + block.append( + f" key.signature.layout_a = ck_tile::dispatcher::LayoutTag::{layout_a};" + ) + block.append( + f" key.signature.layout_b = ck_tile::dispatcher::LayoutTag::{layout_b};" + ) + block.append( + f" key.signature.layout_c = ck_tile::dispatcher::LayoutTag::{layout_c};" + ) + block.append(" key.algorithm.tile_shape.m = SelectedKernel::TileM;") + block.append(" key.algorithm.tile_shape.n = SelectedKernel::TileN;") + block.append(" key.algorithm.tile_shape.k = SelectedKernel::TileK;") + block.append( + " key.algorithm.wave_shape.m = SelectedKernel::WarpPerBlock_M;" + ) + block.append( + " key.algorithm.wave_shape.n = SelectedKernel::WarpPerBlock_N;" + ) + block.append( + " key.algorithm.wave_shape.k = SelectedKernel::WarpPerBlock_K;" + ) + block.append( + " key.algorithm.warp_tile_shape.m = SelectedKernel::WarpTileM;" + ) + block.append( + " key.algorithm.warp_tile_shape.n = SelectedKernel::WarpTileN;" + ) + block.append( + " key.algorithm.warp_tile_shape.k = SelectedKernel::WarpTileK;" + ) + block.append( + " key.algorithm.block_size = SelectedKernel::BlockSize;" + ) + block.append( + f" key.algorithm.pipeline = ck_tile::dispatcher::Pipeline::{pipeline};" + ) + block.append( + f" key.algorithm.scheduler = ck_tile::dispatcher::Scheduler::{scheduler};" + ) + block.append( + f" key.algorithm.epilogue = ck_tile::dispatcher::Epilogue::{epilogue};" + ) + block.append(" key.gfx_arch = arch;") + block.append( + f' auto instance = std::make_shared>(key, "{kernel_name}");' + ) + block.append(" registry.register_kernel(instance);") + block.append(" }") + return "\n".join(block) + + def find_kernel_set(header: Path) -> str: + """Find which kernel set a header belongs to.""" + name = header.stem + for pattern, kset in kernel_to_set.items(): + if pattern in name: + return kset + return "default" + + # Group kernels by set + kernels_by_set = {} + for h in kernel_headers: + kset = find_kernel_set(h) + if kset not in kernels_by_set: + kernels_by_set[kset] = [] + kernels_by_set[kset].append(h) + + # If only one set or no set info, use simple registration + if len(kernels_by_set) <= 1: + lines = [" (void)arch;", ""] + for h in kernel_headers: + lines.append(generate_registration_block(h)) + return "\n".join(lines) + + # Multiple sets - generate registration for all, plus store per-set info + lines = [" // Register ALL kernels from all sets", " (void)arch;", ""] + for h in kernel_headers: + lines.append(generate_registration_block(h)) + + # Store per-set mapping for separate function generation + global _kernels_by_set_cache + _kernels_by_set_cache = (kernels_by_set, generate_registration_block) + + return "\n".join(lines) + + +# Global cache for per-set kernel info +_kernels_by_set_cache = None + + +def generate_per_set_functions(source_stem: str) -> str: + """Generate separate registration functions for each kernel set. + + Generates: + 1. Per-set functions: register_(registry, arch) + 2. String-based dispatcher: register_kernel_set("set_name", registry, arch) + 3. get_kernel_set_names() to list available sets + """ + global _kernels_by_set_cache + if not _kernels_by_set_cache: + return "" + + kernels_by_set, gen_block = _kernels_by_set_cache + _kernels_by_set_cache = None # Clear cache + + lines = [] + set_names = [] + + # Generate per-set functions + for set_name, headers in kernels_by_set.items(): + safe_name = set_name.replace("-", "_") + set_names.append((set_name, safe_name)) + lines.append( + f"inline void register_{safe_name}(ck_tile::dispatcher::Registry& registry, const std::string& arch) {{" + ) + lines.append(" (void)arch;") + for h in headers: + lines.append(gen_block(h)) + lines.append("}") + lines.append("") + + # Generate string-based dispatcher (only if multiple sets) + if len(set_names) > 0: + lines.append("// Dynamic registration by kernel set name") + lines.append( + "inline bool register_kernel_set(const std::string& set_name, ck_tile::dispatcher::Registry& registry, const std::string& arch) {" + ) + for set_name, safe_name in set_names: + lines.append( + f' if (set_name == "{set_name}") {{ register_{safe_name}(registry, arch); return true; }}' + ) + lines.append(" return false; // Unknown set name") + lines.append("}") + lines.append("") + + # Generate helper to list available set names + lines.append("// Get list of available kernel set names") + lines.append("inline std::vector get_kernel_set_names() {") + names_str = ", ".join(f'"{name}"' for name, _ in set_names) + lines.append(f" return {{{names_str}}};") + lines.append("}") + lines.append("") + + return "\n".join(lines) + + +def generate_conv_registration( + kernel_headers: List[Path], example_name: str, kernels: List[Dict] +) -> str: + """Generate Conv kernel registration code for the dispatcher registry.""" + if not kernel_headers: + return " // No kernels to register" + + lines = [] + lines.append( + " (void)registry; (void)arch; // Conv uses direct launcher pattern for now" + ) + + # For conv, we provide direct access to kernel launchers + for i, h in enumerate(kernel_headers): + kernel_name = h.stem + lines.append(f" // Kernel {i + 1}: {kernel_name}") + + return "\n".join(lines) + + +def generate_conv_kernels( + kernels: List[Dict], output_dir: Path, codegen_dir: Path +) -> bool: + """Generate Conv kernels for ALL declarations using unified codegen.""" + if not kernels: + return False + + variant_map = { + "forward": "forward", + "bwd_data": "bwd_data", + "backward_data": "bwd_data", + "bwd_weight": "bwd_weight", + "backward_weight": "bwd_weight", + } + + success_count = 0 + + # Generate a kernel for EACH declaration + for idx, k in enumerate(kernels): + variant = variant_map.get(k.get("conv_type", "forward"), "forward") + + cmd = [ + sys.executable, + str(codegen_dir / "unified_conv_codegen.py"), + "--datatype", + k.get("dtype", "fp16"), + "--variant", + variant, + "--ndim", + str(k.get("ndim", 2)), + "--output", + str(output_dir), + ] + + # Add optional parameters if specified + if k.get("tile_m"): + cmd.extend(["--tile-m", str(k["tile_m"])]) + if k.get("tile_n"): + cmd.extend(["--tile-n", str(k["tile_n"])]) + if k.get("warp_m"): + cmd.extend(["--warp-m", str(k["warp_m"])]) + if k.get("warp_n"): + cmd.extend(["--warp-n", str(k["warp_n"])]) + if k.get("warp_k"): + cmd.extend(["--warp-k", str(k["warp_k"])]) + if k.get("warp_tile_m"): + cmd.extend(["--warp-tile-m", str(k["warp_tile_m"])]) + if k.get("warp_tile_n"): + cmd.extend(["--warp-tile-n", str(k["warp_tile_n"])]) + if k.get("warp_tile_k"): + cmd.extend(["--warp-tile-k", str(k["warp_tile_k"])]) + if k.get("pipeline"): + cmd.extend(["--pipeline", k["pipeline"]]) + if k.get("scheduler"): + cmd.extend(["--scheduler", k["scheduler"]]) + if k.get("epilogue"): + cmd.extend(["--epilogue", k["epilogue"]]) + if k.get("vector_a"): + cmd.extend(["--vector-a", str(k["vector_a"])]) + if k.get("vector_b"): + cmd.extend(["--vector-b", str(k["vector_b"])]) + if k.get("vector_c"): + cmd.extend(["--vector-c", str(k["vector_c"])]) + if k.get("block_per_cu"): + cmd.extend(["--block-per-cu", str(k["block_per_cu"])]) + if k.get("num_wave_groups"): + cmd.extend(["--num-wave-groups", str(k["num_wave_groups"])]) + if k.get("num_groups_to_merge"): + cmd.extend(["--num-groups-to-merge", str(k["num_groups_to_merge"])]) + if k.get("double_smem_buffer") is not None: + cmd.extend(["--double-smem-buffer", str(k["double_smem_buffer"]).lower()]) + if k.get("tile_k"): + cmd.extend(["--tile-k", str(k["tile_k"])]) + + result = subprocess.run( + cmd, capture_output=True, text=True, cwd=str(codegen_dir) + ) + if result.returncode != 0: + print(f" Codegen error for kernel {idx + 1}: {result.stderr[:300]}") + else: + success_count += 1 + + return success_count > 0 + + +def generate_gemm_kernels( + kernels: List[Dict], output_dir: Path, codegen_dir: Path +) -> bool: + """Generate GEMM kernels for ALL declarations using unified codegen.""" + import json + + if not kernels: + return False + + success_count = 0 + + # Generate a kernel for EACH declaration + for idx, k in enumerate(kernels): + variant = "multi_d" if k.get("elementwise_op") else "standard" + + # Build tile config JSON for this specific kernel + tile_config = { + "tile_m": [k.get("tile_m", 128)], + "tile_n": [k.get("tile_n", 128)], + "tile_k": [k.get("tile_k", 32)], + "warp_m": [k.get("warp_m", 2)], + "warp_n": [k.get("warp_n", 2)], + "warp_k": [k.get("warp_k", 1)], + "warp_tile_m": [k.get("warp_tile_m", 32)], + "warp_tile_n": [k.get("warp_tile_n", 32)], + "warp_tile_k": [k.get("warp_tile_k", 16)], + } + + trait_config = { + "pipeline": [k.get("pipeline", "compv3")], + "epilogue": [k.get("epilogue", "cshuffle")], + "scheduler": [k.get("scheduler", "intrawave")], + "pad_m": [k.get("pad_m", False)], + "pad_n": [k.get("pad_n", False)], + "pad_k": [k.get("pad_k", False)], + "persistent": [False], + } + + config_json = json.dumps( + {"tile_config": tile_config, "trait_config": trait_config} + ) + + cmd = [ + sys.executable, + str(codegen_dir / "unified_gemm_codegen.py"), + "--datatype", + k.get("dtype", "fp16"), + "--layout", + k.get("layout", "rcr"), + "--variants", + variant, + "--output", + str(output_dir), + "--tile-config-json", + config_json, + ] + + result = subprocess.run( + cmd, capture_output=True, text=True, cwd=str(codegen_dir) + ) + if result.returncode != 0: + print(f" Codegen error for kernel {idx + 1}: {result.stderr[:300]}") + else: + success_count += 1 + + return success_count > 0 + + +def compile_kernel(args: Tuple) -> Tuple[str, bool, str]: + """Compile a single kernel to object file.""" + kernel_hpp, output_dir, include_dirs, hipcc, gpu_target, idx, total = args + kernel_name = kernel_hpp.stem + + wrapper_cpp = output_dir / f"{kernel_name}.cpp" + wrapper_cpp.write_text( + f'#include "{kernel_hpp.name}"\nnamespace {{ volatile bool _k{idx} = true; }}\n' + ) + + obj_file = output_dir / f"{kernel_name}.o" + + cmd = [ + hipcc, + "-c", + "-fPIC", + "-std=c++17", + "-O3", + f"--offload-arch={gpu_target}", + "-mllvm", + "-enable-noalias-to-md-conversion=0", + "-Wno-undefined-func-template", + "-Wno-float-equal", + "--offload-compress", + ] + + for inc_dir in include_dirs: + cmd.extend(["-I", str(inc_dir)]) + cmd.extend(["-I", str(kernel_hpp.parent)]) + cmd.extend(["-o", str(obj_file), str(wrapper_cpp)]) + + result = subprocess.run(cmd, capture_output=True, text=True) + if result.returncode != 0: + return (kernel_name, False, result.stderr[:500]) + return (kernel_name, True, str(obj_file)) + + +def main(): + parser = argparse.ArgumentParser(description="Build example kernels") + parser.add_argument("source", type=Path, help="C++ source file") + parser.add_argument("--output-dir", type=Path, required=True) + parser.add_argument("--include-dirs", type=str, required=True) + parser.add_argument("--gpu-target", type=str, default="gfx942") + parser.add_argument("--jobs", type=int, default=os.cpu_count()) + parser.add_argument( + "--target-name", type=str, help="CMake target name (for library naming)" + ) + args = parser.parse_args() + + script_dir = Path(__file__).parent + codegen_dir = script_dir.parent / "codegen" + source_stem = args.source.stem # e.g., "01_basic_gemm" + target_name = args.target_name or source_stem # e.g., "gemm_01_basic" from CMake + + args.output_dir.mkdir(parents=True, exist_ok=True) + + # Detect and parse + example_type, kernels = detect_and_parse(args.source) + + if example_type == "conv": + k = kernels[0] if kernels else {} + variant = k.get("conv_type", "forward") + print( + f"[{target_name}] Conv {k.get('dtype', 'fp16')} {variant} {k.get('ndim', 2)}D ({len(kernels)} declarations)" + ) + elif example_type == "gemm": + k = kernels[0] if kernels else {} + print( + f"[{target_name}] GEMM {k.get('dtype', 'fp16')} {k.get('layout', 'rcr')} ({len(kernels)} declarations)" + ) + else: + print(f"[{target_name}] No kernel declarations - creating empty library") + lib_path = args.output_dir / f"lib{target_name}_kernels.a" + subprocess.run([find_ar(), "rcs", str(lib_path)], check=True) + header = args.output_dir / f"{source_stem}_kernels.hpp" + header.write_text(f"// No kernels for {target_name}\n#pragma once\n") + return 0 + + # Generate kernels + print(f"[{target_name}] Generating kernels...") + if example_type == "conv": + success = generate_conv_kernels(kernels, args.output_dir, codegen_dir) + else: + success = generate_gemm_kernels(kernels, args.output_dir, codegen_dir) + + if not success: + print(f"[{target_name}] Kernel generation failed!") + return 1 + + # Find generated headers + if example_type == "gemm": + kernel_headers = list(args.output_dir.glob("gemm_*.hpp")) + else: + k = kernels[0] if kernels else {} + variant = k.get("conv_type", "forward") + prefix_map = { + "forward": "conv_fwd", + "bwd_data": "conv_bwdd", + "bwd_weight": "conv_bwdw", + } + prefix = prefix_map.get(variant, "conv_fwd") + kernel_headers = list(args.output_dir.glob(f"{prefix}_*.hpp")) + + if not kernel_headers: + print(f"[{target_name}] No kernel headers generated!") + return 1 + + print(f"[{target_name}] Compiling {len(kernel_headers)} kernels...") + + include_dirs = [Path(p.strip()) for p in args.include_dirs.split(",")] + hipcc = find_hipcc() + + work = [ + ( + h, + args.output_dir, + include_dirs, + hipcc, + args.gpu_target, + i + 1, + len(kernel_headers), + ) + for i, h in enumerate(kernel_headers) + ] + + obj_files = [] + failed = [] + + with ProcessPoolExecutor(max_workers=args.jobs) as executor: + futures = {executor.submit(compile_kernel, w): w[0].name for w in work} + for future in as_completed(futures): + name, ok, result = future.result() + if ok: + obj_files.append(result) + else: + failed.append((name, result)) + print(f"[{target_name}] FAILED: {name}") + + if failed: + print(f"[{target_name}] {len(failed)} kernels failed") + for name, err in failed[:3]: + print(f" {name}: {err[:200]}") + return 1 + + # Create static library (use target_name for CMake compatibility) + lib_path = args.output_dir / f"lib{target_name}_kernels.a" + subprocess.run([find_ar(), "rcs", str(lib_path)] + obj_files, check=True) + + # Generate registration header (use source_stem for header name to match CMake's EXAMPLE_STEM) + header_path = args.output_dir / f"{source_stem}_kernels.hpp" + + # Build includes + includes = "\n".join(f'#include "{h.name}"' for h in kernel_headers) + + # Build kernel registration entries + # Function name uses source_stem (e.g., register_01_basic_gemm_kernels) + func_name = f"register_{source_stem}_kernels" + + # Generate registration code based on example type + if example_type == "gemm": + register_body = generate_gemm_registration(kernel_headers, target_name, kernels) + else: + register_body = generate_conv_registration(kernel_headers, target_name, kernels) + + # Generate appropriate header based on example type + if example_type == "conv" and kernel_headers: + launcher_aliases = [] + + # Helper to find kernel by dtype and type + def find_kernel_by_dtype_type(headers, dtype, conv_type_marker): + """Find kernel matching dtype and conv type, prioritize fp16.""" + matching = [h for h in headers if conv_type_marker in h.stem] + # Prefer fp16 over bf16 for default launchers + fp16_kernels = [h for h in matching if f"_{dtype}_" in h.stem] + return ( + fp16_kernels[0] if fp16_kernels else (matching[0] if matching else None) + ) + + # Check what conv types are in the declarations + has_fwd = any("forward" in k.get("conv_type", "forward") for k in kernels) + has_bwd_data = any("bwd_data" in k.get("conv_type", "") for k in kernels) + has_bwd_weight = any("bwd_weight" in k.get("conv_type", "") for k in kernels) + + # Export dtype-specific launcher aliases for each available dtype + for dtype in ["fp16", "bf16", "fp32"]: + dtype_fwd_kernels = [ + h + for h in kernel_headers + if "_fwd_" in h.stem and f"_{dtype}_" in h.stem + ] + if dtype_fwd_kernels: + k = dtype_fwd_kernels[0] + ns = f"ns_{k.stem}" + dtype_upper = dtype.upper() + launcher_aliases.append( + f"using {dtype_upper}FwdKernelLauncher = {ns}::{k.stem}_Launcher;" + ) + + # Export generic launcher aliases (prioritize fp16) + if has_fwd: + fwd_kernel = find_kernel_by_dtype_type(kernel_headers, "fp16", "_fwd_") + if fwd_kernel: + fwd_ns = f"ns_{fwd_kernel.stem}" + launcher_aliases.append( + f"using FwdKernelLauncher = {fwd_ns}::{fwd_kernel.stem}_Launcher;" + ) + launcher_aliases.append( + f"using FirstKernelLauncher = {fwd_ns}::{fwd_kernel.stem}_Launcher;" + ) + + if has_bwd_data: + bwdd_kernel = find_kernel_by_dtype_type(kernel_headers, "fp16", "_bwdd_") + if bwdd_kernel: + bwdd_ns = f"ns_{bwdd_kernel.stem}" + launcher_aliases.append( + f"using BwdDataKernelLauncher = {bwdd_ns}::{bwdd_kernel.stem}_Launcher;" + ) + if not has_fwd: # If no fwd, use bwd_data as first + launcher_aliases.append( + f"using FirstKernelLauncher = {bwdd_ns}::{bwdd_kernel.stem}_Launcher;" + ) + + if has_bwd_weight: + bwdw_kernel = find_kernel_by_dtype_type(kernel_headers, "fp16", "_bwdw_") + if bwdw_kernel: + bwdw_ns = f"ns_{bwdw_kernel.stem}" + launcher_aliases.append( + f"using BwdWeightKernelLauncher = {bwdw_ns}::{bwdw_kernel.stem}_Launcher;" + ) + if ( + not has_fwd and not has_bwd_data + ): # If no fwd or bwdd, use bwdw as first + launcher_aliases.append( + f"using FirstKernelLauncher = {bwdw_ns}::{bwdw_kernel.stem}_Launcher;" + ) + + launcher_section = "\n".join(launcher_aliases) + + header_content = f"""// Auto-generated for {target_name} +#pragma once + +{includes} + +#include "ck_tile/dispatcher/registry.hpp" +#include "ck_tile/dispatcher/kernel_instance.hpp" +#include "ck_tile/dispatcher/kernel_key.hpp" + +namespace generated {{ + +// Kernel launchers for direct use +{launcher_section} + +// Registration function +inline void {func_name}(ck_tile::dispatcher::Registry& registry, const std::string& arch) {{ +{register_body} +}} + +}} // namespace generated + +// Generic registration - avoids hardcoding the example name in user code +// Safe for single-example executables (typical use case) +#ifndef REGISTER_GENERATED_KERNELS +#define REGISTER_GENERATED_KERNELS(registry, arch) generated::{func_name}(registry, arch) +#endif +""" + else: + # GEMM: Generate per-set functions if multiple kernel sets declared + per_set_funcs = generate_per_set_functions(source_stem) + + header_content = f"""// Auto-generated for {target_name} +#pragma once + +{includes} + +#include "ck_tile/dispatcher/registry.hpp" +#include "ck_tile/dispatcher/kernel_instance.hpp" +#include "ck_tile/dispatcher/kernel_key.hpp" +#include "ck_tile/dispatcher/backends/generated_kernel_backend.hpp" + +namespace generated {{ + +// Register ALL kernels from all declared sets +inline void {func_name}(ck_tile::dispatcher::Registry& registry, const std::string& arch) {{ +{register_body} +}} + +{per_set_funcs} +}} // namespace generated + +// Generic registration - avoids hardcoding the example name in user code +// Safe for single-example executables (typical use case) +#ifndef REGISTER_GENERATED_KERNELS +#define REGISTER_GENERATED_KERNELS(registry, arch) generated::{func_name}(registry, arch) +#endif + +// Register a specific kernel set by name (for multi-registry patterns) +// Usage: REGISTER_KERNEL_SET("compute_bound_set", registry, arch) +#ifndef REGISTER_KERNEL_SET +#define REGISTER_KERNEL_SET(set_name, registry, arch) generated::register_kernel_set(set_name, registry, arch) +#endif +""" + header_path.write_text(header_content) + + print(f"[{target_name}] ✓ {len(obj_files)} kernels compiled") + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/dispatcher/scripts/parallel_kernel_builder.py b/dispatcher/scripts/parallel_kernel_builder.py new file mode 100755 index 00000000000..911ea61bd7e --- /dev/null +++ b/dispatcher/scripts/parallel_kernel_builder.py @@ -0,0 +1,142 @@ +#!/usr/bin/env python3 + +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +Build kernels in parallel - one translation unit per kernel. + +This script is called at make time (not cmake time) to avoid slow cmake configuration. +""" + +import argparse +import os +import subprocess +import sys +from pathlib import Path +from concurrent.futures import ProcessPoolExecutor, as_completed + + +def find_hipcc(): + """Find hipcc compiler.""" + candidates = [ + os.environ.get("HIPCC"), + "/opt/rocm/bin/hipcc", + shutil.which("hipcc") if shutil else None, + ] + for path in candidates: + if path and os.path.isfile(path): + return path + return "hipcc" # Assume in PATH + + +def compile_kernel(args): + """Compile a single kernel.""" + kernel_hpp, output_dir, include_dirs, hipcc = args + kernel_name = kernel_hpp.stem + + # Create wrapper .cpp + wrapper_cpp = output_dir / f"{kernel_name}.cpp" + wrapper_cpp.write_text(f'''// Auto-generated wrapper +#include "{kernel_hpp.name}" +namespace {{ volatile bool _k = true; }} +''') + + # Compile to object + obj_file = output_dir / f"{kernel_name}.o" + + cmd = [ + hipcc, + "-c", + "-fPIC", + "-std=c++17", + "-O3", + "--offload-arch=gfx942", + "-mllvm", + "-enable-noalias-to-md-conversion=0", + "-Wno-undefined-func-template", + "-Wno-float-equal", + "--offload-compress", + ] + + for inc_dir in include_dirs: + cmd.extend(["-I", str(inc_dir)]) + cmd.extend(["-I", str(kernel_hpp.parent)]) + + cmd.extend(["-o", str(obj_file), str(wrapper_cpp)]) + + result = subprocess.run(cmd, capture_output=True, text=True) + + if result.returncode != 0: + return (kernel_name, False, result.stderr) + return (kernel_name, True, str(obj_file)) + + +def main(): + parser = argparse.ArgumentParser(description="Build kernels in parallel") + parser.add_argument("--kernel-dir", type=Path, required=True) + parser.add_argument("--output-dir", type=Path, required=True) + parser.add_argument("--include-dirs", type=str, required=True) + parser.add_argument("--jobs", type=int, default=os.cpu_count()) + args = parser.parse_args() + + # Find kernel headers + kernel_headers = list(args.kernel_dir.glob("gemm_*.hpp")) + list( + args.kernel_dir.glob("conv_*.hpp") + ) + + if not kernel_headers: + print("No kernels found to build") + return 0 + + print(f"Building {len(kernel_headers)} kernels with {args.jobs} parallel jobs...") + + include_dirs = [Path(p.strip()) for p in args.include_dirs.split(",")] + hipcc = find_hipcc() + + args.output_dir.mkdir(parents=True, exist_ok=True) + + # Prepare work items + work = [(h, args.output_dir, include_dirs, hipcc) for h in kernel_headers] + + # Compile in parallel + obj_files = [] + failed = [] + + with ProcessPoolExecutor(max_workers=args.jobs) as executor: + futures = {executor.submit(compile_kernel, w): w[0].name for w in work} + + for i, future in enumerate(as_completed(futures), 1): + name, success, result = future.result() + if success: + obj_files.append(result) + print(f"[{i}/{len(kernel_headers)}] Built: {name}") + else: + failed.append((name, result)) + print(f"[{i}/{len(kernel_headers)}] FAILED: {name}") + + if failed: + print(f"\n{len(failed)} kernels failed to compile:") + for name, err in failed[:5]: + print(f" {name}: {err[:100]}") + return 1 + + # Link into shared library + print(f"\nLinking {len(obj_files)} objects into libdispatcher_kernels.so...") + lib_path = args.output_dir / "libdispatcher_kernels.so" + + link_cmd = [hipcc, "-shared", "-fPIC", "-o", str(lib_path)] + obj_files + result = subprocess.run(link_cmd, capture_output=True, text=True) + + if result.returncode != 0: + print(f"Linking failed: {result.stderr}") + return 1 + + print(f"✓ Built: {lib_path}") + return 0 + + +if __name__ == "__main__": + import shutil + + sys.exit(main()) diff --git a/dispatcher/scripts/stress_test_autocorrect.py b/dispatcher/scripts/stress_test_autocorrect.py new file mode 100644 index 00000000000..13e92abffa9 --- /dev/null +++ b/dispatcher/scripts/stress_test_autocorrect.py @@ -0,0 +1,540 @@ +#!/usr/bin/env python3 + +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +Stress Test for Auto-Correction and Codegen + +This script tests the robustness of: +1. GEMM auto-correction (Python) +2. Conv auto-correction (Python) +3. C++ kernel declaration validation and wildcard expansion +4. Architecture filtering + +Usage: + python3 scripts/stress_test_autocorrect.py [--arch gfx942] [--samples 50] [--verbose] +""" + +import argparse +import random +import sys +from pathlib import Path + +# Add paths for imports +dispatcher_root = Path(__file__).parent.parent +sys.path.insert(0, str(dispatcher_root / "python")) +sys.path.insert(0, str(dispatcher_root / "codegen")) +sys.path.insert(0, str(dispatcher_root / "scripts")) + +from ctypes_utils import auto_correct_kernel_config, KernelConfig # noqa: E402 + +# Import validation/expansion functions from compile scripts +from compile_gemm_examples import ( # noqa: E402 + validate_kernel_config, + expand_declaration_with_arch_filter, +) +from compile_conv_examples import ( # noqa: E402 + validate_conv_kernel_config, + expand_conv_declaration_with_arch_filter, +) + + +# ============================================================================= +# TEST PARAMETERS +# ============================================================================= + +# Valid dtypes +DTYPES = ["fp16", "bf16", "fp32", "fp8", "bf8", "int8"] + +# Valid layouts +LAYOUTS = ["rcr", "rrr", "crr", "ccr"] + +# Tile sizes (some valid, some invalid) +TILE_SIZES = [ + (32, 32, 16), + (64, 64, 32), + (128, 128, 32), + (256, 256, 64), + (128, 256, 32), + (256, 128, 32), + # Invalid sizes to test auto-correction + (100, 100, 50), + (17, 17, 17), + (512, 512, 128), +] + +# Wave configs (some valid, some invalid) +WAVE_CONFIGS = [ + (1, 1, 1), + (1, 2, 1), + (2, 1, 1), + (2, 2, 1), + (1, 4, 1), + (4, 1, 1), + (2, 4, 1), + (4, 2, 1), + # Invalid configs to test auto-correction + (3, 3, 1), + (5, 5, 1), + (1, 1, 2), +] + +# Warp tile sizes (some valid, some invalid) +WARP_TILES = [ + (16, 16, 16), + (16, 16, 32), + (32, 32, 8), + (32, 32, 16), + # Invalid tiles to test auto-correction + (48, 48, 24), + (64, 64, 32), +] + +# Pipelines and schedulers +PIPELINES = ["compv3", "compv4", "flatmma", "invalid_pipeline"] +SCHEDULERS = ["intrawave", "interwave", "invalid_scheduler"] + +# Architectures +ARCHS = ["gfx90a", "gfx942", "gfx950", "gfx1100", "gfx1200", "gfx1201"] + + +# ============================================================================= +# TEST FUNCTIONS +# ============================================================================= + + +def generate_random_gemm_config(): + """Generate a random GEMM configuration (may be invalid).""" + dtype = random.choice(DTYPES) + layout = random.choice(LAYOUTS) + tile = random.choice(TILE_SIZES) + wave = random.choice(WAVE_CONFIGS) + warp = random.choice(WARP_TILES) + pipeline = random.choice(PIPELINES) + scheduler = random.choice(SCHEDULERS) + arch = random.choice(ARCHS) + + return { + "name": f"test_{dtype}_{layout}_{tile[0]}x{tile[1]}x{tile[2]}", + "dtype_a": dtype, + "dtype_b": dtype, + "dtype_c": dtype, + "dtype_acc": "fp32", + "layout": layout, + "tile_m": tile[0], + "tile_n": tile[1], + "tile_k": tile[2], + "wave_m": wave[0], + "wave_n": wave[1], + "wave_k": wave[2], + "warp_m": warp[0], + "warp_n": warp[1], + "warp_k": warp[2], + "pipeline": pipeline, + "scheduler": scheduler, + "arch": arch, + } + + +def generate_random_conv_config(): + """Generate a random Conv configuration (may be invalid).""" + dtype = random.choice(["fp16", "bf16"]) + tile_k = random.choice([64, 128, 256]) + tile_c = random.choice([64, 128, 256]) + wave = random.choice(WAVE_CONFIGS) + warp = random.choice(WARP_TILES) + pipeline = random.choice(["compv3", "compv4"]) + scheduler = random.choice(["intrawave"]) + arch = random.choice(ARCHS) + + return { + "name": f"test_conv_{dtype}_{tile_k}x{tile_c}", + "dtype": dtype, + "layout": "nhwgc", + "conv_type": "forward", + "tile_k": tile_k, + "tile_c": tile_c, + "wave_m": wave[0], + "wave_n": wave[1], + "wave_k": wave[2], + "warp_m": warp[0], + "warp_n": warp[1], + "warp_k": warp[2], + "pipeline": pipeline, + "scheduler": scheduler, + "arch": arch, + } + + +def test_gemm_validation(config, verbose=False): + """Test GEMM validation and auto-correction.""" + arch = config.get("arch", "gfx942") + is_valid, error_msg = validate_kernel_config(config, arch) + + result = { + "config": config, + "is_valid": is_valid, + "error_msg": error_msg, + "expanded": [], + "auto_corrected": None, + } + + if not is_valid: + # Try wildcard expansion + wildcard_config = config.copy() + wildcard_config["wave_m"] = -1 + wildcard_config["wave_n"] = -1 + wildcard_config["warp_m"] = -1 + wildcard_config["warp_n"] = -1 + wildcard_config["pipeline"] = "*" + wildcard_config["scheduler"] = "*" + + expanded = expand_declaration_with_arch_filter(wildcard_config, arch) + result["expanded"] = expanded + + if verbose: + print(f"\n Config: {config['name']}") + print(f" Valid: {is_valid}") + if not is_valid: + print(f" Error: {error_msg[:80]}...") + print(f" Expanded to: {len(result['expanded'])} configurations") + + return result + + +def test_python_autocorrect(verbose=False): + """Test Python auto-correction for GEMM KernelConfig.""" + print("\n" + "=" * 70) + print(" PYTHON AUTO-CORRECTION TEST (GEMM KernelConfig)") + print("=" * 70) + + test_cases = [ + # Valid config + { + "name": "valid_fp16", + "dtype_a": "fp16", + "dtype_b": "fp16", + "dtype_c": "fp16", + "dtype_acc": "fp32", + "layout": "rcr", + "tile_m": 128, + "tile_n": 128, + "tile_k": 32, + "wave_m": 2, + "wave_n": 2, + "wave_k": 1, + "warp_m": 32, + "warp_n": 32, + "warp_k": 16, + "pipeline": "compv4", + "scheduler": "intrawave", + "gfx_arch": "gfx942", + }, + # Invalid wave config + { + "name": "invalid_wave", + "dtype_a": "fp16", + "dtype_b": "fp16", + "dtype_c": "fp16", + "dtype_acc": "fp32", + "layout": "rcr", + "tile_m": 128, + "tile_n": 128, + "tile_k": 32, + "wave_m": 1, + "wave_n": 1, + "wave_k": 1, # Invalid for gfx942 + "warp_m": 32, + "warp_n": 32, + "warp_k": 16, + "pipeline": "compv4", + "scheduler": "intrawave", + "gfx_arch": "gfx942", + }, + # Invalid scheduler + { + "name": "invalid_scheduler", + "dtype_a": "fp16", + "dtype_b": "fp16", + "dtype_c": "fp16", + "dtype_acc": "fp32", + "layout": "rcr", + "tile_m": 128, + "tile_n": 128, + "tile_k": 32, + "wave_m": 2, + "wave_n": 2, + "wave_k": 1, + "warp_m": 32, + "warp_n": 32, + "warp_k": 16, + "pipeline": "compv4", + "scheduler": "interwave", # May not be valid for all archs + "gfx_arch": "gfx942", + }, + ] + + results = {"passed": 0, "failed": 0, "details": []} + + for tc in test_cases: + try: + config = KernelConfig() + config.dtype_a = tc["dtype_a"] + config.dtype_b = tc["dtype_b"] + config.dtype_c = tc["dtype_c"] + config.dtype_acc = tc["dtype_acc"] + config.tile_m = tc["tile_m"] + config.tile_n = tc["tile_n"] + config.tile_k = tc["tile_k"] + config.wave_m = tc["wave_m"] + config.wave_n = tc["wave_n"] + config.wave_k = tc["wave_k"] + config.warp_m = tc["warp_m"] + config.warp_n = tc["warp_n"] + config.warp_k = tc["warp_k"] + config.pipeline = tc["pipeline"] + config.scheduler = tc["scheduler"] + config.gfx_arch = tc["gfx_arch"] + + corrected, was_modified, corrections = auto_correct_kernel_config( + config, verbose=verbose + ) + + results["passed"] += 1 + results["details"].append( + { + "name": tc["name"], + "status": "PASS", + "was_modified": was_modified, + "corrections": corrections, + } + ) + + if verbose: + print(f"\n {tc['name']}: PASS") + if was_modified: + print(f" Modified: {len(corrections)} correction(s)") + for c in corrections: + print(f" • {c}") + + except Exception as e: + results["failed"] += 1 + results["details"].append( + {"name": tc["name"], "status": "FAIL", "error": str(e)} + ) + if verbose: + print(f"\n {tc['name']}: FAIL - {e}") + + print(f"\n Summary: {results['passed']} passed, {results['failed']} failed") + return results + + +def run_stress_test(arch, num_samples, verbose): + """Run the full stress test.""" + print("\n" + "=" * 70) + print(" DISPATCHER AUTO-CORRECTION & CODEGEN STRESS TEST") + print("=" * 70) + print(f" Target Architecture: {arch}") + print(f" Number of Samples: {num_samples}") + print("=" * 70) + + # Test 1: GEMM Validation + print("\n" + "-" * 70) + print(" TEST 1: GEMM Validation & Wildcard Expansion") + print("-" * 70) + + gemm_results = {"valid": 0, "invalid": 0, "expanded": 0, "expansion_failed": 0} + + for i in range(num_samples): + config = generate_random_gemm_config() + config["arch"] = arch # Override with target arch + + result = test_gemm_validation(config, verbose) + + if result["is_valid"]: + gemm_results["valid"] += 1 + else: + gemm_results["invalid"] += 1 + if result["expanded"]: + gemm_results["expanded"] += 1 + else: + gemm_results["expansion_failed"] += 1 + + print("\n GEMM Results:") + print(f" Valid configs: {gemm_results['valid']}") + print(f" Invalid configs: {gemm_results['invalid']}") + print(f" Successfully expanded: {gemm_results['expanded']}") + print(f" Expansion failed: {gemm_results['expansion_failed']}") + + # Test 2: Conv Validation + print("\n" + "-" * 70) + print(" TEST 2: Conv Validation & Wildcard Expansion") + print("-" * 70) + + conv_results = {"valid": 0, "invalid": 0, "expanded": 0, "expansion_failed": 0} + + for i in range(num_samples): + config = generate_random_conv_config() + config["arch"] = arch # Override with target arch + + is_valid, error_msg = validate_conv_kernel_config(config, arch) + + if is_valid: + conv_results["valid"] += 1 + else: + conv_results["invalid"] += 1 + # Try wildcard expansion + wildcard_config = config.copy() + wildcard_config["wave_m"] = -1 + wildcard_config["wave_n"] = -1 + wildcard_config["warp_m"] = -1 + wildcard_config["warp_n"] = -1 + + expanded = expand_conv_declaration_with_arch_filter(wildcard_config, arch) + if expanded: + conv_results["expanded"] += 1 + else: + conv_results["expansion_failed"] += 1 + + print("\n Conv Results:") + print(f" Valid configs: {conv_results['valid']}") + print(f" Invalid configs: {conv_results['invalid']}") + print(f" Successfully expanded: {conv_results['expanded']}") + print(f" Expansion failed: {conv_results['expansion_failed']}") + + # Test 3: Python Auto-Correction + print("\n" + "-" * 70) + print(" TEST 3: Python Auto-Correction (KernelConfig)") + print("-" * 70) + + py_results = test_python_autocorrect(verbose) + + # Test 4: Architecture-specific tests + print("\n" + "-" * 70) + print(" TEST 4: Architecture-Specific Validation") + print("-" * 70) + + arch_test_configs = [ + # fp16 should work on all archs + {"dtype": "fp16", "expected_archs": ARCHS}, + # bf16 works on all archs that have bf16_bf16_fp32 in warp_tile_combos + { + "dtype": "bf16", + "expected_archs": [ + "gfx908", + "gfx90a", + "gfx942", + "gfx950", + "gfx1100", + "gfx1200", + "gfx1201", + ], + }, + # fp8 works on archs that have fp8_fp8_fp32 in warp_tile_combos + { + "dtype": "fp8", + "expected_archs": ["gfx90a", "gfx942", "gfx950", "gfx1200", "gfx1201"], + }, + ] + + for test in arch_test_configs: + dtype = test["dtype"] + print(f"\n Testing {dtype}:") + + for test_arch in ARCHS: + config = { + "name": f"arch_test_{dtype}_{test_arch}", + "dtype_a": dtype, + "dtype_b": dtype, + "dtype_c": dtype, + "dtype_acc": "fp32", + "layout": "rcr", + "tile_m": 128, + "tile_n": 128, + "tile_k": 32, + "wave_m": -1, # Wildcard + "wave_n": -1, + "wave_k": 1, + "warp_m": -1, + "warp_n": -1, + "warp_k": -1, + "pipeline": "*", + "scheduler": "*", + "arch": test_arch, + } + + expanded = expand_declaration_with_arch_filter(config, test_arch) + status = "✓" if expanded else "✗" + expected = test_arch in test["expected_archs"] + match = "OK" if (bool(expanded) == expected) else "MISMATCH" + + if verbose or match == "MISMATCH": + print(f" {test_arch}: {status} ({len(expanded)} configs) [{match}]") + + # Summary + print("\n" + "=" * 70) + print(" STRESS TEST SUMMARY") + print("=" * 70) + print( + f" GEMM: {gemm_results['valid'] + gemm_results['expanded']}/{num_samples} handled" + ) + print( + f" Conv: {conv_results['valid'] + conv_results['expanded']}/{num_samples} handled" + ) + print( + f" Python Auto-Correct: {py_results['passed']}/{py_results['passed'] + py_results['failed']} passed" + ) + + total_success = ( + gemm_results["valid"] + + gemm_results["expanded"] + + conv_results["valid"] + + conv_results["expanded"] + + py_results["passed"] + ) + total_tests = num_samples * 2 + py_results["passed"] + py_results["failed"] + + print(f"\n Overall: {total_success}/{total_tests} tests handled successfully") + print("=" * 70) + + return ( + gemm_results["expansion_failed"] == 0 and conv_results["expansion_failed"] == 0 + ) + + +def main(): + parser = argparse.ArgumentParser( + description="Stress test auto-correction and codegen" + ) + parser.add_argument( + "--arch", + default="gfx942", + choices=ARCHS, + help="Target GPU architecture (default: gfx942)", + ) + parser.add_argument( + "--samples", + type=int, + default=50, + help="Number of random samples to test (default: 50)", + ) + parser.add_argument( + "--verbose", "-v", action="store_true", help="Show detailed output" + ) + parser.add_argument( + "--seed", type=int, default=None, help="Random seed for reproducibility" + ) + + args = parser.parse_args() + + if args.seed is not None: + random.seed(args.seed) + + success = run_stress_test(args.arch, args.samples, args.verbose) + + return 0 if success else 1 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/dispatcher/src/dispatcher.cpp b/dispatcher/src/dispatcher.cpp new file mode 100644 index 00000000000..fdb400921ec --- /dev/null +++ b/dispatcher/src/dispatcher.cpp @@ -0,0 +1,152 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck_tile/dispatcher/dispatcher.hpp" +#include +#include +#include + +namespace ck_tile { +namespace dispatcher { + +Dispatcher::Dispatcher(Registry* registry) + : registry_(registry ? registry : &Registry::instance()), + heuristic_(nullptr), + strategy_(SelectionStrategy::FirstFit) +{ +} + +void Dispatcher::set_heuristic(HeuristicFunction heuristic) +{ + heuristic_ = heuristic; + if(heuristic_) + { + strategy_ = SelectionStrategy::Heuristic; + } +} + +void Dispatcher::set_strategy(SelectionStrategy strategy) { strategy_ = strategy; } + +KernelInstancePtr Dispatcher::select_kernel(const Problem& problem) const +{ + if(!problem.is_valid()) + { + return nullptr; + } + + switch(strategy_) + { + case SelectionStrategy::FirstFit: return select_first_fit(problem); + case SelectionStrategy::Heuristic: return select_heuristic(problem); + default: return nullptr; + } +} + +float Dispatcher::run( + const void* a_ptr, const void* b_ptr, void* c_ptr, const Problem& problem, void* stream) const +{ + return run_fused(a_ptr, b_ptr, c_ptr, nullptr, problem, stream); +} + +float Dispatcher::run_fused(const void* a_ptr, + const void* b_ptr, + void* c_ptr, + const void** d_ptrs, + const Problem& problem, + void* stream) const +{ + auto kernel = select_kernel(problem); + if(!kernel) + { + std::ostringstream oss; + oss << "No suitable kernel found for problem: M=" << problem.M << " N=" << problem.N + << " K=" << problem.K; + throw std::runtime_error(oss.str()); + } + + return kernel->run(a_ptr, b_ptr, c_ptr, d_ptrs, problem, stream); +} + +float Dispatcher::run_explicit(const std::string& kernel_id, + const void* a_ptr, + const void* b_ptr, + void* c_ptr, + const void** d_ptrs, + const Problem& problem, + void* stream) const +{ + auto kernel = registry_->lookup(kernel_id); + if(!kernel) + { + throw std::runtime_error("Kernel not found: " + kernel_id); + } + + if(!kernel->supports(problem)) + { + std::ostringstream oss; + oss << "Kernel " << kernel_id << " does not support problem: M=" << problem.M + << " N=" << problem.N << " K=" << problem.K; + throw std::runtime_error(oss.str()); + } + + return kernel->run(a_ptr, b_ptr, c_ptr, d_ptrs, problem, stream); +} + +bool Dispatcher::validate(const void* a_ptr, + const void* b_ptr, + const void* c_ptr, + const void** d_ptrs, + const Problem& problem, + float tolerance) const +{ + auto kernel = select_kernel(problem); + if(!kernel) + { + return false; + } + + return kernel->validate(a_ptr, b_ptr, c_ptr, d_ptrs, problem, tolerance); +} + +KernelInstancePtr Dispatcher::select_first_fit(const Problem& problem) const +{ + auto all_kernels = registry_->get_all(); + + for(const auto& kernel : all_kernels) + { + if(kernel->supports(problem)) + { + return kernel; + } + } + + return nullptr; +} + +KernelInstancePtr Dispatcher::select_heuristic(const Problem& problem) const +{ + if(!heuristic_) + { + // Fall back to first-fit if no heuristic available + return select_first_fit(problem); + } + + // Get ranked list of kernel identifiers from heuristic + auto candidates = heuristic_(problem); + + // Try each candidate in order + for(const auto& kernel_id : candidates) + { + auto kernel = registry_->lookup(kernel_id); + if(kernel && kernel->supports(problem)) + { + return kernel; + } + } + + // If no heuristic candidate works, fall back to first-fit + return select_first_fit(problem); +} + +} // namespace dispatcher +} // namespace ck_tile diff --git a/dispatcher/src/registry.cpp b/dispatcher/src/registry.cpp new file mode 100644 index 00000000000..0d83afd6130 --- /dev/null +++ b/dispatcher/src/registry.cpp @@ -0,0 +1,288 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck_tile/dispatcher/registry.hpp" +#include "ck_tile/dispatcher/json_export.hpp" +#include "ck_tile/dispatcher/arch_filter.hpp" +#include + +namespace ck_tile { +namespace dispatcher { + +Registry::Registry() + : name_("default"), + auto_export_enabled_(false), + auto_export_include_statistics_(true), + auto_export_on_every_registration_(true) +{ +} + +Registry::~Registry() +{ + // Perform auto-export on destruction if enabled (regardless of export_on_every_registration + // setting) + if(auto_export_enabled_) + { + perform_auto_export(); + } +} + +Registry::Registry(Registry&& other) noexcept + : mutex_() // mutex is not movable, create new one + , + kernels_(std::move(other.kernels_)), + name_(std::move(other.name_)), + auto_export_enabled_(other.auto_export_enabled_), + auto_export_filename_(std::move(other.auto_export_filename_)), + auto_export_include_statistics_(other.auto_export_include_statistics_), + auto_export_on_every_registration_(other.auto_export_on_every_registration_) +{ + // Disable auto-export on the moved-from object to prevent double export + other.auto_export_enabled_ = false; +} + +Registry& Registry::operator=(Registry&& other) noexcept +{ + if(this != &other) + { + std::lock_guard lock(mutex_); + std::lock_guard other_lock(other.mutex_); + + kernels_ = std::move(other.kernels_); + name_ = std::move(other.name_); + auto_export_enabled_ = other.auto_export_enabled_; + auto_export_filename_ = std::move(other.auto_export_filename_); + auto_export_include_statistics_ = other.auto_export_include_statistics_; + auto_export_on_every_registration_ = other.auto_export_on_every_registration_; + + // Disable auto-export on the moved-from object + other.auto_export_enabled_ = false; + } + return *this; +} + +bool Registry::register_kernel(KernelInstancePtr instance, Priority priority) +{ + if(!instance) + { + return false; + } + + const std::string identifier = instance->get_key().encode_identifier(); + + bool registered = false; + { + std::lock_guard lock(mutex_); + + auto it = kernels_.find(identifier); + if(it != kernels_.end()) + { + // Kernel with this identifier already exists + // Only replace if new priority is higher + if(priority > it->second.priority) + { + it->second.instance = instance; + it->second.priority = priority; + registered = true; + } + } + else + { + // New kernel, insert it + kernels_[identifier] = RegistryEntry{instance, priority}; + registered = true; + } + } + + // Perform auto-export if enabled and configured to export on every registration + if(registered && auto_export_enabled_ && auto_export_on_every_registration_) + { + perform_auto_export(); + } + + return registered; +} + +KernelInstancePtr Registry::lookup(const std::string& identifier) const +{ + std::lock_guard lock(mutex_); + + auto it = kernels_.find(identifier); + if(it != kernels_.end()) + { + return it->second.instance; + } + + return nullptr; +} + +KernelInstancePtr Registry::lookup(const KernelKey& key) const +{ + return lookup(key.encode_identifier()); +} + +std::vector Registry::get_all() const +{ + std::lock_guard lock(mutex_); + + std::vector result; + result.reserve(kernels_.size()); + + for(const auto& pair : kernels_) + { + result.push_back(pair.second.instance); + } + + return result; +} + +std::vector +Registry::filter(std::function predicate) const +{ + std::lock_guard lock(mutex_); + + std::vector result; + + for(const auto& pair : kernels_) + { + if(predicate(*pair.second.instance)) + { + result.push_back(pair.second.instance); + } + } + + return result; +} + +std::size_t Registry::size() const +{ + std::lock_guard lock(mutex_); + return kernels_.size(); +} + +bool Registry::empty() const +{ + std::lock_guard lock(mutex_); + return kernels_.empty(); +} + +void Registry::clear() +{ + std::lock_guard lock(mutex_); + kernels_.clear(); +} + +const std::string& Registry::get_name() const +{ + std::lock_guard lock(mutex_); + return name_; +} + +void Registry::set_name(const std::string& name) +{ + std::lock_guard lock(mutex_); + name_ = name; +} + +Registry& Registry::instance() +{ + static Registry global_registry; + return global_registry; +} + +std::string Registry::export_json(bool include_statistics) const +{ + return export_registry_json(*this, include_statistics); +} + +bool Registry::export_json_to_file(const std::string& filename, bool include_statistics) const +{ + return export_registry_json_to_file(*this, filename, include_statistics); +} + +void Registry::enable_auto_export(const std::string& filename, + bool include_statistics, + bool export_on_every_registration) +{ + std::lock_guard lock(mutex_); + auto_export_enabled_ = true; + auto_export_filename_ = filename; + auto_export_include_statistics_ = include_statistics; + auto_export_on_every_registration_ = export_on_every_registration; +} + +void Registry::disable_auto_export() +{ + std::lock_guard lock(mutex_); + auto_export_enabled_ = false; +} + +bool Registry::is_auto_export_enabled() const +{ + std::lock_guard lock(mutex_); + return auto_export_enabled_; +} + +void Registry::perform_auto_export() +{ + // Don't hold the lock during file I/O + std::string filename; + bool include_stats; + + { + std::lock_guard lock(mutex_); + if(!auto_export_enabled_) + { + return; + } + filename = auto_export_filename_; + include_stats = auto_export_include_statistics_; + } + + // Export without holding the lock + export_json_to_file(filename, include_stats); +} + +std::size_t Registry::merge_from(const Registry& other, Priority priority) +{ + auto other_kernels = other.get_all(); + std::size_t merged_count = 0; + + for(const auto& kernel : other_kernels) + { + if(register_kernel(kernel, priority)) + { + merged_count++; + } + } + + return merged_count; +} + +std::size_t Registry::filter_by_arch(const std::string& gpu_arch) +{ + ArchFilter filter(gpu_arch); + std::vector to_remove; + + { + std::lock_guard lock(mutex_); + + for(const auto& pair : kernels_) + { + if(!filter.is_valid(pair.second.instance->get_key())) + { + to_remove.push_back(pair.first); + } + } + + for(const auto& key : to_remove) + { + kernels_.erase(key); + } + } + + return to_remove.size(); +} + +} // namespace dispatcher +} // namespace ck_tile diff --git a/dispatcher/tests/CMakeLists.txt b/dispatcher/tests/CMakeLists.txt new file mode 100644 index 00000000000..6c20c18c957 --- /dev/null +++ b/dispatcher/tests/CMakeLists.txt @@ -0,0 +1,343 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +# ============================================================================= +# CK Tile Dispatcher Tests (C++ and Python) +# ============================================================================= + +cmake_minimum_required(VERSION 3.16) + +# Find Python +find_package(Python3 COMPONENTS Interpreter REQUIRED) + +# ============================================================================= +# Python Tests +# ============================================================================= + +# Auto-correction and validation stress test +add_test( + NAME dispatcher_test_autocorrect + COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_SOURCE_DIR}/test_autocorrect.py + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/.. +) + +set_tests_properties(dispatcher_test_autocorrect PROPERTIES + LABELS "dispatcher;python;validation" + TIMEOUT 120 + ENVIRONMENT "PYTHONPATH=${CMAKE_CURRENT_SOURCE_DIR}/../python:${CMAKE_CURRENT_SOURCE_DIR}/../codegen:${CMAKE_CURRENT_SOURCE_DIR}/../scripts" +) + +# Verbose version of the test +add_test( + NAME dispatcher_test_autocorrect_verbose + COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_SOURCE_DIR}/test_autocorrect.py -v + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/.. +) + +set_tests_properties(dispatcher_test_autocorrect_verbose PROPERTIES + LABELS "dispatcher;python;validation;verbose" + TIMEOUT 180 + ENVIRONMENT "PYTHONPATH=${CMAKE_CURRENT_SOURCE_DIR}/../python:${CMAKE_CURRENT_SOURCE_DIR}/../codegen:${CMAKE_CURRENT_SOURCE_DIR}/../scripts" +) + +# Individual Python Test Categories +add_test( + NAME dispatcher_test_gemm_validation + COMMAND ${Python3_EXECUTABLE} -m unittest test_autocorrect.TestGemmValidation test_autocorrect.TestGemmExpansion -v + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} +) + +set_tests_properties(dispatcher_test_gemm_validation PROPERTIES + LABELS "dispatcher;python;gemm;validation" + TIMEOUT 60 + ENVIRONMENT "PYTHONPATH=${CMAKE_CURRENT_SOURCE_DIR}/../python:${CMAKE_CURRENT_SOURCE_DIR}/../codegen:${CMAKE_CURRENT_SOURCE_DIR}/../scripts" +) + +add_test( + NAME dispatcher_test_python_autocorrect + COMMAND ${Python3_EXECUTABLE} -m unittest test_autocorrect.TestPythonAutoCorrect -v + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} +) + +set_tests_properties(dispatcher_test_python_autocorrect PROPERTIES + LABELS "dispatcher;python;autocorrect" + TIMEOUT 60 + ENVIRONMENT "PYTHONPATH=${CMAKE_CURRENT_SOURCE_DIR}/../python:${CMAKE_CURRENT_SOURCE_DIR}/../codegen:${CMAKE_CURRENT_SOURCE_DIR}/../scripts" +) + +add_test( + NAME dispatcher_test_stress + COMMAND ${Python3_EXECUTABLE} -m unittest test_autocorrect.TestStressRandom -v + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} +) + +set_tests_properties(dispatcher_test_stress PROPERTIES + LABELS "dispatcher;python;stress" + TIMEOUT 120 + ENVIRONMENT "PYTHONPATH=${CMAKE_CURRENT_SOURCE_DIR}/../python:${CMAKE_CURRENT_SOURCE_DIR}/../codegen:${CMAKE_CURRENT_SOURCE_DIR}/../scripts" +) + +add_test( + NAME dispatcher_test_arch_support + COMMAND ${Python3_EXECUTABLE} -m unittest test_autocorrect.TestArchitectureSupport -v + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} +) + +set_tests_properties(dispatcher_test_arch_support PROPERTIES + LABELS "dispatcher;python;arch" + TIMEOUT 60 + ENVIRONMENT "PYTHONPATH=${CMAKE_CURRENT_SOURCE_DIR}/../python:${CMAKE_CURRENT_SOURCE_DIR}/../codegen:${CMAKE_CURRENT_SOURCE_DIR}/../scripts" +) + +# Stress Test Script +add_test( + NAME dispatcher_stress_test + COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_SOURCE_DIR}/../scripts/stress_test_autocorrect.py + --arch gfx942 --samples 30 --seed 42 + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/.. +) + +set_tests_properties(dispatcher_stress_test PROPERTIES + LABELS "dispatcher;python;stress;integration" + TIMEOUT 180 + ENVIRONMENT "PYTHONPATH=${CMAKE_CURRENT_SOURCE_DIR}/../python:${CMAKE_CURRENT_SOURCE_DIR}/../codegen:${CMAKE_CURRENT_SOURCE_DIR}/../scripts" +) + +# ============================================================================= +# Integration Tests (mimic examples) +# ============================================================================= + +# Full integration test suite +add_test( + NAME dispatcher_integration_tests + COMMAND ${Python3_EXECUTABLE} -m pytest ${CMAKE_CURRENT_SOURCE_DIR}/test_examples_integration.py -v + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/.. +) + +set_tests_properties(dispatcher_integration_tests PROPERTIES + LABELS "dispatcher;python;integration;examples" + TIMEOUT 600 + ENVIRONMENT "PYTHONPATH=${CMAKE_CURRENT_SOURCE_DIR}/../python:${CMAKE_CURRENT_SOURCE_DIR}/../codegen:${CMAKE_CURRENT_SOURCE_DIR}/../scripts" +) + +# Quick integration test (utilities only) +add_test( + NAME dispatcher_integration_quick + COMMAND ${Python3_EXECUTABLE} -m pytest ${CMAKE_CURRENT_SOURCE_DIR}/test_examples_integration.py::TestUtilityImports -v + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/.. +) + +set_tests_properties(dispatcher_integration_quick PROPERTIES + LABELS "dispatcher;python;integration;quick" + TIMEOUT 60 + ENVIRONMENT "PYTHONPATH=${CMAKE_CURRENT_SOURCE_DIR}/../python:${CMAKE_CURRENT_SOURCE_DIR}/../codegen:${CMAKE_CURRENT_SOURCE_DIR}/../scripts" +) + +# GEMM examples integration +add_test( + NAME dispatcher_integration_gemm + COMMAND ${Python3_EXECUTABLE} -m pytest ${CMAKE_CURRENT_SOURCE_DIR}/test_examples_integration.py::TestGemmPythonExamples -v + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/.. +) + +set_tests_properties(dispatcher_integration_gemm PROPERTIES + LABELS "dispatcher;python;integration;gemm" + TIMEOUT 300 + ENVIRONMENT "PYTHONPATH=${CMAKE_CURRENT_SOURCE_DIR}/../python:${CMAKE_CURRENT_SOURCE_DIR}/../codegen:${CMAKE_CURRENT_SOURCE_DIR}/../scripts" +) + +# ============================================================================= +# C++ Tests (Google Test) +# ============================================================================= + +# Include Google Test setup +if(EXISTS "${CMAKE_CURRENT_SOURCE_DIR}/../../cmake/gtest.cmake") + include(${CMAKE_CURRENT_SOURCE_DIR}/../../cmake/gtest.cmake) +else() + include(gtest) +endif() + +# Mock kernel instance for testing (shared across tests) +add_library(dispatcher_test_utils STATIC + test_mock_kernel.cpp +) + +target_include_directories(dispatcher_test_utils PUBLIC + ${CMAKE_CURRENT_SOURCE_DIR} + ${CMAKE_CURRENT_SOURCE_DIR}/../include + ${CMAKE_CURRENT_SOURCE_DIR}/../../include +) + +target_link_libraries(dispatcher_test_utils PRIVATE + ck_tile_dispatcher +) + +# Test executables using Google Test +set(TEST_SOURCES + # Core unit tests + test_kernel_key.cpp + test_problem.cpp + test_registry.cpp + test_dispatcher.cpp + test_tile_backend.cpp + + # Extended unit tests (more comprehensive coverage) + test_kernel_key_extended.cpp + test_problem_extended.cpp + test_registry_extended.cpp + test_dispatcher_extended.cpp + + # Regression tests (known issues and edge cases) + test_regression.cpp + + # JSON export tests + test_json_export.cpp +) + +foreach(test_source ${TEST_SOURCES}) + get_filename_component(test_name ${test_source} NAME_WE) + + add_executable(${test_name} ${test_source}) + + target_link_libraries(${test_name} PRIVATE + ck_tile_dispatcher + dispatcher_test_utils + GTest::gtest_main + ) + + target_compile_options(${test_name} PRIVATE + -Wno-global-constructors + -Wno-undef + ) + + add_test(NAME ${test_name} COMMAND ${test_name}) + set_tests_properties(${test_name} PROPERTIES LABELS "dispatcher;cpp;unit") +endforeach() + +# Standalone integration tests (with their own main()) +set(STANDALONE_TESTS + test_minimal.cpp +) + +foreach(test_source ${STANDALONE_TESTS}) + get_filename_component(test_name ${test_source} NAME_WE) + + add_executable(${test_name} ${test_source}) + + target_link_libraries(${test_name} PRIVATE + ck_tile_dispatcher + dispatcher_test_utils + ) + + target_compile_options(${test_name} PRIVATE + -Wno-global-constructors + -Wno-undef + ) + + add_test(NAME ${test_name} COMMAND ${test_name}) + set_tests_properties(${test_name} PROPERTIES LABELS "dispatcher;cpp;integration") +endforeach() + +# ============================================================================= +# Real Kernel Tests (requires generated kernels) +# ============================================================================= + +set(KERNEL_OUTPUT_DIR "${CMAKE_CURRENT_BINARY_DIR}/../generated_kernels") +set(KERNEL_REGISTRATION_HEADER "${KERNEL_OUTPUT_DIR}/dispatcher_wrappers/register_all_kernels.hpp") +set(CODEGEN_SCRIPT "${CMAKE_CURRENT_SOURCE_DIR}/../codegen/unified_gemm_codegen.py") + +option(BUILD_DISPATCHER_REAL_KERNEL_TESTS "Build tests with real GPU kernels" ON) + +if(BUILD_DISPATCHER_REAL_KERNEL_TESTS AND EXISTS "${CODEGEN_SCRIPT}") + message(STATUS "Setting up real kernel test generation") + + add_custom_command( + OUTPUT ${KERNEL_REGISTRATION_HEADER} + COMMAND ${CMAKE_COMMAND} -E make_directory ${KERNEL_OUTPUT_DIR} + COMMAND ${Python3_EXECUTABLE} ${CODEGEN_SCRIPT} + --output-dir ${KERNEL_OUTPUT_DIR} + --datatype fp16 + --layout rcr + --gpu-target gfx942 + --preselected fp16_rcr_essential + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/../codegen + COMMENT "Generating CK Tile kernels for real kernel tests..." + VERBATIM + ) + + add_custom_target(generate_test_kernels DEPENDS ${KERNEL_REGISTRATION_HEADER}) + + set(SINGLE_KERNEL_HEADER "${KERNEL_OUTPUT_DIR}/gemm_fp16_rcr_compv4_cshuffle_intrawave_False_False_False_False_128x128x32_2x2x1_32x32x16.hpp") + + set(REAL_KERNEL_TESTS + test_real_kernel_simple + test_real_kernel_multi_size + test_real_kernel_performance + test_real_kernel_correctness + test_sanity_ck_tile + ) + + if(EXISTS "${SINGLE_KERNEL_HEADER}") + foreach(test_name ${REAL_KERNEL_TESTS}) + add_executable(${test_name} ${test_name}.cpp) + + add_dependencies(${test_name} generate_test_kernels) + + target_link_libraries(${test_name} PRIVATE + ck_tile_dispatcher + ) + + target_include_directories(${test_name} PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/../../include + ${KERNEL_OUTPUT_DIR} + ) + + target_compile_options(${test_name} PRIVATE + -include ${SINGLE_KERNEL_HEADER} + -mllvm -enable-noalias-to-md-conversion=0 + -Wno-undefined-func-template + -Wno-float-equal + --offload-compress + ) + + if(hip_FOUND) + target_link_libraries(${test_name} PRIVATE hip::device hip::host) + endif() + + add_test(NAME ${test_name} COMMAND ${test_name}) + set_tests_properties(${test_name} PROPERTIES LABELS "dispatcher;cpp;gpu;kernel") + endforeach() + endif() +endif() + +# ============================================================================= +# Custom Targets +# ============================================================================= + +add_custom_target(run_dispatcher_tests + COMMAND ${CMAKE_CTEST_COMMAND} -L dispatcher --output-on-failure + WORKING_DIRECTORY ${CMAKE_BINARY_DIR} + COMMENT "Running all dispatcher tests" +) + +add_custom_target(test_dispatcher_python + COMMAND ${CMAKE_CTEST_COMMAND} -L "dispatcher;python" --output-on-failure + WORKING_DIRECTORY ${CMAKE_BINARY_DIR} + COMMENT "Running Python dispatcher tests" +) + +add_custom_target(test_dispatcher_cpp + COMMAND ${CMAKE_CTEST_COMMAND} -L "dispatcher;cpp" --output-on-failure + WORKING_DIRECTORY ${CMAKE_BINARY_DIR} + COMMENT "Running C++ dispatcher tests" +) + +# ============================================================================= +# Summary +# ============================================================================= + +message(STATUS "Dispatcher tests configured:") +message(STATUS " Run all: ctest -L dispatcher") +message(STATUS " Run Python: ctest -L 'dispatcher;python' or make test_dispatcher_python") +message(STATUS " Run C++: ctest -L 'dispatcher;cpp' or make test_dispatcher_cpp") +message(STATUS " Run verbose: ctest -R dispatcher_test_autocorrect_verbose") diff --git a/dispatcher/tests/test_autocorrect.py b/dispatcher/tests/test_autocorrect.py new file mode 100644 index 00000000000..0ec3ebda3ce --- /dev/null +++ b/dispatcher/tests/test_autocorrect.py @@ -0,0 +1,625 @@ +#!/usr/bin/env python3 + +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +Comprehensive Test Suite for Auto-Correction and Validation + +Tests: +1. GEMM validation and wildcard expansion +2. Conv validation and wildcard expansion +3. Python KernelConfig auto-correction +4. Architecture-specific dtype support +5. Edge cases and error handling + +Can be run as: + python3 tests/test_autocorrect.py # Run all tests + python3 tests/test_autocorrect.py -v # Verbose output + python3 tests/test_autocorrect.py TestGemmValidation # Run specific test class + ctest -R test_autocorrect # Via ctest + +Exit codes: + 0 = All tests passed + 1 = Some tests failed +""" + +import sys +import unittest +import random +from pathlib import Path + +# Setup paths +SCRIPT_DIR = Path(__file__).parent.resolve() +DISPATCHER_DIR = SCRIPT_DIR.parent +sys.path.insert(0, str(DISPATCHER_DIR / "python")) +sys.path.insert(0, str(DISPATCHER_DIR / "codegen")) +sys.path.insert(0, str(DISPATCHER_DIR / "scripts")) + +# Import modules under test +from compile_gemm_examples import ( # noqa: E402 + validate_kernel_config, + expand_declaration_with_arch_filter, + is_wildcard_declaration, +) +from compile_conv_examples import ( # noqa: E402 + validate_conv_kernel_config, + expand_conv_declaration_with_arch_filter, + is_conv_wildcard_declaration, +) +from ctypes_utils import auto_correct_kernel_config, KernelConfig # noqa: E402 + + +# ============================================================================= +# TEST DATA +# ============================================================================= + +VALID_ARCHS = ["gfx90a", "gfx942", "gfx950"] +VALID_DTYPES = ["fp16", "bf16"] +VALID_LAYOUTS = ["rcr", "rrr"] +VALID_PIPELINES = ["compv3", "compv4"] +VALID_SCHEDULERS = ["intrawave"] + +# Known valid wave configs for gfx942 +VALID_WAVE_CONFIGS_GFX942 = [[1, 4, 1], [2, 2, 1], [4, 1, 1]] + +# Known valid warp tiles for fp16 on gfx942 +VALID_WARP_TILES_FP16_GFX942 = [[16, 16, 16], [16, 16, 32], [32, 32, 8], [32, 32, 16]] + + +# ============================================================================= +# GEMM VALIDATION TESTS +# ============================================================================= + + +class TestGemmValidation(unittest.TestCase): + """Test GEMM kernel validation.""" + + def test_valid_config(self): + """Valid configuration should pass validation.""" + config = { + "name": "test_valid", + "dtype_a": "fp16", + "dtype_b": "fp16", + "dtype_c": "fp16", + "layout": "rcr", + "tile_m": 128, + "tile_n": 128, + "tile_k": 32, + "wave_m": 2, + "wave_n": 2, + "wave_k": 1, + "warp_m": 32, + "warp_n": 32, + "warp_k": 16, + "pipeline": "compv4", + "scheduler": "intrawave", + } + is_valid, error = validate_kernel_config(config, "gfx942") + self.assertTrue(is_valid, f"Expected valid, got error: {error}") + + def test_invalid_wave_config(self): + """Invalid wave config should fail validation.""" + config = { + "name": "test_invalid_wave", + "dtype_a": "fp16", + "wave_m": 3, # Invalid + "wave_n": 3, # Invalid + "wave_k": 1, + "warp_m": 32, + "warp_n": 32, + "warp_k": 16, + "pipeline": "compv4", + "scheduler": "intrawave", + } + is_valid, error = validate_kernel_config(config, "gfx942") + self.assertFalse(is_valid) + self.assertIn("wave", error.lower()) + + def test_invalid_scheduler(self): + """Invalid scheduler should fail validation.""" + config = { + "name": "test_invalid_scheduler", + "dtype_a": "fp16", + "wave_m": 2, + "wave_n": 2, + "wave_k": 1, + "warp_m": 32, + "warp_n": 32, + "warp_k": 16, + "pipeline": "compv4", + "epilogue": "cshuffle", + "scheduler": "interwave", # Invalid with compv4+cshuffle + } + is_valid, error = validate_kernel_config(config, "gfx942") + self.assertFalse(is_valid) + self.assertIn("trait", error.lower()) + + def test_wildcard_skips_validation(self): + """Wildcard declarations should skip validation.""" + config = { + "name": "test_wildcard", + "dtype_a": "fp16", + "wave_m": -1, # Wildcard + "wave_n": -1, # Wildcard + "wave_k": 1, + "warp_m": 32, + "warp_n": 32, + "warp_k": 16, + "pipeline": "compv4", + "scheduler": "intrawave", + } + self.assertTrue(is_wildcard_declaration(config)) + is_valid, _ = validate_kernel_config(config, "gfx942") + self.assertTrue(is_valid) + + def test_unsupported_arch(self): + """Unsupported architecture should fail validation.""" + config = { + "name": "test_bad_arch", + "dtype_a": "fp16", + "wave_m": 2, + "wave_n": 2, + "wave_k": 1, + "warp_m": 32, + "warp_n": 32, + "warp_k": 16, + "pipeline": "compv4", + "scheduler": "intrawave", + } + is_valid, error = validate_kernel_config(config, "gfx_invalid") + self.assertFalse(is_valid) + self.assertIn("unsupported", error.lower()) + + +class TestGemmExpansion(unittest.TestCase): + """Test GEMM wildcard expansion.""" + + def test_wave_expansion(self): + """Wave wildcard should expand to valid configs.""" + config = { + "name": "test_wave_expand", + "dtype_a": "fp16", + "dtype_b": "fp16", + "dtype_c": "fp16", + "layout": "rcr", + "tile_m": 128, + "tile_n": 128, + "tile_k": 32, + "wave_m": -1, # Wildcard + "wave_n": -1, # Wildcard + "wave_k": 1, + "warp_m": 32, + "warp_n": 32, + "warp_k": 16, + "pipeline": "compv4", + "scheduler": "intrawave", + } + expanded = expand_declaration_with_arch_filter(config, "gfx942") + self.assertGreater(len(expanded), 0, "Should expand to at least one config") + + # All expanded configs should be valid + for exp in expanded: + is_valid, error = validate_kernel_config(exp, "gfx942") + self.assertTrue(is_valid, f"Expanded config invalid: {error}") + + def test_full_wildcard_expansion(self): + """Full wildcard should expand to multiple valid configs.""" + config = { + "name": "test_full_wildcard", + "dtype_a": "fp16", + "dtype_b": "fp16", + "dtype_c": "fp16", + "layout": "rcr", + "tile_m": 128, + "tile_n": 128, + "tile_k": 32, + "wave_m": -1, + "wave_n": -1, + "wave_k": 1, + "warp_m": -1, + "warp_n": -1, + "warp_k": -1, + "pipeline": "*", + "scheduler": "*", + } + expanded = expand_declaration_with_arch_filter(config, "gfx942") + self.assertGreater( + len(expanded), 1, "Full wildcard should expand to multiple configs" + ) + + def test_explicit_config_not_expanded(self): + """Explicit (non-wildcard) config should not expand.""" + config = { + "name": "test_explicit", + "dtype_a": "fp16", + "dtype_b": "fp16", + "dtype_c": "fp16", + "layout": "rcr", + "tile_m": 128, + "tile_n": 128, + "tile_k": 32, + "wave_m": 2, + "wave_n": 2, + "wave_k": 1, + "warp_m": 32, + "warp_n": 32, + "warp_k": 16, + "pipeline": "compv4", + "scheduler": "intrawave", + } + expanded = expand_declaration_with_arch_filter(config, "gfx942") + self.assertEqual(len(expanded), 1, "Explicit config should not expand") + + +# ============================================================================= +# CONV VALIDATION TESTS +# ============================================================================= + + +class TestConvValidation(unittest.TestCase): + """Test Conv kernel validation.""" + + def test_valid_conv_config(self): + """Valid conv configuration should pass validation.""" + config = { + "name": "test_valid_conv", + "dtype": "fp16", + "layout": "nhwgc", + "conv_type": "forward", + "tile_k": 128, + "tile_c": 128, + "wave_m": 2, + "wave_n": 2, + "wave_k": 1, + "warp_m": 32, + "warp_n": 32, + "warp_k": 16, + "pipeline": "compv4", + "scheduler": "intrawave", + } + is_valid, error = validate_conv_kernel_config(config, "gfx942") + self.assertTrue(is_valid, f"Expected valid, got error: {error}") + + def test_invalid_conv_wave(self): + """Invalid wave config should fail conv validation.""" + config = { + "name": "test_invalid_conv_wave", + "dtype": "fp16", + "wave_m": 5, # Invalid + "wave_n": 5, # Invalid + "wave_k": 1, + "warp_m": 32, + "warp_n": 32, + "warp_k": 16, + "pipeline": "compv4", + "scheduler": "intrawave", + } + is_valid, error = validate_conv_kernel_config(config, "gfx942") + self.assertFalse(is_valid) + self.assertIn("wave", error.lower()) + + def test_conv_wildcard_detection(self): + """Should correctly detect conv wildcards.""" + wildcard_config = { + "wave_m": -1, + "wave_n": 2, + "warp_m": 32, + "warp_n": 32, + "pipeline": "compv4", + "scheduler": "intrawave", + } + self.assertTrue(is_conv_wildcard_declaration(wildcard_config)) + + explicit_config = { + "wave_m": 2, + "wave_n": 2, + "warp_m": 32, + "warp_n": 32, + "pipeline": "compv4", + "scheduler": "intrawave", + } + self.assertFalse(is_conv_wildcard_declaration(explicit_config)) + + +class TestConvExpansion(unittest.TestCase): + """Test Conv wildcard expansion.""" + + def test_conv_wave_expansion(self): + """Conv wave wildcard should expand to valid configs.""" + config = { + "name": "test_conv_wave_expand", + "dtype": "fp16", + "layout": "nhwgc", + "conv_type": "forward", + "tile_k": 128, + "tile_c": 128, + "wave_m": -1, + "wave_n": -1, + "wave_k": 1, + "warp_m": 32, + "warp_n": 32, + "warp_k": 16, + "pipeline": "compv4", + "scheduler": "intrawave", + } + expanded = expand_conv_declaration_with_arch_filter(config, "gfx942") + self.assertGreater(len(expanded), 0, "Should expand to at least one config") + + +# ============================================================================= +# PYTHON AUTO-CORRECTION TESTS +# ============================================================================= + + +class TestPythonAutoCorrect(unittest.TestCase): + """Test Python KernelConfig auto-correction.""" + + def test_autocorrect_invalid_wave(self): + """Auto-correction should fix invalid wave config.""" + config = KernelConfig() + config.dtype_a = "fp16" + config.dtype_b = "fp16" + config.dtype_c = "fp16" + config.dtype_acc = "fp32" + config.layout_a = "row" + config.layout_b = "col" + config.layout_c = "row" + config.tile_m = 128 + config.tile_n = 128 + config.tile_k = 32 + config.wave_m = 1 # May be invalid + config.wave_n = 1 # May be invalid + config.wave_k = 1 + config.warp_m = 32 + config.warp_n = 32 + config.warp_k = 16 + config.pipeline = "compv4" + config.scheduler = "intrawave" + config.gfx_arch = "gfx942" + + corrected, was_modified, corrections = auto_correct_kernel_config( + config, verbose=False + ) + + # Should either be valid or corrected + self.assertIsNotNone(corrected) + if was_modified: + self.assertGreater(len(corrections), 0) + + def test_autocorrect_returns_three_values(self): + """Auto-correction should return (config, was_modified, corrections).""" + config = KernelConfig() + config.dtype_a = "fp16" + config.dtype_b = "fp16" + config.dtype_c = "fp16" + config.dtype_acc = "fp32" + config.layout_a = "row" + config.layout_b = "col" + config.layout_c = "row" + config.tile_m = 128 + config.tile_n = 128 + config.tile_k = 32 + config.wave_m = 2 + config.wave_n = 2 + config.wave_k = 1 + config.warp_m = 32 + config.warp_n = 32 + config.warp_k = 16 + config.pipeline = "compv4" + config.scheduler = "intrawave" + config.gfx_arch = "gfx942" + + result = auto_correct_kernel_config(config, verbose=False) + + self.assertEqual(len(result), 3, "Should return 3 values") + corrected, was_modified, corrections = result + self.assertIsInstance(was_modified, bool) + self.assertIsInstance(corrections, list) + + +# ============================================================================= +# STRESS TESTS +# ============================================================================= + + +class TestStressRandom(unittest.TestCase): + """Stress test with random configurations.""" + + def test_random_gemm_configs(self): + """Random GEMM configs should either validate or expand successfully.""" + random.seed(42) # Reproducible + + dtypes = ["fp16", "bf16"] + layouts = ["rcr", "rrr"] + tiles = [(64, 64, 32), (128, 128, 32), (256, 256, 64)] + waves = [(1, 1, 1), (2, 2, 1), (1, 4, 1), (3, 3, 1)] # Some invalid + warps = [(16, 16, 16), (32, 32, 16), (48, 48, 24)] # Some invalid + pipelines = ["compv3", "compv4", "invalid"] + schedulers = ["intrawave", "interwave"] + + success_count = 0 + total_count = 30 + + for _ in range(total_count): + config = { + "name": "random_test", + "dtype_a": random.choice(dtypes), + "dtype_b": random.choice(dtypes), + "dtype_c": random.choice(dtypes), + "layout": random.choice(layouts), + "tile_m": random.choice(tiles)[0], + "tile_n": random.choice(tiles)[1], + "tile_k": random.choice(tiles)[2], + "wave_m": random.choice(waves)[0], + "wave_n": random.choice(waves)[1], + "wave_k": random.choice(waves)[2], + "warp_m": random.choice(warps)[0], + "warp_n": random.choice(warps)[1], + "warp_k": random.choice(warps)[2], + "pipeline": random.choice(pipelines), + "scheduler": random.choice(schedulers), + } + + is_valid, _ = validate_kernel_config(config, "gfx942") + + if is_valid: + success_count += 1 + else: + # Try wildcard expansion + wildcard = config.copy() + wildcard["wave_m"] = -1 + wildcard["wave_n"] = -1 + wildcard["warp_m"] = -1 + wildcard["warp_n"] = -1 + wildcard["pipeline"] = "*" + wildcard["scheduler"] = "*" + + expanded = expand_declaration_with_arch_filter(wildcard, "gfx942") + if expanded: + success_count += 1 + + # At least 50% should be handleable + self.assertGreater( + success_count / total_count, + 0.5, + f"Only {success_count}/{total_count} configs were handleable", + ) + + def test_random_conv_configs(self): + """Random Conv configs should either validate or expand successfully.""" + random.seed(42) + + dtypes = ["fp16", "bf16"] + tiles = [(64, 64), (128, 128), (256, 256)] + waves = [(2, 2, 1), (1, 4, 1), (3, 3, 1)] + warps = [(16, 16, 16), (32, 32, 16)] + + success_count = 0 + total_count = 20 + + for _ in range(total_count): + config = { + "name": "random_conv_test", + "dtype": random.choice(dtypes), + "layout": "nhwgc", + "conv_type": "forward", + "tile_k": random.choice(tiles)[0], + "tile_c": random.choice(tiles)[1], + "wave_m": random.choice(waves)[0], + "wave_n": random.choice(waves)[1], + "wave_k": random.choice(waves)[2], + "warp_m": random.choice(warps)[0], + "warp_n": random.choice(warps)[1], + "warp_k": random.choice(warps)[2], + "pipeline": "compv4", + "scheduler": "intrawave", + } + + is_valid, _ = validate_conv_kernel_config(config, "gfx942") + + if is_valid: + success_count += 1 + else: + # Try wildcard expansion + wildcard = config.copy() + wildcard["wave_m"] = -1 + wildcard["wave_n"] = -1 + wildcard["warp_m"] = -1 + wildcard["warp_n"] = -1 + + expanded = expand_conv_declaration_with_arch_filter(wildcard, "gfx942") + if expanded: + success_count += 1 + + self.assertGreater( + success_count / total_count, + 0.5, + f"Only {success_count}/{total_count} conv configs were handleable", + ) + + +# ============================================================================= +# ARCHITECTURE TESTS +# ============================================================================= + + +class TestArchitectureSupport(unittest.TestCase): + """Test architecture-specific support.""" + + def test_gfx942_fp16_support(self): + """gfx942 should support fp16.""" + config = { + "dtype_a": "fp16", + "wave_m": -1, + "wave_n": -1, + "warp_m": -1, + "warp_n": -1, + "pipeline": "*", + "scheduler": "*", + } + expanded = expand_declaration_with_arch_filter(config, "gfx942") + self.assertGreater(len(expanded), 0, "gfx942 should support fp16") + + def test_gfx942_bf16_support(self): + """gfx942 should support bf16.""" + config = { + "dtype_a": "bf16", + "wave_m": -1, + "wave_n": -1, + "warp_m": -1, + "warp_n": -1, + "pipeline": "*", + "scheduler": "*", + } + expanded = expand_declaration_with_arch_filter(config, "gfx942") + self.assertGreater(len(expanded), 0, "gfx942 should support bf16") + + def test_gfx90a_support(self): + """gfx90a should support fp16.""" + config = { + "dtype_a": "fp16", + "wave_m": -1, + "wave_n": -1, + "warp_m": -1, + "warp_n": -1, + "pipeline": "*", + "scheduler": "*", + } + expanded = expand_declaration_with_arch_filter(config, "gfx90a") + self.assertGreater(len(expanded), 0, "gfx90a should support fp16") + + +# ============================================================================= +# MAIN +# ============================================================================= + + +def main(): + """Run tests.""" + # Parse args for verbosity + verbosity = 2 if "-v" in sys.argv or "--verbose" in sys.argv else 1 + + # Create test suite + loader = unittest.TestLoader() + suite = unittest.TestSuite() + + # Add all test classes + suite.addTests(loader.loadTestsFromTestCase(TestGemmValidation)) + suite.addTests(loader.loadTestsFromTestCase(TestGemmExpansion)) + suite.addTests(loader.loadTestsFromTestCase(TestConvValidation)) + suite.addTests(loader.loadTestsFromTestCase(TestConvExpansion)) + suite.addTests(loader.loadTestsFromTestCase(TestPythonAutoCorrect)) + suite.addTests(loader.loadTestsFromTestCase(TestStressRandom)) + suite.addTests(loader.loadTestsFromTestCase(TestArchitectureSupport)) + + # Run tests + runner = unittest.TextTestRunner(verbosity=verbosity) + result = runner.run(suite) + + # Return exit code + return 0 if result.wasSuccessful() else 1 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/dispatcher/tests/test_dispatcher.cpp b/dispatcher/tests/test_dispatcher.cpp new file mode 100644 index 00000000000..1e3893756c5 --- /dev/null +++ b/dispatcher/tests/test_dispatcher.cpp @@ -0,0 +1,296 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +/// Unit tests for Dispatcher using Google Test + +#include "ck_tile/dispatcher/dispatcher.hpp" +#include "test_mock_kernel.hpp" +#include + +using namespace ck_tile::dispatcher; +using namespace ck_tile::dispatcher::test; + +class DispatcherTest : public ::testing::Test +{ + protected: + void SetUp() override + { + // Clear registry before each test + Registry::instance().clear(); + } + + void TearDown() override + { + // Clean up after each test + Registry::instance().clear(); + } +}; + +TEST_F(DispatcherTest, SelectKernelFirstFit) +{ + Dispatcher dispatcher; + + // Register kernels + auto key1 = make_test_key(256); + auto key2 = make_test_key(128); + auto kernel1 = std::make_shared(key1, "kernel1"); + auto kernel2 = std::make_shared(key2, "kernel2"); + + Registry::instance().register_kernel(kernel1); + Registry::instance().register_kernel(kernel2); + + // Select kernel for valid problem + Problem problem(1024, 1024, 1024); + auto selected = dispatcher.select_kernel(problem); + + ASSERT_NE(selected, nullptr); + // Should select a kernel that supports the problem + // (order is not guaranteed, so just verify one is selected) + EXPECT_TRUE(selected->get_name() == "kernel1" || selected->get_name() == "kernel2"); + EXPECT_TRUE(selected->supports(problem)); +} + +TEST_F(DispatcherTest, SelectKernelInvalidProblem) +{ + Dispatcher dispatcher; + + // Register kernel + auto key = make_test_key(256); + auto kernel = std::make_shared(key, "kernel1"); + Registry::instance().register_kernel(kernel); + + // Invalid problem + Problem invalid_problem(0, 0, 0); + auto selected = dispatcher.select_kernel(invalid_problem); + + EXPECT_EQ(selected, nullptr); +} + +TEST_F(DispatcherTest, SelectKernelNoMatch) +{ + Dispatcher dispatcher; + + // Register kernel that doesn't support the problem + auto key = make_test_key(256); + auto kernel = std::make_shared(key, "kernel1", false); + Registry::instance().register_kernel(kernel); + + // Problem with dimensions not divisible by tile size + Problem problem(100, 100, 100); // Not divisible by 256 + auto selected = dispatcher.select_kernel(problem); + + EXPECT_EQ(selected, nullptr); +} + +TEST_F(DispatcherTest, SelectKernelHeuristic) +{ + Dispatcher dispatcher; + + // Register kernels + auto key1 = make_test_key(256); + auto key2 = make_test_key(128); + auto kernel1 = std::make_shared(key1, "kernel1"); + auto kernel2 = std::make_shared(key2, "kernel2"); + + Registry::instance().register_kernel(kernel1); + Registry::instance().register_kernel(kernel2); + + // Set heuristic that prefers kernel2 + dispatcher.set_heuristic([](const Problem&) { + std::vector candidates; + auto key2 = make_test_key(128); + candidates.push_back(key2.encode_identifier()); + auto key1 = make_test_key(256); + candidates.push_back(key1.encode_identifier()); + return candidates; + }); + + Problem problem(1024, 1024, 1024); + auto selected = dispatcher.select_kernel(problem); + + ASSERT_NE(selected, nullptr); + EXPECT_EQ(selected->get_name(), "kernel2"); +} + +TEST_F(DispatcherTest, SelectKernelHeuristicFallback) +{ + Dispatcher dispatcher; + + // Register kernel + auto key = make_test_key(256); + auto kernel = std::make_shared(key, "kernel1"); + Registry::instance().register_kernel(kernel); + + // Set heuristic that returns non-existent kernel + dispatcher.set_heuristic( + [](const Problem&) { return std::vector{"nonexistent_kernel"}; }); + + Problem problem(1024, 1024, 1024); + auto selected = dispatcher.select_kernel(problem); + + // Should fall back to first-fit + ASSERT_NE(selected, nullptr); + EXPECT_EQ(selected->get_name(), "kernel1"); +} + +TEST_F(DispatcherTest, RunBasic) +{ + Dispatcher dispatcher; + + // Register kernel + auto key = make_test_key(256); + auto kernel = std::make_shared(key, "kernel1"); + Registry::instance().register_kernel(kernel); + + Problem problem(1024, 1024, 1024); + + // Mock pointers (not actually used) + float a[1], b[1], c[1]; + + float time_ms = dispatcher.run(a, b, c, problem); + + EXPECT_GT(time_ms, 0.0f); + EXPECT_EQ(kernel->get_execution_count(), 1); +} + +TEST_F(DispatcherTest, RunNoKernel) +{ + Dispatcher dispatcher; + + // No kernels registered + Problem problem(1024, 1024, 1024); + + float a[1], b[1], c[1]; + + EXPECT_THROW((void)dispatcher.run(a, b, c, problem), std::runtime_error); +} + +TEST_F(DispatcherTest, RunExplicit) +{ + Dispatcher dispatcher; + + // Register kernel + auto key = make_test_key(256); + auto kernel = std::make_shared(key, "kernel1"); + Registry::instance().register_kernel(kernel); + + Problem problem(1024, 1024, 1024); + std::string kernel_id = key.encode_identifier(); + + float a[1], b[1], c[1]; + + float time_ms = dispatcher.run_explicit(kernel_id, a, b, c, nullptr, problem); + + EXPECT_GT(time_ms, 0.0f); + EXPECT_EQ(kernel->get_execution_count(), 1); +} + +TEST_F(DispatcherTest, RunExplicitNotFound) +{ + Dispatcher dispatcher; + + Problem problem(1024, 1024, 1024); + + float a[1], b[1], c[1]; + + EXPECT_THROW((void)dispatcher.run_explicit("nonexistent", a, b, c, nullptr, problem), + std::runtime_error); +} + +TEST_F(DispatcherTest, RunExplicitNotSupported) +{ + Dispatcher dispatcher; + + // Register kernel that doesn't support the problem + auto key = make_test_key(256); + auto kernel = std::make_shared(key, "kernel1", false); + Registry::instance().register_kernel(kernel); + + Problem problem(100, 100, 100); // Not divisible by 256 + std::string kernel_id = key.encode_identifier(); + + float a[1], b[1], c[1]; + + EXPECT_THROW((void)dispatcher.run_explicit(kernel_id, a, b, c, nullptr, problem), + std::runtime_error); +} + +TEST_F(DispatcherTest, Validate) +{ + Dispatcher dispatcher; + + // Register kernel + auto key = make_test_key(256); + auto kernel = std::make_shared(key, "kernel1"); + Registry::instance().register_kernel(kernel); + + Problem problem(1024, 1024, 1024); + + float a[1], b[1], c[1]; + + bool valid = dispatcher.validate(a, b, c, nullptr, problem); + + EXPECT_TRUE(valid); +} + +TEST_F(DispatcherTest, ValidateNoKernel) +{ + Dispatcher dispatcher; + + // No kernels registered + Problem problem(1024, 1024, 1024); + + float a[1], b[1], c[1]; + + bool valid = dispatcher.validate(a, b, c, nullptr, problem); + + EXPECT_FALSE(valid); +} + +TEST_F(DispatcherTest, StrategySelection) +{ + Dispatcher dispatcher; + + // Register kernel + auto key = make_test_key(256); + auto kernel = std::make_shared(key, "kernel1"); + Registry::instance().register_kernel(kernel); + + Problem problem(1024, 1024, 1024); + + // Test FirstFit strategy + dispatcher.set_strategy(Dispatcher::SelectionStrategy::FirstFit); + auto selected1 = dispatcher.select_kernel(problem); + ASSERT_NE(selected1, nullptr); + + // Test Heuristic strategy (without heuristic function - should fallback) + dispatcher.set_strategy(Dispatcher::SelectionStrategy::Heuristic); + auto selected2 = dispatcher.select_kernel(problem); + ASSERT_NE(selected2, nullptr); +} + +TEST_F(DispatcherTest, CustomRegistry) +{ + // Create custom registry instance (not singleton) + // Note: This requires Registry to allow non-singleton instances + // For now, we'll test with a separate registry instance + // In practice, custom registry would be created differently + + // Since Registry is singleton-only, we'll test that dispatcher + // can work with the singleton registry + Registry& registry = Registry::instance(); + registry.clear(); + + auto key = make_test_key(256); + auto kernel = std::make_shared(key, "kernel1"); + registry.register_kernel(kernel); + + // Dispatcher defaults to singleton registry + Dispatcher dispatcher; + + Problem problem(1024, 1024, 1024); + auto selected = dispatcher.select_kernel(problem); + + ASSERT_NE(selected, nullptr); + EXPECT_EQ(selected->get_name(), "kernel1"); +} diff --git a/dispatcher/tests/test_dispatcher_extended.cpp b/dispatcher/tests/test_dispatcher_extended.cpp new file mode 100644 index 00000000000..e8d7e4b5d19 --- /dev/null +++ b/dispatcher/tests/test_dispatcher_extended.cpp @@ -0,0 +1,499 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +/// Extended unit tests for Dispatcher - covers selection strategies, heuristics, edge cases + +#include "ck_tile/dispatcher/dispatcher.hpp" +#include "ck_tile/dispatcher/registry.hpp" +#include "test_mock_kernel.hpp" +#include +#include + +using namespace ck_tile::dispatcher; +using namespace ck_tile::dispatcher::test; +using SelectionStrategy = Dispatcher::SelectionStrategy; + +// ============================================================================= +// Basic Dispatcher Tests +// ============================================================================= + +class DispatcherBasicTest : public ::testing::Test +{ + protected: + void SetUp() override { Registry::instance().clear(); } + + void TearDown() override { Registry::instance().clear(); } +}; + +TEST_F(DispatcherBasicTest, DefaultConstruction) +{ + Dispatcher dispatcher; + // Should not crash + SUCCEED(); +} + +TEST_F(DispatcherBasicTest, SelectKernelEmpty) +{ + Dispatcher dispatcher; + Problem problem(1024, 1024, 1024); + + auto kernel = dispatcher.select_kernel(problem); + EXPECT_EQ(kernel, nullptr); +} + +TEST_F(DispatcherBasicTest, SelectKernelSingle) +{ + auto key = make_test_key(256); + auto kernel = std::make_shared(key, "test_kernel"); + Registry::instance().register_kernel(kernel); + + Dispatcher dispatcher; + Problem problem(1024, 1024, 1024); + + auto selected = dispatcher.select_kernel(problem); + ASSERT_NE(selected, nullptr); + EXPECT_EQ(selected->get_name(), "test_kernel"); +} + +TEST_F(DispatcherBasicTest, SelectKernelMultiple) +{ + // Register multiple kernels + for(int tile : {128, 256, 512}) + { + auto key = make_test_key(tile); + auto kernel = std::make_shared(key, "kernel_" + std::to_string(tile)); + Registry::instance().register_kernel(kernel); + } + + Dispatcher dispatcher; + Problem problem(1024, 1024, 1024); + + auto selected = dispatcher.select_kernel(problem); + ASSERT_NE(selected, nullptr); + // Should select one of the registered kernels + EXPECT_TRUE(selected->get_name() == "kernel_128" || selected->get_name() == "kernel_256" || + selected->get_name() == "kernel_512"); +} + +// ============================================================================= +// Selection Strategy Tests +// ============================================================================= + +class SelectionStrategyTest : public ::testing::Test +{ + protected: + void SetUp() override + { + Registry::instance().clear(); + + // Register kernels with different tile sizes + for(int tile : {128, 256, 512}) + { + auto key = make_test_key(tile); + auto kernel = + std::make_shared(key, "kernel_" + std::to_string(tile)); + Registry::instance().register_kernel(kernel); + } + } + + void TearDown() override { Registry::instance().clear(); } +}; + +TEST_F(SelectionStrategyTest, FirstFitStrategy) +{ + Dispatcher dispatcher; + dispatcher.set_strategy(SelectionStrategy::FirstFit); + + Problem problem(1024, 1024, 1024); + auto selected = dispatcher.select_kernel(problem); + + ASSERT_NE(selected, nullptr); + // FirstFit returns first matching kernel +} + +TEST_F(SelectionStrategyTest, HeuristicStrategy) +{ + Dispatcher dispatcher; + + // Set heuristic that prefers larger tiles for large problems + dispatcher.set_heuristic([](const Problem& p) -> std::vector { + if(p.M >= 1024 && p.N >= 1024) + { + // For large problems, prefer 512 tile + auto key = make_test_key(512); + return {key.encode_identifier()}; + } + // For small problems, prefer 128 tile + auto key = make_test_key(128); + return {key.encode_identifier()}; + }); + + dispatcher.set_strategy(SelectionStrategy::Heuristic); + + // Large problem should get 512 tile + Problem large_problem(2048, 2048, 2048); + auto selected = dispatcher.select_kernel(large_problem); + ASSERT_NE(selected, nullptr); + EXPECT_EQ(selected->get_name(), "kernel_512"); + + // Small problem should get 128 tile + Problem small_problem(256, 256, 256); + selected = dispatcher.select_kernel(small_problem); + ASSERT_NE(selected, nullptr); + EXPECT_EQ(selected->get_name(), "kernel_128"); +} + +TEST_F(SelectionStrategyTest, HeuristicWithFallback) +{ + Dispatcher dispatcher; + + // Heuristic returns non-existent kernel first, then valid one + dispatcher.set_heuristic([](const Problem& p) -> std::vector { + auto key = make_test_key(256); + return {"nonexistent_kernel", key.encode_identifier()}; + }); + + dispatcher.set_strategy(SelectionStrategy::Heuristic); + + Problem problem(1024, 1024, 1024); + auto selected = dispatcher.select_kernel(problem); + + ASSERT_NE(selected, nullptr); + EXPECT_EQ(selected->get_name(), "kernel_256"); +} + +TEST_F(SelectionStrategyTest, SwitchBetweenStrategies) +{ + Dispatcher dispatcher; + + // Start with FirstFit + dispatcher.set_strategy(SelectionStrategy::FirstFit); + + Problem problem(1024, 1024, 1024); + auto selected1 = dispatcher.select_kernel(problem); + ASSERT_NE(selected1, nullptr); + + // Switch to Heuristic + dispatcher.set_heuristic([](const Problem& p) -> std::vector { + auto key = make_test_key(256); + return {key.encode_identifier()}; + }); + dispatcher.set_strategy(SelectionStrategy::Heuristic); + + auto selected2 = dispatcher.select_kernel(problem); + ASSERT_NE(selected2, nullptr); +} + +// ============================================================================= +// Heuristic Function Tests +// ============================================================================= + +class HeuristicTest : public ::testing::Test +{ + protected: + void SetUp() override + { + Registry::instance().clear(); + + for(int tile : {64, 128, 256, 512}) + { + auto key = make_test_key(tile); + auto kernel = + std::make_shared(key, "kernel_" + std::to_string(tile)); + Registry::instance().register_kernel(kernel); + } + } + + void TearDown() override { Registry::instance().clear(); } +}; + +TEST_F(HeuristicTest, SizeBasedHeuristic) +{ + Dispatcher dispatcher; + + dispatcher.set_heuristic([](const Problem& p) -> std::vector { + std::vector candidates; + + // Problem-size based selection + int size = p.M * p.N * p.K; + + if(size >= 1024 * 1024 * 1024) + { + candidates.push_back(make_test_key(512).encode_identifier()); + candidates.push_back(make_test_key(256).encode_identifier()); + } + else if(size >= 256 * 256 * 256) + { + candidates.push_back(make_test_key(256).encode_identifier()); + candidates.push_back(make_test_key(128).encode_identifier()); + } + else + { + candidates.push_back(make_test_key(64).encode_identifier()); + candidates.push_back(make_test_key(128).encode_identifier()); + } + + return candidates; + }); + + dispatcher.set_strategy(SelectionStrategy::Heuristic); + + // Large problem + auto selected = dispatcher.select_kernel(Problem(1024, 1024, 1024)); + ASSERT_NE(selected, nullptr); + EXPECT_EQ(selected->get_name(), "kernel_512"); + + // Medium problem + selected = dispatcher.select_kernel(Problem(256, 256, 256)); + ASSERT_NE(selected, nullptr); + EXPECT_EQ(selected->get_name(), "kernel_256"); + + // Small problem + selected = dispatcher.select_kernel(Problem(64, 64, 64)); + ASSERT_NE(selected, nullptr); + EXPECT_EQ(selected->get_name(), "kernel_64"); +} + +TEST_F(HeuristicTest, EmptyHeuristicFallsBackToFirstFit) +{ + Dispatcher dispatcher; + + dispatcher.set_heuristic([](const Problem& p) -> std::vector { + return {}; // Empty list + }); + + dispatcher.set_strategy(SelectionStrategy::Heuristic); + + Problem problem(1024, 1024, 1024); + auto selected = dispatcher.select_kernel(problem); + + // Should fall back to FirstFit + ASSERT_NE(selected, nullptr); +} + +TEST_F(HeuristicTest, InvalidHeuristicFallsBackToFirstFit) +{ + Dispatcher dispatcher; + + dispatcher.set_heuristic([](const Problem& p) -> std::vector { + return {"invalid_kernel_1", "invalid_kernel_2"}; // All invalid + }); + + dispatcher.set_strategy(SelectionStrategy::Heuristic); + + Problem problem(1024, 1024, 1024); + auto selected = dispatcher.select_kernel(problem); + + // Should fall back to FirstFit + ASSERT_NE(selected, nullptr); +} + +// ============================================================================= +// Dispatcher with Custom Registry Tests +// ============================================================================= + +class DispatcherCustomRegistryTest : public ::testing::Test +{ + protected: + void TearDown() override { Registry::instance().clear(); } +}; + +TEST_F(DispatcherCustomRegistryTest, UseCustomRegistry) +{ + Registry custom_registry; + custom_registry.set_name("custom"); + + auto key = make_test_key(256); + auto kernel = std::make_shared(key, "custom_kernel"); + custom_registry.register_kernel(kernel); + + Dispatcher dispatcher(&custom_registry); + Problem problem(1024, 1024, 1024); + + auto selected = dispatcher.select_kernel(problem); + ASSERT_NE(selected, nullptr); + EXPECT_EQ(selected->get_name(), "custom_kernel"); +} + +TEST_F(DispatcherCustomRegistryTest, CustomRegistryIsolation) +{ + Registry custom_registry; + + auto key_custom = make_test_key(256); + auto key_global = make_test_key(512); + + custom_registry.register_kernel( + std::make_shared(key_custom, "custom_kernel")); + Registry::instance().register_kernel( + std::make_shared(key_global, "global_kernel")); + + Dispatcher custom_dispatcher(&custom_registry); + Dispatcher global_dispatcher; + + Problem problem(1024, 1024, 1024); + + auto custom_selected = custom_dispatcher.select_kernel(problem); + auto global_selected = global_dispatcher.select_kernel(problem); + + ASSERT_NE(custom_selected, nullptr); + ASSERT_NE(global_selected, nullptr); + + EXPECT_EQ(custom_selected->get_name(), "custom_kernel"); + EXPECT_EQ(global_selected->get_name(), "global_kernel"); +} + +// ============================================================================= +// Edge Cases Tests +// ============================================================================= + +class DispatcherEdgeCasesTest : public ::testing::Test +{ + protected: + void SetUp() override { Registry::instance().clear(); } + + void TearDown() override { Registry::instance().clear(); } +}; + +TEST_F(DispatcherEdgeCasesTest, InvalidProblem) +{ + auto key = make_test_key(256); + auto kernel = std::make_shared(key, "kernel"); + Registry::instance().register_kernel(kernel); + + Dispatcher dispatcher; + + // Zero dimensions + Problem invalid(0, 1024, 1024); + EXPECT_FALSE(invalid.is_valid()); + + // The dispatcher should still attempt selection + // (validation is up to the kernel's supports() method) +} + +TEST_F(DispatcherEdgeCasesTest, KernelDoesNotSupportProblem) +{ + auto key = make_test_key(256); + auto kernel = std::make_shared(key, "selective_kernel", false); + Registry::instance().register_kernel(kernel); + + Dispatcher dispatcher; + + // Problem not divisible by tile size - kernel doesn't support it + Problem problem(1000, 1000, 1000); // Not divisible by 256 + + auto selected = dispatcher.select_kernel(problem); + // Should return nullptr since kernel doesn't support this problem + EXPECT_EQ(selected, nullptr); +} + +TEST_F(DispatcherEdgeCasesTest, MultipleSelectionsConsistent) +{ + auto key = make_test_key(256); + auto kernel = std::make_shared(key, "kernel"); + Registry::instance().register_kernel(kernel); + + Dispatcher dispatcher; + Problem problem(1024, 1024, 1024); + + // Multiple selections should return the same kernel + auto selected1 = dispatcher.select_kernel(problem); + auto selected2 = dispatcher.select_kernel(problem); + auto selected3 = dispatcher.select_kernel(problem); + + ASSERT_NE(selected1, nullptr); + EXPECT_EQ(selected1, selected2); + EXPECT_EQ(selected2, selected3); +} + +// ============================================================================= +// Validate Method Tests +// ============================================================================= + +class DispatcherValidateTest : public ::testing::Test +{ + protected: + void SetUp() override + { + Registry::instance().clear(); + + auto key = make_test_key(256); + kernel_ = std::make_shared(key, "kernel"); + Registry::instance().register_kernel(kernel_); + } + + void TearDown() override { Registry::instance().clear(); } + + std::shared_ptr kernel_; +}; + +TEST_F(DispatcherValidateTest, ValidateWithMockKernel) +{ + Dispatcher dispatcher; + Problem problem(1024, 1024, 1024); + + // MockKernelInstance always validates successfully + bool valid = dispatcher.validate(nullptr, nullptr, nullptr, nullptr, problem); + + // This depends on implementation - mock returns true + // Real validation would need actual data +} + +// ============================================================================= +// Run Method Tests (with mock) +// ============================================================================= + +class DispatcherRunTest : public ::testing::Test +{ + protected: + void SetUp() override + { + Registry::instance().clear(); + + auto key = make_test_key(256); + kernel_ = std::make_shared(key, "kernel"); + Registry::instance().register_kernel(kernel_); + } + + void TearDown() override { Registry::instance().clear(); } + + std::shared_ptr kernel_; +}; + +TEST_F(DispatcherRunTest, RunWithMockKernel) +{ + Dispatcher dispatcher; + Problem problem(1024, 1024, 1024); + + // Mock run (with null pointers - mock doesn't use them) + float time = dispatcher.run(nullptr, nullptr, nullptr, problem); + + // Mock kernel returns 1.0f + EXPECT_FLOAT_EQ(time, 1.0f); + + // Verify execution count + EXPECT_EQ(kernel_->get_execution_count(), 1); +} + +TEST_F(DispatcherRunTest, MultipleRuns) +{ + Dispatcher dispatcher; + Problem problem(1024, 1024, 1024); + + for(int i = 0; i < 10; i++) + { + (void)dispatcher.run(nullptr, nullptr, nullptr, problem); + } + + EXPECT_EQ(kernel_->get_execution_count(), 10); +} + +TEST_F(DispatcherRunTest, RunWithNoKernelThrows) +{ + Registry::instance().clear(); + + Dispatcher dispatcher; + Problem problem(1024, 1024, 1024); + + // Should throw when no kernel found + EXPECT_THROW((void)dispatcher.run(nullptr, nullptr, nullptr, problem), std::runtime_error); +} diff --git a/dispatcher/tests/test_examples_integration.py b/dispatcher/tests/test_examples_integration.py new file mode 100644 index 00000000000..cfd18a33056 --- /dev/null +++ b/dispatcher/tests/test_examples_integration.py @@ -0,0 +1,337 @@ +#!/usr/bin/env python3 + +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +Integration tests that verify examples work correctly. + +These tests mimic the examples to ensure they continue working. +Run with: pytest test_examples_integration.py -v +""" + +import unittest +import subprocess +import sys +import os +from pathlib import Path + +# Get paths +SCRIPT_DIR = Path(__file__).parent.resolve() +DISPATCHER_ROOT = SCRIPT_DIR.parent +EXAMPLES_DIR = DISPATCHER_ROOT / "examples" +BUILD_DIR = DISPATCHER_ROOT / "build" +PYTHON_DIR = DISPATCHER_ROOT / "python" + +# Add python utilities to path +sys.path.insert(0, str(PYTHON_DIR)) + + +def run_python_example( + example_path: Path, timeout: int = 120 +) -> subprocess.CompletedProcess: + """Run a Python example and capture output.""" + env = os.environ.copy() + env["PYTHONPATH"] = str(PYTHON_DIR) + + return subprocess.run( + [sys.executable, str(example_path)], + capture_output=True, + text=True, + timeout=timeout, + cwd=example_path.parent, + env=env, + ) + + +def run_cpp_example( + example_name: str, timeout: int = 60 +) -> subprocess.CompletedProcess: + """Run a C++ example and capture output.""" + example_path = BUILD_DIR / "examples" / example_name + + if not example_path.exists(): + return None + + return subprocess.run( + [str(example_path)], + capture_output=True, + text=True, + timeout=timeout, + ) + + +class TestGemmPythonExamples(unittest.TestCase): + """Test GEMM Python examples.""" + + @classmethod + def setUpClass(cls): + """Check if examples directory exists.""" + cls.gemm_examples_dir = EXAMPLES_DIR / "gemm" / "python" + if not cls.gemm_examples_dir.exists(): + raise unittest.SkipTest("GEMM Python examples not found") + + def test_01_basic_gemm(self): + """Test basic GEMM example.""" + example = self.gemm_examples_dir / "01_basic_gemm.py" + if not example.exists(): + self.skipTest(f"{example.name} not found") + + result = run_python_example(example) + + self.assertEqual(result.returncode, 0, f"Example failed:\n{result.stderr}") + self.assertIn("TFLOPS", result.stdout, "Should report TFLOPS") + + def test_02_batch_gemm(self): + """Test batch GEMM example.""" + example = self.gemm_examples_dir / "02_batch_gemm.py" + if not example.exists(): + self.skipTest(f"{example.name} not found") + + result = run_python_example(example) + + self.assertEqual(result.returncode, 0, f"Example failed:\n{result.stderr}") + + def test_03_benchmark(self): + """Test benchmark example.""" + example = self.gemm_examples_dir / "03_benchmark.py" + if not example.exists(): + self.skipTest(f"{example.name} not found") + + result = run_python_example(example) + + self.assertEqual(result.returncode, 0, f"Example failed:\n{result.stderr}") + + def test_04_validation(self): + """Test validation example.""" + example = self.gemm_examples_dir / "04_validation.py" + if not example.exists(): + self.skipTest(f"{example.name} not found") + + result = run_python_example(example) + + self.assertEqual(result.returncode, 0, f"Example failed:\n{result.stderr}") + # Should pass validation + self.assertIn("PASS", result.stdout.upper(), "Validation should pass") + + +class TestConvPythonExamples(unittest.TestCase): + """Test Conv Python examples.""" + + @classmethod + def setUpClass(cls): + """Check if examples directory exists.""" + cls.conv_examples_dir = EXAMPLES_DIR / "conv" / "python" + if not cls.conv_examples_dir.exists(): + raise unittest.SkipTest("Conv Python examples not found") + + def test_01_basic_conv(self): + """Test basic conv example.""" + example = self.conv_examples_dir / "01_basic_conv.py" + if not example.exists(): + self.skipTest(f"{example.name} not found") + + result = run_python_example(example) + + self.assertEqual(result.returncode, 0, f"Example failed:\n{result.stderr}") + self.assertIn("TFLOPS", result.stdout, "Should report TFLOPS") + + def test_02_conv2d_fwd(self): + """Test 2D forward conv example.""" + example = self.conv_examples_dir / "02_conv2d_fwd.py" + if not example.exists(): + self.skipTest(f"{example.name} not found") + + result = run_python_example(example) + + self.assertEqual(result.returncode, 0, f"Example failed:\n{result.stderr}") + + def test_03_conv3d_fwd(self): + """Test 3D forward conv example.""" + example = self.conv_examples_dir / "03_conv3d_fwd.py" + if not example.exists(): + self.skipTest(f"{example.name} not found") + + result = run_python_example(example) + + self.assertEqual(result.returncode, 0, f"Example failed:\n{result.stderr}") + + def test_07_validation(self): + """Test validation example.""" + example = self.conv_examples_dir / "07_validation.py" + if not example.exists(): + self.skipTest(f"{example.name} not found") + + result = run_python_example(example) + + self.assertEqual(result.returncode, 0, f"Example failed:\n{result.stderr}") + self.assertIn("PASS", result.stdout.upper(), "Validation should pass") + + +class TestGemmCppExamples(unittest.TestCase): + """Test GEMM C++ examples.""" + + @classmethod + def setUpClass(cls): + """Check if build directory exists.""" + cls.examples_dir = BUILD_DIR / "examples" + if not cls.examples_dir.exists(): + raise unittest.SkipTest("C++ examples not built") + + def test_gemm_01_basic(self): + """Test basic GEMM C++ example.""" + result = run_cpp_example("gemm_01_basic") + if result is None: + self.skipTest("gemm_01_basic not built") + + self.assertEqual(result.returncode, 0, f"Example failed:\n{result.stderr}") + self.assertIn("TFLOPS", result.stdout, "Should report TFLOPS") + + def test_gemm_02_multi_size(self): + """Test multi-size GEMM C++ example.""" + result = run_cpp_example("gemm_02_multi_size") + if result is None: + self.skipTest("gemm_02_multi_size not built") + + self.assertEqual(result.returncode, 0, f"Example failed:\n{result.stderr}") + + def test_gemm_04_validation(self): + """Test validation GEMM C++ example.""" + result = run_cpp_example("gemm_04_validation") + if result is None: + self.skipTest("gemm_04_validation not built") + + self.assertEqual(result.returncode, 0, f"Example failed:\n{result.stderr}") + self.assertIn("PASS", result.stdout.upper(), "Validation should pass") + + +class TestConvCppExamples(unittest.TestCase): + """Test Conv C++ examples.""" + + @classmethod + def setUpClass(cls): + """Check if build directory exists.""" + cls.examples_dir = BUILD_DIR / "examples" + if not cls.examples_dir.exists(): + raise unittest.SkipTest("C++ examples not built") + + def test_conv_01_forward(self): + """Test forward conv C++ example.""" + result = run_cpp_example("conv_01_forward") + if result is None: + self.skipTest("conv_01_forward not built") + + self.assertEqual(result.returncode, 0, f"Example failed:\n{result.stderr}") + self.assertIn("TFLOPS", result.stdout, "Should report TFLOPS") + + def test_conv_02_validation(self): + """Test validation conv C++ example.""" + result = run_cpp_example("conv_02_validation") + if result is None: + self.skipTest("conv_02_validation not built") + + self.assertEqual(result.returncode, 0, f"Example failed:\n{result.stderr}") + self.assertIn("PASS", result.stdout.upper(), "Validation should pass") + + +class TestUtilityImports(unittest.TestCase): + """Test that utility modules can be imported.""" + + def test_import_ctypes_utils(self): + """Test importing ctypes_utils.""" + try: + from ctypes_utils import KernelConfig, setup_gemm_dispatcher # noqa: F401 + + self.assertTrue(True) + except ImportError as e: + self.fail(f"Failed to import ctypes_utils: {e}") + + def test_import_conv_utils(self): + """Test importing conv_utils.""" + try: + from conv_utils import ConvSignature, ConvAlgorithm, ConvProblem # noqa: F401 + + self.assertTrue(True) + except ImportError as e: + self.fail(f"Failed to import conv_utils: {e}") + + def test_kernel_config_creation(self): + """Test creating a KernelConfig.""" + from ctypes_utils import KernelConfig + + config = KernelConfig( + dtype_a="fp16", + dtype_b="fp16", + dtype_c="fp16", + dtype_acc="fp32", + layout_a="row", + layout_b="col", + layout_c="row", + ) + + self.assertEqual(config.dtype_a, "fp16") + self.assertEqual(config.layout_a, "row") + + def test_conv_signature_creation(self): + """Test creating a ConvSignature.""" + from conv_utils import ConvSignature + + sig = ConvSignature( + dtype_in="fp16", + dtype_wei="fp16", + dtype_out="fp16", + dtype_acc="fp32", + layout="nhwgc", + direction="forward", + num_dims=2, + ) + + self.assertEqual(sig.dtype_in, "fp16") + self.assertEqual(sig.direction, "forward") + + +class TestAutoCorrection(unittest.TestCase): + """Test auto-correction functionality.""" + + def test_gemm_auto_correct(self): + """Test GEMM auto-correction.""" + from ctypes_utils import KernelConfig, auto_correct_kernel_config + + # Create a config with invalid wave config + config = KernelConfig( + dtype_a="fp16", + dtype_b="fp16", + dtype_c="fp16", + dtype_acc="fp32", + layout_a="row", + layout_b="col", + layout_c="row", + wave_m=99, # Invalid + wave_n=99, # Invalid + wave_k=99, # Invalid + ) + + corrected, was_modified, corrections = auto_correct_kernel_config(config) + + self.assertTrue(was_modified, "Config should be modified") + self.assertGreater(len(corrections), 0, "Should have corrections") + + def test_conv_auto_correct(self): + """Test Conv auto-correction.""" + from conv_utils import auto_correct_conv_config + + # Call with invalid wave config parameters + corrected, was_modified, corrections = auto_correct_conv_config( + wave_m=99, # Invalid + wave_n=99, # Invalid + wave_k=99, # Invalid + dtype="fp16", + arch="gfx942", + ) + + self.assertTrue(was_modified, "Config should be modified") + self.assertGreater(len(corrections), 0, "Should have corrections") + + +if __name__ == "__main__": + unittest.main() diff --git a/dispatcher/tests/test_json_export.cpp b/dispatcher/tests/test_json_export.cpp new file mode 100644 index 00000000000..43927295548 --- /dev/null +++ b/dispatcher/tests/test_json_export.cpp @@ -0,0 +1,448 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +/// Unit tests for JSON export functionality + +#include "ck_tile/dispatcher/registry.hpp" +#include "ck_tile/dispatcher/json_export.hpp" +#include "test_mock_kernel.hpp" +#include +#include +#include + +using namespace ck_tile::dispatcher; +using namespace ck_tile::dispatcher::test; + +// ============================================================================= +// Basic Export Tests +// ============================================================================= + +class JSONExportBasicTest : public ::testing::Test +{ + protected: + void SetUp() override { Registry::instance().clear(); } + + void TearDown() override { Registry::instance().clear(); } +}; + +TEST_F(JSONExportBasicTest, ExportEmptyRegistry) +{ + std::string json = Registry::instance().export_json(false); + + EXPECT_FALSE(json.empty()); + EXPECT_NE(json.find("\"kernels\""), std::string::npos); + // Empty registry should still produce valid JSON with kernels section +} + +TEST_F(JSONExportBasicTest, ExportSingleKernel) +{ + auto key = make_test_key(256); + auto kernel = std::make_shared(key, "test_kernel"); + Registry::instance().register_kernel(kernel); + + std::string json = Registry::instance().export_json(false); + + EXPECT_FALSE(json.empty()); + EXPECT_NE(json.find("\"test_kernel\""), std::string::npos); +} + +TEST_F(JSONExportBasicTest, ExportMultipleKernels) +{ + for(int i = 0; i < 5; i++) + { + auto key = make_test_key(100 + i); + auto kernel = std::make_shared(key, "kernel_" + std::to_string(i)); + Registry::instance().register_kernel(kernel); + } + + std::string json = Registry::instance().export_json(false); + + // Should contain all kernel names + for(int i = 0; i < 5; i++) + { + EXPECT_NE(json.find("\"kernel_" + std::to_string(i) + "\""), std::string::npos); + } +} + +// ============================================================================= +// Export with Statistics Tests +// ============================================================================= + +class JSONExportStatisticsTest : public ::testing::Test +{ + protected: + void SetUp() override { Registry::instance().clear(); } + + void TearDown() override { Registry::instance().clear(); } +}; + +TEST_F(JSONExportStatisticsTest, ExportWithStatistics) +{ + auto key = make_test_key(256); + auto kernel = std::make_shared(key, "kernel"); + Registry::instance().register_kernel(kernel); + + std::string json = Registry::instance().export_json(true); // Include statistics + + EXPECT_NE(json.find("\"statistics\""), std::string::npos); + EXPECT_NE(json.find("\"by_datatype\""), std::string::npos); + EXPECT_NE(json.find("\"by_pipeline\""), std::string::npos); +} + +TEST_F(JSONExportStatisticsTest, ExportWithoutStatistics) +{ + auto key = make_test_key(256); + auto kernel = std::make_shared(key, "kernel"); + Registry::instance().register_kernel(kernel); + + std::string json = Registry::instance().export_json(false); // No statistics + + // Statistics section might be minimal or absent + EXPECT_NE(json.find("\"kernels\""), std::string::npos); +} + +// ============================================================================= +// Metadata Tests +// ============================================================================= + +class JSONExportMetadataTest : public ::testing::Test +{ + protected: + void SetUp() override { Registry::instance().clear(); } + + void TearDown() override { Registry::instance().clear(); } +}; + +TEST_F(JSONExportMetadataTest, MetadataPresent) +{ + std::string json = Registry::instance().export_json(true); + + EXPECT_NE(json.find("\"metadata\""), std::string::npos); + EXPECT_NE(json.find("\"timestamp\""), std::string::npos); + EXPECT_NE(json.find("\"total_kernels\""), std::string::npos); +} + +TEST_F(JSONExportMetadataTest, CorrectKernelCount) +{ + const int num_kernels = 7; + for(int i = 0; i < num_kernels; i++) + { + auto key = make_test_key(100 + i); + auto kernel = std::make_shared(key, "kernel_" + std::to_string(i)); + Registry::instance().register_kernel(kernel); + } + + std::string json = Registry::instance().export_json(true); + + EXPECT_NE(json.find("\"total_kernels\": " + std::to_string(num_kernels)), std::string::npos); +} + +TEST_F(JSONExportMetadataTest, RegistryNameIncluded) +{ + Registry::instance().set_name("test_registry"); + + auto key = make_test_key(256); + auto kernel = std::make_shared(key, "kernel"); + Registry::instance().register_kernel(kernel); + + std::string json = Registry::instance().export_json(true); + + EXPECT_NE(json.find("\"registry_name\""), std::string::npos); + EXPECT_NE(json.find("\"test_registry\""), std::string::npos); +} + +// ============================================================================= +// Export to File Tests +// ============================================================================= + +class JSONExportToFileTest : public ::testing::Test +{ + protected: + void SetUp() override + { + Registry::instance().clear(); + test_file_ = "/tmp/test_export_" + std::to_string(time(nullptr)) + ".json"; + } + + void TearDown() override + { + Registry::instance().clear(); + std::remove(test_file_.c_str()); + } + + std::string test_file_; +}; + +TEST_F(JSONExportToFileTest, ExportToFile) +{ + auto key = make_test_key(256); + auto kernel = std::make_shared(key, "kernel"); + Registry::instance().register_kernel(kernel); + + bool success = Registry::instance().export_json_to_file(test_file_, true); + EXPECT_TRUE(success); + + // Verify file exists + std::ifstream file(test_file_); + EXPECT_TRUE(file.good()); + + // Verify content + std::string content((std::istreambuf_iterator(file)), std::istreambuf_iterator()); + EXPECT_NE(content.find("\"kernel\""), std::string::npos); +} + +TEST_F(JSONExportToFileTest, ExportToInvalidPath) +{ + bool success = Registry::instance().export_json_to_file("/invalid/path/file.json", true); + EXPECT_FALSE(success); +} + +// ============================================================================= +// Auto-Export Tests +// ============================================================================= + +class JSONAutoExportTest : public ::testing::Test +{ + protected: + void SetUp() override + { + Registry::instance().clear(); + Registry::instance().disable_auto_export(); + test_file_ = "/tmp/test_auto_export_" + std::to_string(time(nullptr)) + ".json"; + } + + void TearDown() override + { + Registry::instance().disable_auto_export(); + Registry::instance().clear(); + std::remove(test_file_.c_str()); + } + + std::string test_file_; +}; + +TEST_F(JSONAutoExportTest, EnableAutoExport) +{ + EXPECT_FALSE(Registry::instance().is_auto_export_enabled()); + + Registry::instance().enable_auto_export(test_file_, true, false); + + EXPECT_TRUE(Registry::instance().is_auto_export_enabled()); +} + +TEST_F(JSONAutoExportTest, DisableAutoExport) +{ + Registry::instance().enable_auto_export(test_file_, true, false); + EXPECT_TRUE(Registry::instance().is_auto_export_enabled()); + + Registry::instance().disable_auto_export(); + EXPECT_FALSE(Registry::instance().is_auto_export_enabled()); +} + +TEST_F(JSONAutoExportTest, AutoExportOnRegistration) +{ + // Enable auto-export with export_on_every_registration=true + Registry::instance().enable_auto_export(test_file_, true, false); + + auto key = make_test_key(256); + auto kernel = std::make_shared(key, "auto_kernel"); + Registry::instance().register_kernel(kernel); + + // File might be created on registration or on exit depending on implementation + // Just verify auto-export is enabled + EXPECT_TRUE(Registry::instance().is_auto_export_enabled()); +} + +// ============================================================================= +// JSON Validity Tests +// ============================================================================= + +class JSONValidityTest : public ::testing::Test +{ + protected: + void SetUp() override { Registry::instance().clear(); } + + void TearDown() override { Registry::instance().clear(); } + + // Simple JSON syntax checker + bool isValidJSON(const std::string& json) + { + int braces = 0; + int brackets = 0; + bool in_string = false; + char prev = '\0'; + + for(char c : json) + { + if(c == '"' && prev != '\\') + { + in_string = !in_string; + } + + if(!in_string) + { + if(c == '{') + braces++; + else if(c == '}') + braces--; + else if(c == '[') + brackets++; + else if(c == ']') + brackets--; + } + + if(braces < 0 || brackets < 0) + return false; + prev = c; + } + + return braces == 0 && brackets == 0 && !in_string; + } +}; + +TEST_F(JSONValidityTest, EmptyRegistryProducesValidJSON) +{ + std::string json = Registry::instance().export_json(true); + EXPECT_TRUE(isValidJSON(json)); +} + +TEST_F(JSONValidityTest, SingleKernelProducesValidJSON) +{ + auto key = make_test_key(256); + auto kernel = std::make_shared(key, "kernel"); + Registry::instance().register_kernel(kernel); + + std::string json = Registry::instance().export_json(true); + EXPECT_TRUE(isValidJSON(json)); +} + +TEST_F(JSONValidityTest, ManyKernelsProduceValidJSON) +{ + for(int i = 0; i < 50; i++) + { + auto key = make_test_key(100 + i); + auto kernel = std::make_shared(key, "kernel_" + std::to_string(i)); + Registry::instance().register_kernel(kernel); + } + + std::string json = Registry::instance().export_json(true); + EXPECT_TRUE(isValidJSON(json)); +} + +TEST_F(JSONValidityTest, NoNullBytesInJSON) +{ + auto key = make_test_key(256); + auto kernel = std::make_shared(key, "kernel"); + Registry::instance().register_kernel(kernel); + + std::string json = Registry::instance().export_json(true); + + // Check for null bytes + EXPECT_EQ(json.find('\0'), std::string::npos); +} + +TEST_F(JSONValidityTest, NoPrintableGarbageInJSON) +{ + auto key = make_test_key(256); + auto kernel = std::make_shared(key, "kernel"); + Registry::instance().register_kernel(kernel); + + std::string json = Registry::instance().export_json(true); + + // All characters should be printable or whitespace + for(char c : json) + { + EXPECT_TRUE(std::isprint(c) || std::isspace(c)) + << "Non-printable character: " << static_cast(c); + } +} + +// ============================================================================= +// Kernel Details Tests +// ============================================================================= + +class JSONKernelDetailsTest : public ::testing::Test +{ + protected: + void SetUp() override { Registry::instance().clear(); } + + void TearDown() override { Registry::instance().clear(); } +}; + +TEST_F(JSONKernelDetailsTest, SignatureIncluded) +{ + auto key = make_test_key(256); + key.signature.dtype_a = DataType::FP16; + key.signature.dtype_b = DataType::FP16; + key.signature.dtype_c = DataType::FP16; + + auto kernel = std::make_shared(key, "kernel"); + Registry::instance().register_kernel(kernel); + + std::string json = Registry::instance().export_json(true); + + EXPECT_NE(json.find("\"signature\""), std::string::npos); + EXPECT_NE(json.find("\"dtype_a\""), std::string::npos); + EXPECT_NE(json.find("\"fp16\""), std::string::npos); +} + +TEST_F(JSONKernelDetailsTest, AlgorithmIncluded) +{ + auto key = make_test_key(256, 256, 32); + auto kernel = std::make_shared(key, "kernel"); + Registry::instance().register_kernel(kernel); + + std::string json = Registry::instance().export_json(true); + + EXPECT_NE(json.find("\"algorithm\""), std::string::npos); + EXPECT_NE(json.find("\"tile_shape\""), std::string::npos); +} + +TEST_F(JSONKernelDetailsTest, IdentifierIncluded) +{ + auto key = make_test_key(256); + auto kernel = std::make_shared(key, "my_kernel"); + Registry::instance().register_kernel(kernel); + + std::string json = Registry::instance().export_json(true); + + EXPECT_NE(json.find("\"identifier\""), std::string::npos); + EXPECT_NE(json.find("\"name\""), std::string::npos); + EXPECT_NE(json.find("\"my_kernel\""), std::string::npos); +} + +// ============================================================================= +// Multiple Registries Export Tests +// ============================================================================= + +class JSONMultipleRegistriesTest : public ::testing::Test +{ + protected: + void TearDown() override { Registry::instance().clear(); } +}; + +TEST_F(JSONMultipleRegistriesTest, DifferentRegistriesDifferentJSON) +{ + Registry reg1; + reg1.set_name("registry1"); + + Registry reg2; + reg2.set_name("registry2"); + + auto key1 = make_test_key(128); + auto key2 = make_test_key(256); + + reg1.register_kernel(std::make_shared(key1, "k1")); + reg2.register_kernel(std::make_shared(key2, "k2")); + + std::string json1 = reg1.export_json(true); + std::string json2 = reg2.export_json(true); + + EXPECT_NE(json1, json2); + + EXPECT_NE(json1.find("\"registry1\""), std::string::npos); + EXPECT_NE(json2.find("\"registry2\""), std::string::npos); + + EXPECT_NE(json1.find("\"k1\""), std::string::npos); + EXPECT_NE(json2.find("\"k2\""), std::string::npos); +} diff --git a/dispatcher/tests/test_kernel_key.cpp b/dispatcher/tests/test_kernel_key.cpp new file mode 100644 index 00000000000..b35641952af --- /dev/null +++ b/dispatcher/tests/test_kernel_key.cpp @@ -0,0 +1,147 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +/// Unit tests for KernelKey using Google Test + +#include "ck_tile/dispatcher/kernel_key.hpp" +#include "test_mock_kernel.hpp" +#include + +using namespace ck_tile::dispatcher; +using namespace ck_tile::dispatcher::test; + +TEST(KernelKeyTest, Construction) +{ + KernelKey key; + key.signature.dtype_a = DataType::FP16; + key.signature.dtype_b = DataType::FP16; + key.signature.dtype_c = DataType::FP16; + key.signature.dtype_acc = DataType::FP32; + key.signature.elementwise_op = "PassThrough"; + key.signature.num_d_tensors = 0; + + key.algorithm.tile_shape.m = 256; + key.algorithm.tile_shape.n = 256; + key.algorithm.tile_shape.k = 32; + + key.gfx_arch = "gfx942"; + + EXPECT_EQ(key.signature.dtype_a, DataType::FP16); + EXPECT_EQ(key.algorithm.tile_shape.m, 256); + EXPECT_EQ(key.gfx_arch, "gfx942"); +} + +TEST(KernelKeyTest, Equality) +{ + // Use helper function to ensure all fields are initialized + KernelKey key1 = make_test_key(256, 256, 32, "gfx942"); + KernelKey key2 = make_test_key(256, 256, 32, "gfx942"); + + EXPECT_EQ(key1, key2); + EXPECT_FALSE(key1 != key2); + + // Change one value + KernelKey key3 = make_test_key(128, 256, 32, "gfx942"); + EXPECT_NE(key1, key3); + EXPECT_FALSE(key1 == key3); +} + +TEST(KernelKeyTest, EncodeIdentifier) +{ + KernelKey key; + key.signature.split_k = 1; + key.signature.elementwise_op = "PassThrough"; + key.signature.num_d_tensors = 0; + key.algorithm.tile_shape.m = 256; + key.algorithm.tile_shape.n = 256; + key.algorithm.tile_shape.k = 32; + key.algorithm.wave_shape.m = 2; + key.algorithm.wave_shape.n = 2; + key.algorithm.wave_shape.k = 1; + key.algorithm.warp_tile_shape.m = 32; + key.algorithm.warp_tile_shape.n = 32; + key.algorithm.warp_tile_shape.k = 16; + key.algorithm.persistent = true; + key.algorithm.preshuffle = false; + key.signature.structured_sparsity = false; + + std::string id = key.encode_identifier(); + + // Check that identifier contains expected components + EXPECT_NE(id.find("256x256x32"), std::string::npos); // tile shape + EXPECT_NE(id.find("2x2x1"), std::string::npos); // wave shape + EXPECT_NE(id.find("32x32x16"), std::string::npos); // warp tile shape + EXPECT_NE(id.find("persist"), std::string::npos); // persistent flag +} + +TEST(KernelKeyTest, EncodeIdentifierWithFusion) +{ + KernelKey key; + key.signature.split_k = 1; + key.signature.elementwise_op = "Relu"; + key.signature.num_d_tensors = 2; + key.algorithm.tile_shape.m = 128; + key.algorithm.tile_shape.n = 128; + key.algorithm.tile_shape.k = 64; + key.algorithm.wave_shape.m = 2; + key.algorithm.wave_shape.n = 2; + key.algorithm.wave_shape.k = 1; + key.algorithm.warp_tile_shape.m = 16; + key.algorithm.warp_tile_shape.n = 16; + key.algorithm.warp_tile_shape.k = 32; + key.algorithm.persistent = false; + key.signature.structured_sparsity = false; + + std::string id = key.encode_identifier(); + + // Check fusion-specific components + EXPECT_NE(id.find("Relu"), std::string::npos); + EXPECT_NE(id.find("_d2"), std::string::npos); + EXPECT_NE(id.find("nopers"), std::string::npos); +} + +TEST(KernelKeyTest, EncodeIdentifierWithSplitK) +{ + KernelKey key; + key.signature.split_k = 4; + key.signature.elementwise_op = "PassThrough"; + key.signature.num_d_tensors = 0; + key.algorithm.tile_shape.m = 256; + key.algorithm.tile_shape.n = 256; + key.algorithm.tile_shape.k = 32; + key.algorithm.wave_shape.m = 2; + key.algorithm.wave_shape.n = 2; + key.algorithm.wave_shape.k = 1; + key.algorithm.warp_tile_shape.m = 32; + key.algorithm.warp_tile_shape.n = 32; + key.algorithm.warp_tile_shape.k = 16; + key.algorithm.persistent = false; + key.signature.structured_sparsity = false; + + std::string id = key.encode_identifier(); + + EXPECT_NE(id.find("_splitk4"), std::string::npos); +} + +TEST(KernelKeyTest, EncodeIdentifierWithSparsity) +{ + KernelKey key; + key.signature.split_k = 1; + key.signature.elementwise_op = "PassThrough"; + key.signature.num_d_tensors = 0; + key.signature.structured_sparsity = true; + key.algorithm.tile_shape.m = 256; + key.algorithm.tile_shape.n = 256; + key.algorithm.tile_shape.k = 32; + key.algorithm.wave_shape.m = 2; + key.algorithm.wave_shape.n = 2; + key.algorithm.wave_shape.k = 1; + key.algorithm.warp_tile_shape.m = 32; + key.algorithm.warp_tile_shape.n = 32; + key.algorithm.warp_tile_shape.k = 16; + key.algorithm.persistent = false; + + std::string id = key.encode_identifier(); + + EXPECT_NE(id.find("_sparse"), std::string::npos); +} diff --git a/dispatcher/tests/test_kernel_key_extended.cpp b/dispatcher/tests/test_kernel_key_extended.cpp new file mode 100644 index 00000000000..1c6b5bcba01 --- /dev/null +++ b/dispatcher/tests/test_kernel_key_extended.cpp @@ -0,0 +1,453 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +/// Extended unit tests for KernelKey - covers all data types, layouts, pipelines + +#include "ck_tile/dispatcher/kernel_key.hpp" +#include "test_mock_kernel.hpp" +#include +#include +#include + +using namespace ck_tile::dispatcher; +using namespace ck_tile::dispatcher::test; + +// ============================================================================= +// DataType Tests +// ============================================================================= + +class DataTypeTest : public ::testing::Test +{ + protected: + void SetUp() override {} +}; + +TEST_F(DataTypeTest, AllDataTypesExist) +{ + // Every DataType should be accessible + std::vector all_types = {DataType::FP16, + DataType::BF16, + DataType::FP32, + DataType::FP64, + DataType::INT8, + DataType::INT4, + DataType::INT32, + DataType::FP8, + DataType::BF8, + DataType::UNKNOWN}; + + EXPECT_EQ(all_types.size(), 10); +} + +TEST_F(DataTypeTest, DataTypesAreDifferent) +{ + EXPECT_NE(DataType::FP16, DataType::BF16); + EXPECT_NE(DataType::FP16, DataType::FP32); + EXPECT_NE(DataType::INT8, DataType::INT4); +} + +// ============================================================================= +// LayoutTag Tests +// ============================================================================= + +class LayoutTagTest : public ::testing::Test +{ +}; + +TEST_F(LayoutTagTest, AllLayoutsExist) +{ + std::vector all_layouts = { + LayoutTag::RowMajor, LayoutTag::ColMajor, LayoutTag::PackedExternal}; + + EXPECT_EQ(all_layouts.size(), 3); +} + +TEST_F(LayoutTagTest, LayoutsAreDifferent) { EXPECT_NE(LayoutTag::RowMajor, LayoutTag::ColMajor); } + +// ============================================================================= +// Pipeline Tests +// ============================================================================= + +class PipelineTest : public ::testing::Test +{ +}; + +TEST_F(PipelineTest, AllPipelinesExist) +{ + std::vector all_pipelines = {Pipeline::Mem, + Pipeline::CompV1, + Pipeline::CompV2, + Pipeline::CompV3, + Pipeline::CompV4, + Pipeline::CompV5, + Pipeline::PreShuffleV1, + Pipeline::PreShuffleV2}; + + EXPECT_EQ(all_pipelines.size(), 8); +} + +TEST_F(PipelineTest, PipelinesAreDifferent) +{ + EXPECT_NE(Pipeline::Mem, Pipeline::CompV4); + EXPECT_NE(Pipeline::CompV3, Pipeline::CompV4); +} + +// ============================================================================= +// Scheduler Tests +// ============================================================================= + +class SchedulerTest : public ::testing::Test +{ +}; + +TEST_F(SchedulerTest, AllSchedulersExist) +{ + std::vector all_schedulers = { + Scheduler::Auto, Scheduler::Intrawave, Scheduler::Interwave}; + + EXPECT_EQ(all_schedulers.size(), 3); +} + +// ============================================================================= +// Epilogue Tests +// ============================================================================= + +class EpilogueTest : public ::testing::Test +{ +}; + +TEST_F(EpilogueTest, AllEpiloguesExist) +{ + std::vector all_epilogues = {Epilogue::None, + Epilogue::Default, + Epilogue::CShuffle, + Epilogue::Bias, + Epilogue::Activation, + Epilogue::BiasActivation}; + + EXPECT_EQ(all_epilogues.size(), 6); +} + +// ============================================================================= +// KernelKey::Signature Tests +// ============================================================================= + +class SignatureTest : public ::testing::Test +{ + protected: + KernelKey::Signature CreateDefaultSignature() + { + KernelKey::Signature sig; + sig.dtype_a = DataType::FP16; + sig.dtype_b = DataType::FP16; + sig.dtype_c = DataType::FP16; + sig.dtype_acc = DataType::FP32; + sig.layout_a = LayoutTag::RowMajor; + sig.layout_b = LayoutTag::ColMajor; + sig.layout_c = LayoutTag::RowMajor; + sig.transpose_a = false; + sig.transpose_b = false; + sig.grouped = false; + sig.split_k = 1; + sig.elementwise_op = "PassThrough"; + sig.num_d_tensors = 0; + sig.structured_sparsity = false; + return sig; + } +}; + +TEST_F(SignatureTest, DefaultValuesAreReasonable) +{ + KernelKey::Signature sig = CreateDefaultSignature(); + EXPECT_EQ(sig.split_k, 1); + EXPECT_FALSE(sig.grouped); + EXPECT_FALSE(sig.structured_sparsity); +} + +TEST_F(SignatureTest, AllDataTypeCombinations) +{ + // Test various data type combinations that should be valid + std::vector> valid_combos = { + {DataType::FP16, DataType::FP16, DataType::FP16, DataType::FP32}, + {DataType::BF16, DataType::BF16, DataType::BF16, DataType::FP32}, + {DataType::FP32, DataType::FP32, DataType::FP32, DataType::FP32}, + {DataType::INT8, DataType::INT8, DataType::INT8, DataType::INT32}, + }; + + for(const auto& [a, b, c, acc] : valid_combos) + { + KernelKey::Signature sig; + sig.dtype_a = a; + sig.dtype_b = b; + sig.dtype_c = c; + sig.dtype_acc = acc; + + EXPECT_EQ(sig.dtype_a, a); + EXPECT_EQ(sig.dtype_b, b); + EXPECT_EQ(sig.dtype_c, c); + EXPECT_EQ(sig.dtype_acc, acc); + } +} + +TEST_F(SignatureTest, AllLayoutCombinations) +{ + std::vector layout_codes = { + "rrr", "rcr", "crr", "ccr", "rrc", "rcc", "crc", "ccc"}; + + for(const std::string& code : layout_codes) + { + KernelKey::Signature sig = CreateDefaultSignature(); + sig.layout_a = (code[0] == 'r') ? LayoutTag::RowMajor : LayoutTag::ColMajor; + sig.layout_b = (code[1] == 'r') ? LayoutTag::RowMajor : LayoutTag::ColMajor; + sig.layout_c = (code[2] == 'r') ? LayoutTag::RowMajor : LayoutTag::ColMajor; + + // Just verify assignment works + EXPECT_TRUE(sig.layout_a == LayoutTag::RowMajor || sig.layout_a == LayoutTag::ColMajor); + } +} + +TEST_F(SignatureTest, SplitKValues) +{ + KernelKey::Signature sig = CreateDefaultSignature(); + + std::vector valid_split_k = {1, 2, 4, 8, 16}; + for(auto sk : valid_split_k) + { + sig.split_k = sk; + EXPECT_EQ(sig.split_k, sk); + } +} + +// ============================================================================= +// KernelKey::Algorithm Tests +// ============================================================================= + +class AlgorithmTest : public ::testing::Test +{ + protected: + KernelKey::Algorithm CreateDefaultAlgorithm() + { + KernelKey::Algorithm algo; + algo.tile_shape = {256, 256, 32}; + algo.wave_shape = {2, 2, 1}; + algo.warp_tile_shape = {32, 32, 16}; + algo.pipeline = Pipeline::CompV4; + algo.scheduler = Scheduler::Intrawave; + algo.epilogue = Epilogue::CShuffle; + algo.block_size = 256; + algo.double_buffer = true; + algo.persistent = false; + algo.preshuffle = false; + algo.transpose_c = false; + algo.num_wave_groups = 1; + return algo; + } +}; + +TEST_F(AlgorithmTest, CommonTileShapes) +{ + std::vector> valid_tiles = { + {64, 64, 32}, + {128, 128, 32}, + {128, 128, 64}, + {256, 256, 32}, + {256, 256, 64}, + {256, 128, 32}, + {128, 256, 32}, + }; + + for(const auto& [m, n, k] : valid_tiles) + { + KernelKey::Algorithm algo = CreateDefaultAlgorithm(); + algo.tile_shape = {static_cast(m), + static_cast(n), + static_cast(k)}; + + EXPECT_EQ(algo.tile_shape.m, m); + EXPECT_EQ(algo.tile_shape.n, n); + EXPECT_EQ(algo.tile_shape.k, k); + } +} + +TEST_F(AlgorithmTest, CommonWarpConfigs) +{ + std::vector> valid_warps = { + {1, 4, 1}, + {2, 2, 1}, + {4, 1, 1}, + {1, 2, 1}, + {2, 1, 1}, + }; + + for(const auto& [m, n, k] : valid_warps) + { + KernelKey::Algorithm algo = CreateDefaultAlgorithm(); + algo.wave_shape = {static_cast(m), + static_cast(n), + static_cast(k)}; + + EXPECT_EQ(algo.wave_shape.m, m); + EXPECT_EQ(algo.wave_shape.n, n); + EXPECT_EQ(algo.wave_shape.k, k); + } +} + +TEST_F(AlgorithmTest, AllPipelines) +{ + KernelKey::Algorithm algo = CreateDefaultAlgorithm(); + + std::vector pipelines = {Pipeline::Mem, + Pipeline::CompV3, + Pipeline::CompV4, + Pipeline::PreShuffleV1, + Pipeline::PreShuffleV2}; + + for(Pipeline p : pipelines) + { + algo.pipeline = p; + EXPECT_EQ(algo.pipeline, p); + } +} + +// ============================================================================= +// KernelKey Identifier Encoding Tests +// ============================================================================= + +class IdentifierEncodingTest : public ::testing::Test +{ +}; + +TEST_F(IdentifierEncodingTest, UniqueIdentifiersForDifferentConfigs) +{ + std::set identifiers; + + // Generate multiple configurations + for(int tile_m : {128, 256}) + { + for(int wave_m : {1, 2, 4}) + { + for(bool persistent : {true, false}) + { + KernelKey key = make_test_key(tile_m); + key.algorithm.wave_shape.m = wave_m; + key.algorithm.persistent = persistent; + + std::string id = key.encode_identifier(); + EXPECT_TRUE(identifiers.find(id) == identifiers.end()) + << "Duplicate identifier: " << id; + identifiers.insert(id); + } + } + } + + // Should have generated 2 * 3 * 2 = 12 unique identifiers + EXPECT_EQ(identifiers.size(), 12); +} + +TEST_F(IdentifierEncodingTest, IdentifierContainsTileShape) +{ + KernelKey key = make_test_key(256, 128, 64); + std::string id = key.encode_identifier(); + + EXPECT_NE(id.find("256x128x64"), std::string::npos) + << "Identifier should contain tile shape: " << id; +} + +TEST_F(IdentifierEncodingTest, IdentifierContainsWarpConfig) +{ + KernelKey key = make_test_key(256); + key.algorithm.wave_shape = {4, 2, 1}; + std::string id = key.encode_identifier(); + + EXPECT_NE(id.find("4x2x1"), std::string::npos) + << "Identifier should contain warp config: " << id; +} + +TEST_F(IdentifierEncodingTest, IdentifierReflectsPersistence) +{ + KernelKey persistent_key = make_test_key(256); + persistent_key.algorithm.persistent = true; + + KernelKey non_persistent_key = make_test_key(256); + non_persistent_key.algorithm.persistent = false; + + std::string persistent_id = persistent_key.encode_identifier(); + std::string non_persistent_id = non_persistent_key.encode_identifier(); + + EXPECT_NE(persistent_id, non_persistent_id); + EXPECT_NE(persistent_id.find("persist"), std::string::npos); + EXPECT_NE(non_persistent_id.find("nopers"), std::string::npos); +} + +// ============================================================================= +// KernelKey Equality Tests +// ============================================================================= + +class KeyEqualityTest : public ::testing::Test +{ +}; + +TEST_F(KeyEqualityTest, IdenticalKeysAreEqual) +{ + KernelKey key1 = make_test_key(256, 256, 32, "gfx942"); + KernelKey key2 = make_test_key(256, 256, 32, "gfx942"); + + EXPECT_EQ(key1, key2); + EXPECT_FALSE(key1 != key2); +} + +TEST_F(KeyEqualityTest, DifferentTileShapesNotEqual) +{ + KernelKey key1 = make_test_key(256, 256, 32); + KernelKey key2 = make_test_key(128, 128, 32); + + EXPECT_NE(key1, key2); +} + +TEST_F(KeyEqualityTest, DifferentDataTypesNotEqual) +{ + KernelKey key1 = make_test_key(256); + KernelKey key2 = make_test_key(256); + key2.signature.dtype_a = DataType::BF16; + + EXPECT_NE(key1, key2); +} + +TEST_F(KeyEqualityTest, DifferentLayoutsNotEqual) +{ + KernelKey key1 = make_test_key(256); + KernelKey key2 = make_test_key(256); + key2.signature.layout_a = LayoutTag::ColMajor; + + EXPECT_NE(key1, key2); +} + +TEST_F(KeyEqualityTest, DifferentGfxArchNotEqual) +{ + KernelKey key1 = make_test_key(256, 256, 32, "gfx942"); + KernelKey key2 = make_test_key(256, 256, 32, "gfx90a"); + + EXPECT_NE(key1, key2); +} + +// ============================================================================= +// ElementwiseOps Tests +// ============================================================================= + +class ElementwiseOpsTest : public ::testing::Test +{ +}; + +TEST_F(ElementwiseOpsTest, CanUseInKernelKey) +{ + KernelKey key = make_test_key(256); + + key.signature.elementwise_op = "Relu"; + EXPECT_EQ(key.signature.elementwise_op, "Relu"); + + key.signature.elementwise_op = "Gelu"; + EXPECT_EQ(key.signature.elementwise_op, "Gelu"); + + key.signature.elementwise_op = "PassThrough"; + EXPECT_EQ(key.signature.elementwise_op, "PassThrough"); +} diff --git a/dispatcher/tests/test_minimal.cpp b/dispatcher/tests/test_minimal.cpp new file mode 100644 index 00000000000..22efc2524c5 --- /dev/null +++ b/dispatcher/tests/test_minimal.cpp @@ -0,0 +1,57 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +// Minimal test: Verify dispatcher can select and run a kernel +#include +#include +#include "ck_tile/dispatcher/dispatcher.hpp" +#include "ck_tile/dispatcher/registry.hpp" +#include "test_mock_kernel.hpp" + +using namespace ck_tile::dispatcher; +using namespace ck_tile::dispatcher::test; + +int main() +{ + std::cout << "Minimal Dispatcher Test\n"; + std::cout << "=======================\n\n"; + + // Create a mock kernel for testing + KernelKey key = make_test_key(128, 128, 64, "gfx942"); + auto kernel = std::make_shared(key, "test_kernel_128x128x64", true); + + // Register kernel + Registry::instance().clear(); + Registry::instance().register_kernel(kernel); + + std::cout << "OK Registered kernel: " << kernel->get_name() << "\n"; + + // Create dispatcher and problem + Dispatcher dispatcher; + Problem problem(1024, 1024, 1024); + + std::cout << "OK Created problem: M=" << problem.M << " N=" << problem.N << " K=" << problem.K + << "\n"; + + // Select kernel + auto selected = dispatcher.select_kernel(problem); + if(!selected) + { + std::cerr << "[FAIL] Failed to select kernel\n"; + return 1; + } + + std::cout << "OK Selected kernel: " << selected->get_name() << "\n"; + + // Mock execution (no actual GPU computation in mock kernel) + void* a_ptr = nullptr; + void* b_ptr = nullptr; + void* c_ptr = nullptr; + + float time = dispatcher.run(a_ptr, b_ptr, c_ptr, problem); + + std::cout << "OK Executed kernel: " << time << " ms\n"; + std::cout << "\n[OK] Minimal test passed!\n"; + + return 0; +} diff --git a/dispatcher/tests/test_mock_kernel.cpp b/dispatcher/tests/test_mock_kernel.cpp new file mode 100644 index 00000000000..fd8f3f4baa7 --- /dev/null +++ b/dispatcher/tests/test_mock_kernel.cpp @@ -0,0 +1,6 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "test_mock_kernel.hpp" + +// Empty file - implementation is in header diff --git a/dispatcher/tests/test_mock_kernel.hpp b/dispatcher/tests/test_mock_kernel.hpp new file mode 100644 index 00000000000..7d511719a8c --- /dev/null +++ b/dispatcher/tests/test_mock_kernel.hpp @@ -0,0 +1,134 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck_tile/dispatcher/kernel_instance.hpp" +#include "ck_tile/dispatcher/kernel_key.hpp" +#include "ck_tile/dispatcher/problem.hpp" +#include + +namespace ck_tile { +namespace dispatcher { +namespace test { + +/// Mock kernel instance for testing dispatcher functionality +/// Supports configurable behavior for testing different scenarios +class MockKernelInstance : public KernelInstance +{ + public: + /// Constructor + /// @param key Kernel configuration key + /// @param name Human-readable kernel name + /// @param supports_all Whether this kernel supports all problems (default: true) + explicit MockKernelInstance(const KernelKey& key, + const std::string& name, + bool supports_all = true) + : key_(key), name_(name), supports_all_(supports_all), execution_count_(0) + { + } + + const KernelKey& get_key() const override { return key_; } + + bool supports(const Problem& problem) const override + { + if(supports_all_) + { + return problem.is_valid(); + } + // For testing: only support problems where M/N/K are divisible by tile sizes + return problem.is_valid() && (problem.M % key_.algorithm.tile_shape.m == 0) && + (problem.N % key_.algorithm.tile_shape.n == 0) && + (problem.K % key_.algorithm.tile_shape.k == 0); + } + + std::string get_name() const override { return name_; } + + float run(const void* a_ptr, + const void* b_ptr, + void* c_ptr, + const void** d_ptrs, + const Problem& problem, + void* stream) const override + { + execution_count_++; + // Simulate execution time (1ms for testing) + return 1.0f; + } + + bool validate(const void* a_ptr, + const void* b_ptr, + const void* c_ptr, + const void** d_ptrs, + const Problem& problem, + float tolerance) const override + { + // Mock validation always passes + return true; + } + + /// Get execution count (for testing) + int get_execution_count() const { return execution_count_; } + + /// Reset execution count + void reset_execution_count() { execution_count_ = 0; } + + /// Set whether this kernel supports all problems + void set_supports_all(bool supports_all) { supports_all_ = supports_all; } + + private: + KernelKey key_; + std::string name_; + bool supports_all_; + mutable int execution_count_; +}; + +/// Helper function to create a test kernel key +inline KernelKey make_test_key(std::uint16_t tile_m = 256, + std::uint16_t tile_n = 256, + std::uint16_t tile_k = 32, + const std::string& gfx_arch = "gfx942") +{ + KernelKey key; + key.signature.dtype_a = DataType::FP16; + key.signature.dtype_b = DataType::FP16; + key.signature.dtype_c = DataType::FP16; + key.signature.dtype_acc = DataType::FP32; + key.signature.layout_a = LayoutTag::RowMajor; + key.signature.layout_b = LayoutTag::ColMajor; + key.signature.layout_c = LayoutTag::RowMajor; + key.signature.transpose_a = false; + key.signature.transpose_b = false; + key.signature.grouped = false; + key.signature.split_k = 1; + key.signature.elementwise_op = "PassThrough"; + key.signature.num_d_tensors = 0; + key.signature.structured_sparsity = false; + + key.algorithm.tile_shape.m = tile_m; + key.algorithm.tile_shape.n = tile_n; + key.algorithm.tile_shape.k = tile_k; + key.algorithm.wave_shape.m = 2; + key.algorithm.wave_shape.n = 2; + key.algorithm.wave_shape.k = 1; + key.algorithm.warp_tile_shape.m = 32; + key.algorithm.warp_tile_shape.n = 32; + key.algorithm.warp_tile_shape.k = 16; + key.algorithm.pipeline = Pipeline::CompV4; + key.algorithm.scheduler = Scheduler::Intrawave; + key.algorithm.epilogue = Epilogue::CShuffle; + key.algorithm.block_size = 256; + key.algorithm.double_buffer = true; + key.algorithm.persistent = false; + key.algorithm.preshuffle = false; + key.algorithm.transpose_c = false; + key.algorithm.num_wave_groups = 1; + + key.gfx_arch = gfx_arch; + + return key; +} + +} // namespace test +} // namespace dispatcher +} // namespace ck_tile diff --git a/dispatcher/tests/test_problem.cpp b/dispatcher/tests/test_problem.cpp new file mode 100644 index 00000000000..7d5500e320e --- /dev/null +++ b/dispatcher/tests/test_problem.cpp @@ -0,0 +1,96 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +/// Unit tests for Problem using Google Test + +#include "ck_tile/dispatcher/problem.hpp" +#include + +using namespace ck_tile::dispatcher; + +TEST(ProblemTest, DefaultConstruction) +{ + Problem p; + EXPECT_EQ(p.M, 0); + EXPECT_EQ(p.N, 0); + EXPECT_EQ(p.K, 0); + EXPECT_EQ(p.k_batch, 1); + EXPECT_FALSE(p.is_valid()); +} + +TEST(ProblemTest, ConstructorWithDimensions) +{ + Problem p(1024, 1024, 1024); + EXPECT_EQ(p.M, 1024); + EXPECT_EQ(p.N, 1024); + EXPECT_EQ(p.K, 1024); + EXPECT_TRUE(p.is_valid()); +} + +TEST(ProblemTest, Validation) +{ + Problem p; + + // Invalid: all zeros + p.M = 0; + p.N = 0; + p.K = 0; + EXPECT_FALSE(p.is_valid()); + + // Invalid: negative + p.M = -1; + p.N = 1024; + p.K = 1024; + EXPECT_FALSE(p.is_valid()); + + // Invalid: zero K + p.M = 1024; + p.N = 1024; + p.K = 0; + EXPECT_FALSE(p.is_valid()); + + // Valid + p.M = 1024; + p.N = 1024; + p.K = 1024; + EXPECT_TRUE(p.is_valid()); + + // Invalid k_batch + p.k_batch = 0; + EXPECT_FALSE(p.is_valid()); + + p.k_batch = 1; + EXPECT_TRUE(p.is_valid()); +} + +TEST(ProblemTest, NumOps) +{ + Problem p(100, 200, 300); + + // 2 * M * N * K (multiply-add = 2 ops) + std::int64_t expected = 2 * 100 * 200 * 300; + EXPECT_EQ(p.num_ops(), expected); +} + +TEST(ProblemTest, Configuration) +{ + Problem p(1024, 1024, 1024); + + // Set preferences + p.prefer_persistent = true; + p.enable_validation = true; + p.smem_budget = 65536; + p.k_batch = 2; + + EXPECT_TRUE(p.prefer_persistent); + EXPECT_TRUE(p.enable_validation); + EXPECT_EQ(p.smem_budget, 65536); + EXPECT_EQ(p.k_batch, 2); +} + +TEST(ProblemTest, LargeDimensions) +{ + Problem p(1024, 1024, 1024); // Use smaller but still large dimensions + EXPECT_TRUE(p.is_valid()); + EXPECT_GT(p.num_ops(), 0); +} diff --git a/dispatcher/tests/test_problem_extended.cpp b/dispatcher/tests/test_problem_extended.cpp new file mode 100644 index 00000000000..21ea5452921 --- /dev/null +++ b/dispatcher/tests/test_problem_extended.cpp @@ -0,0 +1,457 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +/// Extended unit tests for Problem - covers dimension inference, validation, edge cases + +#include "ck_tile/dispatcher/problem.hpp" +#include +#include + +using namespace ck_tile::dispatcher; + +// ============================================================================= +// Dimension Inference Tests +// ============================================================================= + +class ProblemDimensionInferenceTest : public ::testing::Test +{ +}; + +TEST_F(ProblemDimensionInferenceTest, FromAB_Basic) +{ + // A: M×K (1024×512), B: K×N (512×2048) + auto problem = Problem::from_ab(1024, 512, 512, 2048); + + EXPECT_EQ(problem.M, 1024); + EXPECT_EQ(problem.N, 2048); + EXPECT_EQ(problem.K, 512); + EXPECT_TRUE(problem.is_valid()); +} + +TEST_F(ProblemDimensionInferenceTest, FromDimensions_Valid) +{ + // A: 1024×512, B: 512×2048, C: 1024×2048 + auto problem = Problem::from_dimensions(1024, 512, 512, 2048, 1024, 2048); + + EXPECT_EQ(problem.M, 1024); + EXPECT_EQ(problem.N, 2048); + EXPECT_EQ(problem.K, 512); + EXPECT_TRUE(problem.is_valid()); +} + +TEST_F(ProblemDimensionInferenceTest, FromShapes_WithC) +{ + TensorShape A{1024, 512, false}; + TensorShape B{512, 2048, false}; + TensorShape C{1024, 2048, false}; + + auto problem = Problem::from_shapes(A, B, C); + + EXPECT_EQ(problem.M, 1024); + EXPECT_EQ(problem.N, 2048); + EXPECT_EQ(problem.K, 512); + EXPECT_TRUE(problem.is_valid()); +} + +TEST_F(ProblemDimensionInferenceTest, FromShapes_TransposedA) +{ + // A stored as K×M (transposed) + TensorShape A{512, 1024, true}; + TensorShape B{512, 2048, false}; + TensorShape C{1024, 2048, false}; + + auto problem = Problem::from_shapes(A, B, C); + + EXPECT_EQ(problem.M, 1024); + EXPECT_EQ(problem.N, 2048); + EXPECT_EQ(problem.K, 512); +} + +TEST_F(ProblemDimensionInferenceTest, FromShapes_TransposedB) +{ + TensorShape A{1024, 512, false}; + // B stored as N×K (transposed) + TensorShape B{2048, 512, true}; + TensorShape C{1024, 2048, false}; + + auto problem = Problem::from_shapes(A, B, C); + + EXPECT_EQ(problem.M, 1024); + EXPECT_EQ(problem.N, 2048); + EXPECT_EQ(problem.K, 512); +} + +// ============================================================================= +// Validation Tests +// ============================================================================= + +class ProblemValidationTest : public ::testing::Test +{ +}; + +TEST_F(ProblemValidationTest, ValidProblem) +{ + Problem p(1024, 1024, 1024); + EXPECT_TRUE(p.is_valid()); +} + +TEST_F(ProblemValidationTest, ZeroM) +{ + Problem p(0, 1024, 1024); + EXPECT_FALSE(p.is_valid()); +} + +TEST_F(ProblemValidationTest, ZeroN) +{ + Problem p(1024, 0, 1024); + EXPECT_FALSE(p.is_valid()); +} + +TEST_F(ProblemValidationTest, ZeroK) +{ + Problem p(1024, 1024, 0); + EXPECT_FALSE(p.is_valid()); +} + +TEST_F(ProblemValidationTest, NegativeM) +{ + Problem p; + p.M = -1; + p.N = 1024; + p.K = 1024; + EXPECT_FALSE(p.is_valid()); +} + +TEST_F(ProblemValidationTest, ZeroKBatch) +{ + Problem p(1024, 1024, 1024); + p.k_batch = 0; + EXPECT_FALSE(p.is_valid()); +} + +TEST_F(ProblemValidationTest, ValidKBatch) +{ + Problem p(1024, 1024, 1024); + p.k_batch = 4; + EXPECT_TRUE(p.is_valid()); +} + +// ============================================================================= +// num_ops Tests +// ============================================================================= + +class ProblemNumOpsTest : public ::testing::Test +{ +}; + +TEST_F(ProblemNumOpsTest, SmallProblem) +{ + Problem p(10, 20, 30); + // 2 * M * N * K = 2 * 10 * 20 * 30 = 12000 + EXPECT_EQ(p.num_ops(), 12000); +} + +TEST_F(ProblemNumOpsTest, SymmetricProblem) +{ + Problem p(1024, 1024, 1024); + // 2 * 1024^3 = 2,147,483,648 + EXPECT_EQ(p.num_ops(), 2LL * 1024 * 1024 * 1024); +} + +TEST_F(ProblemNumOpsTest, AsymmetricProblem) +{ + Problem p(512, 2048, 256); + EXPECT_EQ(p.num_ops(), 2LL * 512 * 2048 * 256); +} + +TEST_F(ProblemNumOpsTest, LargeProblem) +{ + Problem p(4096, 4096, 4096); + std::int64_t expected = 2LL * 4096 * 4096 * 4096; + EXPECT_EQ(p.num_ops(), expected); + EXPECT_GT(p.num_ops(), 0); // No overflow +} + +// ============================================================================= +// Edge Cases +// ============================================================================= + +class ProblemEdgeCasesTest : public ::testing::Test +{ +}; + +TEST_F(ProblemEdgeCasesTest, MinimumValidSize) +{ + Problem p(1, 1, 1); + EXPECT_TRUE(p.is_valid()); + EXPECT_EQ(p.num_ops(), 2); +} + +TEST_F(ProblemEdgeCasesTest, NonSquare_TallMatrix) +{ + Problem p(8192, 64, 1024); + EXPECT_TRUE(p.is_valid()); +} + +TEST_F(ProblemEdgeCasesTest, NonSquare_WideMatrix) +{ + Problem p(64, 8192, 1024); + EXPECT_TRUE(p.is_valid()); +} + +TEST_F(ProblemEdgeCasesTest, NonSquare_DeepK) +{ + Problem p(1024, 1024, 8192); + EXPECT_TRUE(p.is_valid()); +} + +TEST_F(ProblemEdgeCasesTest, SmallK) +{ + Problem p(1024, 1024, 16); + EXPECT_TRUE(p.is_valid()); +} + +TEST_F(ProblemEdgeCasesTest, NonPowerOf2Dimensions) +{ + Problem p(1000, 2000, 300); + EXPECT_TRUE(p.is_valid()); + EXPECT_EQ(p.num_ops(), 2LL * 1000 * 2000 * 300); +} + +TEST_F(ProblemEdgeCasesTest, PrimeDimensions) +{ + Problem p(997, 1009, 1013); // All prime numbers + EXPECT_TRUE(p.is_valid()); +} + +// ============================================================================= +// Configuration Tests +// ============================================================================= + +class ProblemConfigurationTest : public ::testing::Test +{ +}; + +TEST_F(ProblemConfigurationTest, DefaultConfiguration) +{ + Problem p(1024, 1024, 1024); + + EXPECT_FALSE(p.prefer_persistent); + EXPECT_FALSE(p.enable_validation); + EXPECT_EQ(p.smem_budget, 0); + EXPECT_EQ(p.k_batch, 1); +} + +TEST_F(ProblemConfigurationTest, SetPersistentPreference) +{ + Problem p(1024, 1024, 1024); + p.prefer_persistent = true; + + EXPECT_TRUE(p.prefer_persistent); + EXPECT_TRUE(p.is_valid()); +} + +TEST_F(ProblemConfigurationTest, SetSmemBudget) +{ + Problem p(1024, 1024, 1024); + p.smem_budget = 65536; // 64KB + + EXPECT_EQ(p.smem_budget, 65536); + EXPECT_TRUE(p.is_valid()); +} + +TEST_F(ProblemConfigurationTest, SetKBatch) +{ + Problem p(1024, 1024, 1024); + + for(int kb : {1, 2, 4, 8, 16}) + { + p.k_batch = kb; + EXPECT_EQ(p.k_batch, kb); + EXPECT_TRUE(p.is_valid()); + } +} + +// ============================================================================= +// Copy and Assignment Tests +// ============================================================================= + +class ProblemCopyTest : public ::testing::Test +{ +}; + +TEST_F(ProblemCopyTest, CopyConstruction) +{ + Problem p1(1024, 2048, 512); + p1.prefer_persistent = true; + p1.k_batch = 4; + + Problem p2(p1); + + EXPECT_EQ(p2.M, 1024); + EXPECT_EQ(p2.N, 2048); + EXPECT_EQ(p2.K, 512); + EXPECT_TRUE(p2.prefer_persistent); + EXPECT_EQ(p2.k_batch, 4); +} + +TEST_F(ProblemCopyTest, Assignment) +{ + Problem p1(1024, 2048, 512); + Problem p2(256, 256, 256); + + p2 = p1; + + EXPECT_EQ(p2.M, 1024); + EXPECT_EQ(p2.N, 2048); + EXPECT_EQ(p2.K, 512); +} + +// ============================================================================= +// Builder Tests +// ============================================================================= + +class ProblemBuilderTest : public ::testing::Test +{ +}; + +TEST_F(ProblemBuilderTest, BasicBuild) +{ + auto problem = ProblemBuilder().dimensions(1024, 2048, 512).build(); + + EXPECT_EQ(problem.M, 1024); + EXPECT_EQ(problem.N, 2048); + EXPECT_EQ(problem.K, 512); + EXPECT_TRUE(problem.is_valid()); +} + +TEST_F(ProblemBuilderTest, WithSplitK) +{ + auto problem = ProblemBuilder().dimensions(1024, 1024, 1024).split_k(4).build(); + + EXPECT_EQ(problem.k_batch, 4); +} + +TEST_F(ProblemBuilderTest, WithPersistent) +{ + auto problem = ProblemBuilder().dimensions(1024, 1024, 1024).persistent(true).build(); + + EXPECT_TRUE(problem.prefer_persistent); +} + +TEST_F(ProblemBuilderTest, WithSmemBudget) +{ + auto problem = ProblemBuilder().dimensions(1024, 1024, 1024).smem_budget(65536).build(); + + EXPECT_EQ(problem.smem_budget, 65536); +} + +TEST_F(ProblemBuilderTest, ChainedConfiguration) +{ + auto problem = ProblemBuilder() + .dimensions(2048, 2048, 1024) + .split_k(2) + .persistent(true) + .smem_budget(32768) + .validate(true) + .build(); + + EXPECT_EQ(problem.M, 2048); + EXPECT_EQ(problem.N, 2048); + EXPECT_EQ(problem.K, 1024); + EXPECT_EQ(problem.k_batch, 2); + EXPECT_TRUE(problem.prefer_persistent); + EXPECT_EQ(problem.smem_budget, 32768); + EXPECT_TRUE(problem.enable_validation); +} + +TEST_F(ProblemBuilderTest, FromAB) +{ + auto problem = ProblemBuilder().from_ab(1024, 512, 512, 2048).build(); + + EXPECT_EQ(problem.M, 1024); + EXPECT_EQ(problem.N, 2048); + EXPECT_EQ(problem.K, 512); +} + +// ============================================================================= +// Dimension Mismatch Error Tests +// ============================================================================= + +class ProblemDimensionErrorTest : public ::testing::Test +{ +}; + +TEST_F(ProblemDimensionErrorTest, KMismatchThrows) +{ + EXPECT_THROW((void)Problem::from_ab(1024, 512, 256, 2048), // K mismatch: 512 vs 256 + std::invalid_argument); +} + +TEST_F(ProblemDimensionErrorTest, MDimensionMismatchThrows) +{ + TensorShape A{1024, 512, false}; + TensorShape B{512, 2048, false}; + TensorShape C{512, 2048, false}; // M mismatch: A says M=1024, C says M=512 + + EXPECT_THROW((void)Problem::from_shapes(A, B, C), std::invalid_argument); +} + +TEST_F(ProblemDimensionErrorTest, NDimensionMismatchThrows) +{ + TensorShape A{1024, 512, false}; + TensorShape B{512, 2048, false}; + TensorShape C{1024, 1024, false}; // N mismatch: B says N=2048, C says N=1024 + + EXPECT_THROW((void)Problem::from_shapes(A, B, C), std::invalid_argument); +} + +// ============================================================================= +// Validate Sizes Tests +// ============================================================================= + +class ProblemValidateSizesTest : public ::testing::Test +{ +}; + +TEST_F(ProblemValidateSizesTest, CorrectSizes) +{ + Problem p(1024, 2048, 512); + + // This should not throw + EXPECT_NO_THROW(p.validate_sizes(1024 * 512, // A size + 512 * 2048, // B size + 1024 * 2048 // C size + )); +} + +TEST_F(ProblemValidateSizesTest, WrongASizeThrows) +{ + Problem p(1024, 2048, 512); + + EXPECT_THROW(p.validate_sizes(1024 * 256, // Wrong A size + 512 * 2048, + 1024 * 2048), + std::invalid_argument); +} + +TEST_F(ProblemValidateSizesTest, WrongBSizeThrows) +{ + Problem p(1024, 2048, 512); + + EXPECT_THROW(p.validate_sizes(1024 * 512, + 256 * 2048, // Wrong B size + 1024 * 2048), + std::invalid_argument); +} + +TEST_F(ProblemValidateSizesTest, WrongCSizeThrows) +{ + Problem p(1024, 2048, 512); + + EXPECT_THROW(p.validate_sizes(1024 * 512, + 512 * 2048, + 512 * 1024 // Wrong C size + ), + std::invalid_argument); +} diff --git a/dispatcher/tests/test_real_kernel_correctness.cpp b/dispatcher/tests/test_real_kernel_correctness.cpp new file mode 100644 index 00000000000..e753f04e191 --- /dev/null +++ b/dispatcher/tests/test_real_kernel_correctness.cpp @@ -0,0 +1,232 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +/** + * Correctness test with real GPU kernel + * Validates GPU results against CPU reference implementation + */ + +#include +#include +#include +#include +#include +#include + +#include "ck_tile/dispatcher/dispatcher.hpp" +#include "ck_tile/dispatcher/registry.hpp" +#include "ck_tile/dispatcher/backends/generated_tile_backend.hpp" + +// Kernel header included via -include compiler flag + +using namespace ck_tile::dispatcher; +using namespace ck_tile::dispatcher::backends; +using Priority = ck_tile::dispatcher::Registry::Priority; + +#define HIP_CHECK(call) \ + { \ + hipError_t err = call; \ + if(err != hipSuccess) \ + { \ + std::cerr << "HIP Error: " << hipGetErrorString(err) << "\n"; \ + exit(1); \ + } \ + } + +// CPU reference GEMM +// A: RowMajor (M x K) - A[m,k] = A[m*K + k] +// B: ColumnMajor (K x N) - B[k,n] = B[k + n*K] +// C: RowMajor (M x N) - C[m,n] = C[m*N + n] +template +void cpu_gemm( + const std::vector& A, const std::vector& B, std::vector& C, int M, int N, int K) +{ + for(int m = 0; m < M; m++) + { + for(int n = 0; n < N; n++) + { + float acc = 0.0f; + for(int k = 0; k < K; k++) + { + // A is row-major: A[m,k] = A[m*K + k] + // B is column-major: B[k,n] = B[k + n*K] + acc += float(A[m * K + k]) * float(B[k + n * K]); + } + C[m * N + n] = T(acc); + } + } +} + +int main() +{ + std::cout << "=======================================\n"; + std::cout << "Correctness Test - Real GPU Kernel\n"; + std::cout << "=======================================\n\n"; + + std::cout << "Kernel: " << KERNEL_NAME << "\n\n"; + + // Register kernel + KernelKey key; + key.signature.dtype_a = DataType::FP16; + key.signature.dtype_b = DataType::FP16; + key.signature.dtype_c = DataType::FP16; + key.signature.dtype_acc = DataType::FP32; + key.signature.layout_a = LayoutTag::RowMajor; + key.signature.layout_b = LayoutTag::ColMajor; + key.signature.layout_c = LayoutTag::RowMajor; + key.signature.transpose_a = false; + key.signature.transpose_b = false; + key.signature.grouped = false; + key.signature.split_k = 1; + key.signature.elementwise_op = "PassThrough"; + key.signature.num_d_tensors = 0; + key.signature.structured_sparsity = false; + + key.algorithm.tile_shape = {128, 128, 32}; + key.algorithm.wave_shape = {2, 2, 1}; + key.algorithm.warp_tile_shape = {32, 32, 16}; + key.algorithm.pipeline = Pipeline::CompV4; + key.algorithm.scheduler = Scheduler::Intrawave; + key.algorithm.epilogue = Epilogue::CShuffle; + key.algorithm.block_size = 256; + key.algorithm.double_buffer = true; + key.algorithm.persistent = false; + key.algorithm.preshuffle = false; + key.algorithm.transpose_c = false; + key.algorithm.num_wave_groups = 1; + key.gfx_arch = "gfx942"; + + auto kernel = + create_generated_tile_kernel( + key, KERNEL_NAME); + + Registry::instance().clear(); + Registry::instance().register_kernel(kernel, Priority::High); + + Dispatcher dispatcher; + + // Test with random matrices + const int M = 256; + const int N = 256; + const int K = 256; + + std::cout << "Test configuration:\n"; + std::cout << " Problem: M=" << M << " N=" << N << " K=" << K << "\n"; + std::cout << " Method: Random matrices vs CPU reference\n\n"; + + // Random number generation + std::mt19937 rng(42); // Fixed seed for reproducibility + std::uniform_real_distribution dist(-1.0f, 1.0f); + + std::vector A_host(M * K); + std::vector B_host(K * N); + std::vector C_gpu(M * N); + std::vector C_cpu(M * N); + + // Initialize with random values + std::cout << "Initializing random matrices...\n"; + for(int i = 0; i < M * K; i++) + { + A_host[i] = ADataType(dist(rng)); + } + for(int i = 0; i < K * N; i++) + { + B_host[i] = BDataType(dist(rng)); + } + + // GPU execution + std::cout << "Executing on GPU...\n"; + + ADataType *A_dev, *B_dev; + CDataType* C_dev; + + HIP_CHECK(hipMalloc(&A_dev, M * K * sizeof(ADataType))); + HIP_CHECK(hipMalloc(&B_dev, K * N * sizeof(BDataType))); + HIP_CHECK(hipMalloc(&C_dev, M * N * sizeof(CDataType))); + + HIP_CHECK(hipMemcpy(A_dev, A_host.data(), M * K * sizeof(ADataType), hipMemcpyHostToDevice)); + HIP_CHECK(hipMemcpy(B_dev, B_host.data(), K * N * sizeof(BDataType), hipMemcpyHostToDevice)); + HIP_CHECK(hipMemset(C_dev, 0, M * N * sizeof(CDataType))); + + Problem problem(M, N, K); + float gpu_time = dispatcher.run(A_dev, B_dev, C_dev, problem); + + HIP_CHECK(hipMemcpy(C_gpu.data(), C_dev, M * N * sizeof(CDataType), hipMemcpyDeviceToHost)); + + std::cout << "OK GPU execution complete: " << gpu_time << " ms\n"; + + double flops = 2.0 * M * N * K; + double tflops = (flops / (gpu_time * 1e-3)) / 1e12; + std::cout << "OK GPU performance: " << tflops << " TFLOPS\n\n"; + + // CPU reference + std::cout << "Computing CPU reference...\n"; + cpu_gemm(A_host, B_host, C_cpu, M, N, K); + std::cout << "OK CPU reference complete\n\n"; + + // Validation + std::cout << "Validating results...\n"; + + int num_correct = 0; + float max_rel_error = 0.0f; + float max_abs_error = 0.0f; + const float tolerance = 0.02f; // 2% for FP16 + + for(int i = 0; i < M * N; i++) + { + float gpu_val = float(C_gpu[i]); + float cpu_val = float(C_cpu[i]); + + float abs_error = std::abs(gpu_val - cpu_val); + float rel_error = abs_error / (std::abs(cpu_val) + 1e-5f); + + max_abs_error = std::max(max_abs_error, abs_error); + max_rel_error = std::max(max_rel_error, rel_error); + + if(rel_error < tolerance) + { + num_correct++; + } + } + + float accuracy = 100.0f * num_correct / (M * N); + + std::cout << "\nValidation Results:\n"; + std::cout << " Correct elements: " << num_correct << "/" << M * N << "\n"; + std::cout << " Accuracy: " << accuracy << "%\n"; + std::cout << " Max absolute error: " << max_abs_error << "\n"; + std::cout << " Max relative error: " << max_rel_error << "\n"; + std::cout << " Tolerance: " << tolerance << " (2%)\n\n"; + + // Show sample comparisons + std::cout << "Sample results (first 5 elements):\n"; + std::cout << " Index | GPU Result | CPU Result | Error\n"; + std::cout << " ------|------------|------------|-------\n"; + + for(int i = 0; i < 5; i++) + { + float gpu_val = float(C_gpu[i]); + float cpu_val = float(C_cpu[i]); + float error = std::abs(gpu_val - cpu_val); + printf(" %-5d | %10.4f | %10.4f | %.4f\n", i, gpu_val, cpu_val, error); + } + std::cout << "\n"; + + // Cleanup + HIP_CHECK(hipFree(A_dev)); + HIP_CHECK(hipFree(B_dev)); + HIP_CHECK(hipFree(C_dev)); + + if(accuracy > 99.0f) + { + std::cout << "[OK] CORRECTNESS TEST PASSED\n"; + std::cout << " GPU results match CPU reference within tolerance\n"; + return 0; + } + else + { + std::cout << "[FAIL] CORRECTNESS TEST FAILED\n"; + std::cout << " Accuracy too low: " << accuracy << "%\n"; + return 1; + } +} diff --git a/dispatcher/tests/test_real_kernel_multi_size.cpp b/dispatcher/tests/test_real_kernel_multi_size.cpp new file mode 100644 index 00000000000..f23f6846313 --- /dev/null +++ b/dispatcher/tests/test_real_kernel_multi_size.cpp @@ -0,0 +1,213 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +/** + * Multi-size real kernel test: Test multiple problem sizes with real GPU kernel + */ + +#include +#include +#include +#include +#include + +#include "ck_tile/dispatcher/dispatcher.hpp" +#include "ck_tile/dispatcher/registry.hpp" +#include "ck_tile/dispatcher/backends/generated_tile_backend.hpp" + +// Kernel header included via -include compiler flag + +using namespace ck_tile::dispatcher; +using namespace ck_tile::dispatcher::backends; +using Priority = ck_tile::dispatcher::Registry::Priority; + +#define HIP_CHECK(call) \ + { \ + hipError_t err = call; \ + if(err != hipSuccess) \ + { \ + std::cerr << "HIP Error: " << hipGetErrorString(err) << "\n"; \ + exit(1); \ + } \ + } + +struct TestResult +{ + int M, N, K; + float time_ms; + double tflops; + int correct; + int total; + bool passed; +}; + +TestResult run_test(Dispatcher& dispatcher, int M, int N, int K) +{ + TestResult result = {M, N, K, 0.0f, 0.0, 0, M * N, false}; + + // Allocate and prepare data + std::vector A_host(M * K); + std::vector B_host(K * N); + std::vector C_gpu(M * N); + + // Initialize: A=1, B=1, expected C=K + for(int i = 0; i < M * K; i++) + A_host[i] = ADataType(1.0f); + for(int i = 0; i < K * N; i++) + B_host[i] = BDataType(1.0f); + + ADataType *A_dev, *B_dev; + CDataType* C_dev; + + HIP_CHECK(hipMalloc(&A_dev, M * K * sizeof(ADataType))); + HIP_CHECK(hipMalloc(&B_dev, K * N * sizeof(BDataType))); + HIP_CHECK(hipMalloc(&C_dev, M * N * sizeof(CDataType))); + + HIP_CHECK(hipMemcpy(A_dev, A_host.data(), M * K * sizeof(ADataType), hipMemcpyHostToDevice)); + HIP_CHECK(hipMemcpy(B_dev, B_host.data(), K * N * sizeof(BDataType), hipMemcpyHostToDevice)); + HIP_CHECK(hipMemset(C_dev, 0, M * N * sizeof(CDataType))); + + // Execute + Problem problem(M, N, K); + result.time_ms = dispatcher.run(A_dev, B_dev, C_dev, problem); + + // Calculate performance + double flops = 2.0 * M * N * K; + result.tflops = (flops / (result.time_ms * 1e-3)) / 1e12; + + // Copy result and validate + HIP_CHECK(hipMemcpy(C_gpu.data(), C_dev, M * N * sizeof(CDataType), hipMemcpyDeviceToHost)); + + for(int i = 0; i < M * N; i++) + { + if(std::abs(float(C_gpu[i]) - float(K)) < 1.0f) + { + result.correct++; + } + } + + result.passed = (result.correct == result.total); + + HIP_CHECK(hipFree(A_dev)); + HIP_CHECK(hipFree(B_dev)); + HIP_CHECK(hipFree(C_dev)); + + return result; +} + +int main() +{ + std::cout << "=======================================\n"; + std::cout << "Multi-Size Real Kernel Test\n"; + std::cout << "=======================================\n\n"; + + std::cout << "Using kernel: " << KERNEL_NAME << "\n\n"; + + // Register kernel + KernelKey key; + key.signature.dtype_a = DataType::FP16; + key.signature.dtype_b = DataType::FP16; + key.signature.dtype_c = DataType::FP16; + key.signature.dtype_acc = DataType::FP32; + key.signature.layout_a = LayoutTag::RowMajor; + key.signature.layout_b = LayoutTag::ColMajor; + key.signature.layout_c = LayoutTag::RowMajor; + key.signature.transpose_a = false; + key.signature.transpose_b = false; + key.signature.grouped = false; + key.signature.split_k = 1; + key.signature.elementwise_op = "PassThrough"; + key.signature.num_d_tensors = 0; + key.signature.structured_sparsity = false; + + key.algorithm.tile_shape = {128, 128, 32}; + key.algorithm.wave_shape = {2, 2, 1}; + key.algorithm.warp_tile_shape = {32, 32, 16}; + key.algorithm.pipeline = Pipeline::CompV4; + key.algorithm.scheduler = Scheduler::Intrawave; + key.algorithm.epilogue = Epilogue::CShuffle; + key.algorithm.block_size = 256; + key.algorithm.double_buffer = true; + key.algorithm.persistent = false; + key.algorithm.preshuffle = false; + key.algorithm.transpose_c = false; + key.algorithm.num_wave_groups = 1; + key.gfx_arch = "gfx942"; + + auto kernel = + create_generated_tile_kernel( + key, KERNEL_NAME); + + Registry::instance().clear(); + Registry::instance().register_kernel(kernel, Priority::High); + + Dispatcher dispatcher; + + std::cout << "Running tests on multiple problem sizes...\n"; + std::cout << "===========================================\n\n"; + + // Test various sizes (all multiples of tile size) + std::vector> test_sizes = { + {128, 128, 128}, // Small + {256, 256, 256}, // Medium + {512, 512, 512}, // Large + {1024, 1024, 1024}, // Very large + {128, 512, 256}, // Non-square + {512, 128, 384}, // Non-square + }; + + std::vector results; + int num_passed = 0; + + for(const auto& [M, N, K] : test_sizes) + { + std::cout << "Testing M=" << M << " N=" << N << " K=" << K << "...\n"; + + auto result = run_test(dispatcher, M, N, K); + results.push_back(result); + + std::cout << " Time: " << result.time_ms << " ms\n"; + std::cout << " Performance: " << result.tflops << " TFLOPS\n"; + std::cout << " Accuracy: " << (100.0f * result.correct / result.total) << "%\n"; + std::cout << " Status: " << (result.passed ? "[OK] PASS" : "[FAIL] FAIL") << "\n\n"; + + if(result.passed) + num_passed++; + } + + // Summary + std::cout << "===========================================\n"; + std::cout << "Summary\n"; + std::cout << "===========================================\n\n"; + + std::cout << "Results by size:\n"; + std::cout << " Size | Time (ms) | TFLOPS | Accuracy | Status\n"; + std::cout << " ---------------|-----------|--------|----------|--------\n"; + + for(const auto& r : results) + { + char size_str[32]; + snprintf(size_str, sizeof(size_str), "%4d×%4d×%4d", r.M, r.N, r.K); + + printf(" %-14s | %9.4f | %6.2f | %7.2f%% | %s\n", + size_str, + r.time_ms, + r.tflops, + 100.0f * r.correct / r.total, + r.passed ? "[OK]" : "[FAIL]"); + } + + std::cout << "\n"; + std::cout << "Tests passed: " << num_passed << "/" << results.size() << "\n"; + + if(num_passed == results.size()) + { + std::cout << "\n[OK] ALL TESTS PASSED\n"; + return 0; + } + else + { + std::cout << "\n[FAIL] SOME TESTS FAILED\n"; + return 1; + } +} diff --git a/dispatcher/tests/test_real_kernel_performance.cpp b/dispatcher/tests/test_real_kernel_performance.cpp new file mode 100644 index 00000000000..ff3d635968c --- /dev/null +++ b/dispatcher/tests/test_real_kernel_performance.cpp @@ -0,0 +1,173 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +/** + * Performance test with real GPU kernel + * Measures and reports detailed performance metrics + */ + +#include +#include +#include +#include +#include + +#include "ck_tile/dispatcher/dispatcher.hpp" +#include "ck_tile/dispatcher/registry.hpp" +#include "ck_tile/dispatcher/backends/generated_tile_backend.hpp" + +// Kernel header included via -include compiler flag + +using namespace ck_tile::dispatcher; +using namespace ck_tile::dispatcher::backends; +using Priority = ck_tile::dispatcher::Registry::Priority; + +#define HIP_CHECK(call) \ + { \ + hipError_t err = call; \ + if(err != hipSuccess) \ + { \ + std::cerr << "HIP Error: " << hipGetErrorString(err) << "\n"; \ + exit(1); \ + } \ + } + +int main() +{ + std::cout << "=======================================\n"; + std::cout << "Performance Test - Real GPU Kernel\n"; + std::cout << "=======================================\n\n"; + + std::cout << "Kernel: " << KERNEL_NAME << "\n"; + std::cout << "Device: AMD Instinct MI325X (gfx942)\n\n"; + + // Register kernel + KernelKey key; + key.signature.dtype_a = DataType::FP16; + key.signature.dtype_b = DataType::FP16; + key.signature.dtype_c = DataType::FP16; + key.signature.dtype_acc = DataType::FP32; + key.signature.layout_a = LayoutTag::RowMajor; + key.signature.layout_b = LayoutTag::ColMajor; + key.signature.layout_c = LayoutTag::RowMajor; + key.signature.transpose_a = false; + key.signature.transpose_b = false; + key.signature.grouped = false; + key.signature.split_k = 1; + key.signature.elementwise_op = "PassThrough"; + key.signature.num_d_tensors = 0; + key.signature.structured_sparsity = false; + + key.algorithm.tile_shape = {128, 128, 32}; + key.algorithm.wave_shape = {2, 2, 1}; + key.algorithm.warp_tile_shape = {32, 32, 16}; + key.algorithm.pipeline = Pipeline::CompV4; + key.algorithm.scheduler = Scheduler::Intrawave; + key.algorithm.epilogue = Epilogue::CShuffle; + key.algorithm.block_size = 256; + key.algorithm.double_buffer = true; + key.algorithm.persistent = false; + key.algorithm.preshuffle = false; + key.algorithm.transpose_c = false; + key.algorithm.num_wave_groups = 1; + key.gfx_arch = "gfx942"; + + auto kernel = + create_generated_tile_kernel( + key, KERNEL_NAME); + + Registry::instance().clear(); + Registry::instance().register_kernel(kernel, Priority::High); + + Dispatcher dispatcher; + + // Performance benchmark sizes + std::vector> benchmarks = { + {128, 128, 128, "Tiny"}, + {256, 256, 256, "Small"}, + {512, 512, 512, "Medium"}, + {1024, 1024, 1024, "Large"}, + {2048, 2048, 2048, "Very Large"}, + }; + + std::cout << "Performance Benchmark Results\n"; + std::cout << "=============================\n\n"; + + std::cout << " Size | Time (ms) | TFLOPS | BW (GB/s) | Status\n"; + std::cout << " ----------|-----------|--------|-----------|--------\n"; + + bool all_passed = true; + + for(const auto& [M, N, K, label] : benchmarks) + { + // Prepare data + std::vector A_host(M * K, ADataType(1.0f)); + std::vector B_host(K * N, BDataType(1.0f)); + std::vector C_gpu(M * N); + + ADataType *A_dev, *B_dev; + CDataType* C_dev; + + HIP_CHECK(hipMalloc(&A_dev, M * K * sizeof(ADataType))); + HIP_CHECK(hipMalloc(&B_dev, K * N * sizeof(BDataType))); + HIP_CHECK(hipMalloc(&C_dev, M * N * sizeof(CDataType))); + + HIP_CHECK( + hipMemcpy(A_dev, A_host.data(), M * K * sizeof(ADataType), hipMemcpyHostToDevice)); + HIP_CHECK( + hipMemcpy(B_dev, B_host.data(), K * N * sizeof(BDataType), hipMemcpyHostToDevice)); + HIP_CHECK(hipMemset(C_dev, 0, M * N * sizeof(CDataType))); + + // Execute + Problem problem(M, N, K); + float time_ms = dispatcher.run(A_dev, B_dev, C_dev, problem); + + // Calculate metrics + double flops = 2.0 * M * N * K; + double tflops = (flops / (time_ms * 1e-3)) / 1e12; + + // Bandwidth (A + B read, C write) + double bytes = (M * K + K * N + M * N) * sizeof(CDataType); + double bandwidth_gbs = (bytes / (time_ms * 1e-3)) / 1e9; + + // Validate + HIP_CHECK(hipMemcpy(C_gpu.data(), C_dev, M * N * sizeof(CDataType), hipMemcpyDeviceToHost)); + + int correct = 0; + for(int i = 0; i < M * N; i++) + { + if(std::abs(float(C_gpu[i]) - float(K)) < 1.0f) + correct++; + } + + bool passed = (correct == M * N); + all_passed = all_passed && passed; + + char size_label[32]; + snprintf(size_label, sizeof(size_label), "%s %d³", label, M); + + printf(" %-9s | %9.4f | %6.2f | %9.1f | %s\n", + size_label, + time_ms, + tflops, + bandwidth_gbs, + passed ? "[OK]" : "[FAIL]"); + + HIP_CHECK(hipFree(A_dev)); + HIP_CHECK(hipFree(B_dev)); + HIP_CHECK(hipFree(C_dev)); + } + + std::cout << "\n"; + + if(all_passed) + { + std::cout << "[OK] ALL PERFORMANCE TESTS PASSED\n"; + return 0; + } + else + { + std::cout << "[FAIL] SOME TESTS FAILED\n"; + return 1; + } +} diff --git a/dispatcher/tests/test_real_kernel_simple.cpp b/dispatcher/tests/test_real_kernel_simple.cpp new file mode 100644 index 00000000000..72e3a5fc87b --- /dev/null +++ b/dispatcher/tests/test_real_kernel_simple.cpp @@ -0,0 +1,201 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +/** + * Simple real kernel test using tile_engine style (single kernel with -include) + * This follows the proven pattern from the examples + */ + +#include +#include +#include +#include +#include + +#include "ck_tile/dispatcher/dispatcher.hpp" +#include "ck_tile/dispatcher/registry.hpp" +#include "ck_tile/dispatcher/backends/generated_tile_backend.hpp" + +// Kernel header will be included via -include compiler flag +// It defines: ADataType, BDataType, CDataType, AccDataType, SelectedKernel, KERNEL_NAME + +using namespace ck_tile::dispatcher; +using namespace ck_tile::dispatcher::backends; +using Priority = ck_tile::dispatcher::Registry::Priority; + +#define HIP_CHECK(call) \ + { \ + hipError_t err = call; \ + if(err != hipSuccess) \ + { \ + std::cerr << "HIP Error: " << hipGetErrorString(err) << "\n"; \ + exit(1); \ + } \ + } + +// Reference CPU GEMM +template +void reference_gemm( + const std::vector& A, const std::vector& B, std::vector& C, int M, int N, int K) +{ + for(int m = 0; m < M; m++) + { + for(int n = 0; n < N; n++) + { + float acc = 0.0f; + for(int k = 0; k < K; k++) + { + acc += float(A[m * K + k]) * float(B[k * N + n]); + } + C[m * N + n] = T(acc); + } + } +} + +int main() +{ + std::cout << "=======================================\n"; + std::cout << "Simple Real Kernel Test\n"; + std::cout << "=======================================\n\n"; + + // Test size (must be multiple of tile size) + const int M = 256; + const int N = 256; + const int K = 256; + + std::cout << "Problem: M=" << M << " N=" << N << " K=" << K << "\n"; + std::cout << "Kernel: " << KERNEL_NAME << "\n\n"; + + // Create kernel key + KernelKey key; + key.signature.dtype_a = DataType::FP16; + key.signature.dtype_b = DataType::FP16; + key.signature.dtype_c = DataType::FP16; + key.signature.dtype_acc = DataType::FP32; + key.signature.layout_a = LayoutTag::RowMajor; + key.signature.layout_b = LayoutTag::ColMajor; + key.signature.layout_c = LayoutTag::RowMajor; + key.signature.transpose_a = false; + key.signature.transpose_b = false; + key.signature.grouped = false; + key.signature.split_k = 1; + key.signature.elementwise_op = "PassThrough"; + key.signature.num_d_tensors = 0; + key.signature.structured_sparsity = false; + + key.algorithm.tile_shape = {128, 128, 64}; + key.algorithm.wave_shape = {2, 2, 1}; + key.algorithm.warp_tile_shape = {32, 32, 16}; + key.algorithm.pipeline = Pipeline::CompV4; + key.algorithm.scheduler = Scheduler::Intrawave; + key.algorithm.epilogue = Epilogue::CShuffle; + key.algorithm.block_size = 256; + key.algorithm.double_buffer = true; + key.algorithm.persistent = false; + key.algorithm.preshuffle = false; + key.algorithm.transpose_c = false; + key.algorithm.num_wave_groups = 1; + key.gfx_arch = "gfx942"; + + // Create and register kernel + auto kernel = + create_generated_tile_kernel( + key, KERNEL_NAME); + + Registry::instance().clear(); + Registry::instance().register_kernel(kernel, Priority::High); + + std::cout << "OK Registered kernel\n"; + + // Create dispatcher + Dispatcher dispatcher; + Problem problem(M, N, K); + + auto selected = dispatcher.select_kernel(problem); + if(!selected) + { + std::cerr << "[FAIL] Failed to select kernel\n"; + return 1; + } + std::cout << "OK Selected kernel: " << selected->get_name() << "\n\n"; + + // Prepare data + std::cout << "Preparing test data...\n"; + std::vector A_host(M * K); + std::vector B_host(K * N); + std::vector C_gpu(M * N); + std::vector C_cpu(M * N); + + // Simple test: A=1, B=1, C should be K + for(int i = 0; i < M * K; i++) + A_host[i] = ADataType(1.0f); + for(int i = 0; i < K * N; i++) + B_host[i] = BDataType(1.0f); + + // Allocate GPU memory + ADataType *A_dev, *B_dev; + CDataType* C_dev; + + HIP_CHECK(hipMalloc(&A_dev, M * K * sizeof(ADataType))); + HIP_CHECK(hipMalloc(&B_dev, K * N * sizeof(BDataType))); + HIP_CHECK(hipMalloc(&C_dev, M * N * sizeof(CDataType))); + + HIP_CHECK(hipMemcpy(A_dev, A_host.data(), M * K * sizeof(ADataType), hipMemcpyHostToDevice)); + HIP_CHECK(hipMemcpy(B_dev, B_host.data(), K * N * sizeof(BDataType), hipMemcpyHostToDevice)); + HIP_CHECK(hipMemset(C_dev, 0, M * N * sizeof(CDataType))); + + std::cout << "OK Data ready on GPU\n\n"; + + // Execute + std::cout << "Executing GPU kernel...\n"; + float gpu_time = dispatcher.run(A_dev, B_dev, C_dev, problem); + + std::cout << "OK GPU time: " << gpu_time << " ms\n"; + + double flops = 2.0 * M * N * K; + double tflops = (flops / (gpu_time * 1e-3)) / 1e12; + std::cout << "OK Performance: " << tflops << " TFLOPS\n\n"; + + // Copy result + HIP_CHECK(hipMemcpy(C_gpu.data(), C_dev, M * N * sizeof(CDataType), hipMemcpyDeviceToHost)); + + // Validate + std::cout << "Validating (expected: all elements = " << K << ")...\n"; + + int correct = 0; + for(int i = 0; i < M * N; i++) + { + float val = float(C_gpu[i]); + if(std::abs(val - float(K)) < 1.0f) + { + correct++; + } + } + + float accuracy = 100.0f * correct / (M * N); + std::cout << "Accuracy: " << accuracy << "% (" << correct << "/" << M * N << ")\n"; + + // Show samples + std::cout << "\nFirst 5 results:\n"; + for(int i = 0; i < 5; i++) + { + std::cout << " C[" << i << "] = " << float(C_gpu[i]) << " (expected " << K << ")\n"; + } + std::cout << "\n"; + + // Cleanup + HIP_CHECK(hipFree(A_dev)); + HIP_CHECK(hipFree(B_dev)); + HIP_CHECK(hipFree(C_dev)); + + if(accuracy > 99.0f) + { + std::cout << "[OK] TEST PASSED\n"; + return 0; + } + else + { + std::cout << "[FAIL] TEST FAILED\n"; + return 1; + } +} diff --git a/dispatcher/tests/test_registry.cpp b/dispatcher/tests/test_registry.cpp new file mode 100644 index 00000000000..4e5bf718df7 --- /dev/null +++ b/dispatcher/tests/test_registry.cpp @@ -0,0 +1,166 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +/// Unit tests for Registry using Google Test + +#include "ck_tile/dispatcher/registry.hpp" +#include "test_mock_kernel.hpp" +#include + +using namespace ck_tile::dispatcher; +using namespace ck_tile::dispatcher::test; + +TEST(RegistryTest, Registration) +{ + Registry& registry = Registry::instance(); + registry.clear(); + + auto key = make_test_key(256); + auto kernel = std::make_shared(key, "test_kernel"); + + bool registered = registry.register_kernel(kernel); + EXPECT_TRUE(registered); + EXPECT_EQ(registry.size(), 1); +} + +TEST(RegistryTest, Lookup) +{ + Registry& registry = Registry::instance(); + registry.clear(); + + auto key = make_test_key(256); + auto kernel = std::make_shared(key, "test_kernel"); + registry.register_kernel(kernel); + + // Lookup by key + auto found = registry.lookup(key); + ASSERT_NE(found, nullptr); + EXPECT_EQ(found->get_name(), "test_kernel"); + + // Lookup by identifier + std::string id = key.encode_identifier(); + auto found2 = registry.lookup(id); + ASSERT_NE(found2, nullptr); + EXPECT_EQ(found2->get_name(), "test_kernel"); + + // Lookup non-existent + auto key2 = make_test_key(128); + auto not_found = registry.lookup(key2); + EXPECT_EQ(not_found, nullptr); +} + +TEST(RegistryTest, Priority) +{ + Registry& registry = Registry::instance(); + registry.clear(); + + auto key = make_test_key(256); + auto kernel1 = std::make_shared(key, "kernel_low"); + auto kernel2 = std::make_shared(key, "kernel_high"); + + // Register with low priority + registry.register_kernel(kernel1, Registry::Priority::Low); + + // Try to register with normal priority (should replace) + bool replaced = registry.register_kernel(kernel2, Registry::Priority::Normal); + EXPECT_TRUE(replaced); + + auto found = registry.lookup(key); + ASSERT_NE(found, nullptr); + EXPECT_EQ(found->get_name(), "kernel_high"); + + // Try to register with low priority again (should fail) + auto kernel3 = std::make_shared(key, "kernel_low2"); + bool not_replaced = registry.register_kernel(kernel3, Registry::Priority::Low); + EXPECT_FALSE(not_replaced); + + found = registry.lookup(key); + ASSERT_NE(found, nullptr); + EXPECT_EQ(found->get_name(), "kernel_high"); +} + +TEST(RegistryTest, GetAll) +{ + Registry& registry = Registry::instance(); + registry.clear(); + + auto key1 = make_test_key(256); + auto key2 = make_test_key(128); + auto kernel1 = std::make_shared(key1, "kernel1"); + auto kernel2 = std::make_shared(key2, "kernel2"); + + registry.register_kernel(kernel1); + registry.register_kernel(kernel2); + + auto all = registry.get_all(); + EXPECT_EQ(all.size(), 2); +} + +TEST(RegistryTest, Filter) +{ + Registry& registry = Registry::instance(); + registry.clear(); + + // Create kernels with different tile sizes + for(int tile_m : {128, 256, 512}) + { + auto key = make_test_key(tile_m); + auto kernel = std::make_shared(key, "kernel_" + std::to_string(tile_m)); + registry.register_kernel(kernel); + } + + // Filter for large tiles (>= 256) + auto large_tiles = registry.filter( + [](const KernelInstance& k) { return k.get_key().algorithm.tile_shape.m >= 256; }); + + EXPECT_EQ(large_tiles.size(), 2); +} + +TEST(RegistryTest, Clear) +{ + Registry& registry = Registry::instance(); + registry.clear(); + + auto key = make_test_key(256); + auto kernel = std::make_shared(key, "test_kernel"); + registry.register_kernel(kernel); + + EXPECT_EQ(registry.size(), 1); + + registry.clear(); + EXPECT_EQ(registry.size(), 0); +} + +TEST(RegistryTest, MultipleKernels) +{ + Registry& registry = Registry::instance(); + registry.clear(); + + // Register multiple kernels + for(int i = 0; i < 10; ++i) + { + auto key = make_test_key(256 + i); + auto kernel = std::make_shared(key, "kernel_" + std::to_string(i)); + registry.register_kernel(kernel); + } + + EXPECT_EQ(registry.size(), 10); + + // Verify all can be looked up + for(int i = 0; i < 10; ++i) + { + auto key = make_test_key(256 + i); + auto found = registry.lookup(key); + ASSERT_NE(found, nullptr); + EXPECT_EQ(found->get_name(), "kernel_" + std::to_string(i)); + } +} + +TEST(RegistryTest, Singleton) +{ + Registry& reg1 = Registry::instance(); + Registry& reg2 = Registry::instance(); + + // Should be the same instance + EXPECT_EQ(®1, ®2); +} diff --git a/dispatcher/tests/test_registry_extended.cpp b/dispatcher/tests/test_registry_extended.cpp new file mode 100644 index 00000000000..d173e1a38d4 --- /dev/null +++ b/dispatcher/tests/test_registry_extended.cpp @@ -0,0 +1,503 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +/// Extended unit tests for Registry - covers multiple registries, merging, filtering + +#include "ck_tile/dispatcher/registry.hpp" +#include "test_mock_kernel.hpp" +#include +#include +#include + +using namespace ck_tile::dispatcher; +using namespace ck_tile::dispatcher::test; + +// ============================================================================= +// Basic Registration Tests +// ============================================================================= + +class RegistryBasicTest : public ::testing::Test +{ + protected: + void SetUp() override { Registry::instance().clear(); } + + void TearDown() override { Registry::instance().clear(); } +}; + +TEST_F(RegistryBasicTest, RegisterSingleKernel) +{ + auto key = make_test_key(256); + auto kernel = std::make_shared(key, "test_kernel"); + + EXPECT_TRUE(Registry::instance().register_kernel(kernel)); + EXPECT_EQ(Registry::instance().size(), 1); +} + +TEST_F(RegistryBasicTest, RegisterNullKernel) +{ + EXPECT_FALSE(Registry::instance().register_kernel(nullptr)); + EXPECT_EQ(Registry::instance().size(), 0); +} + +TEST_F(RegistryBasicTest, RegisterMultipleKernels) +{ + for(int i = 0; i < 100; i++) + { + auto key = make_test_key(100 + i); + auto kernel = std::make_shared(key, "kernel_" + std::to_string(i)); + EXPECT_TRUE(Registry::instance().register_kernel(kernel)); + } + EXPECT_EQ(Registry::instance().size(), 100); +} + +TEST_F(RegistryBasicTest, RegisterDuplicateKey) +{ + auto key = make_test_key(256); + auto kernel1 = std::make_shared(key, "kernel1"); + auto kernel2 = std::make_shared(key, "kernel2"); + + EXPECT_TRUE(Registry::instance().register_kernel(kernel1, Registry::Priority::Normal)); + + // Same priority should not replace + EXPECT_FALSE(Registry::instance().register_kernel(kernel2, Registry::Priority::Normal)); + + auto found = Registry::instance().lookup(key); + EXPECT_EQ(found->get_name(), "kernel1"); +} + +// ============================================================================= +// Priority Tests +// ============================================================================= + +class RegistryPriorityTest : public ::testing::Test +{ + protected: + void SetUp() override { Registry::instance().clear(); } + + void TearDown() override { Registry::instance().clear(); } +}; + +TEST_F(RegistryPriorityTest, HigherPriorityReplaces) +{ + auto key = make_test_key(256); + + auto low = std::make_shared(key, "low"); + auto normal = std::make_shared(key, "normal"); + auto high = std::make_shared(key, "high"); + + EXPECT_TRUE(Registry::instance().register_kernel(low, Registry::Priority::Low)); + EXPECT_EQ(Registry::instance().lookup(key)->get_name(), "low"); + + EXPECT_TRUE(Registry::instance().register_kernel(normal, Registry::Priority::Normal)); + EXPECT_EQ(Registry::instance().lookup(key)->get_name(), "normal"); + + EXPECT_TRUE(Registry::instance().register_kernel(high, Registry::Priority::High)); + EXPECT_EQ(Registry::instance().lookup(key)->get_name(), "high"); +} + +TEST_F(RegistryPriorityTest, LowerPriorityDoesNotReplace) +{ + auto key = make_test_key(256); + + auto high = std::make_shared(key, "high"); + auto low = std::make_shared(key, "low"); + + EXPECT_TRUE(Registry::instance().register_kernel(high, Registry::Priority::High)); + EXPECT_FALSE(Registry::instance().register_kernel(low, Registry::Priority::Low)); + + EXPECT_EQ(Registry::instance().lookup(key)->get_name(), "high"); +} + +TEST_F(RegistryPriorityTest, SamePriorityDoesNotReplace) +{ + auto key = make_test_key(256); + + auto first = std::make_shared(key, "first"); + auto second = std::make_shared(key, "second"); + + EXPECT_TRUE(Registry::instance().register_kernel(first, Registry::Priority::Normal)); + EXPECT_FALSE(Registry::instance().register_kernel(second, Registry::Priority::Normal)); + + EXPECT_EQ(Registry::instance().lookup(key)->get_name(), "first"); +} + +// ============================================================================= +// Lookup Tests +// ============================================================================= + +class RegistryLookupTest : public ::testing::Test +{ + protected: + void SetUp() override + { + Registry::instance().clear(); + + // Register several kernels + for(int tile : {128, 256, 512}) + { + auto key = make_test_key(tile); + auto kernel = + std::make_shared(key, "kernel_" + std::to_string(tile)); + Registry::instance().register_kernel(kernel); + } + } + + void TearDown() override { Registry::instance().clear(); } +}; + +TEST_F(RegistryLookupTest, LookupByKey) +{ + auto key = make_test_key(256); + auto found = Registry::instance().lookup(key); + + ASSERT_NE(found, nullptr); + EXPECT_EQ(found->get_name(), "kernel_256"); +} + +TEST_F(RegistryLookupTest, LookupByIdentifier) +{ + auto key = make_test_key(256); + std::string id = key.encode_identifier(); + + auto found = Registry::instance().lookup(id); + ASSERT_NE(found, nullptr); + EXPECT_EQ(found->get_name(), "kernel_256"); +} + +TEST_F(RegistryLookupTest, LookupNonExistent) +{ + auto key = make_test_key(1024); // Not registered + EXPECT_EQ(Registry::instance().lookup(key), nullptr); + EXPECT_EQ(Registry::instance().lookup("nonexistent_id"), nullptr); +} + +TEST_F(RegistryLookupTest, LookupEmptyIdentifier) +{ + EXPECT_EQ(Registry::instance().lookup(""), nullptr); +} + +// ============================================================================= +// Filter Tests +// ============================================================================= + +class RegistryFilterTest : public ::testing::Test +{ + protected: + void SetUp() override + { + Registry::instance().clear(); + + // Register kernels with various tile sizes + for(int tile : {64, 128, 256, 512, 1024}) + { + auto key = make_test_key(tile); + key.signature.dtype_a = (tile < 256) ? DataType::FP16 : DataType::BF16; + auto kernel = + std::make_shared(key, "kernel_" + std::to_string(tile)); + Registry::instance().register_kernel(kernel); + } + } + + void TearDown() override { Registry::instance().clear(); } +}; + +TEST_F(RegistryFilterTest, FilterByTileSize) +{ + auto large = Registry::instance().filter( + [](const KernelInstance& k) { return k.get_key().algorithm.tile_shape.m >= 256; }); + + EXPECT_EQ(large.size(), 3); // 256, 512, 1024 +} + +TEST_F(RegistryFilterTest, FilterByDataType) +{ + auto fp16 = Registry::instance().filter( + [](const KernelInstance& k) { return k.get_key().signature.dtype_a == DataType::FP16; }); + + EXPECT_EQ(fp16.size(), 2); // 64, 128 +} + +TEST_F(RegistryFilterTest, FilterMatchesNone) +{ + auto none = Registry::instance().filter( + [](const KernelInstance& k) { return k.get_key().algorithm.tile_shape.m > 2048; }); + + EXPECT_EQ(none.size(), 0); +} + +TEST_F(RegistryFilterTest, FilterMatchesAll) +{ + auto all = Registry::instance().filter([](const KernelInstance& k) { return true; }); + + EXPECT_EQ(all.size(), 5); +} + +// ============================================================================= +// Multiple Registries Tests +// ============================================================================= + +class MultipleRegistriesTest : public ::testing::Test +{ + protected: + void TearDown() override { Registry::instance().clear(); } +}; + +TEST_F(MultipleRegistriesTest, CreateIndependentRegistries) +{ + Registry reg1; + Registry reg2; + + reg1.set_name("registry1"); + reg2.set_name("registry2"); + + auto key1 = make_test_key(256); + auto key2 = make_test_key(512); + + reg1.register_kernel(std::make_shared(key1, "kernel1")); + reg2.register_kernel(std::make_shared(key2, "kernel2")); + + EXPECT_EQ(reg1.size(), 1); + EXPECT_EQ(reg2.size(), 1); + + EXPECT_NE(reg1.lookup(key1), nullptr); + EXPECT_EQ(reg1.lookup(key2), nullptr); + + EXPECT_EQ(reg2.lookup(key1), nullptr); + EXPECT_NE(reg2.lookup(key2), nullptr); +} + +TEST_F(MultipleRegistriesTest, RegistryNaming) +{ + Registry reg; + reg.set_name("my_custom_registry"); + + EXPECT_EQ(reg.get_name(), "my_custom_registry"); +} + +TEST_F(MultipleRegistriesTest, MergeRegistries) +{ + Registry reg1; + Registry reg2; + + auto key1 = make_test_key(128); + auto key2 = make_test_key(256); + auto key3 = make_test_key(512); + + reg1.register_kernel(std::make_shared(key1, "k1")); + reg1.register_kernel(std::make_shared(key2, "k2")); + + reg2.register_kernel(std::make_shared(key3, "k3")); + + Registry combined; + combined.merge_from(reg1, Registry::Priority::Normal); + combined.merge_from(reg2, Registry::Priority::Normal); + + EXPECT_EQ(combined.size(), 3); + EXPECT_NE(combined.lookup(key1), nullptr); + EXPECT_NE(combined.lookup(key2), nullptr); + EXPECT_NE(combined.lookup(key3), nullptr); +} + +TEST_F(MultipleRegistriesTest, MergeWithPriorityConflict) +{ + Registry reg1; + Registry reg2; + + auto key = make_test_key(256); + + reg1.register_kernel(std::make_shared(key, "from_reg1")); + reg2.register_kernel(std::make_shared(key, "from_reg2")); + + Registry combined; + combined.merge_from(reg1, Registry::Priority::Low); + combined.merge_from(reg2, Registry::Priority::High); + + EXPECT_EQ(combined.size(), 1); + EXPECT_EQ(combined.lookup(key)->get_name(), "from_reg2"); +} + +TEST_F(MultipleRegistriesTest, SingletonIndependence) +{ + Registry local_reg; + local_reg.set_name("local"); + + auto key1 = make_test_key(256); + auto key2 = make_test_key(512); + + local_reg.register_kernel(std::make_shared(key1, "local_kernel")); + Registry::instance().register_kernel( + std::make_shared(key2, "global_kernel")); + + EXPECT_EQ(local_reg.size(), 1); + EXPECT_EQ(Registry::instance().size(), 1); + + EXPECT_NE(local_reg.lookup(key1), nullptr); + EXPECT_EQ(local_reg.lookup(key2), nullptr); + + EXPECT_EQ(Registry::instance().lookup(key1), nullptr); + EXPECT_NE(Registry::instance().lookup(key2), nullptr); +} + +// ============================================================================= +// Thread Safety Tests +// ============================================================================= + +class RegistryThreadSafetyTest : public ::testing::Test +{ + protected: + void SetUp() override { Registry::instance().clear(); } + + void TearDown() override { Registry::instance().clear(); } +}; + +TEST_F(RegistryThreadSafetyTest, ConcurrentRegistrations) +{ + const int num_threads = 10; + const int kernels_per_thread = 100; + + std::vector threads; + std::atomic success_count{0}; + + for(int t = 0; t < num_threads; t++) + { + threads.emplace_back([t, kernels_per_thread, &success_count]() { + for(int k = 0; k < kernels_per_thread; k++) + { + int tile = t * 1000 + k; // Unique tile size + auto key = make_test_key(tile); + auto kernel = + std::make_shared(key, "kernel_" + std::to_string(tile)); + + if(Registry::instance().register_kernel(kernel)) + { + success_count++; + } + } + }); + } + + for(auto& t : threads) + { + t.join(); + } + + EXPECT_EQ(success_count.load(), num_threads * kernels_per_thread); + EXPECT_EQ(Registry::instance().size(), num_threads * kernels_per_thread); +} + +TEST_F(RegistryThreadSafetyTest, ConcurrentLookups) +{ + // Pre-register kernels + for(int i = 0; i < 100; i++) + { + auto key = make_test_key(i); + auto kernel = std::make_shared(key, "kernel_" + std::to_string(i)); + Registry::instance().register_kernel(kernel); + } + + const int num_threads = 10; + const int lookups_per_thread = 1000; + std::atomic found_count{0}; + + std::vector threads; + for(int t = 0; t < num_threads; t++) + { + threads.emplace_back([lookups_per_thread, &found_count]() { + for(int k = 0; k < lookups_per_thread; k++) + { + auto key = make_test_key(k % 100); + if(Registry::instance().lookup(key) != nullptr) + { + found_count++; + } + } + }); + } + + for(auto& t : threads) + { + t.join(); + } + + EXPECT_EQ(found_count.load(), num_threads * lookups_per_thread); +} + +// ============================================================================= +// Clear and Size Tests +// ============================================================================= + +class RegistryClearTest : public ::testing::Test +{ + protected: + void TearDown() override { Registry::instance().clear(); } +}; + +TEST_F(RegistryClearTest, ClearEmptyRegistry) +{ + Registry::instance().clear(); + EXPECT_EQ(Registry::instance().size(), 0); + + Registry::instance().clear(); // Should not crash + EXPECT_EQ(Registry::instance().size(), 0); +} + +TEST_F(RegistryClearTest, ClearNonEmptyRegistry) +{ + for(int i = 0; i < 10; i++) + { + auto key = make_test_key(i); + auto kernel = std::make_shared(key, "kernel"); + Registry::instance().register_kernel(kernel); + } + + EXPECT_EQ(Registry::instance().size(), 10); + + Registry::instance().clear(); + EXPECT_EQ(Registry::instance().size(), 0); +} + +TEST_F(RegistryClearTest, RegisterAfterClear) +{ + auto key = make_test_key(256); + auto kernel = std::make_shared(key, "kernel"); + + Registry::instance().register_kernel(kernel); + EXPECT_EQ(Registry::instance().size(), 1); + + Registry::instance().clear(); + EXPECT_EQ(Registry::instance().size(), 0); + + Registry::instance().register_kernel(kernel); + EXPECT_EQ(Registry::instance().size(), 1); +} + +// ============================================================================= +// GetAll Tests +// ============================================================================= + +class RegistryGetAllTest : public ::testing::Test +{ + protected: + void SetUp() override { Registry::instance().clear(); } + + void TearDown() override { Registry::instance().clear(); } +}; + +TEST_F(RegistryGetAllTest, GetAllEmpty) +{ + auto all = Registry::instance().get_all(); + EXPECT_EQ(all.size(), 0); +} + +TEST_F(RegistryGetAllTest, GetAllMultiple) +{ + for(int i = 0; i < 5; i++) + { + auto key = make_test_key(100 + i); + auto kernel = std::make_shared(key, "kernel_" + std::to_string(i)); + Registry::instance().register_kernel(kernel); + } + + auto all = Registry::instance().get_all(); + EXPECT_EQ(all.size(), 5); +} diff --git a/dispatcher/tests/test_regression.cpp b/dispatcher/tests/test_regression.cpp new file mode 100644 index 00000000000..8b5a416ecf4 --- /dev/null +++ b/dispatcher/tests/test_regression.cpp @@ -0,0 +1,492 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +/** + * Regression tests for known issues and edge cases. + * Add a new test here whenever a bug is fixed to prevent regression. + */ + +#include "ck_tile/dispatcher/dispatcher.hpp" +#include "ck_tile/dispatcher/registry.hpp" +#include "ck_tile/dispatcher/kernel_key.hpp" +#include "ck_tile/dispatcher/problem.hpp" +#include "test_mock_kernel.hpp" +#include +#include + +using namespace ck_tile::dispatcher; +using namespace ck_tile::dispatcher::test; +using SelectionStrategy = Dispatcher::SelectionStrategy; + +// ============================================================================= +// Issue: Uninitialized 'grouped' field in KernelKey caused JSON corruption +// Fix: Ensure all fields in make_test_key() are initialized +// ============================================================================= + +class RegressionGroupedFieldTest : public ::testing::Test +{ + protected: + void SetUp() override { Registry::instance().clear(); } + + void TearDown() override { Registry::instance().clear(); } +}; + +TEST_F(RegressionGroupedFieldTest, GroupedFieldInitialized) +{ + KernelKey key = make_test_key(256); + + // grouped should be explicitly initialized + EXPECT_FALSE(key.signature.grouped); + + // Encoding should not crash or produce garbage + std::string id = key.encode_identifier(); + EXPECT_FALSE(id.empty()); + + // ID should not contain garbage characters + for(char c : id) + { + EXPECT_TRUE(std::isprint(c) || c == '_' || c == '-') + << "Invalid character in identifier: " << static_cast(c); + } +} + +TEST_F(RegressionGroupedFieldTest, GroupedFieldInJSON) +{ + KernelKey key = make_test_key(256); + key.signature.grouped = false; + + auto kernel = std::make_shared(key, "test_kernel"); + Registry::instance().register_kernel(kernel); + + // Export to JSON + std::string json = Registry::instance().export_json(true); + + // JSON should be valid (not contain null bytes or garbage) + EXPECT_FALSE(json.empty()); + + // Should contain the grouped field with proper value + EXPECT_NE(json.find("\"grouped\""), std::string::npos); + EXPECT_NE(json.find("false"), std::string::npos); +} + +// ============================================================================= +// Issue: Priority comparison was incorrect +// Fix: Higher priority should replace lower, same priority should not replace +// ============================================================================= + +class RegressionPriorityTest : public ::testing::Test +{ + protected: + void SetUp() override { Registry::instance().clear(); } + + void TearDown() override { Registry::instance().clear(); } +}; + +TEST_F(RegressionPriorityTest, LowThenHighReplaces) +{ + auto key = make_test_key(256); + auto low = std::make_shared(key, "low"); + auto high = std::make_shared(key, "high"); + + EXPECT_TRUE(Registry::instance().register_kernel(low, Registry::Priority::Low)); + EXPECT_TRUE(Registry::instance().register_kernel(high, Registry::Priority::High)); + + auto found = Registry::instance().lookup(key); + EXPECT_EQ(found->get_name(), "high"); +} + +TEST_F(RegressionPriorityTest, HighThenLowDoesNotReplace) +{ + auto key = make_test_key(256); + auto high = std::make_shared(key, "high"); + auto low = std::make_shared(key, "low"); + + EXPECT_TRUE(Registry::instance().register_kernel(high, Registry::Priority::High)); + EXPECT_FALSE(Registry::instance().register_kernel(low, Registry::Priority::Low)); + + auto found = Registry::instance().lookup(key); + EXPECT_EQ(found->get_name(), "high"); +} + +TEST_F(RegressionPriorityTest, SamePriorityDoesNotReplace) +{ + auto key = make_test_key(256); + auto first = std::make_shared(key, "first"); + auto second = std::make_shared(key, "second"); + + EXPECT_TRUE(Registry::instance().register_kernel(first, Registry::Priority::Normal)); + EXPECT_FALSE(Registry::instance().register_kernel(second, Registry::Priority::Normal)); + + auto found = Registry::instance().lookup(key); + EXPECT_EQ(found->get_name(), "first"); +} + +// ============================================================================= +// Issue: Empty heuristic caused crash +// Fix: Fall back to FirstFit when heuristic returns empty or invalid results +// ============================================================================= + +class RegressionHeuristicTest : public ::testing::Test +{ + protected: + void SetUp() override + { + Registry::instance().clear(); + + auto key = make_test_key(256); + auto kernel = std::make_shared(key, "kernel"); + Registry::instance().register_kernel(kernel); + } + + void TearDown() override { Registry::instance().clear(); } +}; + +TEST_F(RegressionHeuristicTest, EmptyHeuristicFallback) +{ + Dispatcher dispatcher; + + dispatcher.set_heuristic([](const Problem& p) -> std::vector { + return {}; // Empty + }); + dispatcher.set_strategy(SelectionStrategy::Heuristic); + + Problem problem(1024, 1024, 1024); + + // Should not crash, should fall back to FirstFit + auto selected = dispatcher.select_kernel(problem); + EXPECT_NE(selected, nullptr); +} + +TEST_F(RegressionHeuristicTest, AllInvalidHeuristicFallback) +{ + Dispatcher dispatcher; + + dispatcher.set_heuristic([](const Problem& p) -> std::vector { + return {"invalid1", "invalid2", "invalid3"}; + }); + dispatcher.set_strategy(SelectionStrategy::Heuristic); + + Problem problem(1024, 1024, 1024); + + // Should not crash, should fall back to FirstFit + auto selected = dispatcher.select_kernel(problem); + EXPECT_NE(selected, nullptr); +} + +TEST_F(RegressionHeuristicTest, NullHeuristicSafe) +{ + Dispatcher dispatcher; + + // Don't set any heuristic + dispatcher.set_strategy(SelectionStrategy::Heuristic); + + Problem problem(1024, 1024, 1024); + + // Should not crash + auto selected = dispatcher.select_kernel(problem); + // Behavior depends on implementation - may return nullptr or fall back +} + +// ============================================================================= +// Issue: Lookup by empty string caused crash or undefined behavior +// ============================================================================= + +class RegressionLookupTest : public ::testing::Test +{ + protected: + void SetUp() override { Registry::instance().clear(); } + + void TearDown() override { Registry::instance().clear(); } +}; + +TEST_F(RegressionLookupTest, EmptyStringLookup) +{ + EXPECT_EQ(Registry::instance().lookup(""), nullptr); +} + +TEST_F(RegressionLookupTest, VeryLongStringLookup) +{ + std::string very_long(10000, 'x'); + EXPECT_EQ(Registry::instance().lookup(very_long), nullptr); +} + +TEST_F(RegressionLookupTest, SpecialCharactersLookup) +{ + EXPECT_EQ(Registry::instance().lookup("kernel\0name"), nullptr); + EXPECT_EQ(Registry::instance().lookup("kernel\nname"), nullptr); + EXPECT_EQ(Registry::instance().lookup("kernel\tname"), nullptr); +} + +// ============================================================================= +// Issue: Problem with zero dimensions passed to dispatcher +// ============================================================================= + +class RegressionProblemTest : public ::testing::Test +{ + protected: + void SetUp() override + { + Registry::instance().clear(); + + auto key = make_test_key(256); + auto kernel = std::make_shared(key, "kernel"); + Registry::instance().register_kernel(kernel); + } + + void TearDown() override { Registry::instance().clear(); } +}; + +TEST_F(RegressionProblemTest, ZeroMDimension) +{ + Problem problem; + problem.M = 0; + problem.N = 1024; + problem.K = 1024; + + EXPECT_FALSE(problem.is_valid()); +} + +TEST_F(RegressionProblemTest, ZeroNDimension) +{ + Problem problem; + problem.M = 1024; + problem.N = 0; + problem.K = 1024; + + EXPECT_FALSE(problem.is_valid()); +} + +TEST_F(RegressionProblemTest, ZeroKDimension) +{ + Problem problem; + problem.M = 1024; + problem.N = 1024; + problem.K = 0; + + EXPECT_FALSE(problem.is_valid()); +} + +// ============================================================================= +// Issue: Dispatcher run with null pointers +// ============================================================================= + +class RegressionNullPointerTest : public ::testing::Test +{ + protected: + void SetUp() override + { + Registry::instance().clear(); + + auto key = make_test_key(256); + auto kernel = std::make_shared(key, "kernel"); + Registry::instance().register_kernel(kernel); + } + + void TearDown() override { Registry::instance().clear(); } +}; + +TEST_F(RegressionNullPointerTest, RunWithNullPointers) +{ + Dispatcher dispatcher; + Problem problem(1024, 1024, 1024); + + // Mock kernel doesn't use pointers, so this should work + float time = dispatcher.run(nullptr, nullptr, nullptr, problem); + + // Mock returns 1.0f + EXPECT_FLOAT_EQ(time, 1.0f); +} + +// ============================================================================= +// Issue: Thread safety - concurrent access to singleton +// ============================================================================= + +class RegressionThreadSafetyTest : public ::testing::Test +{ + protected: + void SetUp() override { Registry::instance().clear(); } + + void TearDown() override { Registry::instance().clear(); } +}; + +TEST_F(RegressionThreadSafetyTest, SingletonAddressStable) +{ + Registry* addr1 = &Registry::instance(); + Registry* addr2 = &Registry::instance(); + Registry* addr3 = &Registry::instance(); + + EXPECT_EQ(addr1, addr2); + EXPECT_EQ(addr2, addr3); +} + +// ============================================================================= +// Issue: encode_identifier could produce duplicate IDs for different configs +// ============================================================================= + +class RegressionIdentifierTest : public ::testing::Test +{ +}; + +TEST_F(RegressionIdentifierTest, DifferentConfigsDifferentIDs) +{ + // Create two keys that differ only in one field + KernelKey key1 = make_test_key(256); + KernelKey key2 = make_test_key(256); + key2.algorithm.persistent = true; // Only difference + + std::string id1 = key1.encode_identifier(); + std::string id2 = key2.encode_identifier(); + + EXPECT_NE(id1, id2) << "Different persistent flag should produce different IDs"; +} + +TEST_F(RegressionIdentifierTest, DifferentTileShapesDifferentIDs) +{ + KernelKey key1 = make_test_key(128, 128, 32); + KernelKey key2 = make_test_key(256, 256, 32); + + EXPECT_NE(key1.encode_identifier(), key2.encode_identifier()); +} + +TEST_F(RegressionIdentifierTest, DifferentWarpConfigsDifferentIDs) +{ + KernelKey key1 = make_test_key(256); + key1.algorithm.wave_shape = {2, 2, 1}; + + KernelKey key2 = make_test_key(256); + key2.algorithm.wave_shape = {4, 1, 1}; + + EXPECT_NE(key1.encode_identifier(), key2.encode_identifier()); +} + +// ============================================================================= +// Issue: Negative k_batch could cause issues +// ============================================================================= + +class RegressionKBatchTest : public ::testing::Test +{ +}; + +TEST_F(RegressionKBatchTest, ZeroKBatchInvalid) +{ + Problem problem(1024, 1024, 1024); + problem.k_batch = 0; + + EXPECT_FALSE(problem.is_valid()); +} + +TEST_F(RegressionKBatchTest, NegativeKBatchInvalid) +{ + Problem problem(1024, 1024, 1024); + problem.k_batch = -1; + + EXPECT_FALSE(problem.is_valid()); +} + +TEST_F(RegressionKBatchTest, LargeKBatchValid) +{ + Problem problem(1024, 1024, 1024); + problem.k_batch = 1000; + + EXPECT_TRUE(problem.is_valid()); +} + +// ============================================================================= +// Issue: Filter returning shared_ptr leaks +// ============================================================================= + +class RegressionFilterTest : public ::testing::Test +{ + protected: + void SetUp() override + { + Registry::instance().clear(); + + for(int i = 0; i < 10; i++) + { + auto key = make_test_key(100 + i); + auto kernel = std::make_shared(key, "kernel_" + std::to_string(i)); + Registry::instance().register_kernel(kernel); + } + } + + void TearDown() override { Registry::instance().clear(); } +}; + +TEST_F(RegressionFilterTest, FilterResultsAreValid) +{ + auto results = Registry::instance().filter( + [](const KernelInstance& k) { return k.get_key().algorithm.tile_shape.m >= 105; }); + + EXPECT_EQ(results.size(), 5); + + for(const auto& kernel : results) + { + EXPECT_NE(kernel, nullptr); + EXPECT_GE(kernel->get_key().algorithm.tile_shape.m, 105); + } +} + +// ============================================================================= +// Issue: Double clear() could cause issues +// ============================================================================= + +class RegressionDoubleClearTest : public ::testing::Test +{ +}; + +TEST_F(RegressionDoubleClearTest, DoubleClearSafe) +{ + auto key = make_test_key(256); + auto kernel = std::make_shared(key, "kernel"); + + Registry::instance().register_kernel(kernel); + EXPECT_EQ(Registry::instance().size(), 1); + + Registry::instance().clear(); + EXPECT_EQ(Registry::instance().size(), 0); + + Registry::instance().clear(); // Second clear + EXPECT_EQ(Registry::instance().size(), 0); + + // Should still work after double clear + Registry::instance().register_kernel(kernel); + EXPECT_EQ(Registry::instance().size(), 1); +} + +// ============================================================================= +// Issue: Multiple dispatchers with same registry +// ============================================================================= + +class RegressionMultiDispatcherTest : public ::testing::Test +{ + protected: + void SetUp() override + { + Registry::instance().clear(); + + auto key = make_test_key(256); + auto kernel = std::make_shared(key, "kernel"); + Registry::instance().register_kernel(kernel); + } + + void TearDown() override { Registry::instance().clear(); } +}; + +TEST_F(RegressionMultiDispatcherTest, MultipleDispatchersShareRegistry) +{ + Dispatcher d1; + Dispatcher d2; + Dispatcher d3; + + Problem problem(1024, 1024, 1024); + + auto k1 = d1.select_kernel(problem); + auto k2 = d2.select_kernel(problem); + auto k3 = d3.select_kernel(problem); + + // All should select the same kernel + EXPECT_NE(k1, nullptr); + EXPECT_EQ(k1, k2); + EXPECT_EQ(k2, k3); +} diff --git a/dispatcher/tests/test_sanity_ck_tile.cpp b/dispatcher/tests/test_sanity_ck_tile.cpp new file mode 100644 index 00000000000..fd28b7e54c0 --- /dev/null +++ b/dispatcher/tests/test_sanity_ck_tile.cpp @@ -0,0 +1,607 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +/** + * Sanity check tests to verify CK Tile kernels are actually running on GPU. + * + * These tests verify: + * 1. GPU memory allocation and transfer work correctly + * 2. The dispatcher calls CK Tile infrastructure + * 3. GPU computes correct results (not just zeros) + * 4. Performance is reasonable (not CPU fallback) + * 5. Different problem sizes work correctly + */ + +#include +#include +#include +#include +#include +#include + +#include "ck_tile/dispatcher/dispatcher.hpp" +#include "ck_tile/dispatcher/registry.hpp" +#include "ck_tile/dispatcher/backends/generated_tile_backend.hpp" + +// Kernel header will be included via -include compiler flag + +using namespace ck_tile::dispatcher; +using namespace ck_tile::dispatcher::backends; + +#define HIP_CHECK(call) \ + { \ + hipError_t err = call; \ + if(err != hipSuccess) \ + { \ + std::cerr << "HIP Error at " << __FILE__ << ":" << __LINE__ << ": " \ + << hipGetErrorString(err) << "\n"; \ + return 1; \ + } \ + } + +// Reference CPU GEMM for validation +template +void cpu_gemm( + const std::vector& A, const std::vector& B, std::vector& C, int M, int N, int K) +{ + for(int m = 0; m < M; m++) + { + for(int n = 0; n < N; n++) + { + float acc = 0.0f; + for(int k = 0; k < K; k++) + { + acc += float(A[m * K + k]) * float(B[k * N + n]); + } + C[m * N + n] = T(acc); + } + } +} + +// Test helper to setup dispatcher +void setup_dispatcher() +{ + KernelKey key; + key.signature.dtype_a = DataType::FP16; + key.signature.dtype_b = DataType::FP16; + key.signature.dtype_c = DataType::FP16; + key.signature.dtype_acc = DataType::FP32; + key.signature.layout_a = LayoutTag::RowMajor; + key.signature.layout_b = LayoutTag::ColMajor; + key.signature.layout_c = LayoutTag::RowMajor; + key.signature.transpose_a = false; + key.signature.transpose_b = false; + key.signature.grouped = false; + key.signature.split_k = 1; + key.signature.elementwise_op = "PassThrough"; + key.signature.num_d_tensors = 0; + key.signature.structured_sparsity = false; + + key.algorithm.tile_shape = {128, 128, 64}; + key.algorithm.wave_shape = {2, 2, 1}; + key.algorithm.warp_tile_shape = {32, 32, 16}; + key.algorithm.pipeline = Pipeline::CompV4; + key.algorithm.scheduler = Scheduler::Intrawave; + key.algorithm.epilogue = Epilogue::CShuffle; + key.algorithm.block_size = 256; + key.algorithm.double_buffer = true; + key.algorithm.persistent = false; + key.algorithm.preshuffle = false; + key.algorithm.transpose_c = false; + key.algorithm.num_wave_groups = 1; + key.gfx_arch = "gfx942"; + + auto kernel = + create_generated_tile_kernel( + key, KERNEL_NAME); + + Registry::instance().clear(); + Registry::instance().register_kernel(kernel, Registry::Priority::High); +} + +// ============================================================================= +// Test 1: Basic Sanity - All ones multiplication +// ============================================================================= +int test_all_ones() +{ + std::cout << "\n=== Test: All Ones Multiplication ===\n"; + + const int M = 256, N = 256, K = 256; + + std::vector A(M * K, ADataType(1.0f)); + std::vector B(K * N, BDataType(1.0f)); + std::vector C(M * N, CDataType(0.0f)); + + ADataType *A_dev, *B_dev; + CDataType* C_dev; + + HIP_CHECK(hipMalloc(&A_dev, M * K * sizeof(ADataType))); + HIP_CHECK(hipMalloc(&B_dev, K * N * sizeof(BDataType))); + HIP_CHECK(hipMalloc(&C_dev, M * N * sizeof(CDataType))); + + HIP_CHECK(hipMemcpy(A_dev, A.data(), M * K * sizeof(ADataType), hipMemcpyHostToDevice)); + HIP_CHECK(hipMemcpy(B_dev, B.data(), K * N * sizeof(BDataType), hipMemcpyHostToDevice)); + HIP_CHECK(hipMemset(C_dev, 0, M * N * sizeof(CDataType))); + + Dispatcher dispatcher; + Problem problem(M, N, K); + + float time = dispatcher.run(A_dev, B_dev, C_dev, problem); + + HIP_CHECK(hipMemcpy(C.data(), C_dev, M * N * sizeof(CDataType), hipMemcpyDeviceToHost)); + + // All ones * all ones with K=256 should give K=256 for each element + int correct = 0; + for(int i = 0; i < M * N; i++) + { + if(std::abs(float(C[i]) - float(K)) < 1.0f) + { + correct++; + } + } + + float accuracy = 100.0f * correct / (M * N); + + HIP_CHECK(hipFree(A_dev)); + HIP_CHECK(hipFree(B_dev)); + HIP_CHECK(hipFree(C_dev)); + + std::cout << " Time: " << time << " ms\n"; + std::cout << " Expected: " << K << "\n"; + std::cout << " Sample C[0]: " << float(C[0]) << "\n"; + std::cout << " Accuracy: " << accuracy << "%\n"; + + if(accuracy < 99.0f) + { + std::cerr << " FAILED: Accuracy too low\n"; + return 1; + } + + std::cout << " PASSED\n"; + return 0; +} + +// ============================================================================= +// Test 2: Non-Zero Results - Verify GPU actually computed something +// ============================================================================= +int test_non_zero_results() +{ + std::cout << "\n=== Test: Non-Zero Results ===\n"; + + const int M = 256, N = 256, K = 256; + + std::vector A(M * K, ADataType(2.0f)); // All 2s + std::vector B(K * N, BDataType(3.0f)); // All 3s + std::vector C(M * N, CDataType(0.0f)); + + ADataType *A_dev, *B_dev; + CDataType* C_dev; + + HIP_CHECK(hipMalloc(&A_dev, M * K * sizeof(ADataType))); + HIP_CHECK(hipMalloc(&B_dev, K * N * sizeof(BDataType))); + HIP_CHECK(hipMalloc(&C_dev, M * N * sizeof(CDataType))); + + HIP_CHECK(hipMemcpy(A_dev, A.data(), M * K * sizeof(ADataType), hipMemcpyHostToDevice)); + HIP_CHECK(hipMemcpy(B_dev, B.data(), K * N * sizeof(BDataType), hipMemcpyHostToDevice)); + HIP_CHECK(hipMemset(C_dev, 0, M * N * sizeof(CDataType))); + + Dispatcher dispatcher; + Problem problem(M, N, K); + + float time = dispatcher.run(A_dev, B_dev, C_dev, problem); + + HIP_CHECK(hipMemcpy(C.data(), C_dev, M * N * sizeof(CDataType), hipMemcpyDeviceToHost)); + + // 2 * 3 * K = 6 * 256 = 1536 + float expected = 6.0f * K; + int correct = 0; + int non_zero = 0; + + for(int i = 0; i < M * N; i++) + { + if(float(C[i]) != 0.0f) + non_zero++; + if(std::abs(float(C[i]) - expected) < 10.0f) + { + correct++; + } + } + + HIP_CHECK(hipFree(A_dev)); + HIP_CHECK(hipFree(B_dev)); + HIP_CHECK(hipFree(C_dev)); + + std::cout << " Time: " << time << " ms\n"; + std::cout << " Expected: " << expected << "\n"; + std::cout << " Sample C[0]: " << float(C[0]) << "\n"; + std::cout << " Non-zero elements: " << non_zero << "/" << M * N << "\n"; + + if(non_zero == 0) + { + std::cerr << " FAILED: All zeros - GPU may not have run\n"; + return 1; + } + + float accuracy = 100.0f * correct / (M * N); + std::cout << " Accuracy: " << accuracy << "%\n"; + + if(accuracy < 99.0f) + { + std::cerr << " FAILED: Accuracy too low\n"; + return 1; + } + + std::cout << " PASSED\n"; + return 0; +} + +// ============================================================================= +// Test 3: Performance Check - Ensure not CPU fallback +// ============================================================================= +int test_performance() +{ + std::cout << "\n=== Test: Performance Check ===\n"; + + const int M = 1024, N = 1024, K = 1024; + const int num_runs = 5; + + std::vector A(M * K, ADataType(1.0f)); + std::vector B(K * N, BDataType(1.0f)); + std::vector C(M * N); + + ADataType *A_dev, *B_dev; + CDataType* C_dev; + + HIP_CHECK(hipMalloc(&A_dev, M * K * sizeof(ADataType))); + HIP_CHECK(hipMalloc(&B_dev, K * N * sizeof(BDataType))); + HIP_CHECK(hipMalloc(&C_dev, M * N * sizeof(CDataType))); + + HIP_CHECK(hipMemcpy(A_dev, A.data(), M * K * sizeof(ADataType), hipMemcpyHostToDevice)); + HIP_CHECK(hipMemcpy(B_dev, B.data(), K * N * sizeof(BDataType), hipMemcpyHostToDevice)); + + Dispatcher dispatcher; + Problem problem(M, N, K); + + // Warmup + dispatcher.run(A_dev, B_dev, C_dev, problem); + HIP_CHECK(hipDeviceSynchronize()); + + // Timed runs + std::vector times; + for(int i = 0; i < num_runs; i++) + { + float time = dispatcher.run(A_dev, B_dev, C_dev, problem); + times.push_back(time); + } + + float avg_time = std::accumulate(times.begin(), times.end(), 0.0f) / times.size(); + float min_time = *std::min_element(times.begin(), times.end()); + + double flops = 2.0 * M * N * K; + double tflops = (flops / (min_time * 1e-3)) / 1e12; + + HIP_CHECK(hipFree(A_dev)); + HIP_CHECK(hipFree(B_dev)); + HIP_CHECK(hipFree(C_dev)); + + std::cout << " Problem: " << M << "x" << N << "x" << K << "\n"; + std::cout << " Avg time: " << avg_time << " ms\n"; + std::cout << " Min time: " << min_time << " ms\n"; + std::cout << " Performance: " << tflops << " TFLOPS\n"; + + // GPU should achieve at least 1 TFLOPS for this size + // CPU would be ~0.001 TFLOPS + if(tflops < 1.0) + { + std::cerr << " FAILED: Performance too low - may be CPU fallback\n"; + return 1; + } + + std::cout << " PASSED\n"; + return 0; +} + +// ============================================================================= +// Test 4: CPU vs GPU Correctness +// ============================================================================= +int test_vs_cpu_reference() +{ + std::cout << "\n=== Test: CPU vs GPU Correctness ===\n"; + + const int M = 128, N = 128, K = 128; // Small for CPU reference + + // Random-ish values + std::vector A(M * K); + std::vector B(K * N); + std::vector C_gpu(M * N); + std::vector C_cpu(M * N); + + for(int i = 0; i < M * K; i++) + { + A[i] = ADataType(float((i % 10) + 1) * 0.1f); + } + for(int i = 0; i < K * N; i++) + { + B[i] = BDataType(float((i % 7) + 1) * 0.1f); + } + + // CPU reference + cpu_gemm(A, B, C_cpu, M, N, K); + + // GPU + ADataType *A_dev, *B_dev; + CDataType* C_dev; + + HIP_CHECK(hipMalloc(&A_dev, M * K * sizeof(ADataType))); + HIP_CHECK(hipMalloc(&B_dev, K * N * sizeof(BDataType))); + HIP_CHECK(hipMalloc(&C_dev, M * N * sizeof(CDataType))); + + HIP_CHECK(hipMemcpy(A_dev, A.data(), M * K * sizeof(ADataType), hipMemcpyHostToDevice)); + HIP_CHECK(hipMemcpy(B_dev, B.data(), K * N * sizeof(BDataType), hipMemcpyHostToDevice)); + HIP_CHECK(hipMemset(C_dev, 0, M * N * sizeof(CDataType))); + + Dispatcher dispatcher; + Problem problem(M, N, K); + + dispatcher.run(A_dev, B_dev, C_dev, problem); + + HIP_CHECK(hipMemcpy(C_gpu.data(), C_dev, M * N * sizeof(CDataType), hipMemcpyDeviceToHost)); + + // Compare + float max_diff = 0.0f; + float sum_diff = 0.0f; + int correct = 0; + + for(int i = 0; i < M * N; i++) + { + float gpu_val = float(C_gpu[i]); + float cpu_val = float(C_cpu[i]); + float diff = std::abs(gpu_val - cpu_val); + + max_diff = std::max(max_diff, diff); + sum_diff += diff; + + // FP16 has limited precision (~3-4 decimal digits) + // For K=128, values can reach ~10-30, so allow 5% relative error + absolute tolerance + float tolerance = std::max(std::abs(cpu_val) * 0.05f, 1.0f); + if(diff < tolerance) + { + correct++; + } + } + + float avg_diff = sum_diff / (M * N); + float accuracy = 100.0f * correct / (M * N); + + HIP_CHECK(hipFree(A_dev)); + HIP_CHECK(hipFree(B_dev)); + HIP_CHECK(hipFree(C_dev)); + + std::cout << " Max diff: " << max_diff << "\n"; + std::cout << " Avg diff: " << avg_diff << "\n"; + std::cout << " Sample CPU C[0]: " << float(C_cpu[0]) << "\n"; + std::cout << " Sample GPU C[0]: " << float(C_gpu[0]) << "\n"; + std::cout << " Accuracy: " << accuracy << "%\n"; + + // FP16 accumulation can have significant rounding differences from CPU FP32 + // 90% is reasonable for FP16 with K=128 accumulation + if(accuracy < 90.0f) + { + std::cerr << " FAILED: Too many mismatches vs CPU\n"; + return 1; + } + + std::cout << " PASSED\n"; + return 0; +} + +// ============================================================================= +// Test 5: Different Problem Sizes +// ============================================================================= +int test_multiple_sizes() +{ + std::cout << "\n=== Test: Multiple Problem Sizes ===\n"; + + std::vector> sizes = { + {128, 128, 128}, + {256, 256, 256}, + {512, 512, 512}, + {128, 256, 512}, + {512, 256, 128}, + {1024, 1024, 256}, + }; + + int passed = 0; + int total = sizes.size(); + + for(const auto& [M, N, K] : sizes) + { + std::cout << " Testing " << M << "x" << N << "x" << K << "... "; + + std::vector A(M * K, ADataType(1.0f)); + std::vector B(K * N, BDataType(1.0f)); + std::vector C(M * N); + + ADataType *A_dev, *B_dev; + CDataType* C_dev; + + hipMalloc(&A_dev, M * K * sizeof(ADataType)); + hipMalloc(&B_dev, K * N * sizeof(BDataType)); + hipMalloc(&C_dev, M * N * sizeof(CDataType)); + + hipMemcpy(A_dev, A.data(), M * K * sizeof(ADataType), hipMemcpyHostToDevice); + hipMemcpy(B_dev, B.data(), K * N * sizeof(BDataType), hipMemcpyHostToDevice); + hipMemset(C_dev, 0, M * N * sizeof(CDataType)); + + Dispatcher dispatcher; + Problem problem(M, N, K); + + float time = dispatcher.run(A_dev, B_dev, C_dev, problem); + + hipMemcpy(C.data(), C_dev, M * N * sizeof(CDataType), hipMemcpyDeviceToHost); + + hipFree(A_dev); + hipFree(B_dev); + hipFree(C_dev); + + // Check result + int correct = 0; + for(int i = 0; i < M * N; i++) + { + if(std::abs(float(C[i]) - float(K)) < 1.0f) + { + correct++; + } + } + + float accuracy = 100.0f * correct / (M * N); + + if(accuracy > 99.0f && time > 0) + { + std::cout << "PASS (" << time << " ms)\n"; + passed++; + } + else + { + std::cout << "FAIL (acc=" << accuracy << "%, time=" << time << ")\n"; + } + } + + std::cout << "\n Passed: " << passed << "/" << total << "\n"; + + if(passed < total) + { + std::cerr << " FAILED: Some sizes failed\n"; + return 1; + } + + std::cout << " PASSED\n"; + return 0; +} + +// ============================================================================= +// Test 6: Memory Bounds Check +// ============================================================================= +int test_memory_bounds() +{ + std::cout << "\n=== Test: Memory Bounds Check ===\n"; + + const int M = 256, N = 256, K = 256; + const float sentinel = -999.0f; + + // Allocate with extra padding and sentinel values + const int padding = 16; + std::vector A(M * K + padding, ADataType(1.0f)); + std::vector B(K * N + padding, BDataType(1.0f)); + std::vector C(M * N + padding, CDataType(sentinel)); + + // Set sentinels at the end + for(int i = 0; i < padding; i++) + { + A[M * K + i] = ADataType(sentinel); + B[K * N + i] = BDataType(sentinel); + } + + ADataType *A_dev, *B_dev; + CDataType* C_dev; + + HIP_CHECK(hipMalloc(&A_dev, (M * K + padding) * sizeof(ADataType))); + HIP_CHECK(hipMalloc(&B_dev, (K * N + padding) * sizeof(BDataType))); + HIP_CHECK(hipMalloc(&C_dev, (M * N + padding) * sizeof(CDataType))); + + HIP_CHECK( + hipMemcpy(A_dev, A.data(), (M * K + padding) * sizeof(ADataType), hipMemcpyHostToDevice)); + HIP_CHECK( + hipMemcpy(B_dev, B.data(), (K * N + padding) * sizeof(BDataType), hipMemcpyHostToDevice)); + HIP_CHECK( + hipMemcpy(C_dev, C.data(), (M * N + padding) * sizeof(CDataType), hipMemcpyHostToDevice)); + + Dispatcher dispatcher; + Problem problem(M, N, K); + + dispatcher.run(A_dev, B_dev, C_dev, problem); + + HIP_CHECK( + hipMemcpy(C.data(), C_dev, (M * N + padding) * sizeof(CDataType), hipMemcpyDeviceToHost)); + + // Check sentinels weren't overwritten + bool sentinels_intact = true; + for(int i = 0; i < padding; i++) + { + if(float(C[M * N + i]) != sentinel) + { + sentinels_intact = false; + std::cerr << " Sentinel overwritten at position " << (M * N + i) << "\n"; + } + } + + HIP_CHECK(hipFree(A_dev)); + HIP_CHECK(hipFree(B_dev)); + HIP_CHECK(hipFree(C_dev)); + + if(!sentinels_intact) + { + std::cerr << " FAILED: Memory bounds violated\n"; + return 1; + } + + // Also check actual results are correct + int correct = 0; + for(int i = 0; i < M * N; i++) + { + if(std::abs(float(C[i]) - float(K)) < 1.0f) + { + correct++; + } + } + + float accuracy = 100.0f * correct / (M * N); + std::cout << " Sentinels intact: Yes\n"; + std::cout << " Result accuracy: " << accuracy << "%\n"; + + if(accuracy < 99.0f) + { + std::cerr << " FAILED: Results incorrect\n"; + return 1; + } + + std::cout << " PASSED\n"; + return 0; +} + +// ============================================================================= +// Main +// ============================================================================= +int main() +{ + std::cout << "========================================\n"; + std::cout << "CK Tile Sanity Check Tests\n"; + std::cout << "========================================\n"; + std::cout << "Kernel: " << KERNEL_NAME << "\n"; + + // Setup + setup_dispatcher(); + + int failures = 0; + + // Run all tests + failures += test_all_ones(); + failures += test_non_zero_results(); + failures += test_performance(); + failures += test_vs_cpu_reference(); + failures += test_multiple_sizes(); + failures += test_memory_bounds(); + + std::cout << "\n========================================\n"; + if(failures == 0) + { + std::cout << "ALL TESTS PASSED\n"; + std::cout << "CK Tile is running correctly on GPU.\n"; + return 0; + } + else + { + std::cout << failures << " TEST(S) FAILED\n"; + return 1; + } +} diff --git a/dispatcher/tests/test_tile_backend.cpp b/dispatcher/tests/test_tile_backend.cpp new file mode 100644 index 00000000000..4e7c6930717 --- /dev/null +++ b/dispatcher/tests/test_tile_backend.cpp @@ -0,0 +1,155 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +/// Unit tests for CK Tile backend using Google Test +/// Note: This test validates the dispatcher wrapper infrastructure, not actual kernel execution + +#include "ck_tile/dispatcher/kernel_key.hpp" +#include "ck_tile/dispatcher/problem.hpp" +#include "ck_tile/dispatcher/registry.hpp" +#include "ck_tile/dispatcher/dispatcher.hpp" +#include "test_mock_kernel.hpp" +#include + +using namespace ck_tile::dispatcher; +using namespace ck_tile::dispatcher::test; + +namespace { + +// Note: Actual CK Tile backend tests require real generated kernels and GPU hardware. +// These tests verify the dispatcher's tile backend interface and wrapper functionality +// using mock kernels instead of real tile kernels. +} // anonymous namespace + +// These tests verify the tile backend can be used with mock kernels +// Real tile kernel integration would require generated CK Tile kernels + +TEST(TileBackendTest, KernelKeyCreation) +{ + // Test creating a kernel key for tile backend + KernelKey key = make_test_key(256, 256, 32, "gfx942"); + + EXPECT_EQ(key.algorithm.tile_shape.m, 256); + EXPECT_EQ(key.algorithm.tile_shape.n, 256); + EXPECT_EQ(key.algorithm.tile_shape.k, 32); + EXPECT_EQ(key.gfx_arch, "gfx942"); + EXPECT_EQ(key.signature.dtype_a, DataType::FP16); +} + +TEST(TileBackendTest, MockKernelRegistration) +{ + // Clear registry for clean test + Registry::instance().clear(); + + KernelKey key = make_test_key(256, 256, 32, "gfx942"); + auto kernel = + std::make_shared(key, "mock_tile_kernel", false); // strict divisibility + + // Register kernel + bool registered = Registry::instance().register_kernel(kernel); + EXPECT_TRUE(registered); + + // Lookup kernel + std::string kernel_id = key.encode_identifier(); + auto found_kernel = Registry::instance().lookup(kernel_id); + EXPECT_NE(found_kernel, nullptr); + EXPECT_EQ(found_kernel->get_name(), "mock_tile_kernel"); + + Registry::instance().clear(); +} + +TEST(TileBackendTest, DispatcherWithMockTileKernel) +{ + // Clear registry + Registry::instance().clear(); + + // Create and register mock tile kernel + KernelKey key = make_test_key(256, 256, 32, "gfx942"); + auto kernel = + std::make_shared(key, "mock_tile_kernel", false); // strict divisibility + Registry::instance().register_kernel(kernel); + + // Create dispatcher + Dispatcher dispatcher; + + // Test kernel selection - divisible dimensions + Problem problem1(512, 512, 512); // Divisible by 256, 256, 32 + auto selected1 = dispatcher.select_kernel(problem1); + EXPECT_NE(selected1, nullptr); + EXPECT_EQ(selected1->get_name(), "mock_tile_kernel"); + + // Test with non-divisible problem + Problem problem2(100, 200, 300); // Not divisible + auto not_selected = dispatcher.select_kernel(problem2); + EXPECT_EQ(not_selected, nullptr); + + Registry::instance().clear(); +} + +TEST(TileBackendTest, TileKernelIdentifierEncoding) +{ + KernelKey key = make_test_key(256, 256, 32, "gfx942"); + + std::string id = key.encode_identifier(); + + // Should contain tile dimensions + EXPECT_NE(id.find("256x256x32"), std::string::npos); + EXPECT_NE(id.find("2x2x1"), std::string::npos); + EXPECT_NE(id.find("32x32x16"), std::string::npos); + + // Should contain persistent flag + EXPECT_NE(id.find("nopers"), std::string::npos); // persistent = false +} + +TEST(TileBackendTest, MultipleKernelRegistration) +{ + // Clear registry + Registry::instance().clear(); + + // Register multiple kernels with different tile sizes + KernelKey key1 = make_test_key(256, 256, 32, "gfx942"); + auto kernel1 = std::make_shared(key1, "kernel_256x256x32", false); + + KernelKey key2 = make_test_key(128, 128, 64, "gfx942"); + auto kernel2 = std::make_shared(key2, "kernel_128x128x64", false); + + Registry::instance().register_kernel(kernel1); + Registry::instance().register_kernel(kernel2); + + EXPECT_EQ(Registry::instance().size(), 2); + + // Verify both are accessible + auto found1 = Registry::instance().lookup(key1.encode_identifier()); + auto found2 = Registry::instance().lookup(key2.encode_identifier()); + + EXPECT_NE(found1, nullptr); + EXPECT_NE(found2, nullptr); + EXPECT_EQ(found1->get_name(), "kernel_256x256x32"); + EXPECT_EQ(found2->get_name(), "kernel_128x128x64"); + + Registry::instance().clear(); +} + +TEST(TileBackendTest, TileSizeSupport) +{ + Registry::instance().clear(); + + // Create kernel with 256x256x32 tiles (no padding) + KernelKey key = make_test_key(256, 256, 32, "gfx942"); + auto kernel = + std::make_shared(key, "test_kernel", false); // strict divisibility + + // Should support 512x512x512 (divisible) + EXPECT_TRUE(kernel->supports(Problem(512, 512, 512))); + + // Should support 256x256x32 (exact match) + EXPECT_TRUE(kernel->supports(Problem(256, 256, 32))); + + // Should NOT support 100x200x300 (not divisible) + EXPECT_FALSE(kernel->supports(Problem(100, 200, 300))); + + // Should support 1024x1024x1024 (divisible) + EXPECT_TRUE(kernel->supports(Problem(1024, 1024, 1024))); + + Registry::instance().clear(); +}