diff --git a/llvm/include/llvm/Transforms/Scalar/GVN.h b/llvm/include/llvm/Transforms/Scalar/GVN.h index bc0f108ac8260..c35f9a03112cb 100644 --- a/llvm/include/llvm/Transforms/Scalar/GVN.h +++ b/llvm/include/llvm/Transforms/Scalar/GVN.h @@ -22,6 +22,7 @@ #include "llvm/IR/Dominators.h" #include "llvm/IR/InstrTypes.h" #include "llvm/IR/PassManager.h" +#include "llvm/Analysis/LoopInfo.h" #include "llvm/IR/ValueHandle.h" #include "llvm/Support/Allocator.h" #include "llvm/Support/Compiler.h" @@ -45,6 +46,7 @@ class FunctionPass; class GetElementPtrInst; class ImplicitControlFlowTracking; class LoadInst; +class SelectInst; class LoopInfo; class MemDepResult; class MemoryAccess; @@ -405,6 +407,14 @@ class GVNPass : public PassInfoMixin { void addDeadBlock(BasicBlock *BB); void assignValNumForDeadCode(); void assignBlockRPONumber(Function &F); + + bool recognizeMinFindingSelectPattern(SelectInst *Select); + bool transformMinFindingSelectPattern(Loop *L, Type *LoadType, + BasicBlock *Preheader, BasicBlock *BB, + Value *LHS, Value *RHS, + CmpInst *Comparison, SelectInst *Select, + Value *BasePtr, Value *IndexVal, + Value *OffsetVal); }; /// Create a legacy GVN pass. diff --git a/llvm/lib/Transforms/Scalar/GVN.cpp b/llvm/lib/Transforms/Scalar/GVN.cpp index 72e1131a54a86..c41518f874623 100644 --- a/llvm/lib/Transforms/Scalar/GVN.cpp +++ b/llvm/lib/Transforms/Scalar/GVN.cpp @@ -2743,6 +2743,10 @@ bool GVNPass::processInstruction(Instruction *I) { } return Changed; } + if (SelectInst *Select = dyn_cast(I)) { + if (recognizeMinFindingSelectPattern(Select)) + return true; + } // Instructions with void type don't return a value, so there's // no point in trying to find redundancies in them. @@ -3330,6 +3334,128 @@ void GVNPass::assignValNumForDeadCode() { } } +bool GVNPass::transformMinFindingSelectPattern( + Loop *L, Type *LoadType, BasicBlock *Preheader, BasicBlock *BB, Value *LHS, + Value *RHS, CmpInst *Comparison, SelectInst *Select, Value *BasePtr, + Value *IndexVal, Value *OffsetVal) { + // Hoist the chain of operations for the second load to preheader. + // %min.idx.ext = sext i32 %min.idx to i64 + // %ptr.float.min = getelementptr float, ptr %0, i64 %min.idx.ext + // %ptr.second.load = getelementptr i8, ptr %ptr.float.min, i64 -4 + // %val.current.min = load float, ptr %ptr.second.load, align 4 + IRBuilder<> Builder(Preheader->getTerminator()); + + PHINode *Phi = dyn_cast(IndexVal); + if (!Phi) { + LLVM_DEBUG(dbgs() << "GVN: IndexVal is not a PHI node\n"); + return false; + } + + Value *InitialMinIndex = Phi->getIncomingValueForBlock(Preheader); + + // Insert PHI node at the top of this block. + // This PHI node will be used to memoize the current minimum value so far. + PHINode *KnownMinPhi = PHINode::Create(LoadType, 2, "known_min", BB->begin()); + + // Hoist the load and build the necessary operations. + // 1. hoist_0 = sext i32 to i64 + Value *HoistedSExt = + Builder.CreateSExt(InitialMinIndex, Builder.getInt64Ty(), "hoist_sext"); + + // 2. hoist_gep1 = getelementptr float, ptr BasePtr, i64 HoistedSExt + Value *HoistedGEP1 = + Builder.CreateGEP(LoadType, BasePtr, HoistedSExt, "hoist_gep1"); + + // 3. hoist_gep2 = getelementptr i8, ptr HoistedGEP1, i64 OffsetVal + Value *HoistedGEP2 = Builder.CreateGEP(Builder.getInt8Ty(), HoistedGEP1, + OffsetVal, "hoist_gep2"); + + // 4. hoisted_load = load float, ptr HoistedGEP2 + LoadInst *NewLoad = Builder.CreateLoad(LoadType, HoistedGEP2, "hoisted_load"); + + // Let the new load now take the place of the old load. + RHS->replaceAllUsesWith(NewLoad); + dyn_cast(RHS)->eraseFromParent(); + + // Comparison should now compare the current value and the newly inserted + // PHI node. + Comparison->setOperand(1, KnownMinPhi); + + // Create new select instruction for selecting the minimum value. + IRBuilder<> SelectBuilder(BB->getTerminator()); + SelectInst *CurrentMinSelect = dyn_cast( + SelectBuilder.CreateSelect(Comparison, LHS, KnownMinPhi, "current_min")); + + // Populate the newly created PHI node + // with (hoisted) NewLoad from the preheader and CurrentMinSelect. + KnownMinPhi->addIncoming(NewLoad, Preheader); + KnownMinPhi->addIncoming(CurrentMinSelect, BB); + LLVM_DEBUG(dbgs() << "Transformed the code\n"); + return true; +} + +bool GVNPass::recognizeMinFindingSelectPattern(SelectInst *Select) { + Value *BasePtr = nullptr, *IndexVal = nullptr, *OffsetVal = nullptr; + LLVM_DEBUG( + dbgs() + << "GVN: Analyzing select instruction for minimum finding pattern.\n"); + LLVM_DEBUG(dbgs() << "GVN: Select: " << *Select << "\n"); + BasicBlock *BB = Select->getParent(); + + // If the block is not in a loop, bail out. + Loop *L = LI->getLoopFor(BB); + if (!L) { + LLVM_DEBUG(dbgs() << "GVN: Could not find loop.\n"); + return false; + } + + // If preheader of the loop is not found, bail out. + BasicBlock *Preheader = L->getLoopPreheader(); + if (!Preheader) { + LLVM_DEBUG(dbgs() << "GVN: Could not find loop preheader.\n"); + return false; + } + Value *Condition = Select->getCondition(); + CmpInst *Comparison = dyn_cast(Condition); + if (!Comparison) { + LLVM_DEBUG(dbgs() << "GVN: Condition is not a comparison.\n"); + return false; + } + + // Check if this is less-than comparison. + CmpInst::Predicate Pred = Comparison->getPredicate(); + if (Pred != CmpInst::ICMP_SLT && Pred != CmpInst::ICMP_ULT && + Pred != CmpInst::FCMP_OLT && Pred != CmpInst::FCMP_ULT) { + LLVM_DEBUG(dbgs() << "GVN: Not a less-than comparison, predicate: " << Pred + << "\n"); + return false; + } + + // Check that both operands are loads. + Value *LHS = Comparison->getOperand(0); + Value *RHS = Comparison->getOperand(1); + if (!isa(LHS) || !isa(RHS)) { + LLVM_DEBUG(dbgs() << "GVN: Not both operands are loads.\n"); + return false; + } + + if (!match(RHS, + m_Load(m_GEP(m_GEP(m_Value(BasePtr), m_SExt(m_Value(IndexVal))), + m_Value(OffsetVal))))) { + LLVM_DEBUG(dbgs() << "GVN: Not a required load pattern.\n"); + return false; + } + LLVM_DEBUG(dbgs() << "GVN: Found minimum finding pattern in Block: " + << Select->getParent()->getName() << ".\n"); + + // Get type of load. + Type *LoadType = dyn_cast(LHS)->getType(); + LLVM_DEBUG(dbgs() << "GVN: Transforming minimum finding pattern.\n"); + return transformMinFindingSelectPattern(L, LoadType, Preheader, BB, LHS, RHS, + Comparison, Select, BasePtr, IndexVal, + OffsetVal); +} + class llvm::gvn::GVNLegacyPass : public FunctionPass { public: static char ID; // Pass identification, replacement for typeid. diff --git a/llvm/test/Transforms/GVN/PRE/gvn-min-pre.ll b/llvm/test/Transforms/GVN/PRE/gvn-min-pre.ll new file mode 100644 index 0000000000000..19fec514b28fe --- /dev/null +++ b/llvm/test/Transforms/GVN/PRE/gvn-min-pre.ll @@ -0,0 +1,58 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 6 +; RUN: opt -passes=gvn -S < %s | FileCheck %s + +define void @test_gvn_min_pattern(ptr %0) { +; CHECK-LABEL: define void @test_gvn_min_pattern( +; CHECK-SAME: ptr [[TMP0:%.*]]) { +; CHECK-NEXT: [[ENTRY:.*]]: +; CHECK-NEXT: [[HOIST_GEP1:%.*]] = getelementptr float, ptr [[TMP0]], i64 1 +; CHECK-NEXT: [[HOIST_GEP2:%.*]] = getelementptr i8, ptr [[HOIST_GEP1]], i64 -4 +; CHECK-NEXT: [[HOISTED_LOAD:%.*]] = load float, ptr [[HOIST_GEP2]], align 4 +; CHECK-NEXT: br label %[[LOOP:.*]] +; CHECK: [[LOOP]]: +; CHECK-NEXT: [[KNOWN_MIN:%.*]] = phi float [ [[HOISTED_LOAD]], %[[ENTRY]] ], [ [[CURRENT_MIN:%.*]], %[[LOOP]] ] +; CHECK-NEXT: [[INDVARS_IV_I:%.*]] = phi i64 [ 1, %[[ENTRY]] ], [ [[INDVARS_IV_NEXT_I:%.*]], %[[LOOP]] ] +; CHECK-NEXT: [[LOOP_COUNTER:%.*]] = phi i64 [ 0, %[[ENTRY]] ], [ [[LOOP_COUNTER_NEXT:%.*]], %[[LOOP]] ] +; CHECK-NEXT: [[MIN_IDX:%.*]] = phi i32 [ 1, %[[ENTRY]] ], [ [[MIN_IDX_NEXT:%.*]], %[[LOOP]] ] +; CHECK-NEXT: [[INDVARS_IV_NEXT_I]] = add nsw i64 [[INDVARS_IV_I]], -1 +; CHECK-NEXT: [[PTR_FLOAT_IV:%.*]] = getelementptr float, ptr [[TMP0]], i64 [[INDVARS_IV_I]] +; CHECK-NEXT: [[PTR_FIRST_LOAD:%.*]] = getelementptr i8, ptr [[PTR_FLOAT_IV]], i64 -8 +; CHECK-NEXT: [[VAL_FIRST:%.*]] = load float, ptr [[PTR_FIRST_LOAD]], align 4 +; CHECK-NEXT: [[MIN_IDX_EXT:%.*]] = sext i32 [[MIN_IDX]] to i64 +; CHECK-NEXT: [[PTR_FLOAT_MIN:%.*]] = getelementptr float, ptr [[TMP0]], i64 [[MIN_IDX_EXT]] +; CHECK-NEXT: [[PTR_SECOND_LOAD:%.*]] = getelementptr i8, ptr [[PTR_FLOAT_MIN]], i64 -4 +; CHECK-NEXT: [[CMP:%.*]] = fcmp contract olt float [[VAL_FIRST]], [[KNOWN_MIN]] +; CHECK-NEXT: [[NEXT_IDX_TRUNC:%.*]] = trunc nsw i64 [[INDVARS_IV_NEXT_I]] to i32 +; CHECK-NEXT: [[MIN_IDX_NEXT]] = select i1 [[CMP]], i32 [[NEXT_IDX_TRUNC]], i32 [[MIN_IDX]] +; CHECK-NEXT: [[LOOP_COUNTER_NEXT]] = add nsw i64 [[LOOP_COUNTER]], -1 +; CHECK-NEXT: [[LOOP_CONTINUE:%.*]] = icmp samesign ugt i64 [[LOOP_COUNTER]], 1 +; CHECK-NEXT: [[CURRENT_MIN]] = select i1 [[CMP]], float [[VAL_FIRST]], float [[KNOWN_MIN]] +; CHECK-NEXT: br i1 [[LOOP_CONTINUE]], label %[[LOOP]], label %[[EXIT:.*]] +; CHECK: [[EXIT]]: +; CHECK-NEXT: ret void +; +entry: + br label %loop + +loop: ; preds = %loop, %entry + %indvars.iv.i = phi i64 [ 1, %entry ], [ %indvars.iv.next.i, %loop ] + %loop.counter = phi i64 [ 0, %entry ], [ %loop.counter.next, %loop ] + %min.idx = phi i32 [ 1, %entry ], [ %min.idx.next, %loop ] + %indvars.iv.next.i = add nsw i64 %indvars.iv.i, -1 + %ptr.float.iv = getelementptr float, ptr %0, i64 %indvars.iv.i + %ptr.first.load = getelementptr i8, ptr %ptr.float.iv, i64 -8 + %val.first = load float, ptr %ptr.first.load, align 4 + %min.idx.ext = sext i32 %min.idx to i64 + %ptr.float.min = getelementptr float, ptr %0, i64 %min.idx.ext + %ptr.second.load = getelementptr i8, ptr %ptr.float.min, i64 -4 + %val.current.min = load float, ptr %ptr.second.load, align 4 + %cmp = fcmp contract olt float %val.first, %val.current.min + %next.idx.trunc = trunc nsw i64 %indvars.iv.next.i to i32 + %min.idx.next = select i1 %cmp, i32 %next.idx.trunc, i32 %min.idx + %loop.counter.next = add nsw i64 %loop.counter, -1 + %loop.continue = icmp samesign ugt i64 %loop.counter, 1 + br i1 %loop.continue, label %loop, label %exit + +exit: + ret void +}