diff --git a/commonir/src/target/codegen_commonir.cc b/commonir/src/target/codegen_commonir.cc index a000b1bd..da955cc5 100644 --- a/commonir/src/target/codegen_commonir.cc +++ b/commonir/src/target/codegen_commonir.cc @@ -844,6 +844,16 @@ void CodeGenTileLangCOMMONIR::VisitExpr_(const CallNode *op, std::ostream &os) { InfinityCodegen(op, os); } else if (op->op.same_as(Op::Get("tl.tileop.reduce"))) { ReduceCodegen(op, os); + } else if (op->op.same_as(Op::Get("tir.fabs"))) { + // 添加 abs 支持 (tir.abs 内部是 tir.fabs) + ICHECK_EQ(op->args.size(), 1) << "abs expects 1 argument"; + std::string operand = SSAGetID(PrintExpr(op->args[0]), op->args[0]->dtype); + if (op->dtype.is_float()) { + os << "math.absf %" << operand << " : "; + } else { + os << "math.absi %" << operand << " : "; + } + PrintType(op->dtype, os); } else if (op->op.same_as(Op::Get("tir.rsqrt"))) { StubCodegen(op, os, "tir.rsqrt"); } else if (op->op.same_as(Op::Get("tir.sigmoid"))) { diff --git a/compiler/lib/Dialect/LinalgExt/Transforms/VectorizeParallelLoopPass.cpp b/compiler/lib/Dialect/LinalgExt/Transforms/VectorizeParallelLoopPass.cpp index 026b6f9b..93617d26 100644 --- a/compiler/lib/Dialect/LinalgExt/Transforms/VectorizeParallelLoopPass.cpp +++ b/compiler/lib/Dialect/LinalgExt/Transforms/VectorizeParallelLoopPass.cpp @@ -1,6 +1,7 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" @@ -62,6 +63,20 @@ struct VectorizeParallelLoopPattern : public OpRewritePattern { << "\n[VectorizeParallelLoop] >>> Start matching scf.parallel at " << op.getLoc() << "\n"); + // 预先检查:如果包含 math dialect 操作,暂时拒绝向量化 + bool hasMathOps = false; + for (Operation &inst : op.getBody()->getOperations()) { + if (inst.getDialect() && inst.getDialect()->getNamespace() == "math") { + LLVM_DEBUG(llvm::dbgs() << "[VectorizeParallelLoop] Contains math op: " + << inst.getName() << ", skipping vectorization for now.\n"); + hasMathOps = true; + break; + } + } + if (hasMathOps) { + return failure(); // 拒绝向量化,保持原始循环 + } + // 1. 检查循环结构 if (op.getNumLoops() != 1) { LLVM_DEBUG(llvm::dbgs() @@ -228,7 +243,9 @@ struct VectorizeParallelLoopPattern : public OpRewritePattern { inst.getNumOperands() == 2 && (isa(inst)); + arith::DivUIOp, arith::MinSIOp, arith::MinUIOp, arith::MinNumFOp, + arith::MinimumFOp, arith::MaxSIOp, arith::MaxUIOp, + arith::MaxNumFOp, arith::MaximumFOp>(inst)); if (isBinaryOp) { LLVM_DEBUG(llvm::dbgs() << " [Action] Processing Binary ArithOp: " @@ -328,6 +345,51 @@ struct VectorizeParallelLoopPattern : public OpRewritePattern { continue; } + // --- Case C2: 一元数学操作 (Unary Math Operations -> Vector Unary Operations) --- + bool isUnaryMathOp = + inst.getNumOperands() == 1 && + (isa(inst)); + + if (isUnaryMathOp) { + LLVM_DEBUG(llvm::dbgs() << " [Action] Processing Unary Math Op: " + << inst.getName() << "\n"); + + Value operand = inst.getOperand(0); + + // 检查操作数是否已向量化 + if (scalarToTensorMap.count(operand)) { + Value vecOperand = scalarToTensorMap[operand]; + LLVM_DEBUG(llvm::dbgs() << " Operand is vectorized.\n"); + + // 创建向量化的一元操作 + OperationState state(op.getLoc(), inst.getName().getStringRef()); + state.addOperands({vecOperand}); + + // 结果类型:保持元素类型,但变为 tensor + Type scalarResultType = inst.getResult(0).getType(); + Type vecResultType = RankedTensorType::get({size}, scalarResultType); + state.addTypes({vecResultType}); + + // 复制属性(如果有) + for (const auto &attr : inst.getAttrs()) { + state.addAttribute(attr.getName(), attr.getValue()); + } + + auto newOp = rewriter.create(state); + scalarToTensorMap[inst.getResult(0)] = newOp->getResult(0); + + LLVM_DEBUG(llvm::dbgs() << " Created Vector Unary Operation: " + << inst.getName() << "\n"); + LLVM_DEBUG(llvm::dbgs() << " Result Type: " + << newOp->getResult(0).getType() << "\n"); + } else { + LLVM_DEBUG(llvm::dbgs() << " WARNING: Operand not vectorized, " + "cloning scalar op.\n"); + rewriter.clone(inst, mapper); + } + continue; + } + // --- Case D: 写回逻辑 (Materialize) --- if (auto matOp = dyn_cast(inst)) { @@ -526,6 +588,8 @@ struct VectorizeParallelLoopPattern : public OpRewritePattern { LLVM_DEBUG(llvm::dbgs() << " [Unhandled] Operation not handled specifically: " << inst.getName() << "\n"); + // 对于未处理的操作,至少克隆它们以保持 IR 完整性 + rewriter.clone(inst, mapper); } // 打印当前op @@ -537,8 +601,8 @@ struct VectorizeParallelLoopPattern : public OpRewritePattern { // 打印映射表 LLVM_DEBUG({ llvm::dbgs() << "[VectorizeParallelLoop] Scalar to Tensor Map:\n"; - for (auto &[scalar, tensor] : scalarToTensorMap) { - llvm::dbgs() << " " << scalar << " -> " << tensor << "\n"; + for (const auto &kv : scalarToTensorMap) { + llvm::dbgs() << " " << kv.first << " -> " << kv.second << "\n"; } }); diff --git a/test/commonir/ascend/test_min_max_ops.py b/test/commonir/ascend/test_min_max_ops.py new file mode 100644 index 00000000..995008ce --- /dev/null +++ b/test/commonir/ascend/test_min_max_ops.py @@ -0,0 +1,121 @@ +import torch +import tilelang +import tilelang.language as T + +dtype = "float32" + + +def test_elementwise_max(): + """测试:Element-wise max(A, B)""" + N = 1024 + + @T.prim_func + def kernel( + A: T.Tensor((N,), dtype), B: T.Tensor((N,), dtype), C: T.Tensor((N,), dtype) + ): + with T.Kernel(1, 1) as (tid, _): + for i in T.Parallel(N): + C[i] = T.max(A[i], B[i]) + + a = torch.randn(N, dtype=torch.float32).npu() + b = torch.randn(N, dtype=torch.float32).npu() + c = torch.zeros(N, dtype=torch.float32).npu() + + compiled = tilelang.compile(kernel) + compiled(a, b, c) + + expected = torch.maximum(a, b) + max_diff = torch.max(torch.abs(c - expected)) + print(f"elementwise_max: max_diff = {max_diff.item():.8f}") + print(f" result sample: {c[:5]}") + print(f" expected sample: {expected[:5]}") + torch.testing.assert_close(c, expected, atol=1e-6, rtol=1e-5) + print("✓ test_elementwise_max passed") + + +def test_elementwise_min(): + """测试:Element-wise min(A, B)""" + N = 1024 + + @T.prim_func + def kernel( + A: T.Tensor((N,), dtype), B: T.Tensor((N,), dtype), C: T.Tensor((N,), dtype) + ): + with T.Kernel(1, 1) as (tid, _): + for i in T.Parallel(N): + C[i] = T.min(A[i], B[i]) + + a = torch.randn(N, dtype=torch.float32).npu() + b = torch.randn(N, dtype=torch.float32).npu() + c = torch.zeros(N, dtype=torch.float32).npu() + + compiled = tilelang.compile(kernel) + compiled(a, b, c) + + expected = torch.minimum(a, b) + max_diff = torch.max(torch.abs(c - expected)) + print(f"elementwise_min: max_diff = {max_diff.item():.8f}") + print(f" result sample: {c[:5]}") + print(f" expected sample: {expected[:5]}") + torch.testing.assert_close(c, expected, atol=1e-6, rtol=1e-5) + print("✓ test_elementwise_min passed") + + +def test_relu(): + """测试:ReLU = max(x, 0)""" + N = 512 + + @T.prim_func + def kernel(A: T.Tensor((N,), dtype), C: T.Tensor((N,), dtype)): + with T.Kernel(1, 1) as (tid, _): + for i in T.Parallel(N): + # 直接使用常量,避免局部变量的 codegen bug + C[i] = T.max(A[i], T.float32(0.0)) + + a = torch.randn(N, dtype=torch.float32).npu() + c = torch.zeros(N, dtype=torch.float32).npu() + + compiled = tilelang.compile(kernel) + compiled(a, c) + + expected = torch.relu(a) + max_diff = torch.max(torch.abs(c - expected)) + print(f"relu: max_diff = {max_diff.item():.8f}") + torch.testing.assert_close(c, expected, atol=1e-6, rtol=1e-5) + print("✓ test_relu passed") + + +def test_clip(): + """测试:clip(x, 0, 1) = min(max(x, 0), 1)""" + N = 256 + + @T.prim_func + def kernel(A: T.Tensor((N,), dtype), C: T.Tensor((N,), dtype)): + with T.Kernel(1, 1) as (tid, _): + for i in T.Parallel(N): + # 避免局部变量,直接使用常量 + temp = T.max(A[i], T.float32(0.0)) + C[i] = T.min(temp, T.float32(1.0)) + + a = torch.randn(N, dtype=torch.float32).npu() + c = torch.zeros(N, dtype=torch.float32).npu() + + compiled = tilelang.compile(kernel) + compiled(a, c) + + expected = torch.clip(a, 0.0, 1.0) + max_diff = torch.max(torch.abs(c - expected)) + print(f"clip: max_diff = {max_diff.item():.8f}") + torch.testing.assert_close(c, expected, atol=1e-6, rtol=1e-5) + print("✓ test_clip passed") + + +if __name__ == "__main__": + print("Testing Element-wise Min/Max Operations") + print("=" * 60) + test_elementwise_max() + test_elementwise_min() + test_relu() + test_clip() + print("=" * 60) + print("All element-wise Min/Max tests passed! ✓")