From 6838cbf1e40af25101fb6887102034a5936c5a0b Mon Sep 17 00:00:00 2001 From: Feng Shijie Date: Fri, 20 Mar 2026 20:03:02 +0000 Subject: [PATCH 1/7] Add BufferFatPtr class and update rocdl conversion Made-with: Cursor --- lib/Conversion/FlyToROCDL/BufferFatPtr.h | 96 ++++ lib/Conversion/FlyToROCDL/FlyToROCDL.cpp | 415 +++++----------- python/flydsl/expr/primitive.py | 580 ++++++++++++++--------- tests/mlir/Conversion/pointer_ops.mlir | 18 +- 4 files changed, 594 insertions(+), 515 deletions(-) create mode 100644 lib/Conversion/FlyToROCDL/BufferFatPtr.h diff --git a/lib/Conversion/FlyToROCDL/BufferFatPtr.h b/lib/Conversion/FlyToROCDL/BufferFatPtr.h new file mode 100644 index 00000000..0180503f --- /dev/null +++ b/lib/Conversion/FlyToROCDL/BufferFatPtr.h @@ -0,0 +1,96 @@ +#ifndef FLYDSL_LIB_CONVERSION_FLYTOROCDL_BUFFERFATPTR_H +#define FLYDSL_LIB_CONVERSION_FLYTOROCDL_BUFFERFATPTR_H + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinTypes.h" + +#include "flydsl/Dialect/Fly/IR/FlyDialect.h" +#include "flydsl/Dialect/Fly/Utils/AddressSpaceUtils.h" + +namespace mlir::fly { + +class BufferFatPtr { + static constexpr unsigned kRsrcAddrSpace = mapToLLVMAddressSpace(AddressSpace::BufferDesc); + static constexpr unsigned kOffsetBitWidth = 32; + + fly::PointerType ptrTy; + Value fatPtr; + +public: + BufferFatPtr(fly::PointerType ptrTy, Value v) : ptrTy(ptrTy), fatPtr(v) { + assert(ptrTy.getAddressSpace().getValue() == AddressSpace::BufferDesc); + } + + static LLVM::LLVMStructType getType(MLIRContext *ctx) { + return LLVM::LLVMStructType::getLiteral(ctx, {LLVM::LLVMPointerType::get(ctx, kRsrcAddrSpace), + IntegerType::get(ctx, kOffsetBitWidth)}); + } + static Value pack(OpBuilder &b, Location loc, Value bufferRsrc, Value valOffset = nullptr) { + auto structTy = getType(b.getContext()); + Value undef = LLVM::UndefOp::create(b, loc, structTy); + if (!valOffset) { + valOffset = arith::ConstantIntOp::create(b, loc, 0, kOffsetBitWidth); + } + Value withRsrc = LLVM::InsertValueOp::create(b, loc, undef, bufferRsrc, ArrayRef{0}); + return LLVM::InsertValueOp::create(b, loc, withRsrc, valOffset, ArrayRef{1}); + } + + Value bufferRsrc(OpBuilder &b, Location loc) const { + return LLVM::ExtractValueOp::create(b, loc, fatPtr, ArrayRef{0}); + } + + Value valOffset(OpBuilder &b, Location loc) const { + return LLVM::ExtractValueOp::create(b, loc, fatPtr, ArrayRef{1}); + } + + Value byteOffset(OpBuilder &b, Location loc) const { + int64_t bits = ptrTy.getElemTy().getIntOrFloatBitWidth(); + Value off = valOffset(b, loc); + if (bits == 8) + return off; + if (bits > 8 && bits % 8 == 0) { + int64_t elemBytes = bits / 8; + Value scale = arith::ConstantIntOp::create(b, loc, elemBytes, kOffsetBitWidth).getResult(); + return arith::MulIOp::create(b, loc, off, scale); + } + Value scale = arith::ConstantIntOp::create(b, loc, bits, kOffsetBitWidth).getResult(); + off = arith::MulIOp::create(b, loc, off, scale); + Value const8 = arith::ConstantIntOp::create(b, loc, 8, kOffsetBitWidth).getResult(); + return arith::DivUIOp::create(b, loc, off, const8); + } + + Value swizzleByteOffset(OpBuilder &b, Location loc) const { + Value off = byteOffset(b, loc); + SwizzleAttr swizzle = ptrTy.getSwizzle(); + if (swizzle.isTrivialSwizzle()) + return off; + auto offsetTy = IntegerType::get(b.getContext(), kOffsetBitWidth); + int64_t bitMaskValue = ((int64_t{1} << swizzle.getMask()) - 1) + << (swizzle.getBase() + swizzle.getShift()); + Value bitMask = arith::ConstantIntOp::create(b, loc, offsetTy, bitMaskValue); + Value shiftAmt = arith::ConstantIntOp::create(b, loc, offsetTy, swizzle.getShift()); + Value masked = arith::AndIOp::create(b, loc, off, bitMask); + Value shifted = arith::ShRUIOp::create(b, loc, masked, shiftAmt); + return arith::XOrIOp::create(b, loc, off, shifted); + } + + Value addOffset(OpBuilder &b, Location loc, Value delta) const { + Type offTy = IntegerType::get(b.getContext(), kOffsetBitWidth); + if (delta.getType() != offTy) { + if (delta.getType().isIndex()) + delta = arith::IndexCastOp::create(b, loc, offTy, delta); + else if (delta.getType().getIntOrFloatBitWidth() < kOffsetBitWidth) + delta = arith::ExtSIOp::create(b, loc, offTy, delta); + else + delta = arith::TruncIOp::create(b, loc, offTy, delta); + } + Value newOff = arith::AddIOp::create(b, loc, valOffset(b, loc), delta); + return pack(b, loc, bufferRsrc(b, loc), newOff); + } +}; + +} // namespace mlir::fly + +#endif // FLYDSL_LIB_CONVERSION_FLYTOROCDL_BUFFERFATPTR_H diff --git a/lib/Conversion/FlyToROCDL/FlyToROCDL.cpp b/lib/Conversion/FlyToROCDL/FlyToROCDL.cpp index 1c9a4c60..2d4978b6 100644 --- a/lib/Conversion/FlyToROCDL/FlyToROCDL.cpp +++ b/lib/Conversion/FlyToROCDL/FlyToROCDL.cpp @@ -22,6 +22,8 @@ #include "flydsl/Dialect/Fly/Utils/LayoutUtils.h" #include "flydsl/Dialect/FlyROCDL/IR/Dialect.h" +#include "./BufferFatPtr.h" + namespace mlir { #define GEN_PASS_DEF_FLYTOROCDLCONVERSIONPASS #include "flydsl/Conversion/Passes.h.inc" @@ -32,85 +34,20 @@ using namespace mlir::fly; namespace { -inline unsigned mapToLLVMAddressSpace(AddressSpace space) { - switch (space) { - case AddressSpace::Global: - return 1; - case AddressSpace::Shared: - return 3; - case AddressSpace::Register: - return 5; - case AddressSpace::BufferDesc: - return 8; - } - return 0; -} - -static LLVM::LLVMStructType getBufferFatPtrType(MLIRContext *ctx) { - return LLVM::LLVMStructType::getLiteral( - ctx, {LLVM::LLVMPointerType::get(ctx, 8), IntegerType::get(ctx, 32)}); -} - -static bool isBufferFatPtr(Type ty) { - auto st = dyn_cast(ty); - if (!st || st.getBody().size() != 2) - return false; - auto ptrTy = dyn_cast(st.getBody()[0]); - return ptrTy && ptrTy.getAddressSpace() == 8 && st.getBody()[1].isInteger(32); -} - -static Value extractBufferRsrc(OpBuilder &b, Location loc, Value fatPtr) { - return LLVM::ExtractValueOp::create(b, loc, fatPtr, ArrayRef{0}); -} - -static Value extractBufferOffset(OpBuilder &b, Location loc, Value fatPtr) { - return LLVM::ExtractValueOp::create(b, loc, fatPtr, ArrayRef{1}); -} - -static Value createBufferFatPtr(OpBuilder &b, Location loc, MLIRContext *ctx, Value rsrc, - Value byteOffset) { - auto structTy = getBufferFatPtrType(ctx); - Value undef = LLVM::UndefOp::create(b, loc, structTy); - Value withRsrc = LLVM::InsertValueOp::create(b, loc, undef, rsrc, ArrayRef{0}); - return LLVM::InsertValueOp::create(b, loc, withRsrc, byteOffset, ArrayRef{1}); -} - -static int64_t getElemByteWidth(Type elemTy) { - if (auto ft = dyn_cast(elemTy)) - return ft.getWidth() / 8; - if (auto it = dyn_cast(elemTy)) - return it.getWidth() / 8; - return 0; -} - -static FailureOr toI32(Value v, Location loc, ConversionPatternRewriter &rewriter) { - Type i32Ty = rewriter.getI32Type(); - if (v.getType() == i32Ty) - return v; - if (v.getType().isIndex()) - return arith::IndexCastOp::create(rewriter, loc, i32Ty, v).getResult(); - if (auto intTy = dyn_cast(v.getType())) { - if (intTy.getWidth() < 32) - return arith::ExtSIOp::create(rewriter, loc, i32Ty, v).getResult(); - if (intTy.getWidth() > 32) - return arith::TruncIOp::create(rewriter, loc, i32Ty, v).getResult(); - } - return failure(); -} - -static FailureOr toI64(Value v, Location loc, ConversionPatternRewriter &rewriter) { - Type i64Ty = rewriter.getI64Type(); - if (v.getType() == i64Ty) - return v; - if (v.getType().isIndex()) - return arith::IndexCastOp::create(rewriter, loc, i64Ty, v).getResult(); - if (auto intTy = dyn_cast(v.getType())) { - if (intTy.getWidth() < 64) - return arith::ExtSIOp::create(rewriter, loc, i64Ty, v).getResult(); - if (intTy.getWidth() > 64) - return arith::TruncIOp::create(rewriter, loc, i64Ty, v).getResult(); - } - return failure(); +Value applySwizzleOnPtr(OpBuilder &b, Location loc, Value ptr, SwizzleAttr swizzle) { + if (swizzle.isTrivialSwizzle()) + return ptr; + auto ptrTy = cast(ptr.getType()); + auto i64Ty = b.getI64Type(); + Value ptrInt = LLVM::PtrToIntOp::create(b, loc, i64Ty, ptr); + int64_t bitMaskValue = ((int64_t{1} << swizzle.getMask()) - 1) + << (swizzle.getBase() + swizzle.getShift()); + Value bitMask = arith::ConstantIntOp::create(b, loc, i64Ty, bitMaskValue).getResult(); + Value shiftAmt = arith::ConstantIntOp::create(b, loc, i64Ty, swizzle.getShift()).getResult(); + Value masked = arith::AndIOp::create(b, loc, ptrInt, bitMask).getResult(); + Value shifted = arith::ShRUIOp::create(b, loc, masked, shiftAmt).getResult(); + Value swizzled = arith::XOrIOp::create(b, loc, ptrInt, shifted).getResult(); + return LLVM::IntToPtrOp::create(b, loc, ptrTy, swizzled); } class MakePtrOpLowering : public OpConversionPattern { @@ -126,7 +63,6 @@ class MakePtrOpLowering : public OpConversionPattern { Location loc = op.getLoc(); AddressSpace addrSpace = flyPtrTy.getAddressSpace().getValue(); - auto args = adaptor.getArgs(); if (addrSpace == AddressSpace::Register) { auto dictAttrs = op.getDictAttrs(); @@ -143,6 +79,7 @@ class MakePtrOpLowering : public OpConversionPattern { rewriter.replaceOp(op, ptr); return success(); } else if (addrSpace == AddressSpace::BufferDesc) { + auto args = adaptor.getArgs(); if (args.size() != 4) return rewriter.notifyMatchFailure( op, "buffer_rsrc make_ptr expects 4 args: base, stride, numRecords, flags"); @@ -152,30 +89,15 @@ class MakePtrOpLowering : public OpConversionPattern { Value numRecords = args[2]; Value flags = args[3]; - auto rsrcPtrTy = LLVM::LLVMPointerType::get(rewriter.getContext(), 8); - Value rsrc = ROCDL::MakeBufferRsrcOp::create(rewriter, loc, rsrcPtrTy, base, stride, - numRecords, flags); - Value zero = arith::ConstantIntOp::create(rewriter, loc, 0, 32).getResult(); - Value fatPtr = createBufferFatPtr(rewriter, loc, rewriter.getContext(), rsrc, zero); - rewriter.replaceOp(op, fatPtr); - return success(); - } - - auto resultTy = dyn_cast(getTypeConverter()->convertType(flyPtrTy)); - if (!resultTy) - return failure(); - - if (args.size() == 1) { - Value src = args[0]; - if (src.getType() == resultTy) { - rewriter.replaceOp(op, src); - return success(); - } - rewriter.replaceOpWithNewOp(op, resultTy, src); + auto rsrcPtrTy = LLVM::LLVMPointerType::get(rewriter.getContext(), + mapToLLVMAddressSpace(AddressSpace::BufferDesc)); + Value bufferRsrc = ROCDL::MakeBufferRsrcOp::create(rewriter, loc, rsrcPtrTy, base, stride, + numRecords, flags); + rewriter.replaceOp(op, BufferFatPtr::pack(rewriter, loc, bufferRsrc)); return success(); } - return rewriter.notifyMatchFailure(op, "unsupported make_ptr operand count"); + return rewriter.notifyMatchFailure(op, "unsupported make_ptr address space"); } }; @@ -278,45 +200,6 @@ class PtrToIntOpLowering : public OpConversionPattern { } }; -/// Materialize a scalar index from a non-array `!fly.int_tuple` value. -/// This is used for pointer/memref offset computations. -static FailureOr materializeScalarIndex(Value intTuple, Location loc, - ConversionPatternRewriter &rewriter) { - auto tupleTy = dyn_cast(intTuple.getType()); - if (!tupleTy) - return failure(); - - IntTupleAttr profile = tupleTy.getAttr(); - if (!profile.isLeaf()) - return failure(); - - // Static scalar. - if (auto intAttr = dyn_cast(profile.getValue())) { - if (intAttr.isStatic()) { - Value c = arith::ConstantIndexOp::create(rewriter, loc, intAttr.getValue()); - return c; - } - } - if (profile.getLeafAsInt().isNone()) { - Value c = arith::ConstantIndexOp::create(rewriter, loc, 0); - return c; - } - - // Dynamic scalar: expect it comes from fly.make_int_tuple with exactly one operand. - if (Operation *defOp = intTuple.getDefiningOp()) { - if (defOp->getName().getStringRef() == "fly.make_int_tuple" && defOp->getNumOperands() == 1) { - Value v = defOp->getOperand(0); - if (v.getType().isIndex()) - return v; - // Most Fly scalars are i32; cast to index when needed. - if (v.getType().isSignlessInteger()) - return arith::IndexCastOp::create(rewriter, loc, rewriter.getIndexType(), v).getResult(); - } - } - - return failure(); -} - class GetIterOpLowering : public OpConversionPattern { public: GetIterOpLowering(const TypeConverter &typeConverter, MLIRContext *context) @@ -324,12 +207,29 @@ class GetIterOpLowering : public OpConversionPattern { LogicalResult matchAndRewrite(GetIterOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - Value mem = adaptor.getMemref(); - Type resTy = getTypeConverter()->convertType(op.getResult().getType()); - if (!resTy) - return failure(); - assert(mem.getType() == resTy); - rewriter.replaceOp(op, mem); + rewriter.replaceOp(op, adaptor.getMemref()); + return success(); + } +}; + +class ApplySwizzleOpLowering : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(ApplySwizzleOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + rewriter.replaceOp(op, adaptor.getPtr()); + return success(); + } +}; + +class RecastIterOpLowering : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(RecastIterOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + rewriter.replaceOp(op, adaptor.getSrc()); return success(); } }; @@ -343,56 +243,38 @@ class AddOffsetOpLowering : public OpConversionPattern { ConversionPatternRewriter &rewriter) const override { auto loc = op.getLoc(); Value base = adaptor.getPtr(); + Value offset = adaptor.getOffset(); auto flyPtrTy = dyn_cast(op.getPtr().getType()); if (!flyPtrTy) return failure(); - auto offsetIdx = materializeScalarIndex(op.getOffset(), loc, rewriter); - if (failed(offsetIdx)) - return failure(); + auto offsetTy = dyn_cast(offset.getType()); + IntTupleAttr offsetAttr = offsetTy.getAttr(); + if (!offsetAttr.isLeaf()) + return rewriter.notifyMatchFailure(op, "offset must be a leaf int tuple"); - if (flyPtrTy.getAddressSpace().getValue() == AddressSpace::BufferDesc) { - if (!isBufferFatPtr(base.getType())) - return failure(); - Type elemTy = flyPtrTy.getElemTy(); - int64_t elemBytes = getElemByteWidth(elemTy); - if (elemBytes <= 0) - return failure(); - - FailureOr offsetI32 = toI32(*offsetIdx, loc, rewriter); - if (failed(offsetI32)) - return failure(); - - Value byteOffsetDelta = *offsetI32; - if (elemBytes > 1) { - Value scale = arith::ConstantIntOp::create(rewriter, loc, elemBytes, 32).getResult(); - byteOffsetDelta = arith::MulIOp::create(rewriter, loc, byteOffsetDelta, scale); - } + Value offsetVal; + auto offsetInt = offsetAttr.extractIntFromLeaf(); + if (offsetInt.isStatic()) { + offsetVal = arith::ConstantIntOp::create(rewriter, loc, offsetInt.getValue(), 32); + } else { + Operation *defOp = offset.getDefiningOp(); + offsetVal = defOp->getOperand(0); + } - Value oldOffset = extractBufferOffset(rewriter, loc, base); - Value newOffset = arith::AddIOp::create(rewriter, loc, oldOffset, byteOffsetDelta); - Value rsrc = extractBufferRsrc(rewriter, loc, base); - Value result = createBufferFatPtr(rewriter, loc, rewriter.getContext(), rsrc, newOffset); - rewriter.replaceOp(op, result); + if (flyPtrTy.getAddressSpace().getValue() == AddressSpace::BufferDesc) { + BufferFatPtr bp(flyPtrTy, base); + rewriter.replaceOp(op, bp.addOffset(rewriter, loc, offsetVal)); return success(); } - auto basePtrTy = dyn_cast(base.getType()); - if (!basePtrTy) - return failure(); - - auto resultTy = - dyn_cast(getTypeConverter()->convertType(op.getResult().getType())); - if (!resultTy) - return failure(); - - FailureOr offsetI64 = toI64(*offsetIdx, loc, rewriter); - if (failed(offsetI64)) + auto ptrTy = dyn_cast(base.getType()); + if (!ptrTy) return failure(); Type elemTy = flyPtrTy.getElemTy(); - Value gep = LLVM::GEPOp::create(rewriter, loc, resultTy, elemTy, base, ValueRange{*offsetI64}); + Value gep = LLVM::GEPOp::create(rewriter, loc, ptrTy, elemTy, base, ValueRange{offsetVal}); rewriter.replaceOp(op, gep); return success(); } @@ -405,19 +287,15 @@ class MakeViewOpLowering : public OpConversionPattern { LogicalResult matchAndRewrite(MakeViewOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - Value base = adaptor.getIter(); - Type resultTy = getTypeConverter()->convertType(op.getResult().getType()); - if (!resultTy) - return failure(); - if (base.getType() == resultTy) { - rewriter.replaceOp(op, base); + if (isa(op.getResult().getType())) { + assert(op.getResult().use_empty() && "coord_tensor result should have no uses"); + rewriter.eraseOp(op); return success(); - } - if (isa(base.getType()) && isa(resultTy)) { - rewriter.replaceOpWithNewOp(op, resultTy, base); + } else { + Value base = adaptor.getIter(); + rewriter.replaceOp(op, base); return success(); } - return failure(); } }; @@ -438,7 +316,6 @@ class MemRefLoadVecOpLowering : public OpConversionPattern { if (!resVecTy) return failure(); - // Opaque pointers: we can directly load a vector from the base address. Value loaded = LLVM::LoadOp::create(rewriter, loc, resVecTy, input); rewriter.replaceOp(op, loaded); return success(); @@ -485,25 +362,20 @@ class PtrLoadOpLowering : public OpConversionPattern { Type elemTy = flyPtrTy.getElemTy(); if (flyPtrTy.getAddressSpace().getValue() == AddressSpace::BufferDesc) { - if (!isBufferFatPtr(ptr.getType())) - return failure(); - Value rsrc = extractBufferRsrc(rewriter, loc, ptr); - Value offset = extractBufferOffset(rewriter, loc, ptr); + BufferFatPtr bp(flyPtrTy, ptr); Value zero = arith::ConstantIntOp::create(rewriter, loc, 0, 32).getResult(); ArrayAttr noAttrs; - Value loaded = ROCDL::RawPtrBufferLoadOp::create(rewriter, loc, elemTy, rsrc, offset, zero, - zero, noAttrs, noAttrs, noAttrs); + Value loaded = ROCDL::RawPtrBufferLoadOp::create( + rewriter, loc, elemTy, bp.bufferRsrc(rewriter, loc), bp.swizzleByteOffset(rewriter, loc), + zero, zero, noAttrs, noAttrs, noAttrs); + rewriter.replaceOp(op, loaded); + return success(); + } else { + ptr = applySwizzleOnPtr(rewriter, loc, ptr, flyPtrTy.getSwizzle()); + Value loaded = LLVM::LoadOp::create(rewriter, loc, elemTy, ptr); rewriter.replaceOp(op, loaded); return success(); } - - auto ptrTy = dyn_cast(ptr.getType()); - if (!ptrTy) - return failure(); - - Value loaded = LLVM::LoadOp::create(rewriter, loc, elemTy, ptr); - rewriter.replaceOp(op, loaded); - return success(); } }; @@ -522,25 +394,20 @@ class PtrStoreOpLowering : public OpConversionPattern { return failure(); if (flyPtrTy.getAddressSpace().getValue() == AddressSpace::BufferDesc) { - if (!isBufferFatPtr(ptr.getType())) - return failure(); - Value rsrc = extractBufferRsrc(rewriter, loc, ptr); - Value offset = extractBufferOffset(rewriter, loc, ptr); - Value zero = arith::ConstantIntOp::create(rewriter, loc, 0, 32).getResult(); + BufferFatPtr bp(flyPtrTy, ptr); + Value zero = arith::ConstantIntOp::create(rewriter, loc, 0, 32); ArrayAttr noAttrs; - ROCDL::RawPtrBufferStoreOp::create(rewriter, loc, value, rsrc, offset, zero, zero, noAttrs, + ROCDL::RawPtrBufferStoreOp::create(rewriter, loc, value, bp.bufferRsrc(rewriter, loc), + bp.swizzleByteOffset(rewriter, loc), zero, zero, noAttrs, noAttrs, noAttrs); rewriter.eraseOp(op); return success(); + } else { + ptr = applySwizzleOnPtr(rewriter, loc, ptr, flyPtrTy.getSwizzle()); + LLVM::StoreOp::create(rewriter, loc, value, ptr); + rewriter.eraseOp(op); + return success(); } - - auto ptrTy = dyn_cast(ptr.getType()); - if (!ptrTy) - return failure(); - - LLVM::StoreOp::create(rewriter, loc, value, ptr); - rewriter.eraseOp(op); - return success(); } }; @@ -575,7 +442,7 @@ class CopyAtomCallLowering : public OpConversionPattern { if (!isa(src.getType()) || !isa(dst.getType())) return rewriter.notifyMatchFailure(op, "src/dst are not llvm.ptr for universal copy"); - return lowerUniversalCopy(op, rewriter, loc, copyAtom, srcFlyTy, src, dst); + return lowerUniversalCopy(op, rewriter, loc, copyAtom, srcFlyTy, dstFlyTy, src, dst); } else if (isa(copyOpType)) return lowerCDNA3BufferCopy(op, rewriter, loc, copyAtom, srcFlyTy, dstFlyTy, src, dst); return rewriter.notifyMatchFailure(op, "unsupported CopyOp type"); @@ -598,45 +465,27 @@ class CopyAtomCallLowering : public OpConversionPattern { private: LogicalResult lowerUniversalCopy(CopyAtomCall op, ConversionPatternRewriter &rewriter, Location loc, CopyAtomType copyAtomTy, fly::MemRefType srcFlyTy, - Value src, Value dst) const { + fly::MemRefType dstFlyTy, Value src, Value dst) const { LayoutBuilder attrBuilder(rewriter.getContext()); auto thrValLayoutSrc = dyn_cast(copyAtomTy.getThrValLayoutSrc()); if (!thrValLayoutSrc) return rewriter.notifyMatchFailure(op, "getThrValLayoutSrc returned null or non-LayoutAttr"); - IntAttr numValSrcAttr = - intTupleProduct(attrBuilder, thrValLayoutSrc.getShape().at(1)).getLeafAsInt(); - if (!numValSrcAttr.isStatic()) - return rewriter.notifyMatchFailure(op, "NumValSrc is not static"); - int64_t numValSrc = numValSrcAttr.getValue(); + int32_t numValSrc = + intTupleProduct(attrBuilder, thrValLayoutSrc.getShape().at(1)).getLeafAsInt().getValue(); Type elemTy = srcFlyTy.getElemTy(); - int64_t elemBits = 0; - if (auto ft = dyn_cast(elemTy)) - elemBits = ft.getWidth(); - else if (auto it = dyn_cast(elemTy)) - elemBits = it.getWidth(); - else - return rewriter.notifyMatchFailure(op, "unsupported element type for memcpy sizing"); - if (elemBits <= 0) - return rewriter.notifyMatchFailure(op, "invalid element bit width"); - - int64_t copyBytes = numValSrc * elemBits / 8; - Value len = arith::ConstantIntOp::create(rewriter, loc, copyBytes, /*width=*/64).getResult(); - LLVM::MemcpyOp::create(rewriter, loc, dst, src, len, /*isVolatile=*/false); + int32_t elemBits = elemTy.getIntOrFloatBitWidth(); + Value srcPtr = applySwizzleOnPtr(rewriter, loc, src, srcFlyTy.getSwizzle()); + Value dstPtr = applySwizzleOnPtr(rewriter, loc, dst, dstFlyTy.getSwizzle()); + int32_t copyBytes = numValSrc * elemBits / 8; + Value len = arith::ConstantIntOp::create(rewriter, loc, copyBytes, /*width=*/32).getResult(); + LLVM::MemcpyOp::create(rewriter, loc, dstPtr, srcPtr, len, /*isVolatile=*/false); rewriter.eraseOp(op); return success(); } - static int64_t getElemTypeBitWidth(Type elemTy) { - if (auto ft = dyn_cast(elemTy)) - return ft.getWidth(); - if (auto it = dyn_cast(elemTy)) - return it.getWidth(); - return 0; - } - LogicalResult lowerCDNA3BufferCopy(CopyAtomCall op, ConversionPatternRewriter &rewriter, Location loc, CopyAtomType copyAtomTy, fly::MemRefType srcFlyTy, fly::MemRefType dstFlyTy, Value src, @@ -646,18 +495,11 @@ class CopyAtomCallLowering : public OpConversionPattern { auto thrValLayoutSrc = dyn_cast(copyAtomTy.getThrValLayoutSrc()); if (!thrValLayoutSrc) return rewriter.notifyMatchFailure(op, "getThrValLayoutSrc returned null"); - IntAttr numValSrcAttr = - intTupleProduct(attrBuilder, thrValLayoutSrc.getShape().at(1)).getLeafAsInt(); - if (!numValSrcAttr.isStatic()) - return rewriter.notifyMatchFailure(op, "NumValSrc is not static"); - int64_t numValSrc = numValSrcAttr.getValue(); + int32_t numValSrc = + intTupleProduct(attrBuilder, thrValLayoutSrc.getShape().at(1)).getLeafAsInt().getValue(); Type elemTy = srcFlyTy.getElemTy(); - int64_t elemBits = getElemTypeBitWidth(elemTy); - if (elemBits <= 0) - return rewriter.notifyMatchFailure(op, "unsupported element type"); - - int64_t vecWidth = numValSrc; + int32_t vecWidth = numValSrc; Type vecTy = vecWidth == 1 ? elemTy : VectorType::get({vecWidth}, elemTy); AddressSpace srcAS = srcFlyTy.getAddressSpace().getValue(); @@ -666,38 +508,29 @@ class CopyAtomCallLowering : public OpConversionPattern { bool srcIsBuffer = (srcAS == AddressSpace::BufferDesc); bool dstIsBuffer = (dstAS == AddressSpace::BufferDesc); + if (srcIsBuffer == dstIsBuffer) + return rewriter.notifyMatchFailure( + op, "CDNA3 buffer copy requires exactly one side to be BufferDesc"); + Value zero = arith::ConstantIntOp::create(rewriter, loc, 0, 32).getResult(); ArrayAttr noAttrs; - auto unpackBuffer = [&](Value val) -> std::pair { - if (isBufferFatPtr(val.getType())) - return {extractBufferRsrc(rewriter, loc, val), extractBufferOffset(rewriter, loc, val)}; - return {val, zero}; + auto unpackBuffer = [&](Value val, fly::MemRefType flyTy) -> std::pair { + BufferFatPtr bp(flyTy.getPointerType(), val); + return {bp.bufferRsrc(rewriter, loc), bp.swizzleByteOffset(rewriter, loc)}; }; - if (srcIsBuffer && !dstIsBuffer) { - auto [srcRsrc, srcOff] = unpackBuffer(src); + if (srcIsBuffer) { + auto [srcRsrc, srcOff] = unpackBuffer(src, srcFlyTy); Value loaded = ROCDL::RawPtrBufferLoadOp::create(rewriter, loc, vecTy, srcRsrc, srcOff, zero, zero, noAttrs, noAttrs, noAttrs); LLVM::StoreOp::create(rewriter, loc, loaded, dst); - } else if (!srcIsBuffer && dstIsBuffer) { - auto [dstRsrc, dstOff] = unpackBuffer(dst); + } else { + auto [dstRsrc, dstOff] = unpackBuffer(dst, dstFlyTy); Value loaded = LLVM::LoadOp::create(rewriter, loc, vecTy, src); ROCDL::RawPtrBufferStoreOp::create(rewriter, loc, loaded, dstRsrc, dstOff, zero, zero, noAttrs, noAttrs, noAttrs); - } else if (srcIsBuffer && dstIsBuffer) { - auto [srcRsrc, srcOff] = unpackBuffer(src); - auto [dstRsrc, dstOff] = unpackBuffer(dst); - Value loaded = ROCDL::RawPtrBufferLoadOp::create(rewriter, loc, vecTy, srcRsrc, srcOff, zero, - zero, noAttrs, noAttrs, noAttrs); - ROCDL::RawPtrBufferStoreOp::create(rewriter, loc, loaded, dstRsrc, dstOff, zero, zero, - noAttrs, noAttrs, noAttrs); - } else { - int64_t copyBytes = numValSrc * elemBits / 8; - Value len = arith::ConstantIntOp::create(rewriter, loc, copyBytes, 64).getResult(); - LLVM::MemcpyOp::create(rewriter, loc, dst, src, len, false); } - rewriter.eraseOp(op); return success(); } @@ -946,13 +779,13 @@ class FlyTypeConverter : public TypeConverter { addConversion([&](fly::MemRefType flyMemRefTy) -> Type { if (flyMemRefTy.getAddressSpace().getValue() == AddressSpace::BufferDesc) - return getBufferFatPtrType(flyMemRefTy.getContext()); + return BufferFatPtr::getType(flyMemRefTy.getContext()); unsigned as = mapToLLVMAddressSpace(flyMemRefTy.getAddressSpace().getValue()); return LLVM::LLVMPointerType::get(flyMemRefTy.getContext(), as); }); addConversion([&](fly::PointerType flyPtrTy) -> Type { if (flyPtrTy.getAddressSpace().getValue() == AddressSpace::BufferDesc) - return getBufferFatPtrType(flyPtrTy.getContext()); + return BufferFatPtr::getType(flyPtrTy.getContext()); unsigned as = mapToLLVMAddressSpace(flyPtrTy.getAddressSpace().getValue()); return LLVM::LLVMPointerType::get(flyPtrTy.getContext(), as); }); @@ -972,7 +805,7 @@ class ExtractAlignedPointerAsIndexLowering if (!resultType) resultType = op.getResult().getType(); if (src.getType() != resultType) - src = rewriter.create(op.getLoc(), resultType, src); + src = LLVM::AddrSpaceCastOp::create(rewriter, op.getLoc(), resultType, src); rewriter.replaceOp(op, src); return success(); } @@ -997,7 +830,7 @@ class FlyToROCDLConversionPass // Constructors target.addLegalOp(); - target.addLegalOp(); + target.addLegalOp(); FlyTypeConverter typeConverter; @@ -1039,16 +872,16 @@ class FlyToROCDLConversionPass patterns.add(typeConverter, context); patterns.add(typeConverter, context); - patterns.add(typeConverter, context); + patterns.add(typeConverter, + context); patterns.add(typeConverter, context); patterns.add(typeConverter, context); - patterns.add(typeConverter, context); - patterns.add(typeConverter, context); - patterns.add(typeConverter, context); - patterns.add(typeConverter, context); - patterns.add(typeConverter, context); - patterns.add(typeConverter, context); + patterns.add(typeConverter, context); + patterns.add(typeConverter, context); + patterns.add(typeConverter, context); patterns.add(typeConverter, context); + + // TODO: deprecated in the future patterns.add(typeConverter, context); populateFunctionOpInterfaceTypeConversionPattern(patterns, typeConverter); diff --git a/python/flydsl/expr/primitive.py b/python/flydsl/expr/primitive.py index 7f91f2c1..6e07a448 100644 --- a/python/flydsl/expr/primitive.py +++ b/python/flydsl/expr/primitive.py @@ -133,49 +133,20 @@ def depth(int_or_tuple): return fly.depth(int_or_tuple) -@traced_op -def static(result_type, loc=None, ip=None): - return fly.static(result_type, loc=loc, ip=ip) - - -@traced_op -def int_tuple_add(lhs, rhs, loc=None, ip=None): - return fly.int_tuple_add(lhs, rhs, loc=loc, ip=ip) +# ===----------------------------------------------------------------------=== # +# Constructors +# ===----------------------------------------------------------------------=== # @traced_op -def int_tuple_sub(lhs, rhs, loc=None, ip=None): - return fly.int_tuple_sub(lhs, rhs, loc=loc, ip=ip) - - -@traced_op -def int_tuple_mul(lhs, rhs, loc=None, ip=None): - return fly.int_tuple_mul(lhs, rhs, loc=loc, ip=ip) - - -@traced_op -def int_tuple_div(lhs, rhs, loc=None, ip=None): - return fly.int_tuple_div(lhs, rhs, loc=loc, ip=ip) - - -@traced_op -def int_tuple_product(int_tuple, loc=None, ip=None): - return fly.int_tuple_product(int_tuple, loc=loc, ip=ip) - - -@traced_op -def int_tuple_product_each(int_tuple, loc=None, ip=None): - return fly.int_tuple_product_each(int_tuple, loc=loc, ip=ip) - - -@traced_op -def make_identity_tensor(shape, loc=None, ip=None): - return fly.make_identity_tensor(shape, loc=loc, ip=ip) +def static(result_type, loc=None, ip=None): + return fly.static(result_type, loc=loc, ip=ip) @traced_op -def make_identity_layout(shape, loc=None, ip=None): - return fly.make_identity_layout(shape, loc=loc, ip=ip) +def make_int_tuple(elems, loc=None, ip=None): + IntTupleTy, dyncElems = fly.infer_int_tuple_type(elems) + return fly.make_int_tuple(IntTupleTy, dyncElems, loc=loc, ip=ip) @traced_op @@ -196,12 +167,6 @@ def make_coord(*coord, loc=None, ip=None): return fly.make_coord(IntTupleTy, dyncElems, loc=loc, ip=ip) -@traced_op -def make_int_tuple(elems, loc=None, ip=None): - IntTupleTy, dyncElems = fly.infer_int_tuple_type(elems) - return fly.make_int_tuple(IntTupleTy, dyncElems, loc=loc, ip=ip) - - @traced_op def make_layout(shape, stride, loc=None, ip=None): if not isinstance(shape, ir.Value): @@ -213,6 +178,11 @@ def make_layout(shape, stride, loc=None, ip=None): return fly.make_layout(shape, stride=stride, loc=loc, ip=ip) +@traced_op +def make_layout_like(ref, loc=None, ip=None): + return fly.make_layout_like(ref, loc=loc, ip=ip) + + @traced_op def make_ordered_layout(shape, order, loc=None, ip=None): if not isinstance(shape, ir.Value): @@ -225,27 +195,28 @@ def make_ordered_layout(shape, order, loc=None, ip=None): @traced_op -def make_fragment_like(tensor, dtype=None, loc=None, ip=None): - return fly.make_fragment_like(tensor, dtype=dtype, loc=loc, ip=ip) +def make_composed_layout(inner, offset, outer, loc=None, ip=None): + return fly.make_composed_layout(inner, offset, outer, loc=loc, ip=ip) @traced_op -def size(int_tuple, loc=None, ip=None): - result = fly.size(int_tuple, loc=loc, ip=ip) - # If the int_tuple is static, return the static value - result_ty = IntTupleType(result.type) - if result_ty.is_leaf and result_ty.is_static: - return Int32(result_ty.static_value) - return result +def make_identity_layout(shape, loc=None, ip=None): + return fly.make_identity_layout(shape, loc=loc, ip=ip) @traced_op -def cosize(layout, loc=None, ip=None): - result = fly.cosize(layout, loc=loc, ip=ip) - result_ty = IntTupleType(result.type) - if result_ty.is_leaf and result_ty.is_static: - return Int32(result_ty.static_value) - return result +def make_view(iter, layout, loc=None, ip=None): + return fly.make_view(iter, layout, loc=loc, ip=ip) + + +@traced_op +def make_fragment_like(tensor, dtype=None, loc=None, ip=None): + return fly.make_fragment_like(tensor, dtype=dtype, loc=loc, ip=ip) + + +# ===----------------------------------------------------------------------=== # +# Extractors +# ===----------------------------------------------------------------------=== # @traced_op @@ -253,6 +224,11 @@ def get_scalar(int_tuple, loc=None, ip=None): return fly.get_scalar(int_tuple, loc=loc, ip=ip) +@traced_op +def get_leaves(input, loc=None, ip=None): + return fly.get_leaves(input, loc=loc, ip=ip) + + @traced_op def get_shape(layout, loc=None, ip=None): return fly.get_shape(layout, loc=loc, ip=ip) @@ -264,101 +240,114 @@ def get_stride(layout, loc=None, ip=None): @traced_op -def slice(src, coord, loc=None, ip=None): - if not isinstance(coord, ir.Value): - coordTy, dyncElems = fly.infer_int_tuple_type(coord) - coord = fly.make_coord(coordTy, dyncElems, loc=loc, ip=ip) - return fly.slice(src, coord, loc=loc, ip=ip) +def get_layout(memref, loc=None, ip=None): + return fly.get_layout(memref, loc=loc, ip=ip) @traced_op -def get_leaf(int_tuple, leaf_idx, loc=None, ip=None): - return fly.get_leaf(int_tuple, leaf_idx, loc=loc, ip=ip) +def get_iter(memref, loc=None, ip=None): + return fly.get_iter(memref, loc=loc, ip=ip) @traced_op -def get_flat_coord(index, layout, loc=None, ip=None): - return fly.get_flat_coord(index, layout, loc=loc, ip=ip) +def composed_get_inner(input, loc=None, ip=None): + return fly.composed_get_inner(input, loc=loc, ip=ip) -def _to_i32(v): - """Cast index-type ir.Value to i32 (required by fly.make_int_tuple).""" - if isinstance(v, ir.Value) and isinstance(v.type, ir.IndexType): - return _arith.IndexCastOp(T.i32(), v).result - return v +@traced_op +def composed_get_offset(input, loc=None, ip=None): + return fly.composed_get_offset(input, loc=loc, ip=ip) @traced_op -def crd2idx(crd, layout, loc=None, ip=None): - if isinstance(crd, (list, tuple)): - crd_i32 = [_to_i32(c) for c in crd] - IntTupleTy, dyncElems = fly.infer_int_tuple_type(tuple(crd_i32)) - crd = fly.make_int_tuple(IntTupleTy, dyncElems, loc=loc, ip=ip) - return fly.crd2idx(crd, layout, loc=loc, ip=ip) +def composed_get_outer(input, loc=None, ip=None): + return fly.composed_get_outer(input, loc=loc, ip=ip) + + +# ===----------------------------------------------------------------------=== # +# IntTuple operations +# ===----------------------------------------------------------------------=== # @traced_op -def idx2crd(idx, layout, loc=None, ip=None): - if isinstance(idx, ir.Value) and not str(idx.type).startswith("!fly.int_tuple"): - idx = _to_i32(idx) - IntTupleTy, dyncElems = fly.infer_int_tuple_type(idx) - idx = fly.make_int_tuple(IntTupleTy, dyncElems, loc=loc, ip=ip) - return fly.idx2crd(idx, layout, loc=loc, ip=ip) +def int_tuple_add(lhs, rhs, loc=None, ip=None): + return fly.int_tuple_add(lhs, rhs, loc=loc, ip=ip) @traced_op -def get(int_tuple, mode, loc=None, ip=None): - if isinstance(int_tuple, (list, tuple)): - return int_tuple[mode] - selected = fly.select(int_tuple, indices=[mode], loc=loc, ip=ip) - result = fly.get_scalar(selected, loc=loc, ip=ip) - if isinstance(result, ir.Value) and not isinstance(result.type, ir.IndexType): - result = _arith.IndexCastOp(T.index(), result).result - return result +def int_tuple_sub(lhs, rhs, loc=None, ip=None): + return fly.int_tuple_sub(lhs, rhs, loc=loc, ip=ip) @traced_op -def composition(layout, tiler, loc=None, ip=None): - return fly.composition(layout, tiler, loc=loc, ip=ip) +def int_tuple_mul(lhs, rhs, loc=None, ip=None): + return fly.int_tuple_mul(lhs, rhs, loc=loc, ip=ip) @traced_op -def complement(layout, codomain_size, loc=None, ip=None): - if not isinstance(codomain_size, ir.Value): - codomain_sizeTy, dyncElems = fly.infer_int_tuple_type(codomain_size) - codomain_size = fly.make_shape(codomain_sizeTy, dyncElems, loc=loc, ip=ip) - return fly.complement(layout, codomain_size=codomain_size, loc=loc, ip=ip) +def int_tuple_div(lhs, rhs, loc=None, ip=None): + return fly.int_tuple_div(lhs, rhs, loc=loc, ip=ip) @traced_op -def right_inverse(layout, loc=None, ip=None): - return fly.right_inverse(layout, loc=loc, ip=ip) +def int_tuple_mod(lhs, rhs, loc=None, ip=None): + return fly.int_tuple_mod(lhs, rhs, loc=loc, ip=ip) @traced_op -def coalesce(layout, pattern=None, loc=None, ip=None): - return fly.coalesce(layout, pattern=pattern, loc=loc, ip=ip) +def int_tuple_product(int_tuple, loc=None, ip=None): + return fly.int_tuple_product(int_tuple, loc=loc, ip=ip) @traced_op -def recast_layout(layout, old_type_bits, new_type_bits, loc=None, ip=None): - def _to_static_bits(v): - if isinstance(v, int): - return v - if isinstance(v, ir.Type): - if hasattr(v, "width"): - return int(v.width) - raise TypeError(f"recast_layout only supports int/type-with-width, got type {v}") - raise TypeError(f"recast_layout only supports int/Type, got {type(v)}") +def int_tuple_product_each(int_tuple, loc=None, ip=None): + return fly.int_tuple_product_each(int_tuple, loc=loc, ip=ip) + + +@traced_op +def int_tuple_product_like(lhs, rhs, loc=None, ip=None): + return fly.int_tuple_product_like(lhs, rhs, loc=loc, ip=ip) - old_type_bits = _to_static_bits(old_type_bits) - new_type_bits = _to_static_bits(new_type_bits) - return fly.recast_layout(new_type_bits=new_type_bits, old_type_bits=old_type_bits, src=layout, loc=loc, ip=ip) + +@traced_op +def shape_div(lhs, rhs, loc=None, ip=None): + return fly.shape_div(lhs, rhs, loc=loc, ip=ip) + + +@traced_op +def ceil_div(lhs, rhs, loc=None, ip=None): + return fly.ceil_div(lhs, rhs, loc=loc, ip=ip) + + +@traced_op +def elem_less(lhs, rhs, loc=None, ip=None): + return fly.elem_less(lhs, rhs, loc=loc, ip=ip) + + +@traced_op +def equal(lhs, rhs, loc=None, ip=None): + return fly.equal(lhs, rhs, loc=loc, ip=ip) + + +# ===----------------------------------------------------------------------=== # +# IntTupleLike operations +# ===----------------------------------------------------------------------=== # + + +@traced_op +def get(int_tuple, mode, loc=None, ip=None): + if isinstance(int_tuple, (list, tuple)): + return int_tuple[mode] + selected = fly.select(int_tuple, indices=[mode], loc=loc, ip=ip) + result = fly.get_scalar(selected, loc=loc, ip=ip) + if isinstance(result, ir.Value) and not isinstance(result.type, ir.IndexType): + result = _arith.IndexCastOp(T.index(), result).result + return result @traced_op -def zip(lhs, rhs, loc=None, ip=None): - return fly.zip(lhs, rhs, loc=loc, ip=ip) +def take(int_tuple, begin: int, end: int, loc=None, ip=None): + return fly.take(int_tuple, begin=begin, end=end, loc=loc, ip=ip) @traced_op @@ -382,135 +371,189 @@ def prepend(base, elem, n: int | None = None, loc=None, ip=None): @traced_op -def logical_divide(layout, divisor, loc=None, ip=None): - return fly.logical_divide(layout, divisor, loc=loc, ip=ip) +def slice(src, coord, loc=None, ip=None): + if not isinstance(coord, ir.Value): + coordTy, dyncElems = fly.infer_int_tuple_type(coord) + coord = fly.make_coord(coordTy, dyncElems, loc=loc, ip=ip) + return fly.slice(src, coord, loc=loc, ip=ip) @traced_op -def zipped_divide(layout, divisor, loc=None, ip=None): - return fly.zipped_divide(layout, divisor, loc=loc, ip=ip) +def dice(src, coord, loc=None, ip=None): + if not isinstance(coord, ir.Value): + coordTy, dyncElems = fly.infer_int_tuple_type(coord) + coord = fly.make_coord(coordTy, dyncElems, loc=loc, ip=ip) + return fly.dice(src, coord, loc=loc, ip=ip) + + +# ===----------------------------------------------------------------------=== # +# LayoutLike operations +# ===----------------------------------------------------------------------=== # @traced_op -def tiled_divide(layout, divisor, loc=None, ip=None): - return fly.tiled_divide(layout, divisor, loc=loc, ip=ip) +def size(int_tuple, loc=None, ip=None): + result = fly.size(int_tuple, loc=loc, ip=ip) + result_ty = IntTupleType(result.type) + if result_ty.is_leaf and result_ty.is_static: + return Int32(result_ty.static_value) + return result @traced_op -def flat_divide(layout, divisor, loc=None, ip=None): - return fly.flat_divide(layout, divisor, loc=loc, ip=ip) +def coprofile(layout, loc=None, ip=None): + return fly.coprofile(layout, loc=loc, ip=ip) @traced_op -def logical_product(layout, tiler, loc=None, ip=None): - return fly.logical_product(layout, tiler, loc=loc, ip=ip) +def coshape(layout, loc=None, ip=None): + return fly.coshape(layout, loc=loc, ip=ip) @traced_op -def zipped_product(layout, tiler, loc=None, ip=None): - return fly.zipped_product(layout, tiler, loc=loc, ip=ip) +def cosize(layout, loc=None, ip=None): + result = fly.cosize(layout, loc=loc, ip=ip) + result_ty = IntTupleType(result.type) + if result_ty.is_leaf and result_ty.is_static: + return Int32(result_ty.static_value) + return result @traced_op -def tiled_product(layout, tiler, loc=None, ip=None): - return fly.tiled_product(layout, tiler, loc=loc, ip=ip) +def crd2idx(crd, layout, loc=None, ip=None): + if not isinstance(crd, ir.Value): + crdTy, dyncElems = fly.infer_int_tuple_type(crd) + crd = fly.make_coord(crdTy, dyncElems, loc=loc, ip=ip) + return fly.crd2idx(crd, layout, loc=loc, ip=ip) @traced_op -def flat_product(layout, tiler, loc=None, ip=None): - return fly.flat_product(layout, tiler, loc=loc, ip=ip) +def idx2crd(idx, layout, loc=None, ip=None): + if isinstance(idx, ir.Value) and not str(idx.type).startswith("!fly.int_tuple"): + IntTupleTy, dyncElems = fly.infer_int_tuple_type((idx,)) + idx = fly.make_int_tuple(IntTupleTy, dyncElems, loc=loc, ip=ip) + return fly.idx2crd(idx, layout, loc=loc, ip=ip) @traced_op -def block_product(layout, tiler, loc=None, ip=None): - return fly.block_product(layout, tiler, loc=loc, ip=ip) +def get_flat_coord(index, layout, loc=None, ip=None): + return fly.get_flat_coord(index, layout, loc=loc, ip=ip) @traced_op -def raked_product(layout, tiler, loc=None, ip=None): - return fly.raked_product(layout, tiler, loc=loc, ip=ip) +def get_1d_coord(index, layout, loc=None, ip=None): + return fly.get_1d_coord(index, layout, loc=loc, ip=ip) @traced_op -def memref_alloca(memref_type, layout, loc=None, ip=None): - return fly.memref_alloca(memref_type, layout, loc=loc, ip=ip) +def coalesce(layout, pattern=None, loc=None, ip=None): + return fly.coalesce(layout, pattern=pattern, loc=loc, ip=ip) @traced_op -def memref_load(memref, indices, loc=None, ip=None): - # `fly.memref.load` expects `indices` as `!fly.int_tuple` (typically a scalar offset). - # Accept convenience forms: - # - int_tuple Value (pass through) - # - python int / tuple/list (make_int_tuple) - # - index/i32/i64 Value (cast index->i32 then make_int_tuple) - if isinstance(indices, ir.Value): - if str(indices.type).startswith("!fly.int_tuple"): - return fly.memref_load(memref, indices, loc=loc, ip=ip) - # Common case: user passes `index` as a 1-D coordinate/offset. - if str(indices.type) == "index": - indices = _arith.IndexCastOp(T.i32(), indices) - indices = make_int_tuple(indices, loc=loc, ip=ip) - return fly.memref_load(memref, indices, loc=loc, ip=ip) +def composition(layout, tiler, loc=None, ip=None): + return fly.composition(layout, tiler, loc=loc, ip=ip) - # List/tuple (e.g. [row]) or python int. - indices = make_int_tuple(indices, loc=loc, ip=ip) - return fly.memref_load(memref, indices, loc=loc, ip=ip) + +@traced_op +def complement(layout, codomain_size=None, loc=None, ip=None): + if codomain_size is not None and not isinstance(codomain_size, ir.Value): + codomain_sizeTy, dyncElems = fly.infer_int_tuple_type(codomain_size) + codomain_size = fly.make_shape(codomain_sizeTy, dyncElems, loc=loc, ip=ip) + return fly.complement(layout, codomain_size=codomain_size, loc=loc, ip=ip) @traced_op -def memref_store(value, memref, indices, loc=None, ip=None): - if isinstance(indices, ir.Value): - if str(indices.type).startswith("!fly.int_tuple"): - return fly.memref_store(value, memref, indices, loc=loc, ip=ip) - if str(indices.type) == "index": - indices = _arith.IndexCastOp(T.i32(), indices) - indices = make_int_tuple(indices, loc=loc, ip=ip) - return fly.memref_store(value, memref, indices, loc=loc, ip=ip) +def right_inverse(layout, loc=None, ip=None): + return fly.right_inverse(layout, loc=loc, ip=ip) - indices = make_int_tuple(indices, loc=loc, ip=ip) - return fly.memref_store(value, memref, indices, loc=loc, ip=ip) + +@traced_op +def left_inverse(layout, loc=None, ip=None): + return fly.left_inverse(layout, loc=loc, ip=ip) @traced_op -def memref_load_vec(memref, loc=None, ip=None): - return fly.memref_load_vec(memref, loc=loc, ip=ip) +def logical_divide(layout, divisor, loc=None, ip=None): + return fly.logical_divide(layout, divisor, loc=loc, ip=ip) @traced_op -def memref_store_vec(vector, memref, loc=None, ip=None): - return fly.memref_store_vec(vector, memref, loc=loc, ip=ip) +def zipped_divide(layout, divisor, loc=None, ip=None): + return fly.zipped_divide(layout, divisor, loc=loc, ip=ip) @traced_op -def get_layout(memref, loc=None, ip=None): - return fly.get_layout(memref, loc=loc, ip=ip) +def tiled_divide(layout, divisor, loc=None, ip=None): + return fly.tiled_divide(layout, divisor, loc=loc, ip=ip) @traced_op -def get_iter(memref, loc=None, ip=None): - return fly.get_iter(memref, loc=loc, ip=ip) +def flat_divide(layout, divisor, loc=None, ip=None): + return fly.flat_divide(layout, divisor, loc=loc, ip=ip) @traced_op -def make_view(iter, layout, loc=None, ip=None): - return fly.make_view(iter, layout, loc=loc, ip=ip) +def logical_product(layout, tiler, loc=None, ip=None): + return fly.logical_product(layout, tiler, loc=loc, ip=ip) @traced_op -def make_ptr(result_type, args, loc=None, ip=None): - return fly.make_ptr(result_type, args, loc=loc, ip=ip) +def zipped_product(layout, tiler, loc=None, ip=None): + return fly.zipped_product(layout, tiler, loc=loc, ip=ip) @traced_op -def get_dyn_shared(loc=None, ip=None): - return fly.get_dyn_shared(loc=loc, ip=ip) +def tiled_product(layout, tiler, loc=None, ip=None): + return fly.tiled_product(layout, tiler, loc=loc, ip=ip) @traced_op -def add_offset(ptr, offset, loc=None, ip=None): - if not isinstance(offset, ir.Value): - offset = make_int_tuple(offset, loc=loc, ip=ip) - return fly.add_offset(ptr, offset, loc=loc, ip=ip) +def flat_product(layout, tiler, loc=None, ip=None): + return fly.flat_product(layout, tiler, loc=loc, ip=ip) + + +@traced_op +def block_product(layout, tiler, loc=None, ip=None): + return fly.block_product(layout, tiler, loc=loc, ip=ip) + + +@traced_op +def raked_product(layout, tiler, loc=None, ip=None): + return fly.raked_product(layout, tiler, loc=loc, ip=ip) + + +@traced_op +def recast_layout(layout, old_type_bits, new_type_bits, loc=None, ip=None): + def _to_static_bits(v): + if isinstance(v, int): + return v + if isinstance(v, ir.Type): + if hasattr(v, "width"): + return int(v.width) + raise TypeError(f"recast_layout only supports int/type-with-width, got type {v}") + raise TypeError(f"recast_layout only supports int/Type, got {type(v)}") + + old_type_bits = _to_static_bits(old_type_bits) + new_type_bits = _to_static_bits(new_type_bits) + return fly.recast_layout(new_type_bits=new_type_bits, old_type_bits=old_type_bits, src=layout, loc=loc, ip=ip) + + +@traced_op +def tile_to_shape(block, trg_shape, ord_shape, loc=None, ip=None): + return fly.tile_to_shape(block, trg_shape, ord_shape, loc=loc, ip=ip) + + +# ===----------------------------------------------------------------------=== # +# Atom and Tiled Mma/Copy ops +# ===----------------------------------------------------------------------=== # + + +@traced_op +def make_mma_atom(atom_type, loc=None, ip=None): + from .derived import MmaAtom + + return MmaAtom(fly.make_mma_atom(atom_type, loc=loc, ip=ip)) @traced_op @@ -534,25 +577,8 @@ def make_copy_atom(copy_op_type, elem_type, loc=None, ip=None): @traced_op -def make_mma_atom(atom_type, loc=None, ip=None): - from .derived import MmaAtom - - return MmaAtom(fly.make_mma_atom(atom_type, loc=loc, ip=ip)) - - -@traced_op -def make_tile(*args, loc=None, ip=None): - if len(args) == 1 and isinstance(args[0], (list, tuple)): - modes = args[0] - else: - modes = args - resolved = [] - for m in modes: - if isinstance(m, int): - resolved.append(make_layout(m, 1, loc=loc, ip=ip)) - else: - resolved.append(m) - return fly.make_tile(resolved, loc=loc, ip=ip) +def copy_atom_call(copy_atom, src, dst, *, pred=None, loc=None, ip=None): + return fly.copy_atom_call(copy_atom, src, dst, pred=pred, loc=loc, ip=ip) @traced_op @@ -560,14 +586,6 @@ def mma_atom_call(mma_atom, d, a, b, c, loc=None, ip=None): return fly.mma_atom_call(mma_atom, d, a, b, c, loc=loc, ip=ip) -@traced_op -def copy_atom_call(copy_atom, src, dst, *, pred=None, loc=None, ip=None): - kwargs = dict(loc=loc, ip=ip) - if pred is not None: - kwargs["pred"] = pred - return fly.copy_atom_call(copy_atom, src, dst, **kwargs) - - @traced_op def make_tiled_copy(copy_atom, layout_thr_val, tile_mn, loc=None, ip=None): from .derived import TiledCopy @@ -602,6 +620,16 @@ def tiled_mma_partition(operand_id, tiled_mma, t, coord, loc=None, ip=None): return fly.tiled_mma_partition(operand_id, tiled_mma, t, coord, loc=loc, ip=ip) +@traced_op +def tiled_mma_partition_shape(operand_id, tiled_mma, shape, loc=None, ip=None): + return fly.tiled_mma_partition_shape(operand_id, tiled_mma, shape, loc=loc, ip=ip) + + +@traced_op +def mma_atom_make_fragment(operand_id, tiled_mma, input, loc=None, ip=None): + return fly.mma_atom_make_fragment(operand_id, tiled_mma, input, loc=loc, ip=ip) + + @traced_op def copy(copy_atom, src, dst, *, pred=None, loc=None, ip=None): return fly.copy(copy_atom, src, dst, pred=pred, loc=loc, ip=ip) @@ -612,11 +640,104 @@ def gemm(mma_atom, d, a, b, c, loc=None, ip=None): return fly.gemm(mma_atom, d, a, b, c, loc=loc, ip=ip) +# ===----------------------------------------------------------------------=== # +# MemRef and Ptr operations +# ===----------------------------------------------------------------------=== # + + +@traced_op +def make_ptr(result_type, args, loc=None, ip=None): + return fly.make_ptr(result_type, args, loc=loc, ip=ip) + + +@traced_op +def inttoptr(result_type, src, loc=None, ip=None): + return fly.inttoptr(result_type, src, loc=loc, ip=ip) + + +@traced_op +def ptrtoint(ptr, loc=None, ip=None): + return fly.ptrtoint(ptr, loc=loc, ip=ip) + + +@traced_op +def add_offset(ptr, offset, loc=None, ip=None): + if not isinstance(offset, ir.Value): + offset = make_int_tuple(offset, loc=loc, ip=ip) + return fly.add_offset(ptr, offset, loc=loc, ip=ip) + + +@traced_op +def apply_swizzle(ptr, swizzle, loc=None, ip=None): + return fly.apply_swizzle(ptr, swizzle, loc=loc, ip=ip) + + +@traced_op +def ptr_load(ptr, loc=None, ip=None): + return fly.ptr_load(ptr, loc=loc, ip=ip) + + +@traced_op +def ptr_store(value, ptr, loc=None, ip=None): + return fly.ptr_store(value, ptr, loc=loc, ip=ip) + + +@traced_op +def recast_iter(result_type, src, loc=None, ip=None): + return fly.recast_iter(result_type, src, loc=loc, ip=ip) + + +@traced_op +def memref_alloca(memref_type, layout, loc=None, ip=None): + return fly.memref_alloca(memref_type, layout, loc=loc, ip=ip) + + +@traced_op +def memref_load_vec(memref, loc=None, ip=None): + return fly.memref_load_vec(memref, loc=loc, ip=ip) + + +@traced_op +def memref_store_vec(vector, memref, loc=None, ip=None): + return fly.memref_store_vec(vector, memref, loc=loc, ip=ip) + + +@traced_op +def memref_load(memref, indices, loc=None, ip=None): + if isinstance(indices, ir.Value): + if str(indices.type).startswith("!fly.int_tuple"): + return fly.memref_load(memref, indices, loc=loc, ip=ip) + if str(indices.type) == "index": + indices = _arith.IndexCastOp(T.i32(), indices) + indices = make_int_tuple(indices, loc=loc, ip=ip) + return fly.memref_load(memref, indices, loc=loc, ip=ip) + + indices = make_int_tuple(indices, loc=loc, ip=ip) + return fly.memref_load(memref, indices, loc=loc, ip=ip) + + +@traced_op +def memref_store(value, memref, indices, loc=None, ip=None): + if isinstance(indices, ir.Value): + if str(indices.type).startswith("!fly.int_tuple"): + return fly.memref_store(value, memref, indices, loc=loc, ip=ip) + if str(indices.type) == "index": + indices = _arith.IndexCastOp(T.i32(), indices) + indices = make_int_tuple(indices, loc=loc, ip=ip) + return fly.memref_store(value, memref, indices, loc=loc, ip=ip) + + indices = make_int_tuple(indices, loc=loc, ip=ip) + return fly.memref_store(value, memref, indices, loc=loc, ip=ip) + + +# ===----------------------------------------------------------------------=== # +# Utility ops +# ===----------------------------------------------------------------------=== # + + @traced_op def printf(*args, format_str="", loc=None, ip=None): def _convert_printf_value(val): - """Convert Python values to MLIR Values for printf. - Returns tuple of (is_static, value) where is_static=True means value is a string to embed.""" if isinstance(val, ir.Value): return (False, val) elif isinstance(val, type): @@ -671,3 +792,36 @@ def _convert_printf_value(val): final_format = "".join(result_parts) return fly.print_(final_format, ir_values, loc=loc, ip=ip) + + +@traced_op +def assume(result_type, dst, src, loc=None, ip=None): + return fly.assume(result_type, dst, src, loc=loc, ip=ip) + + +# ===----------------------------------------------------------------------=== # +# Deprecated +# ===----------------------------------------------------------------------=== # + + +@traced_op +def make_tile(*args, loc=None, ip=None): + if len(args) == 1 and isinstance(args[0], (list, tuple)): + modes = args[0] + else: + modes = args + resolved = [] + for m in modes: + if isinstance(m, int): + resolved.append(make_layout(m, 1, loc=loc, ip=ip)) + else: + resolved.append(m) + return fly.make_tile(resolved, loc=loc, ip=ip) + + +@traced_op +def make_identity_tensor(*shape, loc=None, ip=None): + base = make_int_tuple(tuple([0 for i in range(len(shape))]), loc=loc, ip=ip) + shapeTuple = make_int_tuple(shape, loc=loc, ip=ip) + layout = make_identity_layout(shapeTuple, loc=loc, ip=ip) + return fly.make_view(base, layout, loc=loc, ip=ip) diff --git a/tests/mlir/Conversion/pointer_ops.mlir b/tests/mlir/Conversion/pointer_ops.mlir index f72637bd..57809570 100644 --- a/tests/mlir/Conversion/pointer_ops.mlir +++ b/tests/mlir/Conversion/pointer_ops.mlir @@ -20,8 +20,7 @@ func.func @test_get_iter_global(%mem: !fly.memref) { // CHECK-NOT: fly.get_iter %iter = fly.get_iter(%mem) : (!fly.memref) -> !fly.ptr %offset = fly.make_int_tuple() : () -> !fly.int_tuple<8> - // get_iter is eliminated; %MEM is directly used as the GEP base pointer. - // CHECK: llvm.getelementptr %[[MEM]][{{.*}}] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, f32 + // CHECK: llvm.getelementptr %[[MEM]][{{.*}}] : (!llvm.ptr<1>, i32) -> !llvm.ptr<1>, f32 %result = fly.add_offset(%iter, %offset) : (!fly.ptr, !fly.int_tuple<8>) -> !fly.ptr return } @@ -32,7 +31,7 @@ func.func @test_get_iter_shared(%mem: !fly.memref) { // CHECK-NOT: fly.get_iter %iter = fly.get_iter(%mem) : (!fly.memref) -> !fly.ptr %offset = fly.make_int_tuple() : () -> !fly.int_tuple<4> - // CHECK: llvm.getelementptr %[[MEM]][{{.*}}] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, f32 + // CHECK: llvm.getelementptr %[[MEM]][{{.*}}] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, f32 %result = fly.add_offset(%iter, %offset) : (!fly.ptr, !fly.int_tuple<4>) -> !fly.ptr return } @@ -43,7 +42,7 @@ func.func @test_get_iter_register(%mem: !fly.memref) { // CHECK-NOT: fly.get_iter %iter = fly.get_iter(%mem) : (!fly.memref) -> !fly.ptr %offset = fly.make_int_tuple() : () -> !fly.int_tuple<2> - // CHECK: llvm.getelementptr %[[MEM]][{{.*}}] : (!llvm.ptr<5>, i64) -> !llvm.ptr<5>, f32 + // CHECK: llvm.getelementptr %[[MEM]][{{.*}}] : (!llvm.ptr<5>, i32) -> !llvm.ptr<5>, f32 %result = fly.add_offset(%iter, %offset) : (!fly.ptr, !fly.int_tuple<2>) -> !fly.ptr return } @@ -56,9 +55,8 @@ func.func @test_get_iter_register(%mem: !fly.memref) { // CHECK-SAME: (%[[PTR:.*]]: !llvm.ptr<1>) func.func @test_add_offset_static(%ptr: !fly.ptr) { %offset = fly.make_int_tuple() : () -> !fly.int_tuple<4> - // CHECK: %[[C4:.*]] = arith.constant 4 : index - // CHECK: %[[I64:.*]] = arith.index_cast %[[C4]] : index to i64 - // CHECK: llvm.getelementptr %[[PTR]][%[[I64]]] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, f32 + // CHECK: %[[C4:.*]] = arith.constant 4 : i32 + // CHECK: llvm.getelementptr %[[PTR]][%[[C4]]] : (!llvm.ptr<1>, i32) -> !llvm.ptr<1>, f32 %result = fly.add_offset(%ptr, %offset) : (!fly.ptr, !fly.int_tuple<4>) -> !fly.ptr return } @@ -67,9 +65,7 @@ func.func @test_add_offset_static(%ptr: !fly.ptr) { // CHECK-SAME: (%[[PTR:.*]]: !llvm.ptr<1>, %[[OFF:.*]]: i32) func.func @test_add_offset_dynamic(%ptr: !fly.ptr, %off: i32) { %offset = fly.make_int_tuple(%off) : (i32) -> !fly.int_tuple - // CHECK: arith.index_cast %[[OFF]] : i32 to index - // CHECK: arith.index_cast {{.*}} : index to i64 - // CHECK: llvm.getelementptr %[[PTR]][{{.*}}] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, f32 + // CHECK: llvm.getelementptr %[[PTR]][%[[OFF]]] : (!llvm.ptr<1>, i32) -> !llvm.ptr<1>, f32 %result = fly.add_offset(%ptr, %offset) : (!fly.ptr, !fly.int_tuple) -> !fly.ptr return } @@ -108,7 +104,7 @@ func.func @test_make_view(%ptr: !fly.ptr) -> f32 { %view = fly.make_view(%ptr, %layout) : (!fly.ptr, !fly.layout<(4, 8) : (1, 4)>) -> !fly.memref %iter = fly.get_iter(%view) : (!fly.memref) -> !fly.ptr %offset = fly.make_int_tuple() : () -> !fly.int_tuple<7> - // CHECK: llvm.getelementptr %[[PTR]][{{.*}}] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, f32 + // CHECK: llvm.getelementptr %[[PTR]][{{.*}}] : (!llvm.ptr<1>, i32) -> !llvm.ptr<1>, f32 %gep = fly.add_offset(%iter, %offset) : (!fly.ptr, !fly.int_tuple<7>) -> !fly.ptr // CHECK: %[[VAL:.*]] = llvm.load %val = fly.ptr.load(%gep) : (!fly.ptr) -> f32 From 927bd8c7581a0f7be76eade885e3731f284308d3 Mon Sep 17 00:00:00 2001 From: Feng Shijie Date: Mon, 23 Mar 2026 06:25:57 +0000 Subject: [PATCH 2/7] Add DecompositionOp and support swizzle as the mapping of crd2idx --- include/flydsl/Dialect/Fly/IR/FlyOps.td | 6 +- .../Dialect/Fly/Transforms/MemrefLowering.td | 21 +++++ lib/Dialect/Fly/IR/FlyOps.cpp | 80 ++++++++++++++++++- lib/Dialect/Fly/Transforms/LayoutLowering.cpp | 11 ++- 4 files changed, 113 insertions(+), 5 deletions(-) diff --git a/include/flydsl/Dialect/Fly/IR/FlyOps.td b/include/flydsl/Dialect/Fly/IR/FlyOps.td index 0dcc13ec..7288a535 100644 --- a/include/flydsl/Dialect/Fly/IR/FlyOps.td +++ b/include/flydsl/Dialect/Fly/IR/FlyOps.td @@ -251,7 +251,7 @@ def Fly_CosizeOp : Fly_Op<"cosize", [Pure, DeclareOpInterfaceMethods]> { - let arguments = (ins Fly_IntTuple:$coord, Fly_NarrowLayoutType:$layout); + let arguments = (ins Fly_IntTuple:$coord, AnyTypeOf<[Fly_NarrowLayoutType, Fly_Swizzle]>:$layout); let results = (outs Fly_IntTuple:$index); } def Fly_Idx2CrdOp : Fly_Op<"idx2crd", [Pure, DeclareOpInterfaceMethods]> { @@ -421,6 +421,10 @@ def Fly_ApplySwizzleOp : Fly_Op<"apply_swizzle", [Pure, DeclareOpInterfaceMethod let arguments = (ins Fly_Pointer:$ptr, Fly_Swizzle:$swizzle); let results = (outs Fly_Pointer:$result); } +def Fly_DecompositionOp : Fly_Op<"decomposition", [Pure, DeclareOpInterfaceMethods]> { + let arguments = (ins Fly_TensorLikeType:$tensor); + let results = (outs Fly_TensorLikeType:$result); +} def Fly_PtrLoadOp : Fly_Op<"ptr.load", [DeclareOpInterfaceMethods]> { let arguments = (ins Fly_IteratorLikeType:$ptr); diff --git a/include/flydsl/Dialect/Fly/Transforms/MemrefLowering.td b/include/flydsl/Dialect/Fly/Transforms/MemrefLowering.td index dc095ebb..a5b72764 100644 --- a/include/flydsl/Dialect/Fly/Transforms/MemrefLowering.td +++ b/include/flydsl/Dialect/Fly/Transforms/MemrefLowering.td @@ -15,6 +15,27 @@ def : Pat<(Fly_AddOffsetOp Fly_IntTuple:$int_tuple, Fly_IntTuple:$offset), def : Pat<(Fly_AddOffsetOp (Fly_AddOffsetOp Fly_Pointer:$ptr, Fly_IntTuple:$offset1), Fly_IntTuple:$offset2), (Fly_AddOffsetOp $ptr, (Fly_IntTupleAddOp $offset1, $offset2))>; +def : Pat<(Fly_DecompositionOp Fly_MemRef:$memref), + (replaceWithValue $memref), + [(Fly_SimpleLayoutMemRef $memref)]>; +def : Pat<(Fly_DecompositionOp Fly_CoordTensor:$tensor), + (replaceWithValue $tensor), + [(Fly_SimpleLayoutCoordTensor $tensor)]>; +def : Pat<(Fly_DecompositionOp Fly_MemRef:$memref), + (Fly_MakeViewOp + (Fly_AddOffsetOp (Fly_GetIterOp $memref), + (Fly_Crd2IdxOp (Fly_ComposedGetOffsetOp (Fly_GetLayoutOp $memref)), + (Fly_ComposedGetInnerOp (Fly_GetLayoutOp $memref)))), + (Fly_ComposedGetOuterOp (Fly_GetLayoutOp $memref))), + [(Fly_ComposedLayoutMemRef $memref)]>; +def : Pat<(Fly_DecompositionOp Fly_CoordTensor:$tensor), + (Fly_MakeViewOp + (Fly_AddOffsetOp (Fly_GetIterOp $tensor), + (Fly_Crd2IdxOp (Fly_ComposedGetOffsetOp (Fly_GetLayoutOp $tensor)), + (Fly_ComposedGetInnerOp (Fly_GetLayoutOp $tensor)))), + (Fly_ComposedGetOuterOp (Fly_GetLayoutOp $tensor))), + [(Fly_ComposedLayoutCoordTensor $tensor)]>; + def : Pat<(Fly_MemRefLoadOp Fly_MemRef:$memref, $indices), (Fly_PtrLoadOp (Fly_AddOffsetOp (Fly_GetIterOp $memref), (Fly_Crd2IdxOp $indices, (Fly_GetLayoutOp $memref))))>; diff --git a/lib/Dialect/Fly/IR/FlyOps.cpp b/lib/Dialect/Fly/IR/FlyOps.cpp index 983ce090..8196cf0e 100644 --- a/lib/Dialect/Fly/IR/FlyOps.cpp +++ b/lib/Dialect/Fly/IR/FlyOps.cpp @@ -1,7 +1,6 @@ // SPDX-License-Identifier: Apache-2.0 // Copyright (c) 2025 FlyDSL Project Contributors - #include "mlir/IR/Builders.h" #include "mlir/IR/DialectImplementation.h" #include "mlir/IR/OpImplementation.h" @@ -956,10 +955,15 @@ FLY_INFER_RETURN_TYPES(Crd2IdxOp) { layoutCrd2Idx(layoutBuilder, coordAttr, static_cast(composedTy.getAttr())); inferredReturnTypes.assign({IntTupleType::get(result)}); return success(); + } else if (auto swizzleTy = dyn_cast(operands[1].getType())) { + IntTupleAttr result = builder.applySwizzle(coordAttr, swizzleTy.getAttr()); + inferredReturnTypes.assign({IntTupleType::get(result)}); + return success(); } return emitOptionalError(location, - "Crd2IdxOp: expected LayoutType or ComposedLayoutType for layout, got ", + "Crd2IdxOp: expected LayoutType, ComposedLayoutType or SwizzleType " + "for layout, got ", operands[1].getType()); } @@ -1779,6 +1783,55 @@ FLY_INFER_RETURN_TYPES(ApplySwizzleOp) { return success(); } +FLY_INFER_RETURN_TYPES(DecompositionOp) { + Type tensorTy = operands[0].getType(); + + Attribute layoutAttr; + if (auto memrefTy = dyn_cast(tensorTy)) + layoutAttr = memrefTy.getLayout(); + else if (auto coordTensorTy = dyn_cast(tensorTy)) + layoutAttr = coordTensorTy.getLayout(); + else + return emitOptionalError(location, "DecompositionOp: expected TensorLikeType, got ", tensorTy); + + if (isa(layoutAttr)) { + inferredReturnTypes.assign({tensorTy}); + return success(); + } + + auto composed = dyn_cast(layoutAttr); + if (!composed) + return emitOptionalError( + location, "DecompositionOp: expected LayoutAttr or ComposedLayoutAttr, got ", layoutAttr); + + LayoutBuilder builder(context); + Attribute inner = composed.getInner(); + IntTupleAttr offset = composed.getOffset(); + LayoutAttr outer = composed.getOuter(); + + IntTupleAttr iterOffset; + if (auto swizzleInner = dyn_cast(inner)) + iterOffset = builder.applySwizzle(offset, swizzleInner); + else + iterOffset = layoutCrd2Idx(builder, offset, inner); + + if (auto memrefTy = dyn_cast(tensorTy)) { + int32_t valDiv = memrefTy.getValueDivisibility(); + IntAttr offsetInt = iterOffset.extractIntFromLeaf(); + int32_t offsetDiv = + offsetInt.isStatic() ? std::abs(offsetInt.getValue()) : offsetInt.getDivisibility(); + int32_t newValDiv = (offsetDiv == 0) ? valDiv : utils::divisibilityAdd(valDiv, offsetDiv); + inferredReturnTypes.assign({fly::MemRefType::get( + memrefTy.getElemTy(), memrefTy.getAddressSpace(), outer, + AlignAttr::get(memrefTy.getElemTy(), newValDiv), memrefTy.getSwizzle())}); + } else { + auto coordTensorTy = cast(tensorTy); + IntTupleAttr newBase = intTupleAdd(builder, coordTensorTy.getBase(), iterOffset); + inferredReturnTypes.assign({CoordTensorType::get(newBase, outer)}); + } + return success(); +} + FLY_INFER_RETURN_TYPES(PtrLoadOp) { if (auto ptrTy = dyn_cast(operands[0].getType())) { inferredReturnTypes.assign({ptrTy.getElemTy()}); @@ -1798,7 +1851,28 @@ FLY_INFER_RETURN_TYPES(MemRefLoadOp) { return success(); } if (auto coordTensorTy = dyn_cast(operands[0].getType())) { - inferredReturnTypes.push_back(IntTupleType::get(coordTensorTy.getBase())); + auto indicesTy = dyn_cast(operands[1].getType()); + if (!indicesTy) + return emitOptionalError(location, "MemRefLoadOp: expected IntTupleType for indices, got ", + operands[1].getType()); + + IntTupleBuilder builder(context); + IntTupleAttr baseAttr = coordTensorTy.getBase(); + IntTupleAttr indicesAttr = indicesTy.getAttr(); + Attribute layoutAttr = coordTensorTy.getLayout(); + + IntTupleAttr offsetAttr; + if (auto layout = dyn_cast(layoutAttr)) { + offsetAttr = layoutCrd2Idx(builder, indicesAttr, layout.getShape(), layout.getStride()); + } else if (auto composed = dyn_cast(layoutAttr)) { + LayoutBuilder layoutBuilder(context); + offsetAttr = layoutCrd2Idx(layoutBuilder, indicesAttr, static_cast(composed)); + } else { + return emitOptionalError(location, "MemRefLoadOp: unsupported layout type in CoordTensor"); + } + + IntTupleAttr resultAttr = intTupleAdd(builder, baseAttr, offsetAttr); + inferredReturnTypes.push_back(IntTupleType::get(resultAttr)); return success(); } return emitOptionalError(location, "MemRefLoadOp: expected MemRefType or CoordTensorType, got ", diff --git a/lib/Dialect/Fly/Transforms/LayoutLowering.cpp b/lib/Dialect/Fly/Transforms/LayoutLowering.cpp index e03ae0ba..a2a5b3e2 100644 --- a/lib/Dialect/Fly/Transforms/LayoutLowering.cpp +++ b/lib/Dialect/Fly/Transforms/LayoutLowering.cpp @@ -836,6 +836,13 @@ class Crd2IdxLowering : public OpRewritePattern { if (!isNormalForm(cast>(layout))) return failure(); layoutAdaptor = LayoutValueAdaptor(layout, composedLayoutTy.getAttr()); + } else if (auto swizzleTy = dyn_cast(layout.getType())) { + LayoutBuilder layoutBuilder(rewriter, loc); + IntTupleValueAdaptor coordAdaptor = + IntTupleValueAdaptor::create(layoutBuilder, coord, coordTy.getAttr()); + IntTupleValueAdaptor result = layoutBuilder.applySwizzle(coordAdaptor, swizzleTy.getAttr()); + rewriter.replaceOp(op, layoutBuilder.finalize(result)); + return success(); } else { return failure(); } @@ -1838,7 +1845,9 @@ class ExpandCopyOpLowering : public OpRewritePattern { if (srcRank == 1) { if (srcLayoutAttr.getShape().isLeaf()) { - CopyAtomCall::create(rewriter, loc, copyAtomVal, src, dst, pred); + Value srcDecomposition = DecompositionOp::create(rewriter, loc, src); + Value dstDecomposition = DecompositionOp::create(rewriter, loc, dst); + CopyAtomCall::create(rewriter, loc, copyAtomVal, srcDecomposition, dstDecomposition, pred); rewriter.eraseOp(op); return success(); } From 8d6ee979be9d5e918d06712ca8758feac7cbf964 Mon Sep 17 00:00:00 2001 From: Feng Shijie Date: Mon, 23 Mar 2026 07:41:20 +0000 Subject: [PATCH 3/7] minor fix --- include/flydsl/Dialect/Fly/IR/FlyOps.td | 3 +- lib/Conversion/FlyToROCDL/BufferFatPtr.h | 5 ++-- lib/Conversion/FlyToROCDL/FlyToROCDL.cpp | 19 ++++++++++++- lib/Dialect/Fly/IR/FlyAttrDefs.cpp | 7 ++++- lib/Dialect/Fly/IR/FlyTypeDefs.cpp | 10 +++---- lib/Dialect/Fly/Transforms/LayoutLowering.cpp | 28 +++++++++---------- lib/Dialect/FlyROCDL/CDNA3/MmaAtom.cpp | 4 +-- tests/mlir/Conversion/dyn_shared.mlir | 2 +- tests/mlir/Conversion/memref_ops.mlir | 15 ++++------ tests/mlir/Transforms/layout_lowering.mlir | 2 +- .../Transforms/rewrite_func_signature.mlir | 4 +-- 11 files changed, 58 insertions(+), 41 deletions(-) diff --git a/include/flydsl/Dialect/Fly/IR/FlyOps.td b/include/flydsl/Dialect/Fly/IR/FlyOps.td index 7288a535..fcd73d78 100644 --- a/include/flydsl/Dialect/Fly/IR/FlyOps.td +++ b/include/flydsl/Dialect/Fly/IR/FlyOps.td @@ -337,7 +337,7 @@ def Fly_MakeMmaAtomOp : Fly_Op<"make_mma_atom", [Pure]> { def Fly_MakeCopyAtomOp : Fly_Op<"make_copy_atom", [Pure]> { let arguments = (ins I32Attr:$valBits); let results = (outs Fly_CopyAtom:$result); - let assemblyFormat = "attr-dict `:` type($result)"; + let assemblyFormat = "attr-dict `:` qualified(type($result))"; } def Fly_CopyAtomCall : Fly_Op<"copy_atom_call"> { @@ -374,6 +374,7 @@ def Fly_TiledMmaPartitionOp : Fly_Op<"tiled_mma.partition", [Pure, DeclareOpInte let arguments = (ins Fly_MmaOperandAttr:$operand_id, Fly_TiledMma:$tiled_mma, Fly_TensorLikeType:$input, Fly_IntTuple:$coord); let results = (outs Fly_TensorLikeType:$result); + let assemblyFormat = "`(` $operand_id `,` $tiled_mma `,` $input `,` $coord `)` attr-dict `:` functional-type(operands, results)"; } def Fly_TiledMmaPartitionShapeOp : Fly_Op<"tiled_mma.partition_shape", [Pure, DeclareOpInterfaceMethods]> { let arguments = (ins Fly_MmaOperandAttr:$operand_id, Fly_TiledMma:$tiled_mma, Fly_IntTuple:$shape); diff --git a/lib/Conversion/FlyToROCDL/BufferFatPtr.h b/lib/Conversion/FlyToROCDL/BufferFatPtr.h index 0180503f..cbf3c86d 100644 --- a/lib/Conversion/FlyToROCDL/BufferFatPtr.h +++ b/lib/Conversion/FlyToROCDL/BufferFatPtr.h @@ -7,13 +7,12 @@ #include "mlir/IR/BuiltinTypes.h" #include "flydsl/Dialect/Fly/IR/FlyDialect.h" -#include "flydsl/Dialect/Fly/Utils/AddressSpaceUtils.h" namespace mlir::fly { class BufferFatPtr { - static constexpr unsigned kRsrcAddrSpace = mapToLLVMAddressSpace(AddressSpace::BufferDesc); - static constexpr unsigned kOffsetBitWidth = 32; + static constexpr unsigned kRsrcAddrSpace = 8; // BufferDesc + static constexpr unsigned kOffsetBitWidth = 32; // constrained by BufferCopy instruction fly::PointerType ptrTy; Value fatPtr; diff --git a/lib/Conversion/FlyToROCDL/FlyToROCDL.cpp b/lib/Conversion/FlyToROCDL/FlyToROCDL.cpp index 2d4978b6..b65319bb 100644 --- a/lib/Conversion/FlyToROCDL/FlyToROCDL.cpp +++ b/lib/Conversion/FlyToROCDL/FlyToROCDL.cpp @@ -34,6 +34,22 @@ using namespace mlir::fly; namespace { +unsigned mapToLLVMAddressSpace(AddressSpace addrSpace) { + switch (addrSpace) { + case AddressSpace::Global: + return 1; + case AddressSpace::Shared: + return 3; + case AddressSpace::Register: + return 5; + case AddressSpace::BufferDesc: + return 8; + default: + assert(false && "Unsupported address space"); + return 0; + } +} + Value applySwizzleOnPtr(OpBuilder &b, Location loc, Value ptr, SwizzleAttr swizzle) { if (swizzle.isTrivialSwizzle()) return ptr; @@ -288,7 +304,8 @@ class MakeViewOpLowering : public OpConversionPattern { LogicalResult matchAndRewrite(MakeViewOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { if (isa(op.getResult().getType())) { - assert(op.getResult().use_empty() && "coord_tensor result should have no uses"); + if (!op.getResult().use_empty()) + return rewriter.notifyMatchFailure(op, "coord_tensor result should have no uses"); rewriter.eraseOp(op); return success(); } else { diff --git a/lib/Dialect/Fly/IR/FlyAttrDefs.cpp b/lib/Dialect/Fly/IR/FlyAttrDefs.cpp index 4c0a2c8c..884c3fe8 100644 --- a/lib/Dialect/Fly/IR/FlyAttrDefs.cpp +++ b/lib/Dialect/Fly/IR/FlyAttrDefs.cpp @@ -374,8 +374,13 @@ ::mlir::Attribute parseLeafAttr(::mlir::AsmParser &odsParser) { valueAttr = IntAttr::getStatic(ctx, value); } + auto nextLoc = odsParser.getCurrentLocation(); + const char *nextPtr = nextLoc.getPointer(); + if (!nextPtr || *nextPtr != 'E' || !std::isdigit(static_cast(*(nextPtr + 1)))) + return valueAttr; + StringRef strRefModes; - if (failed(odsParser.parseOptionalKeyword(&strRefModes)) || !strRefModes.starts_with("E")) + if (failed(odsParser.parseOptionalKeyword(&strRefModes))) return valueAttr; SmallVector modes; diff --git a/lib/Dialect/Fly/IR/FlyTypeDefs.cpp b/lib/Dialect/Fly/IR/FlyTypeDefs.cpp index 5ec76550..5126d6ed 100644 --- a/lib/Dialect/Fly/IR/FlyTypeDefs.cpp +++ b/lib/Dialect/Fly/IR/FlyTypeDefs.cpp @@ -262,7 +262,7 @@ Type CoordTensorType::parse(AsmParser &parser) { void CoordTensorType::print(AsmPrinter &printer) const { printer << "<"; printer.printStrippedAttrOrType(getBase()); - printer << ","; + printer << ", "; Attribute layoutAttr = getLayout(); if (auto layout = dyn_cast(layoutAttr)) printer.printStrippedAttrOrType(layout); @@ -306,11 +306,11 @@ static LogicalResult parseAlignAndSwizzle(AsmParser &parser, Type elemTy, AlignA static void printAlignAndSwizzle(AsmPrinter &printer, Type elemTy, AlignAttr alignment, SwizzleAttr swizzle, MLIRContext *ctx) { if (alignment != AlignAttr::getTrivialAlignment(elemTy)) { - printer << ","; + printer << ", "; printer.printStrippedAttrOrType(alignment); } if (swizzle != SwizzleAttr::getTrivialSwizzle(ctx)) { - printer << ","; + printer << ", "; printer.printStrippedAttrOrType(swizzle); } } @@ -332,7 +332,7 @@ Type PointerType::parse(AsmParser &parser) { } void PointerType::print(AsmPrinter &printer) const { - printer << "<" << getElemTy() << ","; + printer << "<" << getElemTy() << ", "; printer.printStrippedAttrOrType(getAddressSpace()); printAlignAndSwizzle(printer, getElemTy(), getAlignment(), getSwizzle(), getContext()); printer << ">"; @@ -360,7 +360,7 @@ Type MemRefType::parse(AsmParser &parser) { } void MemRefType::print(AsmPrinter &printer) const { - printer << "<" << getElemTy() << ","; + printer << "<" << getElemTy() << ", "; printer.printStrippedAttrOrType(getAddressSpace()); printer << ", "; Attribute layoutAttr = getLayout(); diff --git a/lib/Dialect/Fly/Transforms/LayoutLowering.cpp b/lib/Dialect/Fly/Transforms/LayoutLowering.cpp index a2a5b3e2..4ed7326d 100644 --- a/lib/Dialect/Fly/Transforms/LayoutLowering.cpp +++ b/lib/Dialect/Fly/Transforms/LayoutLowering.cpp @@ -1928,17 +1928,18 @@ class ExpandGemmOpLowering : public OpRewritePattern { int32_t bRank = bLayoutAttr.rank(); int32_t cRank = cLayoutAttr.rank(); - int32_t loop_m = dLayoutAttr.getShape().at(1).getLeafAsInt().getValue(); - int32_t loop_n = dLayoutAttr.getShape().at(2).getLeafAsInt().getValue(); - - assert(loop_m == aLayoutAttr.getShape().at(1).getLeafAsInt().getValue() && - "Mismatch in loop_m"); - assert(loop_n == bLayoutAttr.getShape().at(1).getLeafAsInt().getValue() && - "Mismatch in loop_n"); - assert(loop_m == cLayoutAttr.getShape().at(1).getLeafAsInt().getValue() && - "Mismatch in loop_m"); - assert(loop_n == cLayoutAttr.getShape().at(2).getLeafAsInt().getValue() && - "Mismatch in loop_n"); + IntTupleBuilder attrBuilder(ctx); + auto get_static_product = [&](IntTupleAttr shape) { + return intTupleProduct(attrBuilder, shape).getLeafAsInt().getValue(); + }; + + int32_t loop_m = get_static_product(dLayoutAttr.getShape().at(1)); + int32_t loop_n = get_static_product(dLayoutAttr.getShape().at(2)); + + assert(loop_m == get_static_product(aLayoutAttr.getShape().at(1)) && "Mismatch in loop_m"); + assert(loop_n == get_static_product(bLayoutAttr.getShape().at(1)) && "Mismatch in loop_n"); + assert(loop_m == get_static_product(cLayoutAttr.getShape().at(1)) && "Mismatch in loop_m"); + assert(loop_n == get_static_product(cLayoutAttr.getShape().at(2)) && "Mismatch in loop_n"); if (dRank == 1 && aRank == 1 && bRank == 1 && cRank == 1) { MmaAtomCall::create(rewriter, loc, mmaAtomVal, d, a, b, c); @@ -1972,9 +1973,8 @@ class ExpandGemmOpLowering : public OpRewritePattern { rewriter.eraseOp(op); return success(); } else if (aRank == 3 && bRank == 3) { - int32_t loop_k = aLayoutAttr.getShape().at(2).getLeafAsInt().getValue(); - assert(loop_k == bLayoutAttr.getShape().at(2).getLeafAsInt().getValue() && - "Mismatch in loop_k"); + int32_t loop_k = get_static_product(aLayoutAttr.getShape().at(2)); + assert(loop_k == get_static_product(bLayoutAttr.getShape().at(2)) && "Mismatch in loop_k"); for (int32_t k = 0; k < loop_k; ++k) { Value cSrc = (k == 0) ? c : d; diff --git a/lib/Dialect/FlyROCDL/CDNA3/MmaAtom.cpp b/lib/Dialect/FlyROCDL/CDNA3/MmaAtom.cpp index d9148597..adc1f07a 100644 --- a/lib/Dialect/FlyROCDL/CDNA3/MmaAtom.cpp +++ b/lib/Dialect/FlyROCDL/CDNA3/MmaAtom.cpp @@ -66,10 +66,8 @@ Attribute MmaAtomCDNA3_MFMAType::getThrValLayoutC() const { int GroupM = 64 / N; int ValM0 = 4; - int ValM1 = M / 4 / GroupM; - return FxLayout(FxShape(FxThr(N, GroupM), FxVal(ValM0, ValM1)), - FxStride(FxThr(M, ValM0), FxVal(1, ValM0 * GroupM))); + return FxLayout(FxShape(FxThr(N, GroupM), FxVal(ValM0)), FxStride(FxThr(M, ValM0), FxVal(1))); } LogicalResult MmaAtomCDNA3_MFMAType::verify(function_ref emitError, int32_t m, diff --git a/tests/mlir/Conversion/dyn_shared.mlir b/tests/mlir/Conversion/dyn_shared.mlir index 3caa36a1..af4c3c56 100644 --- a/tests/mlir/Conversion/dyn_shared.mlir +++ b/tests/mlir/Conversion/dyn_shared.mlir @@ -56,7 +56,7 @@ gpu.module @load_module { %shmem = fly.get_dyn_shared() : !fly.ptr> %off = fly.make_int_tuple(%offset) : (i32) -> !fly.int_tuple // CHECK: llvm.getelementptr {{.*}}[0] : (!llvm.ptr<3>) -> !llvm.ptr<3>, i8 - // CHECK: llvm.getelementptr {{.*}}[%{{.*}}] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i8 + // CHECK: llvm.getelementptr {{.*}}[%{{.*}}] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, i8 %ptr = fly.add_offset(%shmem, %off) : (!fly.ptr>, !fly.int_tuple) -> !fly.ptr // CHECK: llvm.load {{.*}} : !llvm.ptr<3> -> i8 %val = fly.ptr.load(%ptr) : (!fly.ptr) -> i8 diff --git a/tests/mlir/Conversion/memref_ops.mlir b/tests/mlir/Conversion/memref_ops.mlir index a9d89ded..57e09baa 100644 --- a/tests/mlir/Conversion/memref_ops.mlir +++ b/tests/mlir/Conversion/memref_ops.mlir @@ -14,9 +14,8 @@ // CHECK-SAME: (%[[PTR:.*]]: !llvm.ptr<1>) func.func @test_memref_load(%mem: !fly.memref) -> f32 { %idx = fly.make_int_tuple() : () -> !fly.int_tuple<5> - // CHECK: %[[C5:.*]] = arith.constant 5 : index - // CHECK: %[[I64:.*]] = arith.index_cast %[[C5]] : index to i64 - // CHECK: %[[GEP:.*]] = llvm.getelementptr %[[PTR]][%[[I64]]] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, f32 + // CHECK: %[[C5:.*]] = arith.constant 5 : i32 + // CHECK: %[[GEP:.*]] = llvm.getelementptr %[[PTR]][%[[C5]]] : (!llvm.ptr<1>, i32) -> !llvm.ptr<1>, f32 // CHECK: %[[VAL:.*]] = llvm.load %[[GEP]] : !llvm.ptr<1> -> f32 %val = fly.memref.load(%mem, %idx) : (!fly.memref, !fly.int_tuple<5>) -> f32 // CHECK: return %[[VAL]] @@ -27,9 +26,8 @@ func.func @test_memref_load(%mem: !fly.memref) -> f32 { // CHECK-SAME: (%[[PTR:.*]]: !llvm.ptr<1>, %[[VAL:.*]]: f32) func.func @test_memref_store(%mem: !fly.memref, %val: f32) { %idx = fly.make_int_tuple() : () -> !fly.int_tuple<3> - // CHECK: %[[C3:.*]] = arith.constant 3 : index - // CHECK: %[[I64:.*]] = arith.index_cast %[[C3]] : index to i64 - // CHECK: %[[GEP:.*]] = llvm.getelementptr %[[PTR]][%[[I64]]] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, f32 + // CHECK: %[[C3:.*]] = arith.constant 3 : i32 + // CHECK: %[[GEP:.*]] = llvm.getelementptr %[[PTR]][%[[C3]]] : (!llvm.ptr<1>, i32) -> !llvm.ptr<1>, f32 // CHECK: llvm.store %[[VAL]], %[[GEP]] : f32, !llvm.ptr<1> fly.memref.store(%val, %mem, %idx) : (f32, !fly.memref, !fly.int_tuple<3>) -> () return @@ -39,9 +37,8 @@ func.func @test_memref_store(%mem: !fly.memref, %val: f32) { // CHECK-SAME: (%[[PTR:.*]]: !llvm.ptr<3>) func.func @test_memref_load_f16_shared(%mem: !fly.memref) -> f16 { %idx = fly.make_int_tuple() : () -> !fly.int_tuple<10> - // CHECK: %[[C10:.*]] = arith.constant 10 : index - // CHECK: %[[I64:.*]] = arith.index_cast %[[C10]] : index to i64 - // CHECK: %[[GEP:.*]] = llvm.getelementptr %[[PTR]][%[[I64]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, f16 + // CHECK: %[[C10:.*]] = arith.constant 10 : i32 + // CHECK: %[[GEP:.*]] = llvm.getelementptr %[[PTR]][%[[C10]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, f16 // CHECK: %[[VAL:.*]] = llvm.load %[[GEP]] : !llvm.ptr<3> -> f16 %val = fly.memref.load(%mem, %idx) : (!fly.memref, !fly.int_tuple<10>) -> f16 // CHECK: return %[[VAL]] diff --git a/tests/mlir/Transforms/layout_lowering.mlir b/tests/mlir/Transforms/layout_lowering.mlir index 88047c0e..cd39ebe4 100644 --- a/tests/mlir/Transforms/layout_lowering.mlir +++ b/tests/mlir/Transforms/layout_lowering.mlir @@ -60,7 +60,7 @@ func.func @test_get_layout(%ptr: !fly.ptr) -> !fly.layout<(4,8):(1, // get_iter forwards the iterator (ptr) from make_view; all Fly ops are eliminated. // CHECK-LABEL: @test_get_iter -// CHECK-SAME: (%[[PTR:.*]]: !fly.ptr) +// CHECK-SAME: (%[[PTR:.*]]: !fly.ptr) func.func @test_get_iter(%ptr: !fly.ptr) -> !fly.ptr { %s = fly.make_int_tuple() : () -> !fly.int_tuple<(4, 8)> %d = fly.make_int_tuple() : () -> !fly.int_tuple<(1, 4)> diff --git a/tests/mlir/Transforms/rewrite_func_signature.mlir b/tests/mlir/Transforms/rewrite_func_signature.mlir index f60e61c7..6655353c 100644 --- a/tests/mlir/Transforms/rewrite_func_signature.mlir +++ b/tests/mlir/Transforms/rewrite_func_signature.mlir @@ -85,7 +85,7 @@ func.func @test_partially_dynamic_layout(%arg0: !fly.layout<4:?>) { // MemRef with static layout: lowered to a single fly.ptr argument. // CHECK-LABEL: @test_static_memref -// CHECK-SAME: (%[[P:.*]]: !fly.ptr) +// CHECK-SAME: (%[[P:.*]]: !fly.ptr) func.func @test_static_memref(%arg0: !fly.memref) { // CHECK: fly.static : !fly.layout<32:1> // CHECK: fly.make_view(%[[P]] @@ -96,7 +96,7 @@ func.func @test_static_memref(%arg0: !fly.memref) { // MemRef with dynamic layout: lowered to ptr arg + layout struct arg. // CHECK-LABEL: @test_dynamic_memref -// CHECK-SAME: (%[[P:.*]]: !fly.ptr, %[[L:.*]]: !llvm.struct, struct)>) +// CHECK-SAME: (%[[P:.*]]: !fly.ptr, %[[L:.*]]: !llvm.struct, struct)>) func.func @test_dynamic_memref(%arg0: !fly.memref) { // CHECK: llvm.extractvalue %[[L]][0] // CHECK: llvm.extractvalue %[[L]][1] From 58eafcb9817cbb6f3f7dbdf06a25dc152e2b2564 Mon Sep 17 00:00:00 2001 From: Feng Shijie Date: Mon, 23 Mar 2026 17:39:08 +0000 Subject: [PATCH 4/7] remove redundant getResult() --- lib/Conversion/FlyToROCDL/BufferFatPtr.h | 9 ++++++--- lib/Conversion/FlyToROCDL/FlyToROCDL.cpp | 20 +++++++++----------- python/flydsl/expr/primitive.py | 8 -------- 3 files changed, 15 insertions(+), 22 deletions(-) diff --git a/lib/Conversion/FlyToROCDL/BufferFatPtr.h b/lib/Conversion/FlyToROCDL/BufferFatPtr.h index cbf3c86d..f2f811bf 100644 --- a/lib/Conversion/FlyToROCDL/BufferFatPtr.h +++ b/lib/Conversion/FlyToROCDL/BufferFatPtr.h @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: Apache-2.0 +// Copyright (c) 2025 FlyDSL Project Contributors + #ifndef FLYDSL_LIB_CONVERSION_FLYTOROCDL_BUFFERFATPTR_H #define FLYDSL_LIB_CONVERSION_FLYTOROCDL_BUFFERFATPTR_H @@ -51,12 +54,12 @@ class BufferFatPtr { return off; if (bits > 8 && bits % 8 == 0) { int64_t elemBytes = bits / 8; - Value scale = arith::ConstantIntOp::create(b, loc, elemBytes, kOffsetBitWidth).getResult(); + Value scale = arith::ConstantIntOp::create(b, loc, elemBytes, kOffsetBitWidth); return arith::MulIOp::create(b, loc, off, scale); } - Value scale = arith::ConstantIntOp::create(b, loc, bits, kOffsetBitWidth).getResult(); + Value scale = arith::ConstantIntOp::create(b, loc, bits, kOffsetBitWidth); off = arith::MulIOp::create(b, loc, off, scale); - Value const8 = arith::ConstantIntOp::create(b, loc, 8, kOffsetBitWidth).getResult(); + Value const8 = arith::ConstantIntOp::create(b, loc, 8, kOffsetBitWidth); return arith::DivUIOp::create(b, loc, off, const8); } diff --git a/lib/Conversion/FlyToROCDL/FlyToROCDL.cpp b/lib/Conversion/FlyToROCDL/FlyToROCDL.cpp index b65319bb..0ac60de7 100644 --- a/lib/Conversion/FlyToROCDL/FlyToROCDL.cpp +++ b/lib/Conversion/FlyToROCDL/FlyToROCDL.cpp @@ -58,11 +58,11 @@ Value applySwizzleOnPtr(OpBuilder &b, Location loc, Value ptr, SwizzleAttr swizz Value ptrInt = LLVM::PtrToIntOp::create(b, loc, i64Ty, ptr); int64_t bitMaskValue = ((int64_t{1} << swizzle.getMask()) - 1) << (swizzle.getBase() + swizzle.getShift()); - Value bitMask = arith::ConstantIntOp::create(b, loc, i64Ty, bitMaskValue).getResult(); - Value shiftAmt = arith::ConstantIntOp::create(b, loc, i64Ty, swizzle.getShift()).getResult(); - Value masked = arith::AndIOp::create(b, loc, ptrInt, bitMask).getResult(); - Value shifted = arith::ShRUIOp::create(b, loc, masked, shiftAmt).getResult(); - Value swizzled = arith::XOrIOp::create(b, loc, ptrInt, shifted).getResult(); + Value bitMask = arith::ConstantIntOp::create(b, loc, i64Ty, bitMaskValue); + Value shiftAmt = arith::ConstantIntOp::create(b, loc, i64Ty, swizzle.getShift()); + Value masked = arith::AndIOp::create(b, loc, ptrInt, bitMask); + Value shifted = arith::ShRUIOp::create(b, loc, masked, shiftAmt); + Value swizzled = arith::XOrIOp::create(b, loc, ptrInt, shifted); return LLVM::IntToPtrOp::create(b, loc, ptrTy, swizzled); } @@ -89,8 +89,7 @@ class MakePtrOpLowering : public OpConversionPattern { return rewriter.notifyMatchFailure(op, "register make_ptr requires allocSize in ptrAttrs"); unsigned llvmAS = mapToLLVMAddressSpace(AddressSpace::Register); auto llvmPtrTy = LLVM::LLVMPointerType::get(rewriter.getContext(), llvmAS); - Value nElems = - arith::ConstantIntOp::create(rewriter, loc, allocSize.getInt(), 64).getResult(); + Value nElems = arith::ConstantIntOp::create(rewriter, loc, allocSize.getInt(), 64); Value ptr = LLVM::AllocaOp::create(rewriter, loc, llvmPtrTy, flyPtrTy.getElemTy(), nElems, 0); rewriter.replaceOp(op, ptr); return success(); @@ -380,7 +379,7 @@ class PtrLoadOpLowering : public OpConversionPattern { if (flyPtrTy.getAddressSpace().getValue() == AddressSpace::BufferDesc) { BufferFatPtr bp(flyPtrTy, ptr); - Value zero = arith::ConstantIntOp::create(rewriter, loc, 0, 32).getResult(); + Value zero = arith::ConstantIntOp::create(rewriter, loc, 0, 32); ArrayAttr noAttrs; Value loaded = ROCDL::RawPtrBufferLoadOp::create( rewriter, loc, elemTy, bp.bufferRsrc(rewriter, loc), bp.swizzleByteOffset(rewriter, loc), @@ -529,7 +528,7 @@ class CopyAtomCallLowering : public OpConversionPattern { return rewriter.notifyMatchFailure( op, "CDNA3 buffer copy requires exactly one side to be BufferDesc"); - Value zero = arith::ConstantIntOp::create(rewriter, loc, 0, 32).getResult(); + Value zero = arith::ConstantIntOp::create(rewriter, loc, 0, 32); ArrayAttr noAttrs; auto unpackBuffer = [&](Value val, fly::MemRefType flyTy) -> std::pair { @@ -672,8 +671,7 @@ class MmaAtomCallLowering : public OpConversionPattern { Value b = LLVM::LoadOp::create(rewriter, loc, abTyB, bPtr); Value c = LLVM::LoadOp::create(rewriter, loc, accTy, cPtr); auto zeroAttr = rewriter.getI32IntegerAttr(0); - Value res = - MfmaOp::create(rewriter, loc, accTy, a, b, c, zeroAttr, zeroAttr, zeroAttr).getResult(); + Value res = MfmaOp::create(rewriter, loc, accTy, a, b, c, zeroAttr, zeroAttr, zeroAttr); LLVM::StoreOp::create(rewriter, loc, res, dPtr); rewriter.eraseOp(op); return success(); diff --git a/python/flydsl/expr/primitive.py b/python/flydsl/expr/primitive.py index 6e07a448..4cd626e1 100644 --- a/python/flydsl/expr/primitive.py +++ b/python/flydsl/expr/primitive.py @@ -817,11 +817,3 @@ def make_tile(*args, loc=None, ip=None): else: resolved.append(m) return fly.make_tile(resolved, loc=loc, ip=ip) - - -@traced_op -def make_identity_tensor(*shape, loc=None, ip=None): - base = make_int_tuple(tuple([0 for i in range(len(shape))]), loc=loc, ip=ip) - shapeTuple = make_int_tuple(shape, loc=loc, ip=ip) - layout = make_identity_layout(shapeTuple, loc=loc, ip=ip) - return fly.make_view(base, layout, loc=loc, ip=ip) From fb9de4c23de9be61b25bdaebd1952d56ea11d3bb Mon Sep 17 00:00:00 2001 From: Feng Shijie Date: Tue, 24 Mar 2026 03:48:45 +0000 Subject: [PATCH 5/7] fix MFMA 32x32 layoutC --- lib/Dialect/FlyROCDL/CDNA3/MmaAtom.cpp | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/lib/Dialect/FlyROCDL/CDNA3/MmaAtom.cpp b/lib/Dialect/FlyROCDL/CDNA3/MmaAtom.cpp index adc1f07a..d9148597 100644 --- a/lib/Dialect/FlyROCDL/CDNA3/MmaAtom.cpp +++ b/lib/Dialect/FlyROCDL/CDNA3/MmaAtom.cpp @@ -66,8 +66,10 @@ Attribute MmaAtomCDNA3_MFMAType::getThrValLayoutC() const { int GroupM = 64 / N; int ValM0 = 4; + int ValM1 = M / 4 / GroupM; - return FxLayout(FxShape(FxThr(N, GroupM), FxVal(ValM0)), FxStride(FxThr(M, ValM0), FxVal(1))); + return FxLayout(FxShape(FxThr(N, GroupM), FxVal(ValM0, ValM1)), + FxStride(FxThr(M, ValM0), FxVal(1, ValM0 * GroupM))); } LogicalResult MmaAtomCDNA3_MFMAType::verify(function_ref emitError, int32_t m, From 1ac9971fc3a608d79cda4bdf7556db8b645b7d1c Mon Sep 17 00:00:00 2001 From: Feng Shijie Date: Tue, 24 Mar 2026 06:44:54 +0000 Subject: [PATCH 6/7] add missing get_dyn_shared --- python/flydsl/expr/primitive.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/python/flydsl/expr/primitive.py b/python/flydsl/expr/primitive.py index 4cd626e1..c0ca367b 100644 --- a/python/flydsl/expr/primitive.py +++ b/python/flydsl/expr/primitive.py @@ -650,6 +650,11 @@ def make_ptr(result_type, args, loc=None, ip=None): return fly.make_ptr(result_type, args, loc=loc, ip=ip) +@traced_op +def get_dyn_shared(loc=None, ip=None): + return fly.get_dyn_shared(loc=loc, ip=ip) + + @traced_op def inttoptr(result_type, src, loc=None, ip=None): return fly.inttoptr(result_type, src, loc=loc, ip=ip) From a226da8b0e2221f6d6f748715881eb566dfc93bd Mon Sep 17 00:00:00 2001 From: felix Date: Tue, 24 Mar 2026 14:45:55 +0000 Subject: [PATCH 7/7] fix gemm and moe run --- python/flydsl/expr/primitive.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/python/flydsl/expr/primitive.py b/python/flydsl/expr/primitive.py index c0ca367b..28300593 100644 --- a/python/flydsl/expr/primitive.py +++ b/python/flydsl/expr/primitive.py @@ -419,9 +419,18 @@ def cosize(layout, loc=None, ip=None): return result +def _to_i32(v): + """Cast index-type ir.Value to i32 (required by fly.make_int_tuple).""" + if isinstance(v, ir.Value) and isinstance(v.type, ir.IndexType): + return _arith.IndexCastOp(T.i32(), v).result + return v + + @traced_op def crd2idx(crd, layout, loc=None, ip=None): if not isinstance(crd, ir.Value): + if isinstance(crd, (list, tuple)): + crd = tuple(_to_i32(c) for c in crd) crdTy, dyncElems = fly.infer_int_tuple_type(crd) crd = fly.make_coord(crdTy, dyncElems, loc=loc, ip=ip) return fly.crd2idx(crd, layout, loc=loc, ip=ip) @@ -430,7 +439,8 @@ def crd2idx(crd, layout, loc=None, ip=None): @traced_op def idx2crd(idx, layout, loc=None, ip=None): if isinstance(idx, ir.Value) and not str(idx.type).startswith("!fly.int_tuple"): - IntTupleTy, dyncElems = fly.infer_int_tuple_type((idx,)) + idx = _to_i32(idx) + IntTupleTy, dyncElems = fly.infer_int_tuple_type(idx) idx = fly.make_int_tuple(IntTupleTy, dyncElems, loc=loc, ip=ip) return fly.idx2crd(idx, layout, loc=loc, ip=ip)