Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions commonir/src/target/codegen_commonir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -844,6 +844,16 @@ void CodeGenTileLangCOMMONIR::VisitExpr_(const CallNode *op, std::ostream &os) {
InfinityCodegen(op, os);
} else if (op->op.same_as(Op::Get("tl.tileop.reduce"))) {
ReduceCodegen(op, os);
} else if (op->op.same_as(Op::Get("tir.fabs"))) {
// 添加 abs 支持 (tir.abs 内部是 tir.fabs)
ICHECK_EQ(op->args.size(), 1) << "abs expects 1 argument";
std::string operand = SSAGetID(PrintExpr(op->args[0]), op->args[0]->dtype);
if (op->dtype.is_float()) {
os << "math.absf %" << operand << " : ";
} else {
os << "math.absi %" << operand << " : ";
}
PrintType(op->dtype, os);
} else if (op->op.same_as(Op::Get("tir.rsqrt"))) {
StubCodegen(op, os, "tir.rsqrt");
} else if (op->op.same_as(Op::Get("tir.sigmoid"))) {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
Expand Down Expand Up @@ -62,6 +63,20 @@
<< "\n[VectorizeParallelLoop] >>> Start matching scf.parallel at "
<< op.getLoc() << "\n");

// 预先检查:如果包含 math dialect 操作,暂时拒绝向量化
bool hasMathOps = false;
for (Operation &inst : op.getBody()->getOperations()) {
if (inst.getDialect() && inst.getDialect()->getNamespace() == "math") {
LLVM_DEBUG(llvm::dbgs() << "[VectorizeParallelLoop] Contains math op: "
<< inst.getName() << ", skipping vectorization for now.\n");
hasMathOps = true;
break;
}
}
if (hasMathOps) {
return failure(); // 拒绝向量化,保持原始循环
}

// 1. 检查循环结构
if (op.getNumLoops() != 1) {
LLVM_DEBUG(llvm::dbgs()
Expand Down Expand Up @@ -228,7 +243,9 @@
inst.getNumOperands() == 2 &&
(isa<arith::AddFOp, arith::MulFOp, arith::AddIOp, arith::MulIOp,
arith::SubFOp, arith::SubIOp, arith::DivFOp, arith::DivSIOp,
arith::DivUIOp>(inst));
arith::DivUIOp, arith::MinSIOp, arith::MinUIOp, arith::MinNumFOp,
arith::MinimumFOp, arith::MaxSIOp, arith::MaxUIOp,
arith::MaxNumFOp, arith::MaximumFOp>(inst));

if (isBinaryOp) {
LLVM_DEBUG(llvm::dbgs() << " [Action] Processing Binary ArithOp: "
Expand Down Expand Up @@ -328,6 +345,51 @@
continue;
}

// --- Case C2: 一元数学操作 (Unary Math Operations -> Vector Unary Operations) ---
bool isUnaryMathOp =
inst.getNumOperands() == 1 &&
(isa<math::AbsFOp, math::AbsIOp>(inst));

if (isUnaryMathOp) {
LLVM_DEBUG(llvm::dbgs() << " [Action] Processing Unary Math Op: "
<< inst.getName() << "\n");

Value operand = inst.getOperand(0);

// 检查操作数是否已向量化
if (scalarToTensorMap.count(operand)) {
Value vecOperand = scalarToTensorMap[operand];
LLVM_DEBUG(llvm::dbgs() << " Operand is vectorized.\n");

// 创建向量化的一元操作
OperationState state(op.getLoc(), inst.getName().getStringRef());
state.addOperands({vecOperand});

// 结果类型:保持元素类型,但变为 tensor
Type scalarResultType = inst.getResult(0).getType();
Type vecResultType = RankedTensorType::get({size}, scalarResultType);
state.addTypes({vecResultType});

// 复制属性(如果有)
for (const auto &attr : inst.getAttrs()) {
state.addAttribute(attr.getName(), attr.getValue());
}

auto newOp = rewriter.create(state);
scalarToTensorMap[inst.getResult(0)] = newOp->getResult(0);

LLVM_DEBUG(llvm::dbgs() << " Created Vector Unary Operation: "
<< inst.getName() << "\n");
LLVM_DEBUG(llvm::dbgs() << " Result Type: "
<< newOp->getResult(0).getType() << "\n");
} else {
LLVM_DEBUG(llvm::dbgs() << " WARNING: Operand not vectorized, "
"cloning scalar op.\n");
rewriter.clone(inst, mapper);
}
continue;
}

// --- Case D: 写回逻辑 (Materialize) ---
if (auto matOp =
dyn_cast<bufferization::MaterializeInDestinationOp>(inst)) {
Expand Down Expand Up @@ -526,6 +588,8 @@
LLVM_DEBUG(llvm::dbgs()
<< " [Unhandled] Operation not handled specifically: "
<< inst.getName() << "\n");
// 对于未处理的操作,至少克隆它们以保持 IR 完整性
rewriter.clone(inst, mapper);
}

// 打印当前op
Expand All @@ -537,8 +601,8 @@
// 打印映射表
LLVM_DEBUG({
llvm::dbgs() << "[VectorizeParallelLoop] Scalar to Tensor Map:\n";
for (auto &[scalar, tensor] : scalarToTensorMap) {
llvm::dbgs() << " " << scalar << " -> " << tensor << "\n";
for (const auto &kv : scalarToTensorMap) {
llvm::dbgs() << " " << kv.first << " -> " << kv.second << "\n";
}
});

Expand Down
121 changes: 121 additions & 0 deletions test/commonir/ascend/test_min_max_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
import torch
import tilelang
import tilelang.language as T

dtype = "float32"


def test_elementwise_max():
"""测试:Element-wise max(A, B)"""
N = 1024

@T.prim_func
def kernel(
A: T.Tensor((N,), dtype), B: T.Tensor((N,), dtype), C: T.Tensor((N,), dtype)
):
with T.Kernel(1, 1) as (tid, _):
for i in T.Parallel(N):
C[i] = T.max(A[i], B[i])

a = torch.randn(N, dtype=torch.float32).npu()
b = torch.randn(N, dtype=torch.float32).npu()
c = torch.zeros(N, dtype=torch.float32).npu()

compiled = tilelang.compile(kernel)
compiled(a, b, c)

expected = torch.maximum(a, b)
max_diff = torch.max(torch.abs(c - expected))
print(f"elementwise_max: max_diff = {max_diff.item():.8f}")
print(f" result sample: {c[:5]}")
print(f" expected sample: {expected[:5]}")
torch.testing.assert_close(c, expected, atol=1e-6, rtol=1e-5)
print("✓ test_elementwise_max passed")


def test_elementwise_min():
"""测试:Element-wise min(A, B)"""
N = 1024

@T.prim_func
def kernel(
A: T.Tensor((N,), dtype), B: T.Tensor((N,), dtype), C: T.Tensor((N,), dtype)
):
with T.Kernel(1, 1) as (tid, _):
for i in T.Parallel(N):
C[i] = T.min(A[i], B[i])

a = torch.randn(N, dtype=torch.float32).npu()
b = torch.randn(N, dtype=torch.float32).npu()
c = torch.zeros(N, dtype=torch.float32).npu()

compiled = tilelang.compile(kernel)
compiled(a, b, c)

expected = torch.minimum(a, b)
max_diff = torch.max(torch.abs(c - expected))
print(f"elementwise_min: max_diff = {max_diff.item():.8f}")
print(f" result sample: {c[:5]}")
print(f" expected sample: {expected[:5]}")
torch.testing.assert_close(c, expected, atol=1e-6, rtol=1e-5)
print("✓ test_elementwise_min passed")


def test_relu():
"""测试:ReLU = max(x, 0)"""
N = 512

@T.prim_func
def kernel(A: T.Tensor((N,), dtype), C: T.Tensor((N,), dtype)):
with T.Kernel(1, 1) as (tid, _):
for i in T.Parallel(N):
# 直接使用常量,避免局部变量的 codegen bug
C[i] = T.max(A[i], T.float32(0.0))

a = torch.randn(N, dtype=torch.float32).npu()
c = torch.zeros(N, dtype=torch.float32).npu()

compiled = tilelang.compile(kernel)
compiled(a, c)

expected = torch.relu(a)
max_diff = torch.max(torch.abs(c - expected))
print(f"relu: max_diff = {max_diff.item():.8f}")
torch.testing.assert_close(c, expected, atol=1e-6, rtol=1e-5)
print("✓ test_relu passed")


def test_clip():
"""测试:clip(x, 0, 1) = min(max(x, 0), 1)"""
N = 256

@T.prim_func
def kernel(A: T.Tensor((N,), dtype), C: T.Tensor((N,), dtype)):
with T.Kernel(1, 1) as (tid, _):
for i in T.Parallel(N):
# 避免局部变量,直接使用常量
temp = T.max(A[i], T.float32(0.0))
C[i] = T.min(temp, T.float32(1.0))

a = torch.randn(N, dtype=torch.float32).npu()
c = torch.zeros(N, dtype=torch.float32).npu()

compiled = tilelang.compile(kernel)
compiled(a, c)

expected = torch.clip(a, 0.0, 1.0)
max_diff = torch.max(torch.abs(c - expected))
print(f"clip: max_diff = {max_diff.item():.8f}")
torch.testing.assert_close(c, expected, atol=1e-6, rtol=1e-5)
print("✓ test_clip passed")


if __name__ == "__main__":
print("Testing Element-wise Min/Max Operations")
print("=" * 60)
test_elementwise_max()
test_elementwise_min()
test_relu()
test_clip()
print("=" * 60)
print("All element-wise Min/Max tests passed! ✓")
Loading