Skip to content

Commit 2499cef

Browse files
committed
[GVN] Support rnflow pattern matching and transform
1 parent 3f52eef commit 2499cef

File tree

3 files changed

+185
-0
lines changed

3 files changed

+185
-0
lines changed

llvm/include/llvm/Transforms/Scalar/GVN.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
#include "llvm/IR/Dominators.h"
2323
#include "llvm/IR/InstrTypes.h"
2424
#include "llvm/IR/PassManager.h"
25+
#include "llvm/Analysis/LoopInfo.h"
2526
#include "llvm/IR/ValueHandle.h"
2627
#include "llvm/Support/Allocator.h"
2728
#include "llvm/Support/Compiler.h"
@@ -44,6 +45,7 @@ class FunctionPass;
4445
class GetElementPtrInst;
4546
class ImplicitControlFlowTracking;
4647
class LoadInst;
48+
class SelectInst;
4749
class LoopInfo;
4850
class MemDepResult;
4951
class MemoryAccess;
@@ -409,6 +411,8 @@ class GVNPass : public PassInfoMixin<GVNPass> {
409411
void addDeadBlock(BasicBlock *BB);
410412
void assignValNumForDeadCode();
411413
void assignBlockRPONumber(Function &F);
414+
415+
bool optimizeMinMaxFindingSelectPattern(SelectInst *Select);
412416
};
413417

414418
/// Create a legacy GVN pass.

llvm/lib/Transforms/Scalar/GVN.cpp

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2818,6 +2818,10 @@ bool GVNPass::processInstruction(Instruction *I) {
28182818
}
28192819
return Changed;
28202820
}
2821+
if (SelectInst *Select = dyn_cast<SelectInst>(I)) {
2822+
if (optimizeMinMaxFindingSelectPattern(Select))
2823+
return true;
2824+
}
28212825

28222826
// Instructions with void type don't return a value, so there's
28232827
// no point in trying to find redundancies in them.
@@ -3410,6 +3414,124 @@ void GVNPass::assignValNumForDeadCode() {
34103414
}
34113415
}
34123416

3417+
bool GVNPass::optimizeMinMaxFindingSelectPattern(SelectInst *Select) {
3418+
LLVM_DEBUG(
3419+
dbgs()
3420+
<< "GVN: Analyzing select instruction for minimum finding pattern\n");
3421+
LLVM_DEBUG(dbgs() << "GVN: Select: " << *Select << "\n");
3422+
Value *Condition = Select->getCondition();
3423+
CmpInst *Comparison = dyn_cast<CmpInst>(Condition);
3424+
if (!Comparison) {
3425+
LLVM_DEBUG(dbgs() << "GVN: Condition is not a comparison\n");
3426+
return false;
3427+
}
3428+
3429+
// Check if this is ULT comparison.
3430+
CmpInst::Predicate Pred = Comparison->getPredicate();
3431+
if (Pred != CmpInst::ICMP_SLT && Pred != CmpInst::ICMP_ULT &&
3432+
Pred != CmpInst::FCMP_OLT && Pred != CmpInst::FCMP_ULT) {
3433+
LLVM_DEBUG(dbgs() << "GVN: Not a less-than comparison, predicate: " << Pred
3434+
<< "\n");
3435+
return false;
3436+
}
3437+
3438+
// Check that both operands are loads.
3439+
Value *LHS = Comparison->getOperand(0);
3440+
Value *RHS = Comparison->getOperand(1);
3441+
if (!isa<LoadInst>(LHS) || !isa<LoadInst>(RHS)) {
3442+
LLVM_DEBUG(dbgs() << "GVN: Not both operands are loads\n");
3443+
return false;
3444+
}
3445+
3446+
LLVM_DEBUG(dbgs() << "GVN: Found minimum finding pattern in Block: "
3447+
<< Select->getParent()->getName() << "\n");
3448+
3449+
// Transform the pattern.
3450+
// Hoist the chain of operations for the second load to preheader.
3451+
// Get predecessor of the block containing the select instruction.
3452+
BasicBlock *BB = Select->getParent();
3453+
3454+
// Get preheader of the loop.
3455+
Loop *L = LI->getLoopFor(BB);
3456+
if (!L) {
3457+
LLVM_DEBUG(dbgs() << "GVN: Could not find loop\n");
3458+
return false;
3459+
}
3460+
BasicBlock *Preheader = L->getLoopPreheader();
3461+
if (!Preheader) {
3462+
LLVM_DEBUG(dbgs() << "GVN: Could not find loop preheader\n");
3463+
return false;
3464+
}
3465+
3466+
// Hoist the chain of operations for the second load to preheader.
3467+
// %90 = sext i32 %.05.i to i64
3468+
// %91 = getelementptr float, ptr %0, i64 %90 ; %0 + (sext i32 %85 to i64)*4
3469+
// %92 = getelementptr i8, ptr %91, i64 -4 ; %0 + (sext i32 %85 to i64)*4 - 4
3470+
// %93 = load float, ptr %92, align 4
3471+
3472+
Value *BasePtr = nullptr, *IndexVal = nullptr, *OffsetVal = nullptr;
3473+
IRBuilder<> Builder(Preheader->getTerminator());
3474+
if (match(RHS,
3475+
m_Load(m_GEP(m_GEP(m_Value(BasePtr), m_SExt(m_Value(IndexVal))),
3476+
m_Value(OffsetVal))))) {
3477+
LLVM_DEBUG(dbgs() << "GVN: Found pattern: " << *RHS << "\n");
3478+
LLVM_DEBUG(dbgs() << "GVN: Found pattern: " << "\n");
3479+
3480+
PHINode *Phi = dyn_cast<PHINode>(IndexVal);
3481+
if (!Phi) {
3482+
LLVM_DEBUG(dbgs() << "GVN: IndexVal is not a PHI node\n");
3483+
return false;
3484+
}
3485+
Value *InitialMinIndex = Phi->getIncomingValueForBlock(Preheader);
3486+
3487+
// Insert PHI node at the top of this block.
3488+
PHINode *KnownMinPhi =
3489+
PHINode::Create(Builder.getFloatTy(), 2, "known_min", BB->begin());
3490+
3491+
// Build the GEP chain in the preheader.
3492+
// 1. hoist_0 = sext i32 to i64
3493+
Value *HoistedSExt =
3494+
Builder.CreateSExt(InitialMinIndex, Builder.getInt64Ty(), "hoist_sext");
3495+
3496+
// 2. hoist_gep1 = getelementptr float, ptr BasePtr, i64 HoistedSExt
3497+
Value *HoistedGEP1 = Builder.CreateGEP(Builder.getFloatTy(), BasePtr,
3498+
HoistedSExt, "hoist_gep1");
3499+
3500+
// 3. hoist_gep2 = getelementptr i8, ptr HoistedGEP1, i64 OffsetVal
3501+
Value *HoistedGEP2 = Builder.CreateGEP(Builder.getInt8Ty(), HoistedGEP1,
3502+
OffsetVal, "hoist_gep2");
3503+
3504+
// 4. hoisted_load = load float, ptr HoistedGEP2
3505+
LoadInst *NewLoad =
3506+
Builder.CreateLoad(Builder.getFloatTy(), HoistedGEP2, "hoisted_load");
3507+
3508+
// Replace all uses of load with new load.
3509+
RHS->replaceAllUsesWith(NewLoad);
3510+
dyn_cast<LoadInst>(RHS)->eraseFromParent();
3511+
3512+
// Replace second operand of comparison with KnownMinPhi.
3513+
Comparison->setOperand(1, KnownMinPhi);
3514+
3515+
// Create new select instruction for selecting the minimum value.
3516+
IRBuilder<> SelectBuilder(BB->getTerminator());
3517+
SelectInst *CurrentMinSelect =
3518+
dyn_cast<SelectInst>(SelectBuilder.CreateSelect(
3519+
Comparison, LHS, KnownMinPhi, "current_min"));
3520+
3521+
// Populate PHI node.
3522+
KnownMinPhi->addIncoming(NewLoad, Preheader);
3523+
KnownMinPhi->addIncoming(CurrentMinSelect, BB);
3524+
std::cout << "Transformed the code\n";
3525+
return true;
3526+
} else {
3527+
LLVM_DEBUG(dbgs() << "GVN: Could not find pattern: " << *RHS << "\n");
3528+
std::cout << "GVN: Could not find pattern: " << "\n";
3529+
return false;
3530+
}
3531+
return false;
3532+
}
3533+
3534+
34133535
class llvm::gvn::GVNLegacyPass : public FunctionPass {
34143536
public:
34153537
static char ID; // Pass identification, replacement for typeid.
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 6
2+
; Minimal test case containing only the .lr.ph.i basic block
3+
; RUN: opt -passes=gvn -S < %s | FileCheck %s
4+
5+
define void @test_lr_ph_i(ptr %0) {
6+
; CHECK-LABEL: define void @test_lr_ph_i(
7+
; CHECK-SAME: ptr [[TMP0:%.*]]) {
8+
; CHECK-NEXT: [[ENTRY:.*]]:
9+
; CHECK-NEXT: [[HOIST_GEP1:%.*]] = getelementptr float, ptr [[TMP0]], i64 1
10+
; CHECK-NEXT: [[HOIST_GEP2:%.*]] = getelementptr i8, ptr [[HOIST_GEP1]], i64 -4
11+
; CHECK-NEXT: [[HOISTED_LOAD:%.*]] = load float, ptr [[HOIST_GEP2]], align 4
12+
; CHECK-NEXT: br label %[[DOTLR_PH_I:.*]]
13+
; CHECK: [[_LR_PH_I:.*:]]
14+
; CHECK-NEXT: [[KNOWN_MIN:%.*]] = phi float [ [[HOISTED_LOAD]], %[[ENTRY]] ], [ [[CURRENT_MIN:%.*]], %[[DOTLR_PH_I]] ]
15+
; CHECK-NEXT: [[INDVARS_IV_I:%.*]] = phi i64 [ 1, %[[ENTRY]] ], [ [[INDVARS_IV_NEXT_I:%.*]], %[[DOTLR_PH_I]] ]
16+
; CHECK-NEXT: [[TMP1:%.*]] = phi i64 [ 0, %[[ENTRY]] ], [ [[TMP10:%.*]], %[[DOTLR_PH_I]] ]
17+
; CHECK-NEXT: [[DOT05_I:%.*]] = phi i32 [ 1, %[[ENTRY]] ], [ [[DOT1_I:%.*]], %[[DOTLR_PH_I]] ]
18+
; CHECK-NEXT: [[INDVARS_IV_NEXT_I]] = add nsw i64 [[INDVARS_IV_I]], -1
19+
; CHECK-NEXT: [[TMP2:%.*]] = getelementptr float, ptr [[TMP0]], i64 [[INDVARS_IV_I]]
20+
; CHECK-NEXT: [[TMP3:%.*]] = getelementptr i8, ptr [[TMP2]], i64 -8
21+
; CHECK-NEXT: [[TMP4:%.*]] = load float, ptr [[TMP3]], align 4
22+
; CHECK-NEXT: [[TMP5:%.*]] = sext i32 [[DOT05_I]] to i64
23+
; CHECK-NEXT: [[TMP6:%.*]] = getelementptr float, ptr [[TMP0]], i64 [[TMP5]]
24+
; CHECK-NEXT: [[TMP7:%.*]] = getelementptr i8, ptr [[TMP6]], i64 -4
25+
; CHECK-NEXT: [[TMP8:%.*]] = fcmp contract olt float [[TMP4]], [[KNOWN_MIN]]
26+
; CHECK-NEXT: [[TMP9:%.*]] = trunc nsw i64 [[INDVARS_IV_NEXT_I]] to i32
27+
; CHECK-NEXT: [[DOT1_I]] = select i1 [[TMP8]], i32 [[TMP9]], i32 [[DOT05_I]]
28+
; CHECK-NEXT: [[TMP10]] = add nsw i64 [[TMP1]], -1
29+
; CHECK-NEXT: [[TMP11:%.*]] = icmp samesign ugt i64 [[TMP1]], 1
30+
; CHECK-NEXT: [[CURRENT_MIN]] = select i1 [[TMP8]], float [[TMP4]], float [[KNOWN_MIN]]
31+
; CHECK-NEXT: br i1 [[TMP11]], label %[[DOTLR_PH_I]], label %[[EXIT:.*]]
32+
; CHECK: [[EXIT]]:
33+
; CHECK-NEXT: ret void
34+
;
35+
entry:
36+
br label %.lr.ph.i
37+
38+
.lr.ph.i: ; preds = %.lr.ph.i, %entry
39+
%indvars.iv.i = phi i64 [ 1, %entry ], [ %indvars.iv.next.i, %.lr.ph.i ]
40+
%86 = phi i64 [ 0, %entry ], [ %96, %.lr.ph.i ]
41+
%.05.i = phi i32 [ 1, %entry ], [ %.1.i, %.lr.ph.i ]
42+
%indvars.iv.next.i = add nsw i64 %indvars.iv.i, -1
43+
%87 = getelementptr float, ptr %0, i64 %indvars.iv.i
44+
%88 = getelementptr i8, ptr %87, i64 -8 ; first load : %0 + 4 * 1 - 8
45+
%89 = load float, ptr %88, align 4
46+
%90 = sext i32 %.05.i to i64
47+
%91 = getelementptr float, ptr %0, i64 %90 ; %0 + 4 * 1
48+
%92 = getelementptr i8, ptr %91, i64 -4 ; second load : %0 + 4 * 1 - 4
49+
%93 = load float, ptr %92, align 4
50+
%94 = fcmp contract olt float %89, %93
51+
%95 = trunc nsw i64 %indvars.iv.next.i to i32
52+
%.1.i = select i1 %94, i32 %95, i32 %.05.i
53+
%96 = add nsw i64 %86, -1
54+
%97 = icmp samesign ugt i64 %86, 1
55+
br i1 %97, label %.lr.ph.i, label %exit
56+
57+
exit:
58+
ret void
59+
}

0 commit comments

Comments
 (0)