From b2b94a26060ea8809b563f90c0d90b10a8793bf5 Mon Sep 17 00:00:00 2001 From: Feng Shijie Date: Sun, 25 Jan 2026 12:37:55 +0000 Subject: [PATCH 001/113] sync pre_v0.1 --- .clang-format | 2 + .gitignore | 59 +- CMakeLists.txt | 45 + examples/01-vector-add.py | 102 ++ examples/02-layout_algebra.py | 63 + examples/03-mma_atom.py | 114 ++ {flydsl => flydsl_}/requirements.txt | 0 {flydsl => flydsl_}/src/_mlir | 0 .../src/flydsl_}/__init__.py | 0 .../src/flydsl_}/compiler/__init__.py | 0 .../src/flydsl_}/compiler/cache.py | 0 .../src/flydsl_}/compiler/compiler.py | 0 .../src/flydsl_}/compiler/context.py | 0 .../src/flydsl_}/compiler/executor.py | 0 .../src/flydsl_}/compiler/flir_opt_helper.py | 0 .../src/flydsl_}/compiler/pipeline.py | 0 .../src/flydsl_}/dialects/__init__.py | 0 .../src/flydsl_}/dialects/ext/__init__.py | 0 .../src/flydsl_}/dialects/ext/_loc.py | 0 .../src/flydsl_}/dialects/ext/arith.py | 0 .../flydsl_}/dialects/ext/block_reduce_ops.py | 0 .../src/flydsl_}/dialects/ext/buffer_ops.py | 0 .../src/flydsl_}/dialects/ext/flir.py | 0 .../src/flydsl_}/dialects/ext/func.py | 0 .../src/flydsl_}/dialects/ext/gpu.py | 0 .../src/flydsl_}/dialects/ext/llvm.py | 0 .../src/flydsl_}/dialects/ext/math.py | 0 .../src/flydsl_}/dialects/ext/memref.py | 0 .../dialects/ext/mlir_extras/__init__.py | 0 .../dialects/ext/mlir_extras/_shaped_value.py | 0 .../dialects/ext/mlir_extras/arith.py | 0 .../flydsl_}/dialects/ext/mlir_extras/scf.py | 0 .../flydsl_}/dialects/ext/mlir_extras/util.py | 0 .../dialects/ext/python_control_flow.py | 0 .../src/flydsl_}/dialects/ext/rocdl.py | 0 .../src/flydsl_}/dialects/ext/rocm.py | 0 .../src/flydsl_}/dialects/ext/scf.py | 0 .../src/flydsl_}/dialects/ext/vector.py | 0 .../src/flydsl_}/lang/__init__.py | 0 .../src/flydsl_}/lang/ir/__init__.py | 0 .../src/flydsl_}/lang/ir/module.py | 0 .../src/flydsl_}/lang/ir/types.py | 0 .../flydsl => flydsl_/src/flydsl_}/passes.py | 0 .../src/flydsl_}/runtime/__init__.py | 0 .../src/flydsl_}/runtime/device.py | 0 .../src/flydsl_}/utils/__init__.py | 0 .../src/flydsl_}/utils/smem_allocator.py | 0 include/flydsl-c/FlyDialect.h | 16 + include/flydsl-c/FlyPasses.h | 0 include/flydsl/CMakeLists.txt | 2 + include/flydsl/Conversion/CMakeLists.txt | 8 + .../flydsl/Conversion/FlyToROCDL/FlyToROCDL.h | 11 + include/flydsl/Conversion/Passes.h | 14 + include/flydsl/Conversion/Passes.td | 21 + include/flydsl/Dialect/CMakeLists.txt | 1 + include/flydsl/Dialect/Fly/CMakeLists.txt | 2 + include/flydsl/Dialect/Fly/IR/CMakeLists.txt | 33 + include/flydsl/Dialect/Fly/IR/FlyAttrDefs.td | 260 +++ include/flydsl/Dialect/Fly/IR/FlyDialect.h | 36 + include/flydsl/Dialect/Fly/IR/FlyDialect.td | 36 + .../flydsl/Dialect/Fly/IR/FlyInterfaces.td | 56 + include/flydsl/Dialect/Fly/IR/FlyOps.td | 567 ++++++ include/flydsl/Dialect/Fly/IR/FlyTypeDefs.td | 196 ++ .../Dialect/Fly/Transforms/CMakeLists.txt | 12 + .../Dialect/Fly/Transforms/LayoutLowering.td | 68 + .../flydsl/Dialect/Fly/Transforms/Passes.h | 19 + .../flydsl/Dialect/Fly/Transforms/Passes.td | 24 + .../flydsl/Dialect/Fly/Utils/IntTupleUtils.h | 1004 ++++++++++ include/flydsl/Dialect/Fly/Utils/IntUtils.h | 53 + .../flydsl/Dialect/Fly/Utils/LayoutUtils.h | 771 ++++++++ include/flydsl/Dialect/Fly/Utils/NormalForm.h | 25 + include/flydsl/Dialect/FlyROCDL/IR/Dialect.td | 22 + lib/Bindings/Python/MainModules.cpp | 228 +++ lib/CAPI/CMakeLists.txt | 6 + lib/CAPI/FlyDialect.cpp | 6 + lib/CMakeLists.txt | 3 + lib/Conversion/CMakeLists.txt | 1 + lib/Conversion/FlyToROCDL/CMakeLists.txt | 22 + lib/Conversion/FlyToROCDL/FlyToROCDL.cpp | 599 ++++++ lib/Dialect/CMakeLists.txt | 1 + lib/Dialect/Fly/CMakeLists.txt | 18 + lib/Dialect/Fly/IR/FlyAttrDefs.cpp | 455 +++++ lib/Dialect/Fly/IR/FlyDialect.cpp | 41 + lib/Dialect/Fly/IR/FlyOps.cpp | 1235 +++++++++++++ lib/Dialect/Fly/IR/FlyTypeDefs.cpp | 78 + .../Fly/Transforms/FlyCanonicalize.cpp | 104 ++ lib/Dialect/Fly/Transforms/LayoutLowering.cpp | 1609 +++++++++++++++++ lib/Dialect/Fly/Utils/IntTupleUtils.cpp | 448 +++++ lib/Dialect/Fly/Utils/IntUtils.cpp | 231 +++ lib/Dialect/Fly/Utils/NormalForm.cpp | 220 +++ python/flydsl/__init__.py | 1 + python/flydsl/compiler/__init__.py | 3 + python/flydsl/compiler/compiler.py | 148 ++ python/flydsl/compiler/executor.py | 44 + python/flydsl/lang/__init__.py | 2 + python/flydsl/lang/ir/__init__.py | 9 + python/flydsl/lang/ir/core.py | 480 +++++ python/flydsl/lang/ir/gpu.py | 457 +++++ python/flydsl/lang/ir/module.py | 212 +++ python/flydsl/lang/ir/types.py | 5 + python/flydsl/lang/meta.py | 30 + python/flydsl/lang/typing.py | 32 + python/flydsl/utils/__init__.py | 0 python/flydsl/utils/env_manager.py | 0 python/flydsl/utils/hip_utils.py | 2 + python/flydsl/utils/logger.py | 0 python/mlir_flydsl/CMakeLists.txt | 129 ++ python/mlir_flydsl/FlyRegisterEverything.cpp | 33 + .../_mlirRegisterEverything/py.typed | 0 python/mlir_flydsl/dialects/FlyOps.td | 7 + python/mlir_flydsl/dialects/fly.py | 4 + 111 files changed, 10527 insertions(+), 18 deletions(-) create mode 100644 .clang-format create mode 100644 CMakeLists.txt create mode 100644 examples/01-vector-add.py create mode 100644 examples/02-layout_algebra.py create mode 100644 examples/03-mma_atom.py rename {flydsl => flydsl_}/requirements.txt (100%) rename {flydsl => flydsl_}/src/_mlir (100%) rename {flydsl/src/flydsl => flydsl_/src/flydsl_}/__init__.py (100%) rename {flydsl/src/flydsl => flydsl_/src/flydsl_}/compiler/__init__.py (100%) rename {flydsl/src/flydsl => flydsl_/src/flydsl_}/compiler/cache.py (100%) rename {flydsl/src/flydsl => flydsl_/src/flydsl_}/compiler/compiler.py (100%) rename {flydsl/src/flydsl => flydsl_/src/flydsl_}/compiler/context.py (100%) rename {flydsl/src/flydsl => flydsl_/src/flydsl_}/compiler/executor.py (100%) rename {flydsl/src/flydsl => flydsl_/src/flydsl_}/compiler/flir_opt_helper.py (100%) rename {flydsl/src/flydsl => flydsl_/src/flydsl_}/compiler/pipeline.py (100%) rename {flydsl/src/flydsl => flydsl_/src/flydsl_}/dialects/__init__.py (100%) rename {flydsl/src/flydsl => flydsl_/src/flydsl_}/dialects/ext/__init__.py (100%) rename {flydsl/src/flydsl => flydsl_/src/flydsl_}/dialects/ext/_loc.py (100%) rename {flydsl/src/flydsl => flydsl_/src/flydsl_}/dialects/ext/arith.py (100%) rename {flydsl/src/flydsl => flydsl_/src/flydsl_}/dialects/ext/block_reduce_ops.py (100%) rename {flydsl/src/flydsl => flydsl_/src/flydsl_}/dialects/ext/buffer_ops.py (100%) rename {flydsl/src/flydsl => flydsl_/src/flydsl_}/dialects/ext/flir.py (100%) rename {flydsl/src/flydsl => flydsl_/src/flydsl_}/dialects/ext/func.py (100%) rename {flydsl/src/flydsl => flydsl_/src/flydsl_}/dialects/ext/gpu.py (100%) rename {flydsl/src/flydsl => flydsl_/src/flydsl_}/dialects/ext/llvm.py (100%) rename {flydsl/src/flydsl => flydsl_/src/flydsl_}/dialects/ext/math.py (100%) rename {flydsl/src/flydsl => flydsl_/src/flydsl_}/dialects/ext/memref.py (100%) rename {flydsl/src/flydsl => flydsl_/src/flydsl_}/dialects/ext/mlir_extras/__init__.py (100%) rename {flydsl/src/flydsl => flydsl_/src/flydsl_}/dialects/ext/mlir_extras/_shaped_value.py (100%) rename {flydsl/src/flydsl => flydsl_/src/flydsl_}/dialects/ext/mlir_extras/arith.py (100%) rename {flydsl/src/flydsl => flydsl_/src/flydsl_}/dialects/ext/mlir_extras/scf.py (100%) rename {flydsl/src/flydsl => flydsl_/src/flydsl_}/dialects/ext/mlir_extras/util.py (100%) rename {flydsl/src/flydsl => flydsl_/src/flydsl_}/dialects/ext/python_control_flow.py (100%) rename {flydsl/src/flydsl => flydsl_/src/flydsl_}/dialects/ext/rocdl.py (100%) rename {flydsl/src/flydsl => flydsl_/src/flydsl_}/dialects/ext/rocm.py (100%) rename {flydsl/src/flydsl => flydsl_/src/flydsl_}/dialects/ext/scf.py (100%) rename {flydsl/src/flydsl => flydsl_/src/flydsl_}/dialects/ext/vector.py (100%) rename {flydsl/src/flydsl => flydsl_/src/flydsl_}/lang/__init__.py (100%) rename {flydsl/src/flydsl => flydsl_/src/flydsl_}/lang/ir/__init__.py (100%) rename {flydsl/src/flydsl => flydsl_/src/flydsl_}/lang/ir/module.py (100%) rename {flydsl/src/flydsl => flydsl_/src/flydsl_}/lang/ir/types.py (100%) rename {flydsl/src/flydsl => flydsl_/src/flydsl_}/passes.py (100%) rename {flydsl/src/flydsl => flydsl_/src/flydsl_}/runtime/__init__.py (100%) rename {flydsl/src/flydsl => flydsl_/src/flydsl_}/runtime/device.py (100%) rename {flydsl/src/flydsl => flydsl_/src/flydsl_}/utils/__init__.py (100%) rename {flydsl/src/flydsl => flydsl_/src/flydsl_}/utils/smem_allocator.py (100%) create mode 100644 include/flydsl-c/FlyDialect.h create mode 100644 include/flydsl-c/FlyPasses.h create mode 100644 include/flydsl/CMakeLists.txt create mode 100644 include/flydsl/Conversion/CMakeLists.txt create mode 100644 include/flydsl/Conversion/FlyToROCDL/FlyToROCDL.h create mode 100644 include/flydsl/Conversion/Passes.h create mode 100644 include/flydsl/Conversion/Passes.td create mode 100644 include/flydsl/Dialect/CMakeLists.txt create mode 100644 include/flydsl/Dialect/Fly/CMakeLists.txt create mode 100644 include/flydsl/Dialect/Fly/IR/CMakeLists.txt create mode 100644 include/flydsl/Dialect/Fly/IR/FlyAttrDefs.td create mode 100644 include/flydsl/Dialect/Fly/IR/FlyDialect.h create mode 100644 include/flydsl/Dialect/Fly/IR/FlyDialect.td create mode 100644 include/flydsl/Dialect/Fly/IR/FlyInterfaces.td create mode 100644 include/flydsl/Dialect/Fly/IR/FlyOps.td create mode 100644 include/flydsl/Dialect/Fly/IR/FlyTypeDefs.td create mode 100644 include/flydsl/Dialect/Fly/Transforms/CMakeLists.txt create mode 100644 include/flydsl/Dialect/Fly/Transforms/LayoutLowering.td create mode 100644 include/flydsl/Dialect/Fly/Transforms/Passes.h create mode 100644 include/flydsl/Dialect/Fly/Transforms/Passes.td create mode 100644 include/flydsl/Dialect/Fly/Utils/IntTupleUtils.h create mode 100644 include/flydsl/Dialect/Fly/Utils/IntUtils.h create mode 100644 include/flydsl/Dialect/Fly/Utils/LayoutUtils.h create mode 100644 include/flydsl/Dialect/Fly/Utils/NormalForm.h create mode 100644 include/flydsl/Dialect/FlyROCDL/IR/Dialect.td create mode 100644 lib/Bindings/Python/MainModules.cpp create mode 100644 lib/CAPI/CMakeLists.txt create mode 100644 lib/CAPI/FlyDialect.cpp create mode 100644 lib/CMakeLists.txt create mode 100644 lib/Conversion/CMakeLists.txt create mode 100644 lib/Conversion/FlyToROCDL/CMakeLists.txt create mode 100644 lib/Conversion/FlyToROCDL/FlyToROCDL.cpp create mode 100644 lib/Dialect/CMakeLists.txt create mode 100644 lib/Dialect/Fly/CMakeLists.txt create mode 100644 lib/Dialect/Fly/IR/FlyAttrDefs.cpp create mode 100644 lib/Dialect/Fly/IR/FlyDialect.cpp create mode 100644 lib/Dialect/Fly/IR/FlyOps.cpp create mode 100644 lib/Dialect/Fly/IR/FlyTypeDefs.cpp create mode 100644 lib/Dialect/Fly/Transforms/FlyCanonicalize.cpp create mode 100644 lib/Dialect/Fly/Transforms/LayoutLowering.cpp create mode 100644 lib/Dialect/Fly/Utils/IntTupleUtils.cpp create mode 100644 lib/Dialect/Fly/Utils/IntUtils.cpp create mode 100644 lib/Dialect/Fly/Utils/NormalForm.cpp create mode 100644 python/flydsl/__init__.py create mode 100644 python/flydsl/compiler/__init__.py create mode 100644 python/flydsl/compiler/compiler.py create mode 100644 python/flydsl/compiler/executor.py create mode 100644 python/flydsl/lang/__init__.py create mode 100644 python/flydsl/lang/ir/__init__.py create mode 100644 python/flydsl/lang/ir/core.py create mode 100644 python/flydsl/lang/ir/gpu.py create mode 100644 python/flydsl/lang/ir/module.py create mode 100644 python/flydsl/lang/ir/types.py create mode 100644 python/flydsl/lang/meta.py create mode 100644 python/flydsl/lang/typing.py create mode 100644 python/flydsl/utils/__init__.py create mode 100644 python/flydsl/utils/env_manager.py create mode 100644 python/flydsl/utils/hip_utils.py create mode 100644 python/flydsl/utils/logger.py create mode 100644 python/mlir_flydsl/CMakeLists.txt create mode 100644 python/mlir_flydsl/FlyRegisterEverything.cpp create mode 100644 python/mlir_flydsl/_mlir_libs/_mlirRegisterEverything/py.typed create mode 100644 python/mlir_flydsl/dialects/FlyOps.td create mode 100644 python/mlir_flydsl/dialects/fly.py diff --git a/.clang-format b/.clang-format new file mode 100644 index 00000000..135937b9 --- /dev/null +++ b/.clang-format @@ -0,0 +1,2 @@ +BasedOnStyle: LLVM +ColumnLimit: 100 diff --git a/.gitignore b/.gitignore index e8947221..568c9b84 100644 --- a/.gitignore +++ b/.gitignore @@ -1,30 +1,53 @@ -# Build outputs +# Build directories build/ -.flydsl*/ -.flir -## Legacy build outputs (pre-rename) +build_*/ +cmake-build-*/ + +# git *.log *.diff -# Backups -*.old - -# Python caches +# Python __pycache__/ -*.pyc -.cache/ -flir/python_bindings/mlir/ -my_ir_dumps*/ -*.mlir +*.py[cod] +*$py.class +*.so +*.egg +*.egg-info/ +dist/ +.eggs/ +lib/python*/site-packages/ # Virtualenvs .venv/ .venv*/ .venvs/ -# setuptools metadata from editable installs -*.egg-info/ -flydsl/src/*.egg-info/ +# C++ +.cache/ -# Wheels / packaging outputs -dist/ +# IDE +.vscode/ +.cursor/ +.idea/ +*.swp +*.swo +*~ + +# Compiled files +*.o +*.obj +*.a +*.s +*.asm +*.hasco +*.cubin +*.ptx + +# Temporary files +/tmp/ +*.tmp + +# MACOS +.DS_Store +Thumbs.db diff --git a/CMakeLists.txt b/CMakeLists.txt new file mode 100644 index 00000000..2e4efb8f --- /dev/null +++ b/CMakeLists.txt @@ -0,0 +1,45 @@ +cmake_minimum_required(VERSION 3.20) +project(FLYDSL LANGUAGES C CXX) + +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD_REQUIRED ON) +set(CMAKE_EXPORT_COMPILE_COMMANDS ON) + + +option(FLY_UNITTEST "Build Fly Unit Tests" OFF) + +find_package(MLIR REQUIRED CONFIG) + +message(STATUS "Found MLIR: ${MLIR_DIR}") + +set(LLVM_RUNTIME_OUTPUT_INTDIR ${CMAKE_BINARY_DIR}/bin) +set(LLVM_LIBRARY_OUTPUT_INTDIR ${CMAKE_BINARY_DIR}/lib) +set(MLIR_BINARY_DIR ${CMAKE_BINARY_DIR}) + +list(APPEND CMAKE_MODULE_PATH "${MLIR_CMAKE_DIR}") +list(APPEND CMAKE_MODULE_PATH "${LLVM_CMAKE_DIR}") + +include(TableGen) +include(AddLLVM) +include(AddMLIR) +include(HandleLLVMOptions) +include(MLIRDetectPythonEnv) +mlir_configure_python_dev_packages() + +include_directories(SYSTEM ${LLVM_INCLUDE_DIRS}) +include_directories(SYSTEM ${MLIR_INCLUDE_DIRS}) +include_directories(${PROJECT_SOURCE_DIR}/include) +include_directories(${PROJECT_BINARY_DIR}/include) + +link_directories(${LLVM_BUILD_LIBRARY_DIR}) +add_definitions(${LLVM_DEFINITIONS}) + + +set(MLIR_PYTHON_PACKAGE_PREFIX "_mlir" CACHE STRING "" FORCE) +set(MLIR_BINDINGS_PYTHON_INSTALL_PREFIX "python_packages/flydsl/${MLIR_PYTHON_PACKAGE_PREFIX}" CACHE STRING "" FORCE) + + +add_subdirectory(include/flydsl) +add_subdirectory(lib) + +add_subdirectory(python/mlir_flydsl) diff --git a/examples/01-vector-add.py b/examples/01-vector-add.py new file mode 100644 index 00000000..bdc1f323 --- /dev/null +++ b/examples/01-vector-add.py @@ -0,0 +1,102 @@ +import flydsl +from flydsl import lang as fx + +N = 64 +memrefTy = fx.ir.Type.parse(f"!fly.memref") + + +class VecAdd(fx.MlirModule): + def __init__(self): + super().__init__() + + @fx.kernel + def kernel( + self: fx.T.i64(), + A: memrefTy, + B: memrefTy, + C: memrefTy, + ): + tid = fx.arith.IndexCastOp(fx.T.i32(), fx.thread_idx.x) + bid = fx.arith.IndexCastOp(fx.T.i32(), fx.block_idx.x) + + tA = fx.logical_divide(A, fx.make_layout(16, 1)) + tB = fx.logical_divide(B, fx.make_layout(16, 1)) + tC = fx.logical_divide(C, fx.make_layout(16, 1)) + + tA = fx.slice(tA, (None, bid)) + tB = fx.slice(tB, (None, bid)) + tC = fx.slice(tC, (None, bid)) + tA = fx.logical_divide(tA, fx.make_layout(1, 1)) + tB = fx.logical_divide(tB, fx.make_layout(1, 1)) + tC = fx.logical_divide(tC, fx.make_layout(1, 1)) + + RABMemRefTy = fx.ir.Type.parse(f"!fly.memref") + copyAtom = fx.make_atom(fx.ir.Type.parse("!fly.atom.universal_copy_32b")) + rA = fx.memref_alloca(RABMemRefTy, fx.make_layout(1, 1)) + rB = fx.memref_alloca(RABMemRefTy, fx.make_layout(1, 1)) + rC = fx.memref_alloca(RABMemRefTy, fx.make_layout(1, 1)) + + fx.copy_atom_call(copyAtom, fx.slice(tA, (None, tid)), rA) + fx.copy_atom_call(copyAtom, fx.slice(tB, (None, tid)), rB) + + vC = fx.arith.addf(fx.memref_load_vec(rA), fx.memref_load_vec(rB)) + fx.memref_store_vec(vC, rC) + + fx.copy_atom_call(copyAtom, rC, fx.slice(tC, (None, tid))) + + @fx.jit + def __call__( + self: fx.T.i64(), + A: memrefTy, + B: memrefTy, + C: memrefTy, + ): + size = fx.size(A) + + size = fx.get_scalar(size) + + x = fx.arith.constant(fx.T.i64(), 16) + c1 = fx.arith.constant(fx.T.index(), 1) + c16 = fx.arith.constant(fx.T.index(), 16) + + gN = fx.arith.ceildivsi(size, fx.arith.constant(fx.T.i32(), 16)) + gN = fx.arith.IndexCastOp(fx.T.index(), gN) + + kernel_sym = fx.ir.SymbolRefAttr.get(["kernels", "kernel"]) + fx.LaunchFuncOp( + kernel_sym, + grid_size=[gN, c1, c1], + block_size=[c16, c1, c1], + kernel_operands=[x, A, B, C], + ) + + +VecAdd_Module = VecAdd() +print(VecAdd_Module) + + +VecAdd_Executor = flydsl.compile(VecAdd_Module, print_after_all=True) +# VecAdd_Asm = flydsl.compile(VecAdd_Module, output_format="assembly") +# print(VecAdd_Asm) + +import torch + +tA = torch.randint(0, 10, (N,), dtype=torch.float32, device="cuda") + +tB = torch.randint(0, 10, (N,), dtype=torch.float32, device="cuda") +tC = torch.randint(0, 10, (N,), dtype=torch.float32, device="cuda") + +tAmk = torch.randint(0, 10, (N, N), dtype=torch.float32, device="cuda") + +VecAdd_Executor(tA, tB, tC) +is_closed = torch.allclose(tC, tA + tB) +print("Result correct:", is_closed) + + +if not is_closed: + print("tA:", tA[:32]) + print("tB:", tB[:32]) + print("tC:", tC[:32]) + + +print("Hello, Fly!") diff --git a/examples/02-layout_algebra.py b/examples/02-layout_algebra.py new file mode 100644 index 00000000..b940aaef --- /dev/null +++ b/examples/02-layout_algebra.py @@ -0,0 +1,63 @@ +import flydsl +from flydsl import lang as fx + +M = 16 +N = 32 +memrefTy = fx.ir.Type.parse(f"!fly.memref") + + +class VecCopy(fx.MlirModule): + def __init__(self, thr_dim, val_dim): + super().__init__() + + @fx.kernel + def kernel( + self: fx.T.i64(), + A: memrefTy, + B: memrefTy, + ): + tid = fx.arith.IndexCastOp(fx.T.i32(), fx.thread_idx.x) + bid = fx.arith.IndexCastOp(fx.T.i32(), fx.block_idx.x) + + l16 = fx.make_layout(16, 1) + tile = fx.make_tile([l16, l16]) + + tA = fx.logical_divide(A, tile) + tB = fx.logical_divide(B, tile) + + tA = fx.zipped_divide(A, tile) + tB = fx.zipped_divide(B, tile) + + tA = fx.slice(tA, ((None, None), bid)) + tB = fx.slice(tB, ((None, None), bid)) + + vec = fx.memref_load(tA, tid) + fx.memref_store(vec, tB, tid) + + @fx.jit + def __call__( + self: fx.T.i64(), + A: memrefTy, + B: memrefTy, + ): + x = fx.arith.constant(fx.T.i64(), 16) + c1 = fx.arith.constant(fx.T.index(), 1) + c256 = fx.arith.constant(fx.T.index(), 256) + gN = fx.arith.constant(fx.T.index(), N // 16) + + kernel_sym = fx.ir.SymbolRefAttr.get(["kernels", "kernel"]) + fx.LaunchFuncOp( + kernel_sym, + grid_size=[gN, c1, c1], + block_size=[c256, c1, c1], + kernel_operands=[x, A, B], + ) + + +ThrPerBlock = 256 +ValPerThr = 8 + +VecCopy_Module = VecCopy(thr_dim=ThrPerBlock, val_dim=ValPerThr) +print(VecCopy_Module) + +VecCopy_Executor = flydsl.compile(VecCopy_Module, print_after_all=False) diff --git a/examples/03-mma_atom.py b/examples/03-mma_atom.py new file mode 100644 index 00000000..4862e375 --- /dev/null +++ b/examples/03-mma_atom.py @@ -0,0 +1,114 @@ +import flydsl +from flydsl import lang as fx + +MN = 16 +K = 4 +ABMemRefTy = fx.ir.Type.parse(f"!fly.memref") +CMemRefTy = fx.ir.Type.parse(f"!fly.memref") +RABMemRefTy = fx.ir.Type.parse(f"!fly.memref") +RCMemRefTy = fx.ir.Type.parse(f"!fly.memref") + + +class MmaAtom(fx.MlirModule): + def __init__(self): + super().__init__() + + @fx.kernel + def kernel( + self: fx.T.i64(), + A: ABMemRefTy, + B: ABMemRefTy, + C: CMemRefTy, + ): + tid = fx.arith.IndexCastOp(fx.T.i32(), fx.thread_idx.x) + + rA = fx.memref_alloca(RABMemRefTy, fx.make_layout(1, 1)) + rB = fx.memref_alloca(RABMemRefTy, fx.make_layout(1, 1)) + + copyAtom = fx.make_atom(fx.ir.Type.parse("!fly.atom.universal_copy_32b")) + mmaAtom = fx.make_atom( + fx.ir.Type.parse("!fly.atom.amdgpu.mfma.f32.16x16x4f32") + ) + + tA = fx.logical_divide(A, fx.make_layout(1, 1)) + tB = fx.logical_divide(B, fx.make_layout(1, 1)) + fx.copy_atom_call(copyAtom, fx.slice(tA, (None, tid)), rA) + fx.copy_atom_call(copyAtom, fx.slice(tB, (None, tid)), rB) + + rAcc = fx.memref_alloca(RCMemRefTy, fx.make_layout(4, 1)) + f0 = fx.arith.constant(fx.T.f32(), 0.0) + fx.memref_store(f0, rAcc, 0) + fx.memref_store(f0, rAcc, 1) + fx.memref_store(f0, rAcc, 2) + fx.memref_store(f0, rAcc, 3) + fx.mma_atom_call(mmaAtom, rAcc, rA, rB, rAcc) + + tC = fx.zipped_divide( + C, fx.make_tile([fx.make_layout(4, 1), fx.make_layout(1, 1)]) + ) + permutation_tile = fx.make_tile([fx.make_layout(1, 1), fx.make_layout(16, 4)]) + tC = fx.logical_divide(tC, permutation_tile) + + fx.copy_atom_call(copyAtom, rAcc, fx.slice(tC, (None, tid))) + + @fx.jit + def __call__( + self: fx.T.i64(), + A: ABMemRefTy, + B: ABMemRefTy, + C: CMemRefTy, + ): + x = fx.arith.constant(fx.T.i64(), 16) + c1 = fx.arith.constant(fx.T.index(), 1) + c64 = fx.arith.constant(fx.T.index(), 64) + + kernel_sym = fx.ir.SymbolRefAttr.get(["kernels", "kernel"]) + fx.LaunchFuncOp( + kernel_sym, + grid_size=[c1, c1, c1], + block_size=[c64, c1, c1], + kernel_operands=[x, A, B, C], + ) + + +MmaAtom_Module = MmaAtom() +print(MmaAtom_Module) + +MmaAtom_Executor = flydsl.compile(MmaAtom_Module, print_after_all=True) +MmaAtom_Asm = flydsl.compile(MmaAtom_Module, output_format="assembly") +print(MmaAtom_Asm) + +import torch + +tA = torch.randint( + 0, + 10, + (MN, K), + dtype=torch.float32, + device="cuda", +) +tB = torch.randint( + 0, + 10, + (MN, K), + dtype=torch.float32, + device="cuda", +) +tC = torch.empty( + (MN, MN), + dtype=torch.float32, + device="cuda", +) +tC_ref = tA @ tB.T + +MmaAtom_Executor(tA, tB, tC) +is_closed = torch.allclose(tC.T, tC_ref) +print("Result correct:", is_closed) + +if not is_closed: + print("tA:", tA) + print("tB:", tB) + print("tC:", tC.T) + print("tC:", tC_ref) + +print("Hello, Fly!") diff --git a/flydsl/requirements.txt b/flydsl_/requirements.txt similarity index 100% rename from flydsl/requirements.txt rename to flydsl_/requirements.txt diff --git a/flydsl/src/_mlir b/flydsl_/src/_mlir similarity index 100% rename from flydsl/src/_mlir rename to flydsl_/src/_mlir diff --git a/flydsl/src/flydsl/__init__.py b/flydsl_/src/flydsl_/__init__.py similarity index 100% rename from flydsl/src/flydsl/__init__.py rename to flydsl_/src/flydsl_/__init__.py diff --git a/flydsl/src/flydsl/compiler/__init__.py b/flydsl_/src/flydsl_/compiler/__init__.py similarity index 100% rename from flydsl/src/flydsl/compiler/__init__.py rename to flydsl_/src/flydsl_/compiler/__init__.py diff --git a/flydsl/src/flydsl/compiler/cache.py b/flydsl_/src/flydsl_/compiler/cache.py similarity index 100% rename from flydsl/src/flydsl/compiler/cache.py rename to flydsl_/src/flydsl_/compiler/cache.py diff --git a/flydsl/src/flydsl/compiler/compiler.py b/flydsl_/src/flydsl_/compiler/compiler.py similarity index 100% rename from flydsl/src/flydsl/compiler/compiler.py rename to flydsl_/src/flydsl_/compiler/compiler.py diff --git a/flydsl/src/flydsl/compiler/context.py b/flydsl_/src/flydsl_/compiler/context.py similarity index 100% rename from flydsl/src/flydsl/compiler/context.py rename to flydsl_/src/flydsl_/compiler/context.py diff --git a/flydsl/src/flydsl/compiler/executor.py b/flydsl_/src/flydsl_/compiler/executor.py similarity index 100% rename from flydsl/src/flydsl/compiler/executor.py rename to flydsl_/src/flydsl_/compiler/executor.py diff --git a/flydsl/src/flydsl/compiler/flir_opt_helper.py b/flydsl_/src/flydsl_/compiler/flir_opt_helper.py similarity index 100% rename from flydsl/src/flydsl/compiler/flir_opt_helper.py rename to flydsl_/src/flydsl_/compiler/flir_opt_helper.py diff --git a/flydsl/src/flydsl/compiler/pipeline.py b/flydsl_/src/flydsl_/compiler/pipeline.py similarity index 100% rename from flydsl/src/flydsl/compiler/pipeline.py rename to flydsl_/src/flydsl_/compiler/pipeline.py diff --git a/flydsl/src/flydsl/dialects/__init__.py b/flydsl_/src/flydsl_/dialects/__init__.py similarity index 100% rename from flydsl/src/flydsl/dialects/__init__.py rename to flydsl_/src/flydsl_/dialects/__init__.py diff --git a/flydsl/src/flydsl/dialects/ext/__init__.py b/flydsl_/src/flydsl_/dialects/ext/__init__.py similarity index 100% rename from flydsl/src/flydsl/dialects/ext/__init__.py rename to flydsl_/src/flydsl_/dialects/ext/__init__.py diff --git a/flydsl/src/flydsl/dialects/ext/_loc.py b/flydsl_/src/flydsl_/dialects/ext/_loc.py similarity index 100% rename from flydsl/src/flydsl/dialects/ext/_loc.py rename to flydsl_/src/flydsl_/dialects/ext/_loc.py diff --git a/flydsl/src/flydsl/dialects/ext/arith.py b/flydsl_/src/flydsl_/dialects/ext/arith.py similarity index 100% rename from flydsl/src/flydsl/dialects/ext/arith.py rename to flydsl_/src/flydsl_/dialects/ext/arith.py diff --git a/flydsl/src/flydsl/dialects/ext/block_reduce_ops.py b/flydsl_/src/flydsl_/dialects/ext/block_reduce_ops.py similarity index 100% rename from flydsl/src/flydsl/dialects/ext/block_reduce_ops.py rename to flydsl_/src/flydsl_/dialects/ext/block_reduce_ops.py diff --git a/flydsl/src/flydsl/dialects/ext/buffer_ops.py b/flydsl_/src/flydsl_/dialects/ext/buffer_ops.py similarity index 100% rename from flydsl/src/flydsl/dialects/ext/buffer_ops.py rename to flydsl_/src/flydsl_/dialects/ext/buffer_ops.py diff --git a/flydsl/src/flydsl/dialects/ext/flir.py b/flydsl_/src/flydsl_/dialects/ext/flir.py similarity index 100% rename from flydsl/src/flydsl/dialects/ext/flir.py rename to flydsl_/src/flydsl_/dialects/ext/flir.py diff --git a/flydsl/src/flydsl/dialects/ext/func.py b/flydsl_/src/flydsl_/dialects/ext/func.py similarity index 100% rename from flydsl/src/flydsl/dialects/ext/func.py rename to flydsl_/src/flydsl_/dialects/ext/func.py diff --git a/flydsl/src/flydsl/dialects/ext/gpu.py b/flydsl_/src/flydsl_/dialects/ext/gpu.py similarity index 100% rename from flydsl/src/flydsl/dialects/ext/gpu.py rename to flydsl_/src/flydsl_/dialects/ext/gpu.py diff --git a/flydsl/src/flydsl/dialects/ext/llvm.py b/flydsl_/src/flydsl_/dialects/ext/llvm.py similarity index 100% rename from flydsl/src/flydsl/dialects/ext/llvm.py rename to flydsl_/src/flydsl_/dialects/ext/llvm.py diff --git a/flydsl/src/flydsl/dialects/ext/math.py b/flydsl_/src/flydsl_/dialects/ext/math.py similarity index 100% rename from flydsl/src/flydsl/dialects/ext/math.py rename to flydsl_/src/flydsl_/dialects/ext/math.py diff --git a/flydsl/src/flydsl/dialects/ext/memref.py b/flydsl_/src/flydsl_/dialects/ext/memref.py similarity index 100% rename from flydsl/src/flydsl/dialects/ext/memref.py rename to flydsl_/src/flydsl_/dialects/ext/memref.py diff --git a/flydsl/src/flydsl/dialects/ext/mlir_extras/__init__.py b/flydsl_/src/flydsl_/dialects/ext/mlir_extras/__init__.py similarity index 100% rename from flydsl/src/flydsl/dialects/ext/mlir_extras/__init__.py rename to flydsl_/src/flydsl_/dialects/ext/mlir_extras/__init__.py diff --git a/flydsl/src/flydsl/dialects/ext/mlir_extras/_shaped_value.py b/flydsl_/src/flydsl_/dialects/ext/mlir_extras/_shaped_value.py similarity index 100% rename from flydsl/src/flydsl/dialects/ext/mlir_extras/_shaped_value.py rename to flydsl_/src/flydsl_/dialects/ext/mlir_extras/_shaped_value.py diff --git a/flydsl/src/flydsl/dialects/ext/mlir_extras/arith.py b/flydsl_/src/flydsl_/dialects/ext/mlir_extras/arith.py similarity index 100% rename from flydsl/src/flydsl/dialects/ext/mlir_extras/arith.py rename to flydsl_/src/flydsl_/dialects/ext/mlir_extras/arith.py diff --git a/flydsl/src/flydsl/dialects/ext/mlir_extras/scf.py b/flydsl_/src/flydsl_/dialects/ext/mlir_extras/scf.py similarity index 100% rename from flydsl/src/flydsl/dialects/ext/mlir_extras/scf.py rename to flydsl_/src/flydsl_/dialects/ext/mlir_extras/scf.py diff --git a/flydsl/src/flydsl/dialects/ext/mlir_extras/util.py b/flydsl_/src/flydsl_/dialects/ext/mlir_extras/util.py similarity index 100% rename from flydsl/src/flydsl/dialects/ext/mlir_extras/util.py rename to flydsl_/src/flydsl_/dialects/ext/mlir_extras/util.py diff --git a/flydsl/src/flydsl/dialects/ext/python_control_flow.py b/flydsl_/src/flydsl_/dialects/ext/python_control_flow.py similarity index 100% rename from flydsl/src/flydsl/dialects/ext/python_control_flow.py rename to flydsl_/src/flydsl_/dialects/ext/python_control_flow.py diff --git a/flydsl/src/flydsl/dialects/ext/rocdl.py b/flydsl_/src/flydsl_/dialects/ext/rocdl.py similarity index 100% rename from flydsl/src/flydsl/dialects/ext/rocdl.py rename to flydsl_/src/flydsl_/dialects/ext/rocdl.py diff --git a/flydsl/src/flydsl/dialects/ext/rocm.py b/flydsl_/src/flydsl_/dialects/ext/rocm.py similarity index 100% rename from flydsl/src/flydsl/dialects/ext/rocm.py rename to flydsl_/src/flydsl_/dialects/ext/rocm.py diff --git a/flydsl/src/flydsl/dialects/ext/scf.py b/flydsl_/src/flydsl_/dialects/ext/scf.py similarity index 100% rename from flydsl/src/flydsl/dialects/ext/scf.py rename to flydsl_/src/flydsl_/dialects/ext/scf.py diff --git a/flydsl/src/flydsl/dialects/ext/vector.py b/flydsl_/src/flydsl_/dialects/ext/vector.py similarity index 100% rename from flydsl/src/flydsl/dialects/ext/vector.py rename to flydsl_/src/flydsl_/dialects/ext/vector.py diff --git a/flydsl/src/flydsl/lang/__init__.py b/flydsl_/src/flydsl_/lang/__init__.py similarity index 100% rename from flydsl/src/flydsl/lang/__init__.py rename to flydsl_/src/flydsl_/lang/__init__.py diff --git a/flydsl/src/flydsl/lang/ir/__init__.py b/flydsl_/src/flydsl_/lang/ir/__init__.py similarity index 100% rename from flydsl/src/flydsl/lang/ir/__init__.py rename to flydsl_/src/flydsl_/lang/ir/__init__.py diff --git a/flydsl/src/flydsl/lang/ir/module.py b/flydsl_/src/flydsl_/lang/ir/module.py similarity index 100% rename from flydsl/src/flydsl/lang/ir/module.py rename to flydsl_/src/flydsl_/lang/ir/module.py diff --git a/flydsl/src/flydsl/lang/ir/types.py b/flydsl_/src/flydsl_/lang/ir/types.py similarity index 100% rename from flydsl/src/flydsl/lang/ir/types.py rename to flydsl_/src/flydsl_/lang/ir/types.py diff --git a/flydsl/src/flydsl/passes.py b/flydsl_/src/flydsl_/passes.py similarity index 100% rename from flydsl/src/flydsl/passes.py rename to flydsl_/src/flydsl_/passes.py diff --git a/flydsl/src/flydsl/runtime/__init__.py b/flydsl_/src/flydsl_/runtime/__init__.py similarity index 100% rename from flydsl/src/flydsl/runtime/__init__.py rename to flydsl_/src/flydsl_/runtime/__init__.py diff --git a/flydsl/src/flydsl/runtime/device.py b/flydsl_/src/flydsl_/runtime/device.py similarity index 100% rename from flydsl/src/flydsl/runtime/device.py rename to flydsl_/src/flydsl_/runtime/device.py diff --git a/flydsl/src/flydsl/utils/__init__.py b/flydsl_/src/flydsl_/utils/__init__.py similarity index 100% rename from flydsl/src/flydsl/utils/__init__.py rename to flydsl_/src/flydsl_/utils/__init__.py diff --git a/flydsl/src/flydsl/utils/smem_allocator.py b/flydsl_/src/flydsl_/utils/smem_allocator.py similarity index 100% rename from flydsl/src/flydsl/utils/smem_allocator.py rename to flydsl_/src/flydsl_/utils/smem_allocator.py diff --git a/include/flydsl-c/FlyDialect.h b/include/flydsl-c/FlyDialect.h new file mode 100644 index 00000000..fefe8791 --- /dev/null +++ b/include/flydsl-c/FlyDialect.h @@ -0,0 +1,16 @@ +#ifndef FLY_C_DIALECTS_H +#define FLY_C_DIALECTS_H + +#include "mlir-c/IR.h" + +#ifdef __cplusplus +extern "C" { +#endif + +MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(Fly, fly); + +#ifdef __cplusplus +} +#endif + +#endif // FLY_C_DIALECTS_H diff --git a/include/flydsl-c/FlyPasses.h b/include/flydsl-c/FlyPasses.h new file mode 100644 index 00000000..e69de29b diff --git a/include/flydsl/CMakeLists.txt b/include/flydsl/CMakeLists.txt new file mode 100644 index 00000000..629c08af --- /dev/null +++ b/include/flydsl/CMakeLists.txt @@ -0,0 +1,2 @@ +add_subdirectory(Conversion) +add_subdirectory(Dialect) diff --git a/include/flydsl/Conversion/CMakeLists.txt b/include/flydsl/Conversion/CMakeLists.txt new file mode 100644 index 00000000..7cda5937 --- /dev/null +++ b/include/flydsl/Conversion/CMakeLists.txt @@ -0,0 +1,8 @@ +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls -name Conversion) +mlir_tablegen(Passes.capi.h.inc -gen-pass-capi-header --prefix Conversion) +mlir_tablegen(Passes.capi.cpp.inc -gen-pass-capi-impl --prefix Conversion) + +add_mlir_generic_tablegen_target(FlyConversionPassIncGen) + + diff --git a/include/flydsl/Conversion/FlyToROCDL/FlyToROCDL.h b/include/flydsl/Conversion/FlyToROCDL/FlyToROCDL.h new file mode 100644 index 00000000..f42c7387 --- /dev/null +++ b/include/flydsl/Conversion/FlyToROCDL/FlyToROCDL.h @@ -0,0 +1,11 @@ +#ifndef CONVERSION_FLYTOROCDL_FLYTOROCDL_H +#define CONVERSION_FLYTOROCDL_FLYTOROCDL_H + +#include "mlir/Pass/Pass.h" + +namespace mlir { +#define GEN_PASS_DECL_FLYTOROCDLCONVERSIONPASS +#include "flydsl/Conversion/Passes.h.inc" +} // namespace mlir + +#endif // CONVERSION_FLYTOROCDL_FLYTOROCDL_H diff --git a/include/flydsl/Conversion/Passes.h b/include/flydsl/Conversion/Passes.h new file mode 100644 index 00000000..20789e1e --- /dev/null +++ b/include/flydsl/Conversion/Passes.h @@ -0,0 +1,14 @@ + +#ifndef FLY_CONVERSION_PASSES_H +#define FLY_CONVERSION_PASSES_H + +#include "flydsl/Conversion/FlyToROCDL/FlyToROCDL.h" + +namespace mlir { + +#define GEN_PASS_REGISTRATION +#include "flydsl/Conversion/Passes.h.inc" + +} // namespace mlir + +#endif // FLY_CONVERSION_PASSES_H diff --git a/include/flydsl/Conversion/Passes.td b/include/flydsl/Conversion/Passes.td new file mode 100644 index 00000000..d70ffaff --- /dev/null +++ b/include/flydsl/Conversion/Passes.td @@ -0,0 +1,21 @@ +#ifndef FLY_PASSES +#define FLY_PASSES + +include "mlir/Pass/PassBase.td" + +def FlyToROCDLConversionPass : Pass<"convert-fly-to-rocdl"> { + let summary = "Lower Fly to MLIR upstream and rocdl dialects "; + let description = [{ + + }]; + + let dependentDialects = [ + "arith::ArithDialect", + "scf::SCFDialect", + "vector::VectorDialect", + "LLVM::LLVMDialect", + "ROCDL::ROCDLDialect", + ]; +} + +#endif // FLY_PASSES diff --git a/include/flydsl/Dialect/CMakeLists.txt b/include/flydsl/Dialect/CMakeLists.txt new file mode 100644 index 00000000..08c0cd63 --- /dev/null +++ b/include/flydsl/Dialect/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(Fly) diff --git a/include/flydsl/Dialect/Fly/CMakeLists.txt b/include/flydsl/Dialect/Fly/CMakeLists.txt new file mode 100644 index 00000000..9f57627c --- /dev/null +++ b/include/flydsl/Dialect/Fly/CMakeLists.txt @@ -0,0 +1,2 @@ +add_subdirectory(IR) +add_subdirectory(Transforms) diff --git a/include/flydsl/Dialect/Fly/IR/CMakeLists.txt b/include/flydsl/Dialect/Fly/IR/CMakeLists.txt new file mode 100644 index 00000000..1912c8a1 --- /dev/null +++ b/include/flydsl/Dialect/Fly/IR/CMakeLists.txt @@ -0,0 +1,33 @@ +set(LLVM_TARGET_DEFINITIONS FlyDialect.td) +mlir_tablegen(FlyDialect.h.inc -gen-dialect-decls) +mlir_tablegen(FlyDialect.cpp.inc -gen-dialect-defs) + +set(LLVM_TARGET_DEFINITIONS FlyInterfaces.td) +mlir_tablegen(FlyAttrInterfaces.h.inc -gen-attr-interface-decls) +mlir_tablegen(FlyAttrInterfaces.cpp.inc -gen-attr-interface-defs) +mlir_tablegen(FlyTypeInterfaces.h.inc -gen-type-interface-decls) +mlir_tablegen(FlyTypeInterfaces.cpp.inc -gen-type-interface-defs) +mlir_tablegen(FlyOpInterfaces.h.inc -gen-op-interface-decls) +mlir_tablegen(FlyOpInterfaces.cpp.inc -gen-op-interface-defs) + +set(LLVM_TARGET_DEFINITIONS FlyTypeDefs.td) +mlir_tablegen(FlyTypeDefs.h.inc -gen-typedef-decls -typedefs-dialect=fly) +mlir_tablegen(FlyTypeDefs.cpp.inc -gen-typedef-defs -typedefs-dialect=fly) +mlir_tablegen(FlyTypeConstraints.h.inc -gen-type-constraint-decls -typedefs-dialect=fly) +mlir_tablegen(FlyTypeConstraints.cpp.inc -gen-type-constraint-defs -typedefs-dialect=fly) + +set(LLVM_TARGET_DEFINITIONS FlyAttrDefs.td) +mlir_tablegen(FlyEnums.h.inc -gen-enum-decls) +mlir_tablegen(FlyEnums.cpp.inc -gen-enum-defs) + +set(LLVM_TARGET_DEFINITIONS FlyAttrDefs.td) +mlir_tablegen(FlyAttrDefs.h.inc -gen-attrdef-decls -attrdefs-dialect=fly) +mlir_tablegen(FlyAttrDefs.cpp.inc -gen-attrdef-defs -attrdefs-dialect=fly) +mlir_tablegen(FlyAttrConstraints.h.inc -gen-attr-constraint-decls -attrdefs-dialect=fly) +mlir_tablegen(FlyAttrConstraints.cpp.inc -gen-attr-constraint-defs -attrdefs-dialect=fly) + +set(LLVM_TARGET_DEFINITIONS FlyOps.td) +mlir_tablegen(FlyOps.h.inc -gen-op-decls) +mlir_tablegen(FlyOps.cpp.inc -gen-op-defs) + +add_public_tablegen_target(MLIRFlyIncGen) diff --git a/include/flydsl/Dialect/Fly/IR/FlyAttrDefs.td b/include/flydsl/Dialect/Fly/IR/FlyAttrDefs.td new file mode 100644 index 00000000..61e9fd71 --- /dev/null +++ b/include/flydsl/Dialect/Fly/IR/FlyAttrDefs.td @@ -0,0 +1,260 @@ +#ifndef FLY_ATTRDEFS +#define FLY_ATTRDEFS + +include "mlir/IR/AttrTypeBase.td" +include "mlir/IR/BuiltinAttributeInterfaces.td" +include "mlir/IR/Constraints.td" + +include "flydsl/Dialect/Fly/IR/FlyDialect.td" + +//===----------------------------------------------------------------------===// +// Enum attributes +//===----------------------------------------------------------------------===// +def Fly_CachePolicy : I32EnumAttr<"CachePolicy", "", [ + I32EnumAttrCase<"CacheGlobal", 0, "cache_global">, + I32EnumAttrCase<"CacheAlways", 1, "cache_always"> +]> { + let genSpecializedAttr = 0; + let cppNamespace = Fly_Dialect.cppNamespace; +} +def Fly_CachePolicyAttr : EnumAttr {} + +def Fly_AddressSpace : I32EnumAttr<"AddressSpace", "", [ + I32EnumAttrCase<"Flat", 0, "flat">, + I32EnumAttrCase<"Global", 1, "global">, + I32EnumAttrCase<"Shared", 2, "shared">, + I32EnumAttrCase<"Register", 3, "register"> +]> { + let genSpecializedAttr = 0; + let cppNamespace = Fly_Dialect.cppNamespace; +} +def Fly_AddressSpaceAttr : EnumAttr {} + +//===----------------------------------------------------------------------===// +// Type attributes +//===----------------------------------------------------------------------===// + +def Fly_AlignAttr : Fly_Attr<"Align", "align", []> { + let parameters = (ins "int32_t":$alignment); + let assemblyFormat = "`align` `<` `` $alignment `` `>`"; + + let extraClassDeclaration = [{ + static AlignAttr getTrivialAlignment(MLIRContext *context); + }]; + + let extraClassDefinition = [{ + AlignAttr $cppClass::getTrivialAlignment(MLIRContext *context) { + return get(context, 1); + } + }]; +} + +/// Unified integer attribute that represents both static and dynamic integers. +/// - None : value == 0 and width == 0. it could be used as a Static 0 in the arithmetic operations. +/// - Static integer: value is the actual integer value (value != INT32_MIN) +/// - Dynamic integer: value == INT32_MIN (sentinel), width and divisibility describe the dynamic value +def Fly_IntAttr : Fly_Attr<"Int", "int", [ + DeclareAttrInterfaceMethods +]> { + let parameters = (ins + "int32_t":$value, + DefaultValuedParameter<"int32_t", "32">:$width, + DefaultValuedParameter<"int32_t", "1">:$divisibility); + let hasCustomAssemblyFormat = 1; + + let builders = [ + AttrBuilder<(ins "int32_t":$value), [{ + return $_get($_ctxt, value, 32, 1); + }]>, + AttrBuilder<(ins "int32_t":$width, "int32_t":$divisibility), [{ + return $_get($_ctxt, std::numeric_limits::min(), width, divisibility); + }]> + ]; + + let extraClassDeclaration = [{ + bool isNone() const; + // value can't be INT32_MIN here + bool isStaticValue(int32_t value) const; + static IntAttr getNone(MLIRContext *ctx); + static IntAttr getStatic(MLIRContext *ctx, int32_t value); + static IntAttr getDynamic(MLIRContext *ctx, int32_t width = 32, int32_t divisibility = 1); + }]; + + let extraClassDefinition = [{ + bool $cppClass::isNone() const { + return getValue() == 0 && getWidth() == 0; + } + bool $cppClass::isStaticValue(int32_t value) const { + return getValue() == value; + } + IntAttr $cppClass::getNone(MLIRContext *ctx) { + return get(ctx, 0, 0, 0); + } + IntAttr $cppClass::getStatic(MLIRContext *ctx, int32_t value) { + return get(ctx, value); + } + IntAttr $cppClass::getDynamic(MLIRContext *ctx, int32_t width, int32_t divisibility) { + return get(ctx, width, divisibility); + } + }]; +} + +def Fly_SwizzleAttr : Fly_Attr<"Swizzle", "swizzle", []> { + let parameters = (ins "int32_t":$mask, "int32_t":$base, "int32_t":$shift); + let assemblyFormat = "`` `S` `` `<` $mask `` `,` `` $base `` `,` `` $shift `` `>`"; + + let extraClassDeclaration = [{ + bool isTrivialSwizzle() const; + static SwizzleAttr getTrivialSwizzle(MLIRContext *context); + }]; + + let extraClassDefinition = [{ + bool $cppClass::isTrivialSwizzle() const { + return getMask() == 0; + } + SwizzleAttr $cppClass::getTrivialSwizzle(MLIRContext *context) { + return get(context, 0, 0, 0); + } + }]; +} + +def Fly_BasisAttr : Fly_Attr<"Basis", "basis", [ + DeclareAttrInterfaceMethods +]> { + let parameters = (ins + "Attribute":$value, + ArrayRefParameter<"int32_t">:$modes + ); + let hasCustomAssemblyFormat = 1; + + let builders = [ + AttrBuilderWithInferredContext<(ins "Attribute":$value, "int32_t":$mode), [{ + ::llvm::SmallVector modes; + modes.push_back(mode); + return $_get(value.getContext(), value, modes); + }]> + ]; + + let extraClassDeclaration = [{ + int32_t depth(); + }]; +} + +def Fly_IntTupleAttr : Fly_Attr<"IntTuple", "int_tuple", [ + DeclareAttrInterfaceMethods, + DeclareAttrInterfaceMethods +]> { + let parameters = (ins + "Attribute":$value + ); + let hasCustomAssemblyFormat = 1; + + let builders = [ + AttrBuilder<(ins "int32_t":$value), [{ + return $_get($_ctxt, IntAttr::get($_ctxt, value)); + }]>, + AttrBuilderWithInferredContext<(ins "Attribute":$value), [{ + return $_get(value.getContext(), value); + }]> + ]; + + let extraClassDeclaration = [{ + static IntTupleAttr getLeafNone(MLIRContext *ctx); + static IntTupleAttr getLeafStatic(MLIRContext *ctx, int32_t value); + static IntTupleAttr getLeafDynamic(MLIRContext *ctx, int32_t width = 32, int32_t divisibility = 1); + + bool isLeafNone() const; + bool isLeafStaticValue(int32_t value) const; + IntAttr getLeafAsInt() const; + BasisAttr getLeafAsBasis() const; + + IntTupleAttr at(int32_t idx) const; + IntTupleAttr at(const ArrayRef idxs) const; + + int32_t dyncLeafCount() const; + }]; +} + +def Fly_LayoutAttr : Fly_Attr<"Layout", "layout", [ + DeclareAttrInterfaceMethods, + DeclareAttrInterfaceMethods +]> { + let parameters = (ins + Fly_IntTupleAttr:$shape, + Fly_IntTupleAttr:$stride + ); + let assemblyFormat = "`` $shape `` `:` `` $stride"; + + let builders = [ + AttrBuilderWithInferredContext<(ins "IntTupleAttr":$shape, "IntTupleAttr":$stride), [{ + return $_get(shape.getContext(), shape, stride); + }]> + ]; + + let extraClassDeclaration = [{ + bool isStaticShape() const; + bool isStaticStride() const; + + LayoutAttr at(int32_t idx) const; + LayoutAttr at(const ArrayRef idxs) const; + }]; +} + +def Fly_ComposedLayoutAttr : Fly_Attr<"ComposedLayout", "composed_layout", [ + DeclareAttrInterfaceMethods, + DeclareAttrInterfaceMethods +]> { + let parameters = (ins + "Attribute":$inner, + Fly_IntTupleAttr:$offset, + Fly_LayoutAttr:$outer + ); + let assemblyFormat = "`<` $inner `o` $offset `o` $outer `>`"; + + let builders = [ + AttrBuilderWithInferredContext<(ins "Attribute":$inner, "IntTupleAttr":$offset, "LayoutAttr":$outer), [{ + return $_get(inner.getContext(), inner, offset, outer); + }]> + ]; + + let extraClassDeclaration = [{ + bool isStaticOuter() const; + bool isStaticInner() const; + bool isStaticOffset() const; + + ComposedLayoutAttr at(int32_t idx) const; + ComposedLayoutAttr at(const ArrayRef idxs) const; + }]; +} + +def Fly_TileAttr : Fly_Attr<"Tile", "tile", []> { + let parameters = (ins + "Attribute":$value + ); + let hasCustomAssemblyFormat = 1; + + let builders = [ + AttrBuilderWithInferredContext<(ins "Attribute":$value), [{ + return $_get(value.getContext(), value); + }]> + ]; + + let extraClassDeclaration = [{ + int32_t rank() const; + + bool isLeaf() const; + bool isNoneMode() const; + bool isNoneMode(int32_t idx) const; + + Attribute at(int32_t idx) const; + }]; +} + +def Fly_LeafAttr : AnyAttrOf<[Fly_IntAttr, Fly_BasisAttr]> { + let cppFunctionName = "isValidLeafAttr"; +} +def Fly_AnyLayoutAttr : AnyAttrOf<[Fly_LayoutAttr, Fly_ComposedLayoutAttr]> { + let cppFunctionName = "isAnyValidLayoutAttr"; +} + +#endif // FLY_ATTRDEFS diff --git a/include/flydsl/Dialect/Fly/IR/FlyDialect.h b/include/flydsl/Dialect/Fly/IR/FlyDialect.h new file mode 100644 index 00000000..64de1156 --- /dev/null +++ b/include/flydsl/Dialect/Fly/IR/FlyDialect.h @@ -0,0 +1,36 @@ +#ifndef FLY_DIALECT_FLY_IR_DIALECT_H +#define FLY_DIALECT_FLY_IR_DIALECT_H + +#include "mlir/Bytecode/BytecodeOpInterface.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/OpImplementation.h" +#include "mlir/IR/Types.h" +#include "mlir/Interfaces/FunctionInterfaces.h" +#include "mlir/Interfaces/InferTypeOpInterface.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" + +#include "flydsl/Dialect/Fly/IR/FlyDialect.h.inc" +#include "flydsl/Dialect/Fly/IR/FlyEnums.h.inc" + +namespace mlir::fly { +#include "flydsl/Dialect/Fly/IR/FlyAttrInterfaces.h.inc" +#include "flydsl/Dialect/Fly/IR/FlyTypeInterfaces.h.inc" +} // namespace mlir::fly + +#define GET_ATTRDEF_CLASSES +#include "flydsl/Dialect/Fly/IR/FlyAttrDefs.h.inc" +#define GET_TYPEDEF_CLASSES +#include "flydsl/Dialect/Fly/IR/FlyTypeDefs.h.inc" +#define GET_OP_CLASSES +#include "flydsl/Dialect/Fly/IR/FlyOps.h.inc" + +namespace mlir::fly { +#include "flydsl/Dialect/Fly/IR/FlyAttrConstraints.h.inc" +#include "flydsl/Dialect/Fly/IR/FlyTypeConstraints.h.inc" +} // namespace mlir::fly + +#endif // FLY_DIALECT_FLY_IR_DIALECT_H diff --git a/include/flydsl/Dialect/Fly/IR/FlyDialect.td b/include/flydsl/Dialect/Fly/IR/FlyDialect.td new file mode 100644 index 00000000..68f3626e --- /dev/null +++ b/include/flydsl/Dialect/Fly/IR/FlyDialect.td @@ -0,0 +1,36 @@ +#ifndef FLY_DIALECT +#define FLY_DIALECT + +include "mlir/IR/EnumAttr.td" +include "mlir/IR/AttrTypeBase.td" +include "mlir/IR/OpBase.td" + +include "flydsl/Dialect/Fly/IR/FlyInterfaces.td" + +def Fly_Dialect : Dialect { + let name = "fly"; + let cppNamespace = "::mlir::fly"; + + // let hasConstantMaterializer = 1 + let hasConstantMaterializer = 0; + + let useDefaultTypePrinterParser = 1; + let useDefaultAttributePrinterParser = 1; + let usePropertiesForAttributes = 1; +} + +class Fly_Type traits = []> + : TypeDef { + let mnemonic = typeMnemonic; +} + +class Fly_Attr traits = []> + : AttrDef { + let mnemonic = attrMnemonic; +} + +class Fly_Op traits = []> + : Op; + + +#endif // FLY_DIALECT diff --git a/include/flydsl/Dialect/Fly/IR/FlyInterfaces.td b/include/flydsl/Dialect/Fly/IR/FlyInterfaces.td new file mode 100644 index 00000000..c84b2cbc --- /dev/null +++ b/include/flydsl/Dialect/Fly/IR/FlyInterfaces.td @@ -0,0 +1,56 @@ +#ifndef FLY_INTERFACES +#define FLY_INTERFACES + +include "mlir/IR/OpBase.td" + +def Fly_NestedInterfaceMethods { + list methods = [ + InterfaceMethod<"", "bool", "isLeaf", (ins)>, + InterfaceMethod<"", "int32_t", "rank", (ins)>, + InterfaceMethod<"", "int32_t", "rank", (ins "int32_t":$idx)>, + InterfaceMethod<"", "int32_t", "rank", (ins "ArrayRef":$idxs)>, + InterfaceMethod<"", "int32_t", "depth", (ins)>, + InterfaceMethod<"", "int32_t", "depth", (ins "int32_t":$idx)>, + InterfaceMethod<"", "int32_t", "depth", (ins "ArrayRef":$idxs)> + ]; +} + +def Fly_MayStaticInterface { + list methods = [ + InterfaceMethod<"", "bool", "isStatic", (ins)> + ]; +} + +def Fly_NestedAttrInterface : AttrInterface<"NestedAttrInterface"> { + let methods = Fly_NestedInterfaceMethods.methods; +} +def Fly_NestedTypeInterface : TypeInterface<"NestedTypeInterface"> { + let methods = Fly_NestedInterfaceMethods.methods; +} + +def Fly_MayStaticAttrInterface : AttrInterface<"MayStaticAttrInterface"> { + let methods = Fly_MayStaticInterface.methods; +} +def Fly_MayStaticTypeInterface : TypeInterface<"MayStaticTypeInterface"> { + let methods = Fly_MayStaticInterface.methods; +} + + +def Fly_CopyAtomTypeInterface : TypeInterface<"CopyAtomTypeInterface"> { + let methods = [ + InterfaceMethod<"", "Attribute", "getThrSize", (ins)>, + InterfaceMethod<"", "Attribute", "getThrValLayoutSrc", (ins)>, + InterfaceMethod<"", "Attribute", "getThrValLayoutDst", (ins)> + ]; +} + +def Fly_MmaAtomTypeInterface : TypeInterface<"MmaAtomTypeInterface"> { + let methods = [ + InterfaceMethod<"", "Attribute", "getThrSize", (ins)>, + InterfaceMethod<"", "Attribute", "getThrValLayoutA", (ins)>, + InterfaceMethod<"", "Attribute", "getThrValLayoutB", (ins)>, + InterfaceMethod<"", "Attribute", "getThrValLayoutC", (ins)> + ]; +} + +#endif // FLY_INTERFACES diff --git a/include/flydsl/Dialect/Fly/IR/FlyOps.td b/include/flydsl/Dialect/Fly/IR/FlyOps.td new file mode 100644 index 00000000..0db7b3d5 --- /dev/null +++ b/include/flydsl/Dialect/Fly/IR/FlyOps.td @@ -0,0 +1,567 @@ +#ifndef FLY_OPS +#define FLY_OPS + +include "mlir/IR/OpBase.td" +include "mlir/Interfaces/SideEffectInterfaces.td" +include "mlir/Interfaces/InferTypeOpInterface.td" +include "mlir/Bytecode/BytecodeOpInterface.td" + +include "flydsl/Dialect/Fly/IR/FlyDialect.td" +include "flydsl/Dialect/Fly/IR/FlyTypeDefs.td" +include "flydsl/Dialect/Fly/IR/FlyAttrDefs.td" + + +//===----------------------------------------------------------------------===// +// Constructors +//===----------------------------------------------------------------------===// + +def Fly_StaticOp : Fly_Op<"static", [Pure]> { + let arguments = (ins); + let results = (outs AnyType:$result); + let assemblyFormat = "attr-dict `:` functional-type(operands, results)"; +} + +def Fly_MakeIntTupleOp : Fly_Op<"make_int_tuple", [Pure]> { + let arguments = (ins Variadic>:$dyncElems); + let results = (outs Fly_IntTuple:$result); + let assemblyFormat = "`(` $dyncElems `)` attr-dict `:` functional-type(operands, results)"; +} + +def Fly_MakeShapeOp : Fly_Op<"make_shape", [Pure]> { + let arguments = (ins Variadic>:$dyncElems); + let results = (outs Fly_IntTuple:$result); + let assemblyFormat = "`(` $dyncElems `)` attr-dict `:` functional-type(operands, results)"; +} + +def Fly_MakeStrideOp : Fly_Op<"make_stride", [Pure]> { + let arguments = (ins Variadic>:$dyncElems); + let results = (outs Fly_IntTuple:$result); + let assemblyFormat = "`(` $dyncElems `)` attr-dict `:` functional-type(operands, results)"; +} + +def Fly_MakeCoordOp : Fly_Op<"make_coord", [Pure]> { + let arguments = (ins Variadic>:$dyncElems); + let results = (outs Fly_IntTuple:$result); + let assemblyFormat = "`(` $dyncElems `)` attr-dict `:` functional-type(operands, results)"; +} + +def Fly_MakeLayoutOp : Fly_Op<"make_layout", [Pure, DeclareOpInterfaceMethods]> { + let arguments = (ins Fly_IntTuple:$shape, Optional:$stride); + let results = (outs Fly_Layout:$result); + let assemblyFormat = "`(` $shape (`,` $stride^)? `)` attr-dict `:` functional-type(operands, results)"; +} + +def Fly_MakeTileOp: Fly_Op<"make_tile", [Pure, DeclareOpInterfaceMethods]> { + let arguments = (ins Variadic>:$modes); + let results = (outs Fly_Tile:$result); + let assemblyFormat = "`(` $modes `)` attr-dict `:` functional-type(operands, results)"; +} + +def Fly_MakeViewOp : Fly_Op<"make_view", [Pure, DeclareOpInterfaceMethods]> { + let arguments = (ins IteratorLikeType:$iter, AnyLayoutType:$layout); + let results = (outs TensorLikeType:$result); + let assemblyFormat = "`(` $iter `,` $layout `)` attr-dict `:` functional-type(operands, results)"; +} + +def Fly_MakeLayoutLikeOp : Fly_Op<"make_layout_like", [Pure, DeclareOpInterfaceMethods]> { + let arguments = (ins Fly_Layout:$src); + let results = (outs Fly_Layout:$result); + let assemblyFormat = "`(` $src `)` attr-dict `:` functional-type(operands, results)"; +} + +def Fly_MakeOrderedLayoutOp : Fly_Op<"make_ordered_layout", [Pure, DeclareOpInterfaceMethods]> { + let arguments = (ins Fly_IntTuple:$shape, Fly_IntTuple:$order); + let results = (outs Fly_Layout:$result); + let assemblyFormat = "`(` $shape `,` $order `)` attr-dict `:` functional-type(operands, results)"; +} + +def Fly_MakeComposedLayoutOp : Fly_Op<"make_composed_layout", [Pure, DeclareOpInterfaceMethods]> { + let arguments = (ins AnyLayoutType:$inner, + Fly_IntTuple:$offset, + Fly_Layout:$outer); + let results = (outs Fly_ComposedLayout:$result); + let assemblyFormat = "`(` $inner `,` $offset `,` $outer `)` attr-dict `:` functional-type(operands, results)"; +} + +def Fly_MakeIdentityLayoutOp : Fly_Op<"make_identity_layout", [Pure, DeclareOpInterfaceMethods]> { + let arguments = (ins Fly_IntTuple:$shape); + let results = (outs Fly_Layout:$result); + let assemblyFormat = "`(` $shape `)` attr-dict `:` functional-type(operands, results)"; +} + +def Fly_MakeIdentityTensorOp : Fly_Op<"make_identity_tensor", [Pure, DeclareOpInterfaceMethods]> { + let arguments = (ins Fly_IntTuple:$shape); + let results = (outs Fly_CoordTensor:$result); + let assemblyFormat = "`(` $shape `)` attr-dict `:` functional-type(operands, results)"; +} + +def Fly_MakeFragmentLikeOp : Fly_Op<"make_fragment_like", [Pure, DeclareOpInterfaceMethods]> { + let arguments = (ins Fly_Layout:$src); + let results = (outs Fly_MemRef:$result); + let assemblyFormat = "`(` $src `)` attr-dict `:` functional-type(operands, results)"; +} + +//===----------------------------------------------------------------------===// +// Extractors +//===----------------------------------------------------------------------===// + +def Fly_GetOp : Fly_Op<"get", [Pure]> { + let arguments = (ins AnyTypeOf<[Fly_IntTuple, Fly_Layout, Fly_MemRef]>:$input, + OptionalAttr:$mode); + let results = (outs AnyType:$result); + let assemblyFormat = "`(` $input (`,` $mode^)? `)` attr-dict `:` functional-type(operands, results)"; +} + +def Fly_GetScalarOp : Fly_Op<"get_scalar", [Pure, DeclareOpInterfaceMethods]> { + let arguments = (ins Fly_IntTuple:$int_tuple); + let results = (outs AnyType:$result); + let assemblyFormat = "`(` $int_tuple `)` attr-dict `:` functional-type(operands, results)"; +} + +def Fly_GetLeavesOp : Fly_Op<"get_leaves", [Pure, DeclareOpInterfaceMethods]> { + let arguments = (ins AnyTypeOf<[Fly_IntTuple, Fly_Layout]>:$input); + let results = (outs Fly_IntTuple:$result); + let assemblyFormat = "`(` $input `)` attr-dict `:` functional-type(operands, results)"; +} + +def Fly_GetShapeOp : Fly_Op<"get_shape", [Pure, DeclareOpInterfaceMethods]> { + let arguments = (ins Fly_Layout:$layout); + let results = (outs Fly_IntTuple:$result); + let assemblyFormat = "`(` $layout `)` attr-dict `:` functional-type(operands, results)"; +} + +def Fly_GetStrideOp : Fly_Op<"get_stride", [Pure, DeclareOpInterfaceMethods]> { + let arguments = (ins Fly_Layout:$layout); + let results = (outs Fly_IntTuple:$result); + let assemblyFormat = "`(` $layout `)` attr-dict `:` functional-type(operands, results)"; +} + +def Fly_GetLayoutOp : Fly_Op<"get_layout", [Pure, DeclareOpInterfaceMethods]> { + let arguments = (ins Fly_MemRef:$memref); + let results = (outs Fly_Layout:$result); + let assemblyFormat = "`(` $memref `)` attr-dict `:` functional-type(operands, results)"; +} + +def Fly_GetIterOp : Fly_Op<"get_iter", [Pure, DeclareOpInterfaceMethods]> { + let arguments = (ins Fly_MemRef:$memref); + let results = (outs Fly_Pointer:$result); + let assemblyFormat = "`(` $memref `)` attr-dict `:` functional-type(operands, results)"; +} + +def Fly_GetLayoutsFromTileOp : Fly_Op<"get_layouts_from_tile", [Pure]> { + let arguments = (ins Fly_Tile:$tile); + let results = (outs AnyType:$result); + let assemblyFormat = "`(` $tile `)` attr-dict `:` functional-type(operands, results)"; +} + +def Fly_GetLeafOp : Fly_Op<"get_leaf", [Pure, DeclareOpInterfaceMethods]> { + let summary = "Get a leaf element from an IntTuple or Layout"; + let description = [{}]; + + let arguments = (ins AnyTypeOf<[Fly_IntTuple, Fly_Layout]>:$tuple, I32Attr:$leaf_idx); + let results = (outs AnyTypeOf<[Fly_IntTuple, Fly_Layout]>:$leaf); + let assemblyFormat = "`(` $tuple `,` $leaf_idx `)` attr-dict `:` functional-type(operands, $leaf)"; +} + +def Fly_ComposedGetInnerOp : Fly_Op<"composed_get_inner", [Pure, DeclareOpInterfaceMethods]> { + let arguments = (ins Fly_ComposedLayout:$input); + let results = (outs AnyType:$result); + let assemblyFormat = "`(` $input `)` attr-dict `:` functional-type(operands, results)"; +} + +def Fly_ComposedGetOffsetOp : Fly_Op<"composed_get_offset", [Pure, DeclareOpInterfaceMethods]> { + let arguments = (ins Fly_ComposedLayout:$input); + let results = (outs Fly_IntTuple:$result); + let assemblyFormat = "`(` $input `)` attr-dict `:` functional-type(operands, results)"; +} + +def Fly_ComposedGetOuterOp : Fly_Op<"composed_get_outer", [Pure, DeclareOpInterfaceMethods]> { + let arguments = (ins Fly_ComposedLayout:$input); + let results = (outs Fly_Layout:$result); + let assemblyFormat = "`(` $input `)` attr-dict `:` functional-type(operands, results)"; +} + +//===----------------------------------------------------------------------===// +// IntTuple operations +//===----------------------------------------------------------------------===// + +class Fly_IntTupleUnaryOp + : Fly_Op]> { + let arguments = (ins Fly_IntTuple:$input); + let results = (outs Fly_IntTuple:$result); + let assemblyFormat = "`(` $input `)` attr-dict `:` functional-type(operands, results)"; +} + +class Fly_IntTupleBinaryOp + : Fly_Op]> { + let arguments = (ins Fly_IntTuple:$lhs, Fly_IntTuple:$rhs); + let results = (outs Fly_IntTuple:$result); + let assemblyFormat = "`(` $lhs `,` $rhs `)` attr-dict `:` functional-type(operands, results)"; +} + +class Fly_IntTupleUnaryWithProfileOp + : Fly_Op]> { + let arguments = (ins Fly_IntTuple:$input, Optional:$target_profile); + let results = (outs Fly_IntTuple:$result); + let assemblyFormat = "`(` $input (`,` $target_profile^)? `)` attr-dict `:` functional-type(operands, results)"; +} + +def Fly_IntTupleAddOp : Fly_IntTupleBinaryOp<"int_tuple_add">; +def Fly_IntTupleSubOp : Fly_IntTupleBinaryOp<"int_tuple_sub">; +def Fly_IntTupleMulOp : Fly_IntTupleBinaryOp<"int_tuple_mul">; +def Fly_IntTupleDivOp : Fly_IntTupleBinaryOp<"int_tuple_div">; +def Fly_IntTupleModOp : Fly_IntTupleBinaryOp<"int_tuple_mod">; + +def Fly_IntTupleProductEachOp : Fly_IntTupleUnaryOp<"int_tuple_product_each">; +def Fly_IntTupleProductOp : Fly_IntTupleUnaryOp<"int_tuple_product">; + +def Fly_ShapeDivOp : Fly_IntTupleBinaryOp<"shape_div">; +def Fly_CeilDivOp : Fly_IntTupleBinaryOp<"ceil_div">; +def Fly_ElemLessOp : Fly_IntTupleBinaryOp<"elem_less">; +def Fly_EqualOp : Fly_IntTupleBinaryOp<"equal">; + +def Fly_AppendOp : Fly_Op<"append", [Pure, DeclareOpInterfaceMethods]> { + let arguments = (ins Fly_Layout:$tuple, + Fly_Layout:$elem, + OptionalAttr:$n); + let results = (outs Fly_Layout:$result); + let assemblyFormat = "(`<` $n^ `>`)? `(` $tuple `,` $elem `)` attr-dict `:` functional-type(operands, results)"; +} +def Fly_PrependOp : Fly_Op<"prepend", [Pure, DeclareOpInterfaceMethods]> { + let arguments = (ins Fly_Layout:$tuple, + Fly_Layout:$elem, + OptionalAttr:$n); + let results = (outs Fly_Layout:$result); + let assemblyFormat = "(`<` $n^ `>`)? `(` $tuple `,` $elem `)` attr-dict `:` functional-type(operands, results)"; +} + +def Fly_SelectOp : Fly_Op<"select", [Pure, DeclareOpInterfaceMethods]> { + let summary = "Select elements from an IntTuple or Layout by indices"; + let description = [{}]; + + let arguments = (ins IntTupleLikeType:$tuple, DenseI32ArrayAttr:$indices); + let results = (outs IntTupleLikeType:$result); + let assemblyFormat = "`(` $tuple `,` $indices `)` attr-dict `:` functional-type(operands, results)"; +} +def Fly_GroupOp : Fly_Op<"group", [Pure, DeclareOpInterfaceMethods]> { + let summary = "Group elements in an IntTuple or Layout"; + let description = [{}]; + + let arguments = (ins AnyTypeOf<[Fly_IntTuple, Fly_Layout, Fly_MemRef]>:$tuple, I32Attr:$begin, I32Attr:$end); + let results = (outs AnyTypeOf<[Fly_IntTuple, Fly_Layout, Fly_MemRef]>:$result); + let assemblyFormat = "`(` $tuple `,` $begin `,` $end `)` attr-dict `:` functional-type(operands, results)"; +} + + +def Fly_SliceOp : Fly_Op<"slice", [Pure, DeclareOpInterfaceMethods]> { + let arguments = (ins AnyTypeOf<[Fly_IntTuple, Fly_Layout, Fly_MemRef]>:$src, Fly_IntTuple:$coord); + let results = (outs AnyTypeOf<[Fly_IntTuple, Fly_Layout, Fly_MemRef]>:$result); + let assemblyFormat = "`(` $src `,` $coord `)` attr-dict `:` functional-type(operands, results)"; +} +def Fly_DiceOp : Fly_Op<"dice", [Pure, DeclareOpInterfaceMethods]> { + let arguments = (ins AnyTypeOf<[Fly_IntTuple, Fly_Layout, Fly_MemRef]>:$src, Fly_IntTuple:$coord); + let results = (outs AnyTypeOf<[Fly_IntTuple, Fly_Layout, Fly_MemRef]>:$result); + let assemblyFormat = "`(` $src `,` $coord `)` attr-dict `:` functional-type(operands, results)"; +} + +//===----------------------------------------------------------------------===// +// Layout operations +//===----------------------------------------------------------------------===// + +def Fly_SizeOp : Fly_Op<"size", [Pure, DeclareOpInterfaceMethods]> { + let arguments = (ins AnyTypeOf<[Fly_IntTuple, Fly_Layout, Fly_MemRef]>:$int_tuple); + let results = (outs Fly_IntTuple:$result); + let assemblyFormat = "`(` $int_tuple `)` attr-dict `:` functional-type(operands, results)"; +} + +def Fly_CosizeOp : Fly_Op<"cosize", [Pure, DeclareOpInterfaceMethods]> { + let arguments = (ins Fly_Layout:$layout); + let results = (outs Fly_IntTuple:$result); + let assemblyFormat = "$layout attr-dict `:` functional-type(operands, results)"; +} + +def Fly_Crd2IdxOp : Fly_Op<"crd2idx", [Pure, DeclareOpInterfaceMethods]> { + let arguments = (ins Fly_IntTuple:$coord, Fly_Layout:$layout); + let results = (outs Fly_IntTuple:$index); + let assemblyFormat = "`(` $coord `,` $layout `)` attr-dict `:` functional-type(operands, results)"; +} + +def Fly_Idx2CrdOp : Fly_Op<"idx2crd", [Pure, DeclareOpInterfaceMethods]> { + let arguments = (ins Fly_IntTuple:$coord, Fly_Layout:$layout); + let results = (outs Fly_IntTuple:$index); + let assemblyFormat = "`(` $coord `,` $layout `)` attr-dict `:` functional-type(operands, results)"; +} + +def Fly_GetFlatCoordOp : Fly_Op<"get_flat_coord", [Pure, DeclareOpInterfaceMethods]> { + let arguments = (ins Fly_IntTuple:$index, AnyTypeOf<[Fly_IntTuple, Fly_Layout]>:$input); + let results = (outs Fly_IntTuple:$result); + let assemblyFormat = "`(` $index `,` $input `)` attr-dict `:` functional-type(operands, results)"; +} + +def Fly_GetHierCoordOp : Fly_Op<"get_hier_coord", [Pure, DeclareOpInterfaceMethods]> { + let arguments = (ins Fly_IntTuple:$index, AnyTypeOf<[Fly_IntTuple, Fly_Layout]>:$input); + let results = (outs Fly_IntTuple:$result); + let assemblyFormat = "`(` $index `,` $input `)` attr-dict `:` functional-type(operands, results)"; +} + + +def Fly_CoalesceOp : Fly_Op<"coalesce", [Pure, DeclareOpInterfaceMethods]> { + let arguments = (ins Fly_Layout:$layout, Optional:$attr); + let results = (outs Fly_Layout:$result); + let assemblyFormat = "`(` $layout (`,` $attr^)? `)` attr-dict `:` functional-type(operands, results)"; +} +def Fly_CompositionOp : Fly_Op<"composition", [Pure, DeclareOpInterfaceMethods]> { + let arguments = (ins Fly_Layout:$outer, AnyTypeOf<[Fly_Layout, Fly_Tile]>:$inner); + let results = (outs Fly_Layout:$result); + let assemblyFormat = "`(` $outer `,` $inner `)` attr-dict `:` functional-type(operands, results)"; +} +def Fly_ComplementOp : Fly_Op<"complement", [Pure, DeclareOpInterfaceMethods]> { + let arguments = (ins Fly_Layout:$layout, Optional:$codomain_size); + let results = (outs Fly_Layout:$result); + let assemblyFormat = "`(` $layout (`,` $codomain_size^)? `)` attr-dict `:` functional-type(operands, results)"; +} + +def Fly_RightInverseOp : Fly_Op<"right_inverse", [Pure, DeclareOpInterfaceMethods]> { + let arguments = (ins Fly_Layout:$layout); + let results = (outs Fly_Layout:$result); + let assemblyFormat = "`(` $layout `)` attr-dict `:` functional-type(operands, results)"; +} +def Fly_LeftInverseOp : Fly_Op<"left_inverse", [Pure, DeclareOpInterfaceMethods]> { + let arguments = (ins Fly_Layout:$layout); + let results = (outs Fly_Layout:$result); + let assemblyFormat = "`(` $layout `)` attr-dict `:` functional-type(operands, results)"; +} + +def Fly_RecastLayoutOp : Fly_Op<"recast_layout", [Pure, DeclareOpInterfaceMethods]> { + let arguments = (ins AnyTypeOf<[I32, I64]>:$new_type_bits, + AnyTypeOf<[I32, I64]>:$old_type_bits, + Fly_Layout:$src); + let results = (outs Fly_Layout:$result); + let assemblyFormat = "`(` $new_type_bits `,` $old_type_bits `,` $src `)` attr-dict `:` functional-type(operands, results)"; +} + +class Fly_LayoutDivideOp + : Fly_Op]> { + let arguments = (ins LayoutLikeType:$layout, AnyTypeOf<[Fly_Layout, Fly_Tile]>:$divisor); + let results = (outs LayoutLikeType:$result); + let assemblyFormat = "`(` $layout `,` $divisor `)` attr-dict `:` functional-type(operands, results)"; +} + +class Fly_LayoutProductOp + : Fly_Op]> { + let arguments = (ins LayoutLikeType:$layout, Fly_Layout:$tile); + let results = (outs LayoutLikeType:$result); + let assemblyFormat = "`(` $layout `,` $tile `)` attr-dict `:` functional-type(operands, results)"; +} + +def Fly_LogicalDivideOp : Fly_LayoutDivideOp<"logical_divide">; +def Fly_ZippedDivideOp : Fly_LayoutDivideOp<"zipped_divide">; +def Fly_TiledDivideOp : Fly_LayoutDivideOp<"tiled_divide">; +def Fly_FlatDivideOp : Fly_LayoutDivideOp<"flat_divide">; + +def Fly_LogicalProductOp : Fly_LayoutProductOp<"logical_product">; +def Fly_ZippedProductOp : Fly_LayoutProductOp<"zipped_product">; +def Fly_TiledProductOp : Fly_LayoutProductOp<"tiled_product">; +def Fly_FlatProductOp : Fly_LayoutProductOp<"flat_product">; +def Fly_BlockedProductOp : Fly_LayoutProductOp<"blocked_product">; +def Fly_RakedProductOp : Fly_LayoutProductOp<"raked_product">; + +def Fly_TileToShapeOp : Fly_Op<"tile_to_shape", [Pure, DeclareOpInterfaceMethods]> { + let arguments = (ins AnyTypeOf<[Fly_Layout, Fly_Tile]>:$block, + Fly_IntTuple:$trg_shape, + Fly_IntTuple:$ord_shape); + let results = (outs Fly_Layout:$result); + let assemblyFormat = "`(` $block `,` $trg_shape `,` $ord_shape `)` attr-dict `:` functional-type(operands, results)"; +} + + +//===----------------------------------------------------------------------===// +// Atom and Tiled Mma/Copy ops +//===----------------------------------------------------------------------===// +def Fly_MakeAtomOp : Fly_Op<"make_atom", [Pure]> { + let arguments = (ins); + let results = (outs AnyType:$result); + let assemblyFormat = "attr-dict `:` functional-type(operands, results)"; +} + +def Fly_CopyAtomCall : Fly_Op<"copy_atom_call"> { + let arguments = (ins AnyType:$copyAtom, Fly_MemRef:$src, Fly_MemRef:$dst); + let results = (outs); + let assemblyFormat = "`(` $copyAtom `,` $src `,` $dst `)` attr-dict `:` functional-type(operands, results)"; +} +def Fly_MmaAtomCall : Fly_Op<"mma_atom_call"> { + let arguments = (ins AnyType:$mmaAtom, Fly_MemRef:$d, Fly_MemRef:$a, Fly_MemRef:$b, Fly_MemRef:$c); + let results = (outs); + let assemblyFormat = "`(` $mmaAtom `,` $d `,` $a `,` $b `,` $c `)` attr-dict `:` functional-type(operands, results)"; +} + +def Fly_MakeTiledCopyOp : Fly_Op<"make_tiled_copy", [Pure, DeclareOpInterfaceMethods]> { + let arguments = (ins AnyType:$copyAtom, Fly_Layout:$layoutTV, Fly_Tile:$tileMN); + let results = (outs Fly_TiledCopy:$result); + let assemblyFormat = "`(` $copyAtom `,` $layoutTV `,` $tileMN `)` attr-dict `:` functional-type(operands, results)"; +} +def Fly_MakeTiledMmaOp : Fly_Op<"make_tiled_mma", [Pure]> { + let arguments = (ins AnyType:$atom); + let results = (outs AnyType:$result); + let assemblyFormat = "`(` $atom `)` attr-dict `:` functional-type(operands, results)"; +} + + +def Fly_TiledCopyPartitionSrcOp : Fly_Op<"tiled_copy.partition_src", [Pure, DeclareOpInterfaceMethods]> { + let summary = ""; + let description = [{}]; + + let arguments = (ins Fly_TiledCopy:$tiledCopy, Fly_MemRef:$src); + let results = (outs Fly_MemRef:$result); + let assemblyFormat = "`(` $tiledCopy `,` $src `)` attr-dict `:` functional-type(operands, results)"; +} +def Fly_TiledCopyPartitionDstOp : Fly_Op<"tiled_copy.partition_dst", [Pure, DeclareOpInterfaceMethods]> { + let summary = ""; + let description = [{}]; + + let arguments = (ins Fly_TiledCopy:$tiledCopy, Fly_MemRef:$dst); + let results = (outs Fly_MemRef:$result); + let assemblyFormat = "`(` $tiledCopy `,` $dst `)` attr-dict `:` functional-type(operands, results)"; +} + +def Fly_CopyOp : Fly_Op<"copy"> { + let arguments = (ins AnyType:$copyAtom, Fly_MemRef:$src, Fly_MemRef:$dst, OptionalAttr:$pred); + let results = (outs); + let assemblyFormat = "`(` $copyAtom `,` $src `,` $dst (`,` $pred^)? `)` attr-dict `:` functional-type(operands, results)"; +} + +def Fly_GemmOp : Fly_Op<"gemm"> { + let arguments = (ins AnyType:$mmaAtom, Fly_MemRef:$d, Fly_MemRef:$a, + Fly_MemRef:$b, Fly_MemRef:$c); + let results = (outs); + let assemblyFormat = "`(` $mmaAtom `,` $d `,` $a `,` $b `,` $c `)` attr-dict `:` functional-type(operands, results)"; +} + + + +def Fly_MmaMakeFragmentOp : Fly_Op<"mma_make_fragment", [Pure]> { + let arguments = (ins AnyTypeOf<[I32, I64]>:$operand_id, AnyType:$atom, AnyType:$input); + let results = (outs AnyType:$result); + let assemblyFormat = "`(` $operand_id `,` $atom `,` $input `)` attr-dict `:` functional-type(operands, results)"; +} + +def Fly_TiledCopyPartitionDOp : Fly_Op<"tiled_copy.partition_D", [Pure, DeclareOpInterfaceMethods]> { + let arguments = (ins Fly_TiledCopy:$tiledCopy, Fly_MemRef:$input, Fly_IntTuple:$coord); + let results = (outs Fly_MemRef:$result); + let assemblyFormat = "`(` $tiledCopy `,` $input `,` $coord `)` attr-dict `:` functional-type(operands, results)"; +} + +def Fly_TiledCopyPartitionSOp : Fly_Op<"tiled_copy.partition_S", [Pure, DeclareOpInterfaceMethods]> { + let arguments = (ins Fly_TiledCopy:$tiledCopy, Fly_MemRef:$input, Fly_IntTuple:$coord); + let results = (outs Fly_MemRef:$result); + let assemblyFormat = "`(` $tiledCopy `,` $input `,` $coord `)` attr-dict `:` functional-type(operands, results)"; +} + +def Fly_TiledCopyRetileOp : Fly_Op<"tiled_copy.retile", [Pure, DeclareOpInterfaceMethods]> { + let arguments = (ins Fly_TiledCopy:$tiledCopy, Fly_MemRef:$input); + let results = (outs Fly_MemRef:$result); + let assemblyFormat = "`(` $tiledCopy `,` $input `)` attr-dict `:` functional-type(operands, results)"; +} + +def Fly_TiledMmaPartitionOp : Fly_Op<"tiled_mma_partition", [Pure, DeclareOpInterfaceMethods]> { + let arguments = (ins AnyTypeOf<[I32, I64]>:$operand_id, AnyType:$tiled_mma, + Fly_MemRef:$input, Fly_IntTuple:$coord); + let results = (outs Fly_MemRef:$result); + let assemblyFormat = "`(` $operand_id `,` $tiled_mma `,` $input `,` $coord `)` attr-dict `:` functional-type(operands, results)"; +} + +def Fly_TiledMmaPartitionShapeOp : Fly_Op<"tiled_mma_partition_shape", [Pure, DeclareOpInterfaceMethods]> { + let arguments = (ins AnyTypeOf<[I32, I64]>:$operand_id, AnyType:$tiled_mma, Fly_MemRef:$input); + let results = (outs Fly_IntTuple:$result); + let assemblyFormat = "`(` $operand_id `,` $tiled_mma `,` $input `)` attr-dict `:` functional-type(operands, results)"; +} + + +def Fly_CooperativeCopyOp : Fly_Op<"cooperative_copy", []> { + let summary = ""; + let description = [{}]; + + let arguments = (ins Fly_TiledCopy:$tiledCopy, Fly_IntTuple:$partitionIdx, Fly_MemRef:$src, Fly_MemRef:$dst); + let results = (outs); + let assemblyFormat = "`(` $tiledCopy `,` $partitionIdx `,` $src `,` $dst `)` attr-dict `:` functional-type(operands, results)"; +} + + +//===----------------------------------------------------------------------===// +// MemRef and Ptr operations +//===----------------------------------------------------------------------===// + +def Fly_MemRefAllocaOp : Fly_Op<"memref.alloca", []> { + let arguments = (ins Fly_Layout:$layout); + let results = (outs Fly_MemRef:$result); + let assemblyFormat = "`(` $layout `)` attr-dict `:` functional-type(operands, results)"; +} +def Fly_MemRefAllocSharedOp : Fly_Op<"memref.alloc_shared", [DeclareOpInterfaceMethods]> { + let arguments = (ins AnyLayoutType:$memref); + let results = (outs Fly_MemRef:$result); + let assemblyFormat = "`(` $memref `)` attr-dict `:` functional-type(operands, results)"; +} + +def Fly_MemRefLoadOp : Fly_Op<"memref.load", [DeclareOpInterfaceMethods]> { + let arguments = (ins Fly_MemRef:$memref, Fly_IntTuple:$indices); + let results = (outs AnyType:$result); + let assemblyFormat = "`(` $memref `,` $indices `)` attr-dict `:` functional-type(operands, results)"; +} +def Fly_MemRefStoreOp : Fly_Op<"memref.store", []> { + let arguments = (ins AnyType:$value, Fly_MemRef:$memref, Fly_IntTuple:$indices); + let assemblyFormat = "`(` $value `,` $memref `,` $indices `)` attr-dict `:` functional-type(operands, results)"; +} + +def Fly_MemRefLoadVecOp : Fly_Op<"memref.load_vec", [DeclareOpInterfaceMethods]> { + let arguments = (ins Fly_MemRef:$memref); // restrict to register tensor + let results = (outs AnyType:$result); + let assemblyFormat = "`(` $memref `)` attr-dict `:` functional-type(operands, results)"; +} +def Fly_MemRefStoreVecOp : Fly_Op<"memref.store_vec", []> { + let arguments = (ins AnyType:$vector, Fly_MemRef:$memref); // restrict to register tensor + let results = (outs); + let assemblyFormat = "`(` $vector `,` $memref `)` attr-dict `:` functional-type(operands, results)"; +} + +def Fly_PtrLoadOp : Fly_Op<"ptr.load"> { + let arguments = (ins Fly_Pointer:$ptr); + let results = (outs AnyType:$result); + let assemblyFormat = "`(` $ptr `)` attr-dict `:` functional-type(operands, results)"; +} +def Fly_PtrStoreOp : Fly_Op<"ptr.store"> { + let arguments = (ins AnyType:$value, Fly_Pointer:$ptr); + let assemblyFormat = "`(` $value `,` $ptr `)` attr-dict `:` functional-type(operands, results)"; +} +def Fly_ApplySwizzleOp : Fly_Op<"apply_swizzle", [Pure, DeclareOpInterfaceMethods]> { + let arguments = (ins Fly_Pointer:$ptr); + let results = (outs Fly_Pointer:$result); + let assemblyFormat = "`(` $ptr `)` attr-dict `:` functional-type(operands, results)"; +} +def Fly_RecastIterOp : Fly_Op<"recast_iter", [Pure, DeclareOpInterfaceMethods]> { + let arguments = (ins Fly_Pointer:$src); + let results = (outs Fly_Pointer:$result); + let assemblyFormat = "`(` $src `)` attr-dict `:` functional-type(operands, results)"; +} + +def Fly_AddOffsetOp : Fly_Op<"add_offset", [Pure, DeclareOpInterfaceMethods]> { + let arguments = (ins IteratorLikeType:$ptr, Fly_IntTuple:$offset); + let results = (outs IteratorLikeType:$result); + let assemblyFormat = "`(` $ptr `,` $offset `)` attr-dict `:` functional-type(operands, results)"; +} + +//===----------------------------------------------------------------------===// +// Utility ops +//===----------------------------------------------------------------------===// + +def Fly_PrintOp : Fly_Op<"print"> { + let arguments = (ins StrAttr:$format, Variadic:$values); + let assemblyFormat = "`(` $values `)` attr-dict `:` functional-type(operands, results)"; +} + +def Fly_AssumeOp : Fly_Op<"assume", [Pure]> { + let arguments = (ins AnyType:$dst, AnyType:$src); + let results = (outs AnyType:$result); + let assemblyFormat = "`(` $dst `,` $src `)` attr-dict `:` functional-type(operands, results)"; +} + +#endif // FLY_OPS diff --git a/include/flydsl/Dialect/Fly/IR/FlyTypeDefs.td b/include/flydsl/Dialect/Fly/IR/FlyTypeDefs.td new file mode 100644 index 00000000..301f20bc --- /dev/null +++ b/include/flydsl/Dialect/Fly/IR/FlyTypeDefs.td @@ -0,0 +1,196 @@ +#ifndef FLY_TYPEDEFS +#define FLY_TYPEDEFS + +include "flydsl/Dialect/Fly/IR/FlyDialect.td" +include "flydsl/Dialect/Fly/IR/FlyAttrDefs.td" + +def Fly_Basis : Fly_Type<"Basis", "basis", [ + DeclareTypeInterfaceMethods +]> { + let parameters = (ins Fly_BasisAttr:$attr); + let assemblyFormat = "`<` $attr `>`"; + + let extraClassDeclaration = [{ + int32_t depth(); + }]; + + let builders = [ + TypeBuilderWithInferredContext<(ins "BasisAttr":$attr), [{ + return $_get(attr.getContext(), attr); + }]> + ]; +} + +def Fly_IntTuple : Fly_Type<"IntTuple", "int_tuple", [ + DeclareTypeInterfaceMethods, + DeclareTypeInterfaceMethods +]> { + let parameters = (ins Fly_IntTupleAttr:$attr); + let assemblyFormat = "`<` $attr `>`"; + + let extraClassDeclaration = [{ + IntTupleType at(int32_t idx) const; + IntTupleType at(ArrayRef idxs) const; + }]; + + let builders = [ + TypeBuilderWithInferredContext<(ins "IntTupleAttr":$attr), [{ + return $_get(attr.getContext(), attr); + }]> + ]; +} + +def Fly_Layout : Fly_Type<"Layout", "layout", [ + DeclareTypeInterfaceMethods, + DeclareTypeInterfaceMethods +]> { + let parameters = (ins Fly_LayoutAttr:$attr); + let assemblyFormat = "`<` $attr `>`"; + + let extraClassDeclaration = [{ + bool isStaticShape() const; + bool isStaticStride() const; + + LayoutType at(int32_t idx) const; + LayoutType at(ArrayRef idxs) const; + }]; + + let builders = [ + TypeBuilderWithInferredContext<(ins "LayoutAttr":$attr), [{ + return $_get(attr.getContext(), attr); + }]> + ]; +} + +def Fly_Swizzle : Fly_Type<"Swizzle", "swizzle", []> { + let parameters = (ins Fly_SwizzleAttr:$attr); + let assemblyFormat = "`<` $attr `>`"; +} + +def Fly_ComposedLayout : Fly_Type<"ComposedLayout", "composed_layout", [ + DeclareTypeInterfaceMethods, + DeclareTypeInterfaceMethods +]> { + let parameters = (ins Fly_ComposedLayoutAttr:$attr); + let assemblyFormat = "`<` $attr `>`"; + + let extraClassDeclaration = [{ + bool isStaticOuter() const; + bool isStaticInner() const; + bool isStaticOffset() const; + + ComposedLayoutType at(int32_t idx) const; + ComposedLayoutType at(ArrayRef idxs) const; + }]; + + let builders = [ + TypeBuilderWithInferredContext<(ins "ComposedLayoutAttr":$attr), [{ + return $_get(attr.getContext(), attr); + }]> + ]; +} + +def Fly_Tile : Fly_Type<"Tile", "tile", []> { + let parameters = (ins Fly_TileAttr:$attr); + let assemblyFormat = "`<` $attr `>`"; + + let extraClassDeclaration = [{ + int32_t rank() const; + }]; + + let builders = [ + TypeBuilderWithInferredContext<(ins "TileAttr":$attr), [{ + return $_get(attr.getContext(), attr); + }]> + ]; +} + +def Fly_Pointer : Fly_Type<"Pointer", "ptr", []> { + let parameters = (ins + "Type":$elemTy, + "AddressSpaceAttr":$addressSpace, + DefaultValuedParameter<"AlignAttr","AlignAttr::getTrivialAlignment($_ctxt)">:$alignment, + DefaultValuedParameter<"SwizzleAttr","SwizzleAttr::getTrivialSwizzle($_ctxt)">:$swizzle + ); + let assemblyFormat = "`<` $elemTy `,` `` $addressSpace (`,` $alignment^)? (`,` $swizzle^)? `>`"; + + let builders = [ + TypeBuilderWithInferredContext<(ins "Type":$elemTy, "AddressSpaceAttr":$addressSpace), [{ + return $_get(elemTy.getContext(), elemTy, addressSpace, + AlignAttr::getTrivialAlignment(elemTy.getContext()), + SwizzleAttr::getTrivialSwizzle(elemTy.getContext())); + }]> + ]; + let extraClassDeclaration = [{}]; +} + +def Fly_CoordTensor : Fly_Type<"CoordTensor", "coord_tensor", [ + DeclareTypeInterfaceMethods, + DeclareTypeInterfaceMethods +]> { + let parameters = (ins Fly_IntTupleAttr:$base, Fly_LayoutAttr:$layout); + let assemblyFormat = "`<` $base `,` $layout `>`"; + + let extraClassDeclaration = [{ + CoordTensorType at(int32_t idx) const; + CoordTensorType at(ArrayRef idxs) const; + }]; + + let builders = [ + TypeBuilderWithInferredContext<(ins "IntTupleAttr":$base, "LayoutAttr":$layout), [{ + return $_get(base.getContext(), base, layout); + }]> + ]; +} + +def Fly_MemRef : Fly_Type<"MemRef", "memref", []> { + let parameters = (ins + "Type":$elemTy, + "AddressSpaceAttr":$addressSpace, + Fly_LayoutAttr:$layout, + DefaultValuedParameter<"AlignAttr","AlignAttr::getTrivialAlignment($_ctxt)">:$alignment, + DefaultValuedParameter<"SwizzleAttr","SwizzleAttr::getTrivialSwizzle($_ctxt)">:$swizzle + ); + let assemblyFormat = "`<` $elemTy `,` `` $addressSpace `,` $layout (`,` $alignment^)? (`,` $swizzle^)? `>`"; + + let builders = [ + AttrBuilderWithInferredContext<(ins "Type":$elemTy, "AddressSpaceAttr":$addressSpace, "LayoutAttr":$layout), [{ + return $_get(elemTy.getContext(), elemTy, addressSpace, layout, + AlignAttr::getTrivialAlignment(elemTy.getContext()), + SwizzleAttr::getTrivialSwizzle(elemTy.getContext())); + }]> + ]; + let extraClassDeclaration = [{}]; +} + +def IteratorLikeType : AnyTypeOf<[Fly_IntTuple, Fly_Pointer]>; +def TensorLikeType : AnyTypeOf<[Fly_CoordTensor, Fly_MemRef]>; +def AnyLayoutType : AnyTypeOf<[Fly_Layout, Fly_ComposedLayout]>; +def LayoutLikeType : AnyTypeOf<[AnyLayoutType, TensorLikeType]>; +def IntTupleLikeType : AnyTypeOf<[Fly_IntTuple, LayoutLikeType]>; + +def Fly_TiledCopy : Fly_Type<"TiledCopy", "tiled_copy", []> { + let parameters = (ins + "Type":$copyAtom, + Fly_Layout:$layoutThrVal, + Fly_Tile:$tileMN + ); + let assemblyFormat = "`<` $copyAtom `,` $layoutThrVal `,` $tileMN `>`"; +} + +def Fly_TiledMma : Fly_Type<"TiledMma", "tiled_mma", []> { + let parameters = (ins + "Type":$mmaAtom, + Fly_Layout:$atomLayout, + Fly_Tile:$permutation + ); + let assemblyFormat = "`<` $mmaAtom `,` $atomLayout `,` $permutation `>`"; +} + + +def Fly_CopyAtomGlobalLoad4B : Fly_Type<"CopyAtomGlobalLoad4B", "atom.global_load_4B", []> {} +def Fly_CopyAtomUniversalCopy32b : Fly_Type<"CopyAtomUniversalCopy32b", "atom.universal_copy_32b", []> {} + +def Fly_MmaAtomMFMA_F32_16x16x4F32 : Fly_Type<"MmaAtomMFMA_F32_16x16x4F32", "atom.amdgpu.mfma.f32.16x16x4f32", []> {} + +#endif // FLY_TYPEDEFS diff --git a/include/flydsl/Dialect/Fly/Transforms/CMakeLists.txt b/include/flydsl/Dialect/Fly/Transforms/CMakeLists.txt new file mode 100644 index 00000000..3ee211fb --- /dev/null +++ b/include/flydsl/Dialect/Fly/Transforms/CMakeLists.txt @@ -0,0 +1,12 @@ +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls -name Fly) +mlir_tablegen(Passes.capi.h.inc -gen-pass-capi-header -name Fly) +mlir_tablegen(Passes.capi.cpp.inc -gen-pass-capi-impl -name Fly) + +add_mlir_generic_tablegen_target(FlyTransformPassIncGen) + +set(LLVM_TARGET_DEFINITIONS LayoutLowering.td) +mlir_tablegen(LayoutLowering.cpp.inc -gen-rewriters) + +add_mlir_generic_tablegen_target(FlyTransformPatternsIncGen) + diff --git a/include/flydsl/Dialect/Fly/Transforms/LayoutLowering.td b/include/flydsl/Dialect/Fly/Transforms/LayoutLowering.td new file mode 100644 index 00000000..43b1cef0 --- /dev/null +++ b/include/flydsl/Dialect/Fly/Transforms/LayoutLowering.td @@ -0,0 +1,68 @@ +#ifndef FLY_TRANSFORMS_LAYOUT_LOWERING +#define FLY_TRANSFORMS_LAYOUT_LOWERING + +include "mlir/IR/PatternBase.td" +include "flydsl/Dialect/Fly/IR/FlyOps.td" + +def : Pat<(Fly_GetLeafOp Fly_Layout:$layout, I32Attr:$leaf_idx), + (Fly_MakeLayoutOp + (Fly_GetLeafOp (Fly_GetShapeOp $layout), $leaf_idx), + (Fly_GetLeafOp (Fly_GetStrideOp $layout), $leaf_idx) + )>; + +def : Pat<(Fly_GetShapeOp Fly_MemRef:$memref), + (Fly_GetShapeOp (Fly_GetLayoutOp $memref))>; +def : Pat<(Fly_GetStrideOp Fly_MemRef:$memref), + (Fly_GetStrideOp (Fly_GetLayoutOp $memref))>; + + +def : Pat<(Fly_SliceOp Fly_Layout:$layout, Fly_IntTuple:$coord), + (Fly_MakeLayoutOp + (Fly_SliceOp (Fly_GetShapeOp $layout), $coord), + (Fly_SliceOp (Fly_GetStrideOp $layout), $coord) + )>; +def : Pat<(Fly_SliceOp Fly_MemRef:$memref, Fly_IntTuple:$coord), + (Fly_MakeViewOp + (Fly_AddOffsetOp (Fly_GetIterOp $memref), + (Fly_Crd2IdxOp $coord, (Fly_GetLayoutOp $memref))), + (Fly_SliceOp (Fly_GetLayoutOp $memref), $coord) + )>; + +def : Pat<(Fly_SizeOp Fly_Layout:$layout), + (Fly_SizeOp (Fly_GetShapeOp $layout))>; +def : Pat<(Fly_SizeOp Fly_MemRef:$memref), + (Fly_SizeOp (Fly_GetLayoutOp $memref))>; + +def : Pat<(Fly_SelectOp Fly_Layout:$layout, DenseI32ArrayAttr:$indices), + (Fly_MakeLayoutOp + (Fly_SelectOp (Fly_GetShapeOp $layout), $indices), + (Fly_SelectOp (Fly_GetStrideOp $layout), $indices) + )>; + +def : Pat<(Fly_GroupOp Fly_Layout:$layout, I32Attr:$begin, I32Attr:$end), + (Fly_MakeLayoutOp + (Fly_GroupOp (Fly_GetShapeOp $layout), $begin, $end), + (Fly_GroupOp (Fly_GetStrideOp $layout), $begin, $end) + )>; + +def : Pat<(Fly_LogicalDivideOp Fly_ComposedLayout:$layout, AnyTypeOf<[Fly_Layout, Fly_Tile]>:$divisor), + (Fly_MakeComposedLayoutOp + (Fly_ComposedGetInnerOp $layout), + (Fly_ComposedGetOffsetOp $layout), + (Fly_LogicalDivideOp (Fly_ComposedGetOuterOp $layout), $divisor) + )>; + + +def : Pat<(Fly_LogicalDivideOp Fly_MemRef:$mem, AnyTypeOf<[Fly_Layout, Fly_Tile]>:$divisor), + (Fly_MakeViewOp + (Fly_GetIterOp $mem), + (Fly_LogicalDivideOp (Fly_GetLayoutOp $mem), $divisor) + )>; + +def : Pat<(Fly_ZippedDivideOp Fly_MemRef:$mem, AnyTypeOf<[Fly_Layout, Fly_Tile]>:$divisor), + (Fly_MakeViewOp + (Fly_GetIterOp $mem), + (Fly_ZippedDivideOp (Fly_GetLayoutOp $mem), $divisor) + )>; + +#endif // FLY_TRANSFORMS_LAYOUT_LOWERING diff --git a/include/flydsl/Dialect/Fly/Transforms/Passes.h b/include/flydsl/Dialect/Fly/Transforms/Passes.h new file mode 100644 index 00000000..ff2319b8 --- /dev/null +++ b/include/flydsl/Dialect/Fly/Transforms/Passes.h @@ -0,0 +1,19 @@ +#ifndef FLY_TRANSFORM_H +#define FLY_TRANSFORM_H + +#include "mlir/Pass/Pass.h" + +namespace mlir { +namespace fly { + +// Generate the pass class declarations. +#define GEN_PASS_DECL +#include "flydsl/Dialect/Fly/Transforms/Passes.h.inc" + +#define GEN_PASS_REGISTRATION +#include "flydsl/Dialect/Fly/Transforms/Passes.h.inc" + +} // namespace fly +} // namespace mlir + +#endif // FLY_TRANSFORM_H diff --git a/include/flydsl/Dialect/Fly/Transforms/Passes.td b/include/flydsl/Dialect/Fly/Transforms/Passes.td new file mode 100644 index 00000000..fd14d5ae --- /dev/null +++ b/include/flydsl/Dialect/Fly/Transforms/Passes.td @@ -0,0 +1,24 @@ +#ifndef FLY_PASSES +#define FLY_PASSES + +include "mlir/Pass/PassBase.td" + +def FlyCanonicalizePass : Pass<"fly-canonicalize"> { + let summary = "Canonicalize Pattern"; + let description = [{ + Canonicalize Fly operations. + }]; +} + +def FlyLayoutLoweringPass : Pass<"fly-layout-lowering"> { + let summary = "Lower layout algebra operations"; + let description = [{ + Lowers layout algebra operations to simpler forms. + }]; + + let dependentDialects = [ + + ]; +} + +#endif // FLY_PASSES diff --git a/include/flydsl/Dialect/Fly/Utils/IntTupleUtils.h b/include/flydsl/Dialect/Fly/Utils/IntTupleUtils.h new file mode 100644 index 00000000..2ffc8121 --- /dev/null +++ b/include/flydsl/Dialect/Fly/Utils/IntTupleUtils.h @@ -0,0 +1,1004 @@ +#ifndef FLY_DIALECT_UTILS_INTTUPLEUTILS_H +#define FLY_DIALECT_UTILS_INTTUPLEUTILS_H + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/PatternMatch.h" + +#include "flydsl/Dialect/Fly/IR/FlyDialect.h" +#include "flydsl/Dialect/Fly/Utils/IntUtils.h" + +//===----------------------------------------------------------------------===// +// IntTupleAttr constexpr utilities +//===----------------------------------------------------------------------===// + +namespace mlir::fly { + +bool intTupleHasNone(IntTupleAttr attr); +bool intTupleAllNone(IntTupleAttr attr); + +bool intTupleIsCongruent(IntTupleAttr lhs, IntTupleAttr rhs); +bool intTupleIsWeaklyCongruent(IntTupleAttr lhs, IntTupleAttr rhs); + +} // namespace mlir::fly + +//===----------------------------------------------------------------------===// +// Universal IntTuple utilities +//===----------------------------------------------------------------------===// + +namespace mlir::fly { + +template class IntTupleBuilder; + +class IntTupleValueAdaptor { +private: + int32_t prefixSumDyncElems(int32_t idx) const { + int dyncOffset = 0; + for (int32_t i = 0; i < idx; ++i) { + dyncOffset += attr.at(i).dyncLeafCount(); + } + return dyncOffset; + } + + IntTupleValueAdaptor(Value value, IntTupleAttr attr, int32_t dyncIdxStart = 0, + int32_t dyncIdxEnd = -1) + : value(value), attr(attr), dyncIdxStart(dyncIdxStart), dyncIdxEnd(dyncIdxEnd) {} + + Value value; + IntTupleAttr attr; + int32_t dyncIdxStart, dyncIdxEnd; + +public: + template + static IntTupleValueAdaptor create(Builder &builder, Value value, IntTupleAttr attr) { + auto defOp = value.getDefiningOp(); + assert(defOp && "Value must be a MakeIntTupleOp"); + if (attr.isLeaf()) { + if (attr.isStatic()) { + return IntTupleValueAdaptor( + builder.materializeConstantArith(attr.getLeafAsInt().getValue()).value, attr); + } else { + return IntTupleValueAdaptor(value.getDefiningOp()->getOperand(0), attr); + } + } else { + return IntTupleValueAdaptor(value, attr); + } + } + + bool isLeaf() const { return attr.isLeaf(); } + int32_t rank() const { return attr.rank(); } + int32_t depth() const { return attr.depth(); } + + friend class IntTupleBuilder; +}; + +template <> class IntTupleBuilder { +protected: + MLIRContext *ctx; + +public: + IntTupleBuilder(MLIRContext *ctx) : ctx(ctx) {} + + using ArithValue = IntAttr; + struct ElemCollector { + SmallVector attrCollector; + + void push_back(Attribute attr) { attrCollector.push_back(attr); } + size_t size() const { return attrCollector.size(); } + bool empty() const { return attrCollector.empty(); } + void reverse() { std::reverse(attrCollector.begin(), attrCollector.end()); } + }; + + ArithValue add(ArithValue lhs, ArithValue rhs) const { return lhs + rhs; } + ArithValue sub(ArithValue lhs, ArithValue rhs) const { return lhs - rhs; } + ArithValue mul(ArithValue lhs, ArithValue rhs) const { return lhs * rhs; } + ArithValue div(ArithValue lhs, ArithValue rhs) const { return lhs / rhs; } + ArithValue mod(ArithValue lhs, ArithValue rhs) const { return lhs % rhs; } + + ArithValue logicalAnd(ArithValue lhs, ArithValue rhs) const { return lhs && rhs; } + ArithValue logicalOr(ArithValue lhs, ArithValue rhs) const { return lhs || rhs; } + ArithValue logicalNot(ArithValue val) const { return !val; } + ArithValue lt(ArithValue lhs, ArithValue rhs) const { return lhs < rhs; } + ArithValue le(ArithValue lhs, ArithValue rhs) const { return lhs <= rhs; } + ArithValue gt(ArithValue lhs, ArithValue rhs) const { return lhs > rhs; } + ArithValue ge(ArithValue lhs, ArithValue rhs) const { return lhs >= rhs; } + ArithValue eq(ArithValue lhs, ArithValue rhs) const { return lhs == rhs; } + ArithValue ne(ArithValue lhs, ArithValue rhs) const { return lhs != rhs; } + + ArithValue min(ArithValue lhs, ArithValue rhs) const { return intMin(lhs, rhs); } + ArithValue max(ArithValue lhs, ArithValue rhs) const { return intMax(lhs, rhs); } + ArithValue safeDiv(ArithValue lhs, ArithValue rhs) const { return intSafeDiv(lhs, rhs); } + ArithValue ceilDiv(ArithValue lhs, ArithValue rhs) const { return intCeilDiv(lhs, rhs); } + ArithValue shapeDiv(ArithValue lhs, ArithValue rhs) const { return intShapeDiv(lhs, rhs); } + + IntTupleAttr getAttr(IntTupleAttr attr) const { return attr; } + ArithValue getArithValue(IntTupleAttr attr) const { return attr.getLeafAsInt(); } + + ArithValue materializeConstantArith(int32_t value) const { + return IntAttr::getStatic(ctx, value); + } + IntTupleAttr materializeConstantTuple(IntTupleAttr attr) const { + assert(attr.isStatic() && "Tuple must be static"); + return attr; + } + + bool isNone(ArithValue val) const { return val.isNone(); } + bool isStatic(ArithValue val) const { return val.isStatic(); } + bool isStaticValue(ArithValue val, int32_t v) const { return val.isStaticValue(v); } + int32_t getStaticValue(ArithValue val) const { return val.getValue(); } + + IntTupleAttr at(IntTupleAttr attr, int32_t idx) const { return attr.at(idx); } + IntTupleAttr makeInt(ArithValue value) const { return IntTupleAttr::get(value); } + IntTupleAttr makeTuple(const ElemCollector &collector) const { + return IntTupleAttr::get(ArrayAttr::get(ctx, collector.attrCollector)); + } + const IntTupleBuilder &getAttrBuilder() const { return *this; } +}; + +template <> class IntTupleBuilder { +protected: + PatternRewriter &builder; + Location loc; + IntTupleBuilder attrBuilder; + +public: + IntTupleBuilder(PatternRewriter &builder, Location loc) + : builder(builder), loc(loc), attrBuilder(builder.getContext()) {} + + struct ArithValue { + Value value; + IntAttr attr; + }; + struct ElemCollector { + typename IntTupleBuilder::ElemCollector attrCollector; + SmallVector dyncElems; + + void push_back(const IntTupleValueAdaptor &element) { + auto elemAttr = element.attr; + attrCollector.push_back(elemAttr); + if (elemAttr.isLeaf()) { + if (!elemAttr.isStatic()) { + dyncElems.push_back(element.value); + } + } else { + // Handle dyncIdxEnd == -1 (meaning "to the end") + int32_t dyncIdxEnd = element.dyncIdxEnd == -1 + ? element.value.getDefiningOp()->getOperands().size() + : element.dyncIdxEnd; + dyncElems.append(element.value.getDefiningOp()->getOperands().begin() + + element.dyncIdxStart, + element.value.getDefiningOp()->getOperands().begin() + dyncIdxEnd); + } + } + size_t size() const { return attrCollector.size(); } + bool empty() const { return attrCollector.empty(); } + void reverse() { + attrCollector.reverse(); + std::reverse(dyncElems.begin(), dyncElems.end()); + } + }; + + Type getIntType(IntAttr attr) const { + assert((attr.getWidth() == 64 || attr.getWidth() == 32) && "Invalid width"); + return attr.getWidth() == 64 ? builder.getI64Type() : builder.getI32Type(); + } + Type getCommonIntType(IntAttr lhs, IntAttr rhs) const { + assert((lhs.getWidth() == 64 || lhs.getWidth() == 32) && "Invalid width"); + assert((rhs.getWidth() == 64 || rhs.getWidth() == 32) && "Invalid width"); + return lhs.getWidth() == 64 || rhs.getWidth() == 64 ? builder.getI64Type() + : builder.getI32Type(); + } + Value extendToIntType(Value input, Type intType) const { + if (input.getType() != intType) { + input = arith::ExtSIOp::create(builder, loc, intType, input); + } + return input; + } + + ArithValue add(ArithValue lhs, ArithValue rhs) const; + ArithValue sub(ArithValue lhs, ArithValue rhs) const; + ArithValue mul(ArithValue lhs, ArithValue rhs) const; + ArithValue div(ArithValue lhs, ArithValue rhs) const; + ArithValue mod(ArithValue lhs, ArithValue rhs) const; + + ArithValue logicalAnd(ArithValue lhs, ArithValue rhs) const; + ArithValue logicalOr(ArithValue lhs, ArithValue rhs) const; + ArithValue logicalNot(ArithValue val) const; + ArithValue lt(ArithValue lhs, ArithValue rhs) const; + ArithValue le(ArithValue lhs, ArithValue rhs) const; + ArithValue gt(ArithValue lhs, ArithValue rhs) const; + ArithValue ge(ArithValue lhs, ArithValue rhs) const; + ArithValue eq(ArithValue lhs, ArithValue rhs) const; + ArithValue ne(ArithValue lhs, ArithValue rhs) const; + + ArithValue min(ArithValue lhs, ArithValue rhs) const; + ArithValue max(ArithValue lhs, ArithValue rhs) const; + ArithValue safeDiv(ArithValue lhs, ArithValue rhs) const { return div(lhs, rhs); } + ArithValue ceilDiv(ArithValue lhs, ArithValue rhs) const; + ArithValue shapeDiv(ArithValue lhs, ArithValue rhs) const; + + IntTupleAttr getAttr(IntTupleValueAdaptor adaptor) const { return adaptor.attr; } + + ArithValue getArithValue(IntTupleValueAdaptor adaptor) const { + assert(adaptor.attr.isLeaf() && "Adaptor must be a leaf"); + return ArithValue{adaptor.value, attrBuilder.getArithValue(adaptor.attr)}; + } + + ArithValue materializeConstantArith(int32_t value) const { + return ArithValue{arith::ConstantIntOp::create(builder, loc, value, 32).getResult(), + attrBuilder.materializeConstantArith(value)}; + } + IntTupleValueAdaptor materializeConstantTuple(IntTupleAttr attr) const { + assert(attr.isStatic() && "Tuple must be static"); + if (attr.isLeaf()) { + return IntTupleValueAdaptor{ + arith::ConstantIntOp::create(builder, loc, attr.getLeafAsInt().getValue(), 32) + .getResult(), + attrBuilder.materializeConstantTuple(attr)}; + } else { + return IntTupleValueAdaptor{ + MakeIntTupleOp::create(builder, loc, IntTupleType::get(attr), {}).getResult(), + attrBuilder.materializeConstantTuple(attr)}; + } + } + + bool isNone(ArithValue val) const { return attrBuilder.isNone(val.attr); } + bool isStatic(ArithValue val) const { return attrBuilder.isStatic(val.attr); } + bool isStaticValue(ArithValue val, int32_t v) const { + return attrBuilder.isStaticValue(val.attr, v); + } + int32_t getStaticValue(ArithValue val) const { return attrBuilder.getStaticValue(val.attr); } + + IntTupleValueAdaptor at(IntTupleValueAdaptor adaptor, int32_t idx) const { + auto childAttr = adaptor.attr.at(idx); + if (childAttr.isLeaf()) { + if (childAttr.isStatic()) { + return makeInt(this->materializeConstantArith(childAttr.getLeafAsInt().getValue())); + } else { + return IntTupleValueAdaptor(adaptor.value.getDefiningOp()->getOperand( + adaptor.dyncIdxStart + adaptor.prefixSumDyncElems(idx)), + childAttr); + } + } else { + int32_t dyncOffset = adaptor.prefixSumDyncElems(idx); + return IntTupleValueAdaptor(adaptor.value, childAttr, adaptor.dyncIdxStart + dyncOffset, + adaptor.dyncIdxStart + dyncOffset + childAttr.dyncLeafCount()); + } + } + IntTupleValueAdaptor makeInt(ArithValue value) const { + return IntTupleValueAdaptor(value.value, IntTupleAttr::get(value.attr)); + } + IntTupleValueAdaptor makeTuple(const ElemCollector &collector) const { + auto TupleAttr = attrBuilder.makeTuple(collector.attrCollector); + return IntTupleValueAdaptor( + MakeIntTupleOp::create(builder, loc, IntTupleType::get(TupleAttr), collector.dyncElems) + .getResult(), + TupleAttr); + } + const IntTupleBuilder &getAttrBuilder() const { return attrBuilder; } + + //===----------------------------------------------------------------------===// + // IntTupleValueAdaptor only interface + //===----------------------------------------------------------------------===// + + TypedValue finalize(IntTupleValueAdaptor adaptor) const { + auto Ty = IntTupleType::get(adaptor.attr); + if (adaptor.isLeaf()) { + if (adaptor.attr.isStatic()) { + return MakeIntTupleOp::create(builder, loc, Ty, {}).getResult(); + } else { + return MakeIntTupleOp::create(builder, loc, Ty, adaptor.value).getResult(); + } + } else if (adaptor.dyncIdxStart == 0 && adaptor.dyncIdxEnd == -1) { + return cast>(adaptor.value); + } else { + int32_t dyncIdxEnd = adaptor.dyncIdxEnd == -1 + ? adaptor.value.getDefiningOp()->getOperands().size() + : adaptor.dyncIdxEnd; + return MakeIntTupleOp::create(builder, loc, Ty, + adaptor.value.getDefiningOp()->getOperands().slice( + adaptor.dyncIdxStart, dyncIdxEnd - adaptor.dyncIdxStart)) + .getResult(); + } + } + TypedValue reprofile(TypedValue value, + IntTupleAttr newProfile) const { + return MakeIntTupleOp::create(builder, value.getLoc(), IntTupleType::get(newProfile), + value.getDefiningOp()->getOperands()) + .getResult(); + } +}; + +template +IntTuple intTupleBinaryOp(IntTupleBuilder &builder, BinaryOp &&binaryOp, IntTuple lhs, + IntTuple rhs) { + if (lhs.isLeaf()) { + assert(rhs.isLeaf() && "Mismatched structure"); + return builder.makeInt(binaryOp(builder.getArithValue(lhs), builder.getArithValue(rhs))); + } + typename IntTupleBuilder::ElemCollector collector; + const int minRank = std::min(lhs.rank(), rhs.rank()); + for (int i = 0; i < minRank; ++i) { + collector.push_back( + intTupleBinaryOp(builder, binaryOp, builder.at(lhs, i), builder.at(rhs, i))); + } + for (int i = minRank; i < lhs.rank(); ++i) { + collector.push_back(builder.at(lhs, i)); + } + for (int i = minRank; i < rhs.rank(); ++i) { + collector.push_back(builder.at(rhs, i)); + } + return builder.makeTuple(collector); +} + +template +IntTuple intTupleAdd(IntTupleBuilder &builder, IntTuple lhs, IntTuple rhs) { + using ArithValue = typename IntTupleBuilder::ArithValue; + return intTupleBinaryOp( + builder, [&](ArithValue a, ArithValue b) { return builder.add(a, b); }, lhs, rhs); +} + +template +IntTuple intTupleSub(IntTupleBuilder &builder, IntTuple lhs, IntTuple rhs) { + using ArithValue = typename IntTupleBuilder::ArithValue; + return intTupleBinaryOp( + builder, [&](ArithValue a, ArithValue b) { return builder.sub(a, b); }, lhs, rhs); +} + +template +IntTuple intTupleMul(IntTupleBuilder &builder, IntTuple lhs, IntTuple rhs) { + using ArithValue = typename IntTupleBuilder::ArithValue; + return intTupleBinaryOp( + builder, [&](ArithValue a, ArithValue b) { return builder.mul(a, b); }, lhs, rhs); +} + +template +IntTuple intTupleDiv(IntTupleBuilder &builder, IntTuple lhs, IntTuple rhs) { + using ArithValue = typename IntTupleBuilder::ArithValue; + return intTupleBinaryOp( + builder, [&](ArithValue a, ArithValue b) { return builder.div(a, b); }, lhs, rhs); +} + +template +IntTuple intTupleMod(IntTupleBuilder &builder, IntTuple lhs, IntTuple rhs) { + using ArithValue = typename IntTupleBuilder::ArithValue; + return intTupleBinaryOp( + builder, [&](ArithValue a, ArithValue b) { return builder.mod(a, b); }, lhs, rhs); +} + +template +IntTuple intTupleMin(IntTupleBuilder &builder, IntTuple lhs, IntTuple rhs) { + using ArithValue = typename IntTupleBuilder::ArithValue; + return intTupleBinaryOp( + builder, [&](ArithValue a, ArithValue b) { return builder.min(a, b); }, lhs, rhs); +} + +template +IntTuple intTupleMax(IntTupleBuilder &builder, IntTuple lhs, IntTuple rhs) { + using ArithValue = typename IntTupleBuilder::ArithValue; + return intTupleBinaryOp( + builder, [&](ArithValue a, ArithValue b) { return builder.max(a, b); }, lhs, rhs); +} + +template +typename IntTupleBuilder::ArithValue intTupleSumImpl(IntTupleBuilder &builder, + IntTuple t) { + using ArithValue = typename IntTupleBuilder::ArithValue; + if (t.isLeaf()) { + return builder.getArithValue(t); + } + ArithValue result = intTupleSumImpl(builder, builder.at(t, 0)); + for (int i = 1; i < t.rank(); ++i) { + result = builder.add(result, intTupleSumImpl(builder, builder.at(t, i))); + } + return result; +} + +template IntTuple intTupleSum(IntTupleBuilder &builder, IntTuple t) { + return builder.makeInt(intTupleSumImpl(builder, t)); +} + +template +typename IntTupleBuilder::ArithValue +intTupleProductImpl(IntTupleBuilder &builder, IntTuple t) { + using ArithValue = typename IntTupleBuilder::ArithValue; + if (t.isLeaf()) { + return builder.getArithValue(t); + } + ArithValue result = intTupleProductImpl(builder, builder.at(t, 0)); + for (int i = 1; i < t.rank(); ++i) { + result = builder.mul(result, intTupleProductImpl(builder, builder.at(t, i))); + } + return result; +} + +template IntTuple intTupleProduct(IntTupleBuilder &builder, IntTuple t) { + return builder.makeInt(intTupleProductImpl(builder, t)); +} + +template +typename IntTupleBuilder::ArithValue +intTupleInnerProductImpl(IntTupleBuilder &builder, IntTuple lhs, IntTuple rhs) { + using ArithValue = typename IntTupleBuilder::ArithValue; + if (lhs.isLeaf() && rhs.isLeaf()) { + return builder.mul(builder.getArithValue(lhs), builder.getArithValue(rhs)); + } + assert(lhs.rank() == rhs.rank() && "Mismatched ranks"); + ArithValue result = intTupleInnerProductImpl(builder, builder.at(lhs, 0), builder.at(rhs, 0)); + for (int i = 1; i < lhs.rank(); ++i) { + result = builder.add(result, + intTupleInnerProductImpl(builder, builder.at(lhs, i), builder.at(rhs, i))); + } + return result; +} + +template +IntTuple intTupleInnerProduct(IntTupleBuilder &builder, IntTuple lhs, IntTuple rhs) { + return builder.makeInt(intTupleInnerProductImpl(builder, lhs, rhs)); +} + +template +std::pair::ArithValue> +intTupleCeilDivFoldImpl(IntTupleBuilder &builder, IntTuple a, + typename IntTupleBuilder::ArithValue b) { + using ArithValue = typename IntTupleBuilder::ArithValue; + if (a.isLeaf()) { + auto aVal = builder.getArithValue(a); + auto result = builder.ceilDiv(aVal, b); + auto remainder = builder.ceilDiv(b, aVal); + return {builder.makeInt(result), remainder}; + } + typename IntTupleBuilder::ElemCollector collector; + ArithValue remaining = b; + for (int i = 0; i < a.rank(); ++i) { + auto [res, rem] = intTupleCeilDivFoldImpl(builder, builder.at(a, i), remaining); + collector.push_back(res); + remaining = rem; + } + return {builder.makeTuple(collector), remaining}; +} + +template +IntTuple intTupleCeilDiv(IntTupleBuilder &builder, IntTuple lhs, IntTuple rhs) { + if (lhs.isLeaf()) { + if (rhs.isLeaf()) { + return builder.makeInt( + builder.ceilDiv(builder.getArithValue(lhs), builder.getArithValue(rhs))); + } + auto rhsProduct = intTupleProductImpl(builder, rhs); + return builder.makeInt(builder.ceilDiv(builder.getArithValue(lhs), rhsProduct)); + } + if (rhs.isLeaf()) { + auto [result, rest] = intTupleCeilDivFoldImpl(builder, lhs, builder.getArithValue(rhs)); + return result; + } + const int divRank = std::min(lhs.rank(), rhs.rank()); + typename IntTupleBuilder::ElemCollector collector; + for (int i = 0; i < divRank; ++i) { + collector.push_back(intTupleCeilDiv(builder, builder.at(lhs, i), builder.at(rhs, i))); + } + for (int i = divRank; i < lhs.rank(); ++i) { + collector.push_back(builder.at(lhs, i)); + } + return builder.makeTuple(collector); +} + +template +std::pair::ArithValue> +intTupleShapeDivFoldImpl(IntTupleBuilder &builder, IntTuple a, + typename IntTupleBuilder::ArithValue b) { + using ArithValue = typename IntTupleBuilder::ArithValue; + if (a.isLeaf()) { + auto aVal = builder.getArithValue(a); + auto result = builder.shapeDiv(aVal, b); + auto remainder = builder.shapeDiv(b, aVal); + return {builder.makeInt(result), remainder}; + } + typename IntTupleBuilder::ElemCollector collector; + ArithValue remaining = b; + for (int i = 0; i < a.rank(); ++i) { + auto [res, rem] = intTupleShapeDivFoldImpl(builder, builder.at(a, i), remaining); + collector.push_back(res); + remaining = rem; + } + return {builder.makeTuple(collector), remaining}; +} + +template +IntTuple intTupleShapeDiv(IntTupleBuilder &builder, IntTuple lhs, IntTuple rhs) { + if (lhs.isLeaf()) { + if (rhs.isLeaf()) { + return builder.makeInt( + builder.shapeDiv(builder.getArithValue(lhs), builder.getArithValue(rhs))); + } + auto rhsProduct = intTupleProductImpl(builder, rhs); + return builder.makeInt(builder.shapeDiv(builder.getArithValue(lhs), rhsProduct)); + } + if (rhs.isLeaf()) { + auto [result, rest] = intTupleShapeDivFoldImpl(builder, lhs, builder.getArithValue(rhs)); + return result; + } + const int divRank = std::min(lhs.rank(), rhs.rank()); + typename IntTupleBuilder::ElemCollector collector; + for (int i = 0; i < divRank; ++i) { + collector.push_back(intTupleShapeDiv(builder, builder.at(lhs, i), builder.at(rhs, i))); + } + for (int i = divRank; i < lhs.rank(); ++i) { + collector.push_back(builder.at(lhs, i)); + } + return builder.makeTuple(collector); +} + +template +IntTuple intTupleProductEach(IntTupleBuilder &builder, IntTuple val) { + if (val.isLeaf()) { + return val; + } + typename IntTupleBuilder::ElemCollector collector; + for (int i = 0; i < val.rank(); ++i) { + collector.push_back(intTupleProduct(builder, builder.at(val, i))); + } + return builder.makeTuple(collector); +} + +//===----------------------------------------------------------------------===// +// Attribute manipulation +//===----------------------------------------------------------------------===// + +IntTupleAttr intTupleWrap(const IntTupleBuilder &builder, IntTupleAttr attr); +IntTupleAttr intTupleUnwrap(const IntTupleBuilder &builder, IntTupleAttr attr); + +IntTupleAttr intTupleUnflatten(const IntTupleBuilder &builder, IntTupleAttr attr, + IntTupleAttr profile); + +IntTupleAttr intTupleExpand(const IntTupleBuilder &builder, IntTupleAttr attr, + ArrayRef indices); +IntTupleAttr intTupleGroup(const IntTupleBuilder &builder, IntTupleAttr attr, + int32_t begin, int32_t end); + +inline IntTupleValueAdaptor intTupleWrap(const IntTupleBuilder &builder, + IntTupleValueAdaptor adaptor) { + IntTupleAttr newAttr = intTupleWrap(builder.getAttrBuilder(), builder.getAttr(adaptor)); + return IntTupleValueAdaptor::create(builder, builder.finalize(adaptor), newAttr); +} +inline IntTupleValueAdaptor intTupleUnwrap(const IntTupleBuilder &builder, + IntTupleValueAdaptor adaptor) { + IntTupleAttr newAttr = intTupleUnwrap(builder.getAttrBuilder(), builder.getAttr(adaptor)); + return IntTupleValueAdaptor::create( + builder, builder.reprofile(builder.finalize(adaptor), newAttr), newAttr); +} +inline IntTupleValueAdaptor intTupleUnflatten(const IntTupleBuilder &builder, + IntTupleValueAdaptor adaptor, IntTupleAttr profile) { + IntTupleAttr newAttr = + intTupleUnflatten(builder.getAttrBuilder(), builder.getAttr(adaptor), profile); + return IntTupleValueAdaptor::create( + builder, builder.reprofile(builder.finalize(adaptor), newAttr), newAttr); +} +inline IntTupleValueAdaptor intTupleExpand(const IntTupleBuilder &builder, + IntTupleValueAdaptor adaptor, + ArrayRef indices) { + IntTupleAttr newAttr = + intTupleExpand(builder.getAttrBuilder(), builder.getAttr(adaptor), indices); + return IntTupleValueAdaptor::create( + builder, builder.reprofile(builder.finalize(adaptor), newAttr), newAttr); +} +inline IntTupleValueAdaptor intTupleGroup(const IntTupleBuilder &builder, + IntTupleValueAdaptor adaptor, int32_t begin, + int32_t end) { + IntTupleAttr newAttr = + intTupleGroup(builder.getAttrBuilder(), builder.getAttr(adaptor), begin, end); + return IntTupleValueAdaptor::create( + builder, builder.reprofile(builder.finalize(adaptor), newAttr), newAttr); +} + +template +void intTupleFlattenToVector(const IntTupleBuilder &builder, IntTuple t, + Collector &result) { + if (t.isLeaf()) { + result.push_back(t); + } else { + for (int i = 0; i < t.rank(); ++i) { + intTupleFlattenToVector(builder, builder.at(t, i), result); + } + } +} +template +IntTuple intTupleFlatten(const IntTupleBuilder &builder, IntTuple t) { + if (t.isLeaf()) { + return t; + } + typename IntTupleBuilder::ElemCollector collector; + intTupleFlattenToVector(builder, t, collector); + return builder.makeTuple(collector); +} + +//===----------------------------------------------------------------------===// +// Transformation operations +//===----------------------------------------------------------------------===// + +template +IntTuple intTupleTransform(const IntTupleBuilder &builder, F &&fn, IntTuple t0) { + if (t0.isLeaf()) { + return fn(t0); + } + typename IntTupleBuilder::ElemCollector collector; + for (int i = 0; i < t0.rank(); ++i) { + collector.push_back(fn(builder.at(t0, i))); + } + return builder.makeTuple(collector); +} +template +IntTuple intTupleTransform(const IntTupleBuilder &builder, F &&fn, IntTuple t0, + IntTuple t1) { + if (t0.isLeaf()) { + return fn(t0, t1); + } + typename IntTupleBuilder::ElemCollector collector; + for (int i = 0; i < t0.rank(); ++i) { + collector.push_back(fn(builder.at(t0, i), builder.at(t1, i))); + } + return builder.makeTuple(collector); +} +template +IntTuple intTupleTransform(const IntTupleBuilder &builder, F &&fn, IntTuple t0, + IntTuple t1, IntTuple t2) { + if (t0.isLeaf()) { + return fn(t0, t1, t2); + } + typename IntTupleBuilder::ElemCollector collector; + for (int i = 0; i < t0.rank(); ++i) { + collector.push_back(fn(builder.at(t0, i), builder.at(t1, i), builder.at(t2, i))); + } + return builder.makeTuple(collector); +} + +template +IntTuple intTupleTransformLeaf(const IntTupleBuilder &builder, F &&fn, IntTuple t0) { + if (t0.isLeaf()) { + return fn(t0); + } + typename IntTupleBuilder::ElemCollector collector; + for (int i = 0; i < t0.rank(); ++i) { + collector.push_back(intTupleTransformLeaf(builder, fn, builder.at(t0, i))); + } + return builder.makeTuple(collector); +} +template +IntTuple intTupleTransformLeaf(const IntTupleBuilder &builder, F &&fn, IntTuple t0, + IntTuple t1) { + if (t0.isLeaf()) { + return fn(t0, t1); + } + typename IntTupleBuilder::ElemCollector collector; + for (int i = 0; i < t0.rank(); ++i) { + collector.push_back(intTupleTransformLeaf(builder, fn, builder.at(t0, i), builder.at(t1, i))); + } + return builder.makeTuple(collector); +} +template +IntTuple intTupleTransformLeaf(const IntTupleBuilder &builder, F &&fn, IntTuple t0, + IntTuple t1, IntTuple t2) { + if (t0.isLeaf()) { + return fn(t0, t1, t2); + } + typename IntTupleBuilder::ElemCollector collector; + for (int i = 0; i < t0.rank(); ++i) { + collector.push_back(intTupleTransformLeaf(builder, fn, builder.at(t0, i), builder.at(t1, i), + builder.at(t2, i))); + } + return builder.makeTuple(collector); +} + +template +IntTuple intTupleSelect(const IntTupleBuilder &builder, IntTuple val, + ArrayRef indices) { + assert(!val.isLeaf() && "intTupleSelect expects a non-leaf tuple"); + typename IntTupleBuilder::ElemCollector collector; + for (int32_t idx : indices) { + collector.push_back(builder.at(val, idx)); + } + return builder.makeTuple(collector); +} + +/// If n == -1, appends a single element. +template +IntTuple intTupleAppend(const IntTupleBuilder &builder, IntTuple val, IntTuple elem, + int32_t n = -1) { + typename IntTupleBuilder::ElemCollector collector; + if (val.isLeaf()) { + collector.push_back(val); + if (n == -1) { + collector.push_back(elem); + } else { + int32_t currentRank = 1; + while (currentRank < n) { + collector.push_back(elem); + ++currentRank; + } + } + } else { + for (int i = 0; i < val.rank(); ++i) { + collector.push_back(builder.at(val, i)); + } + if (n == -1) { + collector.push_back(elem); + } else { + int32_t currentRank = val.rank(); + assert(currentRank <= n && "intTupleAppend expects n >= current rank"); + while (currentRank < n) { + collector.push_back(elem); + ++currentRank; + } + } + } + return builder.makeTuple(collector); +} +/// If n == -1, prepends a single element. +template +IntTuple intTuplePrepend(const IntTupleBuilder &builder, IntTuple val, IntTuple elem, + int32_t n = -1) { + typename IntTupleBuilder::ElemCollector collector; + if (val.isLeaf()) { + if (n == -1) { + collector.push_back(elem); + } else { + int32_t targetAppend = n - 1; + for (int32_t i = 0; i < targetAppend; ++i) { + collector.push_back(elem); + } + } + collector.push_back(val); + } else { + if (n == -1) { + collector.push_back(elem); + } else { + assert(n >= val.rank() && "intTuplePrepend expects n >= current rank"); + int32_t numToPrepend = n - val.rank(); + for (int32_t i = 0; i < numToPrepend; ++i) { + collector.push_back(elem); + } + } + for (int i = 0; i < val.rank(); ++i) { + collector.push_back(builder.at(val, i)); + } + } + return builder.makeTuple(collector); +} + +template +IntTuple intTupleZip(const IntTupleBuilder &builder, IntTuple attr) { + using Collector = typename IntTupleBuilder::ElemCollector; + if (attr.isLeaf()) { + return attr; + } else { + auto firstChild = builder.at(attr, 0); + if (firstChild.isLeaf()) { + return attr; + } else { + int32_t innerRank = firstChild.rank(); + Collector result; + for (int j = 0; j < innerRank; ++j) { + Collector zipped; + for (int i = 0; i < attr.rank(); ++i) { + zipped.push_back(builder.at(builder.at(attr, i), j)); + } + result.push_back(builder.makeTuple(zipped)); + } + return builder.makeTuple(result); + } + } +} +template +IntTuple intTupleZip(const IntTupleBuilder &builder, IntTuple t0, IntTuple t1) { + typename IntTupleBuilder::ElemCollector collector; + collector.push_back(t0); + collector.push_back(t1); + return intTupleZip(builder, builder.makeTuple(collector)); +} +template +IntTuple intTupleZip(const IntTupleBuilder &builder, IntTuple t0, IntTuple t1, + IntTuple t2) { + typename IntTupleBuilder::ElemCollector collector; + collector.push_back(t0); + collector.push_back(t1); + collector.push_back(t2); + return intTupleZip(builder, builder.makeTuple(collector)); +} + +namespace detail { + +template +std::pair intTupleZip2ByImpl(const IntTupleBuilder &builder, + IntTuple t, IntTupleAttr guide) { + using Collector = typename IntTupleBuilder::ElemCollector; + if (guide.isLeaf()) { + assert(t.rank() == 2 && "intTupleZip2By expects rank-2 tuple at terminal"); + return {builder.at(t, 0), builder.at(t, 1)}; + } + Collector firsts; + Collector seconds; + + int32_t guideRank = guide.rank(); + int32_t tRank = t.rank(); + assert(tRank >= guideRank && "Mismatched ranks in intTupleZip2By"); + for (int i = 0; i < guideRank; ++i) { + auto [first, second] = intTupleZip2ByImpl(builder, builder.at(t, i), guide.at(i)); + firsts.push_back(first); + seconds.push_back(second); + } + for (int i = guideRank; i < tRank; ++i) { + seconds.push_back(builder.at(t, i)); + } + return {builder.makeTuple(firsts), builder.makeTuple(seconds)}; +} + +} // namespace detail + +template +IntTuple intTupleZip2By(const IntTupleBuilder &builder, IntTuple t, IntTupleAttr guide) { + using Collector = typename IntTupleBuilder::ElemCollector; + auto [first, second] = detail::intTupleZip2ByImpl(builder, t, guide); + Collector collector; + collector.push_back(first); + collector.push_back(second); + return builder.makeTuple(collector); +} + +namespace detail { + +template +void intTupleSliceImpl(const IntTupleBuilder &builder, IntTuple tuple, IntTupleAttr coord, + typename IntTupleBuilder::ElemCollector &result) { + if (coord.isLeaf()) { + if (coord.isLeafNone()) { + result.push_back(tuple); + } + return; + } + assert(coord.rank() == tuple.rank() && "Mismatched ranks in slice"); + for (int i = 0; i < coord.rank(); ++i) { + intTupleSliceImpl(builder, builder.at(tuple, i), coord.at(i), result); + } +} +template +void intTupleDiceImpl(const IntTupleBuilder &builder, IntTuple tuple, IntTupleAttr coord, + typename IntTupleBuilder::ElemCollector &result) { + if (coord.isLeaf()) { + if (!coord.isLeafNone()) { + result.push_back(tuple); + } + return; + } + assert(coord.rank() == tuple.rank() && "Mismatched ranks in dice"); + for (int i = 0; i < coord.rank(); ++i) { + intTupleDiceImpl(builder, builder.at(tuple, i), coord.at(i), result); + } +} + +} // namespace detail + +template +IntTuple intTupleSlice(const IntTupleBuilder &builder, IntTuple tuple, + IntTupleAttr coord) { + if (coord.isLeaf()) { + if (coord.isLeafNone()) { + return tuple; + } + llvm_unreachable("not support empty IntTuple"); + } else { + typename IntTupleBuilder::ElemCollector collector; + assert(coord.rank() == tuple.rank() && "Mismatched ranks in slice"); + for (int i = 0; i < coord.rank(); ++i) { + detail::intTupleSliceImpl(builder, builder.at(tuple, i), coord.at(i), collector); + } + assert(!collector.empty() && "not support empty IntTuple"); + return intTupleUnwrap(builder, builder.makeTuple(collector)); + } +} +template +IntTuple intTupleDice(const IntTupleBuilder &builder, IntTuple tuple, + IntTupleAttr coord) { + if (coord.isLeaf()) { + if (!coord.isLeafNone()) { + return tuple; + } + llvm_unreachable("not support empty IntTuple"); + } else { + typename IntTupleBuilder::ElemCollector collector; + assert(coord.rank() == tuple.rank() && "Mismatched ranks in dice"); + for (int i = 0; i < coord.rank(); ++i) { + detail::intTupleDiceImpl(builder, builder.at(tuple, i), coord.at(i), collector); + } + assert(!collector.empty() && "not support empty IntTuple"); + return intTupleUnwrap(builder, builder.makeTuple(collector)); + } +} + +template +IntTuple intTupleFilterZero(IntTupleBuilder &builder, IntTupleAttr guide, IntTuple val) { + using Collector = typename IntTupleBuilder::ElemCollector; + if (guide.isLeaf()) { + if (guide.isLeafStaticValue(0)) { + return intTupleTransformLeaf( + builder, [&](auto) { return builder.makeInt(builder.materializeConstantArith(1)); }, val); + } + return val; + } + assert(guide.rank() == val.rank() && "Mismatched ranks in intTupleFilterZero"); + Collector collector; + for (int i = 0; i < guide.rank(); ++i) { + collector.push_back(intTupleFilterZero(builder, guide.at(i), builder.at(val, i))); + } + return builder.makeTuple(collector); +} +template +IntTuple intTupleFilterZero(IntTupleBuilder &builder, IntTuple val) { + return intTupleFilterZero(builder, builder.getAttr(val), val); +} + +//===----------------------------------------------------------------------===// +// Element-wise comparison +//===----------------------------------------------------------------------===// + +namespace detail { + +template +typename IntTupleBuilder::ArithValue +intTupleElemLessImpl(const IntTupleBuilder &builder, IntTuple lhs, IntTuple rhs) { + using ArithValue = typename IntTupleBuilder::ArithValue; + if (lhs.isLeaf() && rhs.isLeaf()) { + return builder.lt(builder.getArithValue(lhs), builder.getArithValue(rhs)); + } + if (lhs.rank() > rhs.rank()) { + return builder.materializeConstantArith(0); + } + ArithValue result = intTupleElemLessImpl(builder, builder.at(lhs, 0), builder.at(rhs, 0)); + for (int i = 1; i < lhs.rank(); ++i) { + ArithValue ri = intTupleElemLessImpl(builder, builder.at(lhs, i), builder.at(rhs, i)); + result = builder.logicalAnd(result, ri); + } + return result; +} + +} // namespace detail + +template +IntTuple intTupleElemLess(const IntTupleBuilder &builder, IntTuple lhs, IntTuple rhs) { + return builder.makeInt(detail::intTupleElemLessImpl(builder, lhs, rhs)); +} +template +IntTuple intTupleElemLessEqual(const IntTupleBuilder &builder, IntTuple lhs, + IntTuple rhs) { + return builder.makeInt(builder.logicalNot(detail::intTupleElemLessImpl(builder, rhs, lhs))); +} +template +IntTuple intTupleElemGreater(const IntTupleBuilder &builder, IntTuple lhs, IntTuple rhs) { + return builder.makeInt(detail::intTupleElemLessImpl(builder, rhs, lhs)); +} +template +IntTuple intTupleElemGreaterEqual(const IntTupleBuilder &builder, IntTuple lhs, + IntTuple rhs) { + return builder.makeInt(builder.logicalNot(detail::intTupleElemLessImpl(builder, lhs, rhs))); +} + +//===----------------------------------------------------------------------===// +// Basis arithmetic operations +//===----------------------------------------------------------------------===// + +IntTupleAttr intTupleExpandBasis(BasisAttr attr); +IntTupleAttr intTupleMakeBasisLike(IntTupleAttr profile); + +IntTupleAttr operator+(BasisAttr lhs, BasisAttr rhs); +IntTupleAttr operator+(BasisAttr lhs, IntTupleAttr rhs); +IntTupleAttr operator+(IntTupleAttr lhs, BasisAttr rhs); +BasisAttr operator*(BasisAttr lhs, IntAttr rhs); +BasisAttr operator*(IntAttr lhs, BasisAttr rhs); +BasisAttr operator/(BasisAttr lhs, IntAttr rhs); + +BasisAttr basisSafeDiv(BasisAttr lhs, IntAttr rhs); +BasisAttr basisCeilDiv(BasisAttr lhs, IntAttr rhs); + +} // namespace mlir::fly + +#endif // FLY_DIALECT_UTILS_INTTUPLEUTILS_H diff --git a/include/flydsl/Dialect/Fly/Utils/IntUtils.h b/include/flydsl/Dialect/Fly/Utils/IntUtils.h new file mode 100644 index 00000000..3e14fbac --- /dev/null +++ b/include/flydsl/Dialect/Fly/Utils/IntUtils.h @@ -0,0 +1,53 @@ +#ifndef FLY_DIALECT_UTILS_INTUTILS_H +#define FLY_DIALECT_UTILS_INTUTILS_H + +#include "mlir/IR/Attributes.h" +#include "mlir/Support/LogicalResult.h" + +#include "flydsl/Dialect/Fly/IR/FlyDialect.h" + +#include + +namespace mlir::fly { +namespace utils { + +inline int32_t divisibilityAdd(int32_t lhs, int32_t rhs) { return std::gcd(lhs, rhs); } +inline int32_t divisibilitySub(int32_t lhs, int32_t rhs) { return std::gcd(lhs, rhs); } +inline int32_t divisibilityMul(int32_t lhs, int32_t rhs) { return lhs * rhs; } +inline int32_t divisibilityDiv(int32_t lhs, int32_t rhs) { return 1; } +inline int32_t divisibilityCeilDiv(int32_t lhs, int32_t rhs) { return 1; } +inline int32_t divisibilityModulo(int32_t lhs, int32_t rhs) { return std::gcd(lhs, rhs); } +inline int32_t divisibilityMin(int32_t lhs, int32_t rhs) { return std::gcd(lhs, rhs); } +inline int32_t divisibilityMax(int32_t lhs, int32_t rhs) { return std::gcd(lhs, rhs); } + +} // namespace utils + +/// Sentinel value for dynamic integers +constexpr int32_t kDynamicIntSentinel = std::numeric_limits::min(); + +IntAttr operator+(IntAttr lhs, IntAttr rhs); +IntAttr operator-(IntAttr lhs, IntAttr rhs); +IntAttr operator*(IntAttr lhs, IntAttr rhs); +IntAttr operator/(IntAttr lhs, IntAttr rhs); +IntAttr operator%(IntAttr lhs, IntAttr rhs); + +IntAttr operator&&(IntAttr lhs, IntAttr rhs); +IntAttr operator||(IntAttr lhs, IntAttr rhs); +IntAttr operator!(IntAttr val); + +IntAttr operator<(IntAttr lhs, IntAttr rhs); +IntAttr operator<=(IntAttr lhs, IntAttr rhs); +IntAttr operator>(IntAttr lhs, IntAttr rhs); +IntAttr operator>=(IntAttr lhs, IntAttr rhs); +IntAttr operator==(IntAttr lhs, IntAttr rhs); +IntAttr operator!=(IntAttr lhs, IntAttr rhs); + +IntAttr intMin(IntAttr lhs, IntAttr rhs); +IntAttr intMax(IntAttr lhs, IntAttr rhs); +IntAttr intSafeDiv(IntAttr lhs, IntAttr rhs); +IntAttr intCeilDiv(IntAttr lhs, IntAttr rhs); +IntAttr intShapeDiv(IntAttr lhs, IntAttr rhs); + +} // namespace mlir::fly + +#endif // FLY_DIALECT_UTILS_INTUTILS_H diff --git a/include/flydsl/Dialect/Fly/Utils/LayoutUtils.h b/include/flydsl/Dialect/Fly/Utils/LayoutUtils.h new file mode 100644 index 00000000..39afd472 --- /dev/null +++ b/include/flydsl/Dialect/Fly/Utils/LayoutUtils.h @@ -0,0 +1,771 @@ +#ifndef FLY_DIALECT_UTILS_LAYOUTATTR_H +#define FLY_DIALECT_UTILS_LAYOUTATTR_H + +#include + +#include "mlir/IR/Attributes.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Support/LogicalResult.h" + +#include "flydsl/Dialect/Fly/IR/FlyDialect.h" +#include "flydsl/Dialect/Fly/Utils/IntTupleUtils.h" +#include "flydsl/Dialect/Fly/Utils/IntUtils.h" + +namespace mlir::fly { + +namespace detail { + +template +typename IntTupleBuilder::ArithValue layoutCrd2idxTTT(IntTupleBuilder &builder, + IntTuple coord, IntTuple shape, + IntTuple stride); + +template +typename IntTupleBuilder::ArithValue +layoutCrd2idxITT(IntTupleBuilder &builder, + typename IntTupleBuilder::ArithValue coord, IntTuple shape, + IntTuple stride) { + using ArithValue = typename IntTupleBuilder::ArithValue; + int32_t rank = shape.rank(); + if (rank == 1) { + return layoutCrd2idxTTT(builder, builder.makeInt(coord), builder.at(shape, 0), + builder.at(stride, 0)); + } + IntTuple si = builder.at(shape, 0); + IntTuple di = builder.at(stride, 0); + + ArithValue siProduct = intTupleProductImpl(builder, si); + ArithValue ci = builder.mod(coord, siProduct); + ArithValue remaining = builder.div(coord, siProduct); + + ArithValue result; + if (si.isLeaf()) { + result = builder.mul(ci, builder.getArithValue(di)); + } else { + result = layoutCrd2idxITT(builder, ci, si, di); + } + + for (int i = 1; i < rank; ++i) { + si = builder.at(shape, i); + di = builder.at(stride, i); + + if (i == rank - 1) { + ci = remaining; + } else { + siProduct = intTupleProductImpl(builder, si); + ci = builder.mod(remaining, siProduct); + remaining = builder.div(remaining, siProduct); + } + if (si.isLeaf()) { + result = builder.add(result, builder.mul(ci, builder.getArithValue(di))); + } else { + result = builder.add(result, layoutCrd2idxITT(builder, ci, si, di)); + } + } + return result; +} + +template +typename IntTupleBuilder::ArithValue layoutCrd2idxTTT(IntTupleBuilder &builder, + IntTuple coord, IntTuple shape, + IntTuple stride) { + using ArithValue = typename IntTupleBuilder::ArithValue; + if (coord.isLeaf()) { + if (shape.isLeaf()) { + return builder.mul(builder.getArithValue(coord), builder.getArithValue(stride)); + } else { + return layoutCrd2idxITT(builder, builder.getArithValue(coord), shape, stride); + } + } else { + assert(coord.rank() == shape.rank() && "Mismatched ranks"); + ArithValue result = layoutCrd2idxTTT(builder, builder.at(coord, 0), builder.at(shape, 0), + builder.at(stride, 0)); + for (int i = 1; i < coord.rank(); ++i) { + result = builder.add(result, layoutCrd2idxTTT(builder, builder.at(coord, i), + builder.at(shape, i), builder.at(stride, i))); + } + return result; + } +} + +} // namespace detail + +template +IntTuple layoutCrd2idx(IntTupleBuilder &builder, IntTuple coord, IntTuple shape, + IntTuple stride) { + return builder.makeInt(detail::layoutCrd2idxTTT(builder, coord, shape, stride)); +} + +template class LayoutBuilder; + +class LayoutValueAdaptor { +private: + Value value; + LayoutAttr attr; + +public: + LayoutValueAdaptor(Value value, LayoutAttr attr) : value(value), attr(attr) {} + + bool isLeaf() const { return attr.isLeaf(); } + int32_t rank() const { return attr.rank(); } + + friend class LayoutBuilder; +}; + +template <> class LayoutBuilder : public IntTupleBuilder { +public: + using IntTupleBuilder::IntTupleBuilder; + using IntTuple = IntTupleAttr; + + LayoutAttr getLayoutAttr(LayoutAttr attr) const { return attr; } + IntTuple getShape(LayoutAttr attr) const { return attr.getShape(); } + IntTuple getStride(LayoutAttr attr) const { return attr.getStride(); } + + LayoutAttr materializeConstantLayout(IntTupleAttr shape, IntTupleAttr stride) const { + return LayoutAttr::get(materializeConstantTuple(shape), materializeConstantTuple(stride)); + } + LayoutAttr materializeConstantLayout(LayoutAttr attr) const { + assert(attr.isStatic() && "Layout must be static"); + return attr; + } + LayoutAttr makeLayout(IntTupleAttr shape, IntTupleAttr stride) const { + return LayoutAttr::get(shape, stride); + } +}; + +template <> class LayoutBuilder : public IntTupleBuilder { +public: + using IntTupleBuilder::IntTupleBuilder; + using IntTuple = IntTupleValueAdaptor; + + LayoutAttr getLayoutAttr(LayoutValueAdaptor adaptor) const { return adaptor.attr; } + IntTuple getShape(LayoutValueAdaptor adaptor) const { + return IntTupleValueAdaptor::create(*this, adaptor.value.getDefiningOp()->getOperand(0), + adaptor.attr.getShape()); + } + IntTuple getStride(LayoutValueAdaptor adaptor) const { + return IntTupleValueAdaptor::create(*this, adaptor.value.getDefiningOp()->getOperand(1), + adaptor.attr.getStride()); + } + + LayoutValueAdaptor materializeConstantLayout(IntTupleAttr shape, IntTupleAttr stride) const { + return makeLayout(materializeConstantTuple(shape), materializeConstantTuple(stride)); + } + LayoutValueAdaptor materializeConstantLayout(LayoutAttr attr) const { + return materializeConstantLayout(attr.getShape(), attr.getStride()); + } + LayoutValueAdaptor makeLayout(IntTuple shape, IntTuple stride) const { + auto value = MakeLayoutOp::create(this->builder, this->loc, this->finalize(shape), + this->finalize(stride)) + .getResult(); + return LayoutValueAdaptor(value, LayoutAttr::get(this->getAttr(shape), this->getAttr(stride))); + } + Value getValue(LayoutValueAdaptor adaptor) const { return adaptor.value; } +}; + +//===----------------------------------------------------------------------===// +// Layout operations +//===----------------------------------------------------------------------===// + +template +typename LayoutBuilder::IntTuple layoutSize(LayoutBuilder &builder, Layout layout) { + return intTupleProduct(builder, builder.getShape(layout)); +} + +template +typename LayoutBuilder::IntTuple layoutCosize(LayoutBuilder &builder, + Layout layout) { + using IntTuple = typename LayoutBuilder::IntTuple; + using ArithValue = typename LayoutBuilder::ArithValue; + + auto shape = builder.getShape(layout); + auto stride = builder.getStride(layout); + + SmallVector flatShapeLeaves; + SmallVector flatStrideLeaves; + intTupleFlattenToVector(builder, shape, flatShapeLeaves); + intTupleFlattenToVector(builder, stride, flatStrideLeaves); + + ArithValue one = builder.materializeConstantArith(1); + ArithValue s = builder.getArithValue(flatShapeLeaves[0]); + ArithValue d = builder.getArithValue(flatStrideLeaves[0]); + ArithValue cosize = builder.mul(builder.sub(s, one), d); + + for (size_t i = 1; i < flatShapeLeaves.size(); ++i) { + ArithValue s = builder.getArithValue(flatShapeLeaves[i]); + ArithValue d = builder.getArithValue(flatStrideLeaves[i]); + cosize = builder.add(cosize, builder.mul(builder.sub(s, one), d)); + } + return builder.makeInt(cosize); +} + +namespace detail { + +template +std::pair coalesceImpl(const IntTupleBuilder &builder, IntTuple shape, + IntTuple stride) { + using ArithValue = typename IntTupleBuilder::ArithValue; + + SmallVector flatShapeLeaves; + SmallVector flatStrideLeaves; + intTupleFlattenToVector(builder, shape, flatShapeLeaves); + intTupleFlattenToVector(builder, stride, flatStrideLeaves); + + const int flatRank = flatShapeLeaves.size(); + ArithValue currShapeInt = builder.getArithValue(flatShapeLeaves[flatRank - 1]); + ArithValue currStrideInt = builder.getArithValue(flatStrideLeaves[flatRank - 1]); + + if (flatRank == 1) { + if (builder.isStaticValue(currShapeInt, 1)) { + return {builder.makeInt(builder.materializeConstantArith(1)), + builder.makeInt(builder.materializeConstantArith(0))}; + } else { + return {shape, stride}; + } + } + + typename IntTupleBuilder::ElemCollector resultShape; + typename IntTupleBuilder::ElemCollector resultStride; + for (int i = flatRank - 2; i >= 0; --i) { + ArithValue nextShapeInt = builder.getArithValue(flatShapeLeaves[i]); + ArithValue nextStrideInt = builder.getArithValue(flatStrideLeaves[i]); + + if (builder.isStaticValue(nextShapeInt, 1)) { + continue; + } + if (builder.isStaticValue(currShapeInt, 1)) { + currShapeInt = nextShapeInt; + currStrideInt = nextStrideInt; + continue; + } + + bool merged = false; + if (builder.isStatic(nextShapeInt) && builder.isStatic(nextStrideInt) && + builder.isStatic(currShapeInt) && builder.isStatic(currStrideInt)) { + if (builder.getStaticValue(nextShapeInt) * builder.getStaticValue(nextStrideInt) == + builder.getStaticValue(currStrideInt)) { + currShapeInt = builder.mul(nextShapeInt, currShapeInt); + currStrideInt = nextStrideInt; + merged = true; + } + } + if (!merged) { + resultShape.push_back(builder.makeInt(currShapeInt)); + resultStride.push_back(builder.makeInt(currStrideInt)); + currShapeInt = nextShapeInt; + currStrideInt = nextStrideInt; + } + } + + if (resultShape.empty()) { + return {builder.makeInt(currShapeInt), builder.makeInt(currStrideInt)}; + } + resultShape.push_back(builder.makeInt(currShapeInt)); + resultStride.push_back(builder.makeInt(currStrideInt)); + resultShape.reverse(); + resultStride.reverse(); + return {builder.makeTuple(resultShape), builder.makeTuple(resultStride)}; +} + +template +std::pair coalesceWithProfile(const IntTupleBuilder &builder, + IntTuple shape, IntTuple stride, + IntTupleAttr profile) { + if (profile.isLeaf()) { + return coalesceImpl(builder, shape, stride); + } + + typename IntTupleBuilder::ElemCollector newShapeElems; + typename IntTupleBuilder::ElemCollector newStrideElems; + + int32_t profileRank = profile.rank(); + for (int i = 0; i < shape.rank(); ++i) { + if (i < profileRank) { + auto [cs, cd] = + coalesceWithProfile(builder, builder.at(shape, i), builder.at(stride, i), profile.at(i)); + newShapeElems.push_back(cs); + newStrideElems.push_back(cd); + } else { + newShapeElems.push_back(builder.at(shape, i)); + newStrideElems.push_back(builder.at(stride, i)); + } + } + return {builder.makeTuple(newShapeElems), builder.makeTuple(newStrideElems)}; +} + +template +std::pair compositionImpl(const IntTupleBuilder &builder, + IntTuple lhsShape, IntTuple lhsStride, + IntTuple rhsShape, IntTuple rhsStride) { + using ArithValue = typename IntTupleBuilder::ArithValue; + + if (!rhsShape.isLeaf()) { + typename IntTupleBuilder::ElemCollector resultShape; + typename IntTupleBuilder::ElemCollector resultStride; + for (int i = 0; i < rhsShape.rank(); ++i) { + auto [elemShape, elemStride] = compositionImpl( + builder, lhsShape, lhsStride, builder.at(rhsShape, i), builder.at(rhsStride, i)); + resultShape.push_back(elemShape); + resultStride.push_back(elemStride); + } + return {builder.makeTuple(resultShape), builder.makeTuple(resultStride)}; + } + + ArithValue rhsStrideVal = builder.getArithValue(rhsStride); + if (builder.isStaticValue(rhsStrideVal, 0)) { + return {rhsShape, rhsStride}; + } + if (lhsShape.isLeaf()) { + return {rhsShape, builder.makeInt(builder.mul(builder.getArithValue(lhsStride), rhsStrideVal))}; + } + + ArithValue restShape = builder.getArithValue(rhsShape); + ArithValue restStride = rhsStrideVal; + + typename IntTupleBuilder::ElemCollector resultShape; + typename IntTupleBuilder::ElemCollector resultStride; + int32_t resultCount = 0; + IntTuple lastShapeElem = rhsShape; + IntTuple lastStrideElem = rhsStride; + + int R = lhsShape.rank(); + for (int i = 0; i < R - 1; ++i) { + ArithValue currShape = builder.getArithValue(builder.at(lhsShape, i)); + ArithValue currStride = builder.getArithValue(builder.at(lhsStride, i)); + + if (builder.isStatic(currShape) && builder.isStatic(restStride)) { + int64_t restStrideVal = builder.getStaticValue(restStride); + int64_t currShapeVal = builder.getStaticValue(currShape); + assert(restStrideVal % currShapeVal == 0 || restStrideVal < currShapeVal); + } + + ArithValue nextShape = builder.ceilDiv(currShape, restStride); + ArithValue nextStride = builder.ceilDiv(restStride, currShape); + + if (builder.isStaticValue(nextShape, 1) || builder.isStaticValue(restShape, 1)) { + restStride = nextStride; + continue; + } + + ArithValue newShape = builder.min(nextShape, restShape); + ArithValue newStride = builder.mul(restStride, currStride); + + if (builder.isStatic(newShape) && builder.isStatic(restShape)) { + int64_t restShapeVal = builder.getStaticValue(restShape); + int64_t newShapeVal = builder.getStaticValue(newShape); + assert(restShapeVal % newShapeVal == 0); + } + + IntTuple lastShapeElem = builder.makeInt(newShape); + IntTuple lastStrideElem = builder.makeInt(newStride); + resultShape.push_back(lastShapeElem); + resultStride.push_back(lastStrideElem); + restShape = builder.div(restShape, newShape); + restStride = nextStride; + + ++resultCount; + } + + ArithValue lhsLastStride = builder.getArithValue(builder.at(lhsStride, R - 1)); + if (resultCount == 0) { + return {builder.makeInt(restShape), builder.makeInt(builder.mul(restStride, lhsLastStride))}; + } + if (builder.isStaticValue(restShape, 1)) { + if (resultCount == 1) { + return {lastShapeElem, lastStrideElem}; + } + return {builder.makeTuple(resultShape), builder.makeTuple(resultStride)}; + } + + resultShape.push_back(builder.makeInt(restShape)); + resultStride.push_back(builder.makeInt(builder.mul(restStride, lhsLastStride))); + return {builder.makeTuple(resultShape), builder.makeTuple(resultStride)}; +} + +template +std::pair complementImpl(const IntTupleBuilder &builder, + IntTuple filteredShape, IntTuple filteredStride, + IntTuple codomainSize) { + using ArithValue = typename IntTupleBuilder::ArithValue; + + if (!codomainSize.isLeaf()) { + assert(false && "this is for basis-strided layout, maybe support this later"); + return {filteredShape, filteredStride}; + } + + auto flatShape = intTupleFlatten(builder, filteredShape); + auto flatStride = intTupleFlatten(builder, filteredStride); + + if (flatStride.isLeaf()) { + if (builder.isStaticValue(builder.getArithValue(flatStride), 0)) { + return {codomainSize, builder.makeInt(builder.materializeConstantArith(1))}; + } + } + + const int R = flatStride.rank(); + assert(R == 1 || + builder.getAttr(filteredStride).isStatic() && "stride must be static for complement"); + + struct ShapeStridePair { + ArithValue shapeVal; + ArithValue strideVal; + int64_t strideStatic; + }; + SmallVector modes; + modes.reserve(R); + + if (!flatStride.isLeaf()) { + for (int i = 0; i < R; ++i) { + ArithValue s = builder.getArithValue(builder.at(flatShape, i)); + ArithValue d = builder.getArithValue(builder.at(flatStride, i)); + modes.push_back({s, d, builder.getStaticValue(d)}); + } + std::sort(modes.begin(), modes.end(), [](const ShapeStridePair &a, const ShapeStridePair &b) { + return a.strideStatic < b.strideStatic; + }); + } else { + modes.push_back({builder.getArithValue(flatShape), builder.getArithValue(flatStride), 0}); + } + + ArithValue lastStride = builder.materializeConstantArith(1); + typename IntTupleBuilder::ElemCollector resultShapeVals; + typename IntTupleBuilder::ElemCollector resultStrideVals; + + resultStrideVals.push_back(builder.makeInt(lastStride)); + for (int64_t i = 0; i < R - 1; ++i) { + ArithValue minStride = modes[i].strideVal; + ArithValue newShape = builder.div(minStride, lastStride); + ArithValue newStride = builder.mul(minStride, modes[i].shapeVal); + + resultShapeVals.push_back(builder.makeInt(newShape)); + resultStrideVals.push_back(builder.makeInt(newStride)); + lastStride = newStride; + } + + auto lastMode = modes.back(); + ArithValue newShape = builder.div(lastMode.strideVal, lastStride); + resultShapeVals.push_back(builder.makeInt(newShape)); + + ArithValue newStrideForRest = builder.mul(lastMode.strideVal, lastMode.shapeVal); + ArithValue restShape = builder.ceilDiv(builder.getArithValue(codomainSize), newStrideForRest); + ArithValue restStride = newStrideForRest; + + resultShapeVals.push_back(builder.makeInt(restShape)); + resultStrideVals.push_back(builder.makeInt(restStride)); + + return coalesceImpl(builder, builder.makeTuple(resultShapeVals), + builder.makeTuple(resultStrideVals)); +} + +} // namespace detail + +template +Layout layoutCoalesce(LayoutBuilder &builder, Layout layout, + std::optional profileAttr = std::nullopt) { + auto shape = builder.getShape(layout); + auto stride = builder.getStride(layout); + + if (profileAttr) { + auto [cs, cd] = detail::coalesceWithProfile(builder, shape, stride, *profileAttr); + return builder.makeLayout(cs, cd); + } + auto [cs, cd] = detail::coalesceImpl(builder, shape, stride); + return builder.makeLayout(cs, cd); +} + +template +Layout layoutComposition(LayoutBuilder &builder, Layout outerLayout, Layout innerLayout) { + auto [coalShape, coalStride] = + detail::coalesceImpl(builder, builder.getShape(outerLayout), builder.getStride(outerLayout)); + auto [retShape, retStride] = + detail::compositionImpl(builder, coalShape, coalStride, builder.getShape(innerLayout), + builder.getStride(innerLayout)); + return builder.makeLayout(retShape, retStride); +} +template +Layout layoutComposition(LayoutBuilder &builder, Layout outerLayout, + TileAttr innerTileAttr) { + using IntTuple = typename LayoutBuilder::IntTuple; + + auto lhsShape = builder.getShape(outerLayout); + auto lhsStride = builder.getStride(outerLayout); + + typename LayoutBuilder::ElemCollector retShape; + typename LayoutBuilder::ElemCollector retStride; + + int32_t tileRank = innerTileAttr.rank(); + for (int i = 0; i < lhsShape.rank(); ++i) { + if (i < tileRank && !innerTileAttr.isNoneMode(i)) { + auto [coalShape, coalStride] = + detail::coalesceImpl(builder, builder.at(lhsShape, i), builder.at(lhsStride, i)); + + IntTuple rhsShape, rhsStride; + if (auto attr = dyn_cast(innerTileAttr.at(i))) { + rhsShape = builder.materializeConstantTuple(attr.getShape()); + rhsStride = builder.materializeConstantTuple(attr.getStride()); + } else { + rhsShape = builder.makeInt( + builder.materializeConstantArith(cast(innerTileAttr.at(i)).getValue())); + rhsStride = builder.makeInt(builder.materializeConstantArith(1)); + } + auto [elemShape, elemStride] = + detail::compositionImpl(builder, coalShape, coalStride, rhsShape, rhsStride); + retShape.push_back(elemShape); + retStride.push_back(elemStride); + } else { + retShape.push_back(builder.at(lhsShape, i)); + retStride.push_back(builder.at(lhsStride, i)); + } + } + return builder.makeLayout(builder.makeTuple(retShape), builder.makeTuple(retStride)); +} + +template +Layout layoutComplement( + LayoutBuilder &builder, Layout layout, + std::optional::IntTuple> codomainSize = std::nullopt) { + using IntTuple = typename LayoutBuilder::IntTuple; + + auto filteredShape = intTupleFilterZero(builder, builder.getLayoutAttr(layout).getStride(), + builder.getShape(layout)); + auto filteredStride = builder.getStride(layout); + + IntTuple codomain = + codomainSize ? *codomainSize + : layoutCosize(builder, builder.makeLayout(filteredShape, filteredStride)); + auto [retShape, retStride] = + detail::complementImpl(builder, filteredShape, filteredStride, codomain); + return builder.makeLayout(retShape, retStride); +} + +template Layout layoutRightInverse(LayoutBuilder &builder, Layout layout); +template Layout layoutLeftInverse(LayoutBuilder &builder, Layout layout); + +template +Layout layoutLogicalDivide(LayoutBuilder &builder, Layout layout, Layout divisorLayout) { + using IntTuple = typename LayoutBuilder::IntTuple; + + auto coalesced = layoutCoalesce(builder, layout); + IntTuple codomainSize = layoutSize(builder, coalesced); + + auto complement = layoutComplement(builder, divisorLayout, codomainSize); + + typename LayoutBuilder::ElemCollector rhsShapeElems; + typename LayoutBuilder::ElemCollector rhsStrideElems; + rhsShapeElems.push_back(builder.getShape(divisorLayout)); + rhsShapeElems.push_back(builder.getShape(complement)); + rhsStrideElems.push_back(builder.getStride(divisorLayout)); + rhsStrideElems.push_back(builder.getStride(complement)); + + IntTuple rhsShape = builder.makeTuple(rhsShapeElems); + IntTuple rhsStride = builder.makeTuple(rhsStrideElems); + Layout rhsLayout = builder.makeLayout(rhsShape, rhsStride); + return layoutComposition(builder, layout, rhsLayout); +} + +template +Layout layoutLogicalDivide(LayoutBuilder &builder, Layout layout, TileAttr divisorTile) { + using IntTuple = typename LayoutBuilder::IntTuple; + + auto leafDivide = [&](Layout currentLayout, Attribute divisor) -> Layout { + if (auto attr = dyn_cast(divisor)) { + return layoutLogicalDivide(builder, currentLayout, builder.materializeConstantLayout(attr)); + } else if (auto intDivisor = dyn_cast(divisor)) { + IntTuple divisorShape = builder.materializeConstantTuple(IntTupleAttr::get(intDivisor)); + IntTuple divisorStride = builder.makeInt(builder.materializeConstantArith(1)); + Layout divisorLayout = builder.makeLayout(divisorShape, divisorStride); + return layoutLogicalDivide(builder, currentLayout, divisorLayout); + } + llvm_unreachable("invalid divisor type"); + }; + + if (divisorTile.isLeaf()) { + return leafDivide(layout, divisorTile.getValue()); + } + + auto shape = builder.getShape(layout); + auto stride = builder.getStride(layout); + int32_t layoutRank = shape.rank(); + int32_t tileRank = divisorTile.rank(); + + typename LayoutBuilder::ElemCollector outShape; + typename LayoutBuilder::ElemCollector outStride; + for (int i = 0; i < layoutRank; ++i) { + IntTuple shapeElem = builder.at(shape, i); + IntTuple strideElem = builder.at(stride, i); + if (i < tileRank && !divisorTile.isNoneMode(i)) { + Layout subLayout = builder.makeLayout(shapeElem, strideElem); + Layout divided = leafDivide(subLayout, divisorTile.at(i)); + outShape.push_back(builder.getShape(divided)); + outStride.push_back(builder.getStride(divided)); + } else { + outShape.push_back(shapeElem); + outStride.push_back(strideElem); + } + } + return builder.makeLayout(builder.makeTuple(outShape), builder.makeTuple(outStride)); +} + +template +Layout layoutZippedDivide(LayoutBuilder &builder, Layout layout, Layout divisorLayout) { + using IntTuple = typename LayoutBuilder::IntTuple; + + Layout logicalDiv = layoutLogicalDivide(builder, layout, divisorLayout); + + auto *ctx = builder.getLayoutAttr(layout).getContext(); + IntTupleAttr guide = IntTupleAttr::getLeafStatic(ctx, 1); + IntTuple retShape = intTupleZip2By(builder, builder.getShape(logicalDiv), guide); + IntTuple retStride = intTupleZip2By(builder, builder.getStride(logicalDiv), guide); + return builder.makeLayout(retShape, retStride); +} + +template +Layout layoutZippedDivide(LayoutBuilder &builder, Layout layout, TileAttr divisorTile) { + using IntTuple = typename LayoutBuilder::IntTuple; + + Layout logicalDiv = layoutLogicalDivide(builder, layout, divisorTile); + auto *ctx = builder.getLayoutAttr(layout).getContext(); + + SmallVector guideElems; + for (int i = 0; i < divisorTile.rank(); ++i) { + guideElems.push_back(IntTupleAttr::getLeafNone(ctx)); + } + IntTupleAttr guide = IntTupleAttr::get(ArrayAttr::get(ctx, guideElems)); + IntTuple retShape = intTupleZip2By(builder, builder.getShape(logicalDiv), guide); + IntTuple retStride = intTupleZip2By(builder, builder.getStride(logicalDiv), guide); + return builder.makeLayout(retShape, retStride); +} + +template +Layout layoutTiledDivide(LayoutBuilder &builder, Layout layout, Layout divisorLayout) { + using IntTuple = typename LayoutBuilder::IntTuple; + + Layout zipped = layoutZippedDivide(builder, layout, divisorLayout); + IntTuple retShape = intTupleExpand(builder, builder.getShape(zipped), {1}); + IntTuple retStride = intTupleExpand(builder, builder.getStride(zipped), {1}); + return builder.makeLayout(retShape, retStride); +} +template +Layout layoutTiledDivide(LayoutBuilder &builder, Layout layout, TileAttr divisorTile) { + using IntTuple = typename LayoutBuilder::IntTuple; + Layout zipped = layoutZippedDivide(builder, layout, divisorTile); + IntTuple retShape = intTupleExpand(builder, builder.getShape(zipped), {1}); + IntTuple retStride = intTupleExpand(builder, builder.getStride(zipped), {1}); + return builder.makeLayout(retShape, retStride); +} + +template +Layout layoutFlatDivide(LayoutBuilder &builder, Layout layout, Layout divisorLayout) { + using IntTuple = typename LayoutBuilder::IntTuple; + Layout zipped = layoutZippedDivide(builder, layout, divisorLayout); + IntTuple retShape = intTupleExpand(builder, builder.getShape(zipped), {0, 1}); + IntTuple retStride = intTupleExpand(builder, builder.getStride(zipped), {0, 1}); + return builder.makeLayout(retShape, retStride); +} +template +Layout layoutFlatDivide(LayoutBuilder &builder, Layout layout, TileAttr divisorTile) { + using IntTuple = typename LayoutBuilder::IntTuple; + Layout zipped = layoutZippedDivide(builder, layout, divisorTile); + IntTuple retShape = intTupleExpand(builder, builder.getShape(zipped), {0, 1}); + IntTuple retStride = intTupleExpand(builder, builder.getStride(zipped), {0, 1}); + return builder.makeLayout(retShape, retStride); +} + +template +Layout layoutAppendToRank(LayoutBuilder &builder, Layout layout, int32_t targetRank) { + auto shape = builder.getShape(layout); + auto stride = builder.getStride(layout); + int32_t currentRank = shape.rank(); + if (targetRank <= currentRank) { + return layout; + } + + typename LayoutBuilder::ElemCollector shapeElems; + typename LayoutBuilder::ElemCollector strideElems; + if (shape.isLeaf()) { + shapeElems.push_back(shape); + strideElems.push_back(stride); + } else { + for (int i = 0; i < shape.rank(); ++i) { + shapeElems.push_back(builder.at(shape, i)); + strideElems.push_back(builder.at(stride, i)); + } + } + + for (int32_t i = currentRank; i < targetRank; ++i) { + shapeElems.push_back(builder.makeInt(builder.materializeConstantArith(1))); + strideElems.push_back(builder.makeInt(builder.materializeConstantArith(0))); + } + return builder.makeLayout(builder.makeTuple(shapeElems), builder.makeTuple(strideElems)); +} + +template +Layout layoutLogicalProduct(LayoutBuilder &builder, Layout blockLayout, + Layout tilerLayout) { + using IntTuple = typename LayoutBuilder::IntTuple; + + IntTuple blockSize = layoutSize(builder, blockLayout); + IntTuple tilerCosize = layoutCosize(builder, tilerLayout); + auto blockSizeVal = builder.getArithValue(blockSize); + auto tilerCosizeVal = builder.getArithValue(tilerCosize); + + if (!builder.isStatic(blockSizeVal) || !builder.isStatic(tilerCosizeVal)) { + return blockLayout; + } + + IntTuple codomainSize = builder.makeInt(builder.mul(blockSizeVal, tilerCosizeVal)); + Layout complement = layoutComplement(builder, blockLayout, codomainSize); + Layout composed = layoutComposition(builder, complement, tilerLayout); + + typename LayoutBuilder::ElemCollector retShapeElems; + typename LayoutBuilder::ElemCollector retStrideElems; + retShapeElems.push_back(builder.getShape(blockLayout)); + retShapeElems.push_back(builder.getShape(composed)); + retStrideElems.push_back(builder.getStride(blockLayout)); + retStrideElems.push_back(builder.getStride(composed)); + + return builder.makeLayout(builder.makeTuple(retShapeElems), builder.makeTuple(retStrideElems)); +} + +template +Layout layoutBlockedProduct(LayoutBuilder &builder, Layout blockLayout, + Layout tilerLayout) { + auto blockShape = builder.getShape(blockLayout); + auto tilerShape = builder.getShape(tilerLayout); + int32_t blockRank = blockShape.isLeaf() ? 1 : blockShape.rank(); + int32_t tilerRank = tilerShape.isLeaf() ? 1 : tilerShape.rank(); + int32_t targetRank = std::max(blockRank, tilerRank); + + Layout paddedBlock = layoutAppendToRank(builder, blockLayout, targetRank); + Layout paddedTiler = layoutAppendToRank(builder, tilerLayout, targetRank); + Layout logicalProd = layoutLogicalProduct(builder, paddedBlock, paddedTiler); + + auto outShape = intTupleZip(builder, builder.at(builder.getShape(logicalProd), 0), + builder.at(builder.getShape(logicalProd), 1)); + auto outStride = intTupleZip(builder, builder.at(builder.getStride(logicalProd), 0), + builder.at(builder.getStride(logicalProd), 1)); + return builder.makeLayout(outShape, outStride); +} + +template +Layout layoutRakedProduct(LayoutBuilder &builder, Layout blockLayout, Layout tilerLayout) { + auto blockShape = builder.getShape(blockLayout); + auto tilerShape = builder.getShape(tilerLayout); + int32_t blockRank = blockShape.isLeaf() ? 1 : blockShape.rank(); + int32_t tilerRank = tilerShape.isLeaf() ? 1 : tilerShape.rank(); + int32_t targetRank = std::max(blockRank, tilerRank); + + Layout paddedBlock = layoutAppendToRank(builder, blockLayout, targetRank); + Layout paddedTiler = layoutAppendToRank(builder, tilerLayout, targetRank); + Layout logicalProd = layoutLogicalProduct(builder, paddedBlock, paddedTiler); + + auto outShape = intTupleZip(builder, builder.at(builder.getShape(logicalProd), 1), + builder.at(builder.getShape(logicalProd), 0)); + auto outStride = intTupleZip(builder, builder.at(builder.getStride(logicalProd), 1), + builder.at(builder.getStride(logicalProd), 0)); + return builder.makeLayout(outShape, outStride); +} + +} // namespace mlir::fly + +#endif // FLY_DIALECT_UTILS_LAYOUTATTR_H diff --git a/include/flydsl/Dialect/Fly/Utils/NormalForm.h b/include/flydsl/Dialect/Fly/Utils/NormalForm.h new file mode 100644 index 00000000..1b7b6e3e --- /dev/null +++ b/include/flydsl/Dialect/Fly/Utils/NormalForm.h @@ -0,0 +1,25 @@ +#ifndef FLY_DIALECT_UTILS_NORMALFORM_H +#define FLY_DIALECT_UTILS_NORMALFORM_H + +#include "mlir/IR/Attributes.h" +#include "mlir/Support/LogicalResult.h" + +#include "flydsl/Dialect/Fly/IR/FlyDialect.h" +#include "flydsl/Dialect/Fly/Utils/IntTupleUtils.h" +#include "flydsl/Dialect/Fly/Utils/LayoutUtils.h" + +namespace mlir::fly { + +bool isNormalForm(TypedValue value); +bool isNormalForm(TypedValue value); +bool isNormalForm(TypedValue value); +bool isNormalForm(TypedValue value); +bool isNormalForm(TypedValue value); +bool isNormalForm(TypedValue value); + +bool isNormalForm(TypedValue value); +bool isNormalForm(TypedValue value); + +} // namespace mlir::fly + +#endif // FLY_DIALECT_UTILS_NORMALFORM_H diff --git a/include/flydsl/Dialect/FlyROCDL/IR/Dialect.td b/include/flydsl/Dialect/FlyROCDL/IR/Dialect.td new file mode 100644 index 00000000..3a86fdb9 --- /dev/null +++ b/include/flydsl/Dialect/FlyROCDL/IR/Dialect.td @@ -0,0 +1,22 @@ +#ifndef FLYROCDL_DIALECT +#define FLYROCDL_DIALECT + +include "mlir/IR/EnumAttr.td" +include "mlir/IR/AttrTypeBase.td" +include "mlir/IR/OpBase.td" + +include "flydsl/Dialect/Fly/IR/FlyInterfaces.td" + +def FlyROCDL_Dialect : Dialect { + let name = "fly_rocdl"; + let cppNamespace = "::mlir::fly_rocdl"; + + let usePropertiesForAttributes = 1; +} + +class FlyROCDL_Type traits = []> + : TypeDef { + let mnemonic = typeMnemonic; +} + +#endif // FLYROCDL_DIALECT diff --git a/lib/Bindings/Python/MainModules.cpp b/lib/Bindings/Python/MainModules.cpp new file mode 100644 index 00000000..a94fe75f --- /dev/null +++ b/lib/Bindings/Python/MainModules.cpp @@ -0,0 +1,228 @@ +#include "mlir-c/Bindings/Python/Interop.h" +#include "mlir-c/Dialect/LLVM.h" +#include "mlir-c/IR.h" +#include "mlir-c/Support.h" +#include "mlir/Bindings/Python/Nanobind.h" +#include "mlir/Bindings/Python/NanobindAdaptors.h" +#include "mlir/CAPI/IR.h" +#include "mlir/CAPI/Wrap.h" + +#include +#include +#include + +#include "flydsl/Dialect/Fly/IR/FlyDialect.h" +#include "flydsl/Dialect/Fly/Utils/IntUtils.h" + +namespace nb = nanobind; +using namespace nb::literals; +using namespace mlir; +using namespace mlir::fly; + +// ----------------------------------------------------------------------------- +// Module initialization. +// ----------------------------------------------------------------------------- + +namespace { + +/// Helper to convert Python value to IntTupleAttr +struct IntTupleAttrBuilder { + MLIRContext *ctx; + std::vector dyncElems{}; + + IntTupleAttrBuilder(MLIRContext *ctx) : ctx(ctx) {} + + IntTupleAttr operator()(nb::handle args) { + if (PyTuple_Check(args.ptr())) { + SmallVector elements; + for (auto item : args) { + elements.push_back((*this)(item)); + } + return IntTupleAttr::get(ArrayAttr::get(ctx, elements)); + } else if (PyLong_Check(args.ptr())) { + int32_t cInt = PyLong_AsLong(args.ptr()); + return IntTupleAttr::get(IntAttr::getStatic(ctx, cInt)); + } else if (args.is_none()) { + return IntTupleAttr::getLeafNone(ctx); + } else { + // Dynamic value - for now treat as dynamic + dyncElems.push_back(args); + return IntTupleAttr::get(IntAttr::getDynamic(ctx)); + } + } +}; + +} // namespace + +int32_t rank(nb::handle int_or_tuple) { + nb::object capsule = int_or_tuple.attr("_CAPIPtr"); + MlirValue mlirVal = mlirPythonCapsuleToValue(capsule.ptr()); + mlir::Value val = unwrap(mlirVal); + mlir::Type ty = val.getType(); + if (auto intTupleTy = dyn_cast(ty)) { + return intTupleTy.getAttr().rank(); + } else if (auto layoutTy = dyn_cast(ty)) { + return layoutTy.getAttr().rank(); + } + return 1; +} + +int32_t depth(nb::handle int_or_tuple) { + nb::object capsule = int_or_tuple.attr("_CAPIPtr"); + MlirValue mlirVal = mlirPythonCapsuleToValue(capsule.ptr()); + mlir::Value val = unwrap(mlirVal); + mlir::Type ty = val.getType(); + if (auto intTupleTy = dyn_cast(ty)) { + return intTupleTy.getAttr().depth(); + } else if (auto layoutTy = dyn_cast(ty)) { + return layoutTy.getAttr().depth(); + } + return 0; +} + +// nb::object getFlyTypingModule() { +// static nb::object typing = nb::steal(nb::module_::import_("fly.lang.typing")); +// return typing; +// } + +// nb::object make_int32(int value) { +// static nb::object int32_cls = getFlyTypingModule().attr("Int32"); + +// return int32_cls(value); +// } + +// nb::object make_int32_tuple(int value) { +// static nb::object int32_cls = getFlyTypingModule().attr("Int32"); + +// nb::list subList; +// subList.append(int32_cls(value + 1)); +// nb::tuple subTuple = nb::tuple(subList); + +// nb::list retList; +// retList.append(int32_cls(value)); +// retList.append(subTuple); +// retList.append(nb::int_(0)); + +// return nb::tuple(retList); +// } + +NB_MODULE(_fly, m) { + m.doc() = "MLIR Python FlyDSL Extension"; + + m.def( + "infer_int_tuple_type", + [](MlirContext context, nb::handle int_or_tuple) { + MLIRContext *ctx = unwrap(context); + IntTupleAttrBuilder builder{ctx}; + IntTupleAttr attr = builder(int_or_tuple); + auto intTupleType = IntTupleType::get(attr); + MlirType wrappedType = wrap(intTupleType); + return std::make_pair(wrappedType, builder.dyncElems); + }, + nb::arg("context"), nb::arg("int_or_tuple")); + + m.def( + "infer_layout_type", + [](MlirContext context, nb::handle shape, nb::handle stride) { + MLIRContext *ctx = unwrap(context); + IntTupleAttrBuilder builder{ctx}; + IntTupleAttr shapeAttr = builder(shape); + IntTupleAttr strideAttr = builder(stride); + auto layoutAttr = LayoutAttr::get(ctx, shapeAttr, strideAttr); + auto layoutType = LayoutType::get(ctx, layoutAttr); + MlirType wrappedType = wrap(layoutType); + return wrappedType; + }, + nb::arg("context"), nb::arg("shape"), nb::arg("stride")); + + m.def("rank", &rank, nb::arg("int_or_tuple")); + m.def("depth", &depth, nb::arg("int_or_tuple")); + + //===--------------------------------------------------------------------===// + // Fly Type Classes with static get() methods + //===--------------------------------------------------------------------===// + + nb::class_(m, "PointerType") + .def_static( + "get", + [](MlirType elemTy, int32_t addressSpace, std::optional alignment) { + mlir::Type unwrappedElemTy = unwrap(elemTy); + MLIRContext *ctx = unwrappedElemTy.getContext(); + + AddressSpaceAttr addrSpaceAttr = + AddressSpaceAttr::get(ctx, static_cast(addressSpace)); + + fly::PointerType ptrType; + if (alignment.has_value()) { + AlignAttr alignAttr = AlignAttr::get(ctx, alignment.value()); + ptrType = fly::PointerType::get(ctx, unwrappedElemTy, addrSpaceAttr, alignAttr, + SwizzleAttr::getTrivialSwizzle(ctx)); + } else { + ptrType = fly::PointerType::get(unwrappedElemTy, addrSpaceAttr); + } + return wrap(static_cast(ptrType)); + }, + nb::arg("elem_ty"), nb::arg("address_space"), nb::arg("alignment") = nb::none(), + "Create a PointerType with element type and address space"); + + nb::class_(m, "MemRefType") + .def_static( + "get", + [](MlirType elemTy, int32_t addressSpace, MlirType layoutTy, + std::optional alignment) { + mlir::Type unwrappedElemTy = unwrap(elemTy); + mlir::Type unwrappedLayoutTy = unwrap(layoutTy); + MLIRContext *ctx = unwrappedElemTy.getContext(); + + auto layoutType = dyn_cast(unwrappedLayoutTy); + if (!layoutType) { + throw std::invalid_argument("layout must be a LayoutType"); + } + + AddressSpaceAttr addrSpaceAttr = + AddressSpaceAttr::get(ctx, static_cast(addressSpace)); + LayoutAttr layoutAttr = layoutType.getAttr(); + + fly::MemRefType memrefType; + if (alignment.has_value()) { + AlignAttr alignAttr = AlignAttr::get(ctx, alignment.value()); + memrefType = fly::MemRefType::get(ctx, unwrappedElemTy, addrSpaceAttr, layoutAttr, + alignAttr, SwizzleAttr::getTrivialSwizzle(ctx)); + } else { + memrefType = fly::MemRefType::get(unwrappedElemTy, addrSpaceAttr, layoutAttr); + } + return wrap(static_cast(memrefType)); + }, + nb::arg("elem_ty"), nb::arg("address_space"), nb::arg("layout"), + nb::arg("alignment") = nb::none(), + "Create a MemRefType with element type, address space and layout"); + + nb::class_(m, "LayoutType") + .def_static( + "get", + [](MlirContext context, nb::handle shape, nb::handle stride) { + MLIRContext *ctx = unwrap(context); + IntTupleAttrBuilder builder{ctx}; + IntTupleAttr shapeAttr = builder(shape); + IntTupleAttr strideAttr = builder(stride); + auto layoutAttr = LayoutAttr::get(ctx, shapeAttr, strideAttr); + auto layoutType = LayoutType::get(ctx, layoutAttr); + return wrap(static_cast(layoutType)); + }, + nb::arg("context"), nb::arg("shape"), nb::arg("stride"), + "Create a LayoutType with shape and stride"); + + // IntTupleType class + nb::class_(m, "IntTupleType") + .def_static( + "get", + [](MlirContext context, nb::handle int_or_tuple) { + MLIRContext *ctx = unwrap(context); + IntTupleAttrBuilder builder{ctx}; + IntTupleAttr attr = builder(int_or_tuple); + auto intTupleType = IntTupleType::get(attr); + return std::make_pair(wrap(static_cast(intTupleType)), builder.dyncElems); + }, + nb::arg("context"), nb::arg("int_or_tuple"), + "Create an IntTupleType from Python int or tuple"); +} diff --git a/lib/CAPI/CMakeLists.txt b/lib/CAPI/CMakeLists.txt new file mode 100644 index 00000000..2c17c103 --- /dev/null +++ b/lib/CAPI/CMakeLists.txt @@ -0,0 +1,6 @@ +add_mlir_public_c_api_library(MLIRCPIFly + FlyDialect.cpp + LINK_LIBS PUBLIC + MLIRFlyDialect + MLIRFlyToROCDL +) diff --git a/lib/CAPI/FlyDialect.cpp b/lib/CAPI/FlyDialect.cpp new file mode 100644 index 00000000..ce18bfb6 --- /dev/null +++ b/lib/CAPI/FlyDialect.cpp @@ -0,0 +1,6 @@ +#include "flydsl-c/FlyDialect.h" + +#include "flydsl/Dialect/Fly/IR/FlyDialect.h" +#include "mlir/CAPI/Registration.h" + +MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(Fly, fly, mlir::fly::FlyDialect) diff --git a/lib/CMakeLists.txt b/lib/CMakeLists.txt new file mode 100644 index 00000000..8426d229 --- /dev/null +++ b/lib/CMakeLists.txt @@ -0,0 +1,3 @@ +add_subdirectory(Conversion) +add_subdirectory(Dialect) +add_subdirectory(CAPI) diff --git a/lib/Conversion/CMakeLists.txt b/lib/Conversion/CMakeLists.txt new file mode 100644 index 00000000..be704b62 --- /dev/null +++ b/lib/Conversion/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(FlyToROCDL) diff --git a/lib/Conversion/FlyToROCDL/CMakeLists.txt b/lib/Conversion/FlyToROCDL/CMakeLists.txt new file mode 100644 index 00000000..9eb29bd5 --- /dev/null +++ b/lib/Conversion/FlyToROCDL/CMakeLists.txt @@ -0,0 +1,22 @@ +add_mlir_conversion_library(MLIRFlyToROCDL + FlyToROCDL.cpp + + DEPENDS + MLIRFlyIncGen + FlyConversionPassIncGen + + LINK_LIBS PUBLIC + MLIRFlyDialect + + MLIRAffineDialect + MLIRAffineTransforms + MLIRAffineUtils + MLIRArithDialect + MLIRIR + MLIRLLVMDialect + MLIRMemRefDialect + MLIRPass + MLIRSCFDialect + MLIRTransforms + MLIRVectorDialect +) diff --git a/lib/Conversion/FlyToROCDL/FlyToROCDL.cpp b/lib/Conversion/FlyToROCDL/FlyToROCDL.cpp new file mode 100644 index 00000000..0deed6ae --- /dev/null +++ b/lib/Conversion/FlyToROCDL/FlyToROCDL.cpp @@ -0,0 +1,599 @@ + +#include "flydsl/Dialect/Fly/Utils/IntTupleUtils.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Arith/Utils/Utils.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/LLVMIR/ROCDLDialect.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#include "flydsl/Conversion/FlyToROCDL/FlyToROCDL.h" +#include "flydsl/Dialect/Fly/IR/FlyDialect.h" + +#include +#include +#include +#include + +namespace mlir { +#define GEN_PASS_DEF_FLYTOROCDLCONVERSIONPASS +#include "flydsl/Conversion/Passes.h.inc" +} // namespace mlir + +using namespace mlir; +using namespace mlir::fly; + +namespace { + +// Helper to get the flattened size from an IntTupleAttr (product of all elements) +static int64_t getFlattenedSize(IntTupleAttr attr) { + IntTupleBuilder builder(attr.getContext()); + IntAttr product = intTupleProduct(builder, attr).getLeafAsInt(); + if (product.isStatic()) + return product.getValue(); + return 1; +} + +static int64_t getFlattenedSize(LayoutAttr attr) { return getFlattenedSize(attr.getShape()); } + +static unsigned mapAddressSpace(AddressSpace space) { + // - Global -> 1 (global) + // - Shared -> 3 (local/LDS/workgroup) + // - Register -> 5 (private) + // Fallback to 0 (generic). + switch (space) { + case AddressSpace::Global: + return 1; + case AddressSpace::Shared: + return 3; + case AddressSpace::Register: + return 5; + } + return 0; +} + +static FailureOr toI64(Value v, Location loc, ConversionPatternRewriter &rewriter) { + Type i64Ty = rewriter.getI64Type(); + if (v.getType() == i64Ty) + return v; + if (v.getType().isIndex()) + return arith::IndexCastOp::create(rewriter, loc, i64Ty, v).getResult(); + if (auto intTy = dyn_cast(v.getType())) { + if (intTy.getWidth() < 64) + return arith::ExtSIOp::create(rewriter, loc, i64Ty, v).getResult(); + if (intTy.getWidth() > 64) + return arith::TruncIOp::create(rewriter, loc, i64Ty, v).getResult(); + } + return failure(); +} + +class MemRefAllocOpLowering : public OpConversionPattern { +public: + MemRefAllocOpLowering(const TypeConverter &typeConverter, MLIRContext *context) + : OpConversionPattern(typeConverter, context) {} + + LogicalResult matchAndRewrite(MemRefAllocaOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto flyMemRefTy = dyn_cast(op.getResult().getType()); + if (!flyMemRefTy) + return failure(); + + LayoutAttr layoutAttr = flyMemRefTy.getLayout(); + auto elemTy = flyMemRefTy.getElemTy(); + + int64_t totalSize = getFlattenedSize(layoutAttr); + + auto convertedPtrTy = + dyn_cast(getTypeConverter()->convertType(flyMemRefTy)); + if (!convertedPtrTy) + return failure(); + + auto loc = op.getLoc(); + + // Alloca array size is i64. + Value nElems = arith::ConstantIntOp::create(rewriter, loc, totalSize, /*width=*/64).getResult(); + + // `llvm.alloca` takes element type and array size. Keep alignment unspecified. + Value ptr = LLVM::AllocaOp::create(rewriter, loc, convertedPtrTy, elemTy, nElems, + /*alignment=*/0); + rewriter.replaceOp(op, ptr); + return success(); + } +}; + +/// Materialize a scalar index from a non-array `!fly.int_tuple` value. +/// This is used for pointer/memref offset computations. +static FailureOr materializeScalarIndex(Value intTuple, Location loc, + ConversionPatternRewriter &rewriter) { + auto tupleTy = dyn_cast(intTuple.getType()); + if (!tupleTy) + return failure(); + + IntTupleAttr profile = tupleTy.getAttr(); + if (!profile.isLeaf()) + return failure(); + + // Static scalar. + if (auto intAttr = dyn_cast(profile.getValue())) { + if (intAttr.isStatic()) { + Value c = arith::ConstantIndexOp::create(rewriter, loc, intAttr.getValue()); + return c; + } + } + if (profile.getLeafAsInt().isNone()) { + Value c = arith::ConstantIndexOp::create(rewriter, loc, 0); + return c; + } + + // Dynamic scalar: expect it comes from fly.make_int_tuple with exactly one operand. + if (Operation *defOp = intTuple.getDefiningOp()) { + if (defOp->getName().getStringRef() == "fly.make_int_tuple" && defOp->getNumOperands() == 1) { + Value v = defOp->getOperand(0); + if (v.getType().isIndex()) + return v; + // Most Fly scalars are i32; cast to index when needed. + if (v.getType().isSignlessInteger()) + return arith::IndexCastOp::create(rewriter, loc, rewriter.getIndexType(), v).getResult(); + } + } + + return failure(); +} + +/// Lower `fly.get_iter`: convert a Fly memref to a dynamic 1-D memref "pointer view". +class GetIterOpLowering : public OpConversionPattern { +public: + GetIterOpLowering(const TypeConverter &typeConverter, MLIRContext *context) + : OpConversionPattern(typeConverter, context) {} + + LogicalResult matchAndRewrite(GetIterOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // After type conversion, Fly memref is already a `!llvm.ptr`. + Value mem = adaptor.getMemref(); + auto resTy = + dyn_cast(getTypeConverter()->convertType(op.getResult().getType())); + if (!resTy) + return failure(); + assert(mem.getType() == resTy); + rewriter.replaceOp(op, mem); + return success(); + } +}; + +/// Lower `fly.add_offset`: produce a subview of the dynamic memref "pointer view". +class AddOffsetOpLowering : public OpConversionPattern { +public: + AddOffsetOpLowering(const TypeConverter &typeConverter, MLIRContext *context) + : OpConversionPattern(typeConverter, context) {} + + LogicalResult matchAndRewrite(AddOffsetOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + Value base = adaptor.getPtr(); + auto baseTy = dyn_cast(base.getType()); + if (!baseTy) + return failure(); + + auto offsetIdx = materializeScalarIndex(op.getOffset(), loc, rewriter); + if (failed(offsetIdx)) + return failure(); + + auto resultTy = + dyn_cast(getTypeConverter()->convertType(op.getResult().getType())); + if (!resultTy) + return failure(); + + FailureOr offsetI64 = toI64(*offsetIdx, loc, rewriter); + if (failed(offsetI64)) + return failure(); + + // Pointer arithmetic: gep by element offset. + // Note: GEP element type is the (pointee) element type of the pointer. + auto flyPtrTy = dyn_cast(op.getPtr().getType()); + if (!flyPtrTy) + return failure(); + Type elemTy = flyPtrTy.getElemTy(); + Value gep = LLVM::GEPOp::create(rewriter, loc, resultTy, elemTy, base, ValueRange{*offsetI64}); + rewriter.replaceOp(op, gep); + return success(); + } +}; + +/// Lower `fly.make_view`: take a dynamic pointer view and produce a sized 1-D memref view. +class MakeViewOpLowering : public OpConversionPattern { +public: + MakeViewOpLowering(const TypeConverter &typeConverter, MLIRContext *context) + : OpConversionPattern(typeConverter, context) {} + + LogicalResult matchAndRewrite(MakeViewOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value base = adaptor.getIter(); + auto baseTy = dyn_cast(base.getType()); + if (!baseTy) + return failure(); + + auto resultTy = + dyn_cast(getTypeConverter()->convertType(op.getResult().getType())); + if (!resultTy) + return failure(); + if (base.getType() == resultTy) { + rewriter.replaceOp(op, base); + return success(); + } + rewriter.replaceOpWithNewOp(op, resultTy, base); + return success(); + } +}; + +class MemRefLoadVecOpLowering : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(MemRefLoadVecOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + Value input = adaptor.getMemref(); + + auto ptrTy = dyn_cast(input.getType()); + if (!ptrTy) + return failure(); + + auto resVecTy = dyn_cast(op.getResult().getType()); + if (!resVecTy) + return failure(); + + // Opaque pointers: we can directly load a vector from the base address. + Value loaded = LLVM::LoadOp::create(rewriter, loc, resVecTy, input); + rewriter.replaceOp(op, loaded); + return success(); + } +}; + +class MemRefStoreVecOpLowering : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(MemRefStoreVecOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + Value dest = adaptor.getMemref(); + Value valueToStore = adaptor.getVector(); + + auto ptrTy = dyn_cast(dest.getType()); + if (!ptrTy) + return failure(); + + auto vecTy = dyn_cast(valueToStore.getType()); + if (!vecTy) + return failure(); + + LLVM::StoreOp::create(rewriter, loc, valueToStore, dest); + rewriter.eraseOp(op); + return success(); + } +}; + +class MemRefLoadOpLowering : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(MemRefLoadOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value mem = adaptor.getMemref(); + auto ptrTy = dyn_cast(mem.getType()); + if (!ptrTy) + return failure(); + + // `fly.memref.load` takes a scalar int_tuple offset. + auto idxVal = materializeScalarIndex(op.getIndices(), op.getLoc(), rewriter); + if (failed(idxVal)) + return failure(); + FailureOr idxI64 = toI64(*idxVal, op.getLoc(), rewriter); + if (failed(idxI64)) + return failure(); + + auto flyMemRefTy = dyn_cast(op.getMemref().getType()); + if (!flyMemRefTy) + return failure(); + Type elemTy = flyMemRefTy.getElemTy(); + Value gep = LLVM::GEPOp::create(rewriter, op.getLoc(), ptrTy, elemTy, mem, ValueRange{*idxI64}); + Value loaded = LLVM::LoadOp::create(rewriter, op.getLoc(), op.getResult().getType(), gep); + rewriter.replaceOp(op, loaded); + return success(); + } +}; + +class MemRefStoreOpLowering : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(MemRefStoreOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value mem = adaptor.getMemref(); + auto ptrTy = dyn_cast(mem.getType()); + if (!ptrTy) + return failure(); + + auto idxVal = materializeScalarIndex(op.getIndices(), op.getLoc(), rewriter); + if (failed(idxVal)) + return failure(); + FailureOr idxI64 = toI64(*idxVal, op.getLoc(), rewriter); + if (failed(idxI64)) + return failure(); + auto flyMemRefTy = dyn_cast(op.getMemref().getType()); + if (!flyMemRefTy) + return failure(); + Type elemTy = flyMemRefTy.getElemTy(); + Value gep = LLVM::GEPOp::create(rewriter, op.getLoc(), ptrTy, elemTy, mem, ValueRange{*idxI64}); + LLVM::StoreOp::create(rewriter, op.getLoc(), adaptor.getValue(), gep); + rewriter.eraseOp(op); + return success(); + } +}; + +class CopyAtomCallLowering : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(CopyAtomCall op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // Only handle the universal memref-to-memref copy atom here. + if (!isa(adaptor.getCopyAtom().getType())) + return rewriter.notifyMatchFailure(op, "unsupported copy atom (expected universal_copy_32b)"); + + Value src = adaptor.getSrc(); + Value dst = adaptor.getDst(); + + auto srcPtrTy = dyn_cast(src.getType()); + auto dstPtrTy = dyn_cast(dst.getType()); + if (!srcPtrTy || !dstPtrTy) + return rewriter.notifyMatchFailure(op, "src/dst are not llvm.ptr after conversion"); + + // Determine element type + total size from the *original* Fly memref type to avoid + // losing layout information in the lowered pointer type. + auto srcFlyTy = dyn_cast(op.getSrc().getType()); + auto dstFlyTy = dyn_cast(op.getDst().getType()); + if (!srcFlyTy || !dstFlyTy) + return rewriter.notifyMatchFailure(op, "expected Fly memref types on original op"); + + if (srcFlyTy.getElemTy() != dstFlyTy.getElemTy()) + return rewriter.notifyMatchFailure(op, "src/dst element types mismatch"); + + int64_t nElems = getFlattenedSize(srcFlyTy.getLayout()); + if (nElems != getFlattenedSize(dstFlyTy.getLayout())) + return rewriter.notifyMatchFailure(op, "src/dst shapes mismatch"); + + // Lower to LLVM memcpy intrinsic to keep GPU kernel fully in LLVM dialect + // (GpuModuleToBinary requires no SCF/arith/unrealized casts inside the module). + Location loc = op.getLoc(); + Type elemTy = srcFlyTy.getElemTy(); + int64_t elemBytes = 0; + if (auto ft = dyn_cast(elemTy)) + elemBytes = ft.getWidth() / 8; + else if (auto it = dyn_cast(elemTy)) + elemBytes = it.getWidth() / 8; + else + return rewriter.notifyMatchFailure(op, "unsupported element type for memcpy sizing"); + if (elemBytes <= 0) + return rewriter.notifyMatchFailure(op, "invalid element byte width"); + + int64_t totalBytes = nElems * elemBytes; + Value len = arith::ConstantIntOp::create(rewriter, loc, totalBytes, /*width=*/64).getResult(); + + // llvm.intr.memcpy(dst, src, len, isVolatile=false) + LLVM::MemcpyOp::create(rewriter, loc, dst, src, len, /*isVolatile=*/false); + + rewriter.eraseOp(op); + return success(); + } +}; + +class MmaAtomCallLowering : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(MmaAtomCall op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // Only handle MFMA F32 16x16x4 F32 atom for now. + if (!isa(adaptor.getMmaAtom().getType())) + return rewriter.notifyMatchFailure(op, "unsupported mma atom (expected mfma.f32.16x16x4f32)"); + + Location loc = op.getLoc(); + + // After Fly type conversion, memrefs are lowered to `!llvm.ptr`. + Value dPtr = adaptor.getD(); + Value aPtr = adaptor.getA(); + Value bPtr = adaptor.getB(); + Value cPtr = adaptor.getC(); + + auto dPtrTy = dyn_cast(dPtr.getType()); + auto aPtrTy = dyn_cast(aPtr.getType()); + auto bPtrTy = dyn_cast(bPtr.getType()); + auto cPtrTy = dyn_cast(cPtr.getType()); + if (!dPtrTy || !aPtrTy || !bPtrTy || !cPtrTy) + return rewriter.notifyMatchFailure(op, "expected llvm.ptr operands after type conversion"); + + Type f32Ty = rewriter.getF32Type(); + VectorType accTy = VectorType::get({4}, f32Ty); + + // Load A/B scalars and C accumulator vector from the provided pointers. + Value a = LLVM::LoadOp::create(rewriter, loc, f32Ty, aPtr); + Value b = LLVM::LoadOp::create(rewriter, loc, f32Ty, bPtr); + Value c = LLVM::LoadOp::create(rewriter, loc, accTy, cPtr); + + // MFMA control operands (cbsz, abid, blgp). Default to 0. + Value zeroI32 = arith::ConstantIntOp::create(rewriter, loc, /*value=*/0, /*width=*/32); + + // rocdl.mfma.f32.16x16x4f32 : (f32, f32, vector<4xf32>, i32, i32, i32) -> vector<4xf32> + SmallVector args{a, b, c, zeroI32, zeroI32, zeroI32}; + Value res = ROCDL::mfma_f32_16x16x4f32::create(rewriter, loc, accTy, args).getResult(); + + // Store result back to D pointer. + LLVM::StoreOp::create(rewriter, loc, res, dPtr); + rewriter.eraseOp(op); + return success(); + } +}; + +/// Lower `gpu.launch_func` kernel operands so that any `!fly.memref` values are +/// replaced by their type-converted builtin `memref` values. This prevents +/// `unrealized_conversion_cast` materializations from remaining live after +/// partial conversion (e.g., when the surrounding `func.func` signature has +/// been converted to builtin memrefs). +class GpuLaunchFuncOpLowering : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(gpu::LaunchFuncOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto kernelRef = adaptor.getKernel(); + + auto grid = + gpu::KernelDim3{adaptor.getGridSizeX(), adaptor.getGridSizeY(), adaptor.getGridSizeZ()}; + auto block = + gpu::KernelDim3{adaptor.getBlockSizeX(), adaptor.getBlockSizeY(), adaptor.getBlockSizeZ()}; + + std::optional clusterSize = std::nullopt; + if (adaptor.getClusterSizeX() && adaptor.getClusterSizeY() && adaptor.getClusterSizeZ()) { + clusterSize = gpu::KernelDim3{adaptor.getClusterSizeX(), adaptor.getClusterSizeY(), + adaptor.getClusterSizeZ()}; + } + + // Preserve async token result type when present. + Type asyncTokenType = nullptr; + if (Value tok = op.getAsyncToken()) + asyncTokenType = tok.getType(); + + // There are two relevant builder signatures in this MLIR: + // - (kernel, ..., asyncTokenType, asyncDependencies, clusterSize) + // - (kernel, ..., asyncObject, clusterSize) + // Pick the one that matches the original op structure. + if (Value asyncObj = adaptor.getAsyncObject()) { + if (!adaptor.getAsyncDependencies().empty()) + return rewriter.notifyMatchFailure( + op, "launch_func has both asyncObject and asyncDependencies"); + + rewriter.replaceOpWithNewOp( + op, kernelRef, grid, block, adaptor.getDynamicSharedMemorySize(), + adaptor.getKernelOperands(), asyncObj, clusterSize); + return success(); + } + + rewriter.replaceOpWithNewOp( + op, kernelRef, grid, block, adaptor.getDynamicSharedMemorySize(), + adaptor.getKernelOperands(), asyncTokenType, adaptor.getAsyncDependencies(), clusterSize); + return success(); + } +}; + +class FlyTypeConverter : public TypeConverter { +public: + FlyTypeConverter() { + addConversion([](Type type) { return type; }); + + // Convert Fly memref/pointer to a raw LLVM pointer. + addConversion([&](fly::MemRefType flyMemRefTy) -> Type { + unsigned as = mapAddressSpace(flyMemRefTy.getAddressSpace().getValue()); + return LLVM::LLVMPointerType::get(flyMemRefTy.getContext(), as); + }); + addConversion([&](fly::PointerType flyPtrTy) -> Type { + unsigned as = mapAddressSpace(flyPtrTy.getAddressSpace().getValue()); + return LLVM::LLVMPointerType::get(flyPtrTy.getContext(), as); + }); + } +}; + +class FlyToROCDLConversionPass + : public mlir::impl::FlyToROCDLConversionPassBase { +public: + using mlir::impl::FlyToROCDLConversionPassBase< + FlyToROCDLConversionPass>::FlyToROCDLConversionPassBase; + + void runOnOperation() override { + MLIRContext *context = &getContext(); + RewritePatternSet patterns(context); + + ConversionTarget target(getContext()); + + target.addLegalDialect(); + target.addIllegalDialect(); + + target.addLegalOp(); + target.addLegalOp(); + + FlyTypeConverter typeConverter; + + // Ensure function signatures are type-converted; otherwise conversions may rely on + // inserted unrealized casts that remain live. + target.addDynamicallyLegalOp( + [&](func::FuncOp op) { return typeConverter.isSignatureLegal(op.getFunctionType()); }); + target.addDynamicallyLegalOp( + [&](gpu::GPUFuncOp op) { return typeConverter.isSignatureLegal(op.getFunctionType()); }); + + // IMPORTANT: `gpu.launch_func` itself is in a legal dialect, but its kernel operands may + // still carry illegal `!fly.memref` types. If we don't mark it dynamically illegal in that + // case, partial conversion won't try to rewrite it, leaving `unrealized_conversion_cast` + // users alive and causing legalization failure. + target.addDynamicallyLegalOp([&](gpu::LaunchFuncOp op) { + auto isValueLegal = [&](Value v) { + if (!v) + return true; + return typeConverter.isLegal(v.getType()); + }; + + for (Value v : op.getKernelOperands()) + if (!isValueLegal(v)) + return false; + + if (!isValueLegal(op.getDynamicSharedMemorySize())) + return false; + + // Async operands are part of the operand list; keep them consistent as well. + for (Value dep : op.getAsyncDependencies()) + if (!isValueLegal(dep)) + return false; + if (!isValueLegal(op.getAsyncObject())) + return false; + + // Dimensions are typically index and already legal; no need to special-case. + return true; + }); + + patterns.add(typeConverter, context); + patterns.add(typeConverter, context); + patterns.add(typeConverter, context); + patterns.add(typeConverter, context); + patterns.add(typeConverter, context); + patterns.add(typeConverter, context); + patterns.add(typeConverter, context); + patterns.add(typeConverter, context); + patterns.add(typeConverter, context); + patterns.add(typeConverter, context); + patterns.add(typeConverter, context); + + populateFunctionOpInterfaceTypeConversionPattern(patterns, typeConverter); + populateFunctionOpInterfaceTypeConversionPattern(patterns, typeConverter); + + if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) + signalPassFailure(); + } +}; + +} // namespace + +namespace impl { + +std::unique_ptr<::mlir::Pass> createFlyToROCDLConversionPass() { + return std::make_unique(); +} + +} // namespace impl diff --git a/lib/Dialect/CMakeLists.txt b/lib/Dialect/CMakeLists.txt new file mode 100644 index 00000000..08c0cd63 --- /dev/null +++ b/lib/Dialect/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(Fly) diff --git a/lib/Dialect/Fly/CMakeLists.txt b/lib/Dialect/Fly/CMakeLists.txt new file mode 100644 index 00000000..cd56d524 --- /dev/null +++ b/lib/Dialect/Fly/CMakeLists.txt @@ -0,0 +1,18 @@ +add_mlir_dialect_library(MLIRFlyDialect + IR/FlyDialect.cpp + IR/FlyOps.cpp + IR/FlyTypeDefs.cpp + IR/FlyAttrDefs.cpp + Utils/IntUtils.cpp + Utils/IntTupleUtils.cpp + Utils/NormalForm.cpp + Transforms/LayoutLowering.cpp + Transforms/FlyCanonicalize.cpp + + DEPENDS + MLIRFlyIncGen + FlyTransformPassIncGen + + LINK_LIBS + MLIRIR +) diff --git a/lib/Dialect/Fly/IR/FlyAttrDefs.cpp b/lib/Dialect/Fly/IR/FlyAttrDefs.cpp new file mode 100644 index 00000000..686b2261 --- /dev/null +++ b/lib/Dialect/Fly/IR/FlyAttrDefs.cpp @@ -0,0 +1,455 @@ + +#include "llvm/ADT/TypeSwitch.h" + +#include "mlir/IR/BuiltinAttributes.h" + +#include "flydsl/Dialect/Fly/IR/FlyDialect.h" +#include "flydsl/Dialect/Fly/Utils/IntUtils.h" + +namespace mlir::fly { + +//===----------------------------------------------------------------------===// +// Class Definitions +//===----------------------------------------------------------------------===// + +IntTupleAttr IntTupleAttr::getLeafNone(MLIRContext *ctx) { return get(ctx, IntAttr::getNone(ctx)); } +IntTupleAttr IntTupleAttr::getLeafStatic(MLIRContext *ctx, int32_t value) { + return get(ctx, IntAttr::getStatic(ctx, value)); +} +IntTupleAttr IntTupleAttr::getLeafDynamic(MLIRContext *ctx, int32_t width, int32_t divisibility) { + return get(ctx, IntAttr::getDynamic(ctx, width, divisibility)); +} +bool IntTupleAttr::isLeafNone() const { + if (this->isLeaf()) { + if (auto intAttr = dyn_cast(this->getValue())) { + return intAttr.isNone(); + } + } + return false; +} +bool IntTupleAttr::isLeafStaticValue(int32_t value) const { + if (this->isLeaf()) { + if (auto intAttr = dyn_cast(this->getValue())) { + return intAttr.isStaticValue(value); + } + } + return false; +} + +IntAttr IntTupleAttr::getLeafAsInt() const { + assert(this->isLeaf() && "Non-leaf attribute cannot be converted to IntAttr"); + return cast(this->getValue()); +} +BasisAttr IntTupleAttr::getLeafAsBasis() const { + assert(this->isLeaf() && "Non-leaf attribute cannot be converted to BasisAttr"); + return cast(this->getValue()); +} + +int32_t IntTupleAttr::dyncLeafCount() const { + if (this->isLeaf()) { + return this->isStatic() ? 0 : 1; + } + int32_t count = 0; + for (int32_t i = 0; i < this->rank(); ++i) { + count += this->at(i).dyncLeafCount(); + } + return count; +} + +//===----------------------------------------------------------------------===// +// Interface methods +//===----------------------------------------------------------------------===// + +bool IntAttr::isStatic() const { return getValue() != std::numeric_limits::min(); } + +bool BasisAttr::isStatic() const { return cast(getValue()).isStatic(); } + +int32_t BasisAttr::depth() { return static_cast(getModes().size()); } + +bool IntTupleAttr::isLeaf() const { return !isa(getValue()); } + +bool IntTupleAttr::isStatic() const { + if (auto tupleAttr = dyn_cast(this->getValue())) { + for (int i = 0; i < rank(); ++i) { + if (!at(i).isStatic()) { + return false; + } + } + return true; + } else if (auto basisAttr = dyn_cast(getValue())) { + return basisAttr.isStatic(); + } else if (auto intAttr = dyn_cast(getValue())) { + return intAttr.isStatic(); + } + return true; +} + +int32_t IntTupleAttr::rank() const { + if (auto tupleAttr = dyn_cast(this->getValue())) { + return tupleAttr.size(); + } + return 1; +} +int32_t IntTupleAttr::rank(int32_t idx) const { + if (auto tupleAttr = dyn_cast(this->getValue())) { + return cast(tupleAttr[idx]).rank(); + } + assert(idx == 0); + return 1; +} +int32_t IntTupleAttr::rank(ArrayRef idxs) const { + IntTupleAttr result = *this; + for (int32_t idx : idxs) { + result = result.at(idx); + } + return result.rank(); +} + +int32_t IntTupleAttr::depth() const { + if (auto tupleAttr = dyn_cast(this->getValue())) { + int maxLeafDepth = at(0).depth(); + for (int i = 1; i < rank(); ++i) { + maxLeafDepth = std::max(maxLeafDepth, at(i).depth()); + } + return 1 + maxLeafDepth; + } + return 0; +} +int32_t IntTupleAttr::depth(int32_t idx) const { + if (auto tupleAttr = dyn_cast(this->getValue())) { + return cast(tupleAttr[idx]).depth(); + } + assert(idx == 0); + return 0; +} +int32_t IntTupleAttr::depth(ArrayRef idxs) const { + IntTupleAttr result = *this; + for (int32_t idx : idxs) { + result = result.at(idx); + } + return result.depth(); +} + +IntTupleAttr IntTupleAttr::at(int32_t idx) const { + if (auto tupleAttr = dyn_cast(this->getValue())) { + return cast(tupleAttr[idx]); + } + assert(idx == 0 && "Index out of bounds for non-array pattern"); + return *this; +} +IntTupleAttr IntTupleAttr::at(ArrayRef idxs) const { + IntTupleAttr result = *this; + for (int32_t idx : idxs) { + result = result.at(idx); + } + return result; +} + +bool LayoutAttr::isStatic() const { return getShape().isStatic() && getStride().isStatic(); } + +bool LayoutAttr::isStaticShape() const { return getShape().isStatic(); } + +bool LayoutAttr::isStaticStride() const { return getStride().isStatic(); } + +bool LayoutAttr::isLeaf() const { return getShape().isLeaf(); } + +int32_t LayoutAttr::rank() const { return getShape().rank(); } +int32_t LayoutAttr::rank(int32_t idx) const { return getShape().rank(idx); } +int32_t LayoutAttr::rank(ArrayRef idxs) const { return getShape().rank(idxs); } + +int32_t LayoutAttr::depth() const { return getShape().depth(); } +int32_t LayoutAttr::depth(int32_t idx) const { return getShape().depth(idx); } +int32_t LayoutAttr::depth(ArrayRef idxs) const { return getShape().depth(idxs); } + +LayoutAttr LayoutAttr::at(int32_t idx) const { + return LayoutAttr::get(getContext(), getShape().at(idx), getStride().at(idx)); +} +LayoutAttr LayoutAttr::at(ArrayRef idxs) const { + return LayoutAttr::get(getContext(), getShape().at(idxs), getStride().at(idxs)); +} + +bool ComposedLayoutAttr::isStatic() const { + return isStaticOuter() && isStaticOffset() && isStaticInner(); +} +bool ComposedLayoutAttr::isStaticOuter() const { return getOuter().isStatic(); } +bool ComposedLayoutAttr::isStaticOffset() const { return getOffset().isStatic(); } +bool ComposedLayoutAttr::isStaticInner() const { + if (auto inner = dyn_cast(getInner())) { + return inner.isStatic(); + } else if (auto layout = dyn_cast(getInner())) { + return layout.isStatic(); + } else if (auto basis = dyn_cast(getInner())) { + return true; + } else { + assert(false && "invalid InnerAttr of ComposedLayoutAttr"); + return false; + } +} + +bool ComposedLayoutAttr::isLeaf() const { return getOuter().isLeaf(); } + +int32_t ComposedLayoutAttr::rank() const { return getOuter().rank(); } +int32_t ComposedLayoutAttr::rank(int32_t idx) const { return getOuter().rank(idx); } +int32_t ComposedLayoutAttr::rank(ArrayRef idxs) const { return getOuter().rank(idxs); } +int32_t ComposedLayoutAttr::depth() const { return getOuter().depth(); } +int32_t ComposedLayoutAttr::depth(int32_t idx) const { return getOuter().depth(idx); } +int32_t ComposedLayoutAttr::depth(ArrayRef idxs) const { return getOuter().depth(idxs); } + +ComposedLayoutAttr ComposedLayoutAttr::at(int32_t idx) const { + return ComposedLayoutAttr::get(getContext(), getInner(), getOffset(), getOuter().at(idx)); +} +ComposedLayoutAttr ComposedLayoutAttr::at(ArrayRef idxs) const { + return ComposedLayoutAttr::get(getContext(), getInner(), getOffset(), getOuter().at(idxs)); +} + +int32_t TileAttr::rank() const { + if (auto arrayAttr = dyn_cast(this->getValue())) { + return arrayAttr.size(); + } + assert(false && "invalid TileAttr"); + return 0; +} + +bool TileAttr::isLeaf() const { return !isa(this->getValue()); } +Attribute TileAttr::at(int32_t idx) const { return cast(this->getValue())[idx]; } +bool TileAttr::isNoneMode() const { + if (!isLeaf()) + return false; + if (auto intAttr = dyn_cast(this->getValue())) + return intAttr.isNone(); + return false; +} +bool TileAttr::isNoneMode(int32_t idx) const { + if (auto intAttr = dyn_cast(at(idx))) + return intAttr.isNone(); + return false; +} + +//===----------------------------------------------------------------------===// +// Parser and Printer +//===----------------------------------------------------------------------===// + +void prettyPrintIntAttr(::mlir::AsmPrinter &odsPrinter, IntAttr attr) { + if (attr.isStatic()) { + odsPrinter << attr.getValue(); + } else { + odsPrinter << "?"; + if (attr.getWidth() != 32 || attr.getDivisibility() != 1) { + odsPrinter << "{"; + bool delimiter = false; + if (attr.getWidth() != 32) { + odsPrinter << "i" << attr.getWidth(); + delimiter = true; + } + if (attr.getDivisibility() != 1) { + if (delimiter) { + odsPrinter << " "; + } + odsPrinter << "div=" << attr.getDivisibility(); + } + odsPrinter << "}"; + } + } +} + +::mlir::Attribute IntAttr::parse(::mlir::AsmParser &odsParser, ::mlir::Type odsType) { + auto *ctx = odsParser.getBuilder().getContext(); + + if (odsParser.parseOptionalQuestion().succeeded()) { + int32_t width = 32; + int32_t divisibility = 1; + if (odsParser.parseOptionalLBrace().succeeded()) { + if (odsParser.parseOptionalKeyword("i32")) { + if (odsParser.parseOptionalKeyword("i64").succeeded()) { + width = 64; + } + } + if (odsParser.parseOptionalKeyword("div").succeeded()) { + if (odsParser.parseEqual() || odsParser.parseDecimalInteger(divisibility)) + return {}; + } + if (odsParser.parseRBrace()) + return {}; + } + return IntAttr::get(ctx, width, divisibility); + } + int32_t value; + if (odsParser.parseDecimalInteger(value)) + return {}; + return IntAttr::get(ctx, value); +} + +void IntAttr::print(::mlir::AsmPrinter &odsPrinter) const { prettyPrintIntAttr(odsPrinter, *this); } + +::mlir::Attribute parseLeafAttr(::mlir::AsmParser &odsParser) { + auto *ctx = odsParser.getBuilder().getContext(); + + Attribute valueAttr; + if (odsParser.parseOptionalStar().succeeded()) { + valueAttr = IntAttr::getNone(ctx); + } else if (odsParser.parseOptionalQuestion().succeeded()) { + int32_t width = 32; + int32_t divisibility = 1; + if (odsParser.parseOptionalLBrace().succeeded()) { + if (odsParser.parseOptionalKeyword("i32")) { + if (odsParser.parseOptionalKeyword("i64").succeeded()) { + width = 64; + } + } + if (odsParser.parseOptionalKeyword("div").succeeded()) { + if (odsParser.parseEqual() || odsParser.parseDecimalInteger(divisibility)) + return {}; + } + if (odsParser.parseRBrace()) + return {}; + } + valueAttr = IntAttr::get(ctx, width, divisibility); + } else { + int32_t value; + if (odsParser.parseDecimalInteger(value)) + return {}; + valueAttr = IntAttr::get(ctx, value); + } + + SmallString<16> strModes; + StringRef strRefModes; + if (odsParser.parseOptionalKeyword(&strRefModes)) + return valueAttr; + + SmallVector modes; + SmallVector strRefModeList; + + strRefModes.consume_front("E"); + strRefModes.split(strRefModeList, "E"); + for (StringRef strRefMode : strRefModeList) { + int32_t mode; + if (strRefMode.getAsInteger(10, mode)) + return {}; + modes.push_back(mode); + } + return BasisAttr::get(ctx, valueAttr, modes); +} + +::mlir::Attribute BasisAttr::parse(::mlir::AsmParser &odsParser, ::mlir::Type odsType) { + auto valueAttr = parseLeafAttr(odsParser); + if (!isa(valueAttr)) + return {}; + return valueAttr; +} + +void BasisAttr::print(::mlir::AsmPrinter &odsPrinter) const { + if (auto intAttr = dyn_cast(this->getValue())) { + prettyPrintIntAttr(odsPrinter, intAttr); + } else { + llvm_unreachable("invalid BasisAttr value"); + } + for (int32_t mode : getModes()) + odsPrinter << "E" << mode; +} + +::mlir::Attribute IntTupleAttr::parse(::mlir::AsmParser &odsParser, ::mlir::Type odsType) { + auto *ctx = odsParser.getBuilder().getContext(); + if (odsParser.parseOptionalLParen().succeeded()) { + SmallVector elements; + do { + elements.push_back(IntTupleAttr::parse(odsParser, odsType)); + } while (odsParser.parseOptionalComma().succeeded()); + if (odsParser.parseRParen()) + return {}; + return IntTupleAttr::get(ArrayAttr::get(ctx, elements)); + } else { + return IntTupleAttr::get(parseLeafAttr(odsParser)); + } +} + +void IntTupleAttr::print(::mlir::AsmPrinter &odsPrinter) const { + if (auto tupleAttr = dyn_cast(this->getValue())) { + odsPrinter << "("; + at(0).print(odsPrinter); + for (int i = 1; i < rank(); ++i) { + odsPrinter << ","; + at(i).print(odsPrinter); + } + odsPrinter << ")"; + } else { + ::llvm::TypeSwitch(this->getValue()) + .Case([&](IntAttr attr) { + if (attr.isNone()) { + odsPrinter << "*"; + } else { + prettyPrintIntAttr(odsPrinter, attr); + } + }) + .Case([&](BasisAttr attr) { attr.print(odsPrinter); }) + .DefaultUnreachable("invalid LeafAttr"); + } +} + +::mlir::Attribute TileAttr::parse(::mlir::AsmParser &odsParser, ::mlir::Type odsType) { + auto *ctx = odsParser.getBuilder().getContext(); + auto parseElement = [&]() -> Attribute { + auto shapeAttr = IntTupleAttr::parse(odsParser, odsType); + if (!shapeAttr) + return {}; + auto shape = cast(shapeAttr); + if (odsParser.parseOptionalColon().succeeded()) { + auto strideAttr = IntTupleAttr::parse(odsParser, odsType); + if (!strideAttr) + return {}; + auto stride = cast(strideAttr); + return LayoutAttr::get(ctx, shape, stride); + } + if (!shape.isLeaf()) + return {}; + Attribute leaf = shape.getValue(); + if (isa(leaf)) + return leaf; + return {}; + }; + + if (odsParser.parseOptionalLSquare().succeeded()) { + SmallVector elements; + do { + Attribute elem = parseElement(); + if (!elem) + return {}; + elements.push_back(elem); + } while (odsParser.parseOptionalVerticalBar().succeeded()); + if (odsParser.parseRSquare()) + return {}; + return TileAttr::get(ArrayAttr::get(ctx, elements)); + } else { + Attribute elem = parseElement(); + if (!elem) + return {}; + return TileAttr::get(elem); + } +} + +void TileAttr::print(::mlir::AsmPrinter &odsPrinter) const { + auto elemPrint = [&](Attribute attr) { + ::llvm::TypeSwitch(attr) + .Case([&](IntAttr attr) { + if (attr.isNone()) { + odsPrinter << "*"; + } else { + prettyPrintIntAttr(odsPrinter, attr); + } + }) + .Case([&](LayoutAttr attr) { attr.print(odsPrinter); }) + .DefaultUnreachable("invalid LayoutAttr"); + }; + if (isLeaf()) { + elemPrint(this->getValue()); + return; + } + odsPrinter << "["; + elemPrint(this->at(0)); + for (int i = 1; i < this->rank(); ++i) { + odsPrinter << "|"; + elemPrint(this->at(i)); + } + odsPrinter << "]"; +} + +} // namespace mlir::fly diff --git a/lib/Dialect/Fly/IR/FlyDialect.cpp b/lib/Dialect/Fly/IR/FlyDialect.cpp new file mode 100644 index 00000000..9ea39b1d --- /dev/null +++ b/lib/Dialect/Fly/IR/FlyDialect.cpp @@ -0,0 +1,41 @@ +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/DialectImplementation.h" +#include "llvm/ADT/TypeSwitch.h" + +#include "flydsl/Dialect/Fly/IR/FlyDialect.h" + +using namespace mlir; +using namespace mlir::fly; + +#include "flydsl/Dialect/Fly/IR/FlyDialect.cpp.inc" + +#include "flydsl/Dialect/Fly/IR/FlyEnums.cpp.inc" + +namespace mlir::fly { +#include "flydsl/Dialect/Fly/IR/FlyAttrInterfaces.cpp.inc" +#include "flydsl/Dialect/Fly/IR/FlyTypeInterfaces.cpp.inc" + +#include "flydsl/Dialect/Fly/IR/FlyAttrConstraints.cpp.inc" +#include "flydsl/Dialect/Fly/IR/FlyTypeConstraints.cpp.inc" +} // namespace mlir::fly + +#define GET_TYPEDEF_CLASSES +#include "flydsl/Dialect/Fly/IR/FlyTypeDefs.cpp.inc" + +#define GET_ATTRDEF_CLASSES +#include "flydsl/Dialect/Fly/IR/FlyAttrDefs.cpp.inc" + +void FlyDialect::initialize() { + addTypes< +#define GET_TYPEDEF_LIST +#include "flydsl/Dialect/Fly/IR/FlyTypeDefs.cpp.inc" + >(); + addAttributes< +#define GET_ATTRDEF_LIST +#include "flydsl/Dialect/Fly/IR/FlyAttrDefs.cpp.inc" + >(); + addOperations< +#define GET_OP_LIST +#include "flydsl/Dialect/Fly/IR/FlyOps.cpp.inc" + >(); +} diff --git a/lib/Dialect/Fly/IR/FlyOps.cpp b/lib/Dialect/Fly/IR/FlyOps.cpp new file mode 100644 index 00000000..653e7be7 --- /dev/null +++ b/lib/Dialect/Fly/IR/FlyOps.cpp @@ -0,0 +1,1235 @@ + +#include "mlir/IR/Builders.h" +#include "mlir/IR/DialectImplementation.h" +#include "mlir/IR/OpImplementation.h" +#include "mlir/Support/LogicalResult.h" + +#include "flydsl/Dialect/Fly/IR/FlyDialect.h" +#include "flydsl/Dialect/Fly/Utils/IntTupleUtils.h" +#include "flydsl/Dialect/Fly/Utils/LayoutUtils.h" + +#include +#include + +#define GET_OP_CLASSES +#include "flydsl/Dialect/Fly/IR/FlyOps.cpp.inc" + +#include +#include + +using namespace mlir; +using namespace mlir::fly; + +namespace { + +IntTupleAttr makeDynamicLike(IntTupleAttr guide) { + auto *ctx = guide.getContext(); + IntTupleBuilder builder(ctx); + return intTupleTransformLeaf( + builder, [ctx](IntTupleAttr) { return IntTupleAttr::get(IntAttr::getDynamic(ctx)); }, guide); +} + +IntTupleAttr makeCompactStride(IntTupleAttr shapeAttr) { + auto *ctx = shapeAttr.getContext(); + IntAttr running = IntAttr::getStatic(ctx, 1); + + std::function visit = [&](IntTupleAttr shape) -> IntTupleAttr { + if (shape.isLeaf()) { + IntTupleAttr stride = IntTupleAttr::get(running); + running = running * shape.getLeafAsInt(); + return stride; + } + SmallVector elements; + elements.reserve(shape.rank()); + for (int i = 0; i < shape.rank(); ++i) { + elements.push_back(visit(shape.at(i))); + } + return IntTupleAttr::get(ArrayAttr::get(ctx, elements)); + }; + + return visit(shapeAttr); +} + +LayoutAttr makeOrderedLayoutAttr(IntTupleAttr shapeAttr, IntTupleAttr orderAttr) { + auto *ctx = shapeAttr.getContext(); + IntTupleBuilder builder(ctx); + IntTupleAttr flatShape = intTupleFlatten(builder, shapeAttr); + IntTupleAttr flatOrder = intTupleFlatten(builder, orderAttr); + + if (flatShape.isLeaf() || flatOrder.isLeaf() || flatShape.rank() != flatOrder.rank()) { + return LayoutAttr::get(ctx, shapeAttr, makeCompactStride(shapeAttr)); + } + + int32_t rank = flatShape.rank(); + SmallVector strideElems(rank); + IntAttr running = IntAttr::getStatic(ctx, 1); + + for (int i = 0; i < rank; ++i) { + IntAttr orderVal = flatOrder.at(i).getLeafAsInt(); + if (!orderVal.isStatic()) { + return LayoutAttr::get(ctx, shapeAttr, makeCompactStride(shapeAttr)); + } + int64_t idx = orderVal.getValue(); + if (idx < 0 || idx >= rank || strideElems[idx]) { + return LayoutAttr::get(ctx, shapeAttr, makeCompactStride(shapeAttr)); + } + strideElems[idx] = IntTupleAttr::get(running); + running = running * flatShape.at(idx).getLeafAsInt(); + } + + for (auto elem : strideElems) { + if (!elem) { + return LayoutAttr::get(ctx, shapeAttr, makeCompactStride(shapeAttr)); + } + } + + IntTupleAttr flatStride = IntTupleAttr::get(ArrayAttr::get(ctx, strideElems)); + IntTupleAttr strideAttr = intTupleUnflatten(builder, flatStride, shapeAttr); + return LayoutAttr::get(ctx, shapeAttr, strideAttr); +} + +} // namespace + +#define FLY_INFER_RETURN_TYPES(OP) \ + llvm::LogicalResult OP::inferReturnTypes( \ + mlir::MLIRContext *context, std::optional<::mlir::Location> location, \ + mlir::ValueRange operands, mlir::DictionaryAttr attributes, \ + mlir::OpaqueProperties properties, mlir::RegionRange regions, \ + llvm::SmallVectorImpl &inferredReturnTypes) + +FLY_INFER_RETURN_TYPES(MakeLayoutOp) { + auto shapeType = dyn_cast(operands[0].getType()); + IntTupleAttr shapeAttr = shapeType.getAttr(); + IntTupleAttr strideAttr; + + if (operands.size() > 1) { + strideAttr = dyn_cast(operands[1].getType()).getAttr(); + } else { + strideAttr = makeCompactStride(shapeAttr); + } + auto layoutAttr = LayoutAttr::get(context, shapeAttr, strideAttr); + inferredReturnTypes.assign({LayoutType::get(context, layoutAttr)}); + return success(); +} + +FLY_INFER_RETURN_TYPES(MakeTileOp) { + SmallVector layouts; + for (auto op : operands) { + auto layoutType = dyn_cast(op.getType()); + if (!layoutType) + return failure(); + layouts.push_back(layoutType.getAttr()); + } + auto tileAttr = TileAttr::get(ArrayAttr::get(context, layouts)); + inferredReturnTypes.assign({TileType::get(context, tileAttr)}); + return success(); +} + +FLY_INFER_RETURN_TYPES(MakeViewOp) { + auto ptrTy = dyn_cast(operands[0].getType()); + auto layoutTy = dyn_cast(operands[1].getType()); + if (!ptrTy || !layoutTy) + return failure(); + inferredReturnTypes.assign( + {MemRefType::get(ptrTy.getElemTy(), ptrTy.getAddressSpace(), layoutTy.getAttr())}); + return success(); +} + +FLY_INFER_RETURN_TYPES(MakeLayoutLikeOp) { + if (auto layoutTy = dyn_cast(operands[0].getType())) { + LayoutAttr inferred = layoutTy.getAttr(); + inferredReturnTypes.assign({LayoutType::get(context, inferred)}); + return success(); + } + if (auto memrefTy = dyn_cast(operands[0].getType())) { + LayoutAttr inferred = memrefTy.getLayout(); + inferredReturnTypes.assign({LayoutType::get(context, inferred)}); + return success(); + } + return failure(); +} + +FLY_INFER_RETURN_TYPES(MakeOrderedLayoutOp) { + auto shapeTy = dyn_cast(operands[0].getType()); + auto orderTy = dyn_cast(operands[1].getType()); + if (!shapeTy || !orderTy) + return failure(); + IntTupleAttr shapeAttr = shapeTy.getAttr(); + LayoutAttr layoutAttr = makeOrderedLayoutAttr(shapeAttr, orderTy.getAttr()); + inferredReturnTypes.assign({LayoutType::get(context, layoutAttr)}); + return success(); +} + +FLY_INFER_RETURN_TYPES(MakeComposedLayoutOp) { + auto offsetTy = dyn_cast(operands[1].getType()); + auto outerTy = dyn_cast(operands[2].getType()); + if (!offsetTy || !outerTy) + return failure(); + Attribute innerAttr = nullptr; + if (auto innerLayoutTy = dyn_cast(operands[0].getType())) { + innerAttr = innerLayoutTy.getAttr(); + } else if (auto innerComposedTy = dyn_cast(operands[0].getType())) { + innerAttr = innerComposedTy.getAttr(); + } else { + return failure(); + } + auto composedAttr = + ComposedLayoutAttr::get(context, innerAttr, offsetTy.getAttr(), outerTy.getAttr()); + inferredReturnTypes.assign({ComposedLayoutType::get(context, composedAttr)}); + return success(); +} + +FLY_INFER_RETURN_TYPES(MakeIdentityLayoutOp) { + auto shapeTy = dyn_cast(operands[0].getType()); + + IntTupleAttr shapeAttr = shapeTy.getAttr(); + IntTupleAttr strideAttr = intTupleMakeBasisLike(shapeAttr); + LayoutAttr layoutAttr = LayoutAttr::get(context, shapeAttr, strideAttr); + inferredReturnTypes.assign({LayoutType::get(context, layoutAttr)}); + return success(); +} + +FLY_INFER_RETURN_TYPES(MakeIdentityTensorOp) { + auto shapeTy = dyn_cast(operands[0].getType()); + if (!shapeTy) + return failure(); + + IntTupleAttr shapeAttr = shapeTy.getAttr(); + IntTupleAttr strideAttr = intTupleMakeBasisLike(shapeAttr); + LayoutAttr layoutAttr = LayoutAttr::get(context, shapeAttr, strideAttr); + + IntTupleBuilder builder(context); + IntTupleAttr zeroBaseAttr = intTupleTransformLeaf( + builder, [](IntTupleAttr attr) { return IntTupleAttr::getLeafStatic(attr.getContext(), 0); }, + shapeAttr); + inferredReturnTypes.assign({CoordTensorType::get(context, zeroBaseAttr, layoutAttr)}); + return success(); +} + +FLY_INFER_RETURN_TYPES(MakeFragmentLikeOp) { + inferredReturnTypes.assign({operands[0].getType()}); + return success(); +} + +FLY_INFER_RETURN_TYPES(GetScalarOp) { + auto intTupleType = dyn_cast(operands[0].getType()); + if (!intTupleType) + return failure(); + // Must be a leaf IntTuple + if (!intTupleType.getAttr().isLeaf()) + return failure(); + inferredReturnTypes.assign({IntegerType::get(context, 32)}); + return success(); +} + +FLY_INFER_RETURN_TYPES(GetLeavesOp) { + auto inputTupleTy = dyn_cast(operands[0].getType()); + if (inputTupleTy) { + IntTupleBuilder builder(context); + IntTupleAttr flat = intTupleFlatten(builder, inputTupleTy.getAttr()); + inferredReturnTypes.assign({IntTupleType::get(flat)}); + return success(); + } + auto inputLayoutTy = dyn_cast(operands[0].getType()); + if (!inputLayoutTy) + return failure(); + IntTupleBuilder builder(context); + IntTupleAttr flat = intTupleFlatten(builder, inputLayoutTy.getAttr().getShape()); + inferredReturnTypes.assign({IntTupleType::get(flat)}); + return success(); +} + +FLY_INFER_RETURN_TYPES(GetShapeOp) { + auto layoutType = dyn_cast(operands[0].getType()); + if (!layoutType) + return failure(); + LayoutAttr profile = layoutType.getAttr(); + inferredReturnTypes.assign({IntTupleType::get(profile.getShape())}); + return success(); +} + +FLY_INFER_RETURN_TYPES(GetStrideOp) { + auto layoutType = dyn_cast(operands[0].getType()); + if (!layoutType) + return failure(); + LayoutAttr profile = layoutType.getAttr(); + inferredReturnTypes.assign({IntTupleType::get(profile.getStride())}); + return success(); +} + +FLY_INFER_RETURN_TYPES(GetLayoutOp) { + auto memrefTy = dyn_cast(operands[0].getType()); + if (!memrefTy) + return failure(); + inferredReturnTypes.assign({LayoutType::get(context, memrefTy.getLayout())}); + return success(); +} + +FLY_INFER_RETURN_TYPES(GetIterOp) { + auto memrefTy = dyn_cast(operands[0].getType()); + if (!memrefTy) + return failure(); + inferredReturnTypes.assign({PointerType::get(memrefTy.getElemTy(), memrefTy.getAddressSpace())}); + return success(); +} + +FLY_INFER_RETURN_TYPES(GetLeafOp) { + int32_t leafIdx = properties.as()->leaf_idx.getInt(); + + if (auto layoutType = dyn_cast(operands[0].getType())) { + LayoutAttr profile = layoutType.getAttr(); + LayoutAttr leafProfile = profile.at(leafIdx); + inferredReturnTypes.assign({LayoutType::get(context, leafProfile)}); + return success(); + } + + if (auto intTupleType = dyn_cast(operands[0].getType())) { + IntTupleAttr profile = intTupleType.getAttr(); + IntTupleAttr leafProfile = profile.at(leafIdx); + inferredReturnTypes.assign({IntTupleType::get(leafProfile)}); + return success(); + } + + return failure(); +} + +FLY_INFER_RETURN_TYPES(ComposedGetInnerOp) { + auto inputTy = dyn_cast(operands[0].getType()); + if (!inputTy) + return failure(); + auto innerAttr = inputTy.getAttr().getInner(); + if (auto swizzleAttr = dyn_cast(innerAttr)) { + inferredReturnTypes.assign({SwizzleType::get(context, swizzleAttr)}); + return success(); + } else if (auto layoutAttr = dyn_cast(innerAttr)) { + inferredReturnTypes.assign({LayoutType::get(context, layoutAttr)}); + return success(); + } else if (auto composedLayoutAttr = dyn_cast(innerAttr)) { + inferredReturnTypes.assign({ComposedLayoutType::get(context, composedLayoutAttr)}); + return success(); + } + return failure(); +} + +FLY_INFER_RETURN_TYPES(ComposedGetOffsetOp) { + auto inputTy = dyn_cast(operands[0].getType()); + if (!inputTy) + return failure(); + inferredReturnTypes.assign({IntTupleType::get(inputTy.getAttr().getOffset())}); + return success(); +} + +FLY_INFER_RETURN_TYPES(ComposedGetOuterOp) { + auto inputTy = dyn_cast(operands[0].getType()); + if (!inputTy) + return failure(); + inferredReturnTypes.assign({LayoutType::get(context, inputTy.getAttr().getOuter())}); + return success(); +} + +FLY_INFER_RETURN_TYPES(IntTupleAddOp) { + auto lhsTy = dyn_cast(operands[0].getType()); + auto rhsTy = dyn_cast(operands[1].getType()); + if (!lhsTy || !rhsTy) + return failure(); + IntTupleBuilder builder(context); + inferredReturnTypes.assign( + {IntTupleType::get(intTupleAdd(builder, lhsTy.getAttr(), rhsTy.getAttr()))}); + return success(); +} + +FLY_INFER_RETURN_TYPES(IntTupleSubOp) { + auto lhsTy = dyn_cast(operands[0].getType()); + auto rhsTy = dyn_cast(operands[1].getType()); + if (!lhsTy || !rhsTy) + return failure(); + IntTupleBuilder builder(context); + inferredReturnTypes.assign( + {IntTupleType::get(intTupleSub(builder, lhsTy.getAttr(), rhsTy.getAttr()))}); + return success(); +} + +FLY_INFER_RETURN_TYPES(IntTupleMulOp) { + auto lhsTy = dyn_cast(operands[0].getType()); + auto rhsTy = dyn_cast(operands[1].getType()); + if (!lhsTy || !rhsTy) + return failure(); + IntTupleBuilder builder(context); + inferredReturnTypes.assign( + {IntTupleType::get(intTupleMul(builder, lhsTy.getAttr(), rhsTy.getAttr()))}); + return success(); +} + +FLY_INFER_RETURN_TYPES(IntTupleDivOp) { + auto lhsTy = dyn_cast(operands[0].getType()); + auto rhsTy = dyn_cast(operands[1].getType()); + if (!lhsTy || !rhsTy) + return failure(); + IntTupleBuilder builder(context); + inferredReturnTypes.assign( + {IntTupleType::get(intTupleDiv(builder, lhsTy.getAttr(), rhsTy.getAttr()))}); + return success(); +} + +FLY_INFER_RETURN_TYPES(IntTupleModOp) { + auto lhsTy = dyn_cast(operands[0].getType()); + auto rhsTy = dyn_cast(operands[1].getType()); + if (!lhsTy || !rhsTy) + return failure(); + inferredReturnTypes.assign({IntTupleType::get(makeDynamicLike(lhsTy.getAttr()))}); + return success(); +} + +FLY_INFER_RETURN_TYPES(IntTupleProductEachOp) { + auto inputTy = dyn_cast(operands[0].getType()); + if (!inputTy) + return failure(); + IntTupleBuilder builder(context); + inferredReturnTypes.assign({IntTupleType::get(intTupleProductEach(builder, inputTy.getAttr()))}); + return success(); +} + +FLY_INFER_RETURN_TYPES(IntTupleProductOp) { + auto inputTy = dyn_cast(operands[0].getType()); + if (!inputTy) + return failure(); + IntTupleBuilder builder(context); + IntTupleAttr size = intTupleProduct(builder, inputTy.getAttr()); + inferredReturnTypes.assign({IntTupleType::get(size)}); + return success(); +} + +FLY_INFER_RETURN_TYPES(ShapeDivOp) { + auto lhsTy = dyn_cast(operands[0].getType()); + auto rhsTy = dyn_cast(operands[1].getType()); + if (!lhsTy || !rhsTy) + return failure(); + IntTupleBuilder builder(context); + inferredReturnTypes.assign( + {IntTupleType::get(intTupleShapeDiv(builder, lhsTy.getAttr(), rhsTy.getAttr()))}); + return success(); +} + +FLY_INFER_RETURN_TYPES(CeilDivOp) { + auto lhsTy = dyn_cast(operands[0].getType()); + auto rhsTy = dyn_cast(operands[1].getType()); + if (!lhsTy || !rhsTy) + return failure(); + IntTupleBuilder builder(context); + inferredReturnTypes.assign( + {IntTupleType::get(intTupleCeilDiv(builder, lhsTy.getAttr(), rhsTy.getAttr()))}); + return success(); +} + +FLY_INFER_RETURN_TYPES(ElemLessOp) { + auto lhsTy = dyn_cast(operands[0].getType()); + auto rhsTy = dyn_cast(operands[1].getType()); + if (!lhsTy || !rhsTy) + return failure(); + IntTupleBuilder builder(context); + IntTupleAttr result = intTupleElemLess(builder, lhsTy.getAttr(), rhsTy.getAttr()); + inferredReturnTypes.assign({IntTupleType::get(result)}); + return success(); +} + +FLY_INFER_RETURN_TYPES(EqualOp) { + auto lhsTy = dyn_cast(operands[0].getType()); + auto rhsTy = dyn_cast(operands[1].getType()); + if (!lhsTy || !rhsTy) + return failure(); + bool isCongruent = intTupleIsCongruent(lhsTy.getAttr(), rhsTy.getAttr()); + IntTupleAttr result = IntTupleAttr::getLeafStatic(context, isCongruent ? 1 : 0); + inferredReturnTypes.assign({IntTupleType::get(result)}); + return success(); +} + +FLY_INFER_RETURN_TYPES(AppendOp) { + auto baseLayout = dyn_cast(operands[0].getType()); + auto elemLayout = dyn_cast(operands[1].getType()); + if (!baseLayout || !elemLayout) + return failure(); + + int32_t n = -1; + if (properties) { + auto nAttr = properties.as()->n; + if (nAttr) + n = static_cast(nAttr.getInt()); + } + + LayoutAttr baseAttr = baseLayout.getAttr(); + LayoutAttr elemAttr = elemLayout.getAttr(); + + IntTupleBuilder builder(context); + IntTupleAttr newShape = intTupleAppend(builder, baseAttr.getShape(), elemAttr.getShape(), n); + IntTupleAttr newStride = intTupleAppend(builder, baseAttr.getStride(), elemAttr.getStride(), n); + + inferredReturnTypes.assign( + {LayoutType::get(context, LayoutAttr::get(context, newShape, newStride))}); + return success(); +} + +FLY_INFER_RETURN_TYPES(PrependOp) { + auto baseLayout = dyn_cast(operands[0].getType()); + auto elemLayout = dyn_cast(operands[1].getType()); + if (!baseLayout || !elemLayout) + return failure(); + + int32_t n = -1; + if (properties) { + auto nAttr = properties.as()->n; + if (nAttr) + n = static_cast(nAttr.getInt()); + } + + LayoutAttr baseAttr = baseLayout.getAttr(); + LayoutAttr elemAttr = elemLayout.getAttr(); + + IntTupleBuilder builder(context); + IntTupleAttr newShape = intTuplePrepend(builder, baseAttr.getShape(), elemAttr.getShape(), n); + IntTupleAttr newStride = intTuplePrepend(builder, baseAttr.getStride(), elemAttr.getStride(), n); + + inferredReturnTypes.assign( + {LayoutType::get(context, LayoutAttr::get(context, newShape, newStride))}); + return success(); +} + +FLY_INFER_RETURN_TYPES(SelectOp) { + auto idxArr = properties.as()->indices.asArrayRef(); + SmallVector indices(idxArr.begin(), idxArr.end()); + + Type inputTy = operands[0].getType(); + IntTupleBuilder builder(context); + + if (auto tupleTy = dyn_cast(inputTy)) { + IntTupleAttr profile = tupleTy.getAttr(); + IntTupleAttr selected = intTupleSelect(builder, profile, indices); + inferredReturnTypes.assign({IntTupleType::get(selected)}); + return success(); + } + + if (auto layoutTy = dyn_cast(inputTy)) { + LayoutAttr profile = layoutTy.getAttr(); + IntTupleAttr newShape = intTupleSelect(builder, profile.getShape(), indices); + IntTupleAttr newStride = intTupleSelect(builder, profile.getStride(), indices); + inferredReturnTypes.assign( + {LayoutType::get(context, LayoutAttr::get(context, newShape, newStride))}); + return success(); + } + + return failure(); +} + +FLY_INFER_RETURN_TYPES(GroupOp) { + int32_t begin = properties.as()->begin.getInt(); + int32_t end = properties.as()->end.getInt(); + + Type inputTy = operands[0].getType(); + IntTupleBuilder builder(context); + + if (auto tupleTy = dyn_cast(inputTy)) { + IntTupleAttr profile = tupleTy.getAttr(); + IntTupleAttr grouped = intTupleGroup(builder, profile, begin, end); + inferredReturnTypes.assign({IntTupleType::get(grouped)}); + return success(); + } + + if (auto layoutTy = dyn_cast(inputTy)) { + LayoutAttr profile = layoutTy.getAttr(); + IntTupleAttr newShape = intTupleGroup(builder, profile.getShape(), begin, end); + IntTupleAttr newStride = intTupleGroup(builder, profile.getStride(), begin, end); + inferredReturnTypes.assign( + {LayoutType::get(context, LayoutAttr::get(context, newShape, newStride))}); + return success(); + } + + return failure(); +} + +FLY_INFER_RETURN_TYPES(SliceOp) { + Type srcTy = operands[0].getType(); + auto coordTy = dyn_cast(operands[1].getType()); + if (!coordTy) + return failure(); + + IntTupleAttr coordAttr = coordTy.getAttr(); + IntTupleBuilder builder(context); + + if (auto srcTupleTy = dyn_cast(srcTy)) { + IntTupleAttr result = intTupleSlice(builder, srcTupleTy.getAttr(), coordAttr); + inferredReturnTypes.assign({IntTupleType::get(result)}); + return success(); + } + if (auto srcLayoutTy = dyn_cast(srcTy)) { + LayoutAttr profile = srcLayoutTy.getAttr(); + IntTupleAttr newShape = intTupleSlice(builder, profile.getShape(), coordAttr); + IntTupleAttr newStride = intTupleSlice(builder, profile.getStride(), coordAttr); + inferredReturnTypes.assign( + {LayoutType::get(context, LayoutAttr::get(context, newShape, newStride))}); + return success(); + } + if (auto srcMemRefTy = dyn_cast(srcTy)) { + LayoutAttr layoutAttr = srcMemRefTy.getLayout(); + IntTupleAttr newShape = intTupleSlice(builder, layoutAttr.getShape(), coordAttr); + IntTupleAttr newStride = intTupleSlice(builder, layoutAttr.getStride(), coordAttr); + auto newLayoutAttr = LayoutAttr::get(context, newShape, newStride); + inferredReturnTypes.assign( + {MemRefType::get(srcMemRefTy.getElemTy(), srcMemRefTy.getAddressSpace(), newLayoutAttr)}); + return success(); + } + + return failure(); +} + +FLY_INFER_RETURN_TYPES(DiceOp) { + Type srcTy = operands[0].getType(); + if (isa(srcTy)) { + inferredReturnTypes.assign({srcTy}); + return success(); + } + return failure(); +} + +FLY_INFER_RETURN_TYPES(SizeOp) { + if (auto intTupleTy = dyn_cast(operands[0].getType())) { + IntTupleBuilder builder(context); + IntTupleAttr size = intTupleProduct(builder, intTupleTy.getAttr()); + inferredReturnTypes.assign({IntTupleType::get(size)}); + return success(); + } + if (auto layoutTy = dyn_cast(operands[0].getType())) { + LayoutBuilder layoutBuilder(context); + IntTupleAttr size = layoutSize(layoutBuilder, layoutTy.getAttr()); + inferredReturnTypes.assign({IntTupleType::get(size)}); + return success(); + } + if (auto memrefTy = dyn_cast(operands[0].getType())) { + LayoutBuilder layoutBuilder(context); + IntTupleAttr size = layoutSize(layoutBuilder, memrefTy.getLayout()); + inferredReturnTypes.assign({IntTupleType::get(size)}); + return success(); + } + return failure(); +} + +FLY_INFER_RETURN_TYPES(CosizeOp) { + auto layoutTy = dyn_cast(operands[0].getType()); + if (!layoutTy) + return failure(); + + LayoutBuilder layoutBuilder(context); + IntTupleAttr cosize = layoutCosize(layoutBuilder, layoutTy.getAttr()); + inferredReturnTypes.assign({IntTupleType::get(cosize)}); + return success(); +} + +FLY_INFER_RETURN_TYPES(Crd2IdxOp) { + auto coordTy = dyn_cast(operands[0].getType()); + auto layoutTy = dyn_cast(operands[1].getType()); + if (!coordTy || !layoutTy) + return failure(); + + IntTupleAttr coordAttr = coordTy.getAttr(); + LayoutAttr layoutAttr = layoutTy.getAttr(); + IntTupleBuilder builder(context); + + IntTupleAttr result = + layoutCrd2idx(builder, coordAttr, layoutAttr.getShape(), layoutAttr.getStride()); + inferredReturnTypes.assign({IntTupleType::get(result)}); + return success(); +} + +FLY_INFER_RETURN_TYPES(Idx2CrdOp) { + auto layoutTy = dyn_cast(operands[1].getType()); + if (!layoutTy) + return failure(); + inferredReturnTypes.assign({IntTupleType::get(layoutTy.getAttr().getShape())}); + return success(); +} + +FLY_INFER_RETURN_TYPES(GetFlatCoordOp) { + auto inputTupleTy = dyn_cast(operands[1].getType()); + if (inputTupleTy) { + inferredReturnTypes.assign({inputTupleTy}); + return success(); + } + auto inputLayoutTy = dyn_cast(operands[1].getType()); + if (!inputLayoutTy) + return failure(); + inferredReturnTypes.assign({IntTupleType::get(inputLayoutTy.getAttr().getShape())}); + return success(); +} + +FLY_INFER_RETURN_TYPES(GetHierCoordOp) { + auto inputTupleTy = dyn_cast(operands[1].getType()); + if (inputTupleTy) { + inferredReturnTypes.assign({inputTupleTy}); + return success(); + } + auto inputLayoutTy = dyn_cast(operands[1].getType()); + if (!inputLayoutTy) + return failure(); + inferredReturnTypes.assign({IntTupleType::get(inputLayoutTy.getAttr().getShape())}); + return success(); +} + +FLY_INFER_RETURN_TYPES(CoalesceOp) { + auto layoutTy = dyn_cast(operands[0].getType()); + if (!layoutTy) + return failure(); + + std::optional profileAttr; + if (operands.size() > 1 && operands[1]) { + auto profileTy = dyn_cast(operands[1].getType()); + if (!profileTy) + return failure(); + profileAttr = profileTy.getAttr(); + } + + LayoutBuilder layoutBuilder(context); + LayoutAttr inferred = layoutCoalesce(layoutBuilder, layoutTy.getAttr(), profileAttr); + inferredReturnTypes.assign({LayoutType::get(context, inferred)}); + return success(); +} + +FLY_INFER_RETURN_TYPES(CompositionOp) { + auto outerLayoutTy = dyn_cast(operands[0].getType()); + if (!outerLayoutTy) + return failure(); + + LayoutBuilder layoutBuilder(context); + Type innerTy = operands[1].getType(); + if (auto tileTy = dyn_cast(innerTy)) { + LayoutAttr inferred = + layoutComposition(layoutBuilder, outerLayoutTy.getAttr(), tileTy.getAttr()); + inferredReturnTypes.assign({LayoutType::get(context, inferred)}); + return success(); + } + if (auto innerLayoutTy = dyn_cast(innerTy)) { + LayoutAttr inferred = + layoutComposition(layoutBuilder, outerLayoutTy.getAttr(), innerLayoutTy.getAttr()); + inferredReturnTypes.assign({LayoutType::get(context, inferred)}); + return success(); + } + return failure(); +} + +FLY_INFER_RETURN_TYPES(ComplementOp) { + auto layoutTy = dyn_cast(operands[0].getType()); + if (!layoutTy) + return failure(); + + std::optional codomainSizeAttr; + if (operands.size() > 1 && operands[1]) { + codomainSizeAttr = cast(operands[1].getType()).getAttr(); + } + + LayoutBuilder layoutBuilder(context); + LayoutAttr inferred = layoutComplement(layoutBuilder, layoutTy.getAttr(), codomainSizeAttr); + inferredReturnTypes.assign({LayoutType::get(context, inferred)}); + return success(); +} + +FLY_INFER_RETURN_TYPES(RightInverseOp) { + auto layoutTy = dyn_cast(operands[0].getType()); + if (!layoutTy) + return failure(); + inferredReturnTypes.assign({layoutTy}); + return success(); +} + +FLY_INFER_RETURN_TYPES(LeftInverseOp) { + auto layoutTy = dyn_cast(operands[0].getType()); + if (!layoutTy) + return failure(); + inferredReturnTypes.assign({layoutTy}); + return success(); +} + +FLY_INFER_RETURN_TYPES(RecastLayoutOp) { + auto layoutTy = dyn_cast(operands[2].getType()); + if (!layoutTy) + return failure(); + inferredReturnTypes.assign({layoutTy}); + return success(); +} + +FLY_INFER_RETURN_TYPES(LogicalDivideOp) { + LayoutAttr layoutAttr = nullptr; + MemRefType memrefTy = nullptr; + Type lhsTy = operands[0].getType(); + + if (auto layoutTy = dyn_cast(lhsTy)) { + layoutAttr = layoutTy.getAttr(); + } else if ((memrefTy = dyn_cast(lhsTy))) { + layoutAttr = memrefTy.getLayout(); + } else { + return failure(); + } + + Type divisorTy = operands[1].getType(); + LayoutAttr inferred; + + if (auto divisorLayoutTy = dyn_cast(divisorTy)) { + LayoutBuilder layoutBuilder(context); + inferred = layoutLogicalDivide(layoutBuilder, layoutAttr, divisorLayoutTy.getAttr()); + } else if (auto divisorTileTy = dyn_cast(divisorTy)) { + LayoutBuilder layoutBuilder(context); + inferred = layoutLogicalDivide(layoutBuilder, layoutAttr, divisorTileTy.getAttr()); + } else { + return failure(); + } + + if (memrefTy) { + inferredReturnTypes.assign( + {MemRefType::get(memrefTy.getElemTy(), memrefTy.getAddressSpace(), inferred)}); + } else { + inferredReturnTypes.assign({LayoutType::get(context, inferred)}); + } + return success(); +} + +FLY_INFER_RETURN_TYPES(ZippedDivideOp) { + LayoutAttr layoutAttr = nullptr; + MemRefType memrefTy = nullptr; + Type lhsTy = operands[0].getType(); + + if (auto layoutTy = dyn_cast(lhsTy)) { + layoutAttr = layoutTy.getAttr(); + } else if ((memrefTy = dyn_cast(lhsTy))) { + layoutAttr = memrefTy.getLayout(); + } else { + return failure(); + } + + Type divisorTy = operands[1].getType(); + LayoutAttr inferred; + + if (auto divisorLayoutTy = dyn_cast(divisorTy)) { + LayoutBuilder layoutBuilder(context); + inferred = layoutZippedDivide(layoutBuilder, layoutAttr, divisorLayoutTy.getAttr()); + } else if (auto divisorTileTy = dyn_cast(divisorTy)) { + LayoutBuilder layoutBuilder(context); + inferred = layoutZippedDivide(layoutBuilder, layoutAttr, divisorTileTy.getAttr()); + } else { + return failure(); + } + + if (memrefTy) { + inferredReturnTypes.assign( + {MemRefType::get(memrefTy.getElemTy(), memrefTy.getAddressSpace(), inferred)}); + } else { + inferredReturnTypes.assign({LayoutType::get(context, inferred)}); + } + return success(); +} + +FLY_INFER_RETURN_TYPES(TiledDivideOp) { + LayoutAttr layoutAttr = nullptr; + MemRefType memrefTy = nullptr; + Type lhsTy = operands[0].getType(); + + if (auto layoutTy = dyn_cast(lhsTy)) { + layoutAttr = layoutTy.getAttr(); + } else if ((memrefTy = dyn_cast(lhsTy))) { + layoutAttr = memrefTy.getLayout(); + } else { + return failure(); + } + + Type divisorTy = operands[1].getType(); + LayoutAttr inferred; + + if (auto divisorLayoutTy = dyn_cast(divisorTy)) { + LayoutBuilder layoutBuilder(context); + inferred = layoutTiledDivide(layoutBuilder, layoutAttr, divisorLayoutTy.getAttr()); + } else if (auto divisorTileTy = dyn_cast(divisorTy)) { + LayoutBuilder layoutBuilder(context); + inferred = layoutTiledDivide(layoutBuilder, layoutAttr, divisorTileTy.getAttr()); + } else { + return failure(); + } + + if (memrefTy) { + inferredReturnTypes.assign( + {MemRefType::get(memrefTy.getElemTy(), memrefTy.getAddressSpace(), inferred)}); + } else { + inferredReturnTypes.assign({LayoutType::get(context, inferred)}); + } + return success(); +} + +FLY_INFER_RETURN_TYPES(FlatDivideOp) { + LayoutAttr layoutAttr = nullptr; + MemRefType memrefTy = nullptr; + Type lhsTy = operands[0].getType(); + + if (auto layoutTy = dyn_cast(lhsTy)) { + layoutAttr = layoutTy.getAttr(); + } else if ((memrefTy = dyn_cast(lhsTy))) { + layoutAttr = memrefTy.getLayout(); + } else { + return failure(); + } + + Type divisorTy = operands[1].getType(); + LayoutAttr inferred; + + if (auto divisorLayoutTy = dyn_cast(divisorTy)) { + LayoutBuilder layoutBuilder(context); + inferred = layoutFlatDivide(layoutBuilder, layoutAttr, divisorLayoutTy.getAttr()); + } else if (auto divisorTileTy = dyn_cast(divisorTy)) { + LayoutBuilder layoutBuilder(context); + inferred = layoutFlatDivide(layoutBuilder, layoutAttr, divisorTileTy.getAttr()); + } else { + return failure(); + } + + if (memrefTy) { + inferredReturnTypes.assign( + {MemRefType::get(memrefTy.getElemTy(), memrefTy.getAddressSpace(), inferred)}); + } else { + inferredReturnTypes.assign({LayoutType::get(context, inferred)}); + } + return success(); +} + +FLY_INFER_RETURN_TYPES(LogicalProductOp) { + LayoutAttr layoutAttr = nullptr; + MemRefType memrefTy = nullptr; + Type lhsTy = operands[0].getType(); + + if (auto layoutTy = dyn_cast(lhsTy)) { + layoutAttr = layoutTy.getAttr(); + } else if ((memrefTy = dyn_cast(lhsTy))) { + layoutAttr = memrefTy.getLayout(); + } else { + return failure(); + } + + auto tilerTy = dyn_cast(operands[1].getType()); + if (!tilerTy) + return failure(); + + LayoutBuilder layoutBuilder(context); + LayoutAttr inferred = layoutLogicalProduct(layoutBuilder, layoutAttr, tilerTy.getAttr()); + + if (memrefTy) { + inferredReturnTypes.assign( + {MemRefType::get(memrefTy.getElemTy(), memrefTy.getAddressSpace(), inferred)}); + } else { + inferredReturnTypes.assign({LayoutType::get(context, inferred)}); + } + return success(); +} + +FLY_INFER_RETURN_TYPES(ZippedProductOp) { + LayoutAttr layoutAttr = nullptr; + MemRefType memrefTy = nullptr; + Type lhsTy = operands[0].getType(); + + if (auto layoutTy = dyn_cast(lhsTy)) { + layoutAttr = layoutTy.getAttr(); + } else if ((memrefTy = dyn_cast(lhsTy))) { + layoutAttr = memrefTy.getLayout(); + } else { + return failure(); + } + + auto tilerTy = dyn_cast(operands[1].getType()); + if (!tilerTy) + return failure(); + + LayoutBuilder layoutBuilder(context); + LayoutAttr logicalProd = layoutLogicalProduct(layoutBuilder, layoutAttr, tilerTy.getAttr()); + + // zip2_by with tiler shape as guide + IntTupleBuilder builder(context); + IntTupleAttr guide = tilerTy.getAttr().getShape(); + IntTupleAttr newShape = intTupleZip2By(builder, logicalProd.getShape(), guide); + IntTupleAttr newStride = intTupleZip2By(builder, logicalProd.getStride(), guide); + LayoutAttr inferred = LayoutAttr::get(context, newShape, newStride); + + if (memrefTy) { + inferredReturnTypes.assign( + {MemRefType::get(memrefTy.getElemTy(), memrefTy.getAddressSpace(), inferred)}); + } else { + inferredReturnTypes.assign({LayoutType::get(context, inferred)}); + } + return success(); +} + +FLY_INFER_RETURN_TYPES(TiledProductOp) { + LayoutAttr layoutAttr = nullptr; + MemRefType memrefTy = nullptr; + Type lhsTy = operands[0].getType(); + + if (auto layoutTy = dyn_cast(lhsTy)) { + layoutAttr = layoutTy.getAttr(); + } else if ((memrefTy = dyn_cast(lhsTy))) { + layoutAttr = memrefTy.getLayout(); + } else { + return failure(); + } + + auto tilerTy = dyn_cast(operands[1].getType()); + if (!tilerTy) + return failure(); + + LayoutBuilder layoutBuilder(context); + LayoutAttr logicalProd = layoutLogicalProduct(layoutBuilder, layoutAttr, tilerTy.getAttr()); + + IntTupleBuilder builder(context); + IntTupleAttr guide = tilerTy.getAttr().getShape(); + IntTupleAttr zippedShape = intTupleZip2By(builder, logicalProd.getShape(), guide); + IntTupleAttr zippedStride = intTupleZip2By(builder, logicalProd.getStride(), guide); + + // Expand index 1 + // TODO: Implement proper expand logic + LayoutAttr inferred = LayoutAttr::get(context, zippedShape, zippedStride); + + if (memrefTy) { + inferredReturnTypes.assign( + {MemRefType::get(memrefTy.getElemTy(), memrefTy.getAddressSpace(), inferred)}); + } else { + inferredReturnTypes.assign({LayoutType::get(context, inferred)}); + } + return success(); +} + +FLY_INFER_RETURN_TYPES(FlatProductOp) { + LayoutAttr layoutAttr = nullptr; + MemRefType memrefTy = nullptr; + Type lhsTy = operands[0].getType(); + + if (auto layoutTy = dyn_cast(lhsTy)) { + layoutAttr = layoutTy.getAttr(); + } else if ((memrefTy = dyn_cast(lhsTy))) { + layoutAttr = memrefTy.getLayout(); + } else { + return failure(); + } + + auto tilerTy = dyn_cast(operands[1].getType()); + if (!tilerTy) + return failure(); + + LayoutBuilder layoutBuilder(context); + LayoutAttr logicalProd = layoutLogicalProduct(layoutBuilder, layoutAttr, tilerTy.getAttr()); + + IntTupleBuilder builder(context); + IntTupleAttr guide = tilerTy.getAttr().getShape(); + IntTupleAttr zippedShape = intTupleZip2By(builder, logicalProd.getShape(), guide); + IntTupleAttr zippedStride = intTupleZip2By(builder, logicalProd.getStride(), guide); + + // Expand indices 0 and 1 + // TODO: Implement proper expand logic + LayoutAttr inferred = LayoutAttr::get(context, zippedShape, zippedStride); + + if (memrefTy) { + inferredReturnTypes.assign( + {MemRefType::get(memrefTy.getElemTy(), memrefTy.getAddressSpace(), inferred)}); + } else { + inferredReturnTypes.assign({LayoutType::get(context, inferred)}); + } + return success(); +} + +FLY_INFER_RETURN_TYPES(BlockedProductOp) { + LayoutAttr layoutAttr = nullptr; + MemRefType memrefTy = nullptr; + Type lhsTy = operands[0].getType(); + + if (auto layoutTy = dyn_cast(lhsTy)) { + layoutAttr = layoutTy.getAttr(); + } else if ((memrefTy = dyn_cast(lhsTy))) { + layoutAttr = memrefTy.getLayout(); + } else { + return failure(); + } + + auto tilerTy = dyn_cast(operands[1].getType()); + if (!tilerTy) + return failure(); + + LayoutBuilder layoutBuilder(context); + LayoutAttr inferred = layoutBlockedProduct(layoutBuilder, layoutAttr, tilerTy.getAttr()); + + if (memrefTy) { + inferredReturnTypes.assign( + {MemRefType::get(memrefTy.getElemTy(), memrefTy.getAddressSpace(), inferred)}); + } else { + inferredReturnTypes.assign({LayoutType::get(context, inferred)}); + } + return success(); +} + +FLY_INFER_RETURN_TYPES(RakedProductOp) { + LayoutAttr layoutAttr = nullptr; + MemRefType memrefTy = nullptr; + Type lhsTy = operands[0].getType(); + + if (auto layoutTy = dyn_cast(lhsTy)) { + layoutAttr = layoutTy.getAttr(); + } else if ((memrefTy = dyn_cast(lhsTy))) { + layoutAttr = memrefTy.getLayout(); + } else { + return failure(); + } + + auto tilerTy = dyn_cast(operands[1].getType()); + if (!tilerTy) + return failure(); + + LayoutBuilder layoutBuilder(context); + LayoutAttr inferred = layoutRakedProduct(layoutBuilder, layoutAttr, tilerTy.getAttr()); + + if (memrefTy) { + inferredReturnTypes.assign( + {MemRefType::get(memrefTy.getElemTy(), memrefTy.getAddressSpace(), inferred)}); + } else { + inferredReturnTypes.assign({LayoutType::get(context, inferred)}); + } + return success(); +} + +FLY_INFER_RETURN_TYPES(TileToShapeOp) { + auto shapeTy = dyn_cast(operands[1].getType()); + if (!shapeTy) + return failure(); + IntTupleAttr shapeAttr = shapeTy.getAttr(); + LayoutAttr layoutAttr = LayoutAttr::get(context, shapeAttr, makeDynamicLike(shapeAttr)); + inferredReturnTypes.assign({LayoutType::get(context, layoutAttr)}); + return success(); +} + +FLY_INFER_RETURN_TYPES(MakeTiledCopyOp) { + auto copyAtomTy = operands[0].getType(); + auto layoutTy = dyn_cast(operands[1].getType()); + auto tileTy = dyn_cast(operands[2].getType()); + if (!layoutTy || !tileTy) + return failure(); + + auto tiledCopyTy = TiledCopyType::get(context, copyAtomTy, layoutTy, tileTy); + inferredReturnTypes.assign({tiledCopyTy}); + return success(); +} + +FLY_INFER_RETURN_TYPES(TiledCopyPartitionSrcOp) { + auto memrefTy = dyn_cast(operands[1].getType()); + if (!memrefTy) + return failure(); + inferredReturnTypes.assign({memrefTy}); + return success(); +} + +FLY_INFER_RETURN_TYPES(TiledCopyPartitionDstOp) { + auto memrefTy = dyn_cast(operands[1].getType()); + if (!memrefTy) + return failure(); + inferredReturnTypes.assign({memrefTy}); + return success(); +} + +FLY_INFER_RETURN_TYPES(TiledCopyPartitionDOp) { + auto memrefTy = dyn_cast(operands[1].getType()); + if (!memrefTy) + return failure(); + inferredReturnTypes.assign({memrefTy}); + return success(); +} + +FLY_INFER_RETURN_TYPES(TiledCopyPartitionSOp) { + auto memrefTy = dyn_cast(operands[1].getType()); + if (!memrefTy) + return failure(); + inferredReturnTypes.assign({memrefTy}); + return success(); +} + +FLY_INFER_RETURN_TYPES(TiledCopyRetileOp) { + auto memrefTy = dyn_cast(operands[1].getType()); + if (!memrefTy) + return failure(); + inferredReturnTypes.assign({memrefTy}); + return success(); +} + +FLY_INFER_RETURN_TYPES(TiledMmaPartitionOp) { + auto memrefTy = dyn_cast(operands[2].getType()); + if (!memrefTy) + return failure(); + inferredReturnTypes.assign({memrefTy}); + return success(); +} + +FLY_INFER_RETURN_TYPES(TiledMmaPartitionShapeOp) { + auto memrefTy = dyn_cast(operands[2].getType()); + if (!memrefTy) + return failure(); + inferredReturnTypes.assign({IntTupleType::get(memrefTy.getLayout().getShape())}); + return success(); +} + +FLY_INFER_RETURN_TYPES(MemRefAllocSharedOp) { + auto memrefTy = dyn_cast(operands[0].getType()); + if (!memrefTy) + return failure(); + auto sharedSpace = AddressSpaceAttr::get(context, AddressSpace::Shared); + inferredReturnTypes.assign( + {MemRefType::get(memrefTy.getElemTy(), sharedSpace, memrefTy.getLayout())}); + return success(); +} + +FLY_INFER_RETURN_TYPES(MemRefLoadOp) { + auto memrefTy = dyn_cast(operands[0].getType()); + if (!memrefTy) + return failure(); + inferredReturnTypes.push_back(memrefTy.getElemTy()); + return success(); +} + +FLY_INFER_RETURN_TYPES(MemRefLoadVecOp) { + auto memrefTy = dyn_cast(operands[0].getType()); + if (!memrefTy) + return failure(); + + LayoutAttr layoutAttr = memrefTy.getLayout(); + IntTupleBuilder builder(context); + IntAttr size = cast(intTupleProduct(builder, layoutAttr.getShape()).getValue()); + + if (!size.isStatic()) + return failure(); + + inferredReturnTypes.push_back(VectorType::get({size.getValue()}, memrefTy.getElemTy())); + return success(); +} + +FLY_INFER_RETURN_TYPES(RecastIterOp) { + auto ptrTy = dyn_cast(operands[0].getType()); + if (!ptrTy) + return failure(); + inferredReturnTypes.assign({ptrTy}); + return success(); +} + +FLY_INFER_RETURN_TYPES(AddOffsetOp) { + auto ptrTy = dyn_cast(operands[0].getType()); + auto offsetTy = dyn_cast(operands[1].getType()); + if (!ptrTy || !offsetTy) + return failure(); + // Offset must be a scalar (leaf) int_tuple + if (!offsetTy.getAttr().isLeaf()) + return failure(); + inferredReturnTypes.assign({ptrTy}); + return success(); +} + +FLY_INFER_RETURN_TYPES(ApplySwizzleOp) { + auto ptrTy = dyn_cast(operands[0].getType()); + if (!ptrTy) + return failure(); + inferredReturnTypes.assign({ptrTy}); + return success(); +} + +#undef FLY_INFER_RETURN_TYPES diff --git a/lib/Dialect/Fly/IR/FlyTypeDefs.cpp b/lib/Dialect/Fly/IR/FlyTypeDefs.cpp new file mode 100644 index 00000000..b21c55c5 --- /dev/null +++ b/lib/Dialect/Fly/IR/FlyTypeDefs.cpp @@ -0,0 +1,78 @@ +#include "flydsl/Dialect/Fly/IR/FlyDialect.h" + +namespace mlir::fly { + +bool BasisType::isStatic() const { return getAttr().isStatic(); } +bool IntTupleType::isStatic() const { return getAttr().isStatic(); } +bool LayoutType::isStatic() const { return getAttr().isStatic(); } +bool ComposedLayoutType::isStatic() const { return getAttr().isStatic(); } +bool CoordTensorType::isStatic() const { return getBase().isStatic() && getLayout().isStatic(); } + +int32_t BasisType::depth() { return getAttr().depth(); } + +bool IntTupleType::isLeaf() const { return getAttr().isLeaf(); } +int32_t IntTupleType::rank() const { return getAttr().rank(); } +int32_t IntTupleType::rank(int32_t idx) const { return getAttr().rank(idx); } +int32_t IntTupleType::rank(ArrayRef idxs) const { return getAttr().rank(idxs); } +int32_t IntTupleType::depth() const { return getAttr().depth(); } +int32_t IntTupleType::depth(int32_t idx) const { return getAttr().depth(idx); } +int32_t IntTupleType::depth(ArrayRef idxs) const { return getAttr().depth(idxs); } + +bool LayoutType::isLeaf() const { return getAttr().isLeaf(); } +int32_t LayoutType::rank() const { return getAttr().rank(); } +int32_t LayoutType::rank(int32_t idx) const { return getAttr().rank(idx); } +int32_t LayoutType::rank(ArrayRef idxs) const { return getAttr().rank(idxs); } +int32_t LayoutType::depth() const { return getAttr().depth(); } +int32_t LayoutType::depth(int32_t idx) const { return getAttr().depth(idx); } +int32_t LayoutType::depth(ArrayRef idxs) const { return getAttr().depth(idxs); } +bool LayoutType::isStaticShape() const { return getAttr().isStaticShape(); } +bool LayoutType::isStaticStride() const { return getAttr().isStaticStride(); } + +bool ComposedLayoutType::isLeaf() const { return getAttr().isLeaf(); } +int32_t ComposedLayoutType::rank() const { return getAttr().rank(); } +int32_t ComposedLayoutType::rank(int32_t idx) const { return getAttr().rank(idx); } +int32_t ComposedLayoutType::rank(ArrayRef idxs) const { return getAttr().rank(idxs); } +int32_t ComposedLayoutType::depth() const { return getAttr().depth(); } +int32_t ComposedLayoutType::depth(int32_t idx) const { return getAttr().depth(idx); } +int32_t ComposedLayoutType::depth(ArrayRef idxs) const { return getAttr().depth(idxs); } +bool ComposedLayoutType::isStaticOuter() const { return getAttr().isStaticOuter(); } +bool ComposedLayoutType::isStaticInner() const { return getAttr().isStaticInner(); } +bool ComposedLayoutType::isStaticOffset() const { return getAttr().isStaticOffset(); } + +int32_t TileType::rank() const { return getAttr().rank(); } + +bool CoordTensorType::isLeaf() const { return getLayout().isLeaf(); } +int32_t CoordTensorType::rank() const { return getLayout().rank(); } +int32_t CoordTensorType::rank(int32_t idx) const { return getLayout().rank(idx); } +int32_t CoordTensorType::rank(ArrayRef idxs) const { return getLayout().rank(idxs); } +int32_t CoordTensorType::depth() const { return getLayout().depth(); } +int32_t CoordTensorType::depth(int32_t idx) const { return getLayout().depth(idx); } +int32_t CoordTensorType::depth(ArrayRef idxs) const { return getLayout().depth(idxs); } + +IntTupleType IntTupleType::at(int32_t idx) const { + return IntTupleType::get(getContext(), getAttr().at(idx)); +} +IntTupleType IntTupleType::at(ArrayRef idxs) const { + return IntTupleType::get(getContext(), getAttr().at(idxs)); +} +LayoutType LayoutType::at(int32_t idx) const { + return LayoutType::get(getContext(), getAttr().at(idx)); +} +LayoutType LayoutType::at(ArrayRef idxs) const { + return LayoutType::get(getContext(), getAttr().at(idxs)); +} +ComposedLayoutType ComposedLayoutType::at(int32_t idx) const { + return ComposedLayoutType::get(getContext(), getAttr().at(idx)); +} +ComposedLayoutType ComposedLayoutType::at(ArrayRef idxs) const { + return ComposedLayoutType::get(getContext(), getAttr().at(idxs)); +} + +CoordTensorType CoordTensorType::at(int32_t idx) const { + return CoordTensorType::get(getContext(), getBase().at(idx), getLayout().at(idx)); +} +CoordTensorType CoordTensorType::at(ArrayRef idxs) const { + return CoordTensorType::get(getContext(), getBase().at(idxs), getLayout().at(idxs)); +} + +} // namespace mlir::fly diff --git a/lib/Dialect/Fly/Transforms/FlyCanonicalize.cpp b/lib/Dialect/Fly/Transforms/FlyCanonicalize.cpp new file mode 100644 index 00000000..4b5bf6a4 --- /dev/null +++ b/lib/Dialect/Fly/Transforms/FlyCanonicalize.cpp @@ -0,0 +1,104 @@ + +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#include "flydsl/Dialect/Fly/IR/FlyDialect.h" +#include "flydsl/Dialect/Fly/Transforms/Passes.h" + +using namespace mlir; +using namespace mlir::fly; + +namespace mlir { +namespace fly { +#define GEN_PASS_DEF_FLYCANONICALIZEPASS +#include "flydsl/Dialect/Fly/Transforms/Passes.h.inc" +} // namespace fly +} // namespace mlir + +namespace { + +template +struct RewriteToMakeIntTuple final : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(IntTupleLikeOp op, PatternRewriter &rewriter) const override { + auto newOp = MakeIntTupleOp::create(rewriter, op.getLoc(), op.getResult().getType(), + op->getOperands(), op->getAttrs()); + rewriter.replaceOp(op, newOp.getResult()); + return success(); + } +}; + +class StaticResultLowering : public RewritePattern { +public: + StaticResultLowering(MLIRContext *context, PatternBenefit benefit = 1) + : RewritePattern(MatchAnyOpTypeTag(), benefit, context) {} + + LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override { + // Skip ops that are already in normal form + if (isa(op)) + return failure(); + if (auto makeLayoutOp = dyn_cast(op)) { + if (makeLayoutOp.getShape().getDefiningOp() && + makeLayoutOp.getStride().getDefiningOp()) { + return failure(); + } + } + + // Must have exactly one result + if (op->getNumResults() != 1) + return failure(); + Type resultType = op->getResult(0).getType(); + Location loc = op->getLoc(); + + if (auto intTupleTy = dyn_cast(resultType)) { + IntTupleAttr intTupleAttr = intTupleTy.getAttr(); + if (!intTupleAttr.isStatic()) + return failure(); + rewriter.replaceOpWithNewOp(op, intTupleTy, ValueRange{}); + return success(); + } else if (auto layoutTy = dyn_cast(resultType)) { + LayoutAttr layoutAttr = layoutTy.getAttr(); + if (!layoutAttr.isStatic()) + return failure(); + + Value shape = + MakeIntTupleOp::create(rewriter, loc, IntTupleType::get(layoutAttr.getShape()), {}); + Value stride = + MakeIntTupleOp::create(rewriter, loc, IntTupleType::get(layoutAttr.getStride()), {}); + rewriter.replaceOpWithNewOp(op, layoutTy, shape, stride); + return success(); + } + + return failure(); + } +}; + +class FlyCanonicalizePass + : public mlir::fly::impl::FlyCanonicalizePassBase { +public: + using mlir::fly::impl::FlyCanonicalizePassBase::FlyCanonicalizePassBase; + + void runOnOperation() override { + MLIRContext *context = &getContext(); + RewritePatternSet patterns(context); + + patterns.add, RewriteToMakeIntTuple, + RewriteToMakeIntTuple>(context); + patterns.add(context); + + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) + signalPassFailure(); + } +}; + +} // namespace + +namespace impl { + +std::unique_ptr<::mlir::Pass> createFlyCanonicalizePass() { + return std::make_unique(); +} + +} // namespace impl diff --git a/lib/Dialect/Fly/Transforms/LayoutLowering.cpp b/lib/Dialect/Fly/Transforms/LayoutLowering.cpp new file mode 100644 index 00000000..18af34c7 --- /dev/null +++ b/lib/Dialect/Fly/Transforms/LayoutLowering.cpp @@ -0,0 +1,1609 @@ + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Arith/Utils/Utils.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#include "flydsl/Dialect/Fly/IR/FlyDialect.h" +#include "flydsl/Dialect/Fly/Transforms/Passes.h" +#include "flydsl/Dialect/Fly/Utils/IntTupleUtils.h" +#include "flydsl/Dialect/Fly/Utils/LayoutUtils.h" +#include "flydsl/Dialect/Fly/Utils/NormalForm.h" + +#include +#include +#include +#include +#include + +#include +#include + +using namespace mlir; +using namespace mlir::fly; + +namespace mlir { +namespace fly { +#define GEN_PASS_DEF_FLYLAYOUTLOWERINGPASS +#include "flydsl/Dialect/Fly/Transforms/Passes.h.inc" +} // namespace fly +} // namespace mlir + +namespace { + +// Helper to check if an operation is a make_int_tuple-like op +static bool isMakeIntTupleLikeOp(Operation *op) { + return isa_and_nonnull(op); +} + +static void collectDynamicLeaves(IntTupleAttr attr, SmallVectorImpl &dynamicLeaves) { + if (attr.isLeaf()) { + if (auto intAttr = dyn_cast(attr.getValue())) { + if (!intAttr.isStatic() && !intAttr.isNone()) { + dynamicLeaves.push_back(intAttr); + } + } + return; + } + for (int i = 0; i < attr.rank(); ++i) { + collectDynamicLeaves(attr.at(i), dynamicLeaves); + } +} + +static std::optional getIntTupleStructType(IntTupleAttr profile, + MLIRContext *ctx) { + SmallVector dynamicLeaves; + collectDynamicLeaves(profile, dynamicLeaves); + if (dynamicLeaves.empty()) + return std::nullopt; + + SmallVector fields; + fields.reserve(dynamicLeaves.size()); + for (size_t i = 0; i < dynamicLeaves.size(); ++i) { + if (dynamicLeaves[i].getWidth() == 32) { + fields.push_back(IntegerType::get(ctx, 32)); + } else { + fields.push_back(IntegerType::get(ctx, 64)); + } + } + // Use packed struct to avoid padding between fields + return LLVM::LLVMStructType::getLiteral(ctx, fields, /*isPacked=*/true); +} + +static LLVM::LLVMStructType getIntTupleStructTypeOrEmpty(IntTupleAttr profile, MLIRContext *ctx) { + SmallVector dynamicLeaves; + collectDynamicLeaves(profile, dynamicLeaves); + if (dynamicLeaves.empty()) + return LLVM::LLVMStructType::getLiteral(ctx, {}, /*isPacked=*/true); + + SmallVector fields; + fields.reserve(dynamicLeaves.size()); + for (auto leaf : dynamicLeaves) { + if (leaf.getWidth() == 32) { + fields.push_back(IntegerType::get(ctx, 32)); + } else { + fields.push_back(IntegerType::get(ctx, 64)); + } + } + // Use packed struct to avoid padding between fields + return LLVM::LLVMStructType::getLiteral(ctx, fields, /*isPacked=*/true); +} + +static LLVM::LLVMStructType getLayoutStructTypeOrEmpty(LayoutAttr layoutAttr, MLIRContext *ctx) { + SmallVector fields; + fields.reserve(2); + fields.push_back(getIntTupleStructTypeOrEmpty(layoutAttr.getShape(), ctx)); + fields.push_back(getIntTupleStructTypeOrEmpty(layoutAttr.getStride(), ctx)); + // Use packed struct to avoid padding between fields + return LLVM::LLVMStructType::getLiteral(ctx, fields, /*isPacked=*/true); +} + +static std::optional getLayoutStructType(LayoutAttr layoutAttr, + MLIRContext *ctx) { + SmallVector shapeLeaves; + SmallVector strideLeaves; + collectDynamicLeaves(layoutAttr.getShape(), shapeLeaves); + collectDynamicLeaves(layoutAttr.getStride(), strideLeaves); + if (shapeLeaves.empty() && strideLeaves.empty()) + return std::nullopt; + + SmallVector fields; + fields.reserve(2); + fields.push_back(getIntTupleStructTypeOrEmpty(layoutAttr.getShape(), ctx)); + fields.push_back(getIntTupleStructTypeOrEmpty(layoutAttr.getStride(), ctx)); + // Use packed struct to avoid padding between fields + return LLVM::LLVMStructType::getLiteral(ctx, fields, /*isPacked=*/true); +} + +static unsigned mapAddressSpace(AddressSpace space) { + switch (space) { + case AddressSpace::Flat: + return 0; + case AddressSpace::Global: + return 1; + case AddressSpace::Shared: + return 3; + case AddressSpace::Register: + return 5; + } + return 0; +} + +// Get the fly.ptr type for a MemRef +static PointerType getMemRefPtrType(fly::MemRefType memrefTy) { + auto *ctx = memrefTy.getContext(); + return PointerType::get(ctx, memrefTy.getElemTy(), memrefTy.getAddressSpace(), + memrefTy.getAlignment(), memrefTy.getSwizzle()); +} + +// Get the layout struct type for a MemRef (reuses existing layout struct logic) +// Returns nullopt if the layout is fully static (no dynamic elements) +static std::optional getMemRefLayoutStructType(fly::MemRefType memrefTy) { + auto layoutStructTy = getLayoutStructType(memrefTy.getLayout(), memrefTy.getContext()); + return layoutStructTy; // Returns nullopt if fully static +} + +// Check if a MemRef has any dynamic layout elements +static bool memrefHasDynamicLayout(fly::MemRefType memrefTy) { + SmallVector shapeLeaves, strideLeaves; + collectDynamicLeaves(memrefTy.getLayout().getShape(), shapeLeaves); + collectDynamicLeaves(memrefTy.getLayout().getStride(), strideLeaves); + return !shapeLeaves.empty() || !strideLeaves.empty(); +} + +static Value castToFieldType(OpBuilder &builder, Location loc, Value value, Type fieldTy) { + if (value.getType() == fieldTy) + return value; + if (fieldTy.isIndex()) + return arith::IndexCastOp::create(builder, loc, fieldTy, value); + if (auto intTy = dyn_cast(fieldTy)) { + if (value.getType().isIndex()) + return arith::IndexCastOp::create(builder, loc, fieldTy, value); + if (auto srcInt = dyn_cast(value.getType())) { + if (srcInt.getWidth() < intTy.getWidth()) + return arith::ExtSIOp::create(builder, loc, fieldTy, value); + if (srcInt.getWidth() > intTy.getWidth()) + return arith::TruncIOp::create(builder, loc, fieldTy, value); + } + } + return nullptr; +} + +static std::optional packIntTupleToStruct(OpBuilder &builder, Location loc, Value tuple, + LLVM::LLVMStructType structTy) { + auto tupleTy = dyn_cast(tuple.getType()); + if (!tupleTy) + return std::nullopt; + + IntTupleAttr profile = tupleTy.getAttr(); + SmallVector dynamicLeaves; + collectDynamicLeaves(profile, dynamicLeaves); + if (dynamicLeaves.empty()) { + return LLVM::UndefOp::create(builder, loc, structTy); + } + + Operation *defOp = tuple.getDefiningOp(); + if (!defOp || !isMakeIntTupleLikeOp(defOp)) + return std::nullopt; + + if (defOp->getNumOperands() != dynamicLeaves.size()) + return std::nullopt; + + Value result = LLVM::UndefOp::create(builder, loc, structTy); + for (size_t i = 0; i < dynamicLeaves.size(); ++i) { + Type valueFieldTy = structTy.getBody()[i]; + Value value = castToFieldType(builder, loc, defOp->getOperand(i), valueFieldTy); + if (!value) + return std::nullopt; + result = LLVM::InsertValueOp::create(builder, loc, structTy, result, value, + llvm::ArrayRef{static_cast(i)}); + } + return result; +} + +static std::optional> collectDynamicOperands(Value tuple, IntTupleAttr profile) { + SmallVector dynamicLeaves; + collectDynamicLeaves(profile, dynamicLeaves); + if (dynamicLeaves.empty()) + return SmallVector{}; + + Operation *defOp = tuple.getDefiningOp(); + if (!defOp || !isMakeIntTupleLikeOp(defOp)) + return std::nullopt; + + if (defOp->getNumOperands() != dynamicLeaves.size()) + return std::nullopt; + + SmallVector operands(defOp->getOperands().begin(), defOp->getOperands().end()); + return operands; +} + +static std::optional packLayoutToStruct(OpBuilder &builder, Location loc, Value layout, + LLVM::LLVMStructType structTy, + LayoutAttr layoutAttr) { + auto layoutOp = layout.getDefiningOp(); + if (!layoutOp) { + if (!layoutAttr.isStatic()) + return std::nullopt; + auto shapeStructTy = cast(structTy.getBody()[0]); + auto strideStructTy = cast(structTy.getBody()[1]); + Value shapeStruct = LLVM::UndefOp::create(builder, loc, shapeStructTy); + Value strideStruct = LLVM::UndefOp::create(builder, loc, strideStructTy); + Value result = LLVM::UndefOp::create(builder, loc, structTy); + result = LLVM::InsertValueOp::create(builder, loc, structTy, result, shapeStruct, + llvm::ArrayRef{0}); + result = LLVM::InsertValueOp::create(builder, loc, structTy, result, strideStruct, + llvm::ArrayRef{1}); + return result; + } + + auto shapeOps = collectDynamicOperands(layoutOp.getShape(), layoutAttr.getShape()); + auto strideOps = collectDynamicOperands(layoutOp.getStride(), layoutAttr.getStride()); + if (!shapeOps || !strideOps) + return std::nullopt; + + auto shapeStructTy = cast(structTy.getBody()[0]); + auto strideStructTy = cast(structTy.getBody()[1]); + + Value shapeStruct = LLVM::UndefOp::create(builder, loc, shapeStructTy); + for (size_t i = 0; i < shapeOps->size(); ++i) { + Type fieldTy = shapeStructTy.getBody()[i]; + Value casted = castToFieldType(builder, loc, (*shapeOps)[i], fieldTy); + if (!casted) + return std::nullopt; + shapeStruct = LLVM::InsertValueOp::create(builder, loc, shapeStructTy, shapeStruct, casted, + llvm::ArrayRef{static_cast(i)}); + } + + Value strideStruct = LLVM::UndefOp::create(builder, loc, strideStructTy); + for (size_t i = 0; i < strideOps->size(); ++i) { + Type fieldTy = strideStructTy.getBody()[i]; + Value casted = castToFieldType(builder, loc, (*strideOps)[i], fieldTy); + if (!casted) + return std::nullopt; + strideStruct = LLVM::InsertValueOp::create(builder, loc, strideStructTy, strideStruct, casted, + llvm::ArrayRef{static_cast(i)}); + } + + Value result = LLVM::UndefOp::create(builder, loc, structTy); + result = LLVM::InsertValueOp::create(builder, loc, structTy, result, shapeStruct, + llvm::ArrayRef{0}); + result = LLVM::InsertValueOp::create(builder, loc, structTy, result, strideStruct, + llvm::ArrayRef{1}); + return result; +} + +// Pack an IntTuple value to LLVM struct by extracting dynamic leaf values using GetLeafOp +// Uses recursive GetLeafOp calls to navigate nested IntTuple structure +static Value packIntTupleToStructGeneric(OpBuilder &builder, Location loc, Value intTuple, + IntTupleAttr profile, LLVM::LLVMStructType structTy) { + SmallVector dynamicLeaves; + collectDynamicLeaves(profile, dynamicLeaves); + + Value result = LLVM::UndefOp::create(builder, loc, structTy); + + // If no dynamic leaves, return undef struct + if (dynamicLeaves.empty()) + return result; + + // Recursively extract dynamic leaf values + int32_t structIdx = 0; + std::function extractLeaves = [&](Value currentTuple, + IntTupleAttr currentAttr) { + if (currentAttr.isLeaf()) { + if (!currentAttr.isStatic()) { + // Dynamic leaf - extract the scalar value using GetScalarOp + Value scalarVal = GetScalarOp::create(builder, loc, currentTuple); + Type fieldTy = structTy.getBody()[structIdx]; + Value casted = castToFieldType(builder, loc, scalarVal, fieldTy); + result = + LLVM::InsertValueOp::create(builder, loc, structTy, result, casted, + llvm::ArrayRef{static_cast(structIdx)}); + structIdx++; + } + return; + } + // Non-leaf: recurse into children using GetLeafOp + for (int32_t i = 0; i < currentAttr.rank(); ++i) { + Value childTuple = GetLeafOp::create(builder, loc, currentTuple, static_cast(i)); + extractLeaves(childTuple, currentAttr.at(i)); + } + }; + + extractLeaves(intTuple, profile); + return result; +} + +// Pack layout to struct - generic version that works for any layout Value (not just MakeLayoutOp) +static Value packLayoutToStructGeneric(OpBuilder &builder, Location loc, Value layout, + LayoutAttr layoutAttr, LLVM::LLVMStructType structTy) { + auto shapeStructTy = cast(structTy.getBody()[0]); + auto strideStructTy = cast(structTy.getBody()[1]); + + Value shapeValue = nullptr; + Value strideValue = nullptr; + + // Try to get shape and stride from MakeLayoutOp directly + if (auto layoutOp = layout.getDefiningOp()) { + shapeValue = layoutOp.getShape(); + strideValue = layoutOp.getStride(); + } else { + // Otherwise, create GetShapeOp and GetStrideOp + shapeValue = GetShapeOp::create(builder, loc, layout); + strideValue = GetStrideOp::create(builder, loc, layout); + } + + Value shapeStruct = + packIntTupleToStructGeneric(builder, loc, shapeValue, layoutAttr.getShape(), shapeStructTy); + Value strideStruct = packIntTupleToStructGeneric(builder, loc, strideValue, + layoutAttr.getStride(), strideStructTy); + + Value result = LLVM::UndefOp::create(builder, loc, structTy); + result = LLVM::InsertValueOp::create(builder, loc, structTy, result, shapeStruct, + llvm::ArrayRef{0}); + result = LLVM::InsertValueOp::create(builder, loc, structTy, result, strideStruct, + llvm::ArrayRef{1}); + return result; +} + +// Extract ptr and layout values from a MemRef, returns {ptr, layoutStruct} +static std::pair unpackMemRefToPtrAndLayout(OpBuilder &builder, Location loc, + Value memref, fly::MemRefType memrefTy) { + Value ptrValue = nullptr; + Value layoutValue = nullptr; + + if (auto makeView = memref.getDefiningOp()) { + ptrValue = makeView.getIter(); + layoutValue = makeView.getLayout(); + } else { + ptrValue = GetIterOp::create(builder, loc, memref); + layoutValue = GetLayoutOp::create(builder, loc, memref); + } + + auto layoutAttr = memrefTy.getLayout(); + auto layoutStructTy = getLayoutStructTypeOrEmpty(layoutAttr, memrefTy.getContext()); + + Value layoutStruct = + packLayoutToStructGeneric(builder, loc, layoutValue, layoutAttr, layoutStructTy); + + return std::make_pair(ptrValue, layoutStruct); +} + +static void lowerGpuLaunchFuncIntTupleOperands(gpu::LaunchFuncOp op) { + auto kernelRef = op.getKernel(); + auto gpuFunc = SymbolTable::lookupNearestSymbolFrom(op, kernelRef); + if (!gpuFunc) + return; + + SmallVector oldKernelOperands(op.getKernelOperands().begin(), + op.getKernelOperands().end()); + SmallVector newKernelOperands; + + OpBuilder builder(op); + bool changed = false; + + for (size_t i = 0; i < oldKernelOperands.size(); ++i) { + Value operand = oldKernelOperands[i]; + + if (auto tupleTy = dyn_cast(operand.getType())) { + auto structTy = getIntTupleStructTypeOrEmpty(tupleTy.getAttr(), op.getContext()); + if (auto packed = packIntTupleToStruct(builder, op.getLoc(), operand, structTy)) { + newKernelOperands.push_back(*packed); + changed = true; + } else { + newKernelOperands.push_back(operand); + } + continue; + } + if (auto layoutTy = dyn_cast(operand.getType())) { + auto structTy = getLayoutStructTypeOrEmpty(layoutTy.getAttr(), op.getContext()); + if (auto packed = + packLayoutToStruct(builder, op.getLoc(), operand, structTy, layoutTy.getAttr())) { + newKernelOperands.push_back(*packed); + changed = true; + } else { + newKernelOperands.push_back(operand); + } + continue; + } + if (auto memrefTy = dyn_cast(operand.getType())) { + // MemRef is split into arguments: fly.ptr and optionally layout struct (if dynamic) + auto unpacked = unpackMemRefToPtrAndLayout(builder, op.getLoc(), operand, memrefTy); + newKernelOperands.push_back(unpacked.first); // fly.ptr + // Only add layout struct if layout has dynamic elements + if (memrefHasDynamicLayout(memrefTy)) { + newKernelOperands.push_back(unpacked.second); // layout struct + } + changed = true; + continue; + } + // Other types pass through unchanged + newKernelOperands.push_back(operand); + } + + if (!changed) + return; + + op.getKernelOperandsMutable().assign(newKernelOperands); +} + +static bool lowerGpuFuncIntTupleArgs(gpu::GPUFuncOp op) { + auto funcType = op.getFunctionType(); + SmallVector oldInputs(funcType.getInputs().begin(), funcType.getInputs().end()); + + // First pass: compute new argument types + SmallVector newInputs; + // MemRefStatic: only ptr arg (layout is fully static) + // MemRefDynamic: ptr arg + layout struct arg + enum class ArgKind { None, IntTuple, Layout, MemRefStatic, MemRefDynamic }; + SmallVector argKinds; // One per old argument + + bool changed = false; + for (Type oldType : oldInputs) { + if (auto tupleTy = dyn_cast(oldType)) { + auto structTy = getIntTupleStructTypeOrEmpty(tupleTy.getAttr(), op.getContext()); + newInputs.push_back(structTy); + argKinds.push_back(ArgKind::IntTuple); + changed = true; + continue; + } + if (auto layoutTy = dyn_cast(oldType)) { + auto structTy = getLayoutStructTypeOrEmpty(layoutTy.getAttr(), op.getContext()); + newInputs.push_back(structTy); + argKinds.push_back(ArgKind::Layout); + changed = true; + continue; + } + if (auto memrefTy = dyn_cast(oldType)) { + // MemRef splits into args: fly.ptr and optionally layout struct (if dynamic) + auto ptrTy = getMemRefPtrType(memrefTy); + newInputs.push_back(ptrTy); + if (memrefHasDynamicLayout(memrefTy)) { + auto layoutStructTy = *getMemRefLayoutStructType(memrefTy); + newInputs.push_back(layoutStructTy); + argKinds.push_back(ArgKind::MemRefDynamic); + } else { + argKinds.push_back(ArgKind::MemRefStatic); + } + changed = true; + continue; + } + newInputs.push_back(oldType); + argKinds.push_back(ArgKind::None); + } + + if (!changed) + return false; + + // Update function type + auto newFuncType = FunctionType::get(op.getContext(), newInputs, funcType.getResults()); + op.setType(newFuncType); + + Block &entry = op.getBody().front(); + Location loc = op.getLoc(); + + // Transform block arguments: work backwards to handle index shifts from MemRef expansion + for (int i = oldInputs.size() - 1; i >= 0; --i) { + BlockArgument oldArg = entry.getArgument(i); + + if (argKinds[i] == ArgKind::None) { + continue; + } + + if (argKinds[i] == ArgKind::IntTuple || argKinds[i] == ArgKind::Layout) { + // Compute which newInputs index this corresponds to + size_t newIdx = 0; + for (int j = 0; j < i; ++j) { + newIdx++; + if (argKinds[j] == ArgKind::MemRefDynamic) + newIdx++; // MemRefDynamic adds an extra arg + } + oldArg.setType(newInputs[newIdx]); + continue; + } + + if (argKinds[i] == ArgKind::MemRefStatic) { + // Static MemRef: only ptr arg, no layout struct + size_t newIdx = 0; + for (int j = 0; j < i; ++j) { + newIdx++; + if (argKinds[j] == ArgKind::MemRefDynamic) + newIdx++; + } + oldArg.setType(newInputs[newIdx]); + continue; + } + + if (argKinds[i] == ArgKind::MemRefDynamic) { + // Dynamic MemRef: ptr arg + layout struct arg + size_t newIdx = 0; + for (int j = 0; j < i; ++j) { + newIdx++; + if (argKinds[j] == ArgKind::MemRefDynamic) + newIdx++; + } + // Change existing arg type to ptr + oldArg.setType(newInputs[newIdx]); + // Insert layout struct arg right after + entry.insertArgument(i + 1, newInputs[newIdx + 1], loc); + } + } + + // Now reconstruct the fly values from the new arguments + OpBuilder builder(&entry, entry.begin()); + + // Compute new argument indices for each old argument + size_t newArgIdx = 0; + for (size_t i = 0; i < oldInputs.size(); ++i) { + if (argKinds[i] == ArgKind::None) { + newArgIdx++; + continue; + } + + if (argKinds[i] == ArgKind::IntTuple) { + auto tupleTy = cast(oldInputs[i]); + auto structTy = cast(newInputs[newArgIdx]); + BlockArgument arg = entry.getArgument(newArgIdx); + + SmallVector dynamicLeaves; + collectDynamicLeaves(tupleTy.getAttr(), dynamicLeaves); + if (dynamicLeaves.empty()) { + Value tuple = StaticOp::create(builder, loc, tupleTy); + arg.replaceAllUsesWith(tuple); + newArgIdx++; + continue; + } + + SmallVector dyncElems; + SmallVector extractOps; + dyncElems.reserve(dynamicLeaves.size()); + + for (size_t j = 0; j < dynamicLeaves.size(); ++j) { + Type fieldTy = structTy.getBody()[j]; + Value val = LLVM::ExtractValueOp::create(builder, loc, fieldTy, arg, + llvm::ArrayRef{static_cast(j)}); + dyncElems.push_back(val); + extractOps.push_back(val.getDefiningOp()); + } + + Value tuple = MakeIntTupleOp::create(builder, loc, tupleTy, dyncElems); + llvm::SmallPtrSet except(extractOps.begin(), extractOps.end()); + arg.replaceAllUsesExcept(tuple, except); + newArgIdx++; + continue; + } + + if (argKinds[i] == ArgKind::Layout) { + auto layoutTy = cast(oldInputs[i]); + auto structTy = cast(newInputs[newArgIdx]); + BlockArgument arg = entry.getArgument(newArgIdx); + LayoutAttr layoutAttr = layoutTy.getAttr(); + + SmallVector shapeLeaves; + SmallVector strideLeaves; + collectDynamicLeaves(layoutAttr.getShape(), shapeLeaves); + collectDynamicLeaves(layoutAttr.getStride(), strideLeaves); + if (shapeLeaves.empty() && strideLeaves.empty()) { + Value layout = StaticOp::create(builder, loc, layoutTy); + arg.replaceAllUsesWith(layout); + newArgIdx++; + continue; + } + + SmallVector shapeElems; + SmallVector strideElems; + SmallVector extractOps; + + auto shapeStructTy = cast(structTy.getBody()[0]); + auto strideStructTy = cast(structTy.getBody()[1]); + Value shapeStruct = LLVM::ExtractValueOp::create(builder, loc, shapeStructTy, arg, + llvm::ArrayRef{0}); + Value strideStruct = LLVM::ExtractValueOp::create(builder, loc, strideStructTy, arg, + llvm::ArrayRef{1}); + extractOps.push_back(shapeStruct.getDefiningOp()); + extractOps.push_back(strideStruct.getDefiningOp()); + + for (size_t j = 0; j < shapeLeaves.size(); ++j) { + Type fieldTy = shapeStructTy.getBody()[j]; + Value val = LLVM::ExtractValueOp::create(builder, loc, fieldTy, shapeStruct, + llvm::ArrayRef{static_cast(j)}); + shapeElems.push_back(val); + extractOps.push_back(val.getDefiningOp()); + } + for (size_t j = 0; j < strideLeaves.size(); ++j) { + Type fieldTy = strideStructTy.getBody()[j]; + Value val = LLVM::ExtractValueOp::create(builder, loc, fieldTy, strideStruct, + llvm::ArrayRef{static_cast(j)}); + strideElems.push_back(val); + extractOps.push_back(val.getDefiningOp()); + } + + IntTupleType shapeTy = IntTupleType::get(op.getContext(), layoutAttr.getShape()); + IntTupleType strideTy = IntTupleType::get(op.getContext(), layoutAttr.getStride()); + Value shape = MakeIntTupleOp::create(builder, loc, shapeTy, shapeElems); + Value stride = MakeIntTupleOp::create(builder, loc, strideTy, strideElems); + Value layout = MakeLayoutOp::create(builder, loc, layoutTy, shape, stride); + llvm::SmallPtrSet except(extractOps.begin(), extractOps.end()); + arg.replaceAllUsesExcept(layout, except); + newArgIdx++; + continue; + } + + if (argKinds[i] == ArgKind::MemRefStatic) { + // Static MemRef: only ptr arg, layout is fully static + auto memrefTy = cast(oldInputs[i]); + LayoutAttr layoutAttr = memrefTy.getLayout(); + + BlockArgument ptrArg = entry.getArgument(newArgIdx); + + // Create static layout using MakeIntTupleOp and MakeLayoutOp + IntTupleType shapeTy = IntTupleType::get(op.getContext(), layoutAttr.getShape()); + IntTupleType strideTy = IntTupleType::get(op.getContext(), layoutAttr.getStride()); + Value shape = MakeIntTupleOp::create(builder, loc, shapeTy, ValueRange{}); + Value stride = MakeIntTupleOp::create(builder, loc, strideTy, ValueRange{}); + auto layoutTy = LayoutType::get(op.getContext(), layoutAttr); + Value layout = MakeLayoutOp::create(builder, loc, layoutTy, shape, stride); + + // Create the MakeViewOp with fly.ptr directly + Value view = MakeViewOp::create(builder, loc, memrefTy, ptrArg, layout); + + // Replace uses of the ptr arg + llvm::SmallPtrSet except; + except.insert(view.getDefiningOp()); + ptrArg.replaceAllUsesExcept(view, except); + + newArgIdx++; // Static MemRef uses 1 arg + continue; + } + + if (argKinds[i] == ArgKind::MemRefDynamic) { + // Dynamic MemRef: ptr arg + layout struct arg + auto memrefTy = cast(oldInputs[i]); + LayoutAttr layoutAttr = memrefTy.getLayout(); + + BlockArgument ptrArg = entry.getArgument(newArgIdx); + BlockArgument layoutStructArg = entry.getArgument(newArgIdx + 1); + auto layoutStructTy = cast(layoutStructArg.getType()); + + SmallVector shapeLeaves; + SmallVector strideLeaves; + collectDynamicLeaves(layoutAttr.getShape(), shapeLeaves); + collectDynamicLeaves(layoutAttr.getStride(), strideLeaves); + + SmallVector shapeElems; + SmallVector strideElems; + SmallVector extractOps; + + auto shapeStructTy = cast(layoutStructTy.getBody()[0]); + auto strideStructTy = cast(layoutStructTy.getBody()[1]); + Value shapeStruct = LLVM::ExtractValueOp::create(builder, loc, shapeStructTy, layoutStructArg, + llvm::ArrayRef{0}); + Value strideStruct = LLVM::ExtractValueOp::create( + builder, loc, strideStructTy, layoutStructArg, llvm::ArrayRef{1}); + extractOps.push_back(shapeStruct.getDefiningOp()); + extractOps.push_back(strideStruct.getDefiningOp()); + + for (size_t j = 0; j < shapeLeaves.size(); ++j) { + Type fieldTy = shapeStructTy.getBody()[j]; + Value val = LLVM::ExtractValueOp::create(builder, loc, fieldTy, shapeStruct, + llvm::ArrayRef{static_cast(j)}); + shapeElems.push_back(val); + extractOps.push_back(val.getDefiningOp()); + } + for (size_t j = 0; j < strideLeaves.size(); ++j) { + Type fieldTy = strideStructTy.getBody()[j]; + Value val = LLVM::ExtractValueOp::create(builder, loc, fieldTy, strideStruct, + llvm::ArrayRef{static_cast(j)}); + strideElems.push_back(val); + extractOps.push_back(val.getDefiningOp()); + } + + IntTupleType shapeTy = IntTupleType::get(op.getContext(), layoutAttr.getShape()); + IntTupleType strideTy = IntTupleType::get(op.getContext(), layoutAttr.getStride()); + Value shape = MakeIntTupleOp::create(builder, loc, shapeTy, shapeElems); + Value stride = MakeIntTupleOp::create(builder, loc, strideTy, strideElems); + auto layoutTy = LayoutType::get(op.getContext(), layoutAttr); + Value layout = MakeLayoutOp::create(builder, loc, layoutTy, shape, stride); + + // Create the MakeViewOp with fly.ptr directly + Value view = MakeViewOp::create(builder, loc, memrefTy, ptrArg, layout); + + // Replace uses of the ptr arg (which was the original memref arg) + // Must exclude: extractOps (which use layoutStructArg), and view's defining op (which uses + // ptrArg) + llvm::SmallPtrSet except(extractOps.begin(), extractOps.end()); + except.insert(view.getDefiningOp()); + ptrArg.replaceAllUsesExcept(view, except); + + newArgIdx += 2; // Dynamic MemRef uses 2 args + continue; + } + } + + return true; +} + +/// Lower func::FuncOp arguments: IntTupleType, LayoutType, and MemRefType arguments are lowered +/// to LLVM structs, similar to lowerGpuFuncIntTupleArgs but for func.func operations. +static bool lowerFuncOpIntTupleArgs(func::FuncOp op) { + auto funcType = op.getFunctionType(); + SmallVector oldInputs(funcType.getInputs().begin(), funcType.getInputs().end()); + + // First pass: compute new argument types + SmallVector newInputs; + enum class ArgKind { None, IntTuple, Layout, MemRefStatic, MemRefDynamic }; + SmallVector argKinds; + + bool changed = false; + for (Type oldType : oldInputs) { + if (auto tupleTy = dyn_cast(oldType)) { + auto structTy = getIntTupleStructTypeOrEmpty(tupleTy.getAttr(), op.getContext()); + newInputs.push_back(structTy); + argKinds.push_back(ArgKind::IntTuple); + changed = true; + continue; + } + if (auto layoutTy = dyn_cast(oldType)) { + auto structTy = getLayoutStructTypeOrEmpty(layoutTy.getAttr(), op.getContext()); + newInputs.push_back(structTy); + argKinds.push_back(ArgKind::Layout); + changed = true; + continue; + } + if (auto memrefTy = dyn_cast(oldType)) { + auto ptrTy = getMemRefPtrType(memrefTy); + newInputs.push_back(ptrTy); + if (memrefHasDynamicLayout(memrefTy)) { + auto layoutStructTy = *getMemRefLayoutStructType(memrefTy); + newInputs.push_back(layoutStructTy); + argKinds.push_back(ArgKind::MemRefDynamic); + } else { + argKinds.push_back(ArgKind::MemRefStatic); + } + changed = true; + continue; + } + newInputs.push_back(oldType); + argKinds.push_back(ArgKind::None); + } + + if (!changed) + return false; + + // Update function type + auto newFuncType = FunctionType::get(op.getContext(), newInputs, funcType.getResults()); + op.setType(newFuncType); + + // Handle empty function (declaration only) + if (op.getBody().empty()) + return true; + + Block &entry = op.getBody().front(); + Location loc = op.getLoc(); + + // Transform block arguments: work backwards to handle index shifts from MemRef expansion + for (int i = oldInputs.size() - 1; i >= 0; --i) { + BlockArgument oldArg = entry.getArgument(i); + + if (argKinds[i] == ArgKind::None) { + continue; + } + + if (argKinds[i] == ArgKind::IntTuple || argKinds[i] == ArgKind::Layout) { + size_t newIdx = 0; + for (int j = 0; j < i; ++j) { + newIdx++; + if (argKinds[j] == ArgKind::MemRefDynamic) + newIdx++; + } + oldArg.setType(newInputs[newIdx]); + continue; + } + + if (argKinds[i] == ArgKind::MemRefStatic) { + size_t newIdx = 0; + for (int j = 0; j < i; ++j) { + newIdx++; + if (argKinds[j] == ArgKind::MemRefDynamic) + newIdx++; + } + oldArg.setType(newInputs[newIdx]); + continue; + } + + if (argKinds[i] == ArgKind::MemRefDynamic) { + size_t newIdx = 0; + for (int j = 0; j < i; ++j) { + newIdx++; + if (argKinds[j] == ArgKind::MemRefDynamic) + newIdx++; + } + oldArg.setType(newInputs[newIdx]); + entry.insertArgument(i + 1, newInputs[newIdx + 1], loc); + } + } + + // Reconstruct fly values from the new arguments + OpBuilder builder(&entry, entry.begin()); + + size_t newArgIdx = 0; + for (size_t i = 0; i < oldInputs.size(); ++i) { + if (argKinds[i] == ArgKind::None) { + newArgIdx++; + continue; + } + + if (argKinds[i] == ArgKind::IntTuple) { + auto tupleTy = cast(oldInputs[i]); + auto structTy = cast(newInputs[newArgIdx]); + BlockArgument arg = entry.getArgument(newArgIdx); + + SmallVector dynamicLeaves; + collectDynamicLeaves(tupleTy.getAttr(), dynamicLeaves); + if (dynamicLeaves.empty()) { + Value tuple = StaticOp::create(builder, loc, tupleTy); + arg.replaceAllUsesWith(tuple); + newArgIdx++; + continue; + } + + SmallVector dyncElems; + SmallVector extractOps; + dyncElems.reserve(dynamicLeaves.size()); + + for (size_t j = 0; j < dynamicLeaves.size(); ++j) { + Type fieldTy = structTy.getBody()[j]; + Value val = LLVM::ExtractValueOp::create(builder, loc, fieldTy, arg, + llvm::ArrayRef{static_cast(j)}); + dyncElems.push_back(val); + extractOps.push_back(val.getDefiningOp()); + } + + Value tuple = MakeIntTupleOp::create(builder, loc, tupleTy, dyncElems); + llvm::SmallPtrSet except(extractOps.begin(), extractOps.end()); + arg.replaceAllUsesExcept(tuple, except); + newArgIdx++; + continue; + } + + if (argKinds[i] == ArgKind::Layout) { + auto layoutTy = cast(oldInputs[i]); + auto structTy = cast(newInputs[newArgIdx]); + BlockArgument arg = entry.getArgument(newArgIdx); + LayoutAttr layoutAttr = layoutTy.getAttr(); + + SmallVector shapeLeaves; + SmallVector strideLeaves; + collectDynamicLeaves(layoutAttr.getShape(), shapeLeaves); + collectDynamicLeaves(layoutAttr.getStride(), strideLeaves); + if (shapeLeaves.empty() && strideLeaves.empty()) { + Value layout = StaticOp::create(builder, loc, layoutTy); + arg.replaceAllUsesWith(layout); + newArgIdx++; + continue; + } + + SmallVector shapeElems; + SmallVector strideElems; + SmallVector extractOps; + + auto shapeStructTy = cast(structTy.getBody()[0]); + auto strideStructTy = cast(structTy.getBody()[1]); + Value shapeStruct = LLVM::ExtractValueOp::create(builder, loc, shapeStructTy, arg, + llvm::ArrayRef{0}); + Value strideStruct = LLVM::ExtractValueOp::create(builder, loc, strideStructTy, arg, + llvm::ArrayRef{1}); + extractOps.push_back(shapeStruct.getDefiningOp()); + extractOps.push_back(strideStruct.getDefiningOp()); + + for (size_t j = 0; j < shapeLeaves.size(); ++j) { + Type fieldTy = shapeStructTy.getBody()[j]; + Value val = LLVM::ExtractValueOp::create(builder, loc, fieldTy, shapeStruct, + llvm::ArrayRef{static_cast(j)}); + shapeElems.push_back(val); + extractOps.push_back(val.getDefiningOp()); + } + for (size_t j = 0; j < strideLeaves.size(); ++j) { + Type fieldTy = strideStructTy.getBody()[j]; + Value val = LLVM::ExtractValueOp::create(builder, loc, fieldTy, strideStruct, + llvm::ArrayRef{static_cast(j)}); + strideElems.push_back(val); + extractOps.push_back(val.getDefiningOp()); + } + + IntTupleType shapeTy = IntTupleType::get(op.getContext(), layoutAttr.getShape()); + IntTupleType strideTy = IntTupleType::get(op.getContext(), layoutAttr.getStride()); + Value shape = MakeIntTupleOp::create(builder, loc, shapeTy, shapeElems); + Value stride = MakeIntTupleOp::create(builder, loc, strideTy, strideElems); + Value layout = MakeLayoutOp::create(builder, loc, layoutTy, shape, stride); + llvm::SmallPtrSet except(extractOps.begin(), extractOps.end()); + arg.replaceAllUsesExcept(layout, except); + newArgIdx++; + continue; + } + + if (argKinds[i] == ArgKind::MemRefStatic) { + auto memrefTy = cast(oldInputs[i]); + LayoutAttr layoutAttr = memrefTy.getLayout(); + + BlockArgument ptrArg = entry.getArgument(newArgIdx); + + IntTupleType shapeTy = IntTupleType::get(op.getContext(), layoutAttr.getShape()); + IntTupleType strideTy = IntTupleType::get(op.getContext(), layoutAttr.getStride()); + Value shape = MakeIntTupleOp::create(builder, loc, shapeTy, ValueRange{}); + Value stride = MakeIntTupleOp::create(builder, loc, strideTy, ValueRange{}); + auto layoutTy = LayoutType::get(op.getContext(), layoutAttr); + Value layout = MakeLayoutOp::create(builder, loc, layoutTy, shape, stride); + + Value view = MakeViewOp::create(builder, loc, memrefTy, ptrArg, layout); + + llvm::SmallPtrSet except; + except.insert(view.getDefiningOp()); + ptrArg.replaceAllUsesExcept(view, except); + + newArgIdx++; + continue; + } + + if (argKinds[i] == ArgKind::MemRefDynamic) { + auto memrefTy = cast(oldInputs[i]); + LayoutAttr layoutAttr = memrefTy.getLayout(); + + BlockArgument ptrArg = entry.getArgument(newArgIdx); + BlockArgument layoutStructArg = entry.getArgument(newArgIdx + 1); + auto layoutStructTy = cast(layoutStructArg.getType()); + + SmallVector shapeLeaves; + SmallVector strideLeaves; + collectDynamicLeaves(layoutAttr.getShape(), shapeLeaves); + collectDynamicLeaves(layoutAttr.getStride(), strideLeaves); + + SmallVector shapeElems; + SmallVector strideElems; + SmallVector extractOps; + + auto shapeStructTy = cast(layoutStructTy.getBody()[0]); + auto strideStructTy = cast(layoutStructTy.getBody()[1]); + Value shapeStruct = LLVM::ExtractValueOp::create(builder, loc, shapeStructTy, layoutStructArg, + llvm::ArrayRef{0}); + Value strideStruct = LLVM::ExtractValueOp::create( + builder, loc, strideStructTy, layoutStructArg, llvm::ArrayRef{1}); + extractOps.push_back(shapeStruct.getDefiningOp()); + extractOps.push_back(strideStruct.getDefiningOp()); + + for (size_t j = 0; j < shapeLeaves.size(); ++j) { + Type fieldTy = shapeStructTy.getBody()[j]; + Value val = LLVM::ExtractValueOp::create(builder, loc, fieldTy, shapeStruct, + llvm::ArrayRef{static_cast(j)}); + shapeElems.push_back(val); + extractOps.push_back(val.getDefiningOp()); + } + for (size_t j = 0; j < strideLeaves.size(); ++j) { + Type fieldTy = strideStructTy.getBody()[j]; + Value val = LLVM::ExtractValueOp::create(builder, loc, fieldTy, strideStruct, + llvm::ArrayRef{static_cast(j)}); + strideElems.push_back(val); + extractOps.push_back(val.getDefiningOp()); + } + + IntTupleType shapeTy = IntTupleType::get(op.getContext(), layoutAttr.getShape()); + IntTupleType strideTy = IntTupleType::get(op.getContext(), layoutAttr.getStride()); + Value shape = MakeIntTupleOp::create(builder, loc, shapeTy, shapeElems); + Value stride = MakeIntTupleOp::create(builder, loc, strideTy, strideElems); + auto layoutTy = LayoutType::get(op.getContext(), layoutAttr); + Value layout = MakeLayoutOp::create(builder, loc, layoutTy, shape, stride); + + Value view = MakeViewOp::create(builder, loc, memrefTy, ptrArg, layout); + + llvm::SmallPtrSet except(extractOps.begin(), extractOps.end()); + except.insert(view.getDefiningOp()); + ptrArg.replaceAllUsesExcept(view, except); + + newArgIdx += 2; + continue; + } + } + + return true; +} + +static void collectLeafValues(const IntTupleBuilder &builder, + const IntTupleValueAdaptor &tuple, SmallVectorImpl &out) { + // if (tuple.isLeaf()) { + // out.push_back(tuple.intTupleValue); + // return; + // } + for (int i = 0; i < tuple.rank(); ++i) { + collectLeafValues(builder, builder.at(tuple, i), out); + } +} + +static void collectLeafAttrs(IntTupleAttr attr, SmallVectorImpl &out) { + if (attr.isLeaf()) { + out.push_back(attr.getLeafAsInt()); + return; + } + for (int i = 0; i < attr.rank(); ++i) { + collectLeafAttrs(attr.at(i), out); + } +} + +static Value castPrintfArg(PatternRewriter &rewriter, Location loc, Value value, + std::string &format) { + Type type = value.getType(); + if (isa(type)) { + format += "%ld"; + return arith::IndexCastOp::create(rewriter, loc, rewriter.getI64Type(), value); + } + if (auto intTy = dyn_cast(type)) { + if (intTy.getWidth() <= 32) { + format += "%d"; + if (intTy.getWidth() < 32) { + return arith::ExtSIOp::create(rewriter, loc, rewriter.getI32Type(), value); + } + return value; + } + format += "%ld"; + if (intTy.getWidth() != 64) { + return arith::ExtSIOp::create(rewriter, loc, rewriter.getI64Type(), value); + } + return value; + } + if (auto floatTy = dyn_cast(type)) { + if (floatTy.getWidth() <= 32) { + format += "%f"; + if (floatTy.getWidth() < 32) { + return arith::ExtFOp::create(rewriter, loc, rewriter.getF32Type(), value); + } + return value; + } + format += "%lf"; + if (floatTy.getWidth() != 64) { + return arith::ExtFOp::create(rewriter, loc, rewriter.getF64Type(), value); + } + return value; + } + return nullptr; +} + +static bool appendScalarPrintfArg(PatternRewriter &rewriter, Location loc, Value value, + std::string &format, SmallVectorImpl &args) { + Value casted = castPrintfArg(rewriter, loc, value, format); + if (!casted) { + return false; + } + args.push_back(casted); + return true; +} + +static bool appendIntTuplePrintf(PatternRewriter &rewriter, Location loc, + const IntTupleValueAdaptor &tuple, std::string &format, + SmallVectorImpl &args) { + SmallVector leaves; + IntTupleBuilder builder(rewriter, loc); + collectLeafValues(builder, tuple, leaves); + format += "("; + for (size_t i = 0; i < leaves.size(); ++i) { + if (i > 0) { + format += ", "; + } + if (!appendScalarPrintfArg(rewriter, loc, leaves[i], format, args)) { + return false; + } + } + format += ")"; + return true; +} + +static bool appendIntTuplePrintfStatic(IntTupleAttr attr, std::string &format) { + SmallVector leaves; + collectLeafAttrs(attr, leaves); + format += "("; + for (size_t i = 0; i < leaves.size(); ++i) { + if (i > 0) { + format += ", "; + } + if (leaves[i].isStatic()) { + format += std::to_string(leaves[i].getValue()); + continue; + } + format += "?"; + } + format += ")"; + return true; +} + +template +class IntTupleBinaryOpLowering : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(OpTy op, PatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto lhs = op.getLhs(); + auto rhs = op.getRhs(); + + auto lhsTy = dyn_cast(lhs.getType()); + auto rhsTy = dyn_cast(rhs.getType()); + auto resultTy = dyn_cast(op.getResult().getType()); + if (!lhsTy || !rhsTy || !resultTy) + return failure(); + + // Check if inputs are in normal form (StaticOp or MakeIntTupleOp) + if (!isNormalForm(cast>(lhs)) || + !isNormalForm(cast>(rhs))) + return failure(); + + IntTupleBuilder builder(rewriter, loc); + IntTupleValueAdaptor lhsAdaptor = IntTupleValueAdaptor::create(builder, lhs, lhsTy.getAttr()); + IntTupleValueAdaptor rhsAdaptor = IntTupleValueAdaptor::create(builder, rhs, rhsTy.getAttr()); + + auto result = BinaryOpFn{}(builder, lhsAdaptor, rhsAdaptor); + rewriter.replaceOp(op, builder.finalize(result)); + return success(); + } +}; + +struct IntTupleAddFn { + IntTupleValueAdaptor operator()(IntTupleBuilder &builder, + IntTupleValueAdaptor lhs, IntTupleValueAdaptor rhs) const { + return intTupleAdd(builder, lhs, rhs); + } +}; +struct IntTupleSubFn { + IntTupleValueAdaptor operator()(IntTupleBuilder &builder, + IntTupleValueAdaptor lhs, IntTupleValueAdaptor rhs) const { + return intTupleSub(builder, lhs, rhs); + } +}; +struct IntTupleMulFn { + IntTupleValueAdaptor operator()(IntTupleBuilder &builder, + IntTupleValueAdaptor lhs, IntTupleValueAdaptor rhs) const { + return intTupleMul(builder, lhs, rhs); + } +}; +struct IntTupleDivFn { + IntTupleValueAdaptor operator()(IntTupleBuilder &builder, + IntTupleValueAdaptor lhs, IntTupleValueAdaptor rhs) const { + return intTupleDiv(builder, lhs, rhs); + } +}; +struct IntTupleModFn { + IntTupleValueAdaptor operator()(IntTupleBuilder &builder, + IntTupleValueAdaptor lhs, IntTupleValueAdaptor rhs) const { + return intTupleMod(builder, lhs, rhs); + } +}; + +using IntTupleAddOpLowering = IntTupleBinaryOpLowering; +using IntTupleSubOpLowering = IntTupleBinaryOpLowering; +using IntTupleMulOpLowering = IntTupleBinaryOpLowering; +using IntTupleDivOpLowering = IntTupleBinaryOpLowering; +using IntTupleModOpLowering = IntTupleBinaryOpLowering; + +template class IntTupleReprofileOpLowering : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(OpTy op, PatternRewriter &rewriter) const override { + auto inputTuple = op.getTuple(); + if (auto tupleVal = dyn_cast>(inputTuple)) { + if (isNormalForm(tupleVal)) { + rewriter.replaceOp(op, MakeIntTupleOp::create(rewriter, op.getLoc(), tupleVal.getType(), + tupleVal.getDefiningOp()->getOperands())); + return success(); + } + } + return failure(); + } +}; + +//===----------------------------------------------------------------------===// +// GetShapeOp Lowering +//===----------------------------------------------------------------------===// + +class GetShapeLowering : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(GetShapeOp op, PatternRewriter &rewriter) const override { + auto layout = op.getLayout(); + + if (auto defOp = layout.getDefiningOp()) { + rewriter.replaceOp(op, defOp.getShape()); + return success(); + } + return failure(); + } +}; + +class GetStrideLowering : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(GetStrideOp op, PatternRewriter &rewriter) const override { + auto layout = op.getLayout(); + + if (auto defOp = layout.getDefiningOp()) { + rewriter.replaceOp(op, defOp.getStride()); + return success(); + } + return failure(); + } +}; + +class GetLayoutLowering : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(GetLayoutOp op, PatternRewriter &rewriter) const override { + Value memref = op.getMemref(); + + if (auto makeViewOp = memref.getDefiningOp()) { + rewriter.replaceOp(op, makeViewOp.getLayout()); + return success(); + } + return failure(); + } +}; + +//===----------------------------------------------------------------------===// +// GetScalarOp Lowering +//===----------------------------------------------------------------------===// + +class GetScalarLowering : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(GetScalarOp op, PatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + Value intTuple = op.getIntTuple(); + + auto intTupleTy = dyn_cast(intTuple.getType()); + if (!intTupleTy) + return failure(); + + if (!isNormalForm(cast>(intTuple))) + return failure(); + + IntTupleAttr profile = intTupleTy.getAttr(); + assert(profile.isLeaf() && "IntTuple must be a leaf"); + + Type resultTy = op.getResult().getType(); + if (auto intAttr = dyn_cast(profile.getValue())) { + if (intAttr.isStatic()) { + rewriter.replaceOp( + op, arith::ConstantIntOp::create(rewriter, loc, resultTy, intAttr.getValue())); + return success(); + } else { + auto defOp = intTuple.getDefiningOp(); + if (!defOp) + return failure(); + rewriter.replaceOp(op, defOp->getOperand(0)); + return success(); + } + } + return failure(); + } +}; + +//===----------------------------------------------------------------------===// +// SizeOp Lowering +//===----------------------------------------------------------------------===// + +class SizeOpLowering : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(SizeOp op, PatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + Value intTuple = op.getIntTuple(); + + auto intTupleTy = dyn_cast(intTuple.getType()); + if (!intTupleTy) + return failure(); + if (!isNormalForm(dyn_cast>(intTuple))) { + return failure(); + } + + auto resultTy = dyn_cast(op.getResult().getType()); + if (!resultTy) + return failure(); + + // Use intTupleProduct to compute the size + IntTupleBuilder builder(rewriter, loc); + IntTupleValueAdaptor inputAdaptor = + IntTupleValueAdaptor::create(builder, intTuple, intTupleTy.getAttr()); + IntTupleValueAdaptor productAdaptor = intTupleProduct(builder, inputAdaptor); + + rewriter.replaceOp(op, builder.finalize(productAdaptor)); + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// SliceOp Lowering +//===----------------------------------------------------------------------===// + +class SliceLowering : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(SliceOp op, PatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + Value src = op.getSrc(); + Value coord = op.getCoord(); + + auto srcTy = dyn_cast(src.getType()); + auto coordTy = dyn_cast(coord.getType()); + + if (!srcTy || !coordTy) + return failure(); + + if (!isNormalForm(cast>(src))) + return failure(); + if (!isNormalForm(cast>(coord))) + return failure(); + + IntTupleBuilder builder(rewriter, loc); + IntTupleValueAdaptor srcAdaptor = IntTupleValueAdaptor::create(builder, src, srcTy.getAttr()); + + IntTupleValueAdaptor result = intTupleSlice(builder, srcAdaptor, coordTy.getAttr()); + + rewriter.replaceOp(op, builder.finalize(result)); + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// Crd2IdxOp Lowering +//===----------------------------------------------------------------------===// + +class Crd2IdxLowering : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(Crd2IdxOp op, PatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto coord = op.getCoord(); + auto layout = op.getLayout(); + + auto coordTy = dyn_cast(coord.getType()); + auto layoutTy = dyn_cast(layout.getType()); + if (!coordTy || !layoutTy) + return failure(); + + // Inputs must be in normal form + if (!isNormalForm(cast>(coord))) + return failure(); + if (!isNormalForm(cast>(layout))) + return failure(); + + IntTupleBuilder builder(rewriter, loc); + + IntTupleValueAdaptor coordAdaptor = + IntTupleValueAdaptor::create(builder, coord, coordTy.getAttr()); + IntTupleValueAdaptor shapeAdaptor = IntTupleValueAdaptor::create( + builder, layout.getDefiningOp()->getOperand(0), layoutTy.getAttr().getShape()); + IntTupleValueAdaptor strideAdaptor = IntTupleValueAdaptor::create( + builder, layout.getDefiningOp()->getOperand(1), layoutTy.getAttr().getStride()); + + IntTupleValueAdaptor result = layoutCrd2idx(builder, coordAdaptor, shapeAdaptor, strideAdaptor); + + rewriter.replaceOp(op, builder.finalize(result)); + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// Layout Divide Operations Lowering +//===----------------------------------------------------------------------===// + +/// Template for all four layout divide operations: +/// - LogicalDivideOp -> layoutLogicalDivide +/// - ZippedDivideOp -> layoutZippedDivide +/// - TiledDivideOp -> layoutTiledDivide +/// - FlatDivideOp -> layoutFlatDivide +template &, LayoutValueAdaptor, + LayoutValueAdaptor), + LayoutValueAdaptor (*DivideTileFunc)(LayoutBuilder &, + LayoutValueAdaptor, TileAttr)> +class LayoutDivideOpLowering : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(OpTy op, PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + Value layoutValue = op.getLayout(); + Value divisorValue = op.getDivisor(); + + auto layoutTy = dyn_cast(layoutValue.getType()); + + if (!layoutTy) + return failure(); + if (!isNormalForm(cast>(layoutValue))) + return failure(); + + LayoutBuilder layoutBuilder(rewriter, loc); + LayoutValueAdaptor layoutAdaptor(layoutValue, layoutTy.getAttr()); + + if (auto divisorLayoutTy = dyn_cast(divisorValue.getType())) { + if (!isNormalForm(cast>(divisorValue))) + return failure(); + + LayoutValueAdaptor divisorAdaptor(divisorValue, divisorLayoutTy.getAttr()); + LayoutValueAdaptor result = DivideFunc(layoutBuilder, layoutAdaptor, divisorAdaptor); + + rewriter.replaceOp(op, layoutBuilder.getValue(result)); + return success(); + } + + if (auto divisorTileTy = dyn_cast(divisorValue.getType())) { + TileAttr tileAttr = divisorTileTy.getAttr(); + LayoutValueAdaptor result = DivideTileFunc(layoutBuilder, layoutAdaptor, tileAttr); + + rewriter.replaceOp(op, layoutBuilder.getValue(result)); + return success(); + } + + return failure(); + } +}; + +using LogicalDivideOpLowering = + LayoutDivideOpLowering, + layoutLogicalDivide>; +using ZippedDivideOpLowering = + LayoutDivideOpLowering, + layoutZippedDivide>; +using TiledDivideOpLowering = + LayoutDivideOpLowering, + layoutTiledDivide>; +using FlatDivideOpLowering = + LayoutDivideOpLowering, + layoutFlatDivide>; + +class PrintOpLowering : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(PrintOp op, PatternRewriter &rewriter) const override { + if (!op->getParentOfType()) { + return failure(); + } + + // check all values are in normal form + for (Value val : op.getValues()) { + if (auto intTupleVal = dyn_cast>(val)) { + if (!isNormalForm(intTupleVal)) { + return failure(); + } + } else if (auto layoutVal = dyn_cast>(val)) { + if (!isNormalForm(layoutVal)) { + return failure(); + } + } else { + // TODO: handle other types + continue; + } + } + + auto loc = op.getLoc(); + std::string format; + SmallVector args; + bool first = true; + + auto appendSeparator = [&]() { + if (!first) { + format += " "; + } + first = false; + }; + + for (Value val : op.getValues()) { + appendSeparator(); + if (auto tupleTy = dyn_cast(val.getType())) { + if (tupleTy.getAttr().isStatic()) { + if (!appendIntTuplePrintfStatic(tupleTy.getAttr(), format)) { + return failure(); + } + continue; + } + IntTupleBuilder builder(rewriter, loc); + IntTupleValueAdaptor tuple = IntTupleValueAdaptor::create(builder, val, tupleTy.getAttr()); + if (!appendIntTuplePrintf(rewriter, loc, tuple, format, args)) + return failure(); + } else if (auto layoutTy = dyn_cast(val.getType())) { + format += ""; + if (layoutTy.getAttr().isStatic()) { + if (!appendIntTuplePrintfStatic(layoutTy.getAttr().getShape(), format)) { + return failure(); + } + format += ":"; + if (!appendIntTuplePrintfStatic(layoutTy.getAttr().getStride(), format)) { + return failure(); + } + continue; + } + LayoutBuilder layoutBuilder(rewriter, loc); + LayoutValueAdaptor layout(val, layoutTy.getAttr()); + if (!appendIntTuplePrintf(rewriter, loc, layoutBuilder.getShape(layout), format, args)) { + return failure(); + } + format += ":"; + if (!appendIntTuplePrintf(rewriter, loc, layoutBuilder.getStride(layout), format, args)) { + return failure(); + } + continue; + } else { + return failure(); + } + } + + format += "\n"; + gpu::PrintfOp::create(rewriter, loc, rewriter.getStringAttr(format), args); + rewriter.eraseOp(op); + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// Generated patterns +//===----------------------------------------------------------------------===// + +#include "flydsl/Dialect/Fly/Transforms/LayoutLowering.cpp.inc" + +//===----------------------------------------------------------------------===// +// Pass Definition +//===----------------------------------------------------------------------===// + +class FlyLayoutLoweringPass + : public mlir::fly::impl::FlyLayoutLoweringPassBase { +public: + using mlir::fly::impl::FlyLayoutLoweringPassBase< + FlyLayoutLoweringPass>::FlyLayoutLoweringPassBase; + + void runOnOperation() override { + MLIRContext *context = &getContext(); + getOperation()->walk([&](gpu::GPUFuncOp gpuFunc) { lowerGpuFuncIntTupleArgs(gpuFunc); }); + getOperation()->walk([&](func::FuncOp funcOp) { lowerFuncOpIntTupleArgs(funcOp); }); + getOperation()->walk( + [&](gpu::LaunchFuncOp launchOp) { lowerGpuLaunchFuncIntTupleOperands(launchOp); }); + + RewritePatternSet patterns(context); + + patterns.add(context); + patterns.add(context); + patterns.add(context); + patterns.add(context); + patterns.add(context); + + // Layout algebra lowerings + patterns.add(context); + + populateWithGenerated(patterns); + + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) + signalPassFailure(); + } +}; + +} // namespace + +namespace impl { + +std::unique_ptr<::mlir::Pass> createFlyLayoutLoweringPass() { + return std::make_unique(); +} + +} // namespace impl diff --git a/lib/Dialect/Fly/Utils/IntTupleUtils.cpp b/lib/Dialect/Fly/Utils/IntTupleUtils.cpp new file mode 100644 index 00000000..9a9762fb --- /dev/null +++ b/lib/Dialect/Fly/Utils/IntTupleUtils.cpp @@ -0,0 +1,448 @@ +#include "flydsl/Dialect/Fly/Utils/IntTupleUtils.h" + +#include "mlir/Dialect/Arith/IR/Arith.h" + +namespace mlir::fly { + +bool intTupleHasNone(IntTupleAttr attr) { + if (attr.isLeaf()) { + return attr.isLeafNone(); + } + for (int i = 0; i < attr.rank(); ++i) { + if (intTupleHasNone(attr.at(i))) { + return true; + } + } + return false; +} +bool intTupleAllNone(IntTupleAttr attr) { + if (attr.isLeaf()) { + return attr.isLeafNone(); + } + for (int i = 0; i < attr.rank(); ++i) { + if (!intTupleAllNone(attr.at(i))) { + return false; + } + } + return true; +} + +bool intTupleIsCongruent(IntTupleAttr lhs, IntTupleAttr rhs) { + if (lhs.isLeaf() && rhs.isLeaf()) { + return true; + } + if (lhs.isLeaf() != rhs.isLeaf()) { + return false; + } + if (lhs.rank() != rhs.rank()) { + return false; + } + for (int i = 0; i < lhs.rank(); ++i) { + if (!intTupleIsCongruent(lhs.at(i), rhs.at(i))) { + return false; + } + } + return true; +} +bool intTupleIsWeaklyCongruent(IntTupleAttr lhs, IntTupleAttr rhs) { + if (lhs.isLeaf()) { + return true; + } + if (rhs.isLeaf()) { + return false; + } + if (lhs.rank() != rhs.rank()) { + return false; + } + for (int i = 0; i < lhs.rank(); ++i) { + if (!intTupleIsWeaklyCongruent(lhs.at(i), rhs.at(i))) { + return false; + } + } + return true; +} + +IntTupleBuilder::ArithValue +IntTupleBuilder::add(ArithValue lhs, ArithValue rhs) const { + auto retAttr = attrBuilder.add(lhs.attr, rhs.attr); + auto cmpType = getCommonIntType(lhs.attr, rhs.attr); + return ArithValue{arith::AddIOp::create(builder, loc, extendToIntType(lhs.value, cmpType), + extendToIntType(rhs.value, cmpType)) + .getResult(), + retAttr}; +} +IntTupleBuilder::ArithValue +IntTupleBuilder::sub(ArithValue lhs, ArithValue rhs) const { + auto retAttr = attrBuilder.sub(lhs.attr, rhs.attr); + auto cmpType = getCommonIntType(lhs.attr, rhs.attr); + return ArithValue{arith::SubIOp::create(builder, loc, extendToIntType(lhs.value, cmpType), + extendToIntType(rhs.value, cmpType)) + .getResult(), + retAttr}; +} +IntTupleBuilder::ArithValue +IntTupleBuilder::mul(ArithValue lhs, ArithValue rhs) const { + auto retAttr = attrBuilder.mul(lhs.attr, rhs.attr); + auto cmpType = getCommonIntType(lhs.attr, rhs.attr); + return ArithValue{arith::MulIOp::create(builder, loc, extendToIntType(lhs.value, cmpType), + extendToIntType(rhs.value, cmpType)) + .getResult(), + retAttr}; +} +IntTupleBuilder::ArithValue +IntTupleBuilder::div(ArithValue lhs, ArithValue rhs) const { + auto retAttr = attrBuilder.div(lhs.attr, rhs.attr); + auto cmpType = getCommonIntType(lhs.attr, rhs.attr); + return ArithValue{arith::DivSIOp::create(builder, loc, extendToIntType(lhs.value, cmpType), + extendToIntType(rhs.value, cmpType)) + .getResult(), + retAttr}; +} +IntTupleBuilder::ArithValue +IntTupleBuilder::mod(ArithValue lhs, ArithValue rhs) const { + auto retAttr = attrBuilder.mod(lhs.attr, rhs.attr); + auto cmpType = getCommonIntType(lhs.attr, rhs.attr); + return ArithValue{arith::RemSIOp::create(builder, loc, extendToIntType(lhs.value, cmpType), + extendToIntType(rhs.value, cmpType)) + .getResult(), + retAttr}; +} + +IntTupleBuilder::ArithValue +IntTupleBuilder::min(ArithValue lhs, ArithValue rhs) const { + auto retAttr = attrBuilder.min(lhs.attr, rhs.attr); + auto cmpType = getCommonIntType(lhs.attr, rhs.attr); + return ArithValue{arith::MinSIOp::create(builder, loc, extendToIntType(lhs.value, cmpType), + extendToIntType(rhs.value, cmpType)) + .getResult(), + retAttr}; +} + +IntTupleBuilder::ArithValue +IntTupleBuilder::max(ArithValue lhs, ArithValue rhs) const { + auto retAttr = attrBuilder.max(lhs.attr, rhs.attr); + auto cmpType = getCommonIntType(lhs.attr, rhs.attr); + return ArithValue{arith::MaxSIOp::create(builder, loc, extendToIntType(lhs.value, cmpType), + extendToIntType(rhs.value, cmpType)) + .getResult(), + retAttr}; +} + +IntTupleBuilder::ArithValue +IntTupleBuilder::ceilDiv(ArithValue lhs, ArithValue rhs) const { + auto retAttr = attrBuilder.ceilDiv(lhs.attr, rhs.attr); + auto cmpType = getCommonIntType(lhs.attr, rhs.attr); + return ArithValue{arith::CeilDivSIOp::create(builder, loc, extendToIntType(lhs.value, cmpType), + extendToIntType(rhs.value, cmpType)) + .getResult(), + retAttr}; +} + +IntTupleBuilder::ArithValue +IntTupleBuilder::shapeDiv(ArithValue lhs, ArithValue rhs) const { + auto retAttr = attrBuilder.shapeDiv(lhs.attr, rhs.attr); + auto cmpType = getCommonIntType(lhs.attr, rhs.attr); + return ArithValue{arith::CeilDivSIOp::create(builder, loc, extendToIntType(lhs.value, cmpType), + extendToIntType(rhs.value, cmpType)) + .getResult(), + retAttr}; +} + +IntTupleBuilder::ArithValue +IntTupleBuilder::logicalAnd(ArithValue lhs, ArithValue rhs) const { + auto retAttr = attrBuilder.logicalAnd(lhs.attr, rhs.attr); + auto retType = getIntType(retAttr); + // (lhs != 0) && (rhs != 0) + auto lhsBool = arith::CmpIOp::create( + builder, loc, arith::CmpIPredicate::ne, lhs.value, + arith::ConstantIntOp::create(builder, loc, getIntType(lhs.attr), 0).getResult()); + auto rhsBool = arith::CmpIOp::create( + builder, loc, arith::CmpIPredicate::ne, rhs.value, + arith::ConstantIntOp::create(builder, loc, getIntType(rhs.attr), 0).getResult()); + auto result = arith::AndIOp::create(builder, loc, lhsBool, rhsBool); + return ArithValue{arith::ExtUIOp::create(builder, loc, retType, result).getResult(), retAttr}; +} + +IntTupleBuilder::ArithValue +IntTupleBuilder::logicalOr(ArithValue lhs, ArithValue rhs) const { + auto retAttr = attrBuilder.logicalOr(lhs.attr, rhs.attr); + auto retType = getIntType(retAttr); + // (lhs != 0) || (rhs != 0) + auto lhsBool = arith::CmpIOp::create( + builder, loc, arith::CmpIPredicate::ne, lhs.value, + arith::ConstantIntOp::create(builder, loc, getIntType(lhs.attr), 0).getResult()); + auto rhsBool = arith::CmpIOp::create( + builder, loc, arith::CmpIPredicate::ne, rhs.value, + arith::ConstantIntOp::create(builder, loc, getIntType(rhs.attr), 0).getResult()); + auto result = arith::OrIOp::create(builder, loc, lhsBool, rhsBool); + return ArithValue{arith::ExtUIOp::create(builder, loc, retType, result).getResult(), retAttr}; +} + +IntTupleBuilder::ArithValue +IntTupleBuilder::logicalNot(ArithValue val) const { + auto retAttr = attrBuilder.logicalNot(val.attr); + auto retType = getIntType(retAttr); + auto zero = arith::ConstantIntOp::create(builder, loc, getIntType(val.attr), 0).getResult(); + // !(val) == (val == 0) + auto result = arith::CmpIOp::create(builder, loc, arith::CmpIPredicate::eq, val.value, zero); + return ArithValue{arith::ExtUIOp::create(builder, loc, retType, result).getResult(), retAttr}; +} + +IntTupleBuilder::ArithValue +IntTupleBuilder::lt(ArithValue lhs, ArithValue rhs) const { + auto retAttr = attrBuilder.lt(lhs.attr, rhs.attr); + auto cmpType = getCommonIntType(lhs.attr, rhs.attr); + auto retType = getIntType(retAttr); + auto cmp = arith::CmpIOp::create(builder, loc, arith::CmpIPredicate::slt, + extendToIntType(lhs.value, cmpType), + extendToIntType(rhs.value, cmpType)); + return ArithValue{arith::ExtUIOp::create(builder, loc, retType, cmp).getResult(), retAttr}; +} + +IntTupleBuilder::ArithValue +IntTupleBuilder::le(ArithValue lhs, ArithValue rhs) const { + auto retAttr = attrBuilder.le(lhs.attr, rhs.attr); + auto cmpType = getCommonIntType(lhs.attr, rhs.attr); + auto retType = getIntType(retAttr); + auto cmp = arith::CmpIOp::create(builder, loc, arith::CmpIPredicate::sle, + extendToIntType(lhs.value, cmpType), + extendToIntType(rhs.value, cmpType)); + return ArithValue{arith::ExtUIOp::create(builder, loc, retType, cmp).getResult(), retAttr}; +} + +IntTupleBuilder::ArithValue +IntTupleBuilder::gt(ArithValue lhs, ArithValue rhs) const { + auto retAttr = attrBuilder.gt(lhs.attr, rhs.attr); + auto cmpType = getCommonIntType(lhs.attr, rhs.attr); + auto retType = getIntType(retAttr); + auto cmp = arith::CmpIOp::create(builder, loc, arith::CmpIPredicate::sgt, + extendToIntType(lhs.value, cmpType), + extendToIntType(rhs.value, cmpType)); + return ArithValue{arith::ExtUIOp::create(builder, loc, retType, cmp).getResult(), retAttr}; +} + +IntTupleBuilder::ArithValue +IntTupleBuilder::ge(ArithValue lhs, ArithValue rhs) const { + auto retAttr = attrBuilder.ge(lhs.attr, rhs.attr); + auto cmpType = getCommonIntType(lhs.attr, rhs.attr); + auto retType = getIntType(retAttr); + auto cmp = arith::CmpIOp::create(builder, loc, arith::CmpIPredicate::sge, + extendToIntType(lhs.value, cmpType), + extendToIntType(rhs.value, cmpType)); + return ArithValue{arith::ExtUIOp::create(builder, loc, retType, cmp).getResult(), retAttr}; +} + +IntTupleBuilder::ArithValue +IntTupleBuilder::eq(ArithValue lhs, ArithValue rhs) const { + auto retAttr = attrBuilder.eq(lhs.attr, rhs.attr); + auto cmpType = getCommonIntType(lhs.attr, rhs.attr); + auto retType = getIntType(retAttr); + auto cmp = arith::CmpIOp::create(builder, loc, arith::CmpIPredicate::eq, + extendToIntType(lhs.value, cmpType), + extendToIntType(rhs.value, cmpType)); + return ArithValue{arith::ExtUIOp::create(builder, loc, retType, cmp).getResult(), retAttr}; +} + +IntTupleBuilder::ArithValue +IntTupleBuilder::ne(ArithValue lhs, ArithValue rhs) const { + auto retAttr = attrBuilder.ne(lhs.attr, rhs.attr); + auto cmpType = getCommonIntType(lhs.attr, rhs.attr); + auto retType = getIntType(retAttr); + auto cmp = arith::CmpIOp::create(builder, loc, arith::CmpIPredicate::ne, + extendToIntType(lhs.value, cmpType), + extendToIntType(rhs.value, cmpType)); + return ArithValue{arith::ExtUIOp::create(builder, loc, retType, cmp).getResult(), retAttr}; +} + +IntTupleAttr intTupleWrap(const IntTupleBuilder &builder, IntTupleAttr attr) { + if (attr.isLeaf()) { + SmallVector elements; + elements.push_back(attr); + return IntTupleAttr::get(ArrayAttr::get(attr.getContext(), elements)); + } + return attr; +} +IntTupleAttr intTupleUnwrap(const IntTupleBuilder &builder, IntTupleAttr attr) { + if (!attr.isLeaf()) { + if (attr.rank() == 1) { + return intTupleUnwrap(builder, attr.at(0)); + } + return attr; + } + return attr; +} + +namespace detail { + +std::pair> +intTupleUnflattenImpl(ArrayRef flatElements, IntTupleAttr profile) { + if (profile.isLeaf()) { + return {flatElements[0], flatElements.drop_front()}; + } + SmallVector resultElements; + auto remaining = flatElements; + for (int i = 0; i < profile.rank(); ++i) { + auto [subResult, subRemaining] = intTupleUnflattenImpl(remaining, profile.at(i)); + resultElements.push_back(subResult); + remaining = subRemaining; + } + return std::pair{IntTupleAttr::get(ArrayAttr::get(profile.getContext(), resultElements)), + remaining}; +} + +} // end namespace detail + +IntTupleAttr intTupleUnflatten(const IntTupleBuilder &builder, IntTupleAttr attr, + IntTupleAttr profile) { + if (attr.isLeaf()) { + return attr; + } + SmallVector flatElements; + for (int i = 0; i < attr.rank(); ++i) { + flatElements.push_back(attr.at(i)); + } + auto [result, remaining] = detail::intTupleUnflattenImpl(flatElements, profile); + assert(remaining.empty() && "flat tuple has more elements than profile requires"); + return result; +} +IntTupleAttr intTupleExpand(const IntTupleBuilder &builder, IntTupleAttr attr, + ArrayRef indices) { + if (attr.isLeaf() || indices.empty()) { + return attr; + } + SmallVector elements; + for (int i = 0; i < attr.rank(); ++i) { + bool shouldExpand = false; + for (int32_t idx : indices) { + if (idx == i) { + shouldExpand = true; + break; + } + } + if (shouldExpand && !attr.at(i).isLeaf()) { + for (int j = 0; j < attr.at(i).rank(); ++j) { + elements.push_back(attr.at(i).at(j)); + } + } else { + elements.push_back(attr.at(i)); + } + } + if (elements.size() == 1) { + return cast(elements[0]); + } + return IntTupleAttr::get(ArrayAttr::get(attr.getContext(), elements)); +} +IntTupleAttr intTupleGroup(const IntTupleBuilder &builder, IntTupleAttr attr, + int32_t begin, int32_t end) { + if (attr.isLeaf()) { + return attr; + } + if (end == -1) { + end = attr.rank(); + } + assert(begin >= 0 && begin <= end && "begin must be <= end"); + + SmallVector result; + for (int i = 0; i < begin; ++i) { + result.push_back(attr.at(i)); + } + if (begin < end) { + SmallVector grouped; + for (int i = begin; i < end; ++i) { + grouped.push_back(attr.at(i)); + } + result.push_back(IntTupleAttr::get(ArrayAttr::get(attr.getContext(), grouped))); + } + for (int i = end; i < attr.rank(); ++i) { + result.push_back(attr.at(i)); + } + return IntTupleAttr::get(ArrayAttr::get(attr.getContext(), result)); +} + +//===----------------------------------------------------------------------===// +// Basis operations +//===----------------------------------------------------------------------===// + +IntTupleAttr intTupleExpandBasis(BasisAttr attr) { + auto *ctx = attr.getContext(); + ArrayRef modes = attr.getModes(); + + if (modes.empty()) { + return IntTupleAttr::get(attr.getValue()); + } + + auto zero = IntTupleAttr::get(IntAttr::getStatic(ctx, 0)); + IntTupleAttr result = IntTupleAttr::get(attr.getValue()); + + for (auto it = modes.rbegin(); it != modes.rend(); ++it) { + int32_t n = *it; + SmallVector elements; + for (int32_t i = 0; i < n; ++i) { + elements.push_back(zero); + } + elements.push_back(result); + result = IntTupleAttr::get(ArrayAttr::get(ctx, elements)); + } + return result; +} + +namespace { + +IntTupleAttr intTupleMakeBasisLikeImpl(MLIRContext *ctx, IntTupleAttr profile, + SmallVector &modes) { + if (profile.isLeaf()) { + auto one = IntAttr::getStatic(ctx, 1); + return IntTupleAttr::get(BasisAttr::get(ctx, one, modes)); + } + + SmallVector elements; + for (int32_t i = 0; i < profile.rank(); ++i) { + modes.push_back(i); + elements.push_back(intTupleMakeBasisLikeImpl(ctx, profile.at(i), modes)); + modes.pop_back(); + } + return IntTupleAttr::get(ArrayAttr::get(ctx, elements)); +} + +} // namespace + +IntTupleAttr intTupleMakeBasisLike(IntTupleAttr profile) { + auto *ctx = profile.getContext(); + SmallVector modes; + assert(!profile.isLeaf() && "intTupleMakeBasisLike expects a non-leaf IntTupleAttr"); + return intTupleMakeBasisLikeImpl(ctx, profile, modes); +} + +IntTupleAttr operator+(BasisAttr lhs, BasisAttr rhs) { + IntTupleBuilder builder(lhs.getContext()); + return intTupleAdd(builder, intTupleExpandBasis(lhs), intTupleExpandBasis(rhs)); +} +IntTupleAttr operator+(BasisAttr lhs, IntTupleAttr rhs) { + IntTupleBuilder builder(lhs.getContext()); + return intTupleAdd(builder, intTupleExpandBasis(lhs), rhs); +} +IntTupleAttr operator+(IntTupleAttr lhs, BasisAttr rhs) { + IntTupleBuilder builder(lhs.getContext()); + return intTupleAdd(builder, lhs, intTupleExpandBasis(rhs)); +} + +BasisAttr operator*(BasisAttr lhs, IntAttr rhs) { + return BasisAttr::get(lhs.getContext(), cast(lhs.getValue()) * rhs, lhs.getModes()); +} +BasisAttr operator*(IntAttr lhs, BasisAttr rhs) { + return BasisAttr::get(rhs.getContext(), lhs * cast(rhs.getValue()), rhs.getModes()); +} +BasisAttr operator/(BasisAttr lhs, IntAttr rhs) { + return BasisAttr::get(lhs.getContext(), cast(lhs.getValue()) / rhs, lhs.getModes()); +} + +BasisAttr basisSafeDiv(BasisAttr lhs, IntAttr rhs) { + return BasisAttr::get(lhs.getContext(), intSafeDiv(cast(lhs.getValue()), rhs), + lhs.getModes()); +} +BasisAttr basisCeilDiv(BasisAttr lhs, IntAttr rhs) { + return BasisAttr::get(lhs.getContext(), intCeilDiv(cast(lhs.getValue()), rhs), + lhs.getModes()); +} + +} // namespace mlir::fly diff --git a/lib/Dialect/Fly/Utils/IntUtils.cpp b/lib/Dialect/Fly/Utils/IntUtils.cpp new file mode 100644 index 00000000..4e933452 --- /dev/null +++ b/lib/Dialect/Fly/Utils/IntUtils.cpp @@ -0,0 +1,231 @@ +#include "flydsl/Dialect/Fly/Utils/IntUtils.h" + +namespace mlir::fly { + +IntAttr operator+(IntAttr lhs, IntAttr rhs) { + auto *ctx = lhs.getContext(); + if (lhs.isStatic() && rhs.isStatic()) { + return IntAttr::getStatic(ctx, lhs.getValue() + rhs.getValue()); + } + if (lhs.isStaticValue(0)) { + return rhs; + } + if (rhs.isStaticValue(0)) { + return lhs; + } + int32_t width = std::max(lhs.getWidth(), rhs.getWidth()); + int32_t lhsDiv = lhs.isStatic() ? lhs.getValue() : lhs.getDivisibility(); + int32_t rhsDiv = rhs.isStatic() ? rhs.getValue() : rhs.getDivisibility(); + return IntAttr::getDynamic(ctx, width, utils::divisibilityAdd(lhsDiv, rhsDiv)); +} + +IntAttr operator-(IntAttr lhs, IntAttr rhs) { + auto *ctx = lhs.getContext(); + if (lhs.isStatic() && rhs.isStatic()) { + return IntAttr::getStatic(ctx, lhs.getValue() - rhs.getValue()); + } + if (lhs.isStaticValue(0)) { + return rhs; + } + if (rhs.isStaticValue(0)) { + return lhs; + } + int32_t width = std::max(lhs.getWidth(), rhs.getWidth()); + int32_t lhsDiv = lhs.isStatic() ? lhs.getValue() : lhs.getDivisibility(); + int32_t rhsDiv = rhs.isStatic() ? rhs.getValue() : rhs.getDivisibility(); + return IntAttr::getDynamic(ctx, width, utils::divisibilitySub(lhsDiv, rhsDiv)); +} + +IntAttr operator*(IntAttr lhs, IntAttr rhs) { + auto *ctx = lhs.getContext(); + if (lhs.isStatic() && rhs.isStatic()) { + return IntAttr::getStatic(ctx, lhs.getValue() * rhs.getValue()); + } + if (lhs.isStaticValue(0)) { + return IntAttr::getStatic(ctx, 0); + } + if (rhs.isStaticValue(0)) { + return IntAttr::getStatic(ctx, 0); + } + int32_t width = std::max(lhs.getWidth(), rhs.getWidth()); + int32_t lhsDiv = lhs.isStatic() ? lhs.getValue() : lhs.getDivisibility(); + int32_t rhsDiv = rhs.isStatic() ? rhs.getValue() : rhs.getDivisibility(); + return IntAttr::getDynamic(ctx, width, utils::divisibilityMul(lhsDiv, rhsDiv)); +} + +IntAttr operator/(IntAttr lhs, IntAttr rhs) { + auto *ctx = lhs.getContext(); + if (lhs.isStatic() && rhs.isStatic()) { + return IntAttr::getStatic(ctx, lhs.getValue() / rhs.getValue()); + } + if (lhs.isStaticValue(0)) { + return IntAttr::getStatic(ctx, 0); + } + int32_t width = std::max(lhs.getWidth(), rhs.getWidth()); + int32_t lhsDiv = lhs.isStatic() ? lhs.getValue() : lhs.getDivisibility(); + int32_t rhsDiv = rhs.isStatic() ? rhs.getValue() : rhs.getDivisibility(); + return IntAttr::getDynamic(ctx, width, utils::divisibilityDiv(lhsDiv, rhsDiv)); +} + +IntAttr operator%(IntAttr lhs, IntAttr rhs) { + auto *ctx = lhs.getContext(); + if (lhs.isStatic() && rhs.isStatic()) { + return IntAttr::getStatic(ctx, lhs.getValue() % rhs.getValue()); + } + if (rhs.isStaticValue(1)) { + return IntAttr::getStatic(ctx, 0); + } + if (lhs.isStaticValue(0)) { + return IntAttr::getStatic(ctx, 0); + } + int32_t width = std::max(lhs.getWidth(), rhs.getWidth()); + int32_t lhsDiv = lhs.isStatic() ? lhs.getValue() : lhs.getDivisibility(); + int32_t rhsDiv = rhs.isStatic() ? rhs.getValue() : rhs.getDivisibility(); + return IntAttr::getDynamic(ctx, width, utils::divisibilityModulo(lhsDiv, rhsDiv)); +} + +IntAttr operator&&(IntAttr lhs, IntAttr rhs) { + auto *ctx = lhs.getContext(); + if (lhs.isStaticValue(0)) { + return IntAttr::getStatic(ctx, 0); + } + if (rhs.isStaticValue(0)) { + return IntAttr::getStatic(ctx, 0); + } + return IntAttr::getDynamic(ctx, 32, 1); +} + +IntAttr operator||(IntAttr lhs, IntAttr rhs) { + auto *ctx = lhs.getContext(); + if (lhs.isStatic() && lhs.getValue() != 0) { + return IntAttr::getStatic(ctx, 1); + } + if (rhs.isStatic() && rhs.getValue() != 0) { + return IntAttr::getStatic(ctx, 1); + } + return IntAttr::getDynamic(ctx, 32, 1); +} + +IntAttr operator!(IntAttr val) { + auto *ctx = val.getContext(); + if (val.isStatic()) { + return IntAttr::getStatic(ctx, val.getValue() == 0 ? 1 : 0); + } + return IntAttr::getDynamic(ctx, 32, 1); +} + +IntAttr operator<(IntAttr lhs, IntAttr rhs) { + auto *ctx = lhs.getContext(); + if (lhs.isStatic() && rhs.isStatic()) { + return IntAttr::getStatic(ctx, lhs.getValue() < rhs.getValue() ? 1 : 0); + } + return IntAttr::getDynamic(ctx, 32, 1); +} + +IntAttr operator<=(IntAttr lhs, IntAttr rhs) { + auto *ctx = lhs.getContext(); + if (lhs.isStatic() && rhs.isStatic()) { + return IntAttr::getStatic(ctx, lhs.getValue() <= rhs.getValue() ? 1 : 0); + } + return IntAttr::getDynamic(ctx, 32, 1); +} + +IntAttr operator>(IntAttr lhs, IntAttr rhs) { + auto *ctx = lhs.getContext(); + if (lhs.isStatic() && rhs.isStatic()) { + return IntAttr::getStatic(ctx, lhs.getValue() > rhs.getValue() ? 1 : 0); + } + return IntAttr::getDynamic(ctx, 32, 1); +} + +IntAttr operator>=(IntAttr lhs, IntAttr rhs) { + auto *ctx = lhs.getContext(); + if (lhs.isStatic() && rhs.isStatic()) { + return IntAttr::getStatic(ctx, lhs.getValue() >= rhs.getValue() ? 1 : 0); + } + return IntAttr::getDynamic(ctx, 32, 1); +} + +IntAttr operator==(IntAttr lhs, IntAttr rhs) { + auto *ctx = lhs.getContext(); + if (lhs.isStatic() && rhs.isStatic()) { + return IntAttr::getStatic(ctx, lhs.getValue() == rhs.getValue() ? 1 : 0); + } + return IntAttr::getDynamic(ctx, 32, 1); +} + +IntAttr operator!=(IntAttr lhs, IntAttr rhs) { + auto *ctx = lhs.getContext(); + if (lhs.isStatic() && rhs.isStatic()) { + return IntAttr::getStatic(ctx, lhs.getValue() != rhs.getValue() ? 1 : 0); + } + return IntAttr::getDynamic(ctx, 32, 1); +} + +IntAttr intMin(IntAttr lhs, IntAttr rhs) { + auto *ctx = lhs.getContext(); + if (lhs.isStatic() && rhs.isStatic()) { + return IntAttr::getStatic(ctx, std::min(lhs.getValue(), rhs.getValue())); + } + int32_t width = std::max(lhs.getWidth(), rhs.getWidth()); + int32_t lhsDiv = lhs.isStatic() ? lhs.getValue() : lhs.getDivisibility(); + int32_t rhsDiv = rhs.isStatic() ? rhs.getValue() : rhs.getDivisibility(); + return IntAttr::getDynamic(ctx, width, utils::divisibilityMin(lhsDiv, rhsDiv)); +} + +IntAttr intMax(IntAttr lhs, IntAttr rhs) { + auto *ctx = lhs.getContext(); + if (lhs.isStatic() && rhs.isStatic()) { + return IntAttr::getStatic(ctx, std::max(lhs.getValue(), rhs.getValue())); + } + int32_t width = std::max(lhs.getWidth(), rhs.getWidth()); + int32_t lhsDiv = lhs.isStatic() ? lhs.getValue() : lhs.getDivisibility(); + int32_t rhsDiv = rhs.isStatic() ? rhs.getValue() : rhs.getDivisibility(); + return IntAttr::getDynamic(ctx, width, utils::divisibilityMax(lhsDiv, rhsDiv)); +} + +IntAttr intSafeDiv(IntAttr lhs, IntAttr rhs) { + auto *ctx = lhs.getContext(); + if (lhs.isStatic() && rhs.isStatic()) { + assert(lhs.getValue() % rhs.getValue() == 0); + return IntAttr::getStatic(ctx, lhs.getValue() / rhs.getValue()); + } + if (lhs.isStaticValue(0)) { + return lhs; + } + int32_t width = std::max(lhs.getWidth(), rhs.getWidth()); + int32_t lhsDiv = lhs.isStatic() ? lhs.getValue() : lhs.getDivisibility(); + int32_t rhsDiv = rhs.isStatic() ? rhs.getValue() : rhs.getDivisibility(); + return IntAttr::getDynamic(ctx, width, utils::divisibilityDiv(lhsDiv, rhsDiv)); +} + +IntAttr intCeilDiv(IntAttr lhs, IntAttr rhs) { + auto *ctx = lhs.getContext(); + if (lhs.isStatic() && rhs.isStatic()) { + return IntAttr::getStatic(ctx, (lhs.getValue() + rhs.getValue() - 1) / rhs.getValue()); + } + if (lhs.isStaticValue(0) || lhs.isStaticValue(1)) { + return lhs; + } + int32_t width = std::max(lhs.getWidth(), rhs.getWidth()); + int32_t lhsDiv = lhs.isStatic() ? lhs.getValue() : lhs.getDivisibility(); + int32_t rhsDiv = rhs.isStatic() ? rhs.getValue() : rhs.getDivisibility(); + return IntAttr::getDynamic(ctx, width, utils::divisibilityCeilDiv(lhsDiv, rhsDiv)); +} + +IntAttr intShapeDiv(IntAttr lhs, IntAttr rhs) { + auto *ctx = lhs.getContext(); + if (lhs.isStatic() && rhs.isStatic()) { + assert((lhs.getValue() % rhs.getValue() == 0 || rhs.getValue() % lhs.getValue() == 0)); + return IntAttr::getStatic(ctx, (lhs.getValue() + rhs.getValue() - 1) / rhs.getValue()); + } + if (lhs.isStaticValue(0) || lhs.isStaticValue(1)) { + return lhs; + } + int32_t width = std::max(lhs.getWidth(), rhs.getWidth()); + int32_t lhsDiv = lhs.isStatic() ? lhs.getValue() : lhs.getDivisibility(); + int32_t rhsDiv = rhs.isStatic() ? rhs.getValue() : rhs.getDivisibility(); + return IntAttr::getDynamic(ctx, width, utils::divisibilityCeilDiv(lhsDiv, rhsDiv)); +} + +} // namespace mlir::fly diff --git a/lib/Dialect/Fly/Utils/NormalForm.cpp b/lib/Dialect/Fly/Utils/NormalForm.cpp new file mode 100644 index 00000000..c9ee4135 --- /dev/null +++ b/lib/Dialect/Fly/Utils/NormalForm.cpp @@ -0,0 +1,220 @@ +#include "flydsl/Dialect/Fly/Utils/NormalForm.h" +#include "flydsl/Dialect/Fly/IR/FlyDialect.h" + +namespace mlir::fly { + +//===----------------------------------------------------------------------===// +// NormalBasis: (StaticOp) +// Note: MakeBasisOp is not currently defined, so only StaticOp is valid +//===----------------------------------------------------------------------===// +bool isNormalForm(TypedValue value) { + Operation *defOp = value.getDefiningOp(); + if (!defOp) { + return false; + } + // NormalBasis ::= (StaticOp) + // return isa(defOp); + return false; +} + +//===----------------------------------------------------------------------===// +// NormalIntTuple: (StaticOp) | (MakeIntTupleOp $dyncElems) +//===----------------------------------------------------------------------===// +bool isNormalForm(TypedValue value) { + Operation *defOp = value.getDefiningOp(); + if (!defOp) { + auto tupleTy = value.getType(); + return tupleTy.getAttr().isStatic(); + } + // if (isa(defOp)) { + // return true; + // } + if (isa(defOp)) { + return true; + } + return false; +} + +//===----------------------------------------------------------------------===// +// NormalLayout: (StaticOp) | (MakeLayoutOp NormalIntTuple, NormalIntTuple) +//===----------------------------------------------------------------------===// +bool isNormalForm(TypedValue value) { + Operation *defOp = value.getDefiningOp(); + if (!defOp) { + auto layoutTy = value.getType(); + return layoutTy.getAttr().isStatic(); + } + // NormalLayout ::= (StaticOp) + // if (isa(defOp)) { + // return true; + // } + // NormalLayout ::= (MakeLayoutOp NormalIntTuple, NormalIntTuple) + if (auto makeLayoutOp = dyn_cast(defOp)) { + auto shape = makeLayoutOp.getShape(); + if (!isNormalForm(shape)) { + return false; + } + // Stride is optional + if (auto stride = makeLayoutOp.getStride()) { + if (!isNormalForm(stride)) { + return false; + } + } + return true; + } + return false; +} + +//===----------------------------------------------------------------------===// +// NormalComposedLayout: (StaticOp) +// | (MakeComposedLayoutOp (NormalSwizzle | NormalLayout | NormalComposedLayout), +// NormalIntTuple, NormalLayout) +//===----------------------------------------------------------------------===// + +// Helper: Check if a Value is a valid inner for ComposedLayout +// Inner can be: SwizzleType (always static), LayoutType, or ComposedLayoutType +static bool isNormalInner(Value inner) { + // auto innerType = inner.getType(); + + // SwizzleAttr is embedded in PointerType, not a standalone type + // Check if it's a LayoutType + if (auto layoutTyped = dyn_cast>(inner)) { + return isNormalForm(layoutTyped); + } + // Check if it's a ComposedLayoutType + if (auto composedTyped = dyn_cast>(inner)) { + return isNormalForm(composedTyped); + } + return false; +} + +bool isNormalForm(TypedValue value) { + Operation *defOp = value.getDefiningOp(); + if (!defOp) { + return false; + } + // NormalComposedLayout ::= (StaticOp) + if (isa(defOp)) { + return true; + } + // NormalComposedLayout ::= (MakeComposedLayoutOp inner, offset, outer) + if (auto makeComposedOp = dyn_cast(defOp)) { + // Check inner: (NormalSwizzle | NormalLayout | NormalComposedLayout) + if (!isNormalInner(makeComposedOp.getInner())) { + return false; + } + // Check offset: NormalIntTuple + if (!isNormalForm(makeComposedOp.getOffset())) { + return false; + } + // Check outer: NormalLayout + if (!isNormalForm(makeComposedOp.getOuter())) { + return false; + } + return true; + } + return false; +} + +//===----------------------------------------------------------------------===// +// NormalTile: (StaticOp) | (MakeTileOp (NormalIntTuple | NormalLayout)+) +//===----------------------------------------------------------------------===// +bool isNormalForm(TypedValue value) { + Operation *defOp = value.getDefiningOp(); + if (!defOp) { + return false; + } + // NormalTile ::= (StaticOp) + if (isa(defOp)) { + return true; + } + // NormalTile ::= (MakeTileOp (NormalIntTuple | NormalLayout)+) + if (auto makeTileOp = dyn_cast(defOp)) { + for (Value mode : makeTileOp.getModes()) { + // Each mode can be IntTupleType or LayoutType + if (auto intTupleTyped = dyn_cast>(mode)) { + if (!isNormalForm(intTupleTyped)) { + return false; + } + } else if (auto layoutTyped = dyn_cast>(mode)) { + if (!isNormalForm(layoutTyped)) { + return false; + } + } else { + // Unknown type + return false; + } + } + return true; + } + return false; +} + +//===----------------------------------------------------------------------===// +// NormalCoordTensor: (MakeCoordTensorOp NormalIntTuple, NormalAnyLayout) +// Note: Using MakeIdentityTensorOp as the closest equivalent +//===----------------------------------------------------------------------===// +bool isNormalForm(TypedValue value) { + Operation *defOp = value.getDefiningOp(); + if (!defOp) { + return false; + } + // Static CoordTensor + if (isa(defOp)) { + return true; + } + // NormalCoordTensor via MakeIdentityTensorOp + if (auto makeIdentityTensorOp = dyn_cast(defOp)) { + return isNormalForm(makeIdentityTensorOp.getShape()); + } + return false; +} + +//===----------------------------------------------------------------------===// +// NormalPointer and NormalMemRef +// These are typically created via operations and should be static or from +// well-formed construction operations +//===----------------------------------------------------------------------===// +bool isNormalForm(TypedValue value) { + Operation *defOp = value.getDefiningOp(); + if (!defOp) { + // Block arguments are considered normal form for pointers + return true; + } + // StaticOp produces normal form + if (isa(defOp)) { + return true; + } + // Other operations that produce pointers are considered normal + // as long as they don't have structural requirements + return true; +} + +bool isNormalForm(TypedValue value) { + Operation *defOp = value.getDefiningOp(); + if (!defOp) { + // Block arguments are considered normal form + return true; + } + // StaticOp produces normal form + if (isa(defOp)) { + return true; + } + // MakeFragmentLikeOp with normal layout source + if (auto makeFragmentOp = dyn_cast(defOp)) { + return isNormalForm(makeFragmentOp.getSrc()); + } + return true; +} + +bool isNormalLayout(Value value) { + if (auto layoutTyped = dyn_cast>(value)) { + return isNormalForm(layoutTyped); + } + if (auto composedTyped = dyn_cast>(value)) { + return isNormalForm(composedTyped); + } + return false; +} + +} // namespace mlir::fly diff --git a/python/flydsl/__init__.py b/python/flydsl/__init__.py new file mode 100644 index 00000000..a5921a09 --- /dev/null +++ b/python/flydsl/__init__.py @@ -0,0 +1 @@ +from .compiler import * diff --git a/python/flydsl/compiler/__init__.py b/python/flydsl/compiler/__init__.py new file mode 100644 index 00000000..236470bc --- /dev/null +++ b/python/flydsl/compiler/__init__.py @@ -0,0 +1,3 @@ +from .compiler import compile + +__all__ = ["compile"] diff --git a/python/flydsl/compiler/compiler.py b/python/flydsl/compiler/compiler.py new file mode 100644 index 00000000..c8116aed --- /dev/null +++ b/python/flydsl/compiler/compiler.py @@ -0,0 +1,148 @@ +from contextlib import ExitStack + +from .._mlir.passmanager import PassManager + +from ..lang import MlirModule +from .executor import Executor + + +def _decode_mlir_escaped_bytes(s: str) -> str: + """Decode MLIR string attr content that uses \\xx hex byte escapes (e.g. \\0A, \\09, \\22). + + This is what gpu-module-to-binary emits for `assembly = "..."` (and often `bin = "..."`). + """ + out_chars = [] + i = 0 + n = len(s) + + def _is_hex(c: str) -> bool: + return ("0" <= c <= "9") or ("a" <= c <= "f") or ("A" <= c <= "F") + + while i < n: + ch = s[i] + if ch != "\\": + out_chars.append(ch) + i += 1 + continue + + # Backslash escape. + if i + 2 < n and _is_hex(s[i + 1]) and _is_hex(s[i + 2]): + byte = int(s[i + 1 : i + 3], 16) + out_chars.append(chr(byte)) + i += 3 + continue + + # Common C-style single-char escapes (rare here, but harmless). + if i + 1 < n: + nxt = s[i + 1] + if nxt == "n": + out_chars.append("\n") + i += 2 + continue + if nxt == "t": + out_chars.append("\t") + i += 2 + continue + if nxt == "r": + out_chars.append("\r") + i += 2 + continue + if nxt in ['"', "\\"]: + out_chars.append(nxt) + i += 2 + continue + # Unknown escape: keep the escaped char as-is. + out_chars.append(nxt) + i += 2 + continue + + # Trailing backslash. + i += 1 + + return "".join(out_chars) + + +def _extract_mlir_string_attr(asm: str, attr_name: str) -> str | None: + """Extract and decode a string attribute like `attr_name = "..."` from an MLIR asm dump.""" + marker = f'{attr_name} = "' + start = asm.find(marker) + if start == -1: + return None + + i = start + len(marker) + # Find the closing quote. Skip over \xx escapes as two hex bytes. + while i < len(asm): + if asm[i] == "\\" and i + 2 < len(asm): + # Skip the escape introducer and two following chars (typically hex digits). + i += 3 + continue + if asm[i] == '"': + end = i + encoded = asm[start + len(marker) : end] + return _decode_mlir_escaped_bytes(encoded) + i += 1 + return None + + +def compile( + fx_module: MlirModule, verify=True, print_after_all=False, output_format="fatbin" +): + # gpu-module-to-binary formats are backend-dependent. For ROCm/ROCDL, "isa" + # is the human-readable assembly/ISA dump and "fatbin" is an object container. + fmt_map = { + "fatbin": "fatbin", + "assembly": "isa", + } + if output_format not in fmt_map: + raise ValueError( + f"Unsupported output_format: {output_format}. Use one of {list(fmt_map)}" + ) + + pipeline = ( + "builtin.module(" + "gpu-kernel-outlining{data-layout-str=}," + "fly-canonicalize," + "fly-layout-lowering," + "convert-fly-to-rocdl," + "canonicalize," + "gpu.module(" + "convert-vector-to-llvm," + "canonicalize," + "convert-gpu-to-rocdl{ chipset=gfx000 index-bitwidth=0 runtime=HIP use-bare-ptr-memref-call-conv=true}" + ")," + "rocdl-attach-target{O=2 abi=600 chip=gfx942 correct-sqrt=true daz=false fast=false features= finite-only=false module= triple=amdgcn-amd-amdhsa unsafe-math=false wave64=true}," + "gpu-to-llvm{intersperse-sizes-for-kernels=false use-bare-pointers-for-host=true use-bare-pointers-for-kernels=true}," + "reconcile-unrealized-casts," + f"gpu-module-to-binary{{format={fmt_map[output_format]} opts= section= toolkit=}}" + ")" + ) + mlir_module = fx_module.module + module = mlir_module.parse(mlir_module.operation.get_asm(enable_debug_info=True)) + + try: + with ExitStack() as stack: + stack.enter_context(module.context) + pm = PassManager.parse(pipeline) + pm.enable_verifier(verify) + pm.enable_ir_printing(print_after_all=print_after_all) + + pm.run(module.operation) + except Exception as e: + print(e) + + # Default: produce a runnable executor (requires gpu-module-to-binary to have produced + # a launchable binary container). + if output_format == "fatbin": + return Executor(module) + + # Debug output: return textual assembly/ISA emitted into gpu.binary's `assembly` attribute + # (or `bin` in some toolchains). + # If the toolchain doesn't embed it (or it was elided), fall back to returning the MLIR. + asm = module.operation.get_asm(enable_debug_info=True, large_elements_limit=1 << 30) + text = _extract_mlir_string_attr(asm, "assembly") + if text is not None: + return text + text = _extract_mlir_string_attr(asm, "bin") + if text is not None: + return text + return asm diff --git a/python/flydsl/compiler/executor.py b/python/flydsl/compiler/executor.py new file mode 100644 index 00000000..1d1dcd49 --- /dev/null +++ b/python/flydsl/compiler/executor.py @@ -0,0 +1,44 @@ +import ctypes +import torch + + +from .._mlir.execution_engine import ExecutionEngine + + +class Executor: + def __init__(self, jit_module): + self.engine = ExecutionEngine( + jit_module, + opt_level=3, + shared_libs=[ + "/root/Projects/llvm-project/build/lib/libmlir_rocm_runtime.so", + "/root/Projects/llvm-project/build/lib/libmlir_runner_utils.so", + ], + ) + self.engine.initialize() + + def convert_args(self, args): + if isinstance(args, torch.Tensor): + return ctypes.cast( + ctypes.pointer(ctypes.c_void_p(args.data_ptr())), ctypes.c_void_p + ) + else: + raise TypeError(f"Unsupported argument type: {type(args)}") + + def __call__(self, *args): + return self.__getattr__("__call__")(*args) + + def __getattr__(self, name: str): + try: + func_ptr = self.engine.raw_lookup(name) + func_exe = ctypes.CFUNCTYPE(None, ctypes.c_void_p)(func_ptr) + except KeyError: + raise AttributeError(f"No such function: {name}") from None + + def wrapper(*args): + addresses = [ctypes.c_void_p(0)] + addresses += [self.convert_args(arg) for arg in args] + c_args = (ctypes.c_void_p * len(addresses))(*addresses) + return func_exe(c_args) + + return wrapper diff --git a/python/flydsl/lang/__init__.py b/python/flydsl/lang/__init__.py new file mode 100644 index 00000000..af4aefbe --- /dev/null +++ b/python/flydsl/lang/__init__.py @@ -0,0 +1,2 @@ +from .ir import * +from .typing import * diff --git a/python/flydsl/lang/ir/__init__.py b/python/flydsl/lang/ir/__init__.py new file mode 100644 index 00000000..6d2eea43 --- /dev/null +++ b/python/flydsl/lang/ir/__init__.py @@ -0,0 +1,9 @@ +# from .types import * + +from .core import * +from .module import * + +# from .gpu import * + +# Export MLIR IR types like Type, Value, etc. +from ..._mlir.ir import Type, Value, Context, Location, Module, Attribute, InsertionPoint diff --git a/python/flydsl/lang/ir/core.py b/python/flydsl/lang/ir/core.py new file mode 100644 index 00000000..c83659e7 --- /dev/null +++ b/python/flydsl/lang/ir/core.py @@ -0,0 +1,480 @@ +from functools import partialmethod +from functools import lru_cache + +from flydsl.lang.meta import dsl_api_wrapper + + +from .module import _global_ctx + +from ..._mlir import ir +from ..._mlir.dialects import fly as _fly_ir +from ..._mlir.dialects._fly_enum_gen import AddressSpace, CachePolicy + +from ..._mlir.dialects import arith +from ..._mlir.extras import types as T + + +def _binary_op(lhs, rhs, op: str) -> "ArithValue": + op = op.capitalize() + if arith._is_float_type(lhs.type) and arith._is_float_type(rhs.type): + op += "F" + elif arith._is_integer_like_type(lhs.type) and arith._is_integer_like_type( + lhs.type + ): + op += "I" + else: + raise NotImplementedError(f"Unsupported '{op}' operands: {lhs}, {rhs}") + + op = getattr(arith, f"{op}Op") + return op(lhs, rhs).result + + +@ir.register_value_caster(T.F16Type.static_typeid) +@ir.register_value_caster(T.F32Type.static_typeid) +@ir.register_value_caster(T.F64Type.static_typeid) +@ir.register_value_caster(T.IntegerType.static_typeid) +class ArithValue(ir.Value): + def __init__(self, v): + super().__init__(v) + + __add__ = partialmethod(_binary_op, op="add") + __sub__ = partialmethod(_binary_op, op="sub") + __mul__ = partialmethod(_binary_op, op="mul") + + def __str__(self): + return super().__str__().replace(ir.Value.__name__, ArithValue.__name__) + + +def make_int32(value): + return _fly_ir.make_int32(value) + + +def make_int32_tuple(value): + return _fly_ir.make_int32_tuple(value) + + +def rank(int_or_tuple): + return _fly_ir.rank(int_or_tuple) + + +def depth(int_or_tuple): + return _fly_ir.depth(int_or_tuple) + + +@dsl_api_wrapper +def int_tuple_add(lhs, rhs, loc=None, ip=None): + return _fly_ir.int_tuple_add(lhs, rhs, loc=loc, ip=ip) + + +@dsl_api_wrapper +def int_tuple_sub(lhs, rhs, loc=None, ip=None): + return _fly_ir.int_tuple_sub(lhs, rhs, loc=loc, ip=ip) + + +@dsl_api_wrapper +def int_tuple_mul(lhs, rhs, loc=None, ip=None): + return _fly_ir.int_tuple_mul(lhs, rhs, loc=loc, ip=ip) + + +@dsl_api_wrapper +def int_tuple_div(lhs, rhs, loc=None, ip=None): + return _fly_ir.int_tuple_div(lhs, rhs, loc=loc, ip=ip) + + +@dsl_api_wrapper +def int_tuple_product(int_tuple, loc=None, ip=None): + return _fly_ir.int_tuple_product(int_tuple, loc=loc, ip=ip) + + +@dsl_api_wrapper +def make_identity_tensor(shape, loc=None, ip=None): + return _fly_ir.make_identity_tensor(shape, loc=loc, ip=ip) + + +@dsl_api_wrapper +def make_identity_layout(shape, loc=None, ip=None): + return _fly_ir.make_identity_layout(shape, loc=loc, ip=ip) + + +@dsl_api_wrapper +def make_shape(*shape, loc=None, ip=None): + IntTupleTy, dyncElems = _fly_ir.infer_int_tuple_type(ir.Context.current, shape) + return _fly_ir.make_shape(IntTupleTy, dyncElems, loc=loc, ip=ip) + + +@dsl_api_wrapper +def make_stride(*stride, loc=None, ip=None): + IntTupleTy, dyncElems = _fly_ir.infer_int_tuple_type(ir.Context.current, stride) + return _fly_ir.make_stride(IntTupleTy, dyncElems, loc=loc, ip=ip) + + +@dsl_api_wrapper +def make_coord(*coord, loc=None, ip=None): + IntTupleTy, dyncElems = _fly_ir.infer_int_tuple_type(ir.Context.current, coord) + return _fly_ir.make_coord(IntTupleTy, dyncElems, loc=loc, ip=ip) + + +@dsl_api_wrapper +def make_int_tuple(elems, loc=None, ip=None): + IntTupleTy, dyncElems = _fly_ir.infer_int_tuple_type(ir.Context.current, elems) + return _fly_ir.make_int_tuple(IntTupleTy, dyncElems, loc=loc, ip=ip) + + +@dsl_api_wrapper +def make_layout(shape, stride, loc=None, ip=None): + if not isinstance(shape, ir.Value): + shapeTy, dyncElems = _fly_ir.infer_int_tuple_type(ir.Context.current, shape) + shape = _fly_ir.make_shape(shapeTy, dyncElems, loc=loc, ip=ip) + if not isinstance(stride, ir.Value): + strideTy, dyncElems = _fly_ir.infer_int_tuple_type(ir.Context.current, stride) + stride = _fly_ir.make_stride(strideTy, dyncElems, loc=loc, ip=ip) + return _fly_ir.make_layout(shape, stride=stride, loc=loc, ip=ip) + + +@dsl_api_wrapper +def size(int_tuple, loc=None, ip=None): + return _fly_ir.size(int_tuple, loc=loc, ip=ip) + + +@dsl_api_wrapper +def get_scalar(int_tuple, loc=None, ip=None): + return _fly_ir.get_scalar(int_tuple, loc=loc, ip=ip) + + +@dsl_api_wrapper +def slice(src, coord, loc=None, ip=None): + if not isinstance(coord, ir.Value): + coordTy, dyncElems = _fly_ir.infer_int_tuple_type(ir.Context.current, coord) + coord = _fly_ir.make_coord(coordTy, dyncElems, loc=loc, ip=ip) + return _fly_ir.slice(src, coord, loc=loc, ip=ip) + + +@dsl_api_wrapper +def crd2idx(crd, layout, loc=None, ip=None): + return _fly_ir.crd2idx(crd, layout, loc=loc, ip=ip) + + +@dsl_api_wrapper +def composition(layout, tiler, loc=None, ip=None): + return _fly_ir.composition(layout, tiler, loc=loc, ip=ip) + + +@dsl_api_wrapper +def complement(layout, codomain_size, loc=None, ip=None): + if not isinstance(codomain_size, ir.Value): + codomain_sizeTy, dyncElems = _fly_ir.infer_int_tuple_type( + ir.Context.current, codomain_size + ) + codomain_size = _fly_ir.make_shape(codomain_sizeTy, dyncElems, loc=loc, ip=ip) + return _fly_ir.complement(layout, codomain_size=codomain_size, loc=loc, ip=ip) + + +@dsl_api_wrapper +def coalesce(layout, pattern=None, loc=None, ip=None): + return _fly_ir.coalesce(layout, pattern=pattern, loc=loc, ip=ip) + + +@dsl_api_wrapper +def zip(lhs, rhs, loc=None, ip=None): + return _fly_ir.zip(lhs, rhs, loc=loc, ip=ip) + + +@dsl_api_wrapper +def select(int_tuple, indices, loc=None, ip=None): + return _fly_ir.select(int_tuple, indices=indices, loc=loc, ip=ip) + + +@dsl_api_wrapper +def group(int_tuple, begin: int, end: int, loc=None, ip=None): + return _fly_ir.group(int_tuple, begin=begin, end=end, loc=loc, ip=ip) + + +@dsl_api_wrapper +def append(base, elem, n: int | None = None, loc=None, ip=None): + return _fly_ir.append(base, elem, n=n, loc=loc, ip=ip) + + +@dsl_api_wrapper +def prepend(base, elem, n: int | None = None, loc=None, ip=None): + return _fly_ir.prepend(base, elem, n=n, loc=loc, ip=ip) + + +@dsl_api_wrapper +def logical_divide(layout, divisor, loc=None, ip=None): + return _fly_ir.logical_divide(layout, divisor, loc=loc, ip=ip) + + +@dsl_api_wrapper +def zipped_divide(layout, divisor, loc=None, ip=None): + return _fly_ir.zipped_divide(layout, divisor, loc=loc, ip=ip) + + +@dsl_api_wrapper +def tiled_divide(layout, divisor, loc=None, ip=None): + return _fly_ir.tiled_divide(layout, divisor, loc=loc, ip=ip) + + +@dsl_api_wrapper +def flat_divide(layout, divisor, loc=None, ip=None): + return _fly_ir.flat_divide(layout, divisor, loc=loc, ip=ip) + + +@dsl_api_wrapper +def logical_product(layout, tiler, loc=None, ip=None): + return _fly_ir.logical_product(layout, tiler, loc=loc, ip=ip) + + +@dsl_api_wrapper +def zipped_product(layout, tiler, loc=None, ip=None): + return _fly_ir.zipped_product(layout, tiler, loc=loc, ip=ip) + + +@dsl_api_wrapper +def tiled_product(layout, tiler, loc=None, ip=None): + return _fly_ir.tiled_product(layout, tiler, loc=loc, ip=ip) + + +@dsl_api_wrapper +def flat_product(layout, tiler, loc=None, ip=None): + return _fly_ir.flat_product(layout, tiler, loc=loc, ip=ip) + + +@dsl_api_wrapper +def block_product(layout, tiler, loc=None, ip=None): + return _fly_ir.block_product(layout, tiler, loc=loc, ip=ip) + + +@dsl_api_wrapper +def raked_product(layout, tiler, loc=None, ip=None): + return _fly_ir.raked_product(layout, tiler, loc=loc, ip=ip) + + +@dsl_api_wrapper +def make_atom(atom_type, loc=None, ip=None): + return _fly_ir.make_atom(atom_type, loc=loc, ip=ip) + + +@dsl_api_wrapper +def make_tile(layouts, loc=None, ip=None): + return _fly_ir.make_tile(layouts, loc=loc, ip=ip) + + +@dsl_api_wrapper +def mma_atom_call(mma_atom, d, a, b, c, loc=None, ip=None): + return _fly_ir.mma_atom_call(mma_atom, d, a, b, c, loc=loc, ip=ip) + + +@dsl_api_wrapper +def copy_atom_call(copy_atom, src, dst, loc=None, ip=None): + return _fly_ir.copy_atom_call(copy_atom, src, dst, loc=loc, ip=ip) + + +@dsl_api_wrapper +def make_tiled_copy(copy_atom, layout_tv, tile_mn, loc=None, ip=None): + return _fly_ir.make_tiled_copy(copy_atom, layout_tv, tile_mn, loc=loc, ip=ip) + + +@dsl_api_wrapper +def memref_alloca(memref_type, layout, loc=None, ip=None): + return _fly_ir.memref_alloca(memref_type, layout, loc=loc, ip=ip) + + +@dsl_api_wrapper +def memref_load(memref, indices, loc=None, ip=None): + # `fly.memref.load` expects `indices` as `!fly.int_tuple` (typically a scalar offset). + # Accept convenience forms: + # - int_tuple Value (pass through) + # - python int / tuple/list (make_int_tuple) + # - index/i32/i64 Value (cast index->i32 then make_int_tuple) + if isinstance(indices, ir.Value): + if str(indices.type).startswith("!fly.int_tuple"): + return _fly_ir.memref_load(memref, indices, loc=loc, ip=ip) + # Common case: user passes `index` as a 1-D coordinate/offset. + if str(indices.type) == "index": + indices = arith.IndexCastOp(T.i32(), indices) + indices = make_int_tuple(indices, loc=loc, ip=ip) + return _fly_ir.memref_load(memref, indices, loc=loc, ip=ip) + + # List/tuple (e.g. [row]) or python int. + indices = make_int_tuple(indices, loc=loc, ip=ip) + return _fly_ir.memref_load(memref, indices, loc=loc, ip=ip) + + +@dsl_api_wrapper +def memref_store(value, memref, indices, loc=None, ip=None): + if isinstance(indices, ir.Value): + if str(indices.type).startswith("!fly.int_tuple"): + return _fly_ir.memref_store(value, memref, indices, loc=loc, ip=ip) + if str(indices.type) == "index": + indices = arith.IndexCastOp(T.i32(), indices) + indices = make_int_tuple(indices, loc=loc, ip=ip) + return _fly_ir.memref_store(value, memref, indices, loc=loc, ip=ip) + + indices = make_int_tuple(indices, loc=loc, ip=ip) + return _fly_ir.memref_store(value, memref, indices, loc=loc, ip=ip) + + +@dsl_api_wrapper +def memref_load_vec(memref, loc=None, ip=None): + return _fly_ir.memref_load_vec(memref, loc=loc, ip=ip) + + +@dsl_api_wrapper +def memref_store_vec(vector, memref, loc=None, ip=None): + return _fly_ir.memref_store_vec(vector, memref, loc=loc, ip=ip) + + +@dsl_api_wrapper +def get_layout(memref, loc=None, ip=None): + return _fly_ir.get_layout(memref, loc=loc, ip=ip) + + +@dsl_api_wrapper +def get_iter(memref, loc=None, ip=None): + return _fly_ir.get_iter(memref, loc=loc, ip=ip) + + +@dsl_api_wrapper +def make_view(iter, layout, loc=None, ip=None): + return _fly_ir.make_view(iter, layout, loc=loc, ip=ip) + + +@dsl_api_wrapper +def add_offset(ptr, offset, loc=None, ip=None): + if not isinstance(offset, ir.Value): + offset = make_int_tuple(offset, loc=loc, ip=ip) + return _fly_ir.add_offset(ptr, offset, loc=loc, ip=ip) + + +@dsl_api_wrapper +def cooperative_copy(tiled_copy, partition_idx, src, dst, loc=None, ip=None): + return _fly_ir.cooperative_copy( + tiled_copy, + partition_idx, + src, + dst, + loc=loc, + ip=ip, + ) + + +@dsl_api_wrapper +def print_op(*values, format_str="", loc=None, ip=None): + """ + Print operation for debugging. Supports IntTuple and other value types. + Lowers to printf for host code or gpu.printf for device code. + + Example: + fx.print_op(int_tuple) + fx.print_op(layout) + fx.print_op(value1, value2, value3) + fx.print_op(value1, format_str="v1=%d\n") + """ + return _fly_ir.print_(format_str, list(values), loc=loc, ip=ip) + + +# ============================================================================== +# Fly Type Classes (MLIR-style API) +# ============================================================================== + + +class PointerType: + """ + Fly Pointer Type with MLIR-style static get() method. + + Example: + ptr_ty = PointerType.get(T.f32(), AddressSpace.Global) + ptr_ty = PointerType.get(T.f32(), AddressSpace.Register, alignment=16) + """ + + @staticmethod + def get(elem_ty, address_space, alignment=None): + """ + Create a PointerType. + + Args: + elem_ty: Element type (e.g., T.f32()) + address_space: Address space (AddressSpace.Global, AddressSpace.Shared, AddressSpace.Register) + alignment: Optional alignment value + + Returns: + PointerType as ir.Type + """ + return _fly_ir.PointerType.get(elem_ty, int(address_space), alignment) + + +class MemRefType: + """ + Fly MemRef Type with MLIR-style static get() method. + + Example: + layout_ty = LayoutType.get(ir.Context.current, 16, 1) + memref_ty = MemRefType.get(T.f32(), AddressSpace.Global, layout_ty) + """ + + @staticmethod + def get(elem_ty, address_space, layout, alignment=None): + """ + Create a MemRefType. + + Args: + elem_ty: Element type (e.g., T.f32()) + address_space: Address space (AddressSpace.Global, AddressSpace.Shared, AddressSpace.Register) + layout: Layout type (LayoutType or ir.Type) + alignment: Optional alignment value + + Returns: + MemRefType as ir.Type + """ + # If layout is an ir.Value (from make_layout), get its type + if isinstance(layout, ir.Value): + layout = layout.type + return _fly_ir.MemRefType.get(elem_ty, int(address_space), layout, alignment) + + +class LayoutType: + """ + Fly Layout Type with MLIR-style static get() method. + + Example: + layout_ty = LayoutType.get(ir.Context.current, 16, 1) + layout_ty = LayoutType.get(ir.Context.current, (4, 4), (4, 1)) + """ + + @staticmethod + def get(context, shape, stride): + """ + Create a LayoutType. + + Args: + context: MLIR context + shape: Shape as int or tuple + stride: Stride as int or tuple + + Returns: + LayoutType as ir.Type + """ + return _fly_ir.LayoutType.get(context, shape, stride) + + +class IntTupleType: + """ + Fly IntTuple Type with MLIR-style static get() method. + + Example: + int_tuple_ty = IntTupleType.get(ir.Context.current, (4, 4)) + """ + + @staticmethod + def get(context, int_or_tuple): + """ + Create an IntTupleType. + + Args: + context: MLIR context + int_or_tuple: Python int or tuple + + Returns: + Tuple of (IntTupleType as ir.Type, list of dynamic elements) + """ + return _fly_ir.IntTupleType.get(context, int_or_tuple) diff --git a/python/flydsl/lang/ir/gpu.py b/python/flydsl/lang/ir/gpu.py new file mode 100644 index 00000000..e5dd1964 --- /dev/null +++ b/python/flydsl/lang/ir/gpu.py @@ -0,0 +1,457 @@ +import inspect +from functools import partial +import sys +from pathlib import Path +from functools import wraps +from typing import Any, List, Optional, Tuple, Union, Callable +from typing import Optional, List, Union, TypeVar + +from ..._mlir.dialects._func_ops_gen import FuncOp +from ..._mlir.extras import types as T +from ..._mlir.extras.meta import region_op, op_region_builder + + +from ..._mlir.dialects._ods_common import ( + _cext, + get_default_loc_context, + get_op_result_or_op_results, +) +from ..._mlir.dialects._gpu_ops_gen import _Dialect +from ..._mlir.dialects._gpu_ops_gen import * +from ..._mlir.dialects._gpu_enum_gen import * + + +from ..._mlir.ir import ( + ArrayAttr, + AttrBuilder, + Attribute, + Context, + InsertionPoint, + ShapedType, + Type, + UnitAttr, + Value, + FlatSymbolRefAttr, + FunctionType, + InsertionPoint, + OpView, + Operation, + OpResultList, + Type, + TypeAttr, + Value, + register_attribute_builder, +) + +_block_id = block_id +_thread_id = thread_id +_block_dim = block_dim +_grid_dim = grid_dim + + +class classproperty(property): + def __get__(self, owner_self, owner_cls): + return self.fget(owner_cls) + + +class block_idx: + @classproperty + def x(cls): + return _block_id("x") + + @classproperty + def y(cls): + return _block_id("y") + + @classproperty + def z(cls): + return _block_id("z") + + +class block_dim: + @classproperty + def x(cls): + return _block_dim("x") + + @classproperty + def y(cls): + return _block_dim("y") + + @classproperty + def z(cls): + return _block_dim("z") + + +class thread_idx: + @classproperty + def x(cls): + return _thread_id("x") + + @classproperty + def y(cls): + return _thread_id("y") + + @classproperty + def z(cls): + return _thread_id("z") + + +class grid_dim: + @classproperty + def x(cls): + return _grid_dim("x") + + @classproperty + def y(cls): + return _grid_dim("y") + + @classproperty + def z(cls): + return _grid_dim("z") + + +def gpu_attr(mnemonic, attr_value): + return Attribute.parse(f"#gpu.{mnemonic}<{attr_value}>") + + +class ModuleMeta(type): + def __new__(cls, name, bases, classdict, **kwargs): + ip = classdict.pop("ip") + new = super().__new__(cls, name, bases, classdict) + for k, v in classdict.items(): + if callable(v): + v.qualname = name + ip.__exit__(None, None, None) + return new + + +@_cext.register_operation(_Dialect, replace=True) +class GPUModuleOp(GPUModuleOp): + def __init__( + self, sym_name, targets: Optional[List[Attribute]] = None, *, loc=None, ip=None + ): + if targets is None: + targets = [] + for i, t in enumerate(targets): + if isinstance(t, str): + targets[i] = Attribute.parse(t) + _ods_context = get_default_loc_context(loc) + sym_name = ( + sym_name + if ( + issubclass(type(sym_name), Attribute) + or not AttrBuilder.contains("SymbolNameAttr") + ) + else AttrBuilder.get("SymbolNameAttr")(sym_name, context=_ods_context) + ) + super().__init__(sym_name=sym_name, targets=ArrayAttr.get(targets), ip=ip) + self.regions[0].blocks.append() + + @property + def body(self): + return self.regions[0].blocks[0] + + +module = region_op(GPUModuleOp) + + +class GPUModuleMeta(ModuleMeta): + @classmethod + def __prepare__(cls, name, bases, **kwargs): + loc = kwargs.pop("loc", None) + if loc is None: + loc = get_user_code_loc() + targets = kwargs.pop("targets", None) + gpu_module_op = GPUModuleOp( + sym_name=name, + targets=targets, + ip=kwargs.pop("ip", None), + loc=loc, + ) + ip = InsertionPoint(gpu_module_op.body) + ip.__enter__() + return {"ip": ip, "gpu_module_op": gpu_module_op} + + +@_cext.register_operation(_Dialect, replace=True) +class GPUFuncOp(GPUFuncOp): + def __init__( + self, + sym_name, + function_type, + *, + sym_visibility=None, + arg_attrs=None, + res_attrs=None, + workgroup_attrib_attrs=None, + private_attrib_attrs=None, + loc=None, + ip=None, + ): + super().__init__( + function_type=function_type, + arg_attrs=arg_attrs, + res_attrs=res_attrs, + workgroup_attrib_attrs=workgroup_attrib_attrs, + private_attrib_attrs=private_attrib_attrs, + loc=loc, + ip=ip, + ) + self.operation.attributes["gpu.kernel"] = UnitAttr.get() + _ods_context = get_default_loc_context(loc) + self.operation.attributes["sym_name"] = ( + sym_name + if ( + issubclass(type(sym_name), Attribute) + or not AttrBuilder.contains("SymbolNameAttr") + ) + else AttrBuilder.get("SymbolNameAttr")(sym_name, context=_ods_context) + ) + if sym_visibility is not None: + self.operation.attributes["sym_visibility"] = ( + sym_visibility + if ( + issubclass(type(sym_visibility), Attribute) + or not AttrBuilder.contains("StrAttr") + ) + else AttrBuilder.get("StrAttr")(sym_visibility, context=_ods_context) + ) + + +def isalambda(v): + LAMBDA = lambda: 0 + return isinstance(v, type(LAMBDA)) and v.__name__ == LAMBDA.__name__ + + +def prep_func_types(sig, return_types): + assert not ( + not sig.return_annotation is inspect.Signature.empty and len(return_types) > 0 + ), f"func can use return annotation or explicit return_types but not both" + return_types = ( + sig.return_annotation + if not sig.return_annotation is inspect.Signature.empty + else return_types + ) + if not isinstance(return_types, (tuple, list)): + return_types = [return_types] + return_types = list(return_types) + assert all( + isinstance(r, (str, Type, TypeVar)) or isalambda(r) for r in return_types + ), f"all return types must be ..._mlir types or strings or TypeVars or lambdas {return_types=}" + + input_types = [ + p.annotation + for p in sig.parameters.values() + if not p.annotation is inspect.Signature.empty + ] + assert all( + isinstance(r, (str, Type, TypeVar)) or isalambda(r) for r in input_types + ), f"all input types must be ..._mlir types or strings or TypeVars or lambdas {input_types=}" + user_loc = None + # If ir.Context is none (like for deferred func emit) + if user_loc is None: + user_locs = None + else: + user_locs = [user_loc] * len(sig.parameters) + return input_types, return_types, user_locs + + +@_cext.register_operation(_Dialect, replace=True) +class LaunchFuncOp(LaunchFuncOp): + def __init__( + self, + kernel: List[str], + grid_size: Tuple[Any, Any, Any], + block_size: Tuple[Any, Any, Any], + kernel_operands: List[Value] = None, + async_dependencies=None, + dynamic_shared_memory_size: Optional[Value] = None, + async_object=None, + *, + loc=None, + ip=None, + ): + _ods_context = get_default_loc_context(loc) + if async_dependencies is None: + async_dependencies = [] + async_token = None + grid_size_x, grid_size_y, grid_size_z = grid_size + block_size_x, block_size_y, block_size_z = block_size + + super().__init__( + async_token, + async_dependencies, + kernel, + grid_size_x, + grid_size_y, + grid_size_z, + block_size_x, + block_size_y, + block_size_z, + kernel_operands, + dynamicSharedMemorySize=dynamic_shared_memory_size, + asyncObject=async_object, + loc=loc, + ip=ip, + ) + + +class GPUFunc: + def __init__( + self, + body_builder, + func_op_ctor, + return_op_ctor, + call_op_ctor, + *, + return_types=None, + sym_visibility=None, + sym_name=None, + arg_attrs=None, + res_attrs=None, + func_attrs=None, + function_type=None, + generics: List[Union[TypeVar]] = None, + qualname=None, + loc=None, + ip=None, + ): + assert inspect.isfunction(body_builder), body_builder + assert inspect.isclass(func_op_ctor), func_op_ctor + if return_op_ctor is not None: + assert inspect.isclass(return_op_ctor), return_op_ctor + assert inspect.isclass(call_op_ctor), call_op_ctor + + self.body_builder = body_builder + if sym_name is None: + sym_name = self.body_builder.__name__ + self.func_name = sym_name + self.func_op_ctor = func_op_ctor + self.return_op_ctor = return_op_ctor + self.call_op_ctor = call_op_ctor + self.arg_attrs = arg_attrs + self.res_attrs = res_attrs + self.generics = generics + self.loc = loc + self.ip = ip + self._func_op = None + # in case this function lives inside a class + self.qualname = qualname + + self.sym_visibility = sym_visibility + self.func_attrs = func_attrs + if self.func_attrs is None: + self.func_attrs = {} + self.function_type = function_type + + if return_types is None: + return_types = [] + sig = inspect.signature(self.body_builder) + self.input_types, self.return_types, self.arg_locs = prep_func_types( + sig, return_types + ) + + def __str__(self): + return str(f"{self.__class__} {self.__dict__}") + + def emit(self, *call_args, decl=False, force=False): + if self._func_op is None or decl or force: + if self.function_type is None: + if len(call_args) == 0: + input_types = self.input_types[:] + locals = {"T": T} + for i, v in enumerate(input_types): + if isinstance(v, TypeVar): + v = v.__name__ + if isinstance(v, str): + input_types[i] = Type( + eval(v, self.body_builder.__globals__, locals) + ) + elif isalambda(v): + input_types[i] = v() + else: + input_types = [a.type for a in call_args] + + function_type = TypeAttr.get( + FunctionType.get( + inputs=input_types, + results=self.return_types, + ) + ) + else: + input_types = self.function_type.inputs + function_type = TypeAttr.get(self.function_type) + + self._func_op = self.func_op_ctor( + self.func_name, + function_type, + sym_visibility=self.sym_visibility, + arg_attrs=self.arg_attrs, + res_attrs=self.res_attrs, + loc=self.loc, + ip=self.ip or InsertionPoint.current, + ) + if isinstance(self._func_op, FuncOp): + self._func_op.attributes["llvm.emit_c_interface"] = UnitAttr.get() + for k, v in self.func_attrs.items(): + self._func_op.attributes[k] = v + + self._func_op.regions[0].blocks.append(*input_types, arg_locs=self.arg_locs) + builder_wrapper = op_region_builder( + self._func_op, self._func_op.regions[0], terminator=self.return_op_ctor + ) + + return_types = [] + + def grab_results(*args): + nonlocal return_types + results = self.body_builder(*args) + if isinstance(results, (tuple, list, OpResultList)): + return_types.extend([r.type for r in results]) + elif results is not None: + return_types.append(results.type) + return results + + if self.function_type is None: + builder_wrapper(grab_results) + function_type = FunctionType.get( + inputs=input_types, results=return_types + ) + self._func_op.attributes["function_type"] = TypeAttr.get(function_type) + else: + builder_wrapper(self.body_builder) + + return self._func_op + + +def gpu_func( + f, + *, + sym_visibility=None, + arg_attrs=None, + res_attrs=None, + func_attrs=None, + emit=False, + generics=None, + loc=None, + ip=None, +): + if generics is None and hasattr(f, "__type_params__") and f.__type_params__: + generics = f.__type_params__ + func_ = GPUFunc( + body_builder=f, + func_op_ctor=GPUFuncOp, + return_op_ctor=ReturnOp, + call_op_ctor=LaunchFuncOp, + sym_visibility=sym_visibility, + arg_attrs=arg_attrs, + res_attrs=res_attrs, + func_attrs=func_attrs, + generics=generics, + loc=loc, + ip=ip, + ) + func_.__name__ = f.__name__ + if emit: + func_.emit() + return func_ diff --git a/python/flydsl/lang/ir/module.py b/python/flydsl/lang/ir/module.py new file mode 100644 index 00000000..cc4c7306 --- /dev/null +++ b/python/flydsl/lang/ir/module.py @@ -0,0 +1,212 @@ +import inspect +from typing import Optional + +from ..._mlir import ir +from ..._mlir.extras import types as T +from ..._mlir.dialects import arith, func, _gpu_ops_gen + + +from .gpu import ( + gpu_func, + prep_func_types, + LaunchFuncOp, + block_idx, + thread_idx, + block_dim, + grid_dim, +) + + +class GlobalRAIIMLIRContext: + context: ir.Context + location: ir.Location + + def __init__(self, allow_unregistered_dialects=False): + self.context = ir.Context() + if allow_unregistered_dialects: + self.context.allow_unregistered_dialects = True + self.context.__enter__() + self.location = ir.Location.unknown() + self.location.__enter__() + + def __del__(self): + self.location.__exit__(None, None, None) + self.context.__exit__(None, None, None) + + +class MlirModule: + GPU_MODULE_NAME = "kernels" + + cls_kernel_fn = [] + cls_jit_fn = [] + cls_kernel_sym = {} + + def __init_subclass__(cls, **kwargs): + super().__init_subclass__(**kwargs) + + # Initialize MLIR module for this subclass FIRST + cls.module = ir.Module.create() + cls.module.operation.attributes["gpu.container_module"] = ir.UnitAttr.get() + + with ir.InsertionPoint(cls.module.body): + cls.gpu_module = _gpu_ops_gen.module(cls.GPU_MODULE_NAME) + + # After MLIR module is created, collect functions registered by descriptors + # Descriptors __set_name__ runs during class creation, adding to temporary lists + # We need to move them to the class-specific lists + temp_kernel_fn = [] + temp_jit_fn = [] + temp_kernel_sym = {} + + # Collect from class __dict__ directly (not inherited) + for name, value in cls.__dict__.items(): + if isinstance(value, _KernelDescriptor): + # This descriptor belongs to this class + if hasattr(value, "_wrapper"): + temp_kernel_fn.append(value._wrapper) + temp_kernel_sym[name] = name + elif isinstance(value, _JitDescriptor): + if hasattr(value, "_wrapper"): + temp_jit_fn.append(value._wrapper) + + # Set class-specific lists + cls.cls_kernel_fn = temp_kernel_fn + cls.cls_jit_fn = temp_jit_fn + cls.cls_kernel_sym = temp_kernel_sym + + def __init__(self): + self.kernel_func_op = {} + for fn in self.cls_jit_fn: + fn(self) + for fn in self.cls_kernel_fn: + fn(self) + + def __repr__(self): + return str(self.module) + + def __getattr__(self, name: str): + if name in self.cls_kernel_sym.keys(): + return ir.SymbolRefAttr.get( + [self.GPU_MODULE_NAME, self.cls_kernel_sym[name]] + ) + raise AttributeError(f"{name} not found in kernel functions.") + + @classmethod + def create_gpu_module(cls, module_attrs=None): + cls.gpu_module = _gpu_ops_gen.module("kernels") + + @classmethod + def create_from_mlir_source(cls, file_path: str): + pass + + @classmethod + def kernel(cls, fn): + def wrapper(self, *args, **kwargs): + if len(self.gpu_module.bodyRegion.blocks) == 0: + self.gpu_module.bodyRegion.blocks.append() + with ir.InsertionPoint.at_block_begin(self.gpu_module.bodyRegion.blocks[0]): + self.kernel_func_op[fn.__name__] = gpu_func(fn, emit=True) + + cls.cls_kernel_fn.append(wrapper) + cls.cls_kernel_sym[fn.__name__] = fn.__name__ + return fn + + @classmethod + def jit(cls, fn): + def wrapper(self): + with ir.InsertionPoint.at_block_begin(self.module.body): + sig = inspect.signature(fn) + input_types, return_types, _ = prep_func_types(sig, []) + func.FuncOp.from_py_func(*input_types)(fn) + + cls.cls_jit_fn.append(wrapper) + return fn + + +class _KernelDescriptor: + """Descriptor that automatically registers kernel to the correct class.""" + + def __init__(self, fn): + self.fn = fn + self.name = fn.__name__ + self._wrapper = None + + def __set_name__(self, owner, name): + """Called when the descriptor is assigned to a class attribute.""" + # Check if owner is a subclass of MlirModule + try: + if issubclass(owner, MlirModule): + # Capture fn in the closure + fn = self.fn + + def wrapper(instance_self, *args, **kwargs): + if len(instance_self.gpu_module.bodyRegion.blocks) == 0: + instance_self.gpu_module.bodyRegion.blocks.append() + with ir.InsertionPoint.at_block_begin( + instance_self.gpu_module.bodyRegion.blocks[0] + ): + instance_self.kernel_func_op[fn.__name__] = gpu_func( + fn, emit=True + ) + + # Store the wrapper in the descriptor for later collection + self._wrapper = wrapper + self._name = name + except TypeError: + # owner is not a class, skip + pass + + def __get__(self, obj, objtype=None): + """Return the original function for method access.""" + if obj is None: + return self.fn + return self.fn.__get__(obj, objtype) + + +class _JitDescriptor: + """Descriptor that automatically registers jit function to the correct class.""" + + def __init__(self, fn): + self.fn = fn + self.name = fn.__name__ + self._wrapper = None + + def __set_name__(self, owner, name): + """Called when the descriptor is assigned to a class attribute.""" + # Check if owner is a subclass of MlirModule + try: + if issubclass(owner, MlirModule): + # Capture fn in the closure + fn = self.fn + + def wrapper(instance_self): + with ir.InsertionPoint.at_block_begin(instance_self.module.body): + sig = inspect.signature(fn) + input_types, return_types, _ = prep_func_types(sig, []) + func.FuncOp.from_py_func(*input_types)(fn) + + # Store the wrapper in the descriptor for later collection + self._wrapper = wrapper + except TypeError: + # owner is not a class, skip + pass + + def __get__(self, obj, objtype=None): + """Return the original function for method access.""" + if obj is None: + return self.fn + return self.fn.__get__(obj, objtype) + + +# Use descriptor-based decorators that return descriptors +def kernel(fn): + """Decorator that returns a descriptor for automatic class detection.""" + return _KernelDescriptor(fn) + + +def jit(fn): + """Decorator that returns a descriptor for automatic class detection.""" + return _JitDescriptor(fn) + + +_global_ctx = GlobalRAIIMLIRContext() diff --git a/python/flydsl/lang/ir/types.py b/python/flydsl/lang/ir/types.py new file mode 100644 index 00000000..011fa158 --- /dev/null +++ b/python/flydsl/lang/ir/types.py @@ -0,0 +1,5 @@ +# from fly_mlir.extras import types as T + + +class Tensor: + pass diff --git a/python/flydsl/lang/meta.py b/python/flydsl/lang/meta.py new file mode 100644 index 00000000..7c4d52e0 --- /dev/null +++ b/python/flydsl/lang/meta.py @@ -0,0 +1,30 @@ +import inspect +from functools import wraps + +from .._mlir import ir + + +def dsl_api_wrapper(op): + @wraps(op) + def wrapper(*args, **kwargs): + loc = kwargs.pop("loc", None) + if loc is None: + frame = inspect.currentframe().f_back + frameInfo = inspect.getframeinfo(frame) + file_loc = ir.Location.file( + frameInfo.filename, + frameInfo.positions.lineno, + frameInfo.positions.col_offset, + ) + loc = ir.Location.name( + ( + "".join([c.strip() for c in frameInfo.code_context]) + if frameInfo.code_context + else frameInfo.function + ), + childLoc=file_loc, + ) + with loc: + return op(*args, **kwargs) + + return wrapper diff --git a/python/flydsl/lang/typing.py b/python/flydsl/lang/typing.py new file mode 100644 index 00000000..7e437ecc --- /dev/null +++ b/python/flydsl/lang/typing.py @@ -0,0 +1,32 @@ +import ctypes +import numpy as np +import operator +from typing_extensions import deprecated +from functools import reduce +from typing import ( + Generic, + Protocol, + Union, + Any, + List, + Type, + TypeVar, + overload, + runtime_checkable, + get_origin, +) +from types import FunctionType +from dataclasses import dataclass +from abc import ABC, abstractmethod + + +class NumericType: + pass + + +class Int32: + def __init__(self, value): + self.value = value + + def __repr__(self): + return f"Int32({self.value})" diff --git a/python/flydsl/utils/__init__.py b/python/flydsl/utils/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/python/flydsl/utils/env_manager.py b/python/flydsl/utils/env_manager.py new file mode 100644 index 00000000..e69de29b diff --git a/python/flydsl/utils/hip_utils.py b/python/flydsl/utils/hip_utils.py new file mode 100644 index 00000000..146515fa --- /dev/null +++ b/python/flydsl/utils/hip_utils.py @@ -0,0 +1,2 @@ +def get_hip_arch(): + pass diff --git a/python/flydsl/utils/logger.py b/python/flydsl/utils/logger.py new file mode 100644 index 00000000..e69de29b diff --git a/python/mlir_flydsl/CMakeLists.txt b/python/mlir_flydsl/CMakeLists.txt new file mode 100644 index 00000000..ba6ed8d7 --- /dev/null +++ b/python/mlir_flydsl/CMakeLists.txt @@ -0,0 +1,129 @@ +include(AddMLIRPython) + +add_compile_definitions("MLIR_PYTHON_PACKAGE_PREFIX=flydsl.${MLIR_PYTHON_PACKAGE_PREFIX}.") + +declare_mlir_python_sources(FlyPythonSources) + +declare_mlir_dialect_python_bindings( + ADD_TO_PARENT FlyPythonSources + ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/" + TD_FILE dialects/FlyOps.td + SOURCES + dialects/fly.py + _mlir_libs/_mlirRegisterEverything/py.typed + DIALECT_NAME fly + GEN_ENUM_BINDINGS +) + + +declare_mlir_python_extension(FlyPythonSources.Core + MODULE_NAME _fly + ADD_TO_PARENT FlyPythonSources + ROOT_DIR "${PROJECT_SOURCE_DIR}/lib/Bindings/Python" + PYTHON_BINDINGS_LIBRARY nanobind + SOURCES + MainModules.cpp + PRIVATE_LINK_LIBS + LLVMSupport + MLIRFlyDialect + EMBED_CAPI_LINK_LIBS + MLIRCAPIIR +) + + + +declare_mlir_python_extension(FlyPythonSources.RegisterEverything + MODULE_NAME _mlirRegisterEverything + ADD_TO_PARENT FlyPythonSources + PYTHON_BINDINGS_LIBRARY nanobind + SOURCES + FlyRegisterEverything.cpp + PRIVATE_LINK_LIBS + LLVMSupport + MLIRFlyToROCDL + EMBED_CAPI_LINK_LIBS + MLIRCAPIIR + MLIRCPIFly + MLIRCAPIArith + MLIRCAPIGPU + MLIRCAPILLVM + MLIRCAPIMath + MLIRCAPIVector + MLIRCAPIConversion + MLIRCAPITransforms + MLIRCAPIRegisterEverything +) + + +set(MLIRFlyDSLSources + FlyPythonSources + MLIRPythonSources.Core + MLIRPythonSources.Dialects.builtin + MLIRPythonSources.Dialects.arith + MLIRPythonSources.Dialects.math + MLIRPythonSources.Dialects.gpu + MLIRPythonSources.Dialects.func + MLIRPythonSources.Dialects.cf + MLIRPythonSources.Dialects.scf + MLIRPythonSources.Dialects.rocdl + MLIRPythonSources.Dialects.vector + MLIRPythonSources.Dialects.llvm + MLIRPythonSources.ExecutionEngine +) + +add_mlir_python_common_capi_library(FlyPythonCAPI + INSTALL_COMPONENT FlyPythonModules + INSTALL_DESTINATION "${MLIR_BINDINGS_PYTHON_INSTALL_PREFIX}/_mlir_libs" + OUTPUT_DIRECTORY "${MLIR_BINARY_DIR}/${MLIR_BINDINGS_PYTHON_INSTALL_PREFIX}/_mlir_libs" + RELATIVE_INSTALL_ROOT "../../../.." + DECLARED_SOURCES ${MLIRFlyDSLSources} +) + + +set(FlyPythonModules_ROOT_PREFIX "${MLIR_BINARY_DIR}/${MLIR_BINDINGS_PYTHON_INSTALL_PREFIX}") + +# set(_core_type_stub_sources +# _mlir/__init__.pyi +# _mlir/ir.pyi +# _mlir/passmanager.pyi +# _mlir/rewrite.pyi +# ) +# get_target_property(_core_extension_srcs MLIRPythonExtension.Core INTERFACE_SOURCES) +# mlir_generate_type_stubs( +# MODULE_NAME _mlir +# DEPENDS_TARGETS FlyPythonModules.extension._mlir.dso +# OUTPUT_DIR "${CMAKE_CURRENT_BINARY_DIR}/type_stubs/_mlir_libs" +# OUTPUTS "${_core_type_stub_sources}" +# DEPENDS_TARGET_SRC_DEPS "${_core_extension_srcs}" +# IMPORT_PATHS "${FlyPythonModules_ROOT_PREFIX}/_mlir_libs" +# VERBOSE +# ) +# set(_mlir_typestub_gen_target "${NB_STUBGEN_CUSTOM_TARGET}") + +# list(TRANSFORM _core_type_stub_sources PREPEND "_mlir_libs/") +# declare_mlir_python_sources( +# FlyPythonExtension.Core.type_stub_gen +# ROOT_DIR "${CMAKE_CURRENT_BINARY_DIR}/type_stubs" +# ADD_TO_PARENT FlyPythonSources +# SOURCES "${_core_type_stub_sources}" +# ) + + + +add_mlir_python_modules(FlyPythonModules + ROOT_PREFIX "${FlyPythonModules_ROOT_PREFIX}" + INSTALL_PREFIX "${MLIR_BINDINGS_PYTHON_INSTALL_PREFIX}" + DECLARED_SOURCES "${MLIRFlyDSLSources}" + COMMON_CAPI_LINK_LIBS + FlyPythonCAPI +) + +add_custom_target(CopyFlyPythonSources ALL + COMMAND ${CMAKE_COMMAND} -E copy_directory + "${PROJECT_SOURCE_DIR}/python/flydsl" + "${MLIR_BINARY_DIR}/python_packages/flydsl" + COMMENT "Copying python/flydsl sources to build/python_packages/flydsl" + DEPENDS FlyPythonModules +) + +add_dependencies(CopyFlyPythonSources FlyPythonModules) diff --git a/python/mlir_flydsl/FlyRegisterEverything.cpp b/python/mlir_flydsl/FlyRegisterEverything.cpp new file mode 100644 index 00000000..94b877a2 --- /dev/null +++ b/python/mlir_flydsl/FlyRegisterEverything.cpp @@ -0,0 +1,33 @@ +#include "mlir-c/RegisterEverything.h" +#include "mlir/Bindings/Python/Nanobind.h" +#include "mlir/Bindings/Python/NanobindAdaptors.h" + +#include "flydsl-c/FlyDialect.h" +#include "flydsl/Conversion/FlyToROCDL/FlyToROCDL.h" +#include "flydsl/Dialect/Fly/IR/FlyDialect.h" +#include "flydsl/Dialect/Fly/Transforms/Passes.h" + +namespace mlir { +#define GEN_PASS_REGISTRATION +#include "flydsl/Conversion/Passes.h.inc" +} // namespace mlir + +NB_MODULE(_mlirRegisterEverything, m) { + m.doc() = "MLIR All Upstream Dialects, Translations and Passes Registration"; + + m.def("register_dialects", [](MlirDialectRegistry registry) { + mlirRegisterAllDialects(registry); + + MlirDialectHandle flyHandle = mlirGetDialectHandle__fly__(); + mlirDialectHandleInsertDialect(flyHandle, registry); + }); + m.def("register_llvm_translations", + [](MlirContext context) { mlirRegisterAllLLVMTranslations(context); }); + + // Register all passes on load. + mlirRegisterAllPasses(); + + mlir::fly::registerFlyPasses(); + // Register Fly to ROCDL conversion pass + mlir::registerFlyToROCDLConversionPass(); +} diff --git a/python/mlir_flydsl/_mlir_libs/_mlirRegisterEverything/py.typed b/python/mlir_flydsl/_mlir_libs/_mlirRegisterEverything/py.typed new file mode 100644 index 00000000..e69de29b diff --git a/python/mlir_flydsl/dialects/FlyOps.td b/python/mlir_flydsl/dialects/FlyOps.td new file mode 100644 index 00000000..b43270dd --- /dev/null +++ b/python/mlir_flydsl/dialects/FlyOps.td @@ -0,0 +1,7 @@ +#ifndef PYTHON_BINDINGS_FLY_OPS +#define PYTHON_BINDINGS_FLY_OPS + +include "flydsl/Dialect/Fly/IR/FlyOps.td" +include "flydsl/Dialect/Fly/IR/FlyAttrDefs.td" + +#endif // PYTHON_BINDINGS_FLY_OPS diff --git a/python/mlir_flydsl/dialects/fly.py b/python/mlir_flydsl/dialects/fly.py new file mode 100644 index 00000000..e0e5fbf5 --- /dev/null +++ b/python/mlir_flydsl/dialects/fly.py @@ -0,0 +1,4 @@ +from ._fly_enum_gen import * +from ._fly_ops_gen import * + +from .._mlir_libs._fly import * From dd8c6feb902d8303d8ef46babf778ddf2ba1151d Mon Sep 17 00:00:00 2001 From: Feng Shijie Date: Mon, 26 Jan 2026 06:19:54 +0000 Subject: [PATCH 002/113] update header macro --- include/flydsl-c/FlyDialect.h | 6 +++--- include/flydsl/Conversion/Passes.h | 6 +++--- include/flydsl/Conversion/Passes.td | 6 +++--- include/flydsl/Dialect/Fly/IR/FlyDialect.h | 6 +++--- include/flydsl/Dialect/Fly/Transforms/Passes.h | 4 ++-- include/flydsl/Dialect/Fly/Utils/IntTupleUtils.h | 6 +++--- include/flydsl/Dialect/Fly/Utils/IntUtils.h | 6 +++--- include/flydsl/Dialect/Fly/Utils/LayoutUtils.h | 6 +++--- include/flydsl/Dialect/Fly/Utils/NormalForm.h | 6 +++--- lib/Dialect/Fly/IR/FlyOps.cpp | 2 +- 10 files changed, 27 insertions(+), 27 deletions(-) diff --git a/include/flydsl-c/FlyDialect.h b/include/flydsl-c/FlyDialect.h index fefe8791..d27e0206 100644 --- a/include/flydsl-c/FlyDialect.h +++ b/include/flydsl-c/FlyDialect.h @@ -1,5 +1,5 @@ -#ifndef FLY_C_DIALECTS_H -#define FLY_C_DIALECTS_H +#ifndef FLYDSL_C_DIALECTS_H +#define FLYDSL_C_DIALECTS_H #include "mlir-c/IR.h" @@ -13,4 +13,4 @@ MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(Fly, fly); } #endif -#endif // FLY_C_DIALECTS_H +#endif // FLYDSL_C_DIALECTS_H diff --git a/include/flydsl/Conversion/Passes.h b/include/flydsl/Conversion/Passes.h index 20789e1e..3da57ec1 100644 --- a/include/flydsl/Conversion/Passes.h +++ b/include/flydsl/Conversion/Passes.h @@ -1,6 +1,6 @@ -#ifndef FLY_CONVERSION_PASSES_H -#define FLY_CONVERSION_PASSES_H +#ifndef FLYDSL_CONVERSION_PASSES_H +#define FLYDSL_CONVERSION_PASSES_H #include "flydsl/Conversion/FlyToROCDL/FlyToROCDL.h" @@ -11,4 +11,4 @@ namespace mlir { } // namespace mlir -#endif // FLY_CONVERSION_PASSES_H +#endif // FLYDSL_CONVERSION_PASSES_H diff --git a/include/flydsl/Conversion/Passes.td b/include/flydsl/Conversion/Passes.td index d70ffaff..5b873657 100644 --- a/include/flydsl/Conversion/Passes.td +++ b/include/flydsl/Conversion/Passes.td @@ -1,5 +1,5 @@ -#ifndef FLY_PASSES -#define FLY_PASSES +#ifndef FLYDSL_CONVERSION_PASSES +#define FLYDSL_CONVERSION_PASSES include "mlir/Pass/PassBase.td" @@ -18,4 +18,4 @@ def FlyToROCDLConversionPass : Pass<"convert-fly-to-rocdl"> { ]; } -#endif // FLY_PASSES +#endif // FLYDSL_CONVERSION_PASSES diff --git a/include/flydsl/Dialect/Fly/IR/FlyDialect.h b/include/flydsl/Dialect/Fly/IR/FlyDialect.h index 64de1156..20e78737 100644 --- a/include/flydsl/Dialect/Fly/IR/FlyDialect.h +++ b/include/flydsl/Dialect/Fly/IR/FlyDialect.h @@ -1,5 +1,5 @@ -#ifndef FLY_DIALECT_FLY_IR_DIALECT_H -#define FLY_DIALECT_FLY_IR_DIALECT_H +#ifndef FLYDSL_DIALECT_FLY_IR_DIALECT_H +#define FLYDSL_DIALECT_FLY_IR_DIALECT_H #include "mlir/Bytecode/BytecodeOpInterface.h" #include "mlir/IR/Attributes.h" @@ -33,4 +33,4 @@ namespace mlir::fly { #include "flydsl/Dialect/Fly/IR/FlyTypeConstraints.h.inc" } // namespace mlir::fly -#endif // FLY_DIALECT_FLY_IR_DIALECT_H +#endif // FLYDSL_DIALECT_FLY_IR_DIALECT_H diff --git a/include/flydsl/Dialect/Fly/Transforms/Passes.h b/include/flydsl/Dialect/Fly/Transforms/Passes.h index ff2319b8..6db3b945 100644 --- a/include/flydsl/Dialect/Fly/Transforms/Passes.h +++ b/include/flydsl/Dialect/Fly/Transforms/Passes.h @@ -1,5 +1,5 @@ -#ifndef FLY_TRANSFORM_H -#define FLY_TRANSFORM_H +#ifndef FLYDSL_TRANSFORM_H +#define FLYDSL_TRANSFORM_H #include "mlir/Pass/Pass.h" diff --git a/include/flydsl/Dialect/Fly/Utils/IntTupleUtils.h b/include/flydsl/Dialect/Fly/Utils/IntTupleUtils.h index 2ffc8121..cd3ec262 100644 --- a/include/flydsl/Dialect/Fly/Utils/IntTupleUtils.h +++ b/include/flydsl/Dialect/Fly/Utils/IntTupleUtils.h @@ -1,5 +1,5 @@ -#ifndef FLY_DIALECT_UTILS_INTTUPLEUTILS_H -#define FLY_DIALECT_UTILS_INTTUPLEUTILS_H +#ifndef FLYDSL_DIALECT_UTILS_INTTUPLEUTILS_H +#define FLYDSL_DIALECT_UTILS_INTTUPLEUTILS_H #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/IR/Attributes.h" @@ -1001,4 +1001,4 @@ BasisAttr basisCeilDiv(BasisAttr lhs, IntAttr rhs); } // namespace mlir::fly -#endif // FLY_DIALECT_UTILS_INTTUPLEUTILS_H +#endif // FLYDSL_DIALECT_UTILS_INTTUPLEUTILS_H diff --git a/include/flydsl/Dialect/Fly/Utils/IntUtils.h b/include/flydsl/Dialect/Fly/Utils/IntUtils.h index 3e14fbac..995fea03 100644 --- a/include/flydsl/Dialect/Fly/Utils/IntUtils.h +++ b/include/flydsl/Dialect/Fly/Utils/IntUtils.h @@ -1,5 +1,5 @@ -#ifndef FLY_DIALECT_UTILS_INTUTILS_H -#define FLY_DIALECT_UTILS_INTUTILS_H +#ifndef FLYDSL_DIALECT_UTILS_INTUTILS_H +#define FLYDSL_DIALECT_UTILS_INTUTILS_H #include "mlir/IR/Attributes.h" #include "mlir/Support/LogicalResult.h" @@ -50,4 +50,4 @@ IntAttr intShapeDiv(IntAttr lhs, IntAttr rhs); } // namespace mlir::fly -#endif // FLY_DIALECT_UTILS_INTUTILS_H +#endif // FLYDSL_DIALECT_UTILS_INTUTILS_H diff --git a/include/flydsl/Dialect/Fly/Utils/LayoutUtils.h b/include/flydsl/Dialect/Fly/Utils/LayoutUtils.h index 39afd472..55c4d2a9 100644 --- a/include/flydsl/Dialect/Fly/Utils/LayoutUtils.h +++ b/include/flydsl/Dialect/Fly/Utils/LayoutUtils.h @@ -1,5 +1,5 @@ -#ifndef FLY_DIALECT_UTILS_LAYOUTATTR_H -#define FLY_DIALECT_UTILS_LAYOUTATTR_H +#ifndef FLYDSL_DIALECT_UTILS_LAYOUTATTR_H +#define FLYDSL_DIALECT_UTILS_LAYOUTATTR_H #include @@ -768,4 +768,4 @@ Layout layoutRakedProduct(LayoutBuilder &builder, Layout blockLayout, La } // namespace mlir::fly -#endif // FLY_DIALECT_UTILS_LAYOUTATTR_H +#endif // FLYDSL_DIALECT_UTILS_LAYOUTATTR_H diff --git a/include/flydsl/Dialect/Fly/Utils/NormalForm.h b/include/flydsl/Dialect/Fly/Utils/NormalForm.h index 1b7b6e3e..dcd29ea1 100644 --- a/include/flydsl/Dialect/Fly/Utils/NormalForm.h +++ b/include/flydsl/Dialect/Fly/Utils/NormalForm.h @@ -1,5 +1,5 @@ -#ifndef FLY_DIALECT_UTILS_NORMALFORM_H -#define FLY_DIALECT_UTILS_NORMALFORM_H +#ifndef FLYDSL_DIALECT_UTILS_NORMALFORM_H +#define FLYDSL_DIALECT_UTILS_NORMALFORM_H #include "mlir/IR/Attributes.h" #include "mlir/Support/LogicalResult.h" @@ -22,4 +22,4 @@ bool isNormalForm(TypedValue value); } // namespace mlir::fly -#endif // FLY_DIALECT_UTILS_NORMALFORM_H +#endif // FLYDSL_DIALECT_UTILS_NORMALFORM_H diff --git a/lib/Dialect/Fly/IR/FlyOps.cpp b/lib/Dialect/Fly/IR/FlyOps.cpp index 653e7be7..76726aab 100644 --- a/lib/Dialect/Fly/IR/FlyOps.cpp +++ b/lib/Dialect/Fly/IR/FlyOps.cpp @@ -90,7 +90,7 @@ LayoutAttr makeOrderedLayoutAttr(IntTupleAttr shapeAttr, IntTupleAttr orderAttr) } // namespace -#define FLY_INFER_RETURN_TYPES(OP) \ +#define FLY_INFER_RETURN_TYPES(OP) \ llvm::LogicalResult OP::inferReturnTypes( \ mlir::MLIRContext *context, std::optional<::mlir::Location> location, \ mlir::ValueRange operands, mlir::DictionaryAttr attributes, \ From eafa2d6b4b1779ae4a1feb6995d5a822bec1833b Mon Sep 17 00:00:00 2001 From: Feng Shijie Date: Tue, 27 Jan 2026 06:27:23 +0000 Subject: [PATCH 003/113] add separate target-specific rocdl dialect --- examples/02-layout_algebra.py | 6 +- examples/03-mma_atom.py | 12 +-- include/flydsl-c/FlyDialect.h | 6 +- include/flydsl-c/FlyROCDLDialect.h | 16 ++++ include/flydsl/Conversion/Passes.td | 8 +- include/flydsl/Dialect/CMakeLists.txt | 1 + include/flydsl/Dialect/Fly/IR/FlyDialect.h | 2 - .../flydsl/Dialect/Fly/IR/FlyInterfaces.td | 21 ++--- include/flydsl/Dialect/Fly/IR/FlyTypeDefs.td | 12 ++- .../flydsl/Dialect/Fly/Utils/LayoutUtils.h | 2 +- .../Dialect/Fly/Utils/ThrValLayoutMacro.h.inc | 61 ++++++++++++++ .../flydsl/Dialect/FlyROCDL/CMakeLists.txt | 1 + include/flydsl/Dialect/FlyROCDL/IR/Atom.td | 9 +++ .../flydsl/Dialect/FlyROCDL/IR/CMakeLists.txt | 13 +++ .../flydsl/Dialect/FlyROCDL/IR/CopyAtom.td | 8 ++ include/flydsl/Dialect/FlyROCDL/IR/Dialect.h | 33 ++++++++ include/flydsl/Dialect/FlyROCDL/IR/Dialect.td | 32 +++++++- include/flydsl/Dialect/FlyROCDL/IR/MmaAtom.td | 33 ++++++++ include/flydsl/Dialect/FlyROCDL/IR/Ops.td | 0 lib/CAPI/CMakeLists.txt | 7 +- lib/CAPI/Dialect/CMakeLists.txt | 2 + lib/CAPI/Dialect/Fly/CMakeLists.txt | 5 ++ lib/CAPI/{ => Dialect/Fly}/FlyDialect.cpp | 0 lib/CAPI/Dialect/FlyROCDL/CMakeLists.txt | 6 ++ lib/CAPI/Dialect/FlyROCDL/FlyROCDLDialect.cpp | 6 ++ lib/Conversion/FlyToROCDL/CMakeLists.txt | 6 +- lib/Conversion/FlyToROCDL/FlyToROCDL.cpp | 46 +++++------ lib/Dialect/CMakeLists.txt | 1 + lib/Dialect/Fly/IR/FlyDialect.cpp | 2 +- lib/Dialect/Fly/IR/FlyTypeDefs.cpp | 14 ++++ lib/Dialect/FlyROCDL/CDNA3/CopyAtom.cpp | 20 +++++ lib/Dialect/FlyROCDL/CDNA3/MmaAtom.cpp | 79 +++++++++++++++++++ lib/Dialect/FlyROCDL/CMakeLists.txt | 8 ++ lib/Dialect/FlyROCDL/Dialect.cpp | 43 ++++++++++ python/flydsl/lang/ir/__init__.py | 3 - python/mlir_flydsl/CMakeLists.txt | 10 +++ python/mlir_flydsl/FlyRegisterEverything.cpp | 4 +- python/mlir_flydsl/dialects/FlyROCDL.td | 6 ++ python/mlir_flydsl/dialects/fly_rocdl.py | 2 + 39 files changed, 470 insertions(+), 76 deletions(-) create mode 100644 include/flydsl-c/FlyROCDLDialect.h create mode 100644 include/flydsl/Dialect/Fly/Utils/ThrValLayoutMacro.h.inc create mode 100644 include/flydsl/Dialect/FlyROCDL/CMakeLists.txt create mode 100644 include/flydsl/Dialect/FlyROCDL/IR/Atom.td create mode 100644 include/flydsl/Dialect/FlyROCDL/IR/CMakeLists.txt create mode 100644 include/flydsl/Dialect/FlyROCDL/IR/CopyAtom.td create mode 100644 include/flydsl/Dialect/FlyROCDL/IR/Dialect.h create mode 100644 include/flydsl/Dialect/FlyROCDL/IR/MmaAtom.td create mode 100644 include/flydsl/Dialect/FlyROCDL/IR/Ops.td create mode 100644 lib/CAPI/Dialect/CMakeLists.txt create mode 100644 lib/CAPI/Dialect/Fly/CMakeLists.txt rename lib/CAPI/{ => Dialect/Fly}/FlyDialect.cpp (100%) create mode 100644 lib/CAPI/Dialect/FlyROCDL/CMakeLists.txt create mode 100644 lib/CAPI/Dialect/FlyROCDL/FlyROCDLDialect.cpp create mode 100644 lib/Dialect/FlyROCDL/CDNA3/CopyAtom.cpp create mode 100644 lib/Dialect/FlyROCDL/CDNA3/MmaAtom.cpp create mode 100644 lib/Dialect/FlyROCDL/CMakeLists.txt create mode 100644 lib/Dialect/FlyROCDL/Dialect.cpp create mode 100644 python/mlir_flydsl/dialects/FlyROCDL.td create mode 100644 python/mlir_flydsl/dialects/fly_rocdl.py diff --git a/examples/02-layout_algebra.py b/examples/02-layout_algebra.py index b940aaef..97dead21 100644 --- a/examples/02-layout_algebra.py +++ b/examples/02-layout_algebra.py @@ -16,8 +16,10 @@ def kernel( A: memrefTy, B: memrefTy, ): - tid = fx.arith.IndexCastOp(fx.T.i32(), fx.thread_idx.x) - bid = fx.arith.IndexCastOp(fx.T.i32(), fx.block_idx.x) + tid = fx.arith.index_cast(fx.T.i32(), fx.thread_idx.x) + bid = fx.arith.index_cast(fx.T.i32(), fx.block_idx.x) + + print(type(tid), tid) l16 = fx.make_layout(16, 1) tile = fx.make_tile([l16, l16]) diff --git a/examples/03-mma_atom.py b/examples/03-mma_atom.py index 4862e375..8180afad 100644 --- a/examples/03-mma_atom.py +++ b/examples/03-mma_atom.py @@ -20,14 +20,14 @@ def kernel( B: ABMemRefTy, C: CMemRefTy, ): - tid = fx.arith.IndexCastOp(fx.T.i32(), fx.thread_idx.x) + tid = fx.arith.index_cast(fx.T.i32(), fx.thread_idx.x) rA = fx.memref_alloca(RABMemRefTy, fx.make_layout(1, 1)) rB = fx.memref_alloca(RABMemRefTy, fx.make_layout(1, 1)) - copyAtom = fx.make_atom(fx.ir.Type.parse("!fly.atom.universal_copy_32b")) + copyAtom = fx.make_atom(fx.ir.Type.parse("!fly.atom.universal_copy<32>")) mmaAtom = fx.make_atom( - fx.ir.Type.parse("!fly.atom.amdgpu.mfma.f32.16x16x4f32") + fx.ir.Type.parse("!fly_rocdl.atom.cdna3.mfma<16x16x16, f32 x f32 = f32>") ) tA = fx.logical_divide(A, fx.make_layout(1, 1)) @@ -74,9 +74,9 @@ def __call__( MmaAtom_Module = MmaAtom() print(MmaAtom_Module) -MmaAtom_Executor = flydsl.compile(MmaAtom_Module, print_after_all=True) -MmaAtom_Asm = flydsl.compile(MmaAtom_Module, output_format="assembly") -print(MmaAtom_Asm) +MmaAtom_Executor = flydsl.compile(MmaAtom_Module, print_after_all=False) +# MmaAtom_Asm = flydsl.compile(MmaAtom_Module, output_format="assembly") +# print(MmaAtom_Asm) import torch diff --git a/include/flydsl-c/FlyDialect.h b/include/flydsl-c/FlyDialect.h index d27e0206..a8880f47 100644 --- a/include/flydsl-c/FlyDialect.h +++ b/include/flydsl-c/FlyDialect.h @@ -1,5 +1,5 @@ -#ifndef FLYDSL_C_DIALECTS_H -#define FLYDSL_C_DIALECTS_H +#ifndef FLYDSL_C_FLYDIALECT_H +#define FLYDSL_C_FLYDIALECT_H #include "mlir-c/IR.h" @@ -13,4 +13,4 @@ MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(Fly, fly); } #endif -#endif // FLYDSL_C_DIALECTS_H +#endif // FLYDSL_C_FLYDIALECT_H diff --git a/include/flydsl-c/FlyROCDLDialect.h b/include/flydsl-c/FlyROCDLDialect.h new file mode 100644 index 00000000..96981fff --- /dev/null +++ b/include/flydsl-c/FlyROCDLDialect.h @@ -0,0 +1,16 @@ +#ifndef FLYDSL_C_FLYROCDLDIALECT_H +#define FLYDSL_C_FLYROCDLDIALECT_H + +#include "mlir-c/IR.h" + +#ifdef __cplusplus +extern "C" { +#endif + +MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(FlyROCDL, fly_rocdl); + +#ifdef __cplusplus +} +#endif + +#endif // FLYDSL_C_FLYROCDLDIALECT_H diff --git a/include/flydsl/Conversion/Passes.td b/include/flydsl/Conversion/Passes.td index 5b873657..801b2c73 100644 --- a/include/flydsl/Conversion/Passes.td +++ b/include/flydsl/Conversion/Passes.td @@ -4,17 +4,13 @@ include "mlir/Pass/PassBase.td" def FlyToROCDLConversionPass : Pass<"convert-fly-to-rocdl"> { - let summary = "Lower Fly to MLIR upstream and rocdl dialects "; - let description = [{ - - }]; - + let summary = "Lower Fly to MLIR upstream and rocdl dialects "; let dependentDialects = [ "arith::ArithDialect", "scf::SCFDialect", "vector::VectorDialect", "LLVM::LLVMDialect", - "ROCDL::ROCDLDialect", + "ROCDL::ROCDLDialect" ]; } diff --git a/include/flydsl/Dialect/CMakeLists.txt b/include/flydsl/Dialect/CMakeLists.txt index 08c0cd63..0b152044 100644 --- a/include/flydsl/Dialect/CMakeLists.txt +++ b/include/flydsl/Dialect/CMakeLists.txt @@ -1 +1,2 @@ add_subdirectory(Fly) +add_subdirectory(FlyROCDL) diff --git a/include/flydsl/Dialect/Fly/IR/FlyDialect.h b/include/flydsl/Dialect/Fly/IR/FlyDialect.h index 20e78737..7a55b35b 100644 --- a/include/flydsl/Dialect/Fly/IR/FlyDialect.h +++ b/include/flydsl/Dialect/Fly/IR/FlyDialect.h @@ -16,10 +16,8 @@ #include "flydsl/Dialect/Fly/IR/FlyDialect.h.inc" #include "flydsl/Dialect/Fly/IR/FlyEnums.h.inc" -namespace mlir::fly { #include "flydsl/Dialect/Fly/IR/FlyAttrInterfaces.h.inc" #include "flydsl/Dialect/Fly/IR/FlyTypeInterfaces.h.inc" -} // namespace mlir::fly #define GET_ATTRDEF_CLASSES #include "flydsl/Dialect/Fly/IR/FlyAttrDefs.h.inc" diff --git a/include/flydsl/Dialect/Fly/IR/FlyInterfaces.td b/include/flydsl/Dialect/Fly/IR/FlyInterfaces.td index c84b2cbc..e8cc2199 100644 --- a/include/flydsl/Dialect/Fly/IR/FlyInterfaces.td +++ b/include/flydsl/Dialect/Fly/IR/FlyInterfaces.td @@ -8,10 +8,10 @@ def Fly_NestedInterfaceMethods { InterfaceMethod<"", "bool", "isLeaf", (ins)>, InterfaceMethod<"", "int32_t", "rank", (ins)>, InterfaceMethod<"", "int32_t", "rank", (ins "int32_t":$idx)>, - InterfaceMethod<"", "int32_t", "rank", (ins "ArrayRef":$idxs)>, + InterfaceMethod<"", "int32_t", "rank", (ins "::llvm::ArrayRef":$idxs)>, InterfaceMethod<"", "int32_t", "depth", (ins)>, InterfaceMethod<"", "int32_t", "depth", (ins "int32_t":$idx)>, - InterfaceMethod<"", "int32_t", "depth", (ins "ArrayRef":$idxs)> + InterfaceMethod<"", "int32_t", "depth", (ins "::llvm::ArrayRef":$idxs)> ]; } @@ -37,19 +37,22 @@ def Fly_MayStaticTypeInterface : TypeInterface<"MayStaticTypeInterface"> { def Fly_CopyAtomTypeInterface : TypeInterface<"CopyAtomTypeInterface"> { + let cppNamespace = "::mlir::fly"; let methods = [ - InterfaceMethod<"", "Attribute", "getThrSize", (ins)>, - InterfaceMethod<"", "Attribute", "getThrValLayoutSrc", (ins)>, - InterfaceMethod<"", "Attribute", "getThrValLayoutDst", (ins)> + InterfaceMethod<"", "::mlir::Attribute", "getThrSize", (ins)>, + InterfaceMethod<"", "::mlir::Attribute", "getThrValLayoutSrc", (ins)>, + InterfaceMethod<"", "::mlir::Attribute", "getThrValLayoutDst", (ins)>, + InterfaceMethod<"", "::mlir::Attribute", "getThrValLayoutRef", (ins)> ]; } def Fly_MmaAtomTypeInterface : TypeInterface<"MmaAtomTypeInterface"> { + let cppNamespace = "::mlir::fly"; let methods = [ - InterfaceMethod<"", "Attribute", "getThrSize", (ins)>, - InterfaceMethod<"", "Attribute", "getThrValLayoutA", (ins)>, - InterfaceMethod<"", "Attribute", "getThrValLayoutB", (ins)>, - InterfaceMethod<"", "Attribute", "getThrValLayoutC", (ins)> + InterfaceMethod<"", "::mlir::Attribute", "getThrSize", (ins)>, + InterfaceMethod<"", "::mlir::Attribute", "getThrValLayoutA", (ins)>, + InterfaceMethod<"", "::mlir::Attribute", "getThrValLayoutB", (ins)>, + InterfaceMethod<"", "::mlir::Attribute", "getThrValLayoutC", (ins)> ]; } diff --git a/include/flydsl/Dialect/Fly/IR/FlyTypeDefs.td b/include/flydsl/Dialect/Fly/IR/FlyTypeDefs.td index 301f20bc..0e4a2c1a 100644 --- a/include/flydsl/Dialect/Fly/IR/FlyTypeDefs.td +++ b/include/flydsl/Dialect/Fly/IR/FlyTypeDefs.td @@ -187,10 +187,14 @@ def Fly_TiledMma : Fly_Type<"TiledMma", "tiled_mma", []> { let assemblyFormat = "`<` $mmaAtom `,` $atomLayout `,` $permutation `>`"; } +def Fly_CopyAtomUniversalCopy : Fly_Type<"CopyAtomUniversalCopy", "atom.universal_copy", [ + DeclareTypeInterfaceMethods +]> { + let parameters = (ins + "int32_t":$bitSize + ); + let assemblyFormat = "`<` $bitSize `>`"; +} -def Fly_CopyAtomGlobalLoad4B : Fly_Type<"CopyAtomGlobalLoad4B", "atom.global_load_4B", []> {} -def Fly_CopyAtomUniversalCopy32b : Fly_Type<"CopyAtomUniversalCopy32b", "atom.universal_copy_32b", []> {} - -def Fly_MmaAtomMFMA_F32_16x16x4F32 : Fly_Type<"MmaAtomMFMA_F32_16x16x4F32", "atom.amdgpu.mfma.f32.16x16x4f32", []> {} #endif // FLY_TYPEDEFS diff --git a/include/flydsl/Dialect/Fly/Utils/LayoutUtils.h b/include/flydsl/Dialect/Fly/Utils/LayoutUtils.h index 55c4d2a9..ed9b7530 100644 --- a/include/flydsl/Dialect/Fly/Utils/LayoutUtils.h +++ b/include/flydsl/Dialect/Fly/Utils/LayoutUtils.h @@ -189,7 +189,7 @@ typename LayoutBuilder::IntTuple layoutCosize(LayoutBuilder &bui ArithValue one = builder.materializeConstantArith(1); ArithValue s = builder.getArithValue(flatShapeLeaves[0]); ArithValue d = builder.getArithValue(flatStrideLeaves[0]); - ArithValue cosize = builder.mul(builder.sub(s, one), d); + ArithValue cosize = builder.add(one, builder.mul(builder.sub(s, one), d)); for (size_t i = 1; i < flatShapeLeaves.size(); ++i) { ArithValue s = builder.getArithValue(flatShapeLeaves[i]); diff --git a/include/flydsl/Dialect/Fly/Utils/ThrValLayoutMacro.h.inc b/include/flydsl/Dialect/Fly/Utils/ThrValLayoutMacro.h.inc new file mode 100644 index 00000000..f852d237 --- /dev/null +++ b/include/flydsl/Dialect/Fly/Utils/ThrValLayoutMacro.h.inc @@ -0,0 +1,61 @@ +// clang-format off +#define FxVA_NARGS_IMPL(_1, _2, _3, _4, _5, _6, _7, _8, N, ...) N +#define FxVA_NARGS(...) FxVA_NARGS_IMPL(__VA_ARGS__, 8, 7, 6, 5, 4, 3, 2, 1) + +#define FxVA_DISPATCH(MACRO, ...) FxVA_DISPATCH_IMPL(MACRO, FxVA_NARGS(__VA_ARGS__))(__VA_ARGS__) +#define FxVA_DISPATCH_IMPL(MACRO, N) FxVA_DISPATCH_CONCAT(MACRO, N) +#define FxVA_DISPATCH_CONCAT(MACRO, N) MACRO##N + +#define FxC(val) IntTupleAttr::getLeafStatic(getContext(), val) +#define FxT_1(a) IntTupleAttr::get(ArrayAttr::get(getContext(), {a})) +#define FxT_2(a, b) IntTupleAttr::get(ArrayAttr::get(getContext(), {a, b})) +#define FxT_3(a, b, c) IntTupleAttr::get(ArrayAttr::get(getContext(), {a, b, c})) +#define FxT_4(a, b, c, d) IntTupleAttr::get(ArrayAttr::get(getContext(), {a, b, c, d})) +#define FxT_5(a, b, c, d, e) IntTupleAttr::get(ArrayAttr::get(getContext(), {a, b, c, d, e})) +#define FxT_6(a, b, c, d, e, f) IntTupleAttr::get(ArrayAttr::get(getContext(), {a, b, c, d, e, f})) +#define FxT_7(a, b, c, d, e, f, g) IntTupleAttr::get(ArrayAttr::get(getContext(), {a, b, c, d, e, f, g})) +#define FxT_8(a, b, c, d, e, f, g, h) IntTupleAttr::get(ArrayAttr::get(getContext(), {a, b, c, d, e, f, g, h})) +#define FxT(...) FxVA_DISPATCH(FxT_, __VA_ARGS__) + +#define FxThr1(a) FxT_1(FxC(a)) +#define FxThr2(a, b) FxT_2(FxC(a), FxC(b)) +#define FxThr3(a, b, c) FxT_3(FxC(a), FxC(b), FxC(c)) +#define FxThr4(a, b, c, d) FxT_4(FxC(a), FxC(b), FxC(c), FxC(d)) +#define FxThr5(a, b, c, d, e) FxT_5(FxC(a), FxC(b), FxC(c), FxC(d), FxC(e)) +#define FxThr6(a, b, c, d, e, f) FxT_6(FxC(a), FxC(b), FxC(c), FxC(d), FxC(e), FxC(f)) +#define FxThr7(a, b, c, d, e, f, g) FxT_7(FxC(a), FxC(b), FxC(c), FxC(d), FxC(e), FxC(f), FxC(g)) +#define FxThr8(a, b, c, d, e, f, g, h) FxT_8(FxC(a), FxC(b), FxC(c), FxC(d), FxC(e), FxC(f), FxC(g), FxC(h)) +#define FxThr(...) FxVA_DISPATCH(FxThr, __VA_ARGS__) + +#define FxVal1(a) FxT_1(FxC(a)) +#define FxVal2(a, b) FxT_2(FxC(a), FxC(b)) +#define FxVal3(a, b, c) FxT_3(FxC(a), FxC(b), FxC(c)) +#define FxVal4(a, b, c, d) FxT_4(FxC(a), FxC(b), FxC(c), FxC(d)) +#define FxVal5(a, b, c, d, e) FxT_5(FxC(a), FxC(b), FxC(c), FxC(d), FxC(e)) +#define FxVal6(a, b, c, d, e, f) FxT_6(FxC(a), FxC(b), FxC(c), FxC(d), FxC(e), FxC(f)) +#define FxVal7(a, b, c, d, e, f, g) FxT_7(FxC(a), FxC(b), FxC(c), FxC(d), FxC(e), FxC(f), FxC(g)) +#define FxVal8(a, b, c, d, e, f, g, h) FxT_8(FxC(a), FxC(b), FxC(c), FxC(d), FxC(e), FxC(f), FxC(g), FxC(h)) +#define FxVal(...) FxVA_DISPATCH(FxVal, __VA_ARGS__) + +#define FxShape1(a) FxC(a) +#define FxShape2(a, b) FxT_2(a, b) +#define FxShape3(a, b, c) FxT_3(a, b, c) +#define FxShape4(a, b, c, d) FxT_4(a, b, c, d) +#define FxShape5(a, b, c, d, e) FxT_5(a, b, c, d, e) +#define FxShape6(a, b, c, d, e, f) FxT_6(a, b, c, d, e, f) +#define FxShape7(a, b, c, d, e, f, g) FxT_7(a, b, c, d, e, f, g) +#define FxShape8(a, b, c, d, e, f, g, h) FxT_8(a, b, c, d, e, f, g, h) +#define FxShape(...) FxVA_DISPATCH(FxShape, __VA_ARGS__) + +#define FxStride1(a) FxC(a) +#define FxStride2(a, b) FxT_2(a, b) +#define FxStride3(a, b, c) FxT_3(a, b, c) +#define FxStride4(a, b, c, d) FxT_4(a, b, c, d) +#define FxStride5(a, b, c, d, e) FxT_5(a, b, c, d, e) +#define FxStride6(a, b, c, d, e, f) FxT_6(a, b, c, d, e, f) +#define FxStride7(a, b, c, d, e, f, g) FxT_7(a, b, c, d, e, f, g) +#define FxStride8(a, b, c, d, e, f, g, h) FxT_8(a, b, c, d, e, f, g, h) +#define FxStride(...) FxVA_DISPATCH(FxStride, __VA_ARGS__) + +#define FxLayout(shape, stride) LayoutAttr::get(shape, stride) +// clang-format on diff --git a/include/flydsl/Dialect/FlyROCDL/CMakeLists.txt b/include/flydsl/Dialect/FlyROCDL/CMakeLists.txt new file mode 100644 index 00000000..f33061b2 --- /dev/null +++ b/include/flydsl/Dialect/FlyROCDL/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(IR) diff --git a/include/flydsl/Dialect/FlyROCDL/IR/Atom.td b/include/flydsl/Dialect/FlyROCDL/IR/Atom.td new file mode 100644 index 00000000..f6f57397 --- /dev/null +++ b/include/flydsl/Dialect/FlyROCDL/IR/Atom.td @@ -0,0 +1,9 @@ +#ifndef FLYROCDL_ATOM +#define FLYROCDL_ATOM + +include "flydsl/Dialect/FlyROCDL/IR/Dialect.td" + +include "flydsl/Dialect/FlyROCDL/IR/CopyAtom.td" +include "flydsl/Dialect/FlyROCDL/IR/MmaAtom.td" + +#endif // FLYROCDL_ATOM diff --git a/include/flydsl/Dialect/FlyROCDL/IR/CMakeLists.txt b/include/flydsl/Dialect/FlyROCDL/IR/CMakeLists.txt new file mode 100644 index 00000000..dafddc61 --- /dev/null +++ b/include/flydsl/Dialect/FlyROCDL/IR/CMakeLists.txt @@ -0,0 +1,13 @@ +set(LLVM_TARGET_DEFINITIONS Dialect.td) + +mlir_tablegen(Dialect.h.inc -gen-dialect-decls) +mlir_tablegen(Dialect.cpp.inc -gen-dialect-defs) + +mlir_tablegen(Enums.h.inc -gen-enum-decls -typedefs-dialect=fly_rocdl) +mlir_tablegen(Enums.cpp.inc -gen-enum-defs -typedefs-dialect=fly_rocdl) + +set(LLVM_TARGET_DEFINITIONS Atom.td) +mlir_tablegen(Atom.h.inc -gen-typedef-decls -typedefs-dialect=fly_rocdl) +mlir_tablegen(Atom.cpp.inc -gen-typedef-defs -typedefs-dialect=fly_rocdl) + +add_public_tablegen_target(MLIRFlyROCDLIncGen) diff --git a/include/flydsl/Dialect/FlyROCDL/IR/CopyAtom.td b/include/flydsl/Dialect/FlyROCDL/IR/CopyAtom.td new file mode 100644 index 00000000..47718890 --- /dev/null +++ b/include/flydsl/Dialect/FlyROCDL/IR/CopyAtom.td @@ -0,0 +1,8 @@ +#ifndef FLYROCDL_COPYATOM +#define FLYROCDL_COPYATOM + +include "flydsl/Dialect/FlyROCDL/IR/Dialect.td" + +def FlyROCDL_CopyAtom_BufferLSA : FlyxROCL_CopyAtom<"CopyAtom_CDNA3_BufferLSA", "atom.cdna3.buffer_lsa",[]> {} + +#endif // FLYROCDL_COPYATOM diff --git a/include/flydsl/Dialect/FlyROCDL/IR/Dialect.h b/include/flydsl/Dialect/FlyROCDL/IR/Dialect.h new file mode 100644 index 00000000..eac5ec1f --- /dev/null +++ b/include/flydsl/Dialect/FlyROCDL/IR/Dialect.h @@ -0,0 +1,33 @@ +#ifndef FLYDSL_DIALECT_FLYROCDL_IR_DIALECT_H +#define FLYDSL_DIALECT_FLYROCDL_IR_DIALECT_H + +#include "mlir/Bytecode/BytecodeOpInterface.h" +#include "mlir/Dialect/LLVMIR/ROCDLDialect.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/OpImplementation.h" +#include "mlir/IR/Types.h" +#include "mlir/Interfaces/FunctionInterfaces.h" +#include "mlir/Interfaces/InferTypeOpInterface.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" + +#include "flydsl/Dialect/Fly/IR/FlyDialect.h" + +#include "flydsl/Dialect/FlyROCDL/IR/Dialect.h.inc" +#include "flydsl/Dialect/FlyROCDL/IR/Enums.h.inc" + +#define GET_TYPEDEF_CLASSES +#include "flydsl/Dialect/FlyROCDL/IR/Atom.h.inc" + +namespace mlir::fly_rocdl { + +ParseResult parseMNKDimensionList(AsmParser &parser, int32_t &m, int32_t &n, int32_t &k); + +void printMNKDimensionList(AsmPrinter &printer, int32_t m, int32_t n, int32_t k); + +} // namespace mlir::fly_rocdl + +#endif // FLYDSL_DIALECT_FLYROCDL_IR_DIALECT_H diff --git a/include/flydsl/Dialect/FlyROCDL/IR/Dialect.td b/include/flydsl/Dialect/FlyROCDL/IR/Dialect.td index 3a86fdb9..b6f5ba6b 100644 --- a/include/flydsl/Dialect/FlyROCDL/IR/Dialect.td +++ b/include/flydsl/Dialect/FlyROCDL/IR/Dialect.td @@ -11,12 +11,40 @@ def FlyROCDL_Dialect : Dialect { let name = "fly_rocdl"; let cppNamespace = "::mlir::fly_rocdl"; + let dependentDialects = [ + "ROCDL::ROCDLDialect" + ]; + + let useDefaultTypePrinterParser = 1; let usePropertiesForAttributes = 1; } -class FlyROCDL_Type traits = []> - : TypeDef { +class FlyxROCL_MmaAtom traits = []> + : TypeDef])> { + let mnemonic = typeMnemonic; +} + +class FlyxROCL_CopyAtom traits = []> + : TypeDef])> { let mnemonic = typeMnemonic; } +def FlyROCDL_SchedGroup : I32BitEnumAttr<"SchedGroup", "", [ + I32BitEnumAttrCaseNone<"None">, + I32BitEnumAttrCaseBit<"NonSideEffect", 0>, + I32BitEnumAttrCaseBit<"VALU", 1>, + I32BitEnumAttrCaseBit<"SALU", 2>, + I32BitEnumAttrCaseBit<"MFMA", 3>, + I32BitEnumAttrCaseBit<"VMEM", 4>, + I32BitEnumAttrCaseBit<"VMEM_READ", 5>, + I32BitEnumAttrCaseBit<"VMEM_WRITE", 6>, + I32BitEnumAttrCaseBit<"LDS", 7>, + I32BitEnumAttrCaseBit<"DS_READ", 8>, + I32BitEnumAttrCaseBit<"DS_WRITE", 9>, + I32BitEnumAttrCaseBit<"Transcendental", 10> +]> { + let genSpecializedAttr = 0; + let cppNamespace = FlyROCDL_Dialect.cppNamespace; +} + #endif // FLYROCDL_DIALECT diff --git a/include/flydsl/Dialect/FlyROCDL/IR/MmaAtom.td b/include/flydsl/Dialect/FlyROCDL/IR/MmaAtom.td new file mode 100644 index 00000000..ebdc507b --- /dev/null +++ b/include/flydsl/Dialect/FlyROCDL/IR/MmaAtom.td @@ -0,0 +1,33 @@ +#ifndef FLYROCDL_MMAATOM +#define FLYROCDL_MMAATOM + +include "flydsl/Dialect/FlyROCDL/IR/Dialect.td" + +//===----------------------------------------------------------------------===// +// MmaAtom CDNA3 +//===----------------------------------------------------------------------===// + +def FlyROCDL_MmaAtomCDNA3_MFMA : FlyxROCL_MmaAtom<"MmaAtomCDNA3_MFMA", "atom.cdna3.mfma", []> { + let parameters = (ins + "int32_t":$m, + "int32_t":$n, + "int32_t":$k, + "Type":$elemTyA, + "Type":$elemTyB, + "Type":$elemTyAcc + ); + let assemblyFormat = "`<` custom($m, $n, $k) `,` $elemTyA `x` $elemTyB `=` $elemTyAcc `>`"; + + let builders = [ + TypeBuilderWithInferredContext<(ins "int32_t":$m, "int32_t":$n, "int32_t":$k, "Type":$elemTyA, "Type":$elemTyB, "Type":$elemTyAcc), [{ + return $_get(elemTyA.getContext(), m, n, k, elemTyA, elemTyB, elemTyAcc); + }]> + ]; + let genVerifyDecl = 1; +} + +//===----------------------------------------------------------------------===// +// MmaAtom CDNA4 +//===----------------------------------------------------------------------===// + +#endif // FLYROCDL_MMAATOM diff --git a/include/flydsl/Dialect/FlyROCDL/IR/Ops.td b/include/flydsl/Dialect/FlyROCDL/IR/Ops.td new file mode 100644 index 00000000..e69de29b diff --git a/lib/CAPI/CMakeLists.txt b/lib/CAPI/CMakeLists.txt index 2c17c103..0ca0f41c 100644 --- a/lib/CAPI/CMakeLists.txt +++ b/lib/CAPI/CMakeLists.txt @@ -1,6 +1 @@ -add_mlir_public_c_api_library(MLIRCPIFly - FlyDialect.cpp - LINK_LIBS PUBLIC - MLIRFlyDialect - MLIRFlyToROCDL -) +add_subdirectory(Dialect) diff --git a/lib/CAPI/Dialect/CMakeLists.txt b/lib/CAPI/Dialect/CMakeLists.txt new file mode 100644 index 00000000..0b152044 --- /dev/null +++ b/lib/CAPI/Dialect/CMakeLists.txt @@ -0,0 +1,2 @@ +add_subdirectory(Fly) +add_subdirectory(FlyROCDL) diff --git a/lib/CAPI/Dialect/Fly/CMakeLists.txt b/lib/CAPI/Dialect/Fly/CMakeLists.txt new file mode 100644 index 00000000..2ec143a1 --- /dev/null +++ b/lib/CAPI/Dialect/Fly/CMakeLists.txt @@ -0,0 +1,5 @@ +add_mlir_public_c_api_library(MLIRCPIFly + FlyDialect.cpp + LINK_LIBS PUBLIC + MLIRFlyDialect +) diff --git a/lib/CAPI/FlyDialect.cpp b/lib/CAPI/Dialect/Fly/FlyDialect.cpp similarity index 100% rename from lib/CAPI/FlyDialect.cpp rename to lib/CAPI/Dialect/Fly/FlyDialect.cpp diff --git a/lib/CAPI/Dialect/FlyROCDL/CMakeLists.txt b/lib/CAPI/Dialect/FlyROCDL/CMakeLists.txt new file mode 100644 index 00000000..c30d361d --- /dev/null +++ b/lib/CAPI/Dialect/FlyROCDL/CMakeLists.txt @@ -0,0 +1,6 @@ +add_mlir_public_c_api_library(MLIRCPIFlyROCDL + FlyROCDLDialect.cpp + LINK_LIBS PUBLIC + MLIRFlyROCDLDialect + MLIRFlyToROCDL +) diff --git a/lib/CAPI/Dialect/FlyROCDL/FlyROCDLDialect.cpp b/lib/CAPI/Dialect/FlyROCDL/FlyROCDLDialect.cpp new file mode 100644 index 00000000..08512ee2 --- /dev/null +++ b/lib/CAPI/Dialect/FlyROCDL/FlyROCDLDialect.cpp @@ -0,0 +1,6 @@ +#include "flydsl-c/FlyROCDLDialect.h" + +#include "flydsl/Dialect/FlyROCDL/IR/Dialect.h" +#include "mlir/CAPI/Registration.h" + +MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(FlyROCDL, fly_rocdl, mlir::fly_rocdl::FlyROCDLDialect) diff --git a/lib/Conversion/FlyToROCDL/CMakeLists.txt b/lib/Conversion/FlyToROCDL/CMakeLists.txt index 9eb29bd5..e277f0a0 100644 --- a/lib/Conversion/FlyToROCDL/CMakeLists.txt +++ b/lib/Conversion/FlyToROCDL/CMakeLists.txt @@ -3,18 +3,16 @@ add_mlir_conversion_library(MLIRFlyToROCDL DEPENDS MLIRFlyIncGen + MLIRFlyROCDLIncGen FlyConversionPassIncGen LINK_LIBS PUBLIC MLIRFlyDialect + MLIRFlyROCDLDialect - MLIRAffineDialect - MLIRAffineTransforms - MLIRAffineUtils MLIRArithDialect MLIRIR MLIRLLVMDialect - MLIRMemRefDialect MLIRPass MLIRSCFDialect MLIRTransforms diff --git a/lib/Conversion/FlyToROCDL/FlyToROCDL.cpp b/lib/Conversion/FlyToROCDL/FlyToROCDL.cpp index 0deed6ae..5052f71e 100644 --- a/lib/Conversion/FlyToROCDL/FlyToROCDL.cpp +++ b/lib/Conversion/FlyToROCDL/FlyToROCDL.cpp @@ -1,5 +1,4 @@ -#include "flydsl/Dialect/Fly/Utils/IntTupleUtils.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/Utils/Utils.h" #include "mlir/Dialect/Func/IR/FuncOps.h" @@ -16,11 +15,9 @@ #include "flydsl/Conversion/FlyToROCDL/FlyToROCDL.h" #include "flydsl/Dialect/Fly/IR/FlyDialect.h" - -#include -#include -#include -#include +#include "flydsl/Dialect/Fly/Utils/IntTupleUtils.h" +#include "flydsl/Dialect/Fly/Utils/LayoutUtils.h" +#include "flydsl/Dialect/FlyROCDL/IR/Dialect.h" namespace mlir { #define GEN_PASS_DEF_FLYTOROCDLCONVERSIONPASS @@ -32,17 +29,6 @@ using namespace mlir::fly; namespace { -// Helper to get the flattened size from an IntTupleAttr (product of all elements) -static int64_t getFlattenedSize(IntTupleAttr attr) { - IntTupleBuilder builder(attr.getContext()); - IntAttr product = intTupleProduct(builder, attr).getLeafAsInt(); - if (product.isStatic()) - return product.getValue(); - return 1; -} - -static int64_t getFlattenedSize(LayoutAttr attr) { return getFlattenedSize(attr.getShape()); } - static unsigned mapAddressSpace(AddressSpace space) { // - Global -> 1 (global) // - Shared -> 3 (local/LDS/workgroup) @@ -88,7 +74,10 @@ class MemRefAllocOpLowering : public OpConversionPattern { LayoutAttr layoutAttr = flyMemRefTy.getLayout(); auto elemTy = flyMemRefTy.getElemTy(); - int64_t totalSize = getFlattenedSize(layoutAttr); + LayoutBuilder builder(rewriter.getContext()); + IntTupleAttr totalSize = layoutCosize(builder, layoutAttr); + + assert(totalSize.isStatic() && totalSize.isLeaf()); auto convertedPtrTy = dyn_cast(getTypeConverter()->convertType(flyMemRefTy)); @@ -98,7 +87,9 @@ class MemRefAllocOpLowering : public OpConversionPattern { auto loc = op.getLoc(); // Alloca array size is i64. - Value nElems = arith::ConstantIntOp::create(rewriter, loc, totalSize, /*width=*/64).getResult(); + Value nElems = arith::ConstantIntOp::create(rewriter, loc, totalSize.getLeafAsInt().getValue(), + /*width=*/64) + .getResult(); // `llvm.alloca` takes element type and array size. Keep alignment unspecified. Value ptr = LLVM::AllocaOp::create(rewriter, loc, convertedPtrTy, elemTy, nElems, @@ -345,8 +336,8 @@ class CopyAtomCallLowering : public OpConversionPattern { LogicalResult matchAndRewrite(CopyAtomCall op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // Only handle the universal memref-to-memref copy atom here. - if (!isa(adaptor.getCopyAtom().getType())) - return rewriter.notifyMatchFailure(op, "unsupported copy atom (expected universal_copy_32b)"); + if (!isa(adaptor.getCopyAtom().getType())) + return rewriter.notifyMatchFailure(op, "unsupported copy atom (expected universal_copy)"); Value src = adaptor.getSrc(); Value dst = adaptor.getDst(); @@ -366,8 +357,10 @@ class CopyAtomCallLowering : public OpConversionPattern { if (srcFlyTy.getElemTy() != dstFlyTy.getElemTy()) return rewriter.notifyMatchFailure(op, "src/dst element types mismatch"); - int64_t nElems = getFlattenedSize(srcFlyTy.getLayout()); - if (nElems != getFlattenedSize(dstFlyTy.getLayout())) + LayoutBuilder builder(rewriter.getContext()); + IntTupleAttr totalSize = layoutCosize(builder, srcFlyTy.getLayout()); + + if (totalSize != layoutCosize(builder, dstFlyTy.getLayout())) return rewriter.notifyMatchFailure(op, "src/dst shapes mismatch"); // Lower to LLVM memcpy intrinsic to keep GPU kernel fully in LLVM dialect @@ -384,7 +377,7 @@ class CopyAtomCallLowering : public OpConversionPattern { if (elemBytes <= 0) return rewriter.notifyMatchFailure(op, "invalid element byte width"); - int64_t totalBytes = nElems * elemBytes; + int64_t totalBytes = totalSize.getLeafAsInt().getValue() * elemBytes; Value len = arith::ConstantIntOp::create(rewriter, loc, totalBytes, /*width=*/64).getResult(); // llvm.intr.memcpy(dst, src, len, isVolatile=false) @@ -402,8 +395,9 @@ class MmaAtomCallLowering : public OpConversionPattern { LogicalResult matchAndRewrite(MmaAtomCall op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // Only handle MFMA F32 16x16x4 F32 atom for now. - if (!isa(adaptor.getMmaAtom().getType())) - return rewriter.notifyMatchFailure(op, "unsupported mma atom (expected mfma.f32.16x16x4f32)"); + + auto mmaAtomTy = dyn_cast(adaptor.getMmaAtom().getType()); + assert(mmaAtomTy); Location loc = op.getLoc(); diff --git a/lib/Dialect/CMakeLists.txt b/lib/Dialect/CMakeLists.txt index 08c0cd63..0b152044 100644 --- a/lib/Dialect/CMakeLists.txt +++ b/lib/Dialect/CMakeLists.txt @@ -1 +1,2 @@ add_subdirectory(Fly) +add_subdirectory(FlyROCDL) diff --git a/lib/Dialect/Fly/IR/FlyDialect.cpp b/lib/Dialect/Fly/IR/FlyDialect.cpp index 9ea39b1d..49ba2fb5 100644 --- a/lib/Dialect/Fly/IR/FlyDialect.cpp +++ b/lib/Dialect/Fly/IR/FlyDialect.cpp @@ -11,10 +11,10 @@ using namespace mlir::fly; #include "flydsl/Dialect/Fly/IR/FlyEnums.cpp.inc" -namespace mlir::fly { #include "flydsl/Dialect/Fly/IR/FlyAttrInterfaces.cpp.inc" #include "flydsl/Dialect/Fly/IR/FlyTypeInterfaces.cpp.inc" +namespace mlir::fly { #include "flydsl/Dialect/Fly/IR/FlyAttrConstraints.cpp.inc" #include "flydsl/Dialect/Fly/IR/FlyTypeConstraints.cpp.inc" } // namespace mlir::fly diff --git a/lib/Dialect/Fly/IR/FlyTypeDefs.cpp b/lib/Dialect/Fly/IR/FlyTypeDefs.cpp index b21c55c5..920a49c7 100644 --- a/lib/Dialect/Fly/IR/FlyTypeDefs.cpp +++ b/lib/Dialect/Fly/IR/FlyTypeDefs.cpp @@ -75,4 +75,18 @@ CoordTensorType CoordTensorType::at(ArrayRef idxs) const { return CoordTensorType::get(getContext(), getBase().at(idxs), getLayout().at(idxs)); } +#include "flydsl/Dialect/Fly/Utils/ThrValLayoutMacro.h.inc" + +Attribute CopyAtomUniversalCopyType::getThrSize() const { return FxC(1); } + +Attribute CopyAtomUniversalCopyType::getThrValLayoutSrc() const { + return FxLayout(FxShape(FxC(1), FxC(getBitSize())), FxStride(FxC(0), FxC(1))); +} +Attribute CopyAtomUniversalCopyType::getThrValLayoutDst() const { + return FxLayout(FxShape(FxC(1), FxC(getBitSize())), FxStride(FxC(0), FxC(1))); +} +Attribute CopyAtomUniversalCopyType::getThrValLayoutRef() const { + return FxLayout(FxShape(FxC(1), FxC(getBitSize())), FxStride(FxC(0), FxC(1))); +} + } // namespace mlir::fly diff --git a/lib/Dialect/FlyROCDL/CDNA3/CopyAtom.cpp b/lib/Dialect/FlyROCDL/CDNA3/CopyAtom.cpp new file mode 100644 index 00000000..eb85b54b --- /dev/null +++ b/lib/Dialect/FlyROCDL/CDNA3/CopyAtom.cpp @@ -0,0 +1,20 @@ +#include "flydsl/Dialect/Fly/IR/FlyDialect.h" +#include "flydsl/Dialect/FlyROCDL/IR/Dialect.h" +#include "mlir/IR/BuiltinTypes.h" + +#include "flydsl/Dialect/Fly/Utils/ThrValLayoutMacro.h.inc" + +using namespace mlir; +using namespace mlir::fly; + +namespace mlir::fly_rocdl { + +Attribute CopyAtom_CDNA3_BufferLSAType::getThrSize() const { + auto ctx = getContext(); + return FxC(1); +} +Attribute CopyAtom_CDNA3_BufferLSAType::getThrValLayoutSrc() const { return {}; } +Attribute CopyAtom_CDNA3_BufferLSAType::getThrValLayoutDst() const { return {}; } +Attribute CopyAtom_CDNA3_BufferLSAType::getThrValLayoutRef() const { return {}; } + +} // namespace mlir::fly_rocdl diff --git a/lib/Dialect/FlyROCDL/CDNA3/MmaAtom.cpp b/lib/Dialect/FlyROCDL/CDNA3/MmaAtom.cpp new file mode 100644 index 00000000..0cdc0f26 --- /dev/null +++ b/lib/Dialect/FlyROCDL/CDNA3/MmaAtom.cpp @@ -0,0 +1,79 @@ +#include "flydsl/Dialect/Fly/IR/FlyDialect.h" +#include "flydsl/Dialect/FlyROCDL/IR/Dialect.h" +#include "mlir/IR/BuiltinTypes.h" + +#include "flydsl/Dialect/Fly/Utils/ThrValLayoutMacro.h.inc" + +using namespace mlir; +using namespace mlir::fly; + +namespace cdna3 { + +LayoutAttr getThrValLayoutAB(MLIRContext *ctx, int32_t M, int32_t N, int32_t K, Type elemTyA, + Type elemTyB, Type elemTyAcc) { + auto getContext = [&]() { return ctx; }; + + int MN = M; + int TyBit = elemTyA.getIntOrFloatBitWidth(); + assert(TyBit == int(elemTyB.getIntOrFloatBitWidth()) && + "Element types must have the same bit width"); + assert(M == N && "M and N must be equal"); + + int GroupK = 64 / MN; + int KPerThread = K / GroupK; + + return FxLayout( + FxShape(FxThr(MN, GroupK), FxVal(TyBit, KPerThread)), + FxStride(FxThr(TyBit, TyBit * MN * KPerThread), FxVal(1, TyBit * MN * KPerThread))); +} + +} // namespace cdna3 + +namespace cdna4 {} + +namespace mlir::fly_rocdl { + +Attribute MmaAtomCDNA3_MFMAType::getThrSize() const { return FxC(64); } +Attribute MmaAtomCDNA3_MFMAType::getThrValLayoutA() const { + return cdna3::getThrValLayoutAB(getContext(), getM(), getN(), getK(), getElemTyA(), getElemTyB(), + getElemTyAcc()); +} +Attribute MmaAtomCDNA3_MFMAType::getThrValLayoutB() const { + return cdna3::getThrValLayoutAB(getContext(), getM(), getN(), getK(), getElemTyA(), getElemTyB(), + getElemTyAcc()); +} +Attribute MmaAtomCDNA3_MFMAType::getThrValLayoutC() const { + int M = getM(); + int N = getN(); + + int GroupM = 64 / N; + int ValM0 = 4; + int ValM1 = M / 4 / GroupM; + int TyBitAcc = 32; + + return FxLayout( + FxShape(FxThr(N, GroupM), FxVal(TyBitAcc * ValM0, ValM1)), + FxStride(FxThr(M * TyBitAcc, TyBitAcc * ValM0), FxVal(1, TyBitAcc * ValM0 * GroupM))); +} + +LogicalResult MmaAtomCDNA3_MFMAType::verify(function_ref emitError, int32_t m, + int32_t n, int32_t k, Type elemTyA, Type elemTyB, + Type elemTyAcc) { + assert(m == n && "M and N must be equal"); + if (m != n) { + return emitError() << "invalid MNK dimensions for CDNA3 MFMA: " << m << "x" << n << "x" << k; + } + if (!elemTyAcc.isF32()) + return emitError() << "elemTyAcc must be f32, got " << elemTyAcc; + + auto isValidElemType = [](Type ty) { return ty.isF16() || ty.isBF16() || ty.isF32(); }; + if (!isValidElemType(elemTyA)) { + return emitError() << "elemTyA must be f16, bf16, f32, got " << elemTyA; + } + if (!isValidElemType(elemTyB)) { + return emitError() << "elemTyB must be f16, bf16, f32, got " << elemTyB; + } + return success(); +} + +} // namespace mlir::fly_rocdl diff --git a/lib/Dialect/FlyROCDL/CMakeLists.txt b/lib/Dialect/FlyROCDL/CMakeLists.txt new file mode 100644 index 00000000..3652d2c4 --- /dev/null +++ b/lib/Dialect/FlyROCDL/CMakeLists.txt @@ -0,0 +1,8 @@ +add_mlir_dialect_library(MLIRFlyROCDLDialect + Dialect.cpp + CDNA3/MmaAtom.cpp + CDNA3/CopyAtom.cpp + + DEPENDS + MLIRFlyROCDLIncGen +) diff --git a/lib/Dialect/FlyROCDL/Dialect.cpp b/lib/Dialect/FlyROCDL/Dialect.cpp new file mode 100644 index 00000000..3ae4d5d5 --- /dev/null +++ b/lib/Dialect/FlyROCDL/Dialect.cpp @@ -0,0 +1,43 @@ +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/DialectImplementation.h" +#include "llvm/ADT/StringExtras.h" +#include "llvm/ADT/TypeSwitch.h" + +#include "flydsl/Dialect/FlyROCDL/IR/Dialect.h" + +using namespace mlir; +using namespace mlir::fly_rocdl; + +#include "flydsl/Dialect/FlyROCDL/IR/Dialect.cpp.inc" +#include "flydsl/Dialect/FlyROCDL/IR/Enums.cpp.inc" + +#define GET_TYPEDEF_CLASSES +#include "flydsl/Dialect/FlyROCDL/IR/Atom.cpp.inc" + +namespace mlir::fly_rocdl { + +ParseResult parseMNKDimensionList(AsmParser &parser, int32_t &m, int32_t &n, int32_t &k) { + SmallVector dimensions; + if (parser.parseDimensionList(dimensions, false, false)) + return failure(); + if (dimensions.size() != 3) + return parser.emitError(parser.getCurrentLocation()) + << "expected 3 dimensions in MNK dimension list"; + m = dimensions[0]; + n = dimensions[1]; + k = dimensions[2]; + return success(); +} + +void printMNKDimensionList(AsmPrinter &printer, int32_t m, int32_t n, int32_t k) { + printer.printDimensionList(ArrayRef{m, n, k}); +} + +} // namespace mlir::fly_rocdl + +void FlyROCDLDialect::initialize() { + addTypes< +#define GET_TYPEDEF_LIST +#include "flydsl/Dialect/FlyROCDL/IR/Atom.cpp.inc" + >(); +} diff --git a/python/flydsl/lang/ir/__init__.py b/python/flydsl/lang/ir/__init__.py index 6d2eea43..3c1fa38c 100644 --- a/python/flydsl/lang/ir/__init__.py +++ b/python/flydsl/lang/ir/__init__.py @@ -4,6 +4,3 @@ from .module import * # from .gpu import * - -# Export MLIR IR types like Type, Value, etc. -from ..._mlir.ir import Type, Value, Context, Location, Module, Attribute, InsertionPoint diff --git a/python/mlir_flydsl/CMakeLists.txt b/python/mlir_flydsl/CMakeLists.txt index ba6ed8d7..addc5f4e 100644 --- a/python/mlir_flydsl/CMakeLists.txt +++ b/python/mlir_flydsl/CMakeLists.txt @@ -15,6 +15,15 @@ declare_mlir_dialect_python_bindings( GEN_ENUM_BINDINGS ) +declare_mlir_dialect_python_bindings( + ADD_TO_PARENT FlyPythonSources + ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/" + TD_FILE dialects/FlyROCDL.td + SOURCES + dialects/fly_rocdl.py + DIALECT_NAME fly_rocdl + GEN_ENUM_BINDINGS +) declare_mlir_python_extension(FlyPythonSources.Core MODULE_NAME _fly @@ -44,6 +53,7 @@ declare_mlir_python_extension(FlyPythonSources.RegisterEverything EMBED_CAPI_LINK_LIBS MLIRCAPIIR MLIRCPIFly + MLIRCPIFlyROCDL MLIRCAPIArith MLIRCAPIGPU MLIRCAPILLVM diff --git a/python/mlir_flydsl/FlyRegisterEverything.cpp b/python/mlir_flydsl/FlyRegisterEverything.cpp index 94b877a2..a0b6cda9 100644 --- a/python/mlir_flydsl/FlyRegisterEverything.cpp +++ b/python/mlir_flydsl/FlyRegisterEverything.cpp @@ -3,8 +3,8 @@ #include "mlir/Bindings/Python/NanobindAdaptors.h" #include "flydsl-c/FlyDialect.h" +#include "flydsl-c/FlyROCDLDialect.h" #include "flydsl/Conversion/FlyToROCDL/FlyToROCDL.h" -#include "flydsl/Dialect/Fly/IR/FlyDialect.h" #include "flydsl/Dialect/Fly/Transforms/Passes.h" namespace mlir { @@ -20,6 +20,8 @@ NB_MODULE(_mlirRegisterEverything, m) { MlirDialectHandle flyHandle = mlirGetDialectHandle__fly__(); mlirDialectHandleInsertDialect(flyHandle, registry); + MlirDialectHandle flyROCDLHandle = mlirGetDialectHandle__fly_rocdl__(); + mlirDialectHandleInsertDialect(flyROCDLHandle, registry); }); m.def("register_llvm_translations", [](MlirContext context) { mlirRegisterAllLLVMTranslations(context); }); diff --git a/python/mlir_flydsl/dialects/FlyROCDL.td b/python/mlir_flydsl/dialects/FlyROCDL.td new file mode 100644 index 00000000..7c5da756 --- /dev/null +++ b/python/mlir_flydsl/dialects/FlyROCDL.td @@ -0,0 +1,6 @@ +#ifndef PYTHON_BINDINGS_FLYROCDL_OPS +#define PYTHON_BINDINGS_FLYROCDL_OPS + +include "flydsl/Dialect/FlyROCDL/IR/Atom.td" + +#endif // PYTHON_BINDINGS_FLYROCDL_OPS diff --git a/python/mlir_flydsl/dialects/fly_rocdl.py b/python/mlir_flydsl/dialects/fly_rocdl.py new file mode 100644 index 00000000..caf90541 --- /dev/null +++ b/python/mlir_flydsl/dialects/fly_rocdl.py @@ -0,0 +1,2 @@ +from ._fly_rocdl_enum_gen import * +from ._fly_rocdl_ops_gen import * From 3bd415055e5327ea953131817da0abb1c683c3e2 Mon Sep 17 00:00:00 2001 From: Feng Shijie Date: Tue, 27 Jan 2026 18:06:21 +0000 Subject: [PATCH 004/113] Add utility nbmodules --- examples/01-vector-add.py | 16 +- examples/02-layout_algebra.py | 4 +- examples/03-mma_atom.py | 28 +- include/flydsl-c/FlyDialect.h | 98 +++++ include/flydsl-c/FlyROCDLDialect.h | 8 + include/flydsl/Dialect/Fly/IR/FlyAttrDefs.td | 7 +- include/flydsl/Dialect/Fly/IR/FlyTypeDefs.td | 29 +- lib/Bindings/Python/FlyExtension.cpp | 337 ++++++++++++++++++ lib/Bindings/Python/FlyROCDLExtension.cpp | 43 +++ lib/Bindings/Python/MainModules.cpp | 228 ------------ lib/CAPI/Dialect/Fly/FlyDialect.cpp | 188 ++++++++++ lib/CAPI/Dialect/FlyROCDL/FlyROCDLDialect.cpp | 16 + lib/Dialect/Fly/Transforms/LayoutLowering.cpp | 8 +- python/flydsl/lang/ir/core.py | 133 ++----- python/mlir_flydsl/CMakeLists.txt | 79 ++-- python/mlir_flydsl/dialects/fly_rocdl.py | 2 + 16 files changed, 825 insertions(+), 399 deletions(-) create mode 100644 lib/Bindings/Python/FlyExtension.cpp create mode 100644 lib/Bindings/Python/FlyROCDLExtension.cpp delete mode 100644 lib/Bindings/Python/MainModules.cpp diff --git a/examples/01-vector-add.py b/examples/01-vector-add.py index bdc1f323..2b5bca72 100644 --- a/examples/01-vector-add.py +++ b/examples/01-vector-add.py @@ -2,7 +2,9 @@ from flydsl import lang as fx N = 64 -memrefTy = fx.ir.Type.parse(f"!fly.memref") +memrefTy = fx.MemRefType.get( + fx.T.f32(), fx.LayoutType.get(64, 1), fx.AddressSpace.Global +) class VecAdd(fx.MlirModule): @@ -16,8 +18,8 @@ def kernel( B: memrefTy, C: memrefTy, ): - tid = fx.arith.IndexCastOp(fx.T.i32(), fx.thread_idx.x) - bid = fx.arith.IndexCastOp(fx.T.i32(), fx.block_idx.x) + tid = fx.arith.index_cast(fx.T.i32(), fx.thread_idx.x) + bid = fx.arith.index_cast(fx.T.i32(), fx.block_idx.x) tA = fx.logical_divide(A, fx.make_layout(16, 1)) tB = fx.logical_divide(B, fx.make_layout(16, 1)) @@ -30,8 +32,10 @@ def kernel( tB = fx.logical_divide(tB, fx.make_layout(1, 1)) tC = fx.logical_divide(tC, fx.make_layout(1, 1)) - RABMemRefTy = fx.ir.Type.parse(f"!fly.memref") - copyAtom = fx.make_atom(fx.ir.Type.parse("!fly.atom.universal_copy_32b")) + RABMemRefTy = fx.MemRefType.get( + fx.T.f32(), fx.LayoutType.get(1, 1), fx.AddressSpace.Register + ) + copyAtom = fx.make_atom(fx.CopyAtomUniversalCopyType.get(32)) rA = fx.memref_alloca(RABMemRefTy, fx.make_layout(1, 1)) rB = fx.memref_alloca(RABMemRefTy, fx.make_layout(1, 1)) rC = fx.memref_alloca(RABMemRefTy, fx.make_layout(1, 1)) @@ -75,7 +79,7 @@ def __call__( print(VecAdd_Module) -VecAdd_Executor = flydsl.compile(VecAdd_Module, print_after_all=True) +VecAdd_Executor = flydsl.compile(VecAdd_Module, print_after_all=False) # VecAdd_Asm = flydsl.compile(VecAdd_Module, output_format="assembly") # print(VecAdd_Asm) diff --git a/examples/02-layout_algebra.py b/examples/02-layout_algebra.py index 97dead21..d2821ee7 100644 --- a/examples/02-layout_algebra.py +++ b/examples/02-layout_algebra.py @@ -3,7 +3,9 @@ M = 16 N = 32 -memrefTy = fx.ir.Type.parse(f"!fly.memref") +memrefTy = fx.MemRefType.get( + fx.T.f32(), fx.LayoutType.get(16, 32), fx.AddressSpace.Global +) class VecCopy(fx.MlirModule): diff --git a/examples/03-mma_atom.py b/examples/03-mma_atom.py index 8180afad..9f41b0b9 100644 --- a/examples/03-mma_atom.py +++ b/examples/03-mma_atom.py @@ -3,10 +3,12 @@ MN = 16 K = 4 -ABMemRefTy = fx.ir.Type.parse(f"!fly.memref") -CMemRefTy = fx.ir.Type.parse(f"!fly.memref") -RABMemRefTy = fx.ir.Type.parse(f"!fly.memref") -RCMemRefTy = fx.ir.Type.parse(f"!fly.memref") +ABMemRefTy = fx.MemRefType.get( + fx.T.f32(), fx.LayoutType.get((MN, K), (K, 1)), fx.AddressSpace.Global +) +CMemRefTy = fx.MemRefType.get( + fx.T.f32(), fx.LayoutType.get((MN, MN), (1, MN)), fx.AddressSpace.Global +) class MmaAtom(fx.MlirModule): @@ -22,12 +24,18 @@ def kernel( ): tid = fx.arith.index_cast(fx.T.i32(), fx.thread_idx.x) - rA = fx.memref_alloca(RABMemRefTy, fx.make_layout(1, 1)) - rB = fx.memref_alloca(RABMemRefTy, fx.make_layout(1, 1)) + rA = fx.memref_alloca( + fx.MemRefType.get(fx.T.f32(), fx.LayoutType.get(1, 1)), + fx.make_layout(1, 1), + ) + rB = fx.memref_alloca( + fx.MemRefType.get(fx.T.f32(), fx.LayoutType.get(1, 1)), + fx.make_layout(1, 1), + ) - copyAtom = fx.make_atom(fx.ir.Type.parse("!fly.atom.universal_copy<32>")) + copyAtom = fx.make_atom(fx.CopyAtomUniversalCopyType.get(32)) mmaAtom = fx.make_atom( - fx.ir.Type.parse("!fly_rocdl.atom.cdna3.mfma<16x16x16, f32 x f32 = f32>") + fx.MmaAtomCDNA3_MFMAType.get(16, 16, 16, fx.T.f32(), fx.T.f32(), fx.T.f32()) ) tA = fx.logical_divide(A, fx.make_layout(1, 1)) @@ -35,7 +43,9 @@ def kernel( fx.copy_atom_call(copyAtom, fx.slice(tA, (None, tid)), rA) fx.copy_atom_call(copyAtom, fx.slice(tB, (None, tid)), rB) - rAcc = fx.memref_alloca(RCMemRefTy, fx.make_layout(4, 1)) + rAcc = fx.memref_alloca( + fx.MemRefType.get(fx.T.f32(), fx.LayoutType.get(4, 1)), fx.make_layout(4, 1) + ) f0 = fx.arith.constant(fx.T.f32(), 0.0) fx.memref_store(f0, rAcc, 0) fx.memref_store(f0, rAcc, 1) diff --git a/include/flydsl-c/FlyDialect.h b/include/flydsl-c/FlyDialect.h index a8880f47..9e6c0436 100644 --- a/include/flydsl-c/FlyDialect.h +++ b/include/flydsl-c/FlyDialect.h @@ -2,6 +2,7 @@ #define FLYDSL_C_FLYDIALECT_H #include "mlir-c/IR.h" +#include "mlir-c/Support.h" #ifdef __cplusplus extern "C" { @@ -9,6 +10,103 @@ extern "C" { MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(Fly, fly); +//===----------------------------------------------------------------------===// +// IntTupleType +//===----------------------------------------------------------------------===// + +MLIR_CAPI_EXPORTED bool mlirTypeIsAFlyIntTupleType(MlirType type); +MLIR_CAPI_EXPORTED MlirTypeID mlirFlyIntTupleTypeGetTypeID(void); + +// Accessors +MLIR_CAPI_EXPORTED bool mlirFlyIntTupleTypeIsLeaf(MlirType type); +MLIR_CAPI_EXPORTED int32_t mlirFlyIntTupleTypeGetRank(MlirType type); +MLIR_CAPI_EXPORTED int32_t mlirFlyIntTupleTypeGetDepth(MlirType type); +MLIR_CAPI_EXPORTED bool mlirFlyIntTupleTypeIsStatic(MlirType type); + +//===----------------------------------------------------------------------===// +// LayoutType +//===----------------------------------------------------------------------===// + +MLIR_CAPI_EXPORTED bool mlirTypeIsAFlyLayoutType(MlirType type); +MLIR_CAPI_EXPORTED MlirTypeID mlirFlyLayoutTypeGetTypeID(void); + +// Constructor +MLIR_CAPI_EXPORTED MlirType mlirFlyLayoutTypeGet(MlirType shape, MlirType stride); + +// Accessors +MLIR_CAPI_EXPORTED MlirType mlirFlyLayoutTypeGetShape(MlirType type); +MLIR_CAPI_EXPORTED MlirType mlirFlyLayoutTypeGetStride(MlirType type); +MLIR_CAPI_EXPORTED bool mlirFlyLayoutTypeIsLeaf(MlirType type); +MLIR_CAPI_EXPORTED int32_t mlirFlyLayoutTypeGetRank(MlirType type); +MLIR_CAPI_EXPORTED int32_t mlirFlyLayoutTypeGetDepth(MlirType type); +MLIR_CAPI_EXPORTED bool mlirFlyLayoutTypeIsStatic(MlirType type); +MLIR_CAPI_EXPORTED bool mlirFlyLayoutTypeIsStaticShape(MlirType type); +MLIR_CAPI_EXPORTED bool mlirFlyLayoutTypeIsStaticStride(MlirType type); + +//===----------------------------------------------------------------------===// +// SwizzleType +//===----------------------------------------------------------------------===// + +MLIR_CAPI_EXPORTED bool mlirTypeIsAFlySwizzleType(MlirType type); +MLIR_CAPI_EXPORTED MlirTypeID mlirFlySwizzleTypeGetTypeID(void); + +// Constructor +MLIR_CAPI_EXPORTED MlirType mlirFlySwizzleTypeGet(MlirContext ctx, int32_t mask, int32_t base, + int32_t shift); + +// Accessors +MLIR_CAPI_EXPORTED int32_t mlirFlySwizzleTypeGetMask(MlirType type); +MLIR_CAPI_EXPORTED int32_t mlirFlySwizzleTypeGetBase(MlirType type); +MLIR_CAPI_EXPORTED int32_t mlirFlySwizzleTypeGetShift(MlirType type); + +//===----------------------------------------------------------------------===// +// PointerType +//===----------------------------------------------------------------------===// + +MLIR_CAPI_EXPORTED bool mlirTypeIsAFlyPointerType(MlirType type); +MLIR_CAPI_EXPORTED MlirTypeID mlirFlyPointerTypeGetTypeID(void); + +// Constructor +MLIR_CAPI_EXPORTED MlirType mlirFlyPointerTypeGet(MlirType elemType, int32_t addressSpace, + int32_t alignment); + +// Accessors +MLIR_CAPI_EXPORTED MlirType mlirFlyPointerTypeGetElementType(MlirType type); +MLIR_CAPI_EXPORTED int32_t mlirFlyPointerTypeGetAddressSpace(MlirType type); +MLIR_CAPI_EXPORTED int32_t mlirFlyPointerTypeGetAlignment(MlirType type); +MLIR_CAPI_EXPORTED MlirType mlirFlyPointerTypeGetSwizzle(MlirType type); + +//===----------------------------------------------------------------------===// +// MemRefType +//===----------------------------------------------------------------------===// + +MLIR_CAPI_EXPORTED bool mlirTypeIsAFlyMemRefType(MlirType type); +MLIR_CAPI_EXPORTED MlirTypeID mlirFlyMemRefTypeGetTypeID(void); + +// Constructor - layout must be LayoutType +MLIR_CAPI_EXPORTED MlirType mlirFlyMemRefTypeGet(MlirType elemType, MlirType layout, + int32_t addressSpace, int32_t alignment); + +// Accessors +MLIR_CAPI_EXPORTED MlirType mlirFlyMemRefTypeGetElementType(MlirType type); +MLIR_CAPI_EXPORTED MlirType mlirFlyMemRefTypeGetLayout(MlirType type); +MLIR_CAPI_EXPORTED int32_t mlirFlyMemRefTypeGetAddressSpace(MlirType type); +MLIR_CAPI_EXPORTED int32_t mlirFlyMemRefTypeGetAlignment(MlirType type); +MLIR_CAPI_EXPORTED MlirType mlirFlyMemRefTypeGetSwizzle(MlirType type); + +//===----------------------------------------------------------------------===// +// CopyAtomUniversalCopyType +//===----------------------------------------------------------------------===// + +MLIR_CAPI_EXPORTED bool mlirTypeIsAFlyCopyAtomUniversalCopyType(MlirType type); +MLIR_CAPI_EXPORTED MlirTypeID mlirFlyCopyAtomUniversalCopyTypeGetTypeID(void); + +// Constructor +MLIR_CAPI_EXPORTED MlirType mlirFlyCopyAtomUniversalCopyTypeGet(MlirContext ctx, int32_t bitSize); + +// Accessors +MLIR_CAPI_EXPORTED int32_t mlirFlyCopyAtomUniversalCopyTypeGetBitSize(MlirType type); + #ifdef __cplusplus } #endif diff --git a/include/flydsl-c/FlyROCDLDialect.h b/include/flydsl-c/FlyROCDLDialect.h index 96981fff..9f7f3a3c 100644 --- a/include/flydsl-c/FlyROCDLDialect.h +++ b/include/flydsl-c/FlyROCDLDialect.h @@ -2,6 +2,7 @@ #define FLYDSL_C_FLYROCDLDIALECT_H #include "mlir-c/IR.h" +#include "mlir-c/Support.h" #ifdef __cplusplus extern "C" { @@ -9,6 +10,13 @@ extern "C" { MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(FlyROCDL, fly_rocdl); +//===----------------------------------------------------------------------===// +// MmaAtomCDNA3_MFMAType +//===----------------------------------------------------------------------===// + +MLIR_CAPI_EXPORTED bool mlirTypeIsAFlyROCDLMmaAtomCDNA3_MFMAType(MlirType type); +MLIR_CAPI_EXPORTED MlirTypeID mlirFlyROCDLMmaAtomCDNA3_MFMATypeGetTypeID(void); + #ifdef __cplusplus } #endif diff --git a/include/flydsl/Dialect/Fly/IR/FlyAttrDefs.td b/include/flydsl/Dialect/Fly/IR/FlyAttrDefs.td index 61e9fd71..ded0372b 100644 --- a/include/flydsl/Dialect/Fly/IR/FlyAttrDefs.td +++ b/include/flydsl/Dialect/Fly/IR/FlyAttrDefs.td @@ -20,10 +20,9 @@ def Fly_CachePolicy : I32EnumAttr<"CachePolicy", "", [ def Fly_CachePolicyAttr : EnumAttr {} def Fly_AddressSpace : I32EnumAttr<"AddressSpace", "", [ - I32EnumAttrCase<"Flat", 0, "flat">, - I32EnumAttrCase<"Global", 1, "global">, - I32EnumAttrCase<"Shared", 2, "shared">, - I32EnumAttrCase<"Register", 3, "register"> + I32EnumAttrCase<"Global", 0, "global">, + I32EnumAttrCase<"Shared", 1, "shared">, + I32EnumAttrCase<"Register", 2, "register"> ]> { let genSpecializedAttr = 0; let cppNamespace = Fly_Dialect.cppNamespace; diff --git a/include/flydsl/Dialect/Fly/IR/FlyTypeDefs.td b/include/flydsl/Dialect/Fly/IR/FlyTypeDefs.td index 0e4a2c1a..0caa42ab 100644 --- a/include/flydsl/Dialect/Fly/IR/FlyTypeDefs.td +++ b/include/flydsl/Dialect/Fly/IR/FlyTypeDefs.td @@ -65,6 +65,12 @@ def Fly_Layout : Fly_Type<"Layout", "layout", [ def Fly_Swizzle : Fly_Type<"Swizzle", "swizzle", []> { let parameters = (ins Fly_SwizzleAttr:$attr); let assemblyFormat = "`<` $attr `>`"; + + let builders = [ + TypeBuilderWithInferredContext<(ins "SwizzleAttr":$attr), [{ + return $_get(attr.getContext(), attr); + }]> + ]; } def Fly_ComposedLayout : Fly_Type<"ComposedLayout", "composed_layout", [ @@ -119,6 +125,13 @@ def Fly_Pointer : Fly_Type<"Pointer", "ptr", []> { return $_get(elemTy.getContext(), elemTy, addressSpace, AlignAttr::getTrivialAlignment(elemTy.getContext()), SwizzleAttr::getTrivialSwizzle(elemTy.getContext())); + }]>, + TypeBuilderWithInferredContext<(ins "Type":$elemTy, "AddressSpaceAttr":$addressSpace, "AlignAttr":$alignment), [{ + return $_get(elemTy.getContext(), elemTy, addressSpace, alignment, + SwizzleAttr::getTrivialSwizzle(elemTy.getContext())); + }]>, + TypeBuilderWithInferredContext<(ins "Type":$elemTy, "AddressSpaceAttr":$addressSpace, "AlignAttr":$alignment, "SwizzleAttr":$swizzle), [{ + return $_get(elemTy.getContext(), elemTy, addressSpace, alignment, swizzle); }]> ]; let extraClassDeclaration = [{}]; @@ -154,13 +167,25 @@ def Fly_MemRef : Fly_Type<"MemRef", "memref", []> { let assemblyFormat = "`<` $elemTy `,` `` $addressSpace `,` $layout (`,` $alignment^)? (`,` $swizzle^)? `>`"; let builders = [ - AttrBuilderWithInferredContext<(ins "Type":$elemTy, "AddressSpaceAttr":$addressSpace, "LayoutAttr":$layout), [{ + TypeBuilderWithInferredContext<(ins "Type":$elemTy, "LayoutAttr":$layout), [{ + // default address space is Global + return $_get(elemTy.getContext(), elemTy, AddressSpaceAttr::get(elemTy.getContext(), static_cast<::mlir::fly::AddressSpace>(0)), layout, + AlignAttr::getTrivialAlignment(elemTy.getContext()), + SwizzleAttr::getTrivialSwizzle(elemTy.getContext())); + }]>, + TypeBuilderWithInferredContext<(ins "Type":$elemTy, "AddressSpaceAttr":$addressSpace, "LayoutAttr":$layout), [{ return $_get(elemTy.getContext(), elemTy, addressSpace, layout, AlignAttr::getTrivialAlignment(elemTy.getContext()), SwizzleAttr::getTrivialSwizzle(elemTy.getContext())); + }]>, + TypeBuilderWithInferredContext<(ins "Type":$elemTy, "AddressSpaceAttr":$addressSpace, "LayoutAttr":$layout, "AlignAttr":$alignment), [{ + return $_get(elemTy.getContext(), elemTy, addressSpace, layout, alignment, + SwizzleAttr::getTrivialSwizzle(elemTy.getContext())); + }]>, + TypeBuilderWithInferredContext<(ins "Type":$elemTy, "AddressSpaceAttr":$addressSpace, "LayoutAttr":$layout, "AlignAttr":$alignment, "SwizzleAttr":$swizzle), [{ + return $_get(elemTy.getContext(), elemTy, addressSpace, layout, alignment, swizzle); }]> ]; - let extraClassDeclaration = [{}]; } def IteratorLikeType : AnyTypeOf<[Fly_IntTuple, Fly_Pointer]>; diff --git a/lib/Bindings/Python/FlyExtension.cpp b/lib/Bindings/Python/FlyExtension.cpp new file mode 100644 index 00000000..3feb8581 --- /dev/null +++ b/lib/Bindings/Python/FlyExtension.cpp @@ -0,0 +1,337 @@ +#include "mlir-c/Bindings/Python/Interop.h" +#include "mlir-c/Dialect/LLVM.h" +#include "mlir-c/IR.h" +#include "mlir-c/Support.h" +#include "mlir/Bindings/Python/Nanobind.h" +#include "mlir/Bindings/Python/NanobindAdaptors.h" +#include "mlir/CAPI/IR.h" +#include "mlir/CAPI/Wrap.h" + +#include +#include +#include + +#include "flydsl-c/FlyDialect.h" +#include "flydsl/Dialect/Fly/IR/FlyDialect.h" +#include "flydsl/Dialect/Fly/Utils/IntUtils.h" + +namespace nb = nanobind; +using namespace nb::literals; +using namespace mlir; +using namespace mlir::fly; +using namespace mlir::python::nanobind_adaptors; + +namespace { + +struct IntTupleAttrBuilder { + MLIRContext *ctx; + std::vector dyncElems{}; + + IntTupleAttrBuilder(MLIRContext *ctx) : ctx(ctx) {} + + void clear() { dyncElems.clear(); } + + IntTupleAttr operator()(nb::handle args) { + if (PyTuple_Check(args.ptr())) { + SmallVector elements; + for (auto item : args) { + elements.push_back((*this)(item)); + } + return IntTupleAttr::get(ArrayAttr::get(ctx, elements)); + } else if (PyLong_Check(args.ptr())) { + int32_t cInt = PyLong_AsLong(args.ptr()); + return IntTupleAttr::get(IntAttr::getStatic(ctx, cInt)); + } else if (args.is_none()) { + return IntTupleAttr::getLeafNone(ctx); + } else { + if (!nb::hasattr(args, "_CAPIPtr")) { + throw std::invalid_argument("Expected I32, got: " + + std::string(nb::str(nb::type_name(args)).c_str())); + } + // dynamic value, default as i32 + dyncElems.push_back(args); + return IntTupleAttr::get(IntAttr::getDynamic(ctx)); + } + } +}; + +} // namespace + +int32_t rank(MlirValue int_or_tuple) { + mlir::Value val = unwrap(int_or_tuple); + mlir::Type ty = val.getType(); + if (auto intTupleTy = dyn_cast(ty)) { + return intTupleTy.getAttr().rank(); + } else if (auto layoutTy = dyn_cast(ty)) { + return layoutTy.getAttr().rank(); + } else if (auto composedLayoutTy = dyn_cast(ty)) { + return composedLayoutTy.getAttr().rank(); + } else if (auto coordTensorTy = dyn_cast(ty)) { + return coordTensorTy.getLayout().rank(); + } else if (auto memRefTy = dyn_cast(ty)) { + return memRefTy.getLayout().rank(); + } else { + throw std::invalid_argument("Unsupported type: "); + ty.dump(); + return 0; + } +} + +int32_t depth(MlirValue int_or_tuple) { + mlir::Value val = unwrap(int_or_tuple); + mlir::Type ty = val.getType(); + if (auto intTupleTy = dyn_cast(ty)) { + return intTupleTy.getAttr().depth(); + } else if (auto layoutTy = dyn_cast(ty)) { + return layoutTy.getAttr().depth(); + } else if (auto composedLayoutTy = dyn_cast(ty)) { + return composedLayoutTy.getAttr().depth(); + } else if (auto coordTensorTy = dyn_cast(ty)) { + return coordTensorTy.getLayout().depth(); + } else if (auto memRefTy = dyn_cast(ty)) { + return memRefTy.getLayout().depth(); + } else { + throw std::invalid_argument("Unsupported type: "); + ty.dump(); + return 0; + } +} + +// nb::object getFlyTypingModule() { +// static nb::object typing = nb::steal(nb::module_::import_("fly.lang.typing")); +// return typing; +// } + +// nb::object make_int32(int value) { +// static nb::object int32_cls = getFlyTypingModule().attr("Int32"); + +// return int32_cls(value); +// } + +// nb::object make_int32_tuple(int value) { +// static nb::object int32_cls = getFlyTypingModule().attr("Int32"); + +// nb::list subList; +// subList.append(int32_cls(value + 1)); +// nb::tuple subTuple = nb::tuple(subList); + +// nb::list retList; +// retList.append(int32_cls(value)); +// retList.append(subTuple); +// retList.append(nb::int_(0)); + +// return nb::tuple(retList); +// } + +NB_MODULE(_fly, m) { + m.doc() = "MLIR Python FlyDSL Extension"; + + m.def( + "infer_int_tuple_type", + [](nb::handle int_or_tuple, MlirContext context) { + MLIRContext *ctx = unwrap(context); + IntTupleAttrBuilder builder{ctx}; + IntTupleAttr attr = builder(int_or_tuple); + return std::make_pair(wrap(IntTupleType::get(attr)), builder.dyncElems); + }, + "int_or_tuple"_a, "context"_a = nb::none(), + // clang-format off + nb::sig("def infer_int_tuple_type(int_or_tuple, context: " MAKE_MLIR_PYTHON_QUALNAME("ir.Context") " | None = None)"), + // clang-format on + "infer IntTupleType for given input"); + + m.def("rank", &rank, "int_or_tuple"_a, + nb::sig("def rank(int_or_tuple: " MAKE_MLIR_PYTHON_QUALNAME("ir.Value") ") -> int")); + m.def("depth", &depth, "int_or_tuple"_a, + nb::sig("def depth(int_or_tuple: " MAKE_MLIR_PYTHON_QUALNAME("ir.Value") ") -> int")); + + //===--------------------------------------------------------------------===// + // Core Types + //===--------------------------------------------------------------------===// + + mlir_type_subclass(m, "IntTupleType", mlirTypeIsAFlyIntTupleType, mlirFlyIntTupleTypeGetTypeID) + .def_classmethod( + "get", + [](const nb::object &cls, nb::handle int_or_tuple, MlirContext context) { + MLIRContext *ctx = unwrap(context); + IntTupleAttrBuilder builder{ctx}; + IntTupleAttr attr = builder(int_or_tuple); + return cls(wrap(IntTupleType::get(attr))); + }, + "cls"_a, "int_or_tuple"_a, "context"_a = nb::none(), + // clang-format off + nb::sig("def get(cls, int_or_tuple: int | tuple, context: " MAKE_MLIR_PYTHON_QUALNAME("ir.Context") " | None = None) -> IntTupleType"), + // clang-format on + "Create an IntTupleType from Python int or tuple") + .def_property_readonly("rank", [](MlirType self) { return mlirFlyIntTupleTypeGetRank(self); }) + .def_property_readonly("depth", + [](MlirType self) { return mlirFlyIntTupleTypeGetDepth(self); }) + .def_property_readonly("is_leaf", + [](MlirType self) { return mlirFlyIntTupleTypeIsLeaf(self); }) + .def_property_readonly("is_static", + [](MlirType self) { return mlirFlyIntTupleTypeIsStatic(self); }); + + mlir_type_subclass(m, "LayoutType", mlirTypeIsAFlyLayoutType, mlirFlyLayoutTypeGetTypeID) + .def_classmethod( + "get", + [](const nb::object &cls, nb::handle shape, nb::handle stride, MlirContext context) { + MLIRContext *ctx = unwrap(context); + auto toIntTupleAttr = [ctx](nb::handle h) -> IntTupleAttr { + if (nb::hasattr(h, "_CAPIPtr")) { + auto capsule = nb::cast(h.attr(MLIR_PYTHON_CAPI_PTR_ATTR)); + MlirType mlirTy = mlirPythonCapsuleToType(capsule.ptr()); + auto intTupleType = dyn_cast(unwrap(mlirTy)); + if (!intTupleType) { + throw std::invalid_argument("Expected IntTupleType, got other MlirType"); + } + return intTupleType.getAttr(); + } + IntTupleAttrBuilder builder{ctx}; + return builder(h); + }; + + IntTupleAttr shapeAttr = toIntTupleAttr(shape); + IntTupleAttr strideAttr = toIntTupleAttr(stride); + auto layoutAttr = LayoutAttr::get(ctx, shapeAttr, strideAttr); + return cls(wrap(LayoutType::get(layoutAttr))); + }, + "cls"_a, "shape"_a, "stride"_a, "context"_a = nb::none(), + // clang-format off + nb::sig("def get(cls, shape: int | tuple | IntTupleType, stride: int | tuple | IntTupleType, context: " MAKE_MLIR_PYTHON_QUALNAME("ir.Context") " | None = None) -> LayoutType"), + // clang-format on + "Create a LayoutType with shape and stride") + .def_property_readonly("shape", [](MlirType self) { return mlirFlyLayoutTypeGetShape(self); }) + .def_property_readonly("stride", + [](MlirType self) { return mlirFlyLayoutTypeGetStride(self); }) + .def_property_readonly("rank", [](MlirType self) { return mlirFlyLayoutTypeGetRank(self); }) + .def_property_readonly("depth", [](MlirType self) { return mlirFlyLayoutTypeGetDepth(self); }) + .def_property_readonly("is_leaf", [](MlirType self) { return mlirFlyLayoutTypeIsLeaf(self); }) + .def_property_readonly("is_static", + [](MlirType self) { return mlirFlyLayoutTypeIsStatic(self); }) + .def_property_readonly("is_static_shape", + [](MlirType self) { return mlirFlyLayoutTypeIsStaticShape(self); }) + .def_property_readonly("is_static_stride", + [](MlirType self) { return mlirFlyLayoutTypeIsStaticStride(self); }); + + mlir_type_subclass(m, "SwizzleType", mlirTypeIsAFlySwizzleType, mlirFlySwizzleTypeGetTypeID) + .def_classmethod( + "get", + [](const nb::object &cls, int32_t mask, int32_t base, int32_t shift, + MlirContext context) { + MLIRContext *ctx = unwrap(context); + SwizzleAttr attr = SwizzleAttr::get(ctx, mask, base, shift); + return cls(wrap(SwizzleType::get(attr))); + }, + "cls"_a, "mask"_a, "base"_a, "shift"_a, "context"_a = nb::none(), + // clang-format off + nb::sig("def get(cls, mask: int, base: int, shift: int, context: " MAKE_MLIR_PYTHON_QUALNAME("ir.Context") " | None = None) -> SwizzleType"), + // clang-format on + "Create a SwizzleType") + .def_property_readonly("mask", [](MlirType self) { return mlirFlySwizzleTypeGetMask(self); }) + .def_property_readonly("base", [](MlirType self) { return mlirFlySwizzleTypeGetBase(self); }) + .def_property_readonly("shift", + [](MlirType self) { return mlirFlySwizzleTypeGetShift(self); }); + + mlir_type_subclass(m, "PointerType", mlirTypeIsAFlyPointerType, mlirFlyPointerTypeGetTypeID) + .def_classmethod( + "get", + [](const nb::object &cls, nb::object elemTyObj, std::optional addressSpace, + std::optional alignment, MlirContext context) { + MLIRContext *ctx = unwrap(context); + + // Manual type conversion from nb::object to MlirType + auto capsule = nb::cast(elemTyObj.attr(MLIR_PYTHON_CAPI_PTR_ATTR)); + MlirType elemTy = mlirPythonCapsuleToType(capsule.ptr()); + + // default address space is Register + AddressSpace addr = AddressSpace::Register; + if (addressSpace.has_value()) { + addr = static_cast(addressSpace.value()); + } + int32_t alignSize = 1; + if (alignment.has_value()) { + alignSize = alignment.value(); + } + assert(alignSize > 0 && "alignment must be positive"); + + return cls(wrap(fly::PointerType::get(unwrap(elemTy), AddressSpaceAttr::get(ctx, addr), + AlignAttr::get(ctx, alignSize)))); + }, + "cls"_a, "elem_ty"_a, "address_space"_a = nb::none(), "alignment"_a = nb::none(), + "context"_a = nb::none(), + // clang-format off + nb::sig("def get(cls, elem_ty: " MAKE_MLIR_PYTHON_QUALNAME("ir.Type") ", address_space: int = 0, alignment: int | None = None, context: " MAKE_MLIR_PYTHON_QUALNAME("ir.Context") " | None = None) -> PointerType"), + // clang-format on + "Create a PointerType with element type and address space") + .def_property_readonly("element_type", + [](MlirType self) { return mlirFlyPointerTypeGetElementType(self); }) + .def_property_readonly("address_space", + [](MlirType self) { return mlirFlyPointerTypeGetAddressSpace(self); }) + .def_property_readonly("alignment", + [](MlirType self) { return mlirFlyPointerTypeGetAlignment(self); }) + .def_property_readonly("swizzle", + [](MlirType self) { return mlirFlyPointerTypeGetSwizzle(self); }); + + mlir_type_subclass(m, "MemRefType", mlirTypeIsAFlyMemRefType, mlirFlyMemRefTypeGetTypeID) + .def_classmethod( + "get", + [](const nb::object &cls, MlirType elemTy, MlirType layoutMlirTy, + std::optional addressSpace, std::optional alignment, + MlirContext context) { + MLIRContext *ctx = unwrap(context); + + auto layoutType = dyn_cast(unwrap(layoutMlirTy)); + if (!layoutType) { + throw std::invalid_argument("layout must be a LayoutType"); + } + + // default address space is Register + AddressSpace addr = AddressSpace::Register; + if (addressSpace.has_value()) { + addr = static_cast(addressSpace.value()); + } + + int32_t alignSize = 1; + if (alignment.has_value()) { + alignSize = alignment.value(); + } + assert(alignSize > 0 && "alignment must be positive"); + + return cls( + wrap(fly::MemRefType::get(unwrap(elemTy), AddressSpaceAttr::get(ctx, addr), + layoutType.getAttr(), AlignAttr::get(ctx, alignSize)))); + }, + "cls"_a, "elem_ty"_a, "layout"_a, "address_space"_a = 0, "alignment"_a = nb::none(), + "context"_a = nb::none(), + // clang-format off + nb::sig("def get(cls, elem_ty: " MAKE_MLIR_PYTHON_QUALNAME("ir.Type") ", layout: LayoutType, address_space: int = 0, alignment: int | None = None, context: " MAKE_MLIR_PYTHON_QUALNAME("ir.Context") " | None = None) -> MemRefType"), + // clang-format on + "Create a MemRefType with element type, layout, address space and alignment") + .def_property_readonly("element_type", + [](MlirType self) { return mlirFlyMemRefTypeGetElementType(self); }) + .def_property_readonly("layout", + [](MlirType self) { return mlirFlyMemRefTypeGetLayout(self); }) + .def_property_readonly("address_space", + [](MlirType self) { return mlirFlyMemRefTypeGetAddressSpace(self); }) + .def_property_readonly("alignment", + [](MlirType self) { return mlirFlyMemRefTypeGetAlignment(self); }) + .def_property_readonly("swizzle", + [](MlirType self) { return mlirFlyMemRefTypeGetSwizzle(self); }); + + mlir_type_subclass(m, "CopyAtomUniversalCopyType", mlirTypeIsAFlyCopyAtomUniversalCopyType, + mlirFlyCopyAtomUniversalCopyTypeGetTypeID) + .def_classmethod( + "get", + [](const nb::object &cls, int32_t bitSize, MlirContext context) { + MLIRContext *ctx = unwrap(context); + return cls(wrap(CopyAtomUniversalCopyType::get(ctx, bitSize))); + }, + "cls"_a, "bitSize"_a, "context"_a = nb::none(), + // clang-format off + nb::sig("def get(cls, bitSize: int, context: " MAKE_MLIR_PYTHON_QUALNAME("ir.Context") " | None = None) -> CopyAtomUniversalCopyType"), + // clang-format on + "Create a CopyAtomUniversalCopyType with bit size") + .def_property_readonly("bit_size", [](MlirType self) { + return mlirFlyCopyAtomUniversalCopyTypeGetBitSize(self); + }); +} diff --git a/lib/Bindings/Python/FlyROCDLExtension.cpp b/lib/Bindings/Python/FlyROCDLExtension.cpp new file mode 100644 index 00000000..98d5c83b --- /dev/null +++ b/lib/Bindings/Python/FlyROCDLExtension.cpp @@ -0,0 +1,43 @@ +#include "mlir-c/Bindings/Python/Interop.h" +#include "mlir-c/Dialect/LLVM.h" +#include "mlir-c/IR.h" +#include "mlir-c/Support.h" +#include "mlir/Bindings/Python/Nanobind.h" +#include "mlir/Bindings/Python/NanobindAdaptors.h" +#include "mlir/CAPI/IR.h" +#include "mlir/CAPI/Wrap.h" + +#include +#include +#include + +#include "flydsl-c/FlyROCDLDialect.h" +#include "flydsl/Dialect/Fly/IR/FlyDialect.h" +#include "flydsl/Dialect/FlyROCDL/IR/Dialect.h" + +namespace nb = nanobind; +using namespace nb::literals; +using namespace mlir; +using namespace mlir::fly; +using namespace mlir::fly_rocdl; +using namespace mlir::python::nanobind_adaptors; + +NB_MODULE(_fly_rocdl, m) { + m.doc() = "MLIR Python FlyROCDL Extension"; + + mlir_type_subclass(m, "MmaAtomCDNA3_MFMAType", mlirTypeIsAFlyROCDLMmaAtomCDNA3_MFMAType, + mlirFlyROCDLMmaAtomCDNA3_MFMATypeGetTypeID) + .def_classmethod( + "get", + [](const nb::object &cls, int32_t m, int32_t n, int32_t k, MlirType elemTyA, + MlirType elemTyB, MlirType elemTyAcc, MlirContext context) { + return cls(wrap(MmaAtomCDNA3_MFMAType::get(m, n, k, unwrap(elemTyA), unwrap(elemTyB), + unwrap(elemTyAcc)))); + }, + "cls"_a, "m"_a, "n"_a, "k"_a, "elem_ty_a"_a, "elem_ty_b"_a, "elem_ty_acc"_a, + "context"_a = nb::none(), + // clang-format off + nb::sig("def get(cls, m: int, n: int, k: int, elem_ty_a: " MAKE_MLIR_PYTHON_QUALNAME("ir.Type") ", elem_ty_b: " MAKE_MLIR_PYTHON_QUALNAME("ir.Type") ", elem_ty_acc: " MAKE_MLIR_PYTHON_QUALNAME("ir.Type") ", context: " MAKE_MLIR_PYTHON_QUALNAME("ir.Context") " | None = None) -> MmaAtomCDNA3_MFMAType"), + // clang-format on + "Create a MmaAtomCDNA3_MFMAType with m, n, k dimensions and element types"); +} diff --git a/lib/Bindings/Python/MainModules.cpp b/lib/Bindings/Python/MainModules.cpp deleted file mode 100644 index a94fe75f..00000000 --- a/lib/Bindings/Python/MainModules.cpp +++ /dev/null @@ -1,228 +0,0 @@ -#include "mlir-c/Bindings/Python/Interop.h" -#include "mlir-c/Dialect/LLVM.h" -#include "mlir-c/IR.h" -#include "mlir-c/Support.h" -#include "mlir/Bindings/Python/Nanobind.h" -#include "mlir/Bindings/Python/NanobindAdaptors.h" -#include "mlir/CAPI/IR.h" -#include "mlir/CAPI/Wrap.h" - -#include -#include -#include - -#include "flydsl/Dialect/Fly/IR/FlyDialect.h" -#include "flydsl/Dialect/Fly/Utils/IntUtils.h" - -namespace nb = nanobind; -using namespace nb::literals; -using namespace mlir; -using namespace mlir::fly; - -// ----------------------------------------------------------------------------- -// Module initialization. -// ----------------------------------------------------------------------------- - -namespace { - -/// Helper to convert Python value to IntTupleAttr -struct IntTupleAttrBuilder { - MLIRContext *ctx; - std::vector dyncElems{}; - - IntTupleAttrBuilder(MLIRContext *ctx) : ctx(ctx) {} - - IntTupleAttr operator()(nb::handle args) { - if (PyTuple_Check(args.ptr())) { - SmallVector elements; - for (auto item : args) { - elements.push_back((*this)(item)); - } - return IntTupleAttr::get(ArrayAttr::get(ctx, elements)); - } else if (PyLong_Check(args.ptr())) { - int32_t cInt = PyLong_AsLong(args.ptr()); - return IntTupleAttr::get(IntAttr::getStatic(ctx, cInt)); - } else if (args.is_none()) { - return IntTupleAttr::getLeafNone(ctx); - } else { - // Dynamic value - for now treat as dynamic - dyncElems.push_back(args); - return IntTupleAttr::get(IntAttr::getDynamic(ctx)); - } - } -}; - -} // namespace - -int32_t rank(nb::handle int_or_tuple) { - nb::object capsule = int_or_tuple.attr("_CAPIPtr"); - MlirValue mlirVal = mlirPythonCapsuleToValue(capsule.ptr()); - mlir::Value val = unwrap(mlirVal); - mlir::Type ty = val.getType(); - if (auto intTupleTy = dyn_cast(ty)) { - return intTupleTy.getAttr().rank(); - } else if (auto layoutTy = dyn_cast(ty)) { - return layoutTy.getAttr().rank(); - } - return 1; -} - -int32_t depth(nb::handle int_or_tuple) { - nb::object capsule = int_or_tuple.attr("_CAPIPtr"); - MlirValue mlirVal = mlirPythonCapsuleToValue(capsule.ptr()); - mlir::Value val = unwrap(mlirVal); - mlir::Type ty = val.getType(); - if (auto intTupleTy = dyn_cast(ty)) { - return intTupleTy.getAttr().depth(); - } else if (auto layoutTy = dyn_cast(ty)) { - return layoutTy.getAttr().depth(); - } - return 0; -} - -// nb::object getFlyTypingModule() { -// static nb::object typing = nb::steal(nb::module_::import_("fly.lang.typing")); -// return typing; -// } - -// nb::object make_int32(int value) { -// static nb::object int32_cls = getFlyTypingModule().attr("Int32"); - -// return int32_cls(value); -// } - -// nb::object make_int32_tuple(int value) { -// static nb::object int32_cls = getFlyTypingModule().attr("Int32"); - -// nb::list subList; -// subList.append(int32_cls(value + 1)); -// nb::tuple subTuple = nb::tuple(subList); - -// nb::list retList; -// retList.append(int32_cls(value)); -// retList.append(subTuple); -// retList.append(nb::int_(0)); - -// return nb::tuple(retList); -// } - -NB_MODULE(_fly, m) { - m.doc() = "MLIR Python FlyDSL Extension"; - - m.def( - "infer_int_tuple_type", - [](MlirContext context, nb::handle int_or_tuple) { - MLIRContext *ctx = unwrap(context); - IntTupleAttrBuilder builder{ctx}; - IntTupleAttr attr = builder(int_or_tuple); - auto intTupleType = IntTupleType::get(attr); - MlirType wrappedType = wrap(intTupleType); - return std::make_pair(wrappedType, builder.dyncElems); - }, - nb::arg("context"), nb::arg("int_or_tuple")); - - m.def( - "infer_layout_type", - [](MlirContext context, nb::handle shape, nb::handle stride) { - MLIRContext *ctx = unwrap(context); - IntTupleAttrBuilder builder{ctx}; - IntTupleAttr shapeAttr = builder(shape); - IntTupleAttr strideAttr = builder(stride); - auto layoutAttr = LayoutAttr::get(ctx, shapeAttr, strideAttr); - auto layoutType = LayoutType::get(ctx, layoutAttr); - MlirType wrappedType = wrap(layoutType); - return wrappedType; - }, - nb::arg("context"), nb::arg("shape"), nb::arg("stride")); - - m.def("rank", &rank, nb::arg("int_or_tuple")); - m.def("depth", &depth, nb::arg("int_or_tuple")); - - //===--------------------------------------------------------------------===// - // Fly Type Classes with static get() methods - //===--------------------------------------------------------------------===// - - nb::class_(m, "PointerType") - .def_static( - "get", - [](MlirType elemTy, int32_t addressSpace, std::optional alignment) { - mlir::Type unwrappedElemTy = unwrap(elemTy); - MLIRContext *ctx = unwrappedElemTy.getContext(); - - AddressSpaceAttr addrSpaceAttr = - AddressSpaceAttr::get(ctx, static_cast(addressSpace)); - - fly::PointerType ptrType; - if (alignment.has_value()) { - AlignAttr alignAttr = AlignAttr::get(ctx, alignment.value()); - ptrType = fly::PointerType::get(ctx, unwrappedElemTy, addrSpaceAttr, alignAttr, - SwizzleAttr::getTrivialSwizzle(ctx)); - } else { - ptrType = fly::PointerType::get(unwrappedElemTy, addrSpaceAttr); - } - return wrap(static_cast(ptrType)); - }, - nb::arg("elem_ty"), nb::arg("address_space"), nb::arg("alignment") = nb::none(), - "Create a PointerType with element type and address space"); - - nb::class_(m, "MemRefType") - .def_static( - "get", - [](MlirType elemTy, int32_t addressSpace, MlirType layoutTy, - std::optional alignment) { - mlir::Type unwrappedElemTy = unwrap(elemTy); - mlir::Type unwrappedLayoutTy = unwrap(layoutTy); - MLIRContext *ctx = unwrappedElemTy.getContext(); - - auto layoutType = dyn_cast(unwrappedLayoutTy); - if (!layoutType) { - throw std::invalid_argument("layout must be a LayoutType"); - } - - AddressSpaceAttr addrSpaceAttr = - AddressSpaceAttr::get(ctx, static_cast(addressSpace)); - LayoutAttr layoutAttr = layoutType.getAttr(); - - fly::MemRefType memrefType; - if (alignment.has_value()) { - AlignAttr alignAttr = AlignAttr::get(ctx, alignment.value()); - memrefType = fly::MemRefType::get(ctx, unwrappedElemTy, addrSpaceAttr, layoutAttr, - alignAttr, SwizzleAttr::getTrivialSwizzle(ctx)); - } else { - memrefType = fly::MemRefType::get(unwrappedElemTy, addrSpaceAttr, layoutAttr); - } - return wrap(static_cast(memrefType)); - }, - nb::arg("elem_ty"), nb::arg("address_space"), nb::arg("layout"), - nb::arg("alignment") = nb::none(), - "Create a MemRefType with element type, address space and layout"); - - nb::class_(m, "LayoutType") - .def_static( - "get", - [](MlirContext context, nb::handle shape, nb::handle stride) { - MLIRContext *ctx = unwrap(context); - IntTupleAttrBuilder builder{ctx}; - IntTupleAttr shapeAttr = builder(shape); - IntTupleAttr strideAttr = builder(stride); - auto layoutAttr = LayoutAttr::get(ctx, shapeAttr, strideAttr); - auto layoutType = LayoutType::get(ctx, layoutAttr); - return wrap(static_cast(layoutType)); - }, - nb::arg("context"), nb::arg("shape"), nb::arg("stride"), - "Create a LayoutType with shape and stride"); - - // IntTupleType class - nb::class_(m, "IntTupleType") - .def_static( - "get", - [](MlirContext context, nb::handle int_or_tuple) { - MLIRContext *ctx = unwrap(context); - IntTupleAttrBuilder builder{ctx}; - IntTupleAttr attr = builder(int_or_tuple); - auto intTupleType = IntTupleType::get(attr); - return std::make_pair(wrap(static_cast(intTupleType)), builder.dyncElems); - }, - nb::arg("context"), nb::arg("int_or_tuple"), - "Create an IntTupleType from Python int or tuple"); -} diff --git a/lib/CAPI/Dialect/Fly/FlyDialect.cpp b/lib/CAPI/Dialect/Fly/FlyDialect.cpp index ce18bfb6..60e74a80 100644 --- a/lib/CAPI/Dialect/Fly/FlyDialect.cpp +++ b/lib/CAPI/Dialect/Fly/FlyDialect.cpp @@ -1,6 +1,194 @@ #include "flydsl-c/FlyDialect.h" #include "flydsl/Dialect/Fly/IR/FlyDialect.h" +#include "mlir/CAPI/IR.h" #include "mlir/CAPI/Registration.h" +using namespace mlir; +using namespace mlir::fly; + MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(Fly, fly, mlir::fly::FlyDialect) + +//===----------------------------------------------------------------------===// +// IntTupleType +//===----------------------------------------------------------------------===// + +bool mlirTypeIsAFlyIntTupleType(MlirType type) { return isa(unwrap(type)); } + +MlirTypeID mlirFlyIntTupleTypeGetTypeID(void) { return wrap(IntTupleType::getTypeID()); } + +bool mlirFlyIntTupleTypeIsLeaf(MlirType type) { return cast(unwrap(type)).isLeaf(); } + +int32_t mlirFlyIntTupleTypeGetRank(MlirType type) { + return cast(unwrap(type)).rank(); +} + +int32_t mlirFlyIntTupleTypeGetDepth(MlirType type) { + return cast(unwrap(type)).depth(); +} + +bool mlirFlyIntTupleTypeIsStatic(MlirType type) { + return cast(unwrap(type)).isStatic(); +} + +//===----------------------------------------------------------------------===// +// LayoutType +//===----------------------------------------------------------------------===// + +bool mlirTypeIsAFlyLayoutType(MlirType type) { return isa(unwrap(type)); } + +MlirTypeID mlirFlyLayoutTypeGetTypeID(void) { return wrap(LayoutType::getTypeID()); } + +MlirType mlirFlyLayoutTypeGet(MlirType shape, MlirType stride) { + auto shapeType = cast(unwrap(shape)); + auto strideType = cast(unwrap(stride)); + LayoutAttr attr = LayoutAttr::get(shapeType.getAttr(), strideType.getAttr()); + return wrap(LayoutType::get(attr)); +} + +MlirType mlirFlyLayoutTypeGetShape(MlirType type) { + auto layoutType = cast(unwrap(type)); + IntTupleAttr shapeAttr = layoutType.getAttr().getShape(); + return wrap(IntTupleType::get(shapeAttr)); +} + +MlirType mlirFlyLayoutTypeGetStride(MlirType type) { + auto layoutType = cast(unwrap(type)); + IntTupleAttr strideAttr = layoutType.getAttr().getStride(); + return wrap(IntTupleType::get(strideAttr)); +} + +bool mlirFlyLayoutTypeIsLeaf(MlirType type) { return cast(unwrap(type)).isLeaf(); } + +int32_t mlirFlyLayoutTypeGetRank(MlirType type) { return cast(unwrap(type)).rank(); } + +int32_t mlirFlyLayoutTypeGetDepth(MlirType type) { return cast(unwrap(type)).depth(); } + +bool mlirFlyLayoutTypeIsStatic(MlirType type) { return cast(unwrap(type)).isStatic(); } + +bool mlirFlyLayoutTypeIsStaticShape(MlirType type) { + return cast(unwrap(type)).isStaticShape(); +} + +bool mlirFlyLayoutTypeIsStaticStride(MlirType type) { + return cast(unwrap(type)).isStaticStride(); +} + +//===----------------------------------------------------------------------===// +// SwizzleType +//===----------------------------------------------------------------------===// + +bool mlirTypeIsAFlySwizzleType(MlirType type) { return isa(unwrap(type)); } + +MlirTypeID mlirFlySwizzleTypeGetTypeID(void) { return wrap(SwizzleType::getTypeID()); } + +MlirType mlirFlySwizzleTypeGet(MlirContext ctx, int32_t mask, int32_t base, int32_t shift) { + MLIRContext *context = unwrap(ctx); + SwizzleAttr attr = SwizzleAttr::get(context, mask, base, shift); + return wrap(SwizzleType::get(attr)); +} + +int32_t mlirFlySwizzleTypeGetMask(MlirType type) { + return cast(unwrap(type)).getAttr().getMask(); +} + +int32_t mlirFlySwizzleTypeGetBase(MlirType type) { + return cast(unwrap(type)).getAttr().getBase(); +} + +int32_t mlirFlySwizzleTypeGetShift(MlirType type) { + return cast(unwrap(type)).getAttr().getShift(); +} + +//===----------------------------------------------------------------------===// +// PointerType +//===----------------------------------------------------------------------===// + +bool mlirTypeIsAFlyPointerType(MlirType type) { return isa(unwrap(type)); } + +MlirTypeID mlirFlyPointerTypeGetTypeID(void) { return wrap(fly::PointerType::getTypeID()); } + +MlirType mlirFlyPointerTypeGet(MlirType elemType, int32_t addressSpace, int32_t alignment) { + Type elemTy = unwrap(elemType); + MLIRContext *ctx = elemTy.getContext(); + AddressSpaceAttr addrSpaceAttr = + AddressSpaceAttr::get(ctx, static_cast(addressSpace)); + AlignAttr alignAttr = AlignAttr::get(ctx, alignment); + return wrap(fly::PointerType::get(elemTy, addrSpaceAttr, alignAttr)); +} + +MlirType mlirFlyPointerTypeGetElementType(MlirType type) { + return wrap(cast(unwrap(type)).getElemTy()); +} + +int32_t mlirFlyPointerTypeGetAddressSpace(MlirType type) { + return static_cast(cast(unwrap(type)).getAddressSpace().getValue()); +} + +int32_t mlirFlyPointerTypeGetAlignment(MlirType type) { + return cast(unwrap(type)).getAlignment().getAlignment(); +} + +MlirType mlirFlyPointerTypeGetSwizzle(MlirType type) { + return wrap(SwizzleType::get(cast(unwrap(type)).getSwizzle())); +} + +//===----------------------------------------------------------------------===// +// MemRefType +//===----------------------------------------------------------------------===// + +bool mlirTypeIsAFlyMemRefType(MlirType type) { return isa(unwrap(type)); } + +MlirTypeID mlirFlyMemRefTypeGetTypeID(void) { return wrap(fly::MemRefType::getTypeID()); } + +MlirType mlirFlyMemRefTypeGet(MlirType elemType, MlirType layout, int32_t addressSpace, + int32_t alignment) { + Type elemTy = unwrap(elemType); + auto layoutType = cast(unwrap(layout)); + MLIRContext *ctx = elemTy.getContext(); + AddressSpaceAttr addrSpaceAttr = + AddressSpaceAttr::get(ctx, static_cast(addressSpace)); + AlignAttr alignAttr = AlignAttr::get(ctx, alignment); + return wrap(fly::MemRefType::get(elemTy, addrSpaceAttr, layoutType.getAttr(), alignAttr)); +} + +MlirType mlirFlyMemRefTypeGetElementType(MlirType type) { + return wrap(cast(unwrap(type)).getElemTy()); +} + +MlirType mlirFlyMemRefTypeGetLayout(MlirType type) { + auto memrefType = cast(unwrap(type)); + return wrap(LayoutType::get(memrefType.getLayout())); +} + +int32_t mlirFlyMemRefTypeGetAddressSpace(MlirType type) { + return static_cast(cast(unwrap(type)).getAddressSpace().getValue()); +} + +int32_t mlirFlyMemRefTypeGetAlignment(MlirType type) { + return cast(unwrap(type)).getAlignment().getAlignment(); +} + +MlirType mlirFlyMemRefTypeGetSwizzle(MlirType type) { + return wrap(SwizzleType::get(cast(unwrap(type)).getSwizzle())); +} + +//===----------------------------------------------------------------------===// +// CopyAtomUniversalCopyType +//===----------------------------------------------------------------------===// + +bool mlirTypeIsAFlyCopyAtomUniversalCopyType(MlirType type) { + return isa(unwrap(type)); +} + +MlirTypeID mlirFlyCopyAtomUniversalCopyTypeGetTypeID(void) { + return wrap(CopyAtomUniversalCopyType::getTypeID()); +} + +MlirType mlirFlyCopyAtomUniversalCopyTypeGet(MlirContext ctx, int32_t bitSize) { + return wrap(CopyAtomUniversalCopyType::get(unwrap(ctx), bitSize)); +} + +int32_t mlirFlyCopyAtomUniversalCopyTypeGetBitSize(MlirType type) { + return cast(unwrap(type)).getBitSize(); +} diff --git a/lib/CAPI/Dialect/FlyROCDL/FlyROCDLDialect.cpp b/lib/CAPI/Dialect/FlyROCDL/FlyROCDLDialect.cpp index 08512ee2..48f2063b 100644 --- a/lib/CAPI/Dialect/FlyROCDL/FlyROCDLDialect.cpp +++ b/lib/CAPI/Dialect/FlyROCDL/FlyROCDLDialect.cpp @@ -1,6 +1,22 @@ #include "flydsl-c/FlyROCDLDialect.h" #include "flydsl/Dialect/FlyROCDL/IR/Dialect.h" +#include "mlir/CAPI/IR.h" #include "mlir/CAPI/Registration.h" +using namespace mlir; +using namespace mlir::fly_rocdl; + MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(FlyROCDL, fly_rocdl, mlir::fly_rocdl::FlyROCDLDialect) + +//===----------------------------------------------------------------------===// +// MmaAtomCDNA3_MFMAType +//===----------------------------------------------------------------------===// + +bool mlirTypeIsAFlyROCDLMmaAtomCDNA3_MFMAType(MlirType type) { + return isa(unwrap(type)); +} + +MlirTypeID mlirFlyROCDLMmaAtomCDNA3_MFMATypeGetTypeID(void) { + return wrap(MmaAtomCDNA3_MFMAType::getTypeID()); +} diff --git a/lib/Dialect/Fly/Transforms/LayoutLowering.cpp b/lib/Dialect/Fly/Transforms/LayoutLowering.cpp index 18af34c7..ea110049 100644 --- a/lib/Dialect/Fly/Transforms/LayoutLowering.cpp +++ b/lib/Dialect/Fly/Transforms/LayoutLowering.cpp @@ -124,14 +124,12 @@ static std::optional getLayoutStructType(LayoutAttr layout static unsigned mapAddressSpace(AddressSpace space) { switch (space) { - case AddressSpace::Flat: - return 0; case AddressSpace::Global: - return 1; + return 0; case AddressSpace::Shared: - return 3; + return 1; case AddressSpace::Register: - return 5; + return 2; } return 0; } diff --git a/python/flydsl/lang/ir/core.py b/python/flydsl/lang/ir/core.py index c83659e7..34ee3d42 100644 --- a/python/flydsl/lang/ir/core.py +++ b/python/flydsl/lang/ir/core.py @@ -13,6 +13,19 @@ from ..._mlir.dialects import arith from ..._mlir.extras import types as T +from ..._mlir.dialects.fly import ( + IntTupleType, + LayoutType, + SwizzleType, + PointerType, + MemRefType, + CopyAtomUniversalCopyType, +) + +from ..._mlir.dialects.fly_rocdl import ( + MmaAtomCDNA3_MFMAType, +) + def _binary_op(lhs, rhs, op: str) -> "ArithValue": op = op.capitalize() @@ -98,35 +111,35 @@ def make_identity_layout(shape, loc=None, ip=None): @dsl_api_wrapper def make_shape(*shape, loc=None, ip=None): - IntTupleTy, dyncElems = _fly_ir.infer_int_tuple_type(ir.Context.current, shape) + IntTupleTy, dyncElems = _fly_ir.infer_int_tuple_type(shape) return _fly_ir.make_shape(IntTupleTy, dyncElems, loc=loc, ip=ip) @dsl_api_wrapper def make_stride(*stride, loc=None, ip=None): - IntTupleTy, dyncElems = _fly_ir.infer_int_tuple_type(ir.Context.current, stride) + IntTupleTy, dyncElems = _fly_ir.infer_int_tuple_type(stride) return _fly_ir.make_stride(IntTupleTy, dyncElems, loc=loc, ip=ip) @dsl_api_wrapper def make_coord(*coord, loc=None, ip=None): - IntTupleTy, dyncElems = _fly_ir.infer_int_tuple_type(ir.Context.current, coord) + IntTupleTy, dyncElems = _fly_ir.infer_int_tuple_type(coord) return _fly_ir.make_coord(IntTupleTy, dyncElems, loc=loc, ip=ip) @dsl_api_wrapper def make_int_tuple(elems, loc=None, ip=None): - IntTupleTy, dyncElems = _fly_ir.infer_int_tuple_type(ir.Context.current, elems) + IntTupleTy, dyncElems = _fly_ir.infer_int_tuple_type(elems) return _fly_ir.make_int_tuple(IntTupleTy, dyncElems, loc=loc, ip=ip) @dsl_api_wrapper def make_layout(shape, stride, loc=None, ip=None): if not isinstance(shape, ir.Value): - shapeTy, dyncElems = _fly_ir.infer_int_tuple_type(ir.Context.current, shape) + shapeTy, dyncElems = _fly_ir.infer_int_tuple_type(shape) shape = _fly_ir.make_shape(shapeTy, dyncElems, loc=loc, ip=ip) if not isinstance(stride, ir.Value): - strideTy, dyncElems = _fly_ir.infer_int_tuple_type(ir.Context.current, stride) + strideTy, dyncElems = _fly_ir.infer_int_tuple_type(stride) stride = _fly_ir.make_stride(strideTy, dyncElems, loc=loc, ip=ip) return _fly_ir.make_layout(shape, stride=stride, loc=loc, ip=ip) @@ -144,7 +157,7 @@ def get_scalar(int_tuple, loc=None, ip=None): @dsl_api_wrapper def slice(src, coord, loc=None, ip=None): if not isinstance(coord, ir.Value): - coordTy, dyncElems = _fly_ir.infer_int_tuple_type(ir.Context.current, coord) + coordTy, dyncElems = _fly_ir.infer_int_tuple_type(coord) coord = _fly_ir.make_coord(coordTy, dyncElems, loc=loc, ip=ip) return _fly_ir.slice(src, coord, loc=loc, ip=ip) @@ -162,9 +175,7 @@ def composition(layout, tiler, loc=None, ip=None): @dsl_api_wrapper def complement(layout, codomain_size, loc=None, ip=None): if not isinstance(codomain_size, ir.Value): - codomain_sizeTy, dyncElems = _fly_ir.infer_int_tuple_type( - ir.Context.current, codomain_size - ) + codomain_sizeTy, dyncElems = _fly_ir.infer_int_tuple_type(codomain_size) codomain_size = _fly_ir.make_shape(codomain_sizeTy, dyncElems, loc=loc, ip=ip) return _fly_ir.complement(layout, codomain_size=codomain_size, loc=loc, ip=ip) @@ -376,105 +387,3 @@ def print_op(*values, format_str="", loc=None, ip=None): # ============================================================================== # Fly Type Classes (MLIR-style API) # ============================================================================== - - -class PointerType: - """ - Fly Pointer Type with MLIR-style static get() method. - - Example: - ptr_ty = PointerType.get(T.f32(), AddressSpace.Global) - ptr_ty = PointerType.get(T.f32(), AddressSpace.Register, alignment=16) - """ - - @staticmethod - def get(elem_ty, address_space, alignment=None): - """ - Create a PointerType. - - Args: - elem_ty: Element type (e.g., T.f32()) - address_space: Address space (AddressSpace.Global, AddressSpace.Shared, AddressSpace.Register) - alignment: Optional alignment value - - Returns: - PointerType as ir.Type - """ - return _fly_ir.PointerType.get(elem_ty, int(address_space), alignment) - - -class MemRefType: - """ - Fly MemRef Type with MLIR-style static get() method. - - Example: - layout_ty = LayoutType.get(ir.Context.current, 16, 1) - memref_ty = MemRefType.get(T.f32(), AddressSpace.Global, layout_ty) - """ - - @staticmethod - def get(elem_ty, address_space, layout, alignment=None): - """ - Create a MemRefType. - - Args: - elem_ty: Element type (e.g., T.f32()) - address_space: Address space (AddressSpace.Global, AddressSpace.Shared, AddressSpace.Register) - layout: Layout type (LayoutType or ir.Type) - alignment: Optional alignment value - - Returns: - MemRefType as ir.Type - """ - # If layout is an ir.Value (from make_layout), get its type - if isinstance(layout, ir.Value): - layout = layout.type - return _fly_ir.MemRefType.get(elem_ty, int(address_space), layout, alignment) - - -class LayoutType: - """ - Fly Layout Type with MLIR-style static get() method. - - Example: - layout_ty = LayoutType.get(ir.Context.current, 16, 1) - layout_ty = LayoutType.get(ir.Context.current, (4, 4), (4, 1)) - """ - - @staticmethod - def get(context, shape, stride): - """ - Create a LayoutType. - - Args: - context: MLIR context - shape: Shape as int or tuple - stride: Stride as int or tuple - - Returns: - LayoutType as ir.Type - """ - return _fly_ir.LayoutType.get(context, shape, stride) - - -class IntTupleType: - """ - Fly IntTuple Type with MLIR-style static get() method. - - Example: - int_tuple_ty = IntTupleType.get(ir.Context.current, (4, 4)) - """ - - @staticmethod - def get(context, int_or_tuple): - """ - Create an IntTupleType. - - Args: - context: MLIR context - int_or_tuple: Python int or tuple - - Returns: - Tuple of (IntTupleType as ir.Type, list of dynamic elements) - """ - return _fly_ir.IntTupleType.get(context, int_or_tuple) diff --git a/python/mlir_flydsl/CMakeLists.txt b/python/mlir_flydsl/CMakeLists.txt index addc5f4e..a3011d17 100644 --- a/python/mlir_flydsl/CMakeLists.txt +++ b/python/mlir_flydsl/CMakeLists.txt @@ -25,21 +25,29 @@ declare_mlir_dialect_python_bindings( GEN_ENUM_BINDINGS ) -declare_mlir_python_extension(FlyPythonSources.Core +declare_mlir_python_extension(FlyPythonSources.Core.fly MODULE_NAME _fly ADD_TO_PARENT FlyPythonSources ROOT_DIR "${PROJECT_SOURCE_DIR}/lib/Bindings/Python" PYTHON_BINDINGS_LIBRARY nanobind SOURCES - MainModules.cpp + FlyExtension.cpp PRIVATE_LINK_LIBS LLVMSupport MLIRFlyDialect - EMBED_CAPI_LINK_LIBS - MLIRCAPIIR ) - +declare_mlir_python_extension(FlyPythonSources.Core.fly_rocdl + MODULE_NAME _fly_rocdl + ADD_TO_PARENT FlyPythonSources + ROOT_DIR "${PROJECT_SOURCE_DIR}/lib/Bindings/Python" + PYTHON_BINDINGS_LIBRARY nanobind + SOURCES + FlyROCDLExtension.cpp + PRIVATE_LINK_LIBS + LLVMSupport + MLIRFlyROCDLDialect +) declare_mlir_python_extension(FlyPythonSources.RegisterEverything MODULE_NAME _mlirRegisterEverything @@ -92,33 +100,9 @@ add_mlir_python_common_capi_library(FlyPythonCAPI set(FlyPythonModules_ROOT_PREFIX "${MLIR_BINARY_DIR}/${MLIR_BINDINGS_PYTHON_INSTALL_PREFIX}") -# set(_core_type_stub_sources -# _mlir/__init__.pyi -# _mlir/ir.pyi -# _mlir/passmanager.pyi -# _mlir/rewrite.pyi -# ) -# get_target_property(_core_extension_srcs MLIRPythonExtension.Core INTERFACE_SOURCES) -# mlir_generate_type_stubs( -# MODULE_NAME _mlir -# DEPENDS_TARGETS FlyPythonModules.extension._mlir.dso -# OUTPUT_DIR "${CMAKE_CURRENT_BINARY_DIR}/type_stubs/_mlir_libs" -# OUTPUTS "${_core_type_stub_sources}" -# DEPENDS_TARGET_SRC_DEPS "${_core_extension_srcs}" -# IMPORT_PATHS "${FlyPythonModules_ROOT_PREFIX}/_mlir_libs" -# VERBOSE -# ) -# set(_mlir_typestub_gen_target "${NB_STUBGEN_CUSTOM_TARGET}") - -# list(TRANSFORM _core_type_stub_sources PREPEND "_mlir_libs/") -# declare_mlir_python_sources( -# FlyPythonExtension.Core.type_stub_gen -# ROOT_DIR "${CMAKE_CURRENT_BINARY_DIR}/type_stubs" -# ADD_TO_PARENT FlyPythonSources -# SOURCES "${_core_type_stub_sources}" -# ) - - +################################################################################ +# Python Modules +################################################################################ add_mlir_python_modules(FlyPythonModules ROOT_PREFIX "${FlyPythonModules_ROOT_PREFIX}" @@ -128,6 +112,37 @@ add_mlir_python_modules(FlyPythonModules FlyPythonCAPI ) +################################################################################ +# Type Stubs Generation +################################################################################ + +set(_FLYDSL_PYTHON_PACKAGES_DIR "${MLIR_BINARY_DIR}/python_packages") +set(_MLIR_LIBS_DIR "${FlyPythonModules_ROOT_PREFIX}/_mlir_libs") +set(_STUB_MARKER_FILE "${_MLIR_LIBS_DIR}/.stubs_generated") + +add_custom_command( + OUTPUT "${_STUB_MARKER_FILE}" + COMMAND ${CMAKE_COMMAND} -E env + "PYTHONPATH=${_FLYDSL_PYTHON_PACKAGES_DIR}" + ${Python3_EXECUTABLE} -m nanobind.stubgen + -q + -r + -m flydsl._mlir._mlir_libs._mlir + -m flydsl._mlir._mlir_libs._fly + -m flydsl._mlir._mlir_libs._fly_rocdl + -m flydsl._mlir._mlir_libs._mlirDialectsGPU + -m flydsl._mlir._mlir_libs._mlirDialectsLLVM + -O "${_MLIR_LIBS_DIR}" + COMMAND ${CMAKE_COMMAND} -E touch "${_STUB_MARKER_FILE}" + DEPENDS CopyFlyPythonSources + COMMENT "Generating Python stub files for all extension modules" + VERBATIM +) + +add_custom_target(FlyPythonStubs ALL + DEPENDS "${_STUB_MARKER_FILE}" +) + add_custom_target(CopyFlyPythonSources ALL COMMAND ${CMAKE_COMMAND} -E copy_directory "${PROJECT_SOURCE_DIR}/python/flydsl" diff --git a/python/mlir_flydsl/dialects/fly_rocdl.py b/python/mlir_flydsl/dialects/fly_rocdl.py index caf90541..8ea4bbe5 100644 --- a/python/mlir_flydsl/dialects/fly_rocdl.py +++ b/python/mlir_flydsl/dialects/fly_rocdl.py @@ -1,2 +1,4 @@ from ._fly_rocdl_enum_gen import * from ._fly_rocdl_ops_gen import * + +from .._mlir_libs._fly_rocdl import * From 51e45a584d64b06231ebf5ddec6cabbfc1913939 Mon Sep 17 00:00:00 2001 From: Feng Shijie Date: Wed, 28 Jan 2026 11:54:12 +0000 Subject: [PATCH 005/113] Add universalMma Atom --- include/flydsl-c/FlyDialect.h | 13 + include/flydsl-c/FlyROCDLDialect.h | 17 + include/flydsl/Dialect/Fly/IR/FlyAttrDefs.td | 17 +- include/flydsl/Dialect/Fly/IR/FlyDialect.h | 4 + include/flydsl/Dialect/Fly/IR/FlyTypeDefs.td | 12 + .../flydsl/Dialect/Fly/Utils/IntTupleUtils.h | 16 + include/flydsl/Dialect/FlyROCDL/IR/Dialect.h | 8 - include/flydsl/Dialect/FlyROCDL/IR/MmaAtom.td | 2 +- lib/Bindings/Python/FlyExtension.cpp | 15 + lib/Bindings/Python/FlyROCDLExtension.cpp | 25 +- lib/CAPI/Dialect/Fly/FlyDialect.cpp | 20 ++ lib/CAPI/Dialect/FlyROCDL/FlyROCDLDialect.cpp | 34 ++ lib/Dialect/Fly/IR/FlyAttrDefs.cpp | 8 +- lib/Dialect/Fly/IR/FlyDialect.cpp | 18 + lib/Dialect/Fly/IR/FlyOps.cpp | 1 + lib/Dialect/Fly/IR/FlyTypeDefs.cpp | 52 ++- lib/Dialect/Fly/Transforms/LayoutLowering.cpp | 312 ++---------------- lib/Dialect/FlyROCDL/Dialect.cpp | 22 +- python/flydsl/lang/ir/core.py | 1 + 19 files changed, 253 insertions(+), 344 deletions(-) diff --git a/include/flydsl-c/FlyDialect.h b/include/flydsl-c/FlyDialect.h index 9e6c0436..164194d3 100644 --- a/include/flydsl-c/FlyDialect.h +++ b/include/flydsl-c/FlyDialect.h @@ -107,6 +107,19 @@ MLIR_CAPI_EXPORTED MlirType mlirFlyCopyAtomUniversalCopyTypeGet(MlirContext ctx, // Accessors MLIR_CAPI_EXPORTED int32_t mlirFlyCopyAtomUniversalCopyTypeGetBitSize(MlirType type); +//===----------------------------------------------------------------------===// +// MmaAtomUniversalFMAType +//===----------------------------------------------------------------------===// + +MLIR_CAPI_EXPORTED bool mlirTypeIsAFlyMmaAtomUniversalFMAType(MlirType type); +MLIR_CAPI_EXPORTED MlirTypeID mlirFlyMmaAtomUniversalFMATypeGetTypeID(void); + +// Constructor +MLIR_CAPI_EXPORTED MlirType mlirFlyMmaAtomUniversalFMATypeGet(MlirContext ctx, MlirType elemTy); + +// Accessors +MLIR_CAPI_EXPORTED MlirType mlirFlyMmaAtomUniversalFMATypeGetElemTy(MlirType type); + #ifdef __cplusplus } #endif diff --git a/include/flydsl-c/FlyROCDLDialect.h b/include/flydsl-c/FlyROCDLDialect.h index 9f7f3a3c..e2da228f 100644 --- a/include/flydsl-c/FlyROCDLDialect.h +++ b/include/flydsl-c/FlyROCDLDialect.h @@ -17,6 +17,23 @@ MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(FlyROCDL, fly_rocdl); MLIR_CAPI_EXPORTED bool mlirTypeIsAFlyROCDLMmaAtomCDNA3_MFMAType(MlirType type); MLIR_CAPI_EXPORTED MlirTypeID mlirFlyROCDLMmaAtomCDNA3_MFMATypeGetTypeID(void); +// Constructor +MLIR_CAPI_EXPORTED MlirType mlirFlyROCDLMmaAtomCDNA3_MFMATypeGet(int32_t m, int32_t n, int32_t k, + MlirType elemTyA, MlirType elemTyB, + MlirType elemTyAcc); + +// Accessors +MLIR_CAPI_EXPORTED int32_t mlirFlyROCDLMmaAtomCDNA3_MFMATypeGetM(MlirType type); +MLIR_CAPI_EXPORTED int32_t mlirFlyROCDLMmaAtomCDNA3_MFMATypeGetN(MlirType type); +MLIR_CAPI_EXPORTED int32_t mlirFlyROCDLMmaAtomCDNA3_MFMATypeGetK(MlirType type); +MLIR_CAPI_EXPORTED MlirType mlirFlyROCDLMmaAtomCDNA3_MFMATypeGetElemTyA(MlirType type); +MLIR_CAPI_EXPORTED MlirType mlirFlyROCDLMmaAtomCDNA3_MFMATypeGetElemTyB(MlirType type); +MLIR_CAPI_EXPORTED MlirType mlirFlyROCDLMmaAtomCDNA3_MFMATypeGetElemTyAcc(MlirType type); + +//===----------------------------------------------------------------------===// +// CopyAtom_CDNA3_BufferLSAType +//===----------------------------------------------------------------------===// + #ifdef __cplusplus } #endif diff --git a/include/flydsl/Dialect/Fly/IR/FlyAttrDefs.td b/include/flydsl/Dialect/Fly/IR/FlyAttrDefs.td index ded0372b..c63299bb 100644 --- a/include/flydsl/Dialect/Fly/IR/FlyAttrDefs.td +++ b/include/flydsl/Dialect/Fly/IR/FlyAttrDefs.td @@ -56,20 +56,11 @@ def Fly_IntAttr : Fly_Attr<"Int", "int", [ DeclareAttrInterfaceMethods ]> { let parameters = (ins - "int32_t":$value, + "int32_t":$value, // needs to be int64_t to support large integers in the future DefaultValuedParameter<"int32_t", "32">:$width, DefaultValuedParameter<"int32_t", "1">:$divisibility); let hasCustomAssemblyFormat = 1; - let builders = [ - AttrBuilder<(ins "int32_t":$value), [{ - return $_get($_ctxt, value, 32, 1); - }]>, - AttrBuilder<(ins "int32_t":$width, "int32_t":$divisibility), [{ - return $_get($_ctxt, std::numeric_limits::min(), width, divisibility); - }]> - ]; - let extraClassDeclaration = [{ bool isNone() const; // value can't be INT32_MIN here @@ -90,10 +81,10 @@ def Fly_IntAttr : Fly_Attr<"Int", "int", [ return get(ctx, 0, 0, 0); } IntAttr $cppClass::getStatic(MLIRContext *ctx, int32_t value) { - return get(ctx, value); + return get(ctx, value, 32, value == 0 ? 1 : value); } IntAttr $cppClass::getDynamic(MLIRContext *ctx, int32_t width, int32_t divisibility) { - return get(ctx, width, divisibility); + return get(ctx, std::numeric_limits::min(), width, divisibility); } }]; } @@ -150,7 +141,7 @@ def Fly_IntTupleAttr : Fly_Attr<"IntTuple", "int_tuple", [ let builders = [ AttrBuilder<(ins "int32_t":$value), [{ - return $_get($_ctxt, IntAttr::get($_ctxt, value)); + return $_get($_ctxt, IntAttr::getStatic($_ctxt, value)); }]>, AttrBuilderWithInferredContext<(ins "Attribute":$value), [{ return $_get(value.getContext(), value); diff --git a/include/flydsl/Dialect/Fly/IR/FlyDialect.h b/include/flydsl/Dialect/Fly/IR/FlyDialect.h index 7a55b35b..8dbe23a1 100644 --- a/include/flydsl/Dialect/Fly/IR/FlyDialect.h +++ b/include/flydsl/Dialect/Fly/IR/FlyDialect.h @@ -29,6 +29,10 @@ namespace mlir::fly { #include "flydsl/Dialect/Fly/IR/FlyAttrConstraints.h.inc" #include "flydsl/Dialect/Fly/IR/FlyTypeConstraints.h.inc" + +ParseResult parseMNKDimensionList(AsmParser &parser, int32_t &m, int32_t &n, int32_t &k); +void printMNKDimensionList(AsmPrinter &printer, int32_t m, int32_t n, int32_t k); + } // namespace mlir::fly #endif // FLYDSL_DIALECT_FLY_IR_DIALECT_H diff --git a/include/flydsl/Dialect/Fly/IR/FlyTypeDefs.td b/include/flydsl/Dialect/Fly/IR/FlyTypeDefs.td index 0caa42ab..517e9645 100644 --- a/include/flydsl/Dialect/Fly/IR/FlyTypeDefs.td +++ b/include/flydsl/Dialect/Fly/IR/FlyTypeDefs.td @@ -221,5 +221,17 @@ def Fly_CopyAtomUniversalCopy : Fly_Type<"CopyAtomUniversalCopy", "atom.universa let assemblyFormat = "`<` $bitSize `>`"; } +def Fly_MmaAtomUniversalFMA : Fly_Type<"MmaAtomUniversalFMA", "atom.universal_fma", [ + DeclareTypeInterfaceMethods +]> { + let parameters = (ins "Type":$elemTy); + + let builders = [ + TypeBuilderWithInferredContext<(ins "Type":$elemTy), [{ + return $_get(elemTy.getContext(), elemTy); + }]> + ]; + let hasCustomAssemblyFormat = 1; +} #endif // FLY_TYPEDEFS diff --git a/include/flydsl/Dialect/Fly/Utils/IntTupleUtils.h b/include/flydsl/Dialect/Fly/Utils/IntTupleUtils.h index cd3ec262..919815b9 100644 --- a/include/flydsl/Dialect/Fly/Utils/IntTupleUtils.h +++ b/include/flydsl/Dialect/Fly/Utils/IntTupleUtils.h @@ -118,6 +118,13 @@ template <> class IntTupleBuilder { ArithValue materializeConstantArith(int32_t value) const { return IntAttr::getStatic(ctx, value); } + ArithValue materializeConstantArith(int64_t value) const; + + ArithValue materializeConstantArith(IntAttr value) const { + assert(value.isStatic() && "Value must be static"); + return value; + } + IntTupleAttr materializeConstantTuple(IntTupleAttr attr) const { assert(attr.isStatic() && "Tuple must be static"); return attr; @@ -229,6 +236,15 @@ template <> class IntTupleBuilder { return ArithValue{arith::ConstantIntOp::create(builder, loc, value, 32).getResult(), attrBuilder.materializeConstantArith(value)}; } + ArithValue materializeConstantArith(int64_t value) const; + + ArithValue materializeConstantArith(IntAttr value) const { + assert(value.isStatic() && "Value must be static"); + return ArithValue{ + arith::ConstantIntOp::create(builder, loc, value.getValue(), value.getWidth()).getResult(), + value}; + } + IntTupleValueAdaptor materializeConstantTuple(IntTupleAttr attr) const { assert(attr.isStatic() && "Tuple must be static"); if (attr.isLeaf()) { diff --git a/include/flydsl/Dialect/FlyROCDL/IR/Dialect.h b/include/flydsl/Dialect/FlyROCDL/IR/Dialect.h index eac5ec1f..3f03d380 100644 --- a/include/flydsl/Dialect/FlyROCDL/IR/Dialect.h +++ b/include/flydsl/Dialect/FlyROCDL/IR/Dialect.h @@ -22,12 +22,4 @@ #define GET_TYPEDEF_CLASSES #include "flydsl/Dialect/FlyROCDL/IR/Atom.h.inc" -namespace mlir::fly_rocdl { - -ParseResult parseMNKDimensionList(AsmParser &parser, int32_t &m, int32_t &n, int32_t &k); - -void printMNKDimensionList(AsmPrinter &printer, int32_t m, int32_t n, int32_t k); - -} // namespace mlir::fly_rocdl - #endif // FLYDSL_DIALECT_FLYROCDL_IR_DIALECT_H diff --git a/include/flydsl/Dialect/FlyROCDL/IR/MmaAtom.td b/include/flydsl/Dialect/FlyROCDL/IR/MmaAtom.td index ebdc507b..e3275f06 100644 --- a/include/flydsl/Dialect/FlyROCDL/IR/MmaAtom.td +++ b/include/flydsl/Dialect/FlyROCDL/IR/MmaAtom.td @@ -16,7 +16,7 @@ def FlyROCDL_MmaAtomCDNA3_MFMA : FlyxROCL_MmaAtom<"MmaAtomCDNA3_MFMA", "atom.cdn "Type":$elemTyB, "Type":$elemTyAcc ); - let assemblyFormat = "`<` custom($m, $n, $k) `,` $elemTyA `x` $elemTyB `=` $elemTyAcc `>`"; + let assemblyFormat = "`<` custom($m, $n, $k) `,` `(` $elemTyA `,` $elemTyB `)` `->` $elemTyAcc `>`"; let builders = [ TypeBuilderWithInferredContext<(ins "int32_t":$m, "int32_t":$n, "int32_t":$k, "Type":$elemTyA, "Type":$elemTyB, "Type":$elemTyAcc), [{ diff --git a/lib/Bindings/Python/FlyExtension.cpp b/lib/Bindings/Python/FlyExtension.cpp index 3feb8581..fae145ac 100644 --- a/lib/Bindings/Python/FlyExtension.cpp +++ b/lib/Bindings/Python/FlyExtension.cpp @@ -334,4 +334,19 @@ NB_MODULE(_fly, m) { .def_property_readonly("bit_size", [](MlirType self) { return mlirFlyCopyAtomUniversalCopyTypeGetBitSize(self); }); + + mlir_type_subclass(m, "MmaAtomUniversalFMAType", mlirTypeIsAFlyMmaAtomUniversalFMAType, + mlirFlyMmaAtomUniversalFMATypeGetTypeID) + .def_classmethod( + "get", + [](const nb::object &cls, MlirType elemTy, MlirContext context) { + return cls(wrap(MmaAtomUniversalFMAType::get(unwrap(elemTy)))); + }, + "cls"_a, "elem_ty"_a, "context"_a = nb::none(), + // clang-format off + nb::sig("def get(cls, elem_ty: " MAKE_MLIR_PYTHON_QUALNAME("ir.Type") ", context: " MAKE_MLIR_PYTHON_QUALNAME("ir.Context") " | None = None) -> MmaAtomUniversalFMAType"), + // clang-format on + "Create a MmaAtomUniversalFMAType with element type") + .def_property_readonly( + "elem_ty", [](MlirType self) { return mlirFlyMmaAtomUniversalFMATypeGetElemTy(self); }); } diff --git a/lib/Bindings/Python/FlyROCDLExtension.cpp b/lib/Bindings/Python/FlyROCDLExtension.cpp index 98d5c83b..e005237f 100644 --- a/lib/Bindings/Python/FlyROCDLExtension.cpp +++ b/lib/Bindings/Python/FlyROCDLExtension.cpp @@ -25,6 +25,10 @@ using namespace mlir::python::nanobind_adaptors; NB_MODULE(_fly_rocdl, m) { m.doc() = "MLIR Python FlyROCDL Extension"; + //===--------------------------------------------------------------------===// + // MmaAtomCDNA3_MFMAType + //===--------------------------------------------------------------------===// + mlir_type_subclass(m, "MmaAtomCDNA3_MFMAType", mlirTypeIsAFlyROCDLMmaAtomCDNA3_MFMAType, mlirFlyROCDLMmaAtomCDNA3_MFMATypeGetTypeID) .def_classmethod( @@ -39,5 +43,24 @@ NB_MODULE(_fly_rocdl, m) { // clang-format off nb::sig("def get(cls, m: int, n: int, k: int, elem_ty_a: " MAKE_MLIR_PYTHON_QUALNAME("ir.Type") ", elem_ty_b: " MAKE_MLIR_PYTHON_QUALNAME("ir.Type") ", elem_ty_acc: " MAKE_MLIR_PYTHON_QUALNAME("ir.Type") ", context: " MAKE_MLIR_PYTHON_QUALNAME("ir.Context") " | None = None) -> MmaAtomCDNA3_MFMAType"), // clang-format on - "Create a MmaAtomCDNA3_MFMAType with m, n, k dimensions and element types"); + "Create a MmaAtomCDNA3_MFMAType with m, n, k dimensions and element types") + .def_property_readonly( + "m", [](MlirType self) { return mlirFlyROCDLMmaAtomCDNA3_MFMATypeGetM(self); }) + .def_property_readonly( + "n", [](MlirType self) { return mlirFlyROCDLMmaAtomCDNA3_MFMATypeGetN(self); }) + .def_property_readonly( + "k", [](MlirType self) { return mlirFlyROCDLMmaAtomCDNA3_MFMATypeGetK(self); }) + .def_property_readonly( + "elem_ty_a", + [](MlirType self) { return mlirFlyROCDLMmaAtomCDNA3_MFMATypeGetElemTyA(self); }) + .def_property_readonly( + "elem_ty_b", + [](MlirType self) { return mlirFlyROCDLMmaAtomCDNA3_MFMATypeGetElemTyB(self); }) + .def_property_readonly("elem_ty_acc", [](MlirType self) { + return mlirFlyROCDLMmaAtomCDNA3_MFMATypeGetElemTyAcc(self); + }); + + //===--------------------------------------------------------------------===// + // CopyAtom_CDNA3_BufferLSAType + //===--------------------------------------------------------------------===// } diff --git a/lib/CAPI/Dialect/Fly/FlyDialect.cpp b/lib/CAPI/Dialect/Fly/FlyDialect.cpp index 60e74a80..9fc74fad 100644 --- a/lib/CAPI/Dialect/Fly/FlyDialect.cpp +++ b/lib/CAPI/Dialect/Fly/FlyDialect.cpp @@ -192,3 +192,23 @@ MlirType mlirFlyCopyAtomUniversalCopyTypeGet(MlirContext ctx, int32_t bitSize) { int32_t mlirFlyCopyAtomUniversalCopyTypeGetBitSize(MlirType type) { return cast(unwrap(type)).getBitSize(); } + +//===----------------------------------------------------------------------===// +// MmaAtomUniversalFMAType +//===----------------------------------------------------------------------===// + +bool mlirTypeIsAFlyMmaAtomUniversalFMAType(MlirType type) { + return isa(unwrap(type)); +} + +MlirTypeID mlirFlyMmaAtomUniversalFMATypeGetTypeID(void) { + return wrap(MmaAtomUniversalFMAType::getTypeID()); +} + +MlirType mlirFlyMmaAtomUniversalFMATypeGet(MlirContext ctx, MlirType elemTy) { + return wrap(MmaAtomUniversalFMAType::get(unwrap(ctx), unwrap(elemTy))); +} + +MlirType mlirFlyMmaAtomUniversalFMATypeGetElemTy(MlirType type) { + return wrap(cast(unwrap(type)).getElemTy()); +} diff --git a/lib/CAPI/Dialect/FlyROCDL/FlyROCDLDialect.cpp b/lib/CAPI/Dialect/FlyROCDL/FlyROCDLDialect.cpp index 48f2063b..0ac89c22 100644 --- a/lib/CAPI/Dialect/FlyROCDL/FlyROCDLDialect.cpp +++ b/lib/CAPI/Dialect/FlyROCDL/FlyROCDLDialect.cpp @@ -20,3 +20,37 @@ bool mlirTypeIsAFlyROCDLMmaAtomCDNA3_MFMAType(MlirType type) { MlirTypeID mlirFlyROCDLMmaAtomCDNA3_MFMATypeGetTypeID(void) { return wrap(MmaAtomCDNA3_MFMAType::getTypeID()); } + +MlirType mlirFlyROCDLMmaAtomCDNA3_MFMATypeGet(int32_t m, int32_t n, int32_t k, MlirType elemTyA, + MlirType elemTyB, MlirType elemTyAcc) { + return wrap( + MmaAtomCDNA3_MFMAType::get(m, n, k, unwrap(elemTyA), unwrap(elemTyB), unwrap(elemTyAcc))); +} + +int32_t mlirFlyROCDLMmaAtomCDNA3_MFMATypeGetM(MlirType type) { + return cast(unwrap(type)).getM(); +} + +int32_t mlirFlyROCDLMmaAtomCDNA3_MFMATypeGetN(MlirType type) { + return cast(unwrap(type)).getN(); +} + +int32_t mlirFlyROCDLMmaAtomCDNA3_MFMATypeGetK(MlirType type) { + return cast(unwrap(type)).getK(); +} + +MlirType mlirFlyROCDLMmaAtomCDNA3_MFMATypeGetElemTyA(MlirType type) { + return wrap(cast(unwrap(type)).getElemTyA()); +} + +MlirType mlirFlyROCDLMmaAtomCDNA3_MFMATypeGetElemTyB(MlirType type) { + return wrap(cast(unwrap(type)).getElemTyB()); +} + +MlirType mlirFlyROCDLMmaAtomCDNA3_MFMATypeGetElemTyAcc(MlirType type) { + return wrap(cast(unwrap(type)).getElemTyAcc()); +} + +//===----------------------------------------------------------------------===// +// CopyAtom_CDNA3_BufferLSAType +//===----------------------------------------------------------------------===// diff --git a/lib/Dialect/Fly/IR/FlyAttrDefs.cpp b/lib/Dialect/Fly/IR/FlyAttrDefs.cpp index 686b2261..d5821693 100644 --- a/lib/Dialect/Fly/IR/FlyAttrDefs.cpp +++ b/lib/Dialect/Fly/IR/FlyAttrDefs.cpp @@ -271,12 +271,12 @@ ::mlir::Attribute IntAttr::parse(::mlir::AsmParser &odsParser, ::mlir::Type odsT if (odsParser.parseRBrace()) return {}; } - return IntAttr::get(ctx, width, divisibility); + return IntAttr::getDynamic(ctx, width, divisibility); } int32_t value; if (odsParser.parseDecimalInteger(value)) return {}; - return IntAttr::get(ctx, value); + return IntAttr::getStatic(ctx, value); } void IntAttr::print(::mlir::AsmPrinter &odsPrinter) const { prettyPrintIntAttr(odsPrinter, *this); } @@ -303,12 +303,12 @@ ::mlir::Attribute parseLeafAttr(::mlir::AsmParser &odsParser) { if (odsParser.parseRBrace()) return {}; } - valueAttr = IntAttr::get(ctx, width, divisibility); + valueAttr = IntAttr::getDynamic(ctx, width, divisibility); } else { int32_t value; if (odsParser.parseDecimalInteger(value)) return {}; - valueAttr = IntAttr::get(ctx, value); + valueAttr = IntAttr::getStatic(ctx, value); } SmallString<16> strModes; diff --git a/lib/Dialect/Fly/IR/FlyDialect.cpp b/lib/Dialect/Fly/IR/FlyDialect.cpp index 49ba2fb5..9066b22f 100644 --- a/lib/Dialect/Fly/IR/FlyDialect.cpp +++ b/lib/Dialect/Fly/IR/FlyDialect.cpp @@ -17,6 +17,24 @@ using namespace mlir::fly; namespace mlir::fly { #include "flydsl/Dialect/Fly/IR/FlyAttrConstraints.cpp.inc" #include "flydsl/Dialect/Fly/IR/FlyTypeConstraints.cpp.inc" + +ParseResult parseMNKDimensionList(AsmParser &parser, int32_t &m, int32_t &n, int32_t &k) { + SmallVector dimensions; + if (parser.parseDimensionList(dimensions, false, false)) + return failure(); + if (dimensions.size() != 3) + return parser.emitError(parser.getCurrentLocation()) + << "expected 3 dimensions in MNK dimension list"; + m = dimensions[0]; + n = dimensions[1]; + k = dimensions[2]; + return success(); +} + +void printMNKDimensionList(AsmPrinter &printer, int32_t m, int32_t n, int32_t k) { + printer.printDimensionList(ArrayRef{m, n, k}); +} + } // namespace mlir::fly #define GET_TYPEDEF_CLASSES diff --git a/lib/Dialect/Fly/IR/FlyOps.cpp b/lib/Dialect/Fly/IR/FlyOps.cpp index 76726aab..f27e5a4d 100644 --- a/lib/Dialect/Fly/IR/FlyOps.cpp +++ b/lib/Dialect/Fly/IR/FlyOps.cpp @@ -1220,6 +1220,7 @@ FLY_INFER_RETURN_TYPES(AddOffsetOp) { // Offset must be a scalar (leaf) int_tuple if (!offsetTy.getAttr().isLeaf()) return failure(); + // todo: alignment of the return pointer should be the gcd(offset, original alignment) inferredReturnTypes.assign({ptrTy}); return success(); } diff --git a/lib/Dialect/Fly/IR/FlyTypeDefs.cpp b/lib/Dialect/Fly/IR/FlyTypeDefs.cpp index 920a49c7..6b7a876b 100644 --- a/lib/Dialect/Fly/IR/FlyTypeDefs.cpp +++ b/lib/Dialect/Fly/IR/FlyTypeDefs.cpp @@ -1,3 +1,5 @@ +#include "mlir/IR/DialectImplementation.h" + #include "flydsl/Dialect/Fly/IR/FlyDialect.h" namespace mlir::fly { @@ -80,13 +82,57 @@ CoordTensorType CoordTensorType::at(ArrayRef idxs) const { Attribute CopyAtomUniversalCopyType::getThrSize() const { return FxC(1); } Attribute CopyAtomUniversalCopyType::getThrValLayoutSrc() const { - return FxLayout(FxShape(FxC(1), FxC(getBitSize())), FxStride(FxC(0), FxC(1))); + return FxLayout(FxShape(FxC(1), FxC(getBitSize())), FxStride(FxC(1), FxC(1))); } Attribute CopyAtomUniversalCopyType::getThrValLayoutDst() const { - return FxLayout(FxShape(FxC(1), FxC(getBitSize())), FxStride(FxC(0), FxC(1))); + return FxLayout(FxShape(FxC(1), FxC(getBitSize())), FxStride(FxC(1), FxC(1))); } Attribute CopyAtomUniversalCopyType::getThrValLayoutRef() const { - return FxLayout(FxShape(FxC(1), FxC(getBitSize())), FxStride(FxC(0), FxC(1))); + return FxLayout(FxShape(FxC(1), FxC(getBitSize())), FxStride(FxC(1), FxC(1))); +} + +Attribute MmaAtomUniversalFMAType::getThrSize() const { return FxC(1); } + +Attribute MmaAtomUniversalFMAType::getThrValLayoutA() const { + return FxLayout(FxShape(FxC(1), FxC(1)), FxStride(FxC(1), FxC(1))); +} +Attribute MmaAtomUniversalFMAType::getThrValLayoutB() const { + return FxLayout(FxShape(FxC(1), FxC(1)), FxStride(FxC(1), FxC(1))); +} +Attribute MmaAtomUniversalFMAType::getThrValLayoutC() const { + return FxLayout(FxShape(FxC(1), FxC(1)), FxStride(FxC(1), FxC(1))); +} + +Type MmaAtomUniversalFMAType::parse(AsmParser &parser) { + Type elemTyA, elemTyB, elemTyC; + if (parser.parseLess()) + return {}; + int32_t m, n, k; + if (parseMNKDimensionList(parser, m, n, k)) + return {}; + if (m != 1 || n != 1 || k != 1) { + parser.emitError(parser.getCurrentLocation()) + << "expected 1x1x1 dimensions for universal FMA, got " << m << "x" << n << "x" << k; + return {}; + } + // Parse ", (elemTy, elemTy) -> elemTy>" + if (parser.parseComma() || parser.parseLParen() || parser.parseType(elemTyA) || + parser.parseComma() || parser.parseType(elemTyB) || parser.parseRParen() || + parser.parseArrow() || parser.parseType(elemTyC) || parser.parseGreater()) + return {}; + // For universal FMA, all element types should be the same + if (elemTyA != elemTyB || elemTyB != elemTyC) { + parser.emitError(parser.getCurrentLocation()) + << "expected all element types to be the same for universal FMA"; + return {}; + } + return get(parser.getContext(), elemTyA); +} + +void MmaAtomUniversalFMAType::print(AsmPrinter &printer) const { + printer << "<"; + printMNKDimensionList(printer, 1, 1, 1); + printer << ", (" << getElemTy() << ", " << getElemTy() << ") -> " << getElemTy() << ">"; } } // namespace mlir::fly diff --git a/lib/Dialect/Fly/Transforms/LayoutLowering.cpp b/lib/Dialect/Fly/Transforms/LayoutLowering.cpp index ea110049..ba07971a 100644 --- a/lib/Dialect/Fly/Transforms/LayoutLowering.cpp +++ b/lib/Dialect/Fly/Transforms/LayoutLowering.cpp @@ -7,6 +7,7 @@ #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/IR/BuiltinTypes.h" +#include "mlir/Interfaces/FunctionInterfaces.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" @@ -432,9 +433,11 @@ static void lowerGpuLaunchFuncIntTupleOperands(gpu::LaunchFuncOp op) { op.getKernelOperandsMutable().assign(newKernelOperands); } -static bool lowerGpuFuncIntTupleArgs(gpu::GPUFuncOp op) { - auto funcType = op.getFunctionType(); - SmallVector oldInputs(funcType.getInputs().begin(), funcType.getInputs().end()); +/// Lower function arguments: IntTupleType, LayoutType, and MemRefType arguments are lowered +/// to LLVM structs. Works with any operation implementing FunctionOpInterface. +static bool lowerFuncIntTupleArgs(FunctionOpInterface op) { + ArrayRef argTypes = op.getArgumentTypes(); + SmallVector oldInputs(argTypes.begin(), argTypes.end()); // First pass: compute new argument types SmallVector newInputs; @@ -481,10 +484,14 @@ static bool lowerGpuFuncIntTupleArgs(gpu::GPUFuncOp op) { return false; // Update function type - auto newFuncType = FunctionType::get(op.getContext(), newInputs, funcType.getResults()); + auto newFuncType = FunctionType::get(op.getContext(), newInputs, op.getResultTypes()); op.setType(newFuncType); - Block &entry = op.getBody().front(); + // Handle empty function (declaration only) + if (op.getFunctionBody().empty()) + return true; + + Block &entry = op.getFunctionBody().front(); Location loc = op.getLoc(); // Transform block arguments: work backwards to handle index shifts from MemRef expansion @@ -553,7 +560,7 @@ static bool lowerGpuFuncIntTupleArgs(gpu::GPUFuncOp op) { SmallVector dynamicLeaves; collectDynamicLeaves(tupleTy.getAttr(), dynamicLeaves); if (dynamicLeaves.empty()) { - Value tuple = StaticOp::create(builder, loc, tupleTy); + Value tuple = MakeIntTupleOp::create(builder, loc, tupleTy, {}); arg.replaceAllUsesWith(tuple); newArgIdx++; continue; @@ -589,7 +596,11 @@ static bool lowerGpuFuncIntTupleArgs(gpu::GPUFuncOp op) { collectDynamicLeaves(layoutAttr.getShape(), shapeLeaves); collectDynamicLeaves(layoutAttr.getStride(), strideLeaves); if (shapeLeaves.empty() && strideLeaves.empty()) { - Value layout = StaticOp::create(builder, loc, layoutTy); + Value Shape = + MakeIntTupleOp::create(builder, loc, IntTupleType::get(layoutAttr.getShape()), {}); + Value Stride = + MakeIntTupleOp::create(builder, loc, IntTupleType::get(layoutAttr.getStride()), {}); + Value layout = MakeLayoutOp::create(builder, loc, layoutTy, Shape, Stride); arg.replaceAllUsesWith(layout); newArgIdx++; continue; @@ -728,290 +739,6 @@ static bool lowerGpuFuncIntTupleArgs(gpu::GPUFuncOp op) { return true; } -/// Lower func::FuncOp arguments: IntTupleType, LayoutType, and MemRefType arguments are lowered -/// to LLVM structs, similar to lowerGpuFuncIntTupleArgs but for func.func operations. -static bool lowerFuncOpIntTupleArgs(func::FuncOp op) { - auto funcType = op.getFunctionType(); - SmallVector oldInputs(funcType.getInputs().begin(), funcType.getInputs().end()); - - // First pass: compute new argument types - SmallVector newInputs; - enum class ArgKind { None, IntTuple, Layout, MemRefStatic, MemRefDynamic }; - SmallVector argKinds; - - bool changed = false; - for (Type oldType : oldInputs) { - if (auto tupleTy = dyn_cast(oldType)) { - auto structTy = getIntTupleStructTypeOrEmpty(tupleTy.getAttr(), op.getContext()); - newInputs.push_back(structTy); - argKinds.push_back(ArgKind::IntTuple); - changed = true; - continue; - } - if (auto layoutTy = dyn_cast(oldType)) { - auto structTy = getLayoutStructTypeOrEmpty(layoutTy.getAttr(), op.getContext()); - newInputs.push_back(structTy); - argKinds.push_back(ArgKind::Layout); - changed = true; - continue; - } - if (auto memrefTy = dyn_cast(oldType)) { - auto ptrTy = getMemRefPtrType(memrefTy); - newInputs.push_back(ptrTy); - if (memrefHasDynamicLayout(memrefTy)) { - auto layoutStructTy = *getMemRefLayoutStructType(memrefTy); - newInputs.push_back(layoutStructTy); - argKinds.push_back(ArgKind::MemRefDynamic); - } else { - argKinds.push_back(ArgKind::MemRefStatic); - } - changed = true; - continue; - } - newInputs.push_back(oldType); - argKinds.push_back(ArgKind::None); - } - - if (!changed) - return false; - - // Update function type - auto newFuncType = FunctionType::get(op.getContext(), newInputs, funcType.getResults()); - op.setType(newFuncType); - - // Handle empty function (declaration only) - if (op.getBody().empty()) - return true; - - Block &entry = op.getBody().front(); - Location loc = op.getLoc(); - - // Transform block arguments: work backwards to handle index shifts from MemRef expansion - for (int i = oldInputs.size() - 1; i >= 0; --i) { - BlockArgument oldArg = entry.getArgument(i); - - if (argKinds[i] == ArgKind::None) { - continue; - } - - if (argKinds[i] == ArgKind::IntTuple || argKinds[i] == ArgKind::Layout) { - size_t newIdx = 0; - for (int j = 0; j < i; ++j) { - newIdx++; - if (argKinds[j] == ArgKind::MemRefDynamic) - newIdx++; - } - oldArg.setType(newInputs[newIdx]); - continue; - } - - if (argKinds[i] == ArgKind::MemRefStatic) { - size_t newIdx = 0; - for (int j = 0; j < i; ++j) { - newIdx++; - if (argKinds[j] == ArgKind::MemRefDynamic) - newIdx++; - } - oldArg.setType(newInputs[newIdx]); - continue; - } - - if (argKinds[i] == ArgKind::MemRefDynamic) { - size_t newIdx = 0; - for (int j = 0; j < i; ++j) { - newIdx++; - if (argKinds[j] == ArgKind::MemRefDynamic) - newIdx++; - } - oldArg.setType(newInputs[newIdx]); - entry.insertArgument(i + 1, newInputs[newIdx + 1], loc); - } - } - - // Reconstruct fly values from the new arguments - OpBuilder builder(&entry, entry.begin()); - - size_t newArgIdx = 0; - for (size_t i = 0; i < oldInputs.size(); ++i) { - if (argKinds[i] == ArgKind::None) { - newArgIdx++; - continue; - } - - if (argKinds[i] == ArgKind::IntTuple) { - auto tupleTy = cast(oldInputs[i]); - auto structTy = cast(newInputs[newArgIdx]); - BlockArgument arg = entry.getArgument(newArgIdx); - - SmallVector dynamicLeaves; - collectDynamicLeaves(tupleTy.getAttr(), dynamicLeaves); - if (dynamicLeaves.empty()) { - Value tuple = StaticOp::create(builder, loc, tupleTy); - arg.replaceAllUsesWith(tuple); - newArgIdx++; - continue; - } - - SmallVector dyncElems; - SmallVector extractOps; - dyncElems.reserve(dynamicLeaves.size()); - - for (size_t j = 0; j < dynamicLeaves.size(); ++j) { - Type fieldTy = structTy.getBody()[j]; - Value val = LLVM::ExtractValueOp::create(builder, loc, fieldTy, arg, - llvm::ArrayRef{static_cast(j)}); - dyncElems.push_back(val); - extractOps.push_back(val.getDefiningOp()); - } - - Value tuple = MakeIntTupleOp::create(builder, loc, tupleTy, dyncElems); - llvm::SmallPtrSet except(extractOps.begin(), extractOps.end()); - arg.replaceAllUsesExcept(tuple, except); - newArgIdx++; - continue; - } - - if (argKinds[i] == ArgKind::Layout) { - auto layoutTy = cast(oldInputs[i]); - auto structTy = cast(newInputs[newArgIdx]); - BlockArgument arg = entry.getArgument(newArgIdx); - LayoutAttr layoutAttr = layoutTy.getAttr(); - - SmallVector shapeLeaves; - SmallVector strideLeaves; - collectDynamicLeaves(layoutAttr.getShape(), shapeLeaves); - collectDynamicLeaves(layoutAttr.getStride(), strideLeaves); - if (shapeLeaves.empty() && strideLeaves.empty()) { - Value layout = StaticOp::create(builder, loc, layoutTy); - arg.replaceAllUsesWith(layout); - newArgIdx++; - continue; - } - - SmallVector shapeElems; - SmallVector strideElems; - SmallVector extractOps; - - auto shapeStructTy = cast(structTy.getBody()[0]); - auto strideStructTy = cast(structTy.getBody()[1]); - Value shapeStruct = LLVM::ExtractValueOp::create(builder, loc, shapeStructTy, arg, - llvm::ArrayRef{0}); - Value strideStruct = LLVM::ExtractValueOp::create(builder, loc, strideStructTy, arg, - llvm::ArrayRef{1}); - extractOps.push_back(shapeStruct.getDefiningOp()); - extractOps.push_back(strideStruct.getDefiningOp()); - - for (size_t j = 0; j < shapeLeaves.size(); ++j) { - Type fieldTy = shapeStructTy.getBody()[j]; - Value val = LLVM::ExtractValueOp::create(builder, loc, fieldTy, shapeStruct, - llvm::ArrayRef{static_cast(j)}); - shapeElems.push_back(val); - extractOps.push_back(val.getDefiningOp()); - } - for (size_t j = 0; j < strideLeaves.size(); ++j) { - Type fieldTy = strideStructTy.getBody()[j]; - Value val = LLVM::ExtractValueOp::create(builder, loc, fieldTy, strideStruct, - llvm::ArrayRef{static_cast(j)}); - strideElems.push_back(val); - extractOps.push_back(val.getDefiningOp()); - } - - IntTupleType shapeTy = IntTupleType::get(op.getContext(), layoutAttr.getShape()); - IntTupleType strideTy = IntTupleType::get(op.getContext(), layoutAttr.getStride()); - Value shape = MakeIntTupleOp::create(builder, loc, shapeTy, shapeElems); - Value stride = MakeIntTupleOp::create(builder, loc, strideTy, strideElems); - Value layout = MakeLayoutOp::create(builder, loc, layoutTy, shape, stride); - llvm::SmallPtrSet except(extractOps.begin(), extractOps.end()); - arg.replaceAllUsesExcept(layout, except); - newArgIdx++; - continue; - } - - if (argKinds[i] == ArgKind::MemRefStatic) { - auto memrefTy = cast(oldInputs[i]); - LayoutAttr layoutAttr = memrefTy.getLayout(); - - BlockArgument ptrArg = entry.getArgument(newArgIdx); - - IntTupleType shapeTy = IntTupleType::get(op.getContext(), layoutAttr.getShape()); - IntTupleType strideTy = IntTupleType::get(op.getContext(), layoutAttr.getStride()); - Value shape = MakeIntTupleOp::create(builder, loc, shapeTy, ValueRange{}); - Value stride = MakeIntTupleOp::create(builder, loc, strideTy, ValueRange{}); - auto layoutTy = LayoutType::get(op.getContext(), layoutAttr); - Value layout = MakeLayoutOp::create(builder, loc, layoutTy, shape, stride); - - Value view = MakeViewOp::create(builder, loc, memrefTy, ptrArg, layout); - - llvm::SmallPtrSet except; - except.insert(view.getDefiningOp()); - ptrArg.replaceAllUsesExcept(view, except); - - newArgIdx++; - continue; - } - - if (argKinds[i] == ArgKind::MemRefDynamic) { - auto memrefTy = cast(oldInputs[i]); - LayoutAttr layoutAttr = memrefTy.getLayout(); - - BlockArgument ptrArg = entry.getArgument(newArgIdx); - BlockArgument layoutStructArg = entry.getArgument(newArgIdx + 1); - auto layoutStructTy = cast(layoutStructArg.getType()); - - SmallVector shapeLeaves; - SmallVector strideLeaves; - collectDynamicLeaves(layoutAttr.getShape(), shapeLeaves); - collectDynamicLeaves(layoutAttr.getStride(), strideLeaves); - - SmallVector shapeElems; - SmallVector strideElems; - SmallVector extractOps; - - auto shapeStructTy = cast(layoutStructTy.getBody()[0]); - auto strideStructTy = cast(layoutStructTy.getBody()[1]); - Value shapeStruct = LLVM::ExtractValueOp::create(builder, loc, shapeStructTy, layoutStructArg, - llvm::ArrayRef{0}); - Value strideStruct = LLVM::ExtractValueOp::create( - builder, loc, strideStructTy, layoutStructArg, llvm::ArrayRef{1}); - extractOps.push_back(shapeStruct.getDefiningOp()); - extractOps.push_back(strideStruct.getDefiningOp()); - - for (size_t j = 0; j < shapeLeaves.size(); ++j) { - Type fieldTy = shapeStructTy.getBody()[j]; - Value val = LLVM::ExtractValueOp::create(builder, loc, fieldTy, shapeStruct, - llvm::ArrayRef{static_cast(j)}); - shapeElems.push_back(val); - extractOps.push_back(val.getDefiningOp()); - } - for (size_t j = 0; j < strideLeaves.size(); ++j) { - Type fieldTy = strideStructTy.getBody()[j]; - Value val = LLVM::ExtractValueOp::create(builder, loc, fieldTy, strideStruct, - llvm::ArrayRef{static_cast(j)}); - strideElems.push_back(val); - extractOps.push_back(val.getDefiningOp()); - } - - IntTupleType shapeTy = IntTupleType::get(op.getContext(), layoutAttr.getShape()); - IntTupleType strideTy = IntTupleType::get(op.getContext(), layoutAttr.getStride()); - Value shape = MakeIntTupleOp::create(builder, loc, shapeTy, shapeElems); - Value stride = MakeIntTupleOp::create(builder, loc, strideTy, strideElems); - auto layoutTy = LayoutType::get(op.getContext(), layoutAttr); - Value layout = MakeLayoutOp::create(builder, loc, layoutTy, shape, stride); - - Value view = MakeViewOp::create(builder, loc, memrefTy, ptrArg, layout); - - llvm::SmallPtrSet except(extractOps.begin(), extractOps.end()); - except.insert(view.getDefiningOp()); - ptrArg.replaceAllUsesExcept(view, except); - - newArgIdx += 2; - continue; - } - } - - return true; -} - static void collectLeafValues(const IntTupleBuilder &builder, const IntTupleValueAdaptor &tuple, SmallVectorImpl &out) { // if (tuple.isLeaf()) { @@ -1571,8 +1298,7 @@ class FlyLayoutLoweringPass void runOnOperation() override { MLIRContext *context = &getContext(); - getOperation()->walk([&](gpu::GPUFuncOp gpuFunc) { lowerGpuFuncIntTupleArgs(gpuFunc); }); - getOperation()->walk([&](func::FuncOp funcOp) { lowerFuncOpIntTupleArgs(funcOp); }); + getOperation()->walk([&](FunctionOpInterface funcOp) { lowerFuncIntTupleArgs(funcOp); }); getOperation()->walk( [&](gpu::LaunchFuncOp launchOp) { lowerGpuLaunchFuncIntTupleOperands(launchOp); }); diff --git a/lib/Dialect/FlyROCDL/Dialect.cpp b/lib/Dialect/FlyROCDL/Dialect.cpp index 3ae4d5d5..9387bdfa 100644 --- a/lib/Dialect/FlyROCDL/Dialect.cpp +++ b/lib/Dialect/FlyROCDL/Dialect.cpp @@ -6,6 +6,7 @@ #include "flydsl/Dialect/FlyROCDL/IR/Dialect.h" using namespace mlir; +using namespace mlir::fly; using namespace mlir::fly_rocdl; #include "flydsl/Dialect/FlyROCDL/IR/Dialect.cpp.inc" @@ -14,27 +15,6 @@ using namespace mlir::fly_rocdl; #define GET_TYPEDEF_CLASSES #include "flydsl/Dialect/FlyROCDL/IR/Atom.cpp.inc" -namespace mlir::fly_rocdl { - -ParseResult parseMNKDimensionList(AsmParser &parser, int32_t &m, int32_t &n, int32_t &k) { - SmallVector dimensions; - if (parser.parseDimensionList(dimensions, false, false)) - return failure(); - if (dimensions.size() != 3) - return parser.emitError(parser.getCurrentLocation()) - << "expected 3 dimensions in MNK dimension list"; - m = dimensions[0]; - n = dimensions[1]; - k = dimensions[2]; - return success(); -} - -void printMNKDimensionList(AsmPrinter &printer, int32_t m, int32_t n, int32_t k) { - printer.printDimensionList(ArrayRef{m, n, k}); -} - -} // namespace mlir::fly_rocdl - void FlyROCDLDialect::initialize() { addTypes< #define GET_TYPEDEF_LIST diff --git a/python/flydsl/lang/ir/core.py b/python/flydsl/lang/ir/core.py index 34ee3d42..adf8d0c0 100644 --- a/python/flydsl/lang/ir/core.py +++ b/python/flydsl/lang/ir/core.py @@ -20,6 +20,7 @@ PointerType, MemRefType, CopyAtomUniversalCopyType, + MmaAtomUniversalFMAType, ) from ..._mlir.dialects.fly_rocdl import ( From 866e2ed112124eaa8b2529a2a40d8fdde6e133fb Mon Sep 17 00:00:00 2001 From: Feng Shijie Date: Thu, 29 Jan 2026 06:02:25 +0000 Subject: [PATCH 006/113] fix example02 --- examples/02-layout_algebra.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/02-layout_algebra.py b/examples/02-layout_algebra.py index d2821ee7..2e8eaa44 100644 --- a/examples/02-layout_algebra.py +++ b/examples/02-layout_algebra.py @@ -4,7 +4,7 @@ M = 16 N = 32 memrefTy = fx.MemRefType.get( - fx.T.f32(), fx.LayoutType.get(16, 32), fx.AddressSpace.Global + fx.T.f32(), fx.LayoutType.get((M, N), (N, 1)), fx.AddressSpace.Global ) From 8b29d66b3192b3a5673bf16baeae564d8bd12eb6 Mon Sep 17 00:00:00 2001 From: Feng Shijie Date: Sat, 31 Jan 2026 12:32:45 +0000 Subject: [PATCH 007/113] Add DLTensorAdaptor for torch Tensor support --- .gitmodules | 6 + CMakeLists.txt | 3 + .../Dialect/Fly/Transforms/LayoutLowering.td | 2 +- lib/Bindings/Python/DLTensorAdaptor.h | 331 ++++++++++++++++++ lib/Bindings/Python/FlyExtension.cpp | 29 ++ lib/Dialect/Fly/Transforms/LayoutLowering.cpp | 323 +++++++++++++---- lib/Dialect/Fly/Utils/NormalForm.cpp | 6 +- python/flydsl/lang/ir/core.py | 20 +- thirdparty/dlpack | 1 + thirdparty/tvm-ffi | 1 + 10 files changed, 642 insertions(+), 80 deletions(-) create mode 100644 .gitmodules create mode 100644 lib/Bindings/Python/DLTensorAdaptor.h create mode 160000 thirdparty/dlpack create mode 160000 thirdparty/tvm-ffi diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 00000000..59b67345 --- /dev/null +++ b/.gitmodules @@ -0,0 +1,6 @@ +[submodule "thirdparty/dlpack"] + path = thirdparty/dlpack + url = https://github.com/dmlc/dlpack.git +[submodule "thirdparty/tvm-ffi"] + path = thirdparty/tvm-ffi + url = https://github.com/apache/tvm-ffi.git diff --git a/CMakeLists.txt b/CMakeLists.txt index 2e4efb8f..dad694e8 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -31,6 +31,9 @@ include_directories(SYSTEM ${MLIR_INCLUDE_DIRS}) include_directories(${PROJECT_SOURCE_DIR}/include) include_directories(${PROJECT_BINARY_DIR}/include) +include_directories(${PROJECT_SOURCE_DIR}/thirdparty/dlpack/include) +include_directories(${PROJECT_SOURCE_DIR}/thirdparty/tvm-ffi/include) + link_directories(${LLVM_BUILD_LIBRARY_DIR}) add_definitions(${LLVM_DEFINITIONS}) diff --git a/include/flydsl/Dialect/Fly/Transforms/LayoutLowering.td b/include/flydsl/Dialect/Fly/Transforms/LayoutLowering.td index 43b1cef0..53a2b02f 100644 --- a/include/flydsl/Dialect/Fly/Transforms/LayoutLowering.td +++ b/include/flydsl/Dialect/Fly/Transforms/LayoutLowering.td @@ -31,7 +31,7 @@ def : Pat<(Fly_SliceOp Fly_MemRef:$memref, Fly_IntTuple:$coord), def : Pat<(Fly_SizeOp Fly_Layout:$layout), (Fly_SizeOp (Fly_GetShapeOp $layout))>; def : Pat<(Fly_SizeOp Fly_MemRef:$memref), - (Fly_SizeOp (Fly_GetLayoutOp $memref))>; + (Fly_SizeOp (Fly_GetShapeOp (Fly_GetLayoutOp $memref)))>; def : Pat<(Fly_SelectOp Fly_Layout:$layout, DenseI32ArrayAttr:$indices), (Fly_MakeLayoutOp diff --git a/lib/Bindings/Python/DLTensorAdaptor.h b/lib/Bindings/Python/DLTensorAdaptor.h new file mode 100644 index 00000000..866ee344 --- /dev/null +++ b/lib/Bindings/Python/DLTensorAdaptor.h @@ -0,0 +1,331 @@ +#include "mlir-c/Bindings/Python/Interop.h" +#include "mlir-c/IR.h" +#include "mlir-c/Support.h" +#include "mlir/Bindings/Python/Nanobind.h" +#include "mlir/Bindings/Python/NanobindAdaptors.h" +#include "mlir/CAPI/IR.h" +#include "mlir/CAPI/Wrap.h" + +#include +#include + +#include "flydsl/Dialect/Fly/IR/FlyDialect.h" +#include "flydsl/Dialect/Fly/Utils/IntUtils.h" + +#include "dlpack/dlpack.h" + +#include +#include +#include + +namespace nb = nanobind; +using namespace nb::literals; + +using namespace mlir; +using namespace mlir::fly; +using namespace mlir::python::nanobind_adaptors; + +namespace mlir::fly::utils { + +class DLTensorAdaptor { +private: + struct DimInfo { + int64_t dimSize = 0; + int32_t divisibility = 1; + bool isDynamic = false; + + DimInfo() = default; + DimInfo(int64_t dimSize) : dimSize(dimSize), divisibility(dimSize) {} + + DimInfo &setDynamic(int32_t divisibility = 1) { + isDynamic = true; + this->divisibility = divisibility; + return *this; + } + + IntTupleAttr getIntAttr(MLIRContext *ctx_, bool use32bitDynamic = false) const { + if (isDynamic) { + return IntTupleAttr::getLeafDynamic(ctx_, use32bitDynamic ? 32 : 64, divisibility); + } else { + return IntTupleAttr::getLeafStatic(ctx_, dimSize); + } + } + }; + + struct MemRefDescriptor { + Type memrefType = nullptr; + void *dataPtr = nullptr; + std::vector layoutBuffer; + }; + +public: + DLTensorAdaptor(nb::object dlpackCapsule, int32_t alignment, bool use32BitStride, + MlirContext context) + : dlpackCapsule_(dlpackCapsule), alignment_(alignment), use32BitStride_(use32BitStride), + ctx_(unwrap(context)) { + DLManagedTensor *managed = + static_cast(PyCapsule_GetPointer(dlpackCapsule.ptr(), "dltensor")); + if (!managed) { + throw std::runtime_error("Invalid DLPack capsule: expected 'dltensor'"); + } + tensor_ = &managed->dl_tensor; + + ndim_ = tensor_->ndim; + if (ndim_ == 0) { + throw std::runtime_error("DLTensor must have at least one dimension"); + } + + shape_.resize(ndim_); + stride_.resize(ndim_); + for (int i = 0; i < ndim_; ++i) { + shape_[i] = DimInfo(tensor_->shape[i]); + } + for (int i = 0; i < ndim_; ++i) { + stride_[i] = DimInfo(tensor_->strides[i]); + } + } + + nb::tuple getShape() const { + nb::list result; + for (const auto &s : shape_) { + result.append(nb::int_(s.dimSize)); + } + return nb::tuple(result); + } + + nb::tuple getStride() const { + nb::list result; + for (const auto &s : stride_) { + result.append(nb::int_(s.dimSize)); + } + return nb::tuple(result); + } + + int64_t getDataPtr() const { + return reinterpret_cast(static_cast(tensor_->data) + tensor_->byte_offset); + } + + int64_t getSizeInBytes() const { + int64_t numElements = 1; + for (const auto &s : shape_) { + numElements *= s.dimSize; + } + int64_t bitsPerElem = tensor_->dtype.bits * tensor_->dtype.lanes; + return (numElements * bitsPerElem + 7) / 8; + } + + int getAddressSpace() const { + switch (tensor_->device.device_type) { + case kDLCPU: + return 0; // Host + case kDLCUDA: + [[fallthrough]]; + case kDLCUDAHost: + [[fallthrough]]; + case kDLCUDAManaged: + [[fallthrough]]; + case kDLROCM: + [[fallthrough]]; + case kDLROCMHost: + return 1; // Global (device memory) + default: + return 0; + } + } + + Type getElementType() const { + DLDataType dtype = tensor_->dtype; + + switch (dtype.code) { + case kDLFloat: + switch (dtype.bits) { + case 16: + return Float16Type::get(ctx_); + case 32: + return Float32Type::get(ctx_); + case 64: + return Float64Type::get(ctx_); + default: + throw std::runtime_error("Unsupported float bit width: " + std::to_string(dtype.bits)); + } + case kDLInt: + return IntegerType::get(ctx_, dtype.bits, IntegerType::Signed); + case kDLUInt: + return IntegerType::get(ctx_, dtype.bits, IntegerType::Unsigned); + case kDLBfloat: + return BFloat16Type::get(ctx_); + case kDLBool: + return IntegerType::get(ctx_, 1); + case kDLFloat8_e5m2: + return Float8E5M2Type::get(ctx_); + case kDLFloat8_e4m3fn: + return Float8E4M3FNType::get(ctx_); + case kDLFloat8_e5m2fnuz: + return Float8E5M2FNUZType::get(ctx_); + case kDLFloat8_e4m3fnuz: + return Float8E4M3FNUZType::get(ctx_); + case kDLFloat8_e4m3b11fnuz: + return Float8E4M3B11FNUZType::get(ctx_); + case kDLComplex: + switch (dtype.bits) { + case 64: + return ComplexType::get(Float32Type::get(ctx_)); + case 128: + return ComplexType::get(Float64Type::get(ctx_)); + default: + throw std::runtime_error("Unsupported complex bit width: " + std::to_string(dtype.bits)); + } + default: + throw std::runtime_error("Unsupported DLPack dtype code: " + std::to_string(dtype.code)); + } + } + + void buildMemRefDesc() { + if (!isMemrefStale_) { + return; + } + + SmallVector shapeLeaves, strideLeaves; + shapeLeaves.resize(ndim_); + strideLeaves.resize(ndim_); + + size_t shapeDyncCount = 0; + size_t strideDyncCount = 0; + for (int i = 0; i < ndim_; ++i) { + shapeLeaves[i] = shape_[i].getIntAttr(ctx_, true); + strideLeaves[i] = stride_[i].getIntAttr(ctx_, use32BitStride_); + + if (shape_[i].isDynamic) + shapeDyncCount++; + if (stride_[i].isDynamic) + strideDyncCount++; + } + + IntTupleAttr shapeAttr = IntTupleAttr::get(ArrayAttr::get(ctx_, shapeLeaves)); + IntTupleAttr strideAttr = IntTupleAttr::get(ArrayAttr::get(ctx_, strideLeaves)); + LayoutAttr layoutAttr = LayoutAttr::get(ctx_, shapeAttr, strideAttr); + + if (getAddressSpace() != 1) { + throw std::runtime_error("Only device address space is supported"); + } + AddressSpaceAttr addrSpaceAttr = AddressSpaceAttr::get(ctx_, AddressSpace::Global); + + assert(alignment_ > 0 && "alignment must be positive"); + AlignAttr alignAttr = AlignAttr::get(ctx_, alignment_); + + memrefDesc_.memrefType = + fly::MemRefType::get(getElementType(), addrSpaceAttr, layoutAttr, alignAttr); + + // Get data pointer (with byte offset applied) + memrefDesc_.dataPtr = + static_cast(static_cast(tensor_->data) + tensor_->byte_offset); + + // Build packed layout struct for dynamic elements + // Layout: [shape_dync_elems (i32)...][stride_dync_elems (i32 or i64)...] + size_t strideElemSize = use32BitStride_ ? sizeof(int32_t) : sizeof(int64_t); + size_t layoutSize = shapeDyncCount * sizeof(int32_t) + strideDyncCount * strideElemSize; + + if (layoutSize > 0) { + memrefDesc_.layoutBuffer.resize(layoutSize); + char *ptr = memrefDesc_.layoutBuffer.data(); + + // Pack dynamic shape elements (i32) + for (int i = 0; i < ndim_; ++i) { + if (shape_[i].isDynamic) { + int32_t val = static_cast(shape_[i].dimSize); + std::memcpy(ptr, &val, sizeof(int32_t)); + ptr += sizeof(int32_t); + } + } + // Pack dynamic stride elements (i32 or i64) + for (int i = 0; i < ndim_; ++i) { + if (stride_[i].isDynamic) { + if (use32BitStride_) { + int32_t val = static_cast(stride_[i].dimSize); + std::memcpy(ptr, &val, sizeof(int32_t)); + ptr += sizeof(int32_t); + } else { + int64_t val = stride_[i].dimSize; + std::memcpy(ptr, &val, sizeof(int64_t)); + ptr += sizeof(int64_t); + } + } + } + } + + isMemrefStale_ = false; + } + + MlirType getMemRefType() { + if (isMemrefStale_) { + throw std::runtime_error("Memref descriptor is stale"); + } + return wrap(memrefDesc_.memrefType); + } + + nb::list getCPointers() const { + if (isMemrefStale_) { + throw std::runtime_error("Memref descriptor is stale"); + } + nb::list result; + // Add data pointer as integer + result.append(nb::int_(reinterpret_cast(&memrefDesc_.dataPtr))); + // If layout has dynamic elements, add layout struct pointer + if (!memrefDesc_.layoutBuffer.empty()) { + result.append(nb::int_(reinterpret_cast(memrefDesc_.layoutBuffer.data()))); + } + return result; + } + + DLTensorAdaptor &markLayoutDynamic(int leadingDim = -1, int divisibility = 1) { + int ndim_ = static_cast(shape_.size()); + if (leadingDim == -1) { + for (int i = 0; i < ndim_; ++i) { + if (stride_[i].dimSize == 1) { + if (leadingDim != -1) { + throw std::runtime_error("Multiple dimensions have stride 1"); + } + leadingDim = i; + } + } + } + if (leadingDim < 0 || leadingDim >= ndim_) { + throw std::runtime_error("Cannot determine leading dimension"); + } + + isMemrefStale_ = true; + for (int i = 0; i < ndim_; ++i) { + shape_[i].setDynamic(); + } + for (int i = 0; i < ndim_; ++i) { + if (i != leadingDim) { + stride_[i].setDynamic(divisibility); + } + } + return *this; + } + + DLTensorAdaptor &use32BitStride(bool use32BitStride) { + if (use32BitStride_ == use32BitStride) { + return *this; + } + isMemrefStale_ = true; + use32BitStride_ = use32BitStride; + return *this; + } + +private: + nb::object dlpackCapsule_; + int32_t alignment_; + bool use32BitStride_; + MLIRContext *ctx_; + + DLTensor *tensor_; + int32_t ndim_; + std::vector shape_; + std::vector stride_; + MemRefDescriptor memrefDesc_; + bool isMemrefStale_{true}; +}; + +} // namespace mlir::fly::utils diff --git a/lib/Bindings/Python/FlyExtension.cpp b/lib/Bindings/Python/FlyExtension.cpp index fae145ac..1bbc0bd5 100644 --- a/lib/Bindings/Python/FlyExtension.cpp +++ b/lib/Bindings/Python/FlyExtension.cpp @@ -8,6 +8,7 @@ #include "mlir/CAPI/Wrap.h" #include +#include #include #include @@ -15,8 +16,14 @@ #include "flydsl/Dialect/Fly/IR/FlyDialect.h" #include "flydsl/Dialect/Fly/Utils/IntUtils.h" +#include "DLTensorAdaptor.h" + +#include +#include + namespace nb = nanobind; using namespace nb::literals; + using namespace mlir; using namespace mlir::fly; using namespace mlir::python::nanobind_adaptors; @@ -126,6 +133,28 @@ int32_t depth(MlirValue int_or_tuple) { NB_MODULE(_fly, m) { m.doc() = "MLIR Python FlyDSL Extension"; + using DLTensorAdaptor = utils::DLTensorAdaptor; + + nb::class_(m, "DLTensorAdaptor") + .def(nb::init(), "dlpack_capsule"_a, + "alignment"_a = 1, "use_32bit_stride"_a = false, "context"_a, + "Create a DLTensorAdaptor from a DLPack capsule") + .def_prop_ro("shape", &DLTensorAdaptor::getShape, "Get tensor shape as tuple") + .def_prop_ro("stride", &DLTensorAdaptor::getStride, "Get tensor stride as tuple") + .def_prop_ro("data_ptr", &DLTensorAdaptor::getDataPtr, "Get data pointer as int64") + .def_prop_ro("address_space", &DLTensorAdaptor::getAddressSpace, + "Get address space (0=host, 1=device)") + .def("size_in_bytes", &DLTensorAdaptor::getSizeInBytes, "Get total size in bytes") + .def("build_memref_desc", &DLTensorAdaptor::buildMemRefDesc, + "Build memref descriptor based on current dynamic marks") + .def("get_memref_type", &DLTensorAdaptor::getMemRefType, + "Get fly.memref MLIR type based on current dynamic marks") + .def("get_c_pointers", &DLTensorAdaptor::getCPointers, "Get list of c pointers") + .def("mark_layout_dynamic", &DLTensorAdaptor::markLayoutDynamic, "leading_dim"_a = -1, + "divisibility"_a = 1, "Mark entire layout as dynamic except leading dim stride") + .def("use_32bit_stride", &DLTensorAdaptor::use32BitStride, "use_32bit_stride"_a, + "Decide whether to use 32-bit stride"); + m.def( "infer_int_tuple_type", [](nb::handle int_or_tuple, MlirContext context) { diff --git a/lib/Dialect/Fly/Transforms/LayoutLowering.cpp b/lib/Dialect/Fly/Transforms/LayoutLowering.cpp index ba07971a..1f7dc710 100644 --- a/lib/Dialect/Fly/Transforms/LayoutLowering.cpp +++ b/lib/Dialect/Fly/Transforms/LayoutLowering.cpp @@ -741,10 +741,10 @@ static bool lowerFuncIntTupleArgs(FunctionOpInterface op) { static void collectLeafValues(const IntTupleBuilder &builder, const IntTupleValueAdaptor &tuple, SmallVectorImpl &out) { - // if (tuple.isLeaf()) { - // out.push_back(tuple.intTupleValue); - // return; - // } + if (tuple.isLeaf()) { + out.push_back(builder.getArithValue(tuple).value); + return; + } for (int i = 0; i < tuple.rank(); ++i) { collectLeafValues(builder, builder.at(tuple, i), out); } @@ -811,6 +811,13 @@ static bool appendScalarPrintfArg(PatternRewriter &rewriter, Location loc, Value static bool appendIntTuplePrintf(PatternRewriter &rewriter, Location loc, const IntTupleValueAdaptor &tuple, std::string &format, SmallVectorImpl &args) { + // For single leaf values, don't add parentheses + if (tuple.isLeaf()) { + IntTupleBuilder builder(rewriter, loc); + Value leafValue = builder.getArithValue(tuple).value; + return appendScalarPrintfArg(rewriter, loc, leafValue, format, args); + } + SmallVector leaves; IntTupleBuilder builder(rewriter, loc); collectLeafValues(builder, tuple, leaves); @@ -828,6 +835,16 @@ static bool appendIntTuplePrintf(PatternRewriter &rewriter, Location loc, } static bool appendIntTuplePrintfStatic(IntTupleAttr attr, std::string &format) { + // For single leaf values, don't add parentheses + if (attr.isLeaf()) { + if (attr.getLeafAsInt().isStatic()) { + format += std::to_string(attr.getLeafAsInt().getValue()); + } else { + format += "?"; + } + return true; + } + SmallVector leaves; collectLeafAttrs(attr, leaves); format += "("; @@ -941,6 +958,8 @@ class GetShapeLowering : public OpRewritePattern { LogicalResult matchAndRewrite(GetShapeOp op, PatternRewriter &rewriter) const override { auto layout = op.getLayout(); + if (!isNormalForm(cast>(layout))) + return failure(); if (auto defOp = layout.getDefiningOp()) { rewriter.replaceOp(op, defOp.getShape()); return success(); @@ -956,6 +975,8 @@ class GetStrideLowering : public OpRewritePattern { LogicalResult matchAndRewrite(GetStrideOp op, PatternRewriter &rewriter) const override { auto layout = op.getLayout(); + if (!isNormalForm(cast>(layout))) + return failure(); if (auto defOp = layout.getDefiningOp()) { rewriter.replaceOp(op, defOp.getStride()); return success(); @@ -979,6 +1000,87 @@ class GetLayoutLowering : public OpRewritePattern { } }; +//===----------------------------------------------------------------------===// +// GetLeafOp Lowering +//===----------------------------------------------------------------------===// + +class GetLeafOpLowering : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(GetLeafOp op, PatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + Value tuple = op.getTuple(); + int32_t leafIdx = op.getLeafIdx(); + + // Handle IntTuple case + if (auto intTupleTy = dyn_cast(tuple.getType())) { + if (!isNormalForm(cast>(tuple))) + return failure(); + + auto defOp = tuple.getDefiningOp(); + if (!defOp) + return failure(); + + IntTupleAttr profile = intTupleTy.getAttr(); + IntTupleAttr leafProfile = profile.at(leafIdx); + IntTupleType leafTy = IntTupleType::get(leafProfile); + + // Calculate the dynamic element offset for this leaf + int32_t dyncOffset = 0; + for (int32_t i = 0; i < leafIdx; ++i) { + dyncOffset += profile.at(i).dyncLeafCount(); + } + int32_t leafDyncCount = leafProfile.dyncLeafCount(); + + // Extract the dynamic elements for this leaf + SmallVector leafDyncElems; + for (int32_t i = 0; i < leafDyncCount; ++i) { + leafDyncElems.push_back(defOp.getDyncElems()[dyncOffset + i]); + } + + Value newTuple = MakeIntTupleOp::create(rewriter, loc, leafTy, leafDyncElems); + rewriter.replaceOp(op, newTuple); + return success(); + } + + // Handle Layout case + if (auto layoutTy = dyn_cast(tuple.getType())) { + if (!isNormalForm(cast>(tuple))) + return failure(); + + auto defOp = tuple.getDefiningOp(); + if (!defOp) + return failure(); + + LayoutAttr profile = layoutTy.getAttr(); + LayoutAttr leafProfile = profile.at(leafIdx); + LayoutType leafTy = LayoutType::get(op.getContext(), leafProfile); + + // Get shape and stride from the defining MakeLayoutOp + Value shape = defOp.getShape(); + Value stride = defOp.getStride(); + + auto shapeTy = cast(shape.getType()); + auto strideTy = cast(stride.getType()); + + // Extract leaf from shape + IntTupleAttr shapeLeafProfile = shapeTy.getAttr().at(leafIdx); + Value shapeLeaf = GetLeafOp::create(rewriter, loc, shape, leafIdx); + + // Extract leaf from stride + IntTupleAttr strideLeafProfile = strideTy.getAttr().at(leafIdx); + Value strideLeaf = GetLeafOp::create(rewriter, loc, stride, leafIdx); + + Value newLayout = MakeLayoutOp::create(rewriter, loc, leafTy, shapeLeaf, strideLeaf); + rewriter.replaceOp(op, newLayout); + return success(); + } + + return failure(); + } +}; + //===----------------------------------------------------------------------===// // GetScalarOp Lowering //===----------------------------------------------------------------------===// @@ -1029,27 +1131,53 @@ class SizeOpLowering : public OpRewritePattern { LogicalResult matchAndRewrite(SizeOp op, PatternRewriter &rewriter) const override { auto loc = op.getLoc(); - Value intTuple = op.getIntTuple(); + Value input = op.getIntTuple(); - auto intTupleTy = dyn_cast(intTuple.getType()); - if (!intTupleTy) - return failure(); - if (!isNormalForm(dyn_cast>(intTuple))) { - return failure(); + if (auto intTupleTy = dyn_cast(input.getType())) { + if (!isNormalForm(dyn_cast>(input))) { + return failure(); + } + + auto resultTy = dyn_cast(op.getResult().getType()); + if (!resultTy) + return failure(); + + // Use intTupleProduct to compute the size + IntTupleBuilder builder(rewriter, loc); + IntTupleValueAdaptor inputAdaptor = + IntTupleValueAdaptor::create(builder, input, intTupleTy.getAttr()); + IntTupleValueAdaptor productAdaptor = intTupleProduct(builder, inputAdaptor); + + rewriter.replaceOp(op, builder.finalize(productAdaptor)); + return success(); } - auto resultTy = dyn_cast(op.getResult().getType()); - if (!resultTy) - return failure(); + if (auto layoutTy = dyn_cast(input.getType())) { + Value shape = nullptr; + if (auto layoutVal = dyn_cast>(input)) { + if (isNormalForm(layoutVal)) { + if (auto layoutOp = input.getDefiningOp()) { + shape = layoutOp.getShape(); + } + } + } + if (!shape) { + shape = GetShapeOp::create(rewriter, loc, input); + } + Value size = SizeOp::create(rewriter, loc, shape); + rewriter.replaceOp(op, size); + return success(); + } - // Use intTupleProduct to compute the size - IntTupleBuilder builder(rewriter, loc); - IntTupleValueAdaptor inputAdaptor = - IntTupleValueAdaptor::create(builder, intTuple, intTupleTy.getAttr()); - IntTupleValueAdaptor productAdaptor = intTupleProduct(builder, inputAdaptor); + if (auto memrefTy = dyn_cast(input.getType())) { + Value layout = GetLayoutOp::create(rewriter, loc, input); + Value shape = GetShapeOp::create(rewriter, loc, layout); + Value size = SizeOp::create(rewriter, loc, shape); + rewriter.replaceOp(op, size); + return success(); + } - rewriter.replaceOp(op, builder.finalize(productAdaptor)); - return success(); + return failure(); } }; @@ -1201,11 +1329,8 @@ class PrintOpLowering : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(PrintOp op, PatternRewriter &rewriter) const override { - if (!op->getParentOfType()) { - return failure(); - } + bool isGpuContext = op->getParentOfType() != nullptr; - // check all values are in normal form for (Value val : op.getValues()) { if (auto intTupleVal = dyn_cast>(val)) { if (!isNormalForm(intTupleVal)) { @@ -1216,65 +1341,137 @@ class PrintOpLowering : public OpRewritePattern { return failure(); } } else { - // TODO: handle other types continue; } } auto loc = op.getLoc(); + std::string userFormat = op.getFormat().str(); std::string format; SmallVector args; - bool first = true; - - auto appendSeparator = [&]() { - if (!first) { - format += " "; - } - first = false; - }; - for (Value val : op.getValues()) { - appendSeparator(); + auto formatValueToString = [&](Value val) -> std::string { + std::string valFormat; if (auto tupleTy = dyn_cast(val.getType())) { if (tupleTy.getAttr().isStatic()) { - if (!appendIntTuplePrintfStatic(tupleTy.getAttr(), format)) { - return failure(); - } - continue; + appendIntTuplePrintfStatic(tupleTy.getAttr(), valFormat); + } else { + IntTupleBuilder builder(rewriter, loc); + IntTupleValueAdaptor tuple = + IntTupleValueAdaptor::create(builder, val, tupleTy.getAttr()); + appendIntTuplePrintf(rewriter, loc, tuple, valFormat, args); } - IntTupleBuilder builder(rewriter, loc); - IntTupleValueAdaptor tuple = IntTupleValueAdaptor::create(builder, val, tupleTy.getAttr()); - if (!appendIntTuplePrintf(rewriter, loc, tuple, format, args)) - return failure(); } else if (auto layoutTy = dyn_cast(val.getType())) { - format += ""; if (layoutTy.getAttr().isStatic()) { - if (!appendIntTuplePrintfStatic(layoutTy.getAttr().getShape(), format)) { - return failure(); + appendIntTuplePrintfStatic(layoutTy.getAttr().getShape(), valFormat); + valFormat += ":"; + appendIntTuplePrintfStatic(layoutTy.getAttr().getStride(), valFormat); + } else { + LayoutBuilder layoutBuilder(rewriter, loc); + LayoutValueAdaptor layout(val, layoutTy.getAttr()); + appendIntTuplePrintf(rewriter, loc, layoutBuilder.getShape(layout), valFormat, args); + valFormat += ":"; + appendIntTuplePrintf(rewriter, loc, layoutBuilder.getStride(layout), valFormat, args); + } + } else { + appendScalarPrintfArg(rewriter, loc, val, valFormat, args); + } + return valFormat; + }; + + // For CPU context, we need to interleave text and values + // Collect text segments and argument indices + struct PrintSegment { + std::string text; + int argIndex = -1; // -1 means text only + }; + SmallVector segments; + + if (!userFormat.empty()) { + size_t valueIdx = 0; + size_t argIdx = 0; + size_t pos = 0; + while (pos < userFormat.size()) { + size_t placeholderPos = userFormat.find("{}", pos); + if (placeholderPos == std::string::npos) { + segments.push_back({userFormat.substr(pos), -1}); + break; + } + if (placeholderPos > pos) { + segments.push_back({userFormat.substr(pos, placeholderPos - pos), -1}); + } + if (valueIdx < op.getValues().size()) { + size_t argStartIdx = args.size(); + std::string staticFormat = formatValueToString(op.getValues()[valueIdx]); + size_t numArgsAdded = args.size() - argStartIdx; + if (numArgsAdded == 0 && !staticFormat.empty()) { + // Static value: add as text segment + segments.push_back({staticFormat, -1}); + } else { + for (size_t i = 0; i < numArgsAdded; ++i) { + if (i > 0) { + segments.push_back({", ", -1}); + } + segments.push_back({"", static_cast(argStartIdx + i)}); + } } - format += ":"; - if (!appendIntTuplePrintfStatic(layoutTy.getAttr().getStride(), format)) { - return failure(); + valueIdx++; + } + pos = placeholderPos + 2; + } + } else { + bool first = true; + for (Value val : op.getValues()) { + if (!first) { + segments.push_back({" ", -1}); + } + first = false; + size_t argStartIdx = args.size(); + std::string staticFormat = formatValueToString(val); + size_t numArgsAdded = args.size() - argStartIdx; + if (numArgsAdded == 0 && !staticFormat.empty()) { + // Static value: add as text segment + segments.push_back({staticFormat, -1}); + } else { + for (size_t i = 0; i < numArgsAdded; ++i) { + if (i > 0) { + segments.push_back({", ", -1}); + } + segments.push_back({"", static_cast(argStartIdx + i)}); } - continue; } - LayoutBuilder layoutBuilder(rewriter, loc); - LayoutValueAdaptor layout(val, layoutTy.getAttr()); - if (!appendIntTuplePrintf(rewriter, loc, layoutBuilder.getShape(layout), format, args)) { - return failure(); + } + } + + if (isGpuContext) { + // For GPU, build printf format string + for (const auto &seg : segments) { + if (seg.argIndex >= 0) { + castPrintfArg(rewriter, loc, args[seg.argIndex], format); + } else { + format += seg.text; } - format += ":"; - if (!appendIntTuplePrintf(rewriter, loc, layoutBuilder.getStride(layout), format, args)) { - return failure(); + } + format += "\n"; + gpu::PrintfOp::create(rewriter, loc, rewriter.getStringAttr(format), args); + } else { + // For CPU, print segments in order + for (size_t i = 0; i < segments.size(); ++i) { + const auto &seg = segments[i]; + if (seg.argIndex >= 0) { + bool isLast = (i == segments.size() - 1); + auto punctuation = + isLast ? vector::PrintPunctuation::NewLine : vector::PrintPunctuation::NoPunctuation; + vector::PrintOp::create(rewriter, loc, args[seg.argIndex], punctuation); + } else if (!seg.text.empty()) { + vector::PrintOp::create(rewriter, loc, seg.text); } - continue; - } else { - return failure(); + } + if (segments.empty() || segments.back().argIndex < 0) { + vector::PrintOp::create(rewriter, loc, vector::PrintPunctuation::NewLine); } } - format += "\n"; - gpu::PrintfOp::create(rewriter, loc, rewriter.getStringAttr(format), args); rewriter.eraseOp(op); return success(); } @@ -1305,7 +1502,7 @@ class FlyLayoutLoweringPass RewritePatternSet patterns(context); patterns.add(context); - patterns.add(context); + patterns.add(context); patterns.add(context); patterns.add(context); diff --git a/lib/Dialect/Fly/Utils/NormalForm.cpp b/lib/Dialect/Fly/Utils/NormalForm.cpp index c9ee4135..6756d267 100644 --- a/lib/Dialect/Fly/Utils/NormalForm.cpp +++ b/lib/Dialect/Fly/Utils/NormalForm.cpp @@ -23,8 +23,7 @@ bool isNormalForm(TypedValue value) { bool isNormalForm(TypedValue value) { Operation *defOp = value.getDefiningOp(); if (!defOp) { - auto tupleTy = value.getType(); - return tupleTy.getAttr().isStatic(); + return false; } // if (isa(defOp)) { // return true; @@ -41,8 +40,7 @@ bool isNormalForm(TypedValue value) { bool isNormalForm(TypedValue value) { Operation *defOp = value.getDefiningOp(); if (!defOp) { - auto layoutTy = value.getType(); - return layoutTy.getAttr().isStatic(); + return false; } // NormalLayout ::= (StaticOp) // if (isa(defOp)) { diff --git a/python/flydsl/lang/ir/core.py b/python/flydsl/lang/ir/core.py index adf8d0c0..1f8b0b32 100644 --- a/python/flydsl/lang/ir/core.py +++ b/python/flydsl/lang/ir/core.py @@ -21,6 +21,7 @@ MemRefType, CopyAtomUniversalCopyType, MmaAtomUniversalFMAType, + DLTensorAdaptor, ) from ..._mlir.dialects.fly_rocdl import ( @@ -371,18 +372,13 @@ def cooperative_copy(tiled_copy, partition_idx, src, dst, loc=None, ip=None): @dsl_api_wrapper -def print_op(*values, format_str="", loc=None, ip=None): - """ - Print operation for debugging. Supports IntTuple and other value types. - Lowers to printf for host code or gpu.printf for device code. - - Example: - fx.print_op(int_tuple) - fx.print_op(layout) - fx.print_op(value1, value2, value3) - fx.print_op(value1, format_str="v1=%d\n") - """ - return _fly_ir.print_(format_str, list(values), loc=loc, ip=ip) +def print_op(*args, format_str="", loc=None, ip=None): + if len(args) > 0 and isinstance(args[0], str): + format_str = args[0] + values = list(args[1:]) + else: + values = list(args) + return _fly_ir.print_(format_str, values, loc=loc, ip=ip) # ============================================================================== diff --git a/thirdparty/dlpack b/thirdparty/dlpack new file mode 160000 index 00000000..84d107bf --- /dev/null +++ b/thirdparty/dlpack @@ -0,0 +1 @@ +Subproject commit 84d107bf416c6bab9ae68ad285876600d230490d diff --git a/thirdparty/tvm-ffi b/thirdparty/tvm-ffi new file mode 160000 index 00000000..ed067c17 --- /dev/null +++ b/thirdparty/tvm-ffi @@ -0,0 +1 @@ +Subproject commit ed067c17b259774f4ddc23ba7de937a90642bbb1 From 44c6a3baed99bc2e897c85af31e1f7854912bf6a Mon Sep 17 00:00:00 2001 From: jli Date: Mon, 2 Feb 2026 15:25:28 +0800 Subject: [PATCH 008/113] Fix Python compatibility and remove hardcoded paths (#78) - Fix Python version compatibility in meta.py: add support for Python < 3.11 by checking for positions attribute availability - Replace hardcoded MLIR library paths in executor.py with environment variable MLIR_PATH, with clear error message when not set - Update LLVM commit hash and enable ROCM runner in build script --- python/flydsl/compiler/executor.py | 18 ++++++++++++++++-- python/flydsl/lang/meta.py | 11 +++++++++-- scripts/build_llvm.sh | 4 +++- 3 files changed, 28 insertions(+), 5 deletions(-) diff --git a/python/flydsl/compiler/executor.py b/python/flydsl/compiler/executor.py index 1d1dcd49..552ef15e 100644 --- a/python/flydsl/compiler/executor.py +++ b/python/flydsl/compiler/executor.py @@ -1,4 +1,5 @@ import ctypes +import os import torch @@ -7,12 +8,25 @@ class Executor: def __init__(self, jit_module): + # Get MLIR_PATH from environment variable + mlir_path = os.environ.get('MLIR_PATH') + if not mlir_path: + raise RuntimeError( + "Environment variable MLIR_PATH is not set!\n" + "Please set MLIR_PATH before running the program, for example:\n" + " export MLIR_PATH=/path/to/llvm-project/buildmlir\n" + "Or source the configuration script:\n" + " source pre_build.sh" + ) + + lib_dir = os.path.join(mlir_path, 'lib') + self.engine = ExecutionEngine( jit_module, opt_level=3, shared_libs=[ - "/root/Projects/llvm-project/build/lib/libmlir_rocm_runtime.so", - "/root/Projects/llvm-project/build/lib/libmlir_runner_utils.so", + os.path.join(lib_dir, "libmlir_rocm_runtime.so"), + os.path.join(lib_dir, "libmlir_runner_utils.so"), ], ) self.engine.initialize() diff --git a/python/flydsl/lang/meta.py b/python/flydsl/lang/meta.py index 7c4d52e0..ed52f839 100644 --- a/python/flydsl/lang/meta.py +++ b/python/flydsl/lang/meta.py @@ -11,10 +11,17 @@ def wrapper(*args, **kwargs): if loc is None: frame = inspect.currentframe().f_back frameInfo = inspect.getframeinfo(frame) + # Compatible with different Python versions: positions attribute is available in Python 3.11+ + if hasattr(frameInfo, 'positions') and frameInfo.positions: + lineno = frameInfo.positions.lineno + col_offset = frameInfo.positions.col_offset + else: + lineno = frameInfo.lineno + col_offset = 0 # Older versions don't provide column offset information file_loc = ir.Location.file( frameInfo.filename, - frameInfo.positions.lineno, - frameInfo.positions.col_offset, + lineno, + col_offset, ) loc = ir.Location.name( ( diff --git a/scripts/build_llvm.sh b/scripts/build_llvm.sh index 4c4fa28e..28e290b3 100755 --- a/scripts/build_llvm.sh +++ b/scripts/build_llvm.sh @@ -10,7 +10,8 @@ LLVM_BUILD_DIR="$LLVM_SRC_DIR/buildmlir" LLVM_INSTALL_DIR="${LLVM_INSTALL_DIR:-$LLVM_SRC_DIR/mlir_install}" LLVM_INSTALL_TGZ="${LLVM_INSTALL_TGZ:-$LLVM_SRC_DIR/mlir_install.tgz}" LLVM_PACKAGE_INSTALL="${LLVM_PACKAGE_INSTALL:-1}" -LLVM_COMMIT="${LLVM_COMMIT:-04f968b02917}" +# LLVM_COMMIT="${LLVM_COMMIT:-04f968b02917}" +LLVM_COMMIT="${LLVM_COMMIT:-edf06d742821f34060f924dd9db5e01bed90c030}" echo "Base directory: $BASE_DIR" echo "LLVM Source: $LLVM_SRC_DIR" @@ -72,6 +73,7 @@ cmake -G "$GENERATOR" \ -DLLVM_ENABLE_ASSERTIONS=ON \ -DLLVM_INSTALL_UTILS=ON \ -DMLIR_ENABLE_BINDINGS_PYTHON=ON \ + -DMLIR_ENABLE_ROCM_RUNNER=ON \ -DPython3_EXECUTABLE=$(which python3) \ -Dnanobind_DIR="$NANOBIND_DIR" \ -DBUILD_SHARED_LIBS=OFF \ From ce61b2beff87a6e0fa7736ed33f31b7734267cbb Mon Sep 17 00:00:00 2001 From: Feng Shijie Date: Wed, 4 Feb 2026 10:53:08 +0000 Subject: [PATCH 009/113] Add logger and EnvManager --- python/flydsl/utils/__init__.py | 7 + python/flydsl/utils/env.py | 218 +++++++++++++++++++++++++++++ python/flydsl/utils/env_manager.py | 0 python/flydsl/utils/hip_utils.py | 2 - python/flydsl/utils/logger.py | 47 +++++++ 5 files changed, 272 insertions(+), 2 deletions(-) create mode 100644 python/flydsl/utils/env.py delete mode 100644 python/flydsl/utils/env_manager.py delete mode 100644 python/flydsl/utils/hip_utils.py diff --git a/python/flydsl/utils/__init__.py b/python/flydsl/utils/__init__.py index e69de29b..5d7e2d98 100644 --- a/python/flydsl/utils/__init__.py +++ b/python/flydsl/utils/__init__.py @@ -0,0 +1,7 @@ +from . import env +from .logger import log + +__all__ = [ + "env", + "log", +] diff --git a/python/flydsl/utils/env.py b/python/flydsl/utils/env.py new file mode 100644 index 00000000..82441a36 --- /dev/null +++ b/python/flydsl/utils/env.py @@ -0,0 +1,218 @@ +import os +import re +from pathlib import Path +from typing import Any, Callable, Dict, Generic, Optional, TypeVar + +T = TypeVar("T") + + +class EnvOption(Generic[T]): + def __init__( + self, + default: T, + env_var: Optional[str] = None, + description: str = "", + validator: Optional[Callable[[T], bool]] = None, + ): + self.default = default + self.env_var = env_var + self.description = description + self.validator = validator + self.name: Optional[str] = None + + def __set_name__(self, owner: type, name: str): + self.name = name + + def parse_value(self, raw: str) -> T: + raise NotImplementedError + + def __get__(self, obj: Optional[object], objtype: Optional[type] = None) -> T: + if obj is None: + return self # type: ignore + + if self.env_var is None: + raise RuntimeError( + f"EnvOption '{self.name or ''}' has no env_var set. " + "EnvOption must be used as a class attribute in an EnvManager subclass." + ) + + raw = os.environ.get(self.env_var) + if raw is None: + return self.default + + try: + value = self.parse_value(raw) + except (ValueError, TypeError) as e: + raise ValueError(f"Failed to parse environment variable {self.env_var}={raw!r}: {e}") from e + + if self.validator is not None and not self.validator(value): + raise ValueError(f"Invalid value for environment variable {self.env_var}: {value!r}") + + return value + + +class OptBool(EnvOption[bool]): + def __init__( + self, + default: bool = False, + env_var: Optional[str] = None, + description: str = "", + ): + super().__init__(default, env_var, description) + + def parse_value(self, raw: str) -> bool: + return raw.lower() in ("1", "true", "yes", "on") + + +class OptInt(EnvOption[int]): + def __init__( + self, + default: int = 0, + env_var: Optional[str] = None, + description: str = "", + min_value: Optional[int] = None, + max_value: Optional[int] = None, + ): + validator = None + if min_value is not None or max_value is not None: + + def validator(v: int) -> bool: + if min_value is not None and v < min_value: + return False + if max_value is not None and v > max_value: + return False + return True + + super().__init__(default, env_var, description, validator) + self.min_value = min_value + self.max_value = max_value + + def parse_value(self, raw: str) -> int: + return int(raw) + + +class OptStr(EnvOption[str]): + def __init__( + self, + default: str = "", + env_var: Optional[str] = None, + description: str = "", + choices: Optional[list[str]] = None, + ): + validator = None + if choices is not None: + + def validator(v: str) -> bool: + return v in choices + + super().__init__(default, env_var, description, validator) + self.choices = choices + + def parse_value(self, raw: str) -> str: + return raw + + +E = TypeVar("E", int, str) + + +class OptList(EnvOption[list[E]]): + def __init__( + self, + default: Optional[list[E]] = None, + env_var: Optional[str] = None, + description: str = "", + separator: str = ",", + element_type: type = str, + ): + super().__init__(default or [], env_var, description) + self.separator = separator + self.element_type = element_type + + def parse_value(self, raw: str) -> list[E]: + if not raw: + return [] + items = [s.strip() for s in raw.split(self.separator)] + if self.element_type is int: + return [int(s) for s in items] + return items + + +class EnvManagerMeta(type): + def __new__(mcs, name: str, bases: tuple, namespace: dict, **kwargs): + parent_prefix = None + env_bases = [b for b in bases if hasattr(b, "env_prefix")] + if len(env_bases) > 1: + raise TypeError(f"EnvManager subclass '{name}' can only inherit from one EnvManager parent") + + parent_prefix = env_bases[0].env_prefix if env_bases else None + + if "env_prefix" in namespace: + child_prefix = namespace["env_prefix"] + if parent_prefix: + namespace["env_prefix"] = f"{parent_prefix}_{child_prefix}" + elif parent_prefix: + namespace["env_prefix"] = parent_prefix + + cls = super().__new__(mcs, name, bases, namespace) + + options: Dict[str, EnvOption] = {} + for key, value in namespace.items(): + if isinstance(value, EnvOption): + if value.env_var is None: + upper_key = re.sub(r"([a-z])([A-Z])", r"\1_\2", key).upper() + value.env_var = f"{cls.env_prefix}_{upper_key}" + options[key] = value + + cls.options = options + return cls + + +class EnvManager(metaclass=EnvManagerMeta): + env_prefix: str = "FLYDSL" + options: Dict[str, EnvOption] + + def to_dict(self) -> Dict[str, Any]: + return {name: getattr(self, name) for name in self.options} + + @classmethod + def help(cls) -> str: + lines = [f"{cls.__name__} Options:", ""] + for name, opt in cls.options.items(): + desc = opt.description or "No description" + lines.append(f" {name}:") + lines.append(f" Environment: {opt.env_var}") + lines.append(f" Default: {opt.default!r}") + lines.append(f" Description: {desc}") + lines.append("") + return "\n".join(lines) + + def __repr__(self) -> str: + return self.help() + + +class CompileEnvManager(EnvManager): + env_prefix = "COMPILE" + + opt_level = OptInt(2, min_value=0, max_value=3, description="Optimization level") + enable_debug_info = OptBool(True, description="Generate debug info in compiled code") + enable_verifier = OptBool(False, description="Verify IR module") + print_after_all = OptBool(False, description="Print IR after each MLIR pass") + + +class RuntimeEnvManager(EnvManager): + env_prefix = "RUNTIME" + + log_level = OptStr("WARNING", choices=["DEBUG", "INFO", "WARNING", "ERROR"], description="Logging level") + log_to_file = OptStr("", description="Log file path, empty to disable file logging") + log_to_console = OptBool(False, description="Enable console logging") + cache_dir = OptStr(str(Path.home() / ".flydsl" / "cache"), description="Directory for caching compiled kernels") + enable_cache = OptBool(True, description="Enable kernel caching") + + +compile = CompileEnvManager() +runtime = RuntimeEnvManager() + +__all__ = [ + "compile", + "runtime", +] diff --git a/python/flydsl/utils/env_manager.py b/python/flydsl/utils/env_manager.py deleted file mode 100644 index e69de29b..00000000 diff --git a/python/flydsl/utils/hip_utils.py b/python/flydsl/utils/hip_utils.py deleted file mode 100644 index 146515fa..00000000 --- a/python/flydsl/utils/hip_utils.py +++ /dev/null @@ -1,2 +0,0 @@ -def get_hip_arch(): - pass diff --git a/python/flydsl/utils/logger.py b/python/flydsl/utils/logger.py index e69de29b..5dc1c832 100644 --- a/python/flydsl/utils/logger.py +++ b/python/flydsl/utils/logger.py @@ -0,0 +1,47 @@ +import logging +import sys + +__all__ = ["log"] + +_FORMAT = "%(asctime)s - %(levelname)-8s - [%(funcName)s] - %(message)s" +_FORMAT_SIMPLE = "| %(levelname)-8s - [%(funcName)s] - %(message)s" + +_logger: logging.Logger = None +_initialized = False + + +def _init_logger(): + global _logger, _initialized + if _initialized: + return + + from .env import runtime + + _logger = logging.getLogger("flydsl") + _logger.setLevel(logging.DEBUG) + _logger.propagate = False + + level = getattr(logging, runtime.log_level) + + if runtime.log_to_console: + console_handler = logging.StreamHandler(sys.stderr) + console_handler.setFormatter(logging.Formatter(_FORMAT_SIMPLE)) + console_handler.setLevel(level) + _logger.addHandler(console_handler) + + if runtime.log_to_file: + file_handler = logging.FileHandler(runtime.log_to_file, mode="a", encoding="utf-8") + file_handler.setFormatter(logging.Formatter(_FORMAT)) + file_handler.setLevel(level) + _logger.addHandler(file_handler) + + if not _logger.handlers: + _logger.addHandler(logging.NullHandler()) + + _initialized = True + + +def log() -> logging.Logger: + if not _initialized: + _init_logger() + return _logger From bba542210f94bbe2305874fa24c8db0e57ca8ace Mon Sep 17 00:00:00 2001 From: Feng Shijie Date: Thu, 5 Feb 2026 08:23:52 +0000 Subject: [PATCH 010/113] Refact Python module --- cmake/llvm-hash.txt | 1 + examples/01-vector-add.py | 106 ----- examples/01-vectorAdd.py | 72 ++++ examples/02-layout_algebra.py | 67 ---- examples/03-mma_atom.py | 124 ------ lib/Bindings/Python/DLTensorAdaptor.h | 90 +++-- lib/Bindings/Python/FlyExtension.cpp | 7 +- lib/Dialect/Fly/IR/FlyOps.cpp | 4 +- pyproject.toml | 19 + python/flydsl/__init__.py | 2 + python/flydsl/compiler/__init__.py | 11 +- python/flydsl/compiler/ast_rewriter.py | 17 + python/flydsl/compiler/compiler.py | 148 ------- python/flydsl/compiler/executor.py | 58 --- python/flydsl/compiler/jit_argument.py | 144 +++++++ python/flydsl/compiler/jit_executor.py | 112 ++++++ python/flydsl/compiler/jit_function.py | 252 ++++++++++++ python/flydsl/compiler/kernel_function.py | 396 +++++++++++++++++++ python/flydsl/compiler/protocol.py | 61 +++ python/flydsl/expr/primitive.py | 437 +++++++++++++++++++++ python/flydsl/expr/typing.py | 66 ++++ python/flydsl/lang/__init__.py | 2 - python/flydsl/lang/ir/__init__.py | 6 - python/flydsl/lang/ir/core.py | 386 ------------------ python/flydsl/lang/ir/gpu.py | 457 ---------------------- python/flydsl/lang/ir/module.py | 212 ---------- python/flydsl/lang/ir/types.py | 5 - python/flydsl/lang/meta.py | 37 -- python/flydsl/lang/typing.py | 32 -- python/flydsl/utils/env.py | 24 +- python/flydsl/utils/logger.py | 12 +- python/mlir_flydsl/CMakeLists.txt | 26 ++ scripts/build_llvm.sh | 131 +++---- scripts/dumpir.sh | 6 - 34 files changed, 1771 insertions(+), 1759 deletions(-) create mode 100644 cmake/llvm-hash.txt delete mode 100644 examples/01-vector-add.py create mode 100644 examples/01-vectorAdd.py delete mode 100644 examples/02-layout_algebra.py delete mode 100644 examples/03-mma_atom.py create mode 100644 python/flydsl/compiler/ast_rewriter.py delete mode 100644 python/flydsl/compiler/compiler.py delete mode 100644 python/flydsl/compiler/executor.py create mode 100644 python/flydsl/compiler/jit_argument.py create mode 100644 python/flydsl/compiler/jit_executor.py create mode 100644 python/flydsl/compiler/jit_function.py create mode 100644 python/flydsl/compiler/kernel_function.py create mode 100644 python/flydsl/compiler/protocol.py create mode 100644 python/flydsl/expr/primitive.py create mode 100644 python/flydsl/expr/typing.py delete mode 100644 python/flydsl/lang/__init__.py delete mode 100644 python/flydsl/lang/ir/__init__.py delete mode 100644 python/flydsl/lang/ir/core.py delete mode 100644 python/flydsl/lang/ir/gpu.py delete mode 100644 python/flydsl/lang/ir/module.py delete mode 100644 python/flydsl/lang/ir/types.py delete mode 100644 python/flydsl/lang/meta.py delete mode 100644 python/flydsl/lang/typing.py delete mode 100755 scripts/dumpir.sh diff --git a/cmake/llvm-hash.txt b/cmake/llvm-hash.txt new file mode 100644 index 00000000..7b94a6d8 --- /dev/null +++ b/cmake/llvm-hash.txt @@ -0,0 +1 @@ +edf06d742821 diff --git a/examples/01-vector-add.py b/examples/01-vector-add.py deleted file mode 100644 index 2b5bca72..00000000 --- a/examples/01-vector-add.py +++ /dev/null @@ -1,106 +0,0 @@ -import flydsl -from flydsl import lang as fx - -N = 64 -memrefTy = fx.MemRefType.get( - fx.T.f32(), fx.LayoutType.get(64, 1), fx.AddressSpace.Global -) - - -class VecAdd(fx.MlirModule): - def __init__(self): - super().__init__() - - @fx.kernel - def kernel( - self: fx.T.i64(), - A: memrefTy, - B: memrefTy, - C: memrefTy, - ): - tid = fx.arith.index_cast(fx.T.i32(), fx.thread_idx.x) - bid = fx.arith.index_cast(fx.T.i32(), fx.block_idx.x) - - tA = fx.logical_divide(A, fx.make_layout(16, 1)) - tB = fx.logical_divide(B, fx.make_layout(16, 1)) - tC = fx.logical_divide(C, fx.make_layout(16, 1)) - - tA = fx.slice(tA, (None, bid)) - tB = fx.slice(tB, (None, bid)) - tC = fx.slice(tC, (None, bid)) - tA = fx.logical_divide(tA, fx.make_layout(1, 1)) - tB = fx.logical_divide(tB, fx.make_layout(1, 1)) - tC = fx.logical_divide(tC, fx.make_layout(1, 1)) - - RABMemRefTy = fx.MemRefType.get( - fx.T.f32(), fx.LayoutType.get(1, 1), fx.AddressSpace.Register - ) - copyAtom = fx.make_atom(fx.CopyAtomUniversalCopyType.get(32)) - rA = fx.memref_alloca(RABMemRefTy, fx.make_layout(1, 1)) - rB = fx.memref_alloca(RABMemRefTy, fx.make_layout(1, 1)) - rC = fx.memref_alloca(RABMemRefTy, fx.make_layout(1, 1)) - - fx.copy_atom_call(copyAtom, fx.slice(tA, (None, tid)), rA) - fx.copy_atom_call(copyAtom, fx.slice(tB, (None, tid)), rB) - - vC = fx.arith.addf(fx.memref_load_vec(rA), fx.memref_load_vec(rB)) - fx.memref_store_vec(vC, rC) - - fx.copy_atom_call(copyAtom, rC, fx.slice(tC, (None, tid))) - - @fx.jit - def __call__( - self: fx.T.i64(), - A: memrefTy, - B: memrefTy, - C: memrefTy, - ): - size = fx.size(A) - - size = fx.get_scalar(size) - - x = fx.arith.constant(fx.T.i64(), 16) - c1 = fx.arith.constant(fx.T.index(), 1) - c16 = fx.arith.constant(fx.T.index(), 16) - - gN = fx.arith.ceildivsi(size, fx.arith.constant(fx.T.i32(), 16)) - gN = fx.arith.IndexCastOp(fx.T.index(), gN) - - kernel_sym = fx.ir.SymbolRefAttr.get(["kernels", "kernel"]) - fx.LaunchFuncOp( - kernel_sym, - grid_size=[gN, c1, c1], - block_size=[c16, c1, c1], - kernel_operands=[x, A, B, C], - ) - - -VecAdd_Module = VecAdd() -print(VecAdd_Module) - - -VecAdd_Executor = flydsl.compile(VecAdd_Module, print_after_all=False) -# VecAdd_Asm = flydsl.compile(VecAdd_Module, output_format="assembly") -# print(VecAdd_Asm) - -import torch - -tA = torch.randint(0, 10, (N,), dtype=torch.float32, device="cuda") - -tB = torch.randint(0, 10, (N,), dtype=torch.float32, device="cuda") -tC = torch.randint(0, 10, (N,), dtype=torch.float32, device="cuda") - -tAmk = torch.randint(0, 10, (N, N), dtype=torch.float32, device="cuda") - -VecAdd_Executor(tA, tB, tC) -is_closed = torch.allclose(tC, tA + tB) -print("Result correct:", is_closed) - - -if not is_closed: - print("tA:", tA[:32]) - print("tB:", tB[:32]) - print("tC:", tC[:32]) - - -print("Hello, Fly!") diff --git a/examples/01-vectorAdd.py b/examples/01-vectorAdd.py new file mode 100644 index 00000000..c7f59f7a --- /dev/null +++ b/examples/01-vectorAdd.py @@ -0,0 +1,72 @@ +import torch + +import flydsl.compiler as fmc +import flydsl.expr as fx + + +@fmc.kernel +def vectorAddKernel(A: fx.Tensor, B: fx.Tensor, C: fx.Tensor, block_dim: fx.Constexpr[int]): + tid = fx.arith.index_cast(fx.T.i32(), fx.thread_idx.x) + bid = fx.arith.index_cast(fx.T.i32(), fx.block_idx.x) + + tA = fx.logical_divide(A.value, fx.make_layout(block_dim, 1)) + tB = fx.logical_divide(B.value, fx.make_layout(block_dim, 1)) + tC = fx.logical_divide(C.value, fx.make_layout(block_dim, 1)) + + tA = fx.slice(tA, (None, bid)) + tB = fx.slice(tB, (None, bid)) + tC = fx.slice(tC, (None, bid)) + tA = fx.logical_divide(tA, fx.make_layout(1, 1)) + tB = fx.logical_divide(tB, fx.make_layout(1, 1)) + tC = fx.logical_divide(tC, fx.make_layout(1, 1)) + + RABMemRefTy = fx.MemRefType.get(fx.T.f32(), fx.LayoutType.get(1, 1), fx.AddressSpace.Register) + copyAtom = fx.make_atom(fx.CopyAtomUniversalCopyType.get(32)) + rA = fx.memref_alloca(RABMemRefTy, fx.make_layout(1, 1)) + rB = fx.memref_alloca(RABMemRefTy, fx.make_layout(1, 1)) + rC = fx.memref_alloca(RABMemRefTy, fx.make_layout(1, 1)) + + fx.copy_atom_call(copyAtom, fx.slice(tA, (None, tid)), rA) + fx.copy_atom_call(copyAtom, fx.slice(tB, (None, tid)), rB) + + vC = fx.arith.addf(fx.memref_load_vec(rA), fx.memref_load_vec(rB)) + fx.memref_store_vec(vC, rC) + + fx.copy_atom_call(copyAtom, rC, fx.slice(tC, (None, tid))) + + +@fmc.jit +def vectorAdd( + A: fx.Tensor, + B: fx.Tensor, + C, # omitted for auto induction + n: fx.Int32, + const_n: fx.Constexpr[int], + stream: fx.Stream = fx.Stream(None), +): + print("> Runtime: n={} const_n={}", n.value, const_n) + fx.printf("> Runtime: n={} const_n={}", n.value, const_n) + + block_dim = 64 + c64 = fx.arith.constant(fx.T.i32(), block_dim) + grid_x = fx.arith.ceildivsi(n.value, c64) + vectorAddKernel(A, B, C, block_dim).launch(grid=(grid_x, 1, 1), block=[block_dim, 1, 1], stream=stream.value) + + +n = 128 +A = torch.randint(0, 10, (n,), dtype=torch.float32).cuda() +B = torch.randint(0, 10, (n,), dtype=torch.float32).cuda() +C = torch.zeros(n, dtype=torch.float32).cuda() + +tA = fmc.from_dlpack(A).mark_layout_dynamic(leading_dim=0, divisibility=4) +vectorAdd(tA, B, C, n, n, stream=torch.cuda.Stream()) + +torch.cuda.synchronize() + +is_closed = torch.allclose(C, A + B) +print("Result correct:", is_closed) +if not is_closed: + print("tA:", A[:32]) + print("tB:", B[:32]) + print("tC:", C[:32]) +print("Hello, Fly!") diff --git a/examples/02-layout_algebra.py b/examples/02-layout_algebra.py deleted file mode 100644 index 2e8eaa44..00000000 --- a/examples/02-layout_algebra.py +++ /dev/null @@ -1,67 +0,0 @@ -import flydsl -from flydsl import lang as fx - -M = 16 -N = 32 -memrefTy = fx.MemRefType.get( - fx.T.f32(), fx.LayoutType.get((M, N), (N, 1)), fx.AddressSpace.Global -) - - -class VecCopy(fx.MlirModule): - def __init__(self, thr_dim, val_dim): - super().__init__() - - @fx.kernel - def kernel( - self: fx.T.i64(), - A: memrefTy, - B: memrefTy, - ): - tid = fx.arith.index_cast(fx.T.i32(), fx.thread_idx.x) - bid = fx.arith.index_cast(fx.T.i32(), fx.block_idx.x) - - print(type(tid), tid) - - l16 = fx.make_layout(16, 1) - tile = fx.make_tile([l16, l16]) - - tA = fx.logical_divide(A, tile) - tB = fx.logical_divide(B, tile) - - tA = fx.zipped_divide(A, tile) - tB = fx.zipped_divide(B, tile) - - tA = fx.slice(tA, ((None, None), bid)) - tB = fx.slice(tB, ((None, None), bid)) - - vec = fx.memref_load(tA, tid) - fx.memref_store(vec, tB, tid) - - @fx.jit - def __call__( - self: fx.T.i64(), - A: memrefTy, - B: memrefTy, - ): - x = fx.arith.constant(fx.T.i64(), 16) - c1 = fx.arith.constant(fx.T.index(), 1) - c256 = fx.arith.constant(fx.T.index(), 256) - gN = fx.arith.constant(fx.T.index(), N // 16) - - kernel_sym = fx.ir.SymbolRefAttr.get(["kernels", "kernel"]) - fx.LaunchFuncOp( - kernel_sym, - grid_size=[gN, c1, c1], - block_size=[c256, c1, c1], - kernel_operands=[x, A, B], - ) - - -ThrPerBlock = 256 -ValPerThr = 8 - -VecCopy_Module = VecCopy(thr_dim=ThrPerBlock, val_dim=ValPerThr) -print(VecCopy_Module) - -VecCopy_Executor = flydsl.compile(VecCopy_Module, print_after_all=False) diff --git a/examples/03-mma_atom.py b/examples/03-mma_atom.py deleted file mode 100644 index 9f41b0b9..00000000 --- a/examples/03-mma_atom.py +++ /dev/null @@ -1,124 +0,0 @@ -import flydsl -from flydsl import lang as fx - -MN = 16 -K = 4 -ABMemRefTy = fx.MemRefType.get( - fx.T.f32(), fx.LayoutType.get((MN, K), (K, 1)), fx.AddressSpace.Global -) -CMemRefTy = fx.MemRefType.get( - fx.T.f32(), fx.LayoutType.get((MN, MN), (1, MN)), fx.AddressSpace.Global -) - - -class MmaAtom(fx.MlirModule): - def __init__(self): - super().__init__() - - @fx.kernel - def kernel( - self: fx.T.i64(), - A: ABMemRefTy, - B: ABMemRefTy, - C: CMemRefTy, - ): - tid = fx.arith.index_cast(fx.T.i32(), fx.thread_idx.x) - - rA = fx.memref_alloca( - fx.MemRefType.get(fx.T.f32(), fx.LayoutType.get(1, 1)), - fx.make_layout(1, 1), - ) - rB = fx.memref_alloca( - fx.MemRefType.get(fx.T.f32(), fx.LayoutType.get(1, 1)), - fx.make_layout(1, 1), - ) - - copyAtom = fx.make_atom(fx.CopyAtomUniversalCopyType.get(32)) - mmaAtom = fx.make_atom( - fx.MmaAtomCDNA3_MFMAType.get(16, 16, 16, fx.T.f32(), fx.T.f32(), fx.T.f32()) - ) - - tA = fx.logical_divide(A, fx.make_layout(1, 1)) - tB = fx.logical_divide(B, fx.make_layout(1, 1)) - fx.copy_atom_call(copyAtom, fx.slice(tA, (None, tid)), rA) - fx.copy_atom_call(copyAtom, fx.slice(tB, (None, tid)), rB) - - rAcc = fx.memref_alloca( - fx.MemRefType.get(fx.T.f32(), fx.LayoutType.get(4, 1)), fx.make_layout(4, 1) - ) - f0 = fx.arith.constant(fx.T.f32(), 0.0) - fx.memref_store(f0, rAcc, 0) - fx.memref_store(f0, rAcc, 1) - fx.memref_store(f0, rAcc, 2) - fx.memref_store(f0, rAcc, 3) - fx.mma_atom_call(mmaAtom, rAcc, rA, rB, rAcc) - - tC = fx.zipped_divide( - C, fx.make_tile([fx.make_layout(4, 1), fx.make_layout(1, 1)]) - ) - permutation_tile = fx.make_tile([fx.make_layout(1, 1), fx.make_layout(16, 4)]) - tC = fx.logical_divide(tC, permutation_tile) - - fx.copy_atom_call(copyAtom, rAcc, fx.slice(tC, (None, tid))) - - @fx.jit - def __call__( - self: fx.T.i64(), - A: ABMemRefTy, - B: ABMemRefTy, - C: CMemRefTy, - ): - x = fx.arith.constant(fx.T.i64(), 16) - c1 = fx.arith.constant(fx.T.index(), 1) - c64 = fx.arith.constant(fx.T.index(), 64) - - kernel_sym = fx.ir.SymbolRefAttr.get(["kernels", "kernel"]) - fx.LaunchFuncOp( - kernel_sym, - grid_size=[c1, c1, c1], - block_size=[c64, c1, c1], - kernel_operands=[x, A, B, C], - ) - - -MmaAtom_Module = MmaAtom() -print(MmaAtom_Module) - -MmaAtom_Executor = flydsl.compile(MmaAtom_Module, print_after_all=False) -# MmaAtom_Asm = flydsl.compile(MmaAtom_Module, output_format="assembly") -# print(MmaAtom_Asm) - -import torch - -tA = torch.randint( - 0, - 10, - (MN, K), - dtype=torch.float32, - device="cuda", -) -tB = torch.randint( - 0, - 10, - (MN, K), - dtype=torch.float32, - device="cuda", -) -tC = torch.empty( - (MN, MN), - dtype=torch.float32, - device="cuda", -) -tC_ref = tA @ tB.T - -MmaAtom_Executor(tA, tB, tC) -is_closed = torch.allclose(tC.T, tC_ref) -print("Result correct:", is_closed) - -if not is_closed: - print("tA:", tA) - print("tB:", tB) - print("tC:", tC.T) - print("tC:", tC_ref) - -print("Hello, Fly!") diff --git a/lib/Bindings/Python/DLTensorAdaptor.h b/lib/Bindings/Python/DLTensorAdaptor.h index 866ee344..12c6088e 100644 --- a/lib/Bindings/Python/DLTensorAdaptor.h +++ b/lib/Bindings/Python/DLTensorAdaptor.h @@ -16,6 +16,7 @@ #include #include +#include #include namespace nb = nanobind; @@ -27,6 +28,19 @@ using namespace mlir::python::nanobind_adaptors; namespace mlir::fly::utils { +inline MLIRContext *getCurrentContext() { + nb::object currentCtx = mlir::python::irModule().attr("Context").attr("current"); + if (currentCtx.is_none()) { + throw std::runtime_error("No MLIR context available. Either pass a context explicitly or " + "call within an active ir.Context (using 'with context:')"); + } + auto capsule = mlirApiObjectToCapsule(currentCtx); + if (!capsule) { + throw std::runtime_error("Invalid MLIR context capsule"); + } + return unwrap(mlirPythonCapsuleToContext(capsule->ptr())); +} + class DLTensorAdaptor { private: struct DimInfo { @@ -59,10 +73,8 @@ class DLTensorAdaptor { }; public: - DLTensorAdaptor(nb::object dlpackCapsule, int32_t alignment, bool use32BitStride, - MlirContext context) - : dlpackCapsule_(dlpackCapsule), alignment_(alignment), use32BitStride_(use32BitStride), - ctx_(unwrap(context)) { + DLTensorAdaptor(nb::object dlpackCapsule, std::optional alignment, bool use32BitStride) + : dlpackCapsule_(dlpackCapsule), use32BitStride_(use32BitStride) { DLManagedTensor *managed = static_cast(PyCapsule_GetPointer(dlpackCapsule.ptr(), "dltensor")); if (!managed) { @@ -70,6 +82,16 @@ class DLTensorAdaptor { } tensor_ = &managed->dl_tensor; + // Calculate element size in bytes (minimum 1 byte) + int32_t bitsPerElem = tensor_->dtype.bits * tensor_->dtype.lanes; + int32_t bytesPerElem = (bitsPerElem + 7) / 8; + + // Set alignment: use provided value or default to element size + alignment_ = alignment.value_or(bytesPerElem); + if (alignment_ < 1) { + throw std::runtime_error("Alignment must be at least 1"); + } + ndim_ = tensor_->ndim; if (ndim_ == 0) { throw std::runtime_error("DLTensor must have at least one dimension"); @@ -133,45 +155,46 @@ class DLTensorAdaptor { } } - Type getElementType() const { + Type getElementType() { DLDataType dtype = tensor_->dtype; + MLIRContext *ctx = getCurrentContext(); switch (dtype.code) { case kDLFloat: switch (dtype.bits) { case 16: - return Float16Type::get(ctx_); + return Float16Type::get(ctx); case 32: - return Float32Type::get(ctx_); + return Float32Type::get(ctx); case 64: - return Float64Type::get(ctx_); + return Float64Type::get(ctx); default: throw std::runtime_error("Unsupported float bit width: " + std::to_string(dtype.bits)); } case kDLInt: - return IntegerType::get(ctx_, dtype.bits, IntegerType::Signed); + return IntegerType::get(ctx, dtype.bits, IntegerType::Signed); case kDLUInt: - return IntegerType::get(ctx_, dtype.bits, IntegerType::Unsigned); + return IntegerType::get(ctx, dtype.bits, IntegerType::Unsigned); case kDLBfloat: - return BFloat16Type::get(ctx_); + return BFloat16Type::get(ctx); case kDLBool: - return IntegerType::get(ctx_, 1); + return IntegerType::get(ctx, 1); case kDLFloat8_e5m2: - return Float8E5M2Type::get(ctx_); + return Float8E5M2Type::get(ctx); case kDLFloat8_e4m3fn: - return Float8E4M3FNType::get(ctx_); + return Float8E4M3FNType::get(ctx); case kDLFloat8_e5m2fnuz: - return Float8E5M2FNUZType::get(ctx_); + return Float8E5M2FNUZType::get(ctx); case kDLFloat8_e4m3fnuz: - return Float8E4M3FNUZType::get(ctx_); + return Float8E4M3FNUZType::get(ctx); case kDLFloat8_e4m3b11fnuz: - return Float8E4M3B11FNUZType::get(ctx_); + return Float8E4M3B11FNUZType::get(ctx); case kDLComplex: switch (dtype.bits) { case 64: - return ComplexType::get(Float32Type::get(ctx_)); + return ComplexType::get(Float32Type::get(ctx)); case 128: - return ComplexType::get(Float64Type::get(ctx_)); + return ComplexType::get(Float64Type::get(ctx)); default: throw std::runtime_error("Unsupported complex bit width: " + std::to_string(dtype.bits)); } @@ -185,6 +208,7 @@ class DLTensorAdaptor { return; } + MLIRContext *ctx = getCurrentContext(); SmallVector shapeLeaves, strideLeaves; shapeLeaves.resize(ndim_); strideLeaves.resize(ndim_); @@ -192,8 +216,8 @@ class DLTensorAdaptor { size_t shapeDyncCount = 0; size_t strideDyncCount = 0; for (int i = 0; i < ndim_; ++i) { - shapeLeaves[i] = shape_[i].getIntAttr(ctx_, true); - strideLeaves[i] = stride_[i].getIntAttr(ctx_, use32BitStride_); + shapeLeaves[i] = shape_[i].getIntAttr(ctx, true); + strideLeaves[i] = stride_[i].getIntAttr(ctx, use32BitStride_); if (shape_[i].isDynamic) shapeDyncCount++; @@ -201,17 +225,27 @@ class DLTensorAdaptor { strideDyncCount++; } - IntTupleAttr shapeAttr = IntTupleAttr::get(ArrayAttr::get(ctx_, shapeLeaves)); - IntTupleAttr strideAttr = IntTupleAttr::get(ArrayAttr::get(ctx_, strideLeaves)); - LayoutAttr layoutAttr = LayoutAttr::get(ctx_, shapeAttr, strideAttr); + IntTupleAttr shapeAttr, strideAttr; + if (shapeLeaves.size() == 1) { + shapeAttr = cast(shapeLeaves[0]); + } else { + shapeAttr = IntTupleAttr::get(ArrayAttr::get(ctx, shapeLeaves)); + } + if (strideLeaves.size() == 1) { + strideAttr = cast(strideLeaves[0]); + } else { + strideAttr = IntTupleAttr::get(ArrayAttr::get(ctx, strideLeaves)); + } + + LayoutAttr layoutAttr = LayoutAttr::get(ctx, shapeAttr, strideAttr); if (getAddressSpace() != 1) { throw std::runtime_error("Only device address space is supported"); } - AddressSpaceAttr addrSpaceAttr = AddressSpaceAttr::get(ctx_, AddressSpace::Global); + AddressSpaceAttr addrSpaceAttr = AddressSpaceAttr::get(ctx, AddressSpace::Global); assert(alignment_ > 0 && "alignment must be positive"); - AlignAttr alignAttr = AlignAttr::get(ctx_, alignment_); + AlignAttr alignAttr = AlignAttr::get(ctx, alignment_); memrefDesc_.memrefType = fly::MemRefType::get(getElementType(), addrSpaceAttr, layoutAttr, alignAttr); @@ -292,6 +326,9 @@ class DLTensorAdaptor { if (leadingDim < 0 || leadingDim >= ndim_) { throw std::runtime_error("Cannot determine leading dimension"); } + if (stride_[leadingDim].dimSize != 1) { + throw std::runtime_error("Leading dimension must have stride 1"); + } isMemrefStale_ = true; for (int i = 0; i < ndim_; ++i) { @@ -318,7 +355,6 @@ class DLTensorAdaptor { nb::object dlpackCapsule_; int32_t alignment_; bool use32BitStride_; - MLIRContext *ctx_; DLTensor *tensor_; int32_t ndim_; diff --git a/lib/Bindings/Python/FlyExtension.cpp b/lib/Bindings/Python/FlyExtension.cpp index 1bbc0bd5..5747d465 100644 --- a/lib/Bindings/Python/FlyExtension.cpp +++ b/lib/Bindings/Python/FlyExtension.cpp @@ -136,9 +136,10 @@ NB_MODULE(_fly, m) { using DLTensorAdaptor = utils::DLTensorAdaptor; nb::class_(m, "DLTensorAdaptor") - .def(nb::init(), "dlpack_capsule"_a, - "alignment"_a = 1, "use_32bit_stride"_a = false, "context"_a, - "Create a DLTensorAdaptor from a DLPack capsule") + .def(nb::init, bool>(), "dlpack_capsule"_a, + "alignment"_a = nb::none(), "use_32bit_stride"_a = false, + "Create a DLTensorAdaptor from a DLPack capsule. " + "If alignment is None, defaults to element size in bytes (minimum 1). ") .def_prop_ro("shape", &DLTensorAdaptor::getShape, "Get tensor shape as tuple") .def_prop_ro("stride", &DLTensorAdaptor::getStride, "Get tensor stride as tuple") .def_prop_ro("data_ptr", &DLTensorAdaptor::getDataPtr, "Get data pointer as int64") diff --git a/lib/Dialect/Fly/IR/FlyOps.cpp b/lib/Dialect/Fly/IR/FlyOps.cpp index f27e5a4d..e1350371 100644 --- a/lib/Dialect/Fly/IR/FlyOps.cpp +++ b/lib/Dialect/Fly/IR/FlyOps.cpp @@ -130,8 +130,8 @@ FLY_INFER_RETURN_TYPES(MakeViewOp) { auto layoutTy = dyn_cast(operands[1].getType()); if (!ptrTy || !layoutTy) return failure(); - inferredReturnTypes.assign( - {MemRefType::get(ptrTy.getElemTy(), ptrTy.getAddressSpace(), layoutTy.getAttr())}); + inferredReturnTypes.assign({MemRefType::get(ptrTy.getElemTy(), ptrTy.getAddressSpace(), + layoutTy.getAttr(), ptrTy.getAlignment())}); return success(); } diff --git a/pyproject.toml b/pyproject.toml index f57a27ca..265d7110 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,4 +11,23 @@ requires = [ ] build-backend = "setuptools.build_meta" +[tool.black] +line-length = 120 +[tool.ruff] +line-length = 120 +target-version = "py310" + +[tool.ruff.lint] +select = [ + "E", # pycodestyle errors + "W", # pycodestyle warnings + "F", # pyflakes + "I", # isort (import sorting) +] +ignore = [ + "E501", # line too long (handled by formatter) +] + +[tool.ruff.lint.isort] +known-first-party = ["flydsl"] diff --git a/python/flydsl/__init__.py b/python/flydsl/__init__.py index a5921a09..59668c52 100644 --- a/python/flydsl/__init__.py +++ b/python/flydsl/__init__.py @@ -1 +1,3 @@ from .compiler import * + +__version__ = "0.1.0" diff --git a/python/flydsl/compiler/__init__.py b/python/flydsl/compiler/__init__.py index 236470bc..bcc31b4c 100644 --- a/python/flydsl/compiler/__init__.py +++ b/python/flydsl/compiler/__init__.py @@ -1,3 +1,12 @@ from .compiler import compile +from .jit_argument import JitArgumentRegistry, from_dlpack +from .jit_function import jit +from .kernel_function import kernel -__all__ = ["compile"] +__all__ = [ + "compile", + "from_dlpack", + "JitArgumentRegistry", + "jit", + "kernel", +] diff --git a/python/flydsl/compiler/ast_rewriter.py b/python/flydsl/compiler/ast_rewriter.py new file mode 100644 index 00000000..d6baff15 --- /dev/null +++ b/python/flydsl/compiler/ast_rewriter.py @@ -0,0 +1,17 @@ +import ast +from typing import List + + +class ASTRewriter: + transform_manager: List[ast.NodeTransformer] = [] + + @classmethod + def register(cls, transformer: ast.NodeTransformer): + cls.transform_manager.append(transformer) + + @classmethod + def transform(self, func): + pass + # for transformer in self.transform_manager: + # node = transformer.visit(node) + # return node diff --git a/python/flydsl/compiler/compiler.py b/python/flydsl/compiler/compiler.py deleted file mode 100644 index c8116aed..00000000 --- a/python/flydsl/compiler/compiler.py +++ /dev/null @@ -1,148 +0,0 @@ -from contextlib import ExitStack - -from .._mlir.passmanager import PassManager - -from ..lang import MlirModule -from .executor import Executor - - -def _decode_mlir_escaped_bytes(s: str) -> str: - """Decode MLIR string attr content that uses \\xx hex byte escapes (e.g. \\0A, \\09, \\22). - - This is what gpu-module-to-binary emits for `assembly = "..."` (and often `bin = "..."`). - """ - out_chars = [] - i = 0 - n = len(s) - - def _is_hex(c: str) -> bool: - return ("0" <= c <= "9") or ("a" <= c <= "f") or ("A" <= c <= "F") - - while i < n: - ch = s[i] - if ch != "\\": - out_chars.append(ch) - i += 1 - continue - - # Backslash escape. - if i + 2 < n and _is_hex(s[i + 1]) and _is_hex(s[i + 2]): - byte = int(s[i + 1 : i + 3], 16) - out_chars.append(chr(byte)) - i += 3 - continue - - # Common C-style single-char escapes (rare here, but harmless). - if i + 1 < n: - nxt = s[i + 1] - if nxt == "n": - out_chars.append("\n") - i += 2 - continue - if nxt == "t": - out_chars.append("\t") - i += 2 - continue - if nxt == "r": - out_chars.append("\r") - i += 2 - continue - if nxt in ['"', "\\"]: - out_chars.append(nxt) - i += 2 - continue - # Unknown escape: keep the escaped char as-is. - out_chars.append(nxt) - i += 2 - continue - - # Trailing backslash. - i += 1 - - return "".join(out_chars) - - -def _extract_mlir_string_attr(asm: str, attr_name: str) -> str | None: - """Extract and decode a string attribute like `attr_name = "..."` from an MLIR asm dump.""" - marker = f'{attr_name} = "' - start = asm.find(marker) - if start == -1: - return None - - i = start + len(marker) - # Find the closing quote. Skip over \xx escapes as two hex bytes. - while i < len(asm): - if asm[i] == "\\" and i + 2 < len(asm): - # Skip the escape introducer and two following chars (typically hex digits). - i += 3 - continue - if asm[i] == '"': - end = i - encoded = asm[start + len(marker) : end] - return _decode_mlir_escaped_bytes(encoded) - i += 1 - return None - - -def compile( - fx_module: MlirModule, verify=True, print_after_all=False, output_format="fatbin" -): - # gpu-module-to-binary formats are backend-dependent. For ROCm/ROCDL, "isa" - # is the human-readable assembly/ISA dump and "fatbin" is an object container. - fmt_map = { - "fatbin": "fatbin", - "assembly": "isa", - } - if output_format not in fmt_map: - raise ValueError( - f"Unsupported output_format: {output_format}. Use one of {list(fmt_map)}" - ) - - pipeline = ( - "builtin.module(" - "gpu-kernel-outlining{data-layout-str=}," - "fly-canonicalize," - "fly-layout-lowering," - "convert-fly-to-rocdl," - "canonicalize," - "gpu.module(" - "convert-vector-to-llvm," - "canonicalize," - "convert-gpu-to-rocdl{ chipset=gfx000 index-bitwidth=0 runtime=HIP use-bare-ptr-memref-call-conv=true}" - ")," - "rocdl-attach-target{O=2 abi=600 chip=gfx942 correct-sqrt=true daz=false fast=false features= finite-only=false module= triple=amdgcn-amd-amdhsa unsafe-math=false wave64=true}," - "gpu-to-llvm{intersperse-sizes-for-kernels=false use-bare-pointers-for-host=true use-bare-pointers-for-kernels=true}," - "reconcile-unrealized-casts," - f"gpu-module-to-binary{{format={fmt_map[output_format]} opts= section= toolkit=}}" - ")" - ) - mlir_module = fx_module.module - module = mlir_module.parse(mlir_module.operation.get_asm(enable_debug_info=True)) - - try: - with ExitStack() as stack: - stack.enter_context(module.context) - pm = PassManager.parse(pipeline) - pm.enable_verifier(verify) - pm.enable_ir_printing(print_after_all=print_after_all) - - pm.run(module.operation) - except Exception as e: - print(e) - - # Default: produce a runnable executor (requires gpu-module-to-binary to have produced - # a launchable binary container). - if output_format == "fatbin": - return Executor(module) - - # Debug output: return textual assembly/ISA emitted into gpu.binary's `assembly` attribute - # (or `bin` in some toolchains). - # If the toolchain doesn't embed it (or it was elided), fall back to returning the MLIR. - asm = module.operation.get_asm(enable_debug_info=True, large_elements_limit=1 << 30) - text = _extract_mlir_string_attr(asm, "assembly") - if text is not None: - return text - text = _extract_mlir_string_attr(asm, "bin") - if text is not None: - return text - return asm diff --git a/python/flydsl/compiler/executor.py b/python/flydsl/compiler/executor.py deleted file mode 100644 index 552ef15e..00000000 --- a/python/flydsl/compiler/executor.py +++ /dev/null @@ -1,58 +0,0 @@ -import ctypes -import os -import torch - - -from .._mlir.execution_engine import ExecutionEngine - - -class Executor: - def __init__(self, jit_module): - # Get MLIR_PATH from environment variable - mlir_path = os.environ.get('MLIR_PATH') - if not mlir_path: - raise RuntimeError( - "Environment variable MLIR_PATH is not set!\n" - "Please set MLIR_PATH before running the program, for example:\n" - " export MLIR_PATH=/path/to/llvm-project/buildmlir\n" - "Or source the configuration script:\n" - " source pre_build.sh" - ) - - lib_dir = os.path.join(mlir_path, 'lib') - - self.engine = ExecutionEngine( - jit_module, - opt_level=3, - shared_libs=[ - os.path.join(lib_dir, "libmlir_rocm_runtime.so"), - os.path.join(lib_dir, "libmlir_runner_utils.so"), - ], - ) - self.engine.initialize() - - def convert_args(self, args): - if isinstance(args, torch.Tensor): - return ctypes.cast( - ctypes.pointer(ctypes.c_void_p(args.data_ptr())), ctypes.c_void_p - ) - else: - raise TypeError(f"Unsupported argument type: {type(args)}") - - def __call__(self, *args): - return self.__getattr__("__call__")(*args) - - def __getattr__(self, name: str): - try: - func_ptr = self.engine.raw_lookup(name) - func_exe = ctypes.CFUNCTYPE(None, ctypes.c_void_p)(func_ptr) - except KeyError: - raise AttributeError(f"No such function: {name}") from None - - def wrapper(*args): - addresses = [ctypes.c_void_p(0)] - addresses += [self.convert_args(arg) for arg in args] - c_args = (ctypes.c_void_p * len(addresses))(*addresses) - return func_exe(c_args) - - return wrapper diff --git a/python/flydsl/compiler/jit_argument.py b/python/flydsl/compiler/jit_argument.py new file mode 100644 index 00000000..897f5b96 --- /dev/null +++ b/python/flydsl/compiler/jit_argument.py @@ -0,0 +1,144 @@ +import inspect +from dataclasses import dataclass +from typing import Callable, Dict, List, Optional, Tuple, Type, get_origin + +import torch + +from .._mlir._mlir_libs._fly import DLTensorAdaptor +from ..expr.typing import Constexpr, Int32, Stream, Tensor +from .protocol import DslType, JitArgument + + +class JitArgumentRegistry: + registry: Dict[type, Tuple[Callable, Type[DslType]]] = {} + jit_arg2dsl_type: Dict[type, Type[DslType]] = {} + + @classmethod + def register(cls, py_type: type, *, dsl_type: Type[DslType] = None): + def decorator(jit_arg_constructor: Callable): + if py_type in cls.registry: + raise ValueError(f"JitArgumentConstructor for {py_type} already registered") + + if dsl_type is not None: + dest_dsl_type = dsl_type + elif isinstance(jit_arg_constructor, type) and isinstance(jit_arg_constructor, DslType): + dest_dsl_type = jit_arg_constructor + else: + raise ValueError(f"Invalid dsl_type for {py_type}: {dsl_type}") + + cls.registry[py_type] = (jit_arg_constructor, dest_dsl_type) + cls.jit_arg2dsl_type[jit_arg_constructor] = dest_dsl_type + return jit_arg_constructor + + return decorator + + @classmethod + def register_jit_arg(cls, jit_arg: type, dsl_type: Type[DslType]): + if not issubclass(jit_arg, JitArgument): + raise ValueError(f"JitArgument must implement JitArgument protocol, got {jit_arg}") + if jit_arg in cls.jit_arg2dsl_type: + raise ValueError(f"JitArgument {jit_arg} already registered") + cls.jit_arg2dsl_type[jit_arg] = dsl_type + + @classmethod + def get(cls, py_type: type) -> Optional[Tuple[Callable, Type[DslType]]]: + return cls.registry.get(py_type, (None, None)) + + @classmethod + def get_dsl_type(cls, jit_arg_type: type) -> Type[DslType]: + return cls.jit_arg2dsl_type[jit_arg_type] + + +def _is_constexpr_annotation(annotation) -> bool: + """Check if annotation is Constexpr or Constexpr[T].""" + if annotation is Constexpr: + return True + return get_origin(annotation) is Constexpr + + +def convert_to_jit_arguments( + sig: inspect.Signature, bound +) -> tuple[List[str], List[JitArgument], List[DslType], dict[str, any]]: + param_names: List[str] = [] + jit_args: List[JitArgument] = [] + dsl_types: List[DslType] = [] + constexpr_values: dict[str, any] = {} + + for param_name, value in bound.arguments.items(): + param = sig.parameters[param_name] + annotation = param.annotation + + if annotation is not inspect.Parameter.empty and _is_constexpr_annotation(annotation): + constexpr_values[param_name] = value + continue + + if isinstance(value, JitArgument) and isinstance(value, DslType): + jit_arg = value + dsl_type = type(value) + elif isinstance(value, JitArgument): + jit_arg = value + dsl_type = JitArgumentRegistry.get_dsl_type(type(value)) + if dsl_type is None: + raise TypeError( + f"No DslType registered for JitArgument type {type(value).__name__} (parameter '{param_name}')" + ) + else: + jit_arg_constructor, dsl_type = JitArgumentRegistry.get(type(value)) + if jit_arg_constructor is None: + raise TypeError(f"No JitArgument registered for type {type(value).__name__} (parameter '{param_name}')") + try: + jit_arg = jit_arg_constructor(value) + except Exception as e: + raise TypeError(f"Failed to construct JitArgument for parameter '{param_name}': {e}") from e + + param_names.append(param_name) + jit_args.append(jit_arg) + dsl_types.append(dsl_type) + return param_names, jit_args, dsl_types, constexpr_values + + +# ================================ Common useful JitArguments ================================ + + +@JitArgumentRegistry.register(torch.Tensor, dsl_type=Tensor) +class TensorAdaptor: + def __init__( + self, + tensor: torch.Tensor, + assumed_align: Optional[int] = None, + use_32bit_stride: bool = False, + ): + self.tensor_adaptor = DLTensorAdaptor(tensor.__dlpack__(), assumed_align, use_32bit_stride) + self.assumed_align = assumed_align + self.use_32bit_stride = use_32bit_stride + + def requires_memref_desc(func): + def wrapper(self, *args, **kwargs): + self.tensor_adaptor.build_memref_desc() + return func(self, *args, **kwargs) + + return wrapper + + @requires_memref_desc + def __ir_types__(self): + return [self.tensor_adaptor.get_memref_type()] + + @requires_memref_desc + def __c_pointers__(self): + return self.tensor_adaptor.get_c_pointers() + + def mark_layout_dynamic(self, leading_dim: Optional[int] = None, divisibility: int = 1): + if leading_dim is None: + leading_dim = -1 # automatically determine leading dimension + self.tensor_adaptor.mark_layout_dynamic(leading_dim, divisibility) + return self + + +def from_dlpack( + tensor: torch.Tensor, *, assumed_align: Optional[int] = None, use_32bit_stride: bool = False +) -> TensorAdaptor: + return TensorAdaptor(tensor, assumed_align, use_32bit_stride) + + +JitArgumentRegistry.register(int)(Int32) +JitArgumentRegistry.register(torch.cuda.Stream)(Stream) diff --git a/python/flydsl/compiler/jit_executor.py b/python/flydsl/compiler/jit_executor.py new file mode 100644 index 00000000..a570227d --- /dev/null +++ b/python/flydsl/compiler/jit_executor.py @@ -0,0 +1,112 @@ +import ctypes +import threading +from functools import lru_cache +from pathlib import Path +from typing import List + +from .._mlir import ir +from .._mlir.execution_engine import ExecutionEngine +from .protocol import get_c_pointers + + +@lru_cache(maxsize=1) +def _get_mlir_runtime_libs() -> List[str]: + mlir_libs_dir = Path(__file__).resolve().parent.parent / "_mlir" / "_mlir_libs" + lib_names = ["libmlir_rocm_runtime.so", "libmlir_c_runner_utils.so"] + return [str(mlir_libs_dir / name) for name in lib_names] + + +class JitCompiledFunction: + def __init__( + self, + compiled_module: ir.Module, + func_name: str, + original_ir: str = None, + ): + self._compiled_ir = str(compiled_module) + self._func_name = func_name + self._original_ir = original_ir + self._module = None + self._engine = None + self._engine_lock = threading.Lock() + self._tls = threading.local() + + def __getstate__(self): + return { + "compiled_ir": self._compiled_ir, + "func_name": self._func_name, + "original_ir": self._original_ir, + } + + def __setstate__(self, state): + self._compiled_ir = state["compiled_ir"] + self._func_name = state["func_name"] + self._original_ir = state["original_ir"] + self._module = None + self._engine = None + self._engine_lock = threading.Lock() + self._tls = threading.local() + + def _init_engine(self): + with self._engine_lock: + if self._engine is not None: + return + + with ir.Context(): + self._module = ir.Module.parse(self._compiled_ir) + self._engine = ExecutionEngine( + self._module, + opt_level=3, + shared_libs=_get_mlir_runtime_libs(), + ) + self._engine.initialize() + + def _get_packed_args_buffer(self, size: int): + buf = getattr(self._tls, "packed_args", None) + capacity = getattr(self._tls, "capacity", 0) + if buf is None or capacity < size: + buf = (ctypes.c_void_p * size)() + self._tls.packed_args = buf + self._tls.capacity = size + return buf + + def __call__(self, *args, **kwargs): + if self._engine is None: + self._init_engine() + + all_c_ptrs: List[ctypes.c_void_p] = [] + for arg in args: + all_c_ptrs.extend(get_c_pointers(arg)) + + func_ptr = self._engine.raw_lookup(self._func_name) + func_exe = ctypes.CFUNCTYPE(None, ctypes.c_void_p)(func_ptr) + + num_args = len(all_c_ptrs) + packed_args = self._get_packed_args_buffer(num_args) + for i, ptr in enumerate(all_c_ptrs): + packed_args[i] = ptr + + return func_exe(packed_args) + + def print_ir(self, compiled: bool = True): + if compiled: + print("=" * 60) + print("Compiled MLIR IR:") + print("=" * 60) + print(self._compiled_ir) + else: + if self._original_ir is None: + print("Original IR not available") + else: + print("=" * 60) + print("Original MLIR IR:") + print("=" * 60) + print(self._original_ir) + + @property + def ir(self) -> str: + return self._compiled_ir + + @property + def original_ir(self) -> str: + return self._original_ir diff --git a/python/flydsl/compiler/jit_function.py b/python/flydsl/compiler/jit_function.py new file mode 100644 index 00000000..83665033 --- /dev/null +++ b/python/flydsl/compiler/jit_function.py @@ -0,0 +1,252 @@ +import hashlib +import inspect +import pickle +from functools import lru_cache +from pathlib import Path +from typing import Any, Callable, Dict, Optional + +import flydsl + +from .._mlir import ir +from .._mlir.dialects import func +from .._mlir.passmanager import PassManager +from ..utils import env, log +from .jit_argument import convert_to_jit_arguments +from .jit_executor import JitCompiledFunction +from .kernel_function import ( + CompilationContext, + FuncLocationTracker, + create_gpu_module, + get_gpu_module_body, +) +from .protocol import get_ir_types, new_from_ir_values + + +@lru_cache(maxsize=1) +def _get_llvm_version() -> str: + _FLYDSL_ROOT = Path(__file__).resolve().parents[4] + llvm_hash_file = _FLYDSL_ROOT / "cmake" / "llvm-hash.txt" + if llvm_hash_file.exists(): + llvm_hash = llvm_hash_file.read_text() + else: + llvm_hash = "release_version" + log().debug(f"LLVM version: {llvm_hash}") + return llvm_hash + + +@lru_cache(maxsize=1) +def _flydsl_verison_key() -> str: + return f"flydsl:{flydsl.__version__}|llvm:{_get_llvm_version()}" + + +def _jit_function_cache_key(func: Callable) -> str: + parts = [] + parts.append(_flydsl_verison_key()) + try: + source = inspect.getsource(func) + except OSError: + # Fallback to bytecode if source unavailable + source = func.__code__.co_code.hex() + parts.append(source) + combined = "\n".join(parts) + return hashlib.sha256(combined.encode()).hexdigest()[:32] + + +class MlirCompiler: + PIPELINE = ( + "builtin.module(" + "gpu-kernel-outlining{data-layout-str=}," + "fly-canonicalize," + "fly-layout-lowering," + "convert-fly-to-rocdl," + "canonicalize," + "gpu.module(" + "convert-vector-to-llvm," + "canonicalize," + "convert-gpu-to-rocdl{chipset=gfx942 index-bitwidth=0 runtime=HIP use-bare-ptr-memref-call-conv=true}" + ")," + "rocdl-attach-target{O=2 abi=600 chip=gfx942 correct-sqrt=true daz=false fast=false features= finite-only=false module= triple=amdgcn-amd-amdhsa unsafe-math=false wave64=true}," + "gpu-to-llvm{use-bare-pointers-for-host=true use-bare-pointers-for-kernels=true}," + "convert-arith-to-llvm," + "convert-func-to-llvm," + "reconcile-unrealized-casts," + "gpu-module-to-binary{format=fatbin}" + ")" + ) + + @classmethod + def compile(cls, module: ir.Module) -> ir.Module: + module.operation.verify() + + module = ir.Module.parse(module.operation.get_asm(enable_debug_info=env.debug.enable_debug_info)) + pm = PassManager.parse(cls.PIPELINE) + + pm.enable_verifier(env.debug.enable_verifier) + pm.enable_ir_printing(print_after_all=env.debug.print_after_all) + pm.run(module.operation) + + return module + + +class JitCacheManager: + """Directory-based cache manager. + + Cache directory structure: + {cache_root}/{func_name}_{manager_key}/ + {cache_key}.pkl - serialized compiled kernel + + Each compiled kernel is saved immediately after compilation. + """ + + def __init__(self, cache_dir: Path): + self.cache_dir = cache_dir + self.memory_cache: Dict[str, Any] = {} + + def _cache_file(self, cache_key: str) -> Path: + safe_key = hashlib.sha256(cache_key.encode()).hexdigest()[:16] + return self.cache_dir / f"{safe_key}.pkl" + + def get(self, cache_key: str) -> Optional[Any]: + if cache_key in self.memory_cache: + return self.memory_cache[cache_key] + + cache_file = self._cache_file(cache_key) + if cache_file.exists(): + try: + with open(cache_file, "rb") as f: + value = pickle.load(f) + self.memory_cache[cache_key] = value + log().debug(f"Cache hit from disk: {cache_file.name}") + return value + except Exception as e: + log().warning(f"Failed to load cache {cache_file}: {e}") + return None + + def set(self, cache_key: str, value: Any) -> None: + self.memory_cache[cache_key] = value + self.cache_dir.mkdir(parents=True, exist_ok=True) + cache_file = self._cache_file(cache_key) + try: + with open(cache_file, "wb") as f: + pickle.dump(value, f) + log().debug(f"Cache saved: {cache_file.name}") + except Exception as e: + log().warning(f"Failed to save cache {cache_file}: {e}") + + def load_all(self) -> int: + if not self.cache_dir.exists(): + return 0 + count = 0 + for cache_file in self.cache_dir.glob("*.pkl"): + try: + with open(cache_file, "rb") as f: + pickle.load(f) + count += 1 + except Exception: + pass + log().debug(f"Found {count} cached entries in {self.cache_dir}") + return count + + def __contains__(self, cache_key: str) -> bool: + return cache_key in self.memory_cache or self._cache_file(cache_key).exists() + + +class JitFunction: + def __init__(self, func: Callable): + self.func = func + self.managerKey = _jit_function_cache_key(func) + + cache_root = env.runtime.cache_dir + if cache_root: + cache_dir = Path(cache_root) / f"{func.__name__}_{self.managerKey}" + self.cacheManager = JitCacheManager(cache_dir) + self.cacheManager.load_all() + else: + self.cacheManager = None + + def _make_cache_key(self, bound_args: Dict) -> str: + key_parts = [] + for name, arg in bound_args.items(): + key_parts.append(f"{name}:{self._get_type_signature(arg)}") + return "|".join(key_parts) + + def _get_type_signature(self, obj) -> str: + if hasattr(obj, "__cache_signature__"): + return obj.__cache_signature__() + elif hasattr(obj, "dtype") and hasattr(obj, "shape"): + return f"tensor[{obj.dtype},{obj.shape}]" + elif isinstance(obj, (int, float, bool, str)): + return f"{type(obj).__name__}:{obj}" + return type(obj).__name__ + + def __call__(self, *args, **kwargs): + if ir.Context.current is not None: + log().debug(f"JitFunction {self.func.__name__} within ir.Context") + return self.func(*args, **kwargs) + + sig = inspect.signature(self.func) + bound = sig.bind(*args, **kwargs) + bound.apply_defaults() + + cache_key = self._make_cache_key(bound.arguments) + cached_func = self.cacheManager.get(cache_key) if self.cacheManager else None + + if cached_func is not None and env.runtime.enable_cache: + log().info(f"Cache hit for {self.func.__name__}") + with ir.Context(): + _, jit_args, _, _ = convert_to_jit_arguments(sig, bound) + return cached_func(*jit_args) + + with ir.Context() as ctx: + param_names, jit_args, dsl_types, constexpr_values = convert_to_jit_arguments(sig, bound) + ir_types = get_ir_types(jit_args) + loc = ir.Location.unknown(ctx) + + log().info(f"jit_args={jit_args}") + log().info(f"dsl_types={dsl_types}") + + module = ir.Module.create(loc=loc) + module.operation.attributes["gpu.container_module"] = ir.UnitAttr.get() + + func_tracker = FuncLocationTracker(self.func) + + with ir.InsertionPoint(module.body), loc: + gpu_module = create_gpu_module("kernels", targets=['#rocdl.target']) + + func_op = func.FuncOp(self.func.__name__, (ir_types, [])) + func_op.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get() + entry_block = func_op.add_entry_block() + + with CompilationContext.create(func_tracker) as comp_ctx: + comp_ctx.gpu_module_op = gpu_module + comp_ctx.gpu_module_body = get_gpu_module_body(gpu_module) + + with ir.InsertionPoint(entry_block): + ir_args = list(func_op.regions[0].blocks[0].arguments) + dsl_args = new_from_ir_values(dsl_types, jit_args, ir_args) + log().info(f"dsl_args={dsl_args}") + named_args = dict(zip(param_names, dsl_args)) + named_args.update(constexpr_values) + self.func(**named_args) + func.ReturnOp([]) + + original_ir = module.operation.get_asm(enable_debug_info=True) + + compiled_module = MlirCompiler.compile(module) + + compiled_func = JitCompiledFunction( + compiled_module, + self.func.__name__, + original_ir, + ) + + if self.cacheManager: + self.cacheManager.set(cache_key, compiled_func) + + return compiled_func(*jit_args) + + +def jit(func: Optional[Callable] = None) -> JitFunction: + if func is None: + return lambda f: JitFunction(f) + return JitFunction(func) diff --git a/python/flydsl/compiler/kernel_function.py b/python/flydsl/compiler/kernel_function.py new file mode 100644 index 00000000..2fcae545 --- /dev/null +++ b/python/flydsl/compiler/kernel_function.py @@ -0,0 +1,396 @@ +import inspect +from contextlib import contextmanager +from typing import Any, Callable, Dict, List, Optional, Tuple, Union, get_origin + +from .._mlir import ir +from .._mlir.dialects import arith, gpu +from ..expr.typing import Constexpr +from .protocol import extract_ir_values, get_ir_types, new_from_ir_values + +# ============================================================================= +# GPU Operation Helpers +# ============================================================================= + + +def create_gpu_module( + sym_name: str, + targets: Optional[List[str]] = None, + *, + loc=None, + ip=None, +) -> gpu.GPUModuleOp: + target_attrs = [] + if targets: + for t in targets: + if isinstance(t, str): + target_attrs.append(ir.Attribute.parse(t)) + else: + target_attrs.append(t) + module_op = gpu.GPUModuleOp( + sym_name, targets=ir.ArrayAttr.get(target_attrs) if target_attrs else None, loc=loc, ip=ip + ) + module_op.regions[0].blocks.append() + return module_op + + +def get_gpu_module_body(module_op: gpu.GPUModuleOp): + return module_op.regions[0].blocks[0] + + +def create_gpu_func( + sym_name: str, + function_type: ir.TypeAttr, + *, + loc=None, + ip=None, +) -> gpu.GPUFuncOp: + return gpu.GPUFuncOp(function_type, sym_name=sym_name, kernel=True, loc=loc, ip=ip) + + +# ============================================================================= +# Location Tracking Utilities +# ============================================================================= + + +def get_source_location(depth: int = 2) -> Tuple[str, int, int]: + """Get source file location from call stack. + + Args: + depth: Stack depth to look up (2 = caller's caller) + + Returns: + Tuple of (filename, line, column) + """ + frame = inspect.currentframe() + try: + for _ in range(depth): + if frame is not None: + frame = frame.f_back + if frame is not None: + return (frame.f_code.co_filename, frame.f_lineno, 0) + finally: + del frame + return ("", 0, 0) + + +def create_file_location(filename: str, line: int, col: int = 0, context=None) -> ir.Location: + """Create an MLIR file location.""" + ctx = context or ir.Context.current + return ir.Location.file(filename, line, col, context=ctx) + + +def create_caller_location(depth: int = 2, context=None) -> ir.Location: + """Create an MLIR location from the caller's source position.""" + filename, line, col = get_source_location(depth + 1) + return create_file_location(filename, line, col, context) + + +class FuncLocationTracker: + """Track source locations for a Python function being traced.""" + + def __init__(self, func: Callable): + self._func = func + self._filename = inspect.getfile(func) + try: + self._source_lines, self._start_line = inspect.getsourcelines(func) + except (OSError, TypeError): + self._source_lines = [] + self._start_line = 0 + + @property + def filename(self) -> str: + return self._filename + + @property + def start_line(self) -> int: + return self._start_line + + def get_func_location(self, context=None) -> ir.Location: + """Get location for the function definition.""" + return create_file_location(self._filename, self._start_line, 0, context) + + @contextmanager + def func_scope(self): + """Enter a location scope for this function.""" + loc = self.get_func_location() + with loc: + yield loc + + +# ============================================================================= +# Launch Configuration +# ============================================================================= + +DimValueType = Union[int, ir.Value] +DimType = Union[int, ir.Value, Tuple[DimValueType, ...], List[DimValueType]] + + +def _to_index_value(val: DimValueType) -> ir.Value: + if isinstance(val, ir.Value): + if val.type == ir.IndexType.get(): + return val + return arith.index_cast(ir.IndexType.get(), val) + return arith.constant(ir.IndexType.get(), val) + + +def _normalize_dim(dim: DimType) -> Tuple[DimValueType, DimValueType, DimValueType]: + if isinstance(dim, (int, ir.Value)): + return (dim, 1, 1) + elif len(dim) == 1: + return (dim[0], 1, 1) + elif len(dim) == 2: + return (dim[0], dim[1], 1) + return (dim[0], dim[1], dim[2]) + + +# ============================================================================= +# Compilation Context (per-compilation state) +# ============================================================================= + + +class CompilationContext: + """Context for tracking compilation state within a @jit function. + + Manages: + - GPU module op for kernel definitions + - Kernel counter for unique naming + - Location trackers for debugging + """ + + _current: Optional["CompilationContext"] = None + + def __init__(self, func_tracker: Optional[FuncLocationTracker] = None): + self.gpu_module_op = None + self.kernel_counter = 0 + self.func_tracker = func_tracker + self.kernel_trackers: Dict[str, FuncLocationTracker] = {} + + @classmethod + def get_current(cls) -> Optional["CompilationContext"]: + return cls._current + + @classmethod + @contextmanager + def create(cls, func_tracker: Optional[FuncLocationTracker] = None): + prev = cls._current + ctx = CompilationContext(func_tracker) + cls._current = ctx + try: + yield ctx + finally: + cls._current = prev + + def next_kernel_id(self) -> int: + """Get next unique kernel ID.""" + kid = self.kernel_counter + self.kernel_counter += 1 + return kid + + def register_kernel_tracker(self, name: str, tracker: FuncLocationTracker): + """Register a location tracker for a kernel function.""" + self.kernel_trackers[name] = tracker + + def get_kernel_tracker(self, name: str) -> Optional[FuncLocationTracker]: + """Get the location tracker for a kernel function.""" + return self.kernel_trackers.get(name) + + +# ============================================================================= +# Kernel Launcher +# ============================================================================= + + +class KernelLauncher: + """Holds kernel reference and generates gpu.launch_func on launch(). + + Created by calling a @kernel decorated function. Call .launch() + to emit the actual launch operation. + """ + + def __init__( + self, + kernel_name: str, + kernel_args: Tuple, + call_location: Optional[ir.Location] = None, + ): + self._kernel_name = kernel_name + self._kernel_args = kernel_args + self._call_location = call_location + + def launch( + self, + *, + grid: DimType = (1, 1, 1), + block: DimType = (1, 1, 1), + smem: Union[int, ir.Value] = 0, + stream: Optional[ir.Value] = None, + ) -> None: + """Emit gpu.launch_func operation with the given configuration. + + Args: + grid: Grid dimensions (x, y, z). Can be int, ir.Value, tuple, or list. + block: Block dimensions (x, y, z). Can be int, ir.Value, tuple, or list. + smem: Dynamic shared memory size in bytes. Can be int or ir.Value. + stream: CUDA/HIP stream as ir.Value. None means default stream. + """ + launch_loc = create_caller_location(depth=2) + + kernel_operands = [] + for arg in self._kernel_args: + kernel_operands.extend(extract_ir_values(arg)) + + grid_dims = _normalize_dim(grid) + block_dims = _normalize_dim(block) + + with launch_loc: + grid_x = _to_index_value(grid_dims[0]) + grid_y = _to_index_value(grid_dims[1]) + grid_z = _to_index_value(grid_dims[2]) + block_x = _to_index_value(block_dims[0]) + block_y = _to_index_value(block_dims[1]) + block_z = _to_index_value(block_dims[2]) + + smem_val = None + if isinstance(smem, ir.Value): + smem_val = smem + elif smem > 0: + smem_val = arith.constant(ir.IntegerType.get_signless(32), smem) + + async_object = stream + + gpu.LaunchFuncOp( + ["kernels", self._kernel_name], + (grid_x, grid_y, grid_z), + (block_x, block_y, block_z), + kernel_operands, + dynamic_shared_memory_size=smem_val, + async_object=async_object, + loc=launch_loc, + ip=None, + ) + + +# ============================================================================= +# Kernel Function +# ============================================================================= + + +class KernelFunction: + """Wrapper for @kernel decorated functions. + + When called, emits a gpu.func and returns a KernelLauncher for + configuring and launching the kernel. + """ + + def __init__(self, func: Callable, some_args=None): + self._func = func + self._some_args = some_args + self._kernel_name: Optional[str] = None + self._location_tracker = FuncLocationTracker(func) + + @staticmethod + def _is_constexpr_annotation(annotation) -> bool: + if annotation is Constexpr: + return True + return get_origin(annotation) is Constexpr + + def _emit_kernel(self, ctx: CompilationContext, args: Tuple, kwargs: Dict) -> Tuple[Any, ...]: + """Emit gpu.func for this kernel into the GPU module. + + Returns: + Tuple of non-constexpr argument values for use in launch. + """ + sig = inspect.signature(self._func) + bound = sig.bind(*args, **kwargs) + bound.apply_defaults() + + param_names: List[str] = [] + param_values: List[Any] = [] + constexpr_values: Dict[str, Any] = {} + + for param_name, value in bound.arguments.items(): + param = sig.parameters[param_name] + annotation = param.annotation + if annotation is not inspect.Parameter.empty and self._is_constexpr_annotation(annotation): + constexpr_values[param_name] = value + else: + param_names.append(param_name) + param_values.append(value) + + kernel_arg_types = [] + for value in param_values: + kernel_arg_types.extend(get_ir_types(value)) + + kernel_id = ctx.next_kernel_id() + self._kernel_name = f"{self._func.__name__}_{kernel_id}" + + ctx.register_kernel_tracker(self._kernel_name, self._location_tracker) + + kernel_loc = self._location_tracker.get_func_location() + + with ir.InsertionPoint(ctx.gpu_module_body): + func_type = ir.FunctionType.get(kernel_arg_types, []) + with kernel_loc: + gpu_func = create_gpu_func(self._kernel_name, ir.TypeAttr.get(func_type)) + gpu_func.regions[0].blocks.append(*kernel_arg_types) + entry_block = gpu_func.regions[0].blocks[0] + + with ir.InsertionPoint(entry_block), kernel_loc: + block_args = list(entry_block.arguments) + dsl_args: Dict[str, Any] = {} + idx = 0 + for param_name, value in zip(param_names, param_values): + n = len(get_ir_types(value)) + dsl_args[param_name] = new_from_ir_values(type(value), value, list(block_args[idx : idx + n])) + idx += n + + dsl_args.update(constexpr_values) + self._func(**dsl_args) + gpu.ReturnOp([]) + + return tuple(param_values) + + def __call__(self, *args, **kwargs) -> KernelLauncher: + ctx = CompilationContext.get_current() + if ctx is None: + raise RuntimeError("@kernel can only be called inside @jit function") + + call_loc = create_caller_location(depth=2) + + kernel_args = self._emit_kernel(ctx, args, kwargs) + + return KernelLauncher(self._kernel_name, kernel_args, call_loc) + + +# ============================================================================= +# Kernel Decorator +# ============================================================================= + + +def kernel(func: Optional[Callable] = None, *, some_args=None) -> KernelFunction: + """Decorator for GPU kernel functions. + + Usage: + @kernel + def my_kernel(a: Tensor, b: Tensor): + # kernel body + ... + + # Or with arguments + @kernel(some_args=value) + def my_kernel(a: Tensor): + ... + + The decorated function can be called inside a @jit function to + define the kernel, then .launch(config) is called to emit the launch op. + + Args: + func: Function to decorate + some_args: Optional kernel-specific arguments + + Returns: + KernelFunction wrapper + """ + if func is None: + return lambda f: KernelFunction(f, some_args=some_args) + return KernelFunction(func, some_args=some_args) diff --git a/python/flydsl/compiler/protocol.py b/python/flydsl/compiler/protocol.py new file mode 100644 index 00000000..03126126 --- /dev/null +++ b/python/flydsl/compiler/protocol.py @@ -0,0 +1,61 @@ +import ctypes +from itertools import chain +from typing import List, Protocol, runtime_checkable + +from .._mlir import ir + + +@runtime_checkable +class DslType(Protocol): + @classmethod + def __new_from_ir_values__(cls, values: List[ir.Value]) -> "DslType": ... + def __extract_ir_values__(self) -> List[ir.Value]: ... + + +@runtime_checkable +class JitArgument(Protocol): + def __ir_types__(self) -> List[ir.Type]: ... + def __c_pointers__(self) -> List[ctypes.c_void_p]: ... + + +def get_ir_types(obj) -> List[ir.Type]: + if isinstance(obj, ir.Value): + return [obj.type] + elif hasattr(obj, "__ir_types__"): + return obj.__ir_types__() + elif hasattr(obj, "__extract_ir_values__"): + return [v.type for v in obj.__extract_ir_values__()] + elif isinstance(obj, (tuple, list)): + return list(chain.from_iterable(get_ir_types(x) for x in obj)) + raise TypeError(f"Cannot get IR types from {obj}") + + +def get_c_pointers(obj) -> List[ctypes.c_void_p]: + if hasattr(obj, "__c_pointers__"): + return obj.__c_pointers__() + elif isinstance(obj, (tuple, list)): + return list(chain.from_iterable(get_c_pointers(x) for x in obj)) + raise TypeError(f"Cannot get C pointers from {obj}") + + +def extract_ir_values(obj) -> List[ir.Value]: + if isinstance(obj, ir.Value): + return [obj] + elif hasattr(obj, "__extract_ir_values__"): + return obj.__extract_ir_values__() + elif isinstance(obj, (tuple, list)): + return list(chain.from_iterable(extract_ir_values(x) for x in obj)) + raise TypeError(f"Cannot extract IR values from {obj}") + + +def new_from_ir_values(dsl_type, args, values: List[ir.Value]) -> DslType: + if hasattr(dsl_type, "__new_from_ir_values__"): + return dsl_type.__new_from_ir_values__(values) + elif isinstance(dsl_type, (tuple, list)): + elem = [] + for ty, arg in zip(dsl_type, args, strict=True): + val_num = len(get_ir_types(arg)) + elem.append(new_from_ir_values(ty, arg, values[:val_num])) + values = values[val_num:] + return type(dsl_type)(elem) + raise TypeError(f"Cannot construct from IR values for {dsl_type}") diff --git a/python/flydsl/expr/primitive.py b/python/flydsl/expr/primitive.py new file mode 100644 index 00000000..27754eaa --- /dev/null +++ b/python/flydsl/expr/primitive.py @@ -0,0 +1,437 @@ +from flydsl.lang.meta import dsl_api_wrapper + +from .._mlir import ir +from .._mlir.dialects import arith, fly, gpu +from .._mlir.dialects.fly import ( + # Enum Attributes + AddressSpace, + CachePolicy, + # Type + CopyAtomUniversalCopyType, + IntTupleType, + LayoutType, + MemRefType, + MmaAtomUniversalFMAType, + PointerType, + SwizzleType, +) +from .._mlir.extras import types as T + +# @ir.register_value_caster(T.F16Type.static_typeid) +# @ir.register_value_caster(T.F32Type.static_typeid) +# @ir.register_value_caster(T.F64Type.static_typeid) +# @ir.register_value_caster(T.IntegerType.static_typeid) +# class ArithValue(ir.Value): +# def __init__(self, v): +# super().__init__(v) + +# __add__ = partialmethod(_binary_op, op="add") +# __sub__ = partialmethod(_binary_op, op="sub") +# __mul__ = partialmethod(_binary_op, op="mul") + +# def __str__(self): +# return super().__str__().replace(ir.Value.__name__, ArithValue.__name__) + + +class classproperty(property): + def __get__(self, owner_self, owner_cls): + return self.fget(owner_cls) + + +class block_idx: + @classproperty + def x(cls): + return gpu.block_id("x") + + @classproperty + def y(cls): + return gpu.block_id("y") + + @classproperty + def z(cls): + return gpu.block_id("z") + + +class block_dim: + @classproperty + def x(cls): + return gpu.block_dim("x") + + @classproperty + def y(cls): + return gpu.block_dim("y") + + @classproperty + def z(cls): + return gpu.block_dim("z") + + +class thread_idx: + @classproperty + def x(cls): + return gpu.thread_id("x") + + @classproperty + def y(cls): + return gpu.thread_id("y") + + @classproperty + def z(cls): + return gpu.thread_id("z") + + +class grid_dim: + @classproperty + def x(cls): + return gpu.grid_dim("x") + + @classproperty + def y(cls): + return gpu.grid_dim("y") + + @classproperty + def z(cls): + return gpu.grid_dim("z") + + +def make_int32(value): + return fly.make_int32(value) + + +def make_int32_tuple(value): + return fly.make_int32_tuple(value) + + +def rank(int_or_tuple): + return fly.rank(int_or_tuple) + + +def depth(int_or_tuple): + return fly.depth(int_or_tuple) + + +@dsl_api_wrapper +def int_tuple_add(lhs, rhs, loc=None, ip=None): + return fly.int_tuple_add(lhs, rhs, loc=loc, ip=ip) + + +@dsl_api_wrapper +def int_tuple_sub(lhs, rhs, loc=None, ip=None): + return fly.int_tuple_sub(lhs, rhs, loc=loc, ip=ip) + + +@dsl_api_wrapper +def int_tuple_mul(lhs, rhs, loc=None, ip=None): + return fly.int_tuple_mul(lhs, rhs, loc=loc, ip=ip) + + +@dsl_api_wrapper +def int_tuple_div(lhs, rhs, loc=None, ip=None): + return fly.int_tuple_div(lhs, rhs, loc=loc, ip=ip) + + +@dsl_api_wrapper +def int_tuple_product(int_tuple, loc=None, ip=None): + return fly.int_tuple_product(int_tuple, loc=loc, ip=ip) + + +@dsl_api_wrapper +def make_identity_tensor(shape, loc=None, ip=None): + return fly.make_identity_tensor(shape, loc=loc, ip=ip) + + +@dsl_api_wrapper +def make_identity_layout(shape, loc=None, ip=None): + return fly.make_identity_layout(shape, loc=loc, ip=ip) + + +@dsl_api_wrapper +def make_shape(*shape, loc=None, ip=None): + IntTupleTy, dyncElems = fly.infer_int_tuple_type(shape) + return fly.make_shape(IntTupleTy, dyncElems, loc=loc, ip=ip) + + +@dsl_api_wrapper +def make_stride(*stride, loc=None, ip=None): + IntTupleTy, dyncElems = fly.infer_int_tuple_type(stride) + return fly.make_stride(IntTupleTy, dyncElems, loc=loc, ip=ip) + + +@dsl_api_wrapper +def make_coord(*coord, loc=None, ip=None): + IntTupleTy, dyncElems = fly.infer_int_tuple_type(coord) + return fly.make_coord(IntTupleTy, dyncElems, loc=loc, ip=ip) + + +@dsl_api_wrapper +def make_int_tuple(elems, loc=None, ip=None): + IntTupleTy, dyncElems = fly.infer_int_tuple_type(elems) + return fly.make_int_tuple(IntTupleTy, dyncElems, loc=loc, ip=ip) + + +@dsl_api_wrapper +def make_layout(shape, stride, loc=None, ip=None): + if not isinstance(shape, ir.Value): + shapeTy, dyncElems = fly.infer_int_tuple_type(shape) + shape = fly.make_shape(shapeTy, dyncElems, loc=loc, ip=ip) + if not isinstance(stride, ir.Value): + strideTy, dyncElems = fly.infer_int_tuple_type(stride) + stride = fly.make_stride(strideTy, dyncElems, loc=loc, ip=ip) + return fly.make_layout(shape, stride=stride, loc=loc, ip=ip) + + +@dsl_api_wrapper +def size(int_tuple, loc=None, ip=None): + return fly.size(int_tuple, loc=loc, ip=ip) + + +@dsl_api_wrapper +def get_scalar(int_tuple, loc=None, ip=None): + return fly.get_scalar(int_tuple, loc=loc, ip=ip) + + +@dsl_api_wrapper +def slice(src, coord, loc=None, ip=None): + if not isinstance(coord, ir.Value): + coordTy, dyncElems = fly.infer_int_tuple_type(coord) + coord = fly.make_coord(coordTy, dyncElems, loc=loc, ip=ip) + return fly.slice(src, coord, loc=loc, ip=ip) + + +@dsl_api_wrapper +def crd2idx(crd, layout, loc=None, ip=None): + return fly.crd2idx(crd, layout, loc=loc, ip=ip) + + +@dsl_api_wrapper +def composition(layout, tiler, loc=None, ip=None): + return fly.composition(layout, tiler, loc=loc, ip=ip) + + +@dsl_api_wrapper +def complement(layout, codomain_size, loc=None, ip=None): + if not isinstance(codomain_size, ir.Value): + codomain_sizeTy, dyncElems = fly.infer_int_tuple_type(codomain_size) + codomain_size = fly.make_shape(codomain_sizeTy, dyncElems, loc=loc, ip=ip) + return fly.complement(layout, codomain_size=codomain_size, loc=loc, ip=ip) + + +@dsl_api_wrapper +def coalesce(layout, pattern=None, loc=None, ip=None): + return fly.coalesce(layout, pattern=pattern, loc=loc, ip=ip) + + +@dsl_api_wrapper +def zip(lhs, rhs, loc=None, ip=None): + return fly.zip(lhs, rhs, loc=loc, ip=ip) + + +@dsl_api_wrapper +def select(int_tuple, indices, loc=None, ip=None): + return fly.select(int_tuple, indices=indices, loc=loc, ip=ip) + + +@dsl_api_wrapper +def group(int_tuple, begin: int, end: int, loc=None, ip=None): + return fly.group(int_tuple, begin=begin, end=end, loc=loc, ip=ip) + + +@dsl_api_wrapper +def append(base, elem, n: int | None = None, loc=None, ip=None): + return fly.append(base, elem, n=n, loc=loc, ip=ip) + + +@dsl_api_wrapper +def prepend(base, elem, n: int | None = None, loc=None, ip=None): + return fly.prepend(base, elem, n=n, loc=loc, ip=ip) + + +@dsl_api_wrapper +def logical_divide(layout, divisor, loc=None, ip=None): + return fly.logical_divide(layout, divisor, loc=loc, ip=ip) + + +@dsl_api_wrapper +def zipped_divide(layout, divisor, loc=None, ip=None): + return fly.zipped_divide(layout, divisor, loc=loc, ip=ip) + + +@dsl_api_wrapper +def tiled_divide(layout, divisor, loc=None, ip=None): + return fly.tiled_divide(layout, divisor, loc=loc, ip=ip) + + +@dsl_api_wrapper +def flat_divide(layout, divisor, loc=None, ip=None): + return fly.flat_divide(layout, divisor, loc=loc, ip=ip) + + +@dsl_api_wrapper +def logical_product(layout, tiler, loc=None, ip=None): + return fly.logical_product(layout, tiler, loc=loc, ip=ip) + + +@dsl_api_wrapper +def zipped_product(layout, tiler, loc=None, ip=None): + return fly.zipped_product(layout, tiler, loc=loc, ip=ip) + + +@dsl_api_wrapper +def tiled_product(layout, tiler, loc=None, ip=None): + return fly.tiled_product(layout, tiler, loc=loc, ip=ip) + + +@dsl_api_wrapper +def flat_product(layout, tiler, loc=None, ip=None): + return fly.flat_product(layout, tiler, loc=loc, ip=ip) + + +@dsl_api_wrapper +def block_product(layout, tiler, loc=None, ip=None): + return fly.block_product(layout, tiler, loc=loc, ip=ip) + + +@dsl_api_wrapper +def raked_product(layout, tiler, loc=None, ip=None): + return fly.raked_product(layout, tiler, loc=loc, ip=ip) + + +@dsl_api_wrapper +def make_atom(atom_type, loc=None, ip=None): + return fly.make_atom(atom_type, loc=loc, ip=ip) + + +@dsl_api_wrapper +def make_tile(layouts, loc=None, ip=None): + return fly.make_tile(layouts, loc=loc, ip=ip) + + +@dsl_api_wrapper +def mma_atom_call(mma_atom, d, a, b, c, loc=None, ip=None): + return fly.mma_atom_call(mma_atom, d, a, b, c, loc=loc, ip=ip) + + +@dsl_api_wrapper +def copy_atom_call(copy_atom, src, dst, loc=None, ip=None): + return fly.copy_atom_call(copy_atom, src, dst, loc=loc, ip=ip) + + +@dsl_api_wrapper +def make_tiled_copy(copy_atom, layout_tv, tile_mn, loc=None, ip=None): + return fly.make_tiled_copy(copy_atom, layout_tv, tile_mn, loc=loc, ip=ip) + + +@dsl_api_wrapper +def memref_alloca(memref_type, layout, loc=None, ip=None): + return fly.memref_alloca(memref_type, layout, loc=loc, ip=ip) + + +@dsl_api_wrapper +def memref_load(memref, indices, loc=None, ip=None): + # `fly.memref.load` expects `indices` as `!fly.int_tuple` (typically a scalar offset). + # Accept convenience forms: + # - int_tuple Value (pass through) + # - python int / tuple/list (make_int_tuple) + # - index/i32/i64 Value (cast index->i32 then make_int_tuple) + if isinstance(indices, ir.Value): + if str(indices.type).startswith("!fly.int_tuple"): + return fly.memref_load(memref, indices, loc=loc, ip=ip) + # Common case: user passes `index` as a 1-D coordinate/offset. + if str(indices.type) == "index": + indices = arith.IndexCastOp(T.i32(), indices) + indices = make_int_tuple(indices, loc=loc, ip=ip) + return fly.memref_load(memref, indices, loc=loc, ip=ip) + + # List/tuple (e.g. [row]) or python int. + indices = make_int_tuple(indices, loc=loc, ip=ip) + return fly.memref_load(memref, indices, loc=loc, ip=ip) + + +@dsl_api_wrapper +def memref_store(value, memref, indices, loc=None, ip=None): + if isinstance(indices, ir.Value): + if str(indices.type).startswith("!fly.int_tuple"): + return fly.memref_store(value, memref, indices, loc=loc, ip=ip) + if str(indices.type) == "index": + indices = arith.IndexCastOp(T.i32(), indices) + indices = make_int_tuple(indices, loc=loc, ip=ip) + return fly.memref_store(value, memref, indices, loc=loc, ip=ip) + + indices = make_int_tuple(indices, loc=loc, ip=ip) + return fly.memref_store(value, memref, indices, loc=loc, ip=ip) + + +@dsl_api_wrapper +def memref_load_vec(memref, loc=None, ip=None): + return fly.memref_load_vec(memref, loc=loc, ip=ip) + + +@dsl_api_wrapper +def memref_store_vec(vector, memref, loc=None, ip=None): + return fly.memref_store_vec(vector, memref, loc=loc, ip=ip) + + +@dsl_api_wrapper +def get_layout(memref, loc=None, ip=None): + return fly.get_layout(memref, loc=loc, ip=ip) + + +@dsl_api_wrapper +def get_iter(memref, loc=None, ip=None): + return fly.get_iter(memref, loc=loc, ip=ip) + + +@dsl_api_wrapper +def make_view(iter, layout, loc=None, ip=None): + return fly.make_view(iter, layout, loc=loc, ip=ip) + + +@dsl_api_wrapper +def add_offset(ptr, offset, loc=None, ip=None): + if not isinstance(offset, ir.Value): + offset = make_int_tuple(offset, loc=loc, ip=ip) + return fly.add_offset(ptr, offset, loc=loc, ip=ip) + + +@dsl_api_wrapper +def cooperative_copy(tiled_copy, partition_idx, src, dst, loc=None, ip=None): + return fly.cooperative_copy( + tiled_copy, + partition_idx, + src, + dst, + loc=loc, + ip=ip, + ) + + +@dsl_api_wrapper +def printf(*args, format_str="", loc=None, ip=None): + def _convert_printf_value(val): + """Convert Python values to MLIR Values for printf.""" + if isinstance(val, ir.Value): + return val + elif isinstance(val, bool): + return arith.constant(T.i1(), int(val)) + elif isinstance(val, int): + return arith.constant(T.i32(), val) + elif isinstance(val, float): + return arith.constant(T.f64(), val) + elif hasattr(val, "__extract_ir_values__"): + ir_values = val.__extract_ir_values__() + if len(ir_values) == 1: + return ir_values[0] + raise ValueError(f"Cannot use multi-value type in printf: {type(val)}") + elif hasattr(val, "value") and isinstance(val.value, ir.Value): + return val.value + else: + raise ValueError(f"Cannot convert {type(val)} to MLIR Value for printf") + + if len(args) > 0 and isinstance(args[0], str): + format_str = args[0] + raw_values = list(args[1:]) + else: + raw_values = list(args) + + values = [_convert_printf_value(v) for v in raw_values] + return fly.print_(format_str, values, loc=loc, ip=ip) diff --git a/python/flydsl/expr/typing.py b/python/flydsl/expr/typing.py new file mode 100644 index 00000000..3285d179 --- /dev/null +++ b/python/flydsl/expr/typing.py @@ -0,0 +1,66 @@ +import ctypes +from typing import Generic, TypeVar + +from .._mlir import ir +from .._mlir.dialects import gpu, llvm +from .._mlir.extras import types as T + +ValueT = TypeVar("ValueT") + + +class Constexpr(Generic[ValueT]): + """ + Constexpr is transparent for mlir, it will be replaced by the actual value at compile time. + """ + + pass + + +class Int32: + def __init__(self, value): + self.value = value + + def __ir_types__(self): + return [T.i32()] + + def __c_pointers__(self): + return [ctypes.cast(ctypes.pointer(ctypes.c_int32(self.value)), ctypes.c_void_p)] + + @classmethod + def __new_from_ir_values__(self, values): + return Int32(values[0]) + + def __extract_ir_values__(self): + return [self.value] + + +class Tensor: + def __init__(self, value: ir.Value): + self.value = value + + @classmethod + def __new_from_ir_values__(cls, values): + return Tensor(values[0]) + + def __extract_ir_values__(self): + return [self.value] + + +class Stream: + def __init__(self, value): + self.value = value + + def __ir_types__(self): + return [gpu.AsyncTokenType.get()] + + def __c_pointers__(self): + if self.value is None: + return [ctypes.cast(ctypes.pointer(ctypes.c_void_p(0)), ctypes.c_void_p)] + return [ctypes.cast(ctypes.pointer(ctypes.c_void_p(self.value.cuda_stream)), ctypes.c_void_p)] + + @classmethod + def __new_from_ir_values__(cls, values): + return Stream(values[0]) + + def __extract_ir_values__(self): + return [self.value] diff --git a/python/flydsl/lang/__init__.py b/python/flydsl/lang/__init__.py deleted file mode 100644 index af4aefbe..00000000 --- a/python/flydsl/lang/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from .ir import * -from .typing import * diff --git a/python/flydsl/lang/ir/__init__.py b/python/flydsl/lang/ir/__init__.py deleted file mode 100644 index 3c1fa38c..00000000 --- a/python/flydsl/lang/ir/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -# from .types import * - -from .core import * -from .module import * - -# from .gpu import * diff --git a/python/flydsl/lang/ir/core.py b/python/flydsl/lang/ir/core.py deleted file mode 100644 index 1f8b0b32..00000000 --- a/python/flydsl/lang/ir/core.py +++ /dev/null @@ -1,386 +0,0 @@ -from functools import partialmethod -from functools import lru_cache - -from flydsl.lang.meta import dsl_api_wrapper - - -from .module import _global_ctx - -from ..._mlir import ir -from ..._mlir.dialects import fly as _fly_ir -from ..._mlir.dialects._fly_enum_gen import AddressSpace, CachePolicy - -from ..._mlir.dialects import arith -from ..._mlir.extras import types as T - -from ..._mlir.dialects.fly import ( - IntTupleType, - LayoutType, - SwizzleType, - PointerType, - MemRefType, - CopyAtomUniversalCopyType, - MmaAtomUniversalFMAType, - DLTensorAdaptor, -) - -from ..._mlir.dialects.fly_rocdl import ( - MmaAtomCDNA3_MFMAType, -) - - -def _binary_op(lhs, rhs, op: str) -> "ArithValue": - op = op.capitalize() - if arith._is_float_type(lhs.type) and arith._is_float_type(rhs.type): - op += "F" - elif arith._is_integer_like_type(lhs.type) and arith._is_integer_like_type( - lhs.type - ): - op += "I" - else: - raise NotImplementedError(f"Unsupported '{op}' operands: {lhs}, {rhs}") - - op = getattr(arith, f"{op}Op") - return op(lhs, rhs).result - - -@ir.register_value_caster(T.F16Type.static_typeid) -@ir.register_value_caster(T.F32Type.static_typeid) -@ir.register_value_caster(T.F64Type.static_typeid) -@ir.register_value_caster(T.IntegerType.static_typeid) -class ArithValue(ir.Value): - def __init__(self, v): - super().__init__(v) - - __add__ = partialmethod(_binary_op, op="add") - __sub__ = partialmethod(_binary_op, op="sub") - __mul__ = partialmethod(_binary_op, op="mul") - - def __str__(self): - return super().__str__().replace(ir.Value.__name__, ArithValue.__name__) - - -def make_int32(value): - return _fly_ir.make_int32(value) - - -def make_int32_tuple(value): - return _fly_ir.make_int32_tuple(value) - - -def rank(int_or_tuple): - return _fly_ir.rank(int_or_tuple) - - -def depth(int_or_tuple): - return _fly_ir.depth(int_or_tuple) - - -@dsl_api_wrapper -def int_tuple_add(lhs, rhs, loc=None, ip=None): - return _fly_ir.int_tuple_add(lhs, rhs, loc=loc, ip=ip) - - -@dsl_api_wrapper -def int_tuple_sub(lhs, rhs, loc=None, ip=None): - return _fly_ir.int_tuple_sub(lhs, rhs, loc=loc, ip=ip) - - -@dsl_api_wrapper -def int_tuple_mul(lhs, rhs, loc=None, ip=None): - return _fly_ir.int_tuple_mul(lhs, rhs, loc=loc, ip=ip) - - -@dsl_api_wrapper -def int_tuple_div(lhs, rhs, loc=None, ip=None): - return _fly_ir.int_tuple_div(lhs, rhs, loc=loc, ip=ip) - - -@dsl_api_wrapper -def int_tuple_product(int_tuple, loc=None, ip=None): - return _fly_ir.int_tuple_product(int_tuple, loc=loc, ip=ip) - - -@dsl_api_wrapper -def make_identity_tensor(shape, loc=None, ip=None): - return _fly_ir.make_identity_tensor(shape, loc=loc, ip=ip) - - -@dsl_api_wrapper -def make_identity_layout(shape, loc=None, ip=None): - return _fly_ir.make_identity_layout(shape, loc=loc, ip=ip) - - -@dsl_api_wrapper -def make_shape(*shape, loc=None, ip=None): - IntTupleTy, dyncElems = _fly_ir.infer_int_tuple_type(shape) - return _fly_ir.make_shape(IntTupleTy, dyncElems, loc=loc, ip=ip) - - -@dsl_api_wrapper -def make_stride(*stride, loc=None, ip=None): - IntTupleTy, dyncElems = _fly_ir.infer_int_tuple_type(stride) - return _fly_ir.make_stride(IntTupleTy, dyncElems, loc=loc, ip=ip) - - -@dsl_api_wrapper -def make_coord(*coord, loc=None, ip=None): - IntTupleTy, dyncElems = _fly_ir.infer_int_tuple_type(coord) - return _fly_ir.make_coord(IntTupleTy, dyncElems, loc=loc, ip=ip) - - -@dsl_api_wrapper -def make_int_tuple(elems, loc=None, ip=None): - IntTupleTy, dyncElems = _fly_ir.infer_int_tuple_type(elems) - return _fly_ir.make_int_tuple(IntTupleTy, dyncElems, loc=loc, ip=ip) - - -@dsl_api_wrapper -def make_layout(shape, stride, loc=None, ip=None): - if not isinstance(shape, ir.Value): - shapeTy, dyncElems = _fly_ir.infer_int_tuple_type(shape) - shape = _fly_ir.make_shape(shapeTy, dyncElems, loc=loc, ip=ip) - if not isinstance(stride, ir.Value): - strideTy, dyncElems = _fly_ir.infer_int_tuple_type(stride) - stride = _fly_ir.make_stride(strideTy, dyncElems, loc=loc, ip=ip) - return _fly_ir.make_layout(shape, stride=stride, loc=loc, ip=ip) - - -@dsl_api_wrapper -def size(int_tuple, loc=None, ip=None): - return _fly_ir.size(int_tuple, loc=loc, ip=ip) - - -@dsl_api_wrapper -def get_scalar(int_tuple, loc=None, ip=None): - return _fly_ir.get_scalar(int_tuple, loc=loc, ip=ip) - - -@dsl_api_wrapper -def slice(src, coord, loc=None, ip=None): - if not isinstance(coord, ir.Value): - coordTy, dyncElems = _fly_ir.infer_int_tuple_type(coord) - coord = _fly_ir.make_coord(coordTy, dyncElems, loc=loc, ip=ip) - return _fly_ir.slice(src, coord, loc=loc, ip=ip) - - -@dsl_api_wrapper -def crd2idx(crd, layout, loc=None, ip=None): - return _fly_ir.crd2idx(crd, layout, loc=loc, ip=ip) - - -@dsl_api_wrapper -def composition(layout, tiler, loc=None, ip=None): - return _fly_ir.composition(layout, tiler, loc=loc, ip=ip) - - -@dsl_api_wrapper -def complement(layout, codomain_size, loc=None, ip=None): - if not isinstance(codomain_size, ir.Value): - codomain_sizeTy, dyncElems = _fly_ir.infer_int_tuple_type(codomain_size) - codomain_size = _fly_ir.make_shape(codomain_sizeTy, dyncElems, loc=loc, ip=ip) - return _fly_ir.complement(layout, codomain_size=codomain_size, loc=loc, ip=ip) - - -@dsl_api_wrapper -def coalesce(layout, pattern=None, loc=None, ip=None): - return _fly_ir.coalesce(layout, pattern=pattern, loc=loc, ip=ip) - - -@dsl_api_wrapper -def zip(lhs, rhs, loc=None, ip=None): - return _fly_ir.zip(lhs, rhs, loc=loc, ip=ip) - - -@dsl_api_wrapper -def select(int_tuple, indices, loc=None, ip=None): - return _fly_ir.select(int_tuple, indices=indices, loc=loc, ip=ip) - - -@dsl_api_wrapper -def group(int_tuple, begin: int, end: int, loc=None, ip=None): - return _fly_ir.group(int_tuple, begin=begin, end=end, loc=loc, ip=ip) - - -@dsl_api_wrapper -def append(base, elem, n: int | None = None, loc=None, ip=None): - return _fly_ir.append(base, elem, n=n, loc=loc, ip=ip) - - -@dsl_api_wrapper -def prepend(base, elem, n: int | None = None, loc=None, ip=None): - return _fly_ir.prepend(base, elem, n=n, loc=loc, ip=ip) - - -@dsl_api_wrapper -def logical_divide(layout, divisor, loc=None, ip=None): - return _fly_ir.logical_divide(layout, divisor, loc=loc, ip=ip) - - -@dsl_api_wrapper -def zipped_divide(layout, divisor, loc=None, ip=None): - return _fly_ir.zipped_divide(layout, divisor, loc=loc, ip=ip) - - -@dsl_api_wrapper -def tiled_divide(layout, divisor, loc=None, ip=None): - return _fly_ir.tiled_divide(layout, divisor, loc=loc, ip=ip) - - -@dsl_api_wrapper -def flat_divide(layout, divisor, loc=None, ip=None): - return _fly_ir.flat_divide(layout, divisor, loc=loc, ip=ip) - - -@dsl_api_wrapper -def logical_product(layout, tiler, loc=None, ip=None): - return _fly_ir.logical_product(layout, tiler, loc=loc, ip=ip) - - -@dsl_api_wrapper -def zipped_product(layout, tiler, loc=None, ip=None): - return _fly_ir.zipped_product(layout, tiler, loc=loc, ip=ip) - - -@dsl_api_wrapper -def tiled_product(layout, tiler, loc=None, ip=None): - return _fly_ir.tiled_product(layout, tiler, loc=loc, ip=ip) - - -@dsl_api_wrapper -def flat_product(layout, tiler, loc=None, ip=None): - return _fly_ir.flat_product(layout, tiler, loc=loc, ip=ip) - - -@dsl_api_wrapper -def block_product(layout, tiler, loc=None, ip=None): - return _fly_ir.block_product(layout, tiler, loc=loc, ip=ip) - - -@dsl_api_wrapper -def raked_product(layout, tiler, loc=None, ip=None): - return _fly_ir.raked_product(layout, tiler, loc=loc, ip=ip) - - -@dsl_api_wrapper -def make_atom(atom_type, loc=None, ip=None): - return _fly_ir.make_atom(atom_type, loc=loc, ip=ip) - - -@dsl_api_wrapper -def make_tile(layouts, loc=None, ip=None): - return _fly_ir.make_tile(layouts, loc=loc, ip=ip) - - -@dsl_api_wrapper -def mma_atom_call(mma_atom, d, a, b, c, loc=None, ip=None): - return _fly_ir.mma_atom_call(mma_atom, d, a, b, c, loc=loc, ip=ip) - - -@dsl_api_wrapper -def copy_atom_call(copy_atom, src, dst, loc=None, ip=None): - return _fly_ir.copy_atom_call(copy_atom, src, dst, loc=loc, ip=ip) - - -@dsl_api_wrapper -def make_tiled_copy(copy_atom, layout_tv, tile_mn, loc=None, ip=None): - return _fly_ir.make_tiled_copy(copy_atom, layout_tv, tile_mn, loc=loc, ip=ip) - - -@dsl_api_wrapper -def memref_alloca(memref_type, layout, loc=None, ip=None): - return _fly_ir.memref_alloca(memref_type, layout, loc=loc, ip=ip) - - -@dsl_api_wrapper -def memref_load(memref, indices, loc=None, ip=None): - # `fly.memref.load` expects `indices` as `!fly.int_tuple` (typically a scalar offset). - # Accept convenience forms: - # - int_tuple Value (pass through) - # - python int / tuple/list (make_int_tuple) - # - index/i32/i64 Value (cast index->i32 then make_int_tuple) - if isinstance(indices, ir.Value): - if str(indices.type).startswith("!fly.int_tuple"): - return _fly_ir.memref_load(memref, indices, loc=loc, ip=ip) - # Common case: user passes `index` as a 1-D coordinate/offset. - if str(indices.type) == "index": - indices = arith.IndexCastOp(T.i32(), indices) - indices = make_int_tuple(indices, loc=loc, ip=ip) - return _fly_ir.memref_load(memref, indices, loc=loc, ip=ip) - - # List/tuple (e.g. [row]) or python int. - indices = make_int_tuple(indices, loc=loc, ip=ip) - return _fly_ir.memref_load(memref, indices, loc=loc, ip=ip) - - -@dsl_api_wrapper -def memref_store(value, memref, indices, loc=None, ip=None): - if isinstance(indices, ir.Value): - if str(indices.type).startswith("!fly.int_tuple"): - return _fly_ir.memref_store(value, memref, indices, loc=loc, ip=ip) - if str(indices.type) == "index": - indices = arith.IndexCastOp(T.i32(), indices) - indices = make_int_tuple(indices, loc=loc, ip=ip) - return _fly_ir.memref_store(value, memref, indices, loc=loc, ip=ip) - - indices = make_int_tuple(indices, loc=loc, ip=ip) - return _fly_ir.memref_store(value, memref, indices, loc=loc, ip=ip) - - -@dsl_api_wrapper -def memref_load_vec(memref, loc=None, ip=None): - return _fly_ir.memref_load_vec(memref, loc=loc, ip=ip) - - -@dsl_api_wrapper -def memref_store_vec(vector, memref, loc=None, ip=None): - return _fly_ir.memref_store_vec(vector, memref, loc=loc, ip=ip) - - -@dsl_api_wrapper -def get_layout(memref, loc=None, ip=None): - return _fly_ir.get_layout(memref, loc=loc, ip=ip) - - -@dsl_api_wrapper -def get_iter(memref, loc=None, ip=None): - return _fly_ir.get_iter(memref, loc=loc, ip=ip) - - -@dsl_api_wrapper -def make_view(iter, layout, loc=None, ip=None): - return _fly_ir.make_view(iter, layout, loc=loc, ip=ip) - - -@dsl_api_wrapper -def add_offset(ptr, offset, loc=None, ip=None): - if not isinstance(offset, ir.Value): - offset = make_int_tuple(offset, loc=loc, ip=ip) - return _fly_ir.add_offset(ptr, offset, loc=loc, ip=ip) - - -@dsl_api_wrapper -def cooperative_copy(tiled_copy, partition_idx, src, dst, loc=None, ip=None): - return _fly_ir.cooperative_copy( - tiled_copy, - partition_idx, - src, - dst, - loc=loc, - ip=ip, - ) - - -@dsl_api_wrapper -def print_op(*args, format_str="", loc=None, ip=None): - if len(args) > 0 and isinstance(args[0], str): - format_str = args[0] - values = list(args[1:]) - else: - values = list(args) - return _fly_ir.print_(format_str, values, loc=loc, ip=ip) - - -# ============================================================================== -# Fly Type Classes (MLIR-style API) -# ============================================================================== diff --git a/python/flydsl/lang/ir/gpu.py b/python/flydsl/lang/ir/gpu.py deleted file mode 100644 index e5dd1964..00000000 --- a/python/flydsl/lang/ir/gpu.py +++ /dev/null @@ -1,457 +0,0 @@ -import inspect -from functools import partial -import sys -from pathlib import Path -from functools import wraps -from typing import Any, List, Optional, Tuple, Union, Callable -from typing import Optional, List, Union, TypeVar - -from ..._mlir.dialects._func_ops_gen import FuncOp -from ..._mlir.extras import types as T -from ..._mlir.extras.meta import region_op, op_region_builder - - -from ..._mlir.dialects._ods_common import ( - _cext, - get_default_loc_context, - get_op_result_or_op_results, -) -from ..._mlir.dialects._gpu_ops_gen import _Dialect -from ..._mlir.dialects._gpu_ops_gen import * -from ..._mlir.dialects._gpu_enum_gen import * - - -from ..._mlir.ir import ( - ArrayAttr, - AttrBuilder, - Attribute, - Context, - InsertionPoint, - ShapedType, - Type, - UnitAttr, - Value, - FlatSymbolRefAttr, - FunctionType, - InsertionPoint, - OpView, - Operation, - OpResultList, - Type, - TypeAttr, - Value, - register_attribute_builder, -) - -_block_id = block_id -_thread_id = thread_id -_block_dim = block_dim -_grid_dim = grid_dim - - -class classproperty(property): - def __get__(self, owner_self, owner_cls): - return self.fget(owner_cls) - - -class block_idx: - @classproperty - def x(cls): - return _block_id("x") - - @classproperty - def y(cls): - return _block_id("y") - - @classproperty - def z(cls): - return _block_id("z") - - -class block_dim: - @classproperty - def x(cls): - return _block_dim("x") - - @classproperty - def y(cls): - return _block_dim("y") - - @classproperty - def z(cls): - return _block_dim("z") - - -class thread_idx: - @classproperty - def x(cls): - return _thread_id("x") - - @classproperty - def y(cls): - return _thread_id("y") - - @classproperty - def z(cls): - return _thread_id("z") - - -class grid_dim: - @classproperty - def x(cls): - return _grid_dim("x") - - @classproperty - def y(cls): - return _grid_dim("y") - - @classproperty - def z(cls): - return _grid_dim("z") - - -def gpu_attr(mnemonic, attr_value): - return Attribute.parse(f"#gpu.{mnemonic}<{attr_value}>") - - -class ModuleMeta(type): - def __new__(cls, name, bases, classdict, **kwargs): - ip = classdict.pop("ip") - new = super().__new__(cls, name, bases, classdict) - for k, v in classdict.items(): - if callable(v): - v.qualname = name - ip.__exit__(None, None, None) - return new - - -@_cext.register_operation(_Dialect, replace=True) -class GPUModuleOp(GPUModuleOp): - def __init__( - self, sym_name, targets: Optional[List[Attribute]] = None, *, loc=None, ip=None - ): - if targets is None: - targets = [] - for i, t in enumerate(targets): - if isinstance(t, str): - targets[i] = Attribute.parse(t) - _ods_context = get_default_loc_context(loc) - sym_name = ( - sym_name - if ( - issubclass(type(sym_name), Attribute) - or not AttrBuilder.contains("SymbolNameAttr") - ) - else AttrBuilder.get("SymbolNameAttr")(sym_name, context=_ods_context) - ) - super().__init__(sym_name=sym_name, targets=ArrayAttr.get(targets), ip=ip) - self.regions[0].blocks.append() - - @property - def body(self): - return self.regions[0].blocks[0] - - -module = region_op(GPUModuleOp) - - -class GPUModuleMeta(ModuleMeta): - @classmethod - def __prepare__(cls, name, bases, **kwargs): - loc = kwargs.pop("loc", None) - if loc is None: - loc = get_user_code_loc() - targets = kwargs.pop("targets", None) - gpu_module_op = GPUModuleOp( - sym_name=name, - targets=targets, - ip=kwargs.pop("ip", None), - loc=loc, - ) - ip = InsertionPoint(gpu_module_op.body) - ip.__enter__() - return {"ip": ip, "gpu_module_op": gpu_module_op} - - -@_cext.register_operation(_Dialect, replace=True) -class GPUFuncOp(GPUFuncOp): - def __init__( - self, - sym_name, - function_type, - *, - sym_visibility=None, - arg_attrs=None, - res_attrs=None, - workgroup_attrib_attrs=None, - private_attrib_attrs=None, - loc=None, - ip=None, - ): - super().__init__( - function_type=function_type, - arg_attrs=arg_attrs, - res_attrs=res_attrs, - workgroup_attrib_attrs=workgroup_attrib_attrs, - private_attrib_attrs=private_attrib_attrs, - loc=loc, - ip=ip, - ) - self.operation.attributes["gpu.kernel"] = UnitAttr.get() - _ods_context = get_default_loc_context(loc) - self.operation.attributes["sym_name"] = ( - sym_name - if ( - issubclass(type(sym_name), Attribute) - or not AttrBuilder.contains("SymbolNameAttr") - ) - else AttrBuilder.get("SymbolNameAttr")(sym_name, context=_ods_context) - ) - if sym_visibility is not None: - self.operation.attributes["sym_visibility"] = ( - sym_visibility - if ( - issubclass(type(sym_visibility), Attribute) - or not AttrBuilder.contains("StrAttr") - ) - else AttrBuilder.get("StrAttr")(sym_visibility, context=_ods_context) - ) - - -def isalambda(v): - LAMBDA = lambda: 0 - return isinstance(v, type(LAMBDA)) and v.__name__ == LAMBDA.__name__ - - -def prep_func_types(sig, return_types): - assert not ( - not sig.return_annotation is inspect.Signature.empty and len(return_types) > 0 - ), f"func can use return annotation or explicit return_types but not both" - return_types = ( - sig.return_annotation - if not sig.return_annotation is inspect.Signature.empty - else return_types - ) - if not isinstance(return_types, (tuple, list)): - return_types = [return_types] - return_types = list(return_types) - assert all( - isinstance(r, (str, Type, TypeVar)) or isalambda(r) for r in return_types - ), f"all return types must be ..._mlir types or strings or TypeVars or lambdas {return_types=}" - - input_types = [ - p.annotation - for p in sig.parameters.values() - if not p.annotation is inspect.Signature.empty - ] - assert all( - isinstance(r, (str, Type, TypeVar)) or isalambda(r) for r in input_types - ), f"all input types must be ..._mlir types or strings or TypeVars or lambdas {input_types=}" - user_loc = None - # If ir.Context is none (like for deferred func emit) - if user_loc is None: - user_locs = None - else: - user_locs = [user_loc] * len(sig.parameters) - return input_types, return_types, user_locs - - -@_cext.register_operation(_Dialect, replace=True) -class LaunchFuncOp(LaunchFuncOp): - def __init__( - self, - kernel: List[str], - grid_size: Tuple[Any, Any, Any], - block_size: Tuple[Any, Any, Any], - kernel_operands: List[Value] = None, - async_dependencies=None, - dynamic_shared_memory_size: Optional[Value] = None, - async_object=None, - *, - loc=None, - ip=None, - ): - _ods_context = get_default_loc_context(loc) - if async_dependencies is None: - async_dependencies = [] - async_token = None - grid_size_x, grid_size_y, grid_size_z = grid_size - block_size_x, block_size_y, block_size_z = block_size - - super().__init__( - async_token, - async_dependencies, - kernel, - grid_size_x, - grid_size_y, - grid_size_z, - block_size_x, - block_size_y, - block_size_z, - kernel_operands, - dynamicSharedMemorySize=dynamic_shared_memory_size, - asyncObject=async_object, - loc=loc, - ip=ip, - ) - - -class GPUFunc: - def __init__( - self, - body_builder, - func_op_ctor, - return_op_ctor, - call_op_ctor, - *, - return_types=None, - sym_visibility=None, - sym_name=None, - arg_attrs=None, - res_attrs=None, - func_attrs=None, - function_type=None, - generics: List[Union[TypeVar]] = None, - qualname=None, - loc=None, - ip=None, - ): - assert inspect.isfunction(body_builder), body_builder - assert inspect.isclass(func_op_ctor), func_op_ctor - if return_op_ctor is not None: - assert inspect.isclass(return_op_ctor), return_op_ctor - assert inspect.isclass(call_op_ctor), call_op_ctor - - self.body_builder = body_builder - if sym_name is None: - sym_name = self.body_builder.__name__ - self.func_name = sym_name - self.func_op_ctor = func_op_ctor - self.return_op_ctor = return_op_ctor - self.call_op_ctor = call_op_ctor - self.arg_attrs = arg_attrs - self.res_attrs = res_attrs - self.generics = generics - self.loc = loc - self.ip = ip - self._func_op = None - # in case this function lives inside a class - self.qualname = qualname - - self.sym_visibility = sym_visibility - self.func_attrs = func_attrs - if self.func_attrs is None: - self.func_attrs = {} - self.function_type = function_type - - if return_types is None: - return_types = [] - sig = inspect.signature(self.body_builder) - self.input_types, self.return_types, self.arg_locs = prep_func_types( - sig, return_types - ) - - def __str__(self): - return str(f"{self.__class__} {self.__dict__}") - - def emit(self, *call_args, decl=False, force=False): - if self._func_op is None or decl or force: - if self.function_type is None: - if len(call_args) == 0: - input_types = self.input_types[:] - locals = {"T": T} - for i, v in enumerate(input_types): - if isinstance(v, TypeVar): - v = v.__name__ - if isinstance(v, str): - input_types[i] = Type( - eval(v, self.body_builder.__globals__, locals) - ) - elif isalambda(v): - input_types[i] = v() - else: - input_types = [a.type for a in call_args] - - function_type = TypeAttr.get( - FunctionType.get( - inputs=input_types, - results=self.return_types, - ) - ) - else: - input_types = self.function_type.inputs - function_type = TypeAttr.get(self.function_type) - - self._func_op = self.func_op_ctor( - self.func_name, - function_type, - sym_visibility=self.sym_visibility, - arg_attrs=self.arg_attrs, - res_attrs=self.res_attrs, - loc=self.loc, - ip=self.ip or InsertionPoint.current, - ) - if isinstance(self._func_op, FuncOp): - self._func_op.attributes["llvm.emit_c_interface"] = UnitAttr.get() - for k, v in self.func_attrs.items(): - self._func_op.attributes[k] = v - - self._func_op.regions[0].blocks.append(*input_types, arg_locs=self.arg_locs) - builder_wrapper = op_region_builder( - self._func_op, self._func_op.regions[0], terminator=self.return_op_ctor - ) - - return_types = [] - - def grab_results(*args): - nonlocal return_types - results = self.body_builder(*args) - if isinstance(results, (tuple, list, OpResultList)): - return_types.extend([r.type for r in results]) - elif results is not None: - return_types.append(results.type) - return results - - if self.function_type is None: - builder_wrapper(grab_results) - function_type = FunctionType.get( - inputs=input_types, results=return_types - ) - self._func_op.attributes["function_type"] = TypeAttr.get(function_type) - else: - builder_wrapper(self.body_builder) - - return self._func_op - - -def gpu_func( - f, - *, - sym_visibility=None, - arg_attrs=None, - res_attrs=None, - func_attrs=None, - emit=False, - generics=None, - loc=None, - ip=None, -): - if generics is None and hasattr(f, "__type_params__") and f.__type_params__: - generics = f.__type_params__ - func_ = GPUFunc( - body_builder=f, - func_op_ctor=GPUFuncOp, - return_op_ctor=ReturnOp, - call_op_ctor=LaunchFuncOp, - sym_visibility=sym_visibility, - arg_attrs=arg_attrs, - res_attrs=res_attrs, - func_attrs=func_attrs, - generics=generics, - loc=loc, - ip=ip, - ) - func_.__name__ = f.__name__ - if emit: - func_.emit() - return func_ diff --git a/python/flydsl/lang/ir/module.py b/python/flydsl/lang/ir/module.py deleted file mode 100644 index cc4c7306..00000000 --- a/python/flydsl/lang/ir/module.py +++ /dev/null @@ -1,212 +0,0 @@ -import inspect -from typing import Optional - -from ..._mlir import ir -from ..._mlir.extras import types as T -from ..._mlir.dialects import arith, func, _gpu_ops_gen - - -from .gpu import ( - gpu_func, - prep_func_types, - LaunchFuncOp, - block_idx, - thread_idx, - block_dim, - grid_dim, -) - - -class GlobalRAIIMLIRContext: - context: ir.Context - location: ir.Location - - def __init__(self, allow_unregistered_dialects=False): - self.context = ir.Context() - if allow_unregistered_dialects: - self.context.allow_unregistered_dialects = True - self.context.__enter__() - self.location = ir.Location.unknown() - self.location.__enter__() - - def __del__(self): - self.location.__exit__(None, None, None) - self.context.__exit__(None, None, None) - - -class MlirModule: - GPU_MODULE_NAME = "kernels" - - cls_kernel_fn = [] - cls_jit_fn = [] - cls_kernel_sym = {} - - def __init_subclass__(cls, **kwargs): - super().__init_subclass__(**kwargs) - - # Initialize MLIR module for this subclass FIRST - cls.module = ir.Module.create() - cls.module.operation.attributes["gpu.container_module"] = ir.UnitAttr.get() - - with ir.InsertionPoint(cls.module.body): - cls.gpu_module = _gpu_ops_gen.module(cls.GPU_MODULE_NAME) - - # After MLIR module is created, collect functions registered by descriptors - # Descriptors __set_name__ runs during class creation, adding to temporary lists - # We need to move them to the class-specific lists - temp_kernel_fn = [] - temp_jit_fn = [] - temp_kernel_sym = {} - - # Collect from class __dict__ directly (not inherited) - for name, value in cls.__dict__.items(): - if isinstance(value, _KernelDescriptor): - # This descriptor belongs to this class - if hasattr(value, "_wrapper"): - temp_kernel_fn.append(value._wrapper) - temp_kernel_sym[name] = name - elif isinstance(value, _JitDescriptor): - if hasattr(value, "_wrapper"): - temp_jit_fn.append(value._wrapper) - - # Set class-specific lists - cls.cls_kernel_fn = temp_kernel_fn - cls.cls_jit_fn = temp_jit_fn - cls.cls_kernel_sym = temp_kernel_sym - - def __init__(self): - self.kernel_func_op = {} - for fn in self.cls_jit_fn: - fn(self) - for fn in self.cls_kernel_fn: - fn(self) - - def __repr__(self): - return str(self.module) - - def __getattr__(self, name: str): - if name in self.cls_kernel_sym.keys(): - return ir.SymbolRefAttr.get( - [self.GPU_MODULE_NAME, self.cls_kernel_sym[name]] - ) - raise AttributeError(f"{name} not found in kernel functions.") - - @classmethod - def create_gpu_module(cls, module_attrs=None): - cls.gpu_module = _gpu_ops_gen.module("kernels") - - @classmethod - def create_from_mlir_source(cls, file_path: str): - pass - - @classmethod - def kernel(cls, fn): - def wrapper(self, *args, **kwargs): - if len(self.gpu_module.bodyRegion.blocks) == 0: - self.gpu_module.bodyRegion.blocks.append() - with ir.InsertionPoint.at_block_begin(self.gpu_module.bodyRegion.blocks[0]): - self.kernel_func_op[fn.__name__] = gpu_func(fn, emit=True) - - cls.cls_kernel_fn.append(wrapper) - cls.cls_kernel_sym[fn.__name__] = fn.__name__ - return fn - - @classmethod - def jit(cls, fn): - def wrapper(self): - with ir.InsertionPoint.at_block_begin(self.module.body): - sig = inspect.signature(fn) - input_types, return_types, _ = prep_func_types(sig, []) - func.FuncOp.from_py_func(*input_types)(fn) - - cls.cls_jit_fn.append(wrapper) - return fn - - -class _KernelDescriptor: - """Descriptor that automatically registers kernel to the correct class.""" - - def __init__(self, fn): - self.fn = fn - self.name = fn.__name__ - self._wrapper = None - - def __set_name__(self, owner, name): - """Called when the descriptor is assigned to a class attribute.""" - # Check if owner is a subclass of MlirModule - try: - if issubclass(owner, MlirModule): - # Capture fn in the closure - fn = self.fn - - def wrapper(instance_self, *args, **kwargs): - if len(instance_self.gpu_module.bodyRegion.blocks) == 0: - instance_self.gpu_module.bodyRegion.blocks.append() - with ir.InsertionPoint.at_block_begin( - instance_self.gpu_module.bodyRegion.blocks[0] - ): - instance_self.kernel_func_op[fn.__name__] = gpu_func( - fn, emit=True - ) - - # Store the wrapper in the descriptor for later collection - self._wrapper = wrapper - self._name = name - except TypeError: - # owner is not a class, skip - pass - - def __get__(self, obj, objtype=None): - """Return the original function for method access.""" - if obj is None: - return self.fn - return self.fn.__get__(obj, objtype) - - -class _JitDescriptor: - """Descriptor that automatically registers jit function to the correct class.""" - - def __init__(self, fn): - self.fn = fn - self.name = fn.__name__ - self._wrapper = None - - def __set_name__(self, owner, name): - """Called when the descriptor is assigned to a class attribute.""" - # Check if owner is a subclass of MlirModule - try: - if issubclass(owner, MlirModule): - # Capture fn in the closure - fn = self.fn - - def wrapper(instance_self): - with ir.InsertionPoint.at_block_begin(instance_self.module.body): - sig = inspect.signature(fn) - input_types, return_types, _ = prep_func_types(sig, []) - func.FuncOp.from_py_func(*input_types)(fn) - - # Store the wrapper in the descriptor for later collection - self._wrapper = wrapper - except TypeError: - # owner is not a class, skip - pass - - def __get__(self, obj, objtype=None): - """Return the original function for method access.""" - if obj is None: - return self.fn - return self.fn.__get__(obj, objtype) - - -# Use descriptor-based decorators that return descriptors -def kernel(fn): - """Decorator that returns a descriptor for automatic class detection.""" - return _KernelDescriptor(fn) - - -def jit(fn): - """Decorator that returns a descriptor for automatic class detection.""" - return _JitDescriptor(fn) - - -_global_ctx = GlobalRAIIMLIRContext() diff --git a/python/flydsl/lang/ir/types.py b/python/flydsl/lang/ir/types.py deleted file mode 100644 index 011fa158..00000000 --- a/python/flydsl/lang/ir/types.py +++ /dev/null @@ -1,5 +0,0 @@ -# from fly_mlir.extras import types as T - - -class Tensor: - pass diff --git a/python/flydsl/lang/meta.py b/python/flydsl/lang/meta.py deleted file mode 100644 index ed52f839..00000000 --- a/python/flydsl/lang/meta.py +++ /dev/null @@ -1,37 +0,0 @@ -import inspect -from functools import wraps - -from .._mlir import ir - - -def dsl_api_wrapper(op): - @wraps(op) - def wrapper(*args, **kwargs): - loc = kwargs.pop("loc", None) - if loc is None: - frame = inspect.currentframe().f_back - frameInfo = inspect.getframeinfo(frame) - # Compatible with different Python versions: positions attribute is available in Python 3.11+ - if hasattr(frameInfo, 'positions') and frameInfo.positions: - lineno = frameInfo.positions.lineno - col_offset = frameInfo.positions.col_offset - else: - lineno = frameInfo.lineno - col_offset = 0 # Older versions don't provide column offset information - file_loc = ir.Location.file( - frameInfo.filename, - lineno, - col_offset, - ) - loc = ir.Location.name( - ( - "".join([c.strip() for c in frameInfo.code_context]) - if frameInfo.code_context - else frameInfo.function - ), - childLoc=file_loc, - ) - with loc: - return op(*args, **kwargs) - - return wrapper diff --git a/python/flydsl/lang/typing.py b/python/flydsl/lang/typing.py deleted file mode 100644 index 7e437ecc..00000000 --- a/python/flydsl/lang/typing.py +++ /dev/null @@ -1,32 +0,0 @@ -import ctypes -import numpy as np -import operator -from typing_extensions import deprecated -from functools import reduce -from typing import ( - Generic, - Protocol, - Union, - Any, - List, - Type, - TypeVar, - overload, - runtime_checkable, - get_origin, -) -from types import FunctionType -from dataclasses import dataclass -from abc import ABC, abstractmethod - - -class NumericType: - pass - - -class Int32: - def __init__(self, value): - self.value = value - - def __repr__(self): - return f"Int32({self.value})" diff --git a/python/flydsl/utils/env.py b/python/flydsl/utils/env.py index 82441a36..1d78c619 100644 --- a/python/flydsl/utils/env.py +++ b/python/flydsl/utils/env.py @@ -194,25 +194,39 @@ class CompileEnvManager(EnvManager): env_prefix = "COMPILE" opt_level = OptInt(2, min_value=0, max_value=3, description="Optimization level") - enable_debug_info = OptBool(True, description="Generate debug info in compiled code") - enable_verifier = OptBool(False, description="Verify IR module") - print_after_all = OptBool(False, description="Print IR after each MLIR pass") -class RuntimeEnvManager(EnvManager): - env_prefix = "RUNTIME" +class DebugEnvManager(EnvManager): + env_prefix = "DEBUG" + dump_asm = OptBool(False, description="Dump ASM to file") + dump_ir = OptBool(False, description="Dump IR to file") + dump_dir = OptStr(str(Path.home() / ".flydsl" / "debug"), description="Directory for dumping IR") + + # Logging options log_level = OptStr("WARNING", choices=["DEBUG", "INFO", "WARNING", "ERROR"], description="Logging level") log_to_file = OptStr("", description="Log file path, empty to disable file logging") log_to_console = OptBool(False, description="Enable console logging") + + # MLIR pass manager options + print_after_all = OptBool(False, description="Print IR after each MLIR pass") + enable_debug_info = OptBool(True, description="Generate debug info in compiled code") + enable_verifier = OptBool(True, description="Verify IR module") + + +class RuntimeEnvManager(EnvManager): + env_prefix = "RUNTIME" + cache_dir = OptStr(str(Path.home() / ".flydsl" / "cache"), description="Directory for caching compiled kernels") enable_cache = OptBool(True, description="Enable kernel caching") compile = CompileEnvManager() +debug = DebugEnvManager() runtime = RuntimeEnvManager() __all__ = [ "compile", + "debug", "runtime", ] diff --git a/python/flydsl/utils/logger.py b/python/flydsl/utils/logger.py index 5dc1c832..c780646a 100644 --- a/python/flydsl/utils/logger.py +++ b/python/flydsl/utils/logger.py @@ -3,7 +3,7 @@ __all__ = ["log"] -_FORMAT = "%(asctime)s - %(levelname)-8s - [%(funcName)s] - %(message)s" +_FORMAT = "%(asctime)s - %(levelname)-8s - [%(filename)s : %(funcName)s] - %(message)s" _FORMAT_SIMPLE = "| %(levelname)-8s - [%(funcName)s] - %(message)s" _logger: logging.Logger = None @@ -15,22 +15,22 @@ def _init_logger(): if _initialized: return - from .env import runtime + from .env import debug _logger = logging.getLogger("flydsl") _logger.setLevel(logging.DEBUG) _logger.propagate = False - level = getattr(logging, runtime.log_level) + level = getattr(logging, debug.log_level) - if runtime.log_to_console: + if debug.log_to_console: console_handler = logging.StreamHandler(sys.stderr) console_handler.setFormatter(logging.Formatter(_FORMAT_SIMPLE)) console_handler.setLevel(level) _logger.addHandler(console_handler) - if runtime.log_to_file: - file_handler = logging.FileHandler(runtime.log_to_file, mode="a", encoding="utf-8") + if debug.log_to_file: + file_handler = logging.FileHandler(debug.log_to_file, mode="w", encoding="utf-8") file_handler.setFormatter(logging.Formatter(_FORMAT)) file_handler.setLevel(level) _logger.addHandler(file_handler) diff --git a/python/mlir_flydsl/CMakeLists.txt b/python/mlir_flydsl/CMakeLists.txt index a3011d17..63ababab 100644 --- a/python/mlir_flydsl/CMakeLists.txt +++ b/python/mlir_flydsl/CMakeLists.txt @@ -152,3 +152,29 @@ add_custom_target(CopyFlyPythonSources ALL ) add_dependencies(CopyFlyPythonSources FlyPythonModules) + +################################################################################ +# Create symlinks to MLIR runtime libraries +################################################################################ + +set(_MLIR_RUNTIME_LIBS + mlir_rocm_runtime + mlir_runner_utils + mlir_c_runner_utils +) + +get_filename_component(_LLVM_LIB_DIR "${MLIR_DIR}/../../../lib" ABSOLUTE) + +foreach(_lib ${_MLIR_RUNTIME_LIBS}) + set(_src_lib "${_LLVM_LIB_DIR}/lib${_lib}.so") + set(_dst_link "${_MLIR_LIBS_DIR}/lib${_lib}.so") + add_custom_command( + OUTPUT "${_dst_link}" + COMMAND ${CMAKE_COMMAND} -E create_symlink "${_src_lib}" "${_dst_link}" + DEPENDS FlyPythonModules + COMMENT "Creating symlink to lib${_lib}.so" + ) + list(APPEND _RUNTIME_SYMLINKS "${_dst_link}") +endforeach() + +add_custom_target(MlirRuntimeSymlinks ALL DEPENDS ${_RUNTIME_SYMLINKS}) diff --git a/scripts/build_llvm.sh b/scripts/build_llvm.sh index 28e290b3..7f7bc09b 100755 --- a/scripts/build_llvm.sh +++ b/scripts/build_llvm.sh @@ -1,73 +1,77 @@ #!/bin/bash set -e -# Default to downloading llvm-project in the parent directory of flir SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" REPO_ROOT="$(cd "${SCRIPT_DIR}/.." && pwd)" BASE_DIR="$(cd "${REPO_ROOT}/.." && pwd)" -LLVM_SRC_DIR="$BASE_DIR/llvm-project" -LLVM_BUILD_DIR="$LLVM_SRC_DIR/buildmlir" -LLVM_INSTALL_DIR="${LLVM_INSTALL_DIR:-$LLVM_SRC_DIR/mlir_install}" -LLVM_INSTALL_TGZ="${LLVM_INSTALL_TGZ:-$LLVM_SRC_DIR/mlir_install.tgz}" -LLVM_PACKAGE_INSTALL="${LLVM_PACKAGE_INSTALL:-1}" -# LLVM_COMMIT="${LLVM_COMMIT:-04f968b02917}" -LLVM_COMMIT="${LLVM_COMMIT:-edf06d742821f34060f924dd9db5e01bed90c030}" - -echo "Base directory: $BASE_DIR" -echo "LLVM Source: $LLVM_SRC_DIR" -echo "LLVM Build: $LLVM_BUILD_DIR" -echo "LLVM Install: $LLVM_INSTALL_DIR" -echo "LLVM Tarball: $LLVM_INSTALL_TGZ" - -# 1. Clone LLVM + +LLVM_HASH_FILE="${REPO_ROOT}/cmake/llvm-hash.txt" +if [[ -f "${LLVM_HASH_FILE}" ]]; then + LLVM_COMMIT_DEFAULT=$(cat "${LLVM_HASH_FILE}" | tr -d '[:space:]') +else + LLVM_COMMIT_DEFAULT="edf06d742821" +fi + +LLVM_SRC_DIR="${LLVM_SRC_DIR:-$BASE_DIR/llvm-project}" +LLVM_BUILD_DIR="${LLVM_BUILD_DIR:-$LLVM_SRC_DIR/build}" +LLVM_INSTALL_DIR="${LLVM_INSTALL_DIR:-$LLVM_SRC_DIR/install}" +LLVM_COMMIT="${LLVM_COMMIT:-$LLVM_COMMIT_DEFAULT}" +LLVM_PACKAGE_INSTALL="${LLVM_PACKAGE_INSTALL:-0}" + +echo "==============================================" +echo "FlyDSL LLVM/MLIR Build Script" +echo "==============================================" +echo "LLVM Source: ${LLVM_SRC_DIR}" +echo "LLVM Build: ${LLVM_BUILD_DIR}" +echo "LLVM Commit: ${LLVM_COMMIT}" +echo "LLVM Install: ${LLVM_INSTALL_DIR}" +echo "==============================================" + if [ ! -d "$LLVM_SRC_DIR" ]; then - echo "Cloning llvm-project..." - git clone https://github.com/ROCm/llvm-project.git "$LLVM_SRC_DIR" + echo "Cloning llvm-project from ROCm fork..." + git clone --depth 1 https://github.com/ROCm/llvm-project.git "$LLVM_SRC_DIR" fi -echo "Checking out llvm-project commit ${LLVM_COMMIT} (amd-staging)..." -pushd "$LLVM_SRC_DIR" +pushd "$LLVM_SRC_DIR" > /dev/null -# Check if we need to switch remote to ROCm fork CURRENT_REMOTE=$(git remote get-url origin) if [[ "$CURRENT_REMOTE" == *"github.com/llvm/llvm-project"* ]]; then - echo "Detected upstream LLVM. Switching origin to ROCm fork for amd-staging..." + echo "Switching origin to ROCm fork..." git remote set-url origin https://github.com/ROCm/llvm-project.git fi -git fetch origin amd-staging -git checkout "${LLVM_COMMIT}" -popd +CURRENT_COMMIT=$(git rev-parse HEAD 2>/dev/null || echo "none") +SHORT_CURRENT=$(echo "$CURRENT_COMMIT" | cut -c1-12) +SHORT_TARGET=$(echo "$LLVM_COMMIT" | cut -c1-12) -# 2. Create Build Directory -mkdir -p "$LLVM_BUILD_DIR" -cd "$LLVM_BUILD_DIR" +if [[ "$SHORT_CURRENT" != "$SHORT_TARGET"* && "$SHORT_TARGET" != "$SHORT_CURRENT"* ]]; then + echo "Fetching and checking out commit ${LLVM_COMMIT}..." + git fetch --depth 1 origin "${LLVM_COMMIT}" + git checkout "${LLVM_COMMIT}" +else + echo "Already at commit ${SHORT_CURRENT}" +fi -# 3. Configure CMake -echo "Configuring LLVM..." +popd > /dev/null -# Install dependencies for Python bindings -echo "Installing Python dependencies..." -pip install nanobind numpy pybind11 +mkdir -p "$LLVM_BUILD_DIR" -# Check for ninja GENERATOR="Unix Makefiles" if command -v ninja &> /dev/null; then GENERATOR="Ninja" - echo "Using Ninja generator." -else - echo "Ninja not found. Using Unix Makefiles (this might be slower)." fi -# Build only MLIR and necessary Clang tools, targeting native architecture, in Release mode -# Explicitly set nanobind directory if found to help CMake locate it +echo "Installing Python build dependencies..." +pip install -q nanobind numpy pybind11 + NANOBIND_DIR=$(python3 -c "import nanobind; import os; print(os.path.dirname(nanobind.__file__) + '/cmake')") +echo "Configuring LLVM with ${GENERATOR}..." cmake -G "$GENERATOR" \ -S "$LLVM_SRC_DIR/llvm" \ -B "$LLVM_BUILD_DIR" \ -DLLVM_ENABLE_PROJECTS="mlir;clang" \ - -DLLVM_TARGETS_TO_BUILD="X86;NVPTX;AMDGPU" \ + -DLLVM_TARGETS_TO_BUILD="X86;AMDGPU" \ -DCMAKE_BUILD_TYPE=Release \ -DCMAKE_CXX_STANDARD=17 \ -DLLVM_ENABLE_ASSERTIONS=ON \ @@ -80,38 +84,27 @@ cmake -G "$GENERATOR" \ -DLLVM_BUILD_LLVM_DYLIB=OFF \ -DLLVM_LINK_LLVM_DYLIB=OFF -# 4. Build -echo "Starting build with $(nproc) parallel jobs..." -cmake --build . -j$(nproc) +NPROC=$(nproc 2>/dev/null || sysctl -n hw.ncpu 2>/dev/null || echo 4) +echo "Building with ${NPROC} parallel jobs..." +cmake --build "$LLVM_BUILD_DIR" -j"$NPROC" if [[ "${LLVM_PACKAGE_INSTALL}" == "1" ]]; then - echo "==============================================" - echo "Installing MLIR/LLVM to a clean prefix..." - rm -rf "${LLVM_INSTALL_DIR}" - mkdir -p "${LLVM_INSTALL_DIR}" - cmake --install "${LLVM_BUILD_DIR}" --prefix "${LLVM_INSTALL_DIR}" - - if [[ ! -d "${LLVM_INSTALL_DIR}/lib/cmake/mlir" ]]; then - echo "Error: install prefix missing lib/cmake/mlir: ${LLVM_INSTALL_DIR}" >&2 - exit 1 - fi - - echo "Creating tarball..." - tar -C "$(dirname "${LLVM_INSTALL_DIR}")" -czf "${LLVM_INSTALL_TGZ}" "$(basename "${LLVM_INSTALL_DIR}")" + echo "Installing to ${LLVM_INSTALL_DIR}..." + rm -rf "${LLVM_INSTALL_DIR}" + mkdir -p "${LLVM_INSTALL_DIR}" + cmake --install "${LLVM_BUILD_DIR}" --prefix "${LLVM_INSTALL_DIR}" + + if [[ ! -d "${LLVM_INSTALL_DIR}/lib/cmake/mlir" ]]; then + echo "Error: install prefix missing lib/cmake/mlir" >&2 + exit 1 + fi + + LLVM_INSTALL_TGZ="${LLVM_INSTALL_DIR}.tar.gz" + echo "Creating tarball: ${LLVM_INSTALL_TGZ}" + tar -C "$(dirname "${LLVM_INSTALL_DIR}")" -czf "${LLVM_INSTALL_TGZ}" "$(basename "${LLVM_INSTALL_DIR}")" fi echo "==============================================" -echo "LLVM/MLIR build completed successfully!" -echo "" -echo "To configure flir, use:" -echo "cmake .. -DMLIR_DIR=$LLVM_BUILD_DIR/lib/cmake/mlir" -if [[ "${LLVM_PACKAGE_INSTALL}" == "1" ]]; then - echo "" - echo "Packaged install prefix:" - echo " ${LLVM_INSTALL_DIR}" - echo "Use with:" - echo " export MLIR_PATH=${LLVM_INSTALL_DIR}" - echo "Tarball:" - echo " ${LLVM_INSTALL_TGZ}" -fi +echo "LLVM/MLIR build complete!" +echo "MLIR_DIR: ${LLVM_BUILD_DIR}/lib/cmake/mlir" echo "==============================================" diff --git a/scripts/dumpir.sh b/scripts/dumpir.sh deleted file mode 100755 index da06c5ab..00000000 --- a/scripts/dumpir.sh +++ /dev/null @@ -1,6 +0,0 @@ -#!/bin/bash -# dumpir.sh