Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 32 additions & 5 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ jobs:
with:
path: |
llvm-project/llvm/build-shared
key: llvm-${{ runner.os }}-${{ env.LLVM_SOURCE_SHA }}-shared-mlirpy-hardening-v2
key: llvm-${{ runner.os }}-${{ env.LLVM_SOURCE_SHA }}-shared-mlirpy-assertions-v2

- name: Prepare LLVM source (no rebuild)
run: |
Expand All @@ -183,10 +183,17 @@ jobs:
-DMLIR_ENABLE_BINDINGS_PYTHON=ON \
-DPython3_EXECUTABLE=python3 \
-DCMAKE_BUILD_TYPE=Release \
-DLLVM_ENABLE_ASSERTIONS=ON \
-DLLVM_TARGETS_TO_BUILD="host"

ninja -C llvm/build-shared


- name: Verify LLVM assertions
shell: bash
run: |
set -euo pipefail
grep -q '^set(LLVM_ENABLE_ASSERTIONS ON)' \
"${LLVM_DIR}/lib/cmake/llvm/LLVMConfig.cmake"

# LLVM build 完成后立即保存缓存,避免后续测试影响缓存内容
- name: Save LLVM build cache
Expand All @@ -195,7 +202,7 @@ jobs:
with:
path: |
llvm-project/llvm/build-shared
key: llvm-${{ runner.os }}-${{ env.LLVM_SOURCE_SHA }}-shared-mlirpy-hardening-v2
key: llvm-${{ runner.os }}-${{ env.LLVM_SOURCE_SHA }}-shared-mlirpy-assertions-v2

- name: Build PTOAS
run: |
Expand All @@ -209,7 +216,13 @@ jobs:
-DMLIR_ENABLE_BINDINGS_PYTHON=ON \
-DMLIR_PYTHON_PACKAGE_DIR="${LLVM_DIR}/tools/mlir/python_packages/mlir_core" \
-DCMAKE_INSTALL_PREFIX="${PTO_INSTALL_DIR}" \
-DCMAKE_BUILD_TYPE=Release
-DCMAKE_BUILD_TYPE=Release \
-DCMAKE_C_FLAGS_RELEASE="-O3" \
-DCMAKE_CXX_FLAGS_RELEASE="-O3"
if grep -E '^CMAKE_(C|CXX)_FLAGS_RELEASE:.*-DNDEBUG' build/CMakeCache.txt; then
echo "ERROR: PTOAS Release flags still define NDEBUG; assertions are disabled" >&2
exit 1
fi
ninja -C build ptoas
ninja -C build ptobc
ninja -C build install
Expand Down Expand Up @@ -389,10 +402,18 @@ jobs:
-DMLIR_ENABLE_BINDINGS_PYTHON=ON \
-DPython3_EXECUTABLE=python3 \
-DCMAKE_BUILD_TYPE=Release \
-DLLVM_ENABLE_ASSERTIONS=ON \
-DLLVM_TARGETS_TO_BUILD="host"

ninja -C llvm/build-shared

- name: Verify LLVM assertions
shell: bash
run: |
set -euo pipefail
grep -q '^set(LLVM_ENABLE_ASSERTIONS ON)' \
"${LLVM_DIR}/lib/cmake/llvm/LLVMConfig.cmake"

- name: Build PTOAS
shell: bash
run: |
Expand All @@ -406,7 +427,13 @@ jobs:
-Dpybind11_DIR="${PYBIND11_CMAKE_DIR}" \
-DMLIR_ENABLE_BINDINGS_PYTHON=ON \
-DMLIR_PYTHON_PACKAGE_DIR="${LLVM_DIR}/tools/mlir/python_packages/mlir_core" \
-DCMAKE_BUILD_TYPE=Release
-DCMAKE_BUILD_TYPE=Release \
-DCMAKE_C_FLAGS_RELEASE="-O3" \
-DCMAKE_CXX_FLAGS_RELEASE="-O3"
if grep -E '^CMAKE_(C|CXX)_FLAGS_RELEASE:.*-DNDEBUG' build/CMakeCache.txt; then
echo "ERROR: PTOAS Release flags still define NDEBUG; assertions are disabled" >&2
exit 1
fi
ninja -C build ptoas

- name: Resolve simulator environment
Expand Down
13 changes: 7 additions & 6 deletions lib/PTO/IR/PTO.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3608,9 +3608,10 @@ static LogicalResult verifyTColArgReductionOpCommon(Operation *op, Type srcTy,
/*requireNonZeroSrc=*/true)))
return failure();
Type srcElemTy = getElemTy(srcTy);
unsigned srcElemBits = srcElemTy ? srcElemTy.getIntOrFloatBitWidth() : 0;
if (!(mlir::isa<IntegerType, FloatType>(srcElemTy) &&
(srcElemBits == 8 || srcElemBits == 16 || srcElemBits == 32)))
(getPTOStorageElemByteSize(srcElemTy) == 1 ||
getPTOStorageElemByteSize(srcElemTy) == 2 ||
getPTOStorageElemByteSize(srcElemTy) == 4)))
return op->emitOpError(
"expects src/tmp element type to be 1, 2, or 4 bytes wide");
auto dstInt = dyn_cast<IntegerType>(getElemTy(dstTy));
Expand Down Expand Up @@ -5671,8 +5672,8 @@ llvm::LogicalResult mlir::pto::TGatherOp::verify() {
if (!srcSpace || !dstSpace || *srcSpace != pto::AddressSpace::VEC ||
*dstSpace != pto::AddressSpace::VEC)
return emitOpError("expects src and dst to be in the vec address space");
unsigned srcElemBytes = srcElem.getIntOrFloatBitWidth() / 8;
unsigned dstElemBytes = dstElem.getIntOrFloatBitWidth() / 8;
unsigned srcElemBytes = getPTOStorageElemByteSize(srcElem);
unsigned dstElemBytes = getPTOStorageElemByteSize(dstElem);
if (srcElemBytes != dstElemBytes)
return emitOpError("expects src and dst element sizes to match");

Expand Down Expand Up @@ -9432,7 +9433,7 @@ mlir::LogicalResult mlir::pto::TTransOp::verify() {
if (srcTb.getBLayoutValueI32() != static_cast<int32_t>(pto::BLayout::RowMajor))
return emitOpError() << "expects A2/A3 transpose src to use the row_major blayout";
}
unsigned elemBytes = srcElem.getIntOrFloatBitWidth() / 8;
unsigned elemBytes = getPTOStorageElemByteSize(srcElem);
if (elemBytes != 1 && elemBytes != 2 && elemBytes != 4)
return emitOpError() << "expects transpose element size to be 1, 2, or 4 bytes";
auto isAllowedWidthType = [&](Type ty) {
Expand All @@ -9459,7 +9460,7 @@ mlir::LogicalResult mlir::pto::TTransOp::verify() {
Type dstElem = getElemTy(dstTy);
if (!srcElem || !tmpElem || !dstElem || srcElem != dstElem || srcElem != tmpElem)
return emitOpError() << "expects src, tmp, and dst to have the same element type";
unsigned elemBytes = srcElem.getIntOrFloatBitWidth() / 8;
unsigned elemBytes = getPTOStorageElemByteSize(srcElem);
if (elemBytes != 1 && elemBytes != 2 && elemBytes != 4)
return emitOpError() << "expects transpose element size to be 1, 2, or 4 bytes";
auto isAllowedWidthType = [&](Type ty) {
Expand Down
7 changes: 4 additions & 3 deletions lib/PTO/Transforms/GraphSyncSolver/MemInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

#include "PTO/Transforms/GraphSyncSolver/MemInfo.h"
#include "PTO/IR/PTO.h"
#include "PTO/IR/PTOTypeUtils.h"
#include "../Utils.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
Expand All @@ -29,11 +30,11 @@ static std::optional<int64_t> getBufferBitSize(Value value) {
return ShapedType::kDynamic;
}
Type elementType = shaped.getElementType();
auto bitWidth = elementType.getIntOrFloatBitWidth();
if (bitWidth == 0) {
auto byteWidth = pto::getPTOStorageElemByteSize(elementType);
if (byteWidth == 0) {
return ShapedType::kDynamic;
}
return shaped.getNumElements() * bitWidth;
return shaped.getNumElements() * byteWidth * pto::kBitsToByte;
}

llvm::SmallVector<int64_t> getAddresses(const llvm::SmallVector<Value> &addrs) {
Expand Down
19 changes: 14 additions & 5 deletions lib/PTO/Transforms/InsertSync/PTOIRTranslator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
// See LICENSE in the root of the software repository for the full text of the License.

#include "PTO/Transforms/InsertSync/PTOIRTranslator.h"
#include "PTO/IR/PTOTypeUtils.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
Expand All @@ -35,7 +36,7 @@ static std::pair<int64_t, int64_t> getStaticOffsetAndSize(Operation *op, Value s
auto srcType = dyn_cast<MemRefType>(src.getType());
if (!srcType) return {0, 0};

int64_t elemSize = srcType.getElementType().getIntOrFloatBitWidth() / 8;
int64_t elemSize = pto::getPTOStorageElemByteSize(srcType.getElementType());
if (elemSize == 0) elemSize = 1;

// === Case 1: memref.subview ===
Expand Down Expand Up @@ -235,7 +236,10 @@ LogicalResult PTOIRTranslator::UpdateAllocTileOpMemInfo(pto::AllocTileOp op) {
}

if (isStatic) {
int64_t elemSize = tileType.getElementType().getIntOrFloatBitWidth() / 8;
int64_t elemSize =
pto::getPTOStorageElemByteSize(tileType.getElementType());
if (elemSize == 0)
elemSize = 1;
int64_t numElements = 1;
for (auto dim : shape) numElements *= dim;
sizeInBytes = numElements * elemSize;
Expand Down Expand Up @@ -280,7 +284,10 @@ LogicalResult PTOIRTranslator::UpdatePointerCastOpMemInfo(pto::PointerCastOp op)

uint64_t sizeInBytes = 0;
if (memRefType.hasStaticShape()) {
int64_t elemSize = memRefType.getElementType().getIntOrFloatBitWidth() / 8;
int64_t elemSize =
pto::getPTOStorageElemByteSize(memRefType.getElementType());
if (elemSize == 0)
elemSize = 1;
int64_t numElements = 1;
for (auto dim : memRefType.getShape()) numElements *= dim;
sizeInBytes = numElements * elemSize;
Expand Down Expand Up @@ -314,7 +321,8 @@ PTOIRTranslator::UpdateDeclareTileMemRefOpMemInfo(pto::DeclareTileMemRefOp op) {

uint64_t sizeInBytes = 0;
if (memRefType.hasStaticShape()) {
int64_t elemSize = memRefType.getElementType().getIntOrFloatBitWidth() / 8;
int64_t elemSize =
pto::getPTOStorageElemByteSize(memRefType.getElementType());
if (elemSize == 0)
elemSize = 1;

Expand Down Expand Up @@ -593,7 +601,8 @@ LogicalResult PTOIRTranslator::UpdateMemrefAllocOpMemInfo(memref::AllocOp op) {
// 1. 计算大小 (Bytes)
uint64_t sizeInBytes = 0;
if (memRefType.hasStaticShape()) {
int64_t elemSize = memRefType.getElementType().getIntOrFloatBitWidth() / 8;
int64_t elemSize =
pto::getPTOStorageElemByteSize(memRefType.getElementType());
if (elemSize == 0) elemSize = 1; // bool case

int64_t numElements = 1;
Expand Down
16 changes: 7 additions & 9 deletions lib/PTO/Transforms/VPTOExpandWrapperOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -710,10 +710,9 @@ struct LoadCbufToCbControl {
static FailureOr<LoadCbufToCbControl>
deriveLoadCbufToCbControl(Location loc, Value k, Value n, Type elementType,
bool transpose, PatternRewriter &rewriter) {
unsigned elemBitWidth = elementType.getIntOrFloatBitWidth();
if (elemBitWidth == 0 || (elemBitWidth % 8) != 0)
unsigned elemBytes = pto::getPTOStorageElemByteSize(elementType);
if (elemBytes == 0)
return failure();
uint64_t elemBytes = elemBitWidth / 8;

auto constant = [&](uint64_t value) -> Value {
return rewriter.create<arith::ConstantIntOp>(loc, value, 64);
Expand Down Expand Up @@ -749,10 +748,9 @@ deriveLoadCbufToCbControl(Location loc, Value k, Value n, Type elementType,
static FailureOr<LoadCbufToCbControl>
deriveLoadCbufToCaControl(Location loc, Value m, Value k, Type elementType,
bool transpose, PatternRewriter &rewriter) {
unsigned elemBitWidth = elementType.getIntOrFloatBitWidth();
if (elemBitWidth == 0 || (elemBitWidth % 8) != 0)
unsigned elemBytes = pto::getPTOStorageElemByteSize(elementType);
if (elemBytes == 0)
return failure();
uint64_t elemBytes = elemBitWidth / 8;

auto constant = [&](uint64_t value) -> Value {
return rewriter.create<arith::ConstantIntOp>(loc, value, 64);
Expand Down Expand Up @@ -1324,10 +1322,10 @@ struct ExpandRightLoadMxPattern : public OpRewritePattern<pto::MteL1L0bMxOp> {
if (!sourceType)
return rewriter.notifyMatchFailure(op, "expected typed L1 source");

unsigned elemBitWidth = sourceType.getElementType().getIntOrFloatBitWidth();
if (elemBitWidth == 0 || (elemBitWidth % 8) != 0)
unsigned elemBytes =
pto::getPTOStorageElemByteSize(sourceType.getElementType());
if (elemBytes == 0)
return rewriter.notifyMatchFailure(op, "unsupported element type");
uint64_t elemBytes = elemBitWidth / 8;

auto constant = [&](uint64_t value) -> Value {
return rewriter.create<arith::ConstantIntOp>(loc, value, 64);
Expand Down
8 changes: 3 additions & 5 deletions lib/PTO/Transforms/VPTOLLVMEmitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4068,11 +4068,10 @@ class LowerLoadCbufToCaMxOpPattern final
return rewriter.notifyMatchFailure(op, "failed to map cbuf/ca pointer spaces");

Type sourceElemType = cast<pto::PtrType>(op.getSource().getType()).getElementType();
unsigned elemBitWidth = sourceElemType.getIntOrFloatBitWidth();
if (elemBitWidth == 0 || (elemBitWidth % 8) != 0)
unsigned elemBytes = pto::getPTOStorageElemByteSize(sourceElemType);
if (elemBytes == 0)
return rewriter.notifyMatchFailure(op,
"unsupported load_cbuf_to_ca_mx element type");
uint64_t elemBytes = elemBitWidth / 8;
Location loc = op.getLoc();
auto constant = [&](uint64_t value) -> Value {
return rewriter.create<arith::ConstantIntOp>(loc, value, 64);
Expand Down Expand Up @@ -4146,8 +4145,7 @@ class LowerLoadCbufToCbMxOpPattern final
return rewriter.notifyMatchFailure(op, "failed to map cbuf/cb pointer spaces");

Type sourceElemType = cast<pto::PtrType>(op.getSource().getType()).getElementType();
unsigned elemBitWidth = sourceElemType.getIntOrFloatBitWidth();
if (elemBitWidth == 0 || (elemBitWidth % 8) != 0)
if (pto::getPTOStorageElemByteSize(sourceElemType) == 0)
return rewriter.notifyMatchFailure(op,
"unsupported load_cbuf_to_cb_mx element type");
FailureOr<Value> config0 =
Expand Down
Loading