From 498e7f849a778d890580b552fe29c9e2799ea263 Mon Sep 17 00:00:00 2001 From: mouliangyu Date: Fri, 15 May 2026 10:28:21 +0800 Subject: [PATCH] feat: enable assertion on ci --- .github/workflows/ci.yml | 37 ++++++++++++++++--- lib/PTO/IR/PTO.cpp | 13 ++++--- .../Transforms/GraphSyncSolver/MemInfo.cpp | 7 ++-- .../Transforms/InsertSync/PTOIRTranslator.cpp | 19 +++++++--- lib/PTO/Transforms/VPTOExpandWrapperOps.cpp | 16 ++++---- lib/PTO/Transforms/VPTOLLVMEmitter.cpp | 8 ++-- 6 files changed, 67 insertions(+), 33 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 65a799629..386b20560 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -157,7 +157,7 @@ jobs: with: path: | llvm-project/llvm/build-shared - key: llvm-${{ runner.os }}-${{ env.LLVM_SOURCE_SHA }}-shared-mlirpy-hardening-v2 + key: llvm-${{ runner.os }}-${{ env.LLVM_SOURCE_SHA }}-shared-mlirpy-assertions-v2 - name: Prepare LLVM source (no rebuild) run: | @@ -183,10 +183,17 @@ jobs: -DMLIR_ENABLE_BINDINGS_PYTHON=ON \ -DPython3_EXECUTABLE=python3 \ -DCMAKE_BUILD_TYPE=Release \ + -DLLVM_ENABLE_ASSERTIONS=ON \ -DLLVM_TARGETS_TO_BUILD="host" ninja -C llvm/build-shared - + + - name: Verify LLVM assertions + shell: bash + run: | + set -euo pipefail + grep -q '^set(LLVM_ENABLE_ASSERTIONS ON)' \ + "${LLVM_DIR}/lib/cmake/llvm/LLVMConfig.cmake" # LLVM build 完成后立即保存缓存,避免后续测试影响缓存内容 - name: Save LLVM build cache @@ -195,7 +202,7 @@ jobs: with: path: | llvm-project/llvm/build-shared - key: llvm-${{ runner.os }}-${{ env.LLVM_SOURCE_SHA }}-shared-mlirpy-hardening-v2 + key: llvm-${{ runner.os }}-${{ env.LLVM_SOURCE_SHA }}-shared-mlirpy-assertions-v2 - name: Build PTOAS run: | @@ -209,7 +216,13 @@ jobs: -DMLIR_ENABLE_BINDINGS_PYTHON=ON \ -DMLIR_PYTHON_PACKAGE_DIR="${LLVM_DIR}/tools/mlir/python_packages/mlir_core" \ -DCMAKE_INSTALL_PREFIX="${PTO_INSTALL_DIR}" \ - -DCMAKE_BUILD_TYPE=Release + -DCMAKE_BUILD_TYPE=Release \ + -DCMAKE_C_FLAGS_RELEASE="-O3" \ + -DCMAKE_CXX_FLAGS_RELEASE="-O3" + if grep -E '^CMAKE_(C|CXX)_FLAGS_RELEASE:.*-DNDEBUG' build/CMakeCache.txt; then + echo "ERROR: PTOAS Release flags still define NDEBUG; assertions are disabled" >&2 + exit 1 + fi ninja -C build ptoas ninja -C build ptobc ninja -C build install @@ -389,10 +402,18 @@ jobs: -DMLIR_ENABLE_BINDINGS_PYTHON=ON \ -DPython3_EXECUTABLE=python3 \ -DCMAKE_BUILD_TYPE=Release \ + -DLLVM_ENABLE_ASSERTIONS=ON \ -DLLVM_TARGETS_TO_BUILD="host" ninja -C llvm/build-shared + - name: Verify LLVM assertions + shell: bash + run: | + set -euo pipefail + grep -q '^set(LLVM_ENABLE_ASSERTIONS ON)' \ + "${LLVM_DIR}/lib/cmake/llvm/LLVMConfig.cmake" + - name: Build PTOAS shell: bash run: | @@ -406,7 +427,13 @@ jobs: -Dpybind11_DIR="${PYBIND11_CMAKE_DIR}" \ -DMLIR_ENABLE_BINDINGS_PYTHON=ON \ -DMLIR_PYTHON_PACKAGE_DIR="${LLVM_DIR}/tools/mlir/python_packages/mlir_core" \ - -DCMAKE_BUILD_TYPE=Release + -DCMAKE_BUILD_TYPE=Release \ + -DCMAKE_C_FLAGS_RELEASE="-O3" \ + -DCMAKE_CXX_FLAGS_RELEASE="-O3" + if grep -E '^CMAKE_(C|CXX)_FLAGS_RELEASE:.*-DNDEBUG' build/CMakeCache.txt; then + echo "ERROR: PTOAS Release flags still define NDEBUG; assertions are disabled" >&2 + exit 1 + fi ninja -C build ptoas - name: Resolve simulator environment diff --git a/lib/PTO/IR/PTO.cpp b/lib/PTO/IR/PTO.cpp index 68ded38e7..6f12981d7 100644 --- a/lib/PTO/IR/PTO.cpp +++ b/lib/PTO/IR/PTO.cpp @@ -3608,9 +3608,10 @@ static LogicalResult verifyTColArgReductionOpCommon(Operation *op, Type srcTy, /*requireNonZeroSrc=*/true))) return failure(); Type srcElemTy = getElemTy(srcTy); - unsigned srcElemBits = srcElemTy ? srcElemTy.getIntOrFloatBitWidth() : 0; if (!(mlir::isa(srcElemTy) && - (srcElemBits == 8 || srcElemBits == 16 || srcElemBits == 32))) + (getPTOStorageElemByteSize(srcElemTy) == 1 || + getPTOStorageElemByteSize(srcElemTy) == 2 || + getPTOStorageElemByteSize(srcElemTy) == 4))) return op->emitOpError( "expects src/tmp element type to be 1, 2, or 4 bytes wide"); auto dstInt = dyn_cast(getElemTy(dstTy)); @@ -5671,8 +5672,8 @@ llvm::LogicalResult mlir::pto::TGatherOp::verify() { if (!srcSpace || !dstSpace || *srcSpace != pto::AddressSpace::VEC || *dstSpace != pto::AddressSpace::VEC) return emitOpError("expects src and dst to be in the vec address space"); - unsigned srcElemBytes = srcElem.getIntOrFloatBitWidth() / 8; - unsigned dstElemBytes = dstElem.getIntOrFloatBitWidth() / 8; + unsigned srcElemBytes = getPTOStorageElemByteSize(srcElem); + unsigned dstElemBytes = getPTOStorageElemByteSize(dstElem); if (srcElemBytes != dstElemBytes) return emitOpError("expects src and dst element sizes to match"); @@ -9432,7 +9433,7 @@ mlir::LogicalResult mlir::pto::TTransOp::verify() { if (srcTb.getBLayoutValueI32() != static_cast(pto::BLayout::RowMajor)) return emitOpError() << "expects A2/A3 transpose src to use the row_major blayout"; } - unsigned elemBytes = srcElem.getIntOrFloatBitWidth() / 8; + unsigned elemBytes = getPTOStorageElemByteSize(srcElem); if (elemBytes != 1 && elemBytes != 2 && elemBytes != 4) return emitOpError() << "expects transpose element size to be 1, 2, or 4 bytes"; auto isAllowedWidthType = [&](Type ty) { @@ -9459,7 +9460,7 @@ mlir::LogicalResult mlir::pto::TTransOp::verify() { Type dstElem = getElemTy(dstTy); if (!srcElem || !tmpElem || !dstElem || srcElem != dstElem || srcElem != tmpElem) return emitOpError() << "expects src, tmp, and dst to have the same element type"; - unsigned elemBytes = srcElem.getIntOrFloatBitWidth() / 8; + unsigned elemBytes = getPTOStorageElemByteSize(srcElem); if (elemBytes != 1 && elemBytes != 2 && elemBytes != 4) return emitOpError() << "expects transpose element size to be 1, 2, or 4 bytes"; auto isAllowedWidthType = [&](Type ty) { diff --git a/lib/PTO/Transforms/GraphSyncSolver/MemInfo.cpp b/lib/PTO/Transforms/GraphSyncSolver/MemInfo.cpp index 50d0024cb..1e065bf64 100644 --- a/lib/PTO/Transforms/GraphSyncSolver/MemInfo.cpp +++ b/lib/PTO/Transforms/GraphSyncSolver/MemInfo.cpp @@ -11,6 +11,7 @@ #include "PTO/Transforms/GraphSyncSolver/MemInfo.h" #include "PTO/IR/PTO.h" +#include "PTO/IR/PTOTypeUtils.h" #include "../Utils.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/IR/BuiltinTypeInterfaces.h" @@ -29,11 +30,11 @@ static std::optional getBufferBitSize(Value value) { return ShapedType::kDynamic; } Type elementType = shaped.getElementType(); - auto bitWidth = elementType.getIntOrFloatBitWidth(); - if (bitWidth == 0) { + auto byteWidth = pto::getPTOStorageElemByteSize(elementType); + if (byteWidth == 0) { return ShapedType::kDynamic; } - return shaped.getNumElements() * bitWidth; + return shaped.getNumElements() * byteWidth * pto::kBitsToByte; } llvm::SmallVector getAddresses(const llvm::SmallVector &addrs) { diff --git a/lib/PTO/Transforms/InsertSync/PTOIRTranslator.cpp b/lib/PTO/Transforms/InsertSync/PTOIRTranslator.cpp index 8ba4f265b..c5cd705cd 100644 --- a/lib/PTO/Transforms/InsertSync/PTOIRTranslator.cpp +++ b/lib/PTO/Transforms/InsertSync/PTOIRTranslator.cpp @@ -12,6 +12,7 @@ // See LICENSE in the root of the software repository for the full text of the License. #include "PTO/Transforms/InsertSync/PTOIRTranslator.h" +#include "PTO/IR/PTOTypeUtils.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" @@ -35,7 +36,7 @@ static std::pair getStaticOffsetAndSize(Operation *op, Value s auto srcType = dyn_cast(src.getType()); if (!srcType) return {0, 0}; - int64_t elemSize = srcType.getElementType().getIntOrFloatBitWidth() / 8; + int64_t elemSize = pto::getPTOStorageElemByteSize(srcType.getElementType()); if (elemSize == 0) elemSize = 1; // === Case 1: memref.subview === @@ -235,7 +236,10 @@ LogicalResult PTOIRTranslator::UpdateAllocTileOpMemInfo(pto::AllocTileOp op) { } if (isStatic) { - int64_t elemSize = tileType.getElementType().getIntOrFloatBitWidth() / 8; + int64_t elemSize = + pto::getPTOStorageElemByteSize(tileType.getElementType()); + if (elemSize == 0) + elemSize = 1; int64_t numElements = 1; for (auto dim : shape) numElements *= dim; sizeInBytes = numElements * elemSize; @@ -280,7 +284,10 @@ LogicalResult PTOIRTranslator::UpdatePointerCastOpMemInfo(pto::PointerCastOp op) uint64_t sizeInBytes = 0; if (memRefType.hasStaticShape()) { - int64_t elemSize = memRefType.getElementType().getIntOrFloatBitWidth() / 8; + int64_t elemSize = + pto::getPTOStorageElemByteSize(memRefType.getElementType()); + if (elemSize == 0) + elemSize = 1; int64_t numElements = 1; for (auto dim : memRefType.getShape()) numElements *= dim; sizeInBytes = numElements * elemSize; @@ -314,7 +321,8 @@ PTOIRTranslator::UpdateDeclareTileMemRefOpMemInfo(pto::DeclareTileMemRefOp op) { uint64_t sizeInBytes = 0; if (memRefType.hasStaticShape()) { - int64_t elemSize = memRefType.getElementType().getIntOrFloatBitWidth() / 8; + int64_t elemSize = + pto::getPTOStorageElemByteSize(memRefType.getElementType()); if (elemSize == 0) elemSize = 1; @@ -593,7 +601,8 @@ LogicalResult PTOIRTranslator::UpdateMemrefAllocOpMemInfo(memref::AllocOp op) { // 1. 计算大小 (Bytes) uint64_t sizeInBytes = 0; if (memRefType.hasStaticShape()) { - int64_t elemSize = memRefType.getElementType().getIntOrFloatBitWidth() / 8; + int64_t elemSize = + pto::getPTOStorageElemByteSize(memRefType.getElementType()); if (elemSize == 0) elemSize = 1; // bool case int64_t numElements = 1; diff --git a/lib/PTO/Transforms/VPTOExpandWrapperOps.cpp b/lib/PTO/Transforms/VPTOExpandWrapperOps.cpp index e4871a630..f7c64f2b5 100644 --- a/lib/PTO/Transforms/VPTOExpandWrapperOps.cpp +++ b/lib/PTO/Transforms/VPTOExpandWrapperOps.cpp @@ -710,10 +710,9 @@ struct LoadCbufToCbControl { static FailureOr deriveLoadCbufToCbControl(Location loc, Value k, Value n, Type elementType, bool transpose, PatternRewriter &rewriter) { - unsigned elemBitWidth = elementType.getIntOrFloatBitWidth(); - if (elemBitWidth == 0 || (elemBitWidth % 8) != 0) + unsigned elemBytes = pto::getPTOStorageElemByteSize(elementType); + if (elemBytes == 0) return failure(); - uint64_t elemBytes = elemBitWidth / 8; auto constant = [&](uint64_t value) -> Value { return rewriter.create(loc, value, 64); @@ -749,10 +748,9 @@ deriveLoadCbufToCbControl(Location loc, Value k, Value n, Type elementType, static FailureOr deriveLoadCbufToCaControl(Location loc, Value m, Value k, Type elementType, bool transpose, PatternRewriter &rewriter) { - unsigned elemBitWidth = elementType.getIntOrFloatBitWidth(); - if (elemBitWidth == 0 || (elemBitWidth % 8) != 0) + unsigned elemBytes = pto::getPTOStorageElemByteSize(elementType); + if (elemBytes == 0) return failure(); - uint64_t elemBytes = elemBitWidth / 8; auto constant = [&](uint64_t value) -> Value { return rewriter.create(loc, value, 64); @@ -1324,10 +1322,10 @@ struct ExpandRightLoadMxPattern : public OpRewritePattern { if (!sourceType) return rewriter.notifyMatchFailure(op, "expected typed L1 source"); - unsigned elemBitWidth = sourceType.getElementType().getIntOrFloatBitWidth(); - if (elemBitWidth == 0 || (elemBitWidth % 8) != 0) + unsigned elemBytes = + pto::getPTOStorageElemByteSize(sourceType.getElementType()); + if (elemBytes == 0) return rewriter.notifyMatchFailure(op, "unsupported element type"); - uint64_t elemBytes = elemBitWidth / 8; auto constant = [&](uint64_t value) -> Value { return rewriter.create(loc, value, 64); diff --git a/lib/PTO/Transforms/VPTOLLVMEmitter.cpp b/lib/PTO/Transforms/VPTOLLVMEmitter.cpp index 8680a03e4..37306535e 100644 --- a/lib/PTO/Transforms/VPTOLLVMEmitter.cpp +++ b/lib/PTO/Transforms/VPTOLLVMEmitter.cpp @@ -4068,11 +4068,10 @@ class LowerLoadCbufToCaMxOpPattern final return rewriter.notifyMatchFailure(op, "failed to map cbuf/ca pointer spaces"); Type sourceElemType = cast(op.getSource().getType()).getElementType(); - unsigned elemBitWidth = sourceElemType.getIntOrFloatBitWidth(); - if (elemBitWidth == 0 || (elemBitWidth % 8) != 0) + unsigned elemBytes = pto::getPTOStorageElemByteSize(sourceElemType); + if (elemBytes == 0) return rewriter.notifyMatchFailure(op, "unsupported load_cbuf_to_ca_mx element type"); - uint64_t elemBytes = elemBitWidth / 8; Location loc = op.getLoc(); auto constant = [&](uint64_t value) -> Value { return rewriter.create(loc, value, 64); @@ -4146,8 +4145,7 @@ class LowerLoadCbufToCbMxOpPattern final return rewriter.notifyMatchFailure(op, "failed to map cbuf/cb pointer spaces"); Type sourceElemType = cast(op.getSource().getType()).getElementType(); - unsigned elemBitWidth = sourceElemType.getIntOrFloatBitWidth(); - if (elemBitWidth == 0 || (elemBitWidth % 8) != 0) + if (pto::getPTOStorageElemByteSize(sourceElemType) == 0) return rewriter.notifyMatchFailure(op, "unsupported load_cbuf_to_cb_mx element type"); FailureOr config0 =