Skip to content

Commit 2e3c3af

Browse files
authored
[VectorCombine] Pick changes for load scalarization. (#11884)
* [VectorCombine] Avoid inserting freeze when scalarizing extend-extract if all extracts would lead to UB on poison. (llvm#164683) This change aims to avoid inserting a freeze instruction between the load and bitcast when scalarizing extend-extract. This is particularly useful in combination with llvm#164682, which can then potentially further scalarize, provided there is no freeze. alive2 proof: https://alive2.llvm.org/ce/z/W-GD88 (cherry picked from commit 28a20b4) * [VectorCombine] Try to scalarize vector loads feeding bitcast instructions. (llvm#164682) This change aims to convert vector loads to scalar loads, if they are only converted to scalars after anyway. alive2 proof: https://alive2.llvm.org/ce/z/U_rvht (cherry picked from commit 8280070)
1 parent 1bab8f2 commit 2e3c3af

File tree

4 files changed

+493
-25
lines changed

4 files changed

+493
-25
lines changed

llvm/lib/Transforms/Vectorize/VectorCombine.cpp

Lines changed: 139 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,9 @@ class VectorCombine {
123123
bool foldExtractedCmps(Instruction &I);
124124
bool foldBinopOfReductions(Instruction &I);
125125
bool foldSingleElementStore(Instruction &I);
126-
bool scalarizeLoadExtract(Instruction &I);
126+
bool scalarizeLoad(Instruction &I);
127+
bool scalarizeLoadExtract(LoadInst *LI, VectorType *VecTy, Value *Ptr);
128+
bool scalarizeLoadBitcast(LoadInst *LI, VectorType *VecTy, Value *Ptr);
127129
bool scalarizeExtExtract(Instruction &I);
128130
bool foldConcatOfBoolMasks(Instruction &I);
129131
bool foldPermuteOfBinops(Instruction &I);
@@ -1664,8 +1666,9 @@ bool VectorCombine::foldSingleElementStore(Instruction &I) {
16641666
return false;
16651667
}
16661668

1667-
/// Try to scalarize vector loads feeding extractelement instructions.
1668-
bool VectorCombine::scalarizeLoadExtract(Instruction &I) {
1669+
/// Try to scalarize vector loads feeding extractelement or bitcast
1670+
/// instructions.
1671+
bool VectorCombine::scalarizeLoad(Instruction &I) {
16691672
Value *Ptr;
16701673
if (!match(&I, m_Load(m_Value(Ptr))))
16711674
return false;
@@ -1675,35 +1678,30 @@ bool VectorCombine::scalarizeLoadExtract(Instruction &I) {
16751678
if (LI->isVolatile() || !DL->typeSizeEqualsStoreSize(VecTy->getScalarType()))
16761679
return false;
16771680

1678-
InstructionCost OriginalCost =
1679-
TTI.getMemoryOpCost(Instruction::Load, VecTy, LI->getAlign(),
1680-
LI->getPointerAddressSpace(), CostKind);
1681-
InstructionCost ScalarizedCost = 0;
1682-
1681+
bool AllExtracts = true;
1682+
bool AllBitcasts = true;
16831683
Instruction *LastCheckedInst = LI;
16841684
unsigned NumInstChecked = 0;
1685-
DenseMap<ExtractElementInst *, ScalarizationResult> NeedFreeze;
1686-
auto FailureGuard = make_scope_exit([&]() {
1687-
// If the transform is aborted, discard the ScalarizationResults.
1688-
for (auto &Pair : NeedFreeze)
1689-
Pair.second.discard();
1690-
});
16911685

1692-
// Check if all users of the load are extracts with no memory modifications
1693-
// between the load and the extract. Compute the cost of both the original
1694-
// code and the scalarized version.
1686+
// Check what type of users we have (must either all be extracts or
1687+
// bitcasts) and ensure no memory modifications between the load and
1688+
// its users.
16951689
for (User *U : LI->users()) {
1696-
auto *UI = dyn_cast<ExtractElementInst>(U);
1690+
auto *UI = dyn_cast<Instruction>(U);
16971691
if (!UI || UI->getParent() != LI->getParent())
16981692
return false;
16991693

1700-
// If any extract is waiting to be erased, then bail out as this will
1694+
// If any user is waiting to be erased, then bail out as this will
17011695
// distort the cost calculation and possibly lead to infinite loops.
17021696
if (UI->use_empty())
17031697
return false;
17041698

1705-
// Check if any instruction between the load and the extract may modify
1706-
// memory.
1699+
if (!isa<ExtractElementInst>(UI))
1700+
AllExtracts = false;
1701+
if (!isa<BitCastInst>(UI))
1702+
AllBitcasts = false;
1703+
1704+
// Check if any instruction between the load and the user may modify memory.
17071705
if (LastCheckedInst->comesBefore(UI)) {
17081706
for (Instruction &I :
17091707
make_range(std::next(LI->getIterator()), UI->getIterator())) {
@@ -1715,6 +1713,33 @@ bool VectorCombine::scalarizeLoadExtract(Instruction &I) {
17151713
}
17161714
LastCheckedInst = UI;
17171715
}
1716+
}
1717+
1718+
if (AllExtracts)
1719+
return scalarizeLoadExtract(LI, VecTy, Ptr);
1720+
if (AllBitcasts)
1721+
return scalarizeLoadBitcast(LI, VecTy, Ptr);
1722+
return false;
1723+
}
1724+
1725+
/// Try to scalarize vector loads feeding extractelement instructions.
1726+
bool VectorCombine::scalarizeLoadExtract(LoadInst *LI, VectorType *VecTy,
1727+
Value *Ptr) {
1728+
1729+
DenseMap<ExtractElementInst *, ScalarizationResult> NeedFreeze;
1730+
auto FailureGuard = make_scope_exit([&]() {
1731+
// If the transform is aborted, discard the ScalarizationResults.
1732+
for (auto &Pair : NeedFreeze)
1733+
Pair.second.discard();
1734+
});
1735+
1736+
InstructionCost OriginalCost =
1737+
TTI.getMemoryOpCost(Instruction::Load, VecTy, LI->getAlign(),
1738+
LI->getPointerAddressSpace(), CostKind);
1739+
InstructionCost ScalarizedCost = 0;
1740+
1741+
for (User *U : LI->users()) {
1742+
auto *UI = cast<ExtractElementInst>(U);
17181743

17191744
auto ScalarIdx =
17201745
canScalarizeAccess(VecTy, UI->getIndexOperand(), LI, AC, DT);
@@ -1735,7 +1760,7 @@ bool VectorCombine::scalarizeLoadExtract(Instruction &I) {
17351760
ScalarizedCost += TTI.getAddressComputationCost(VecTy->getElementType());
17361761
}
17371762

1738-
LLVM_DEBUG(dbgs() << "Found all extractions of a vector load: " << I
1763+
LLVM_DEBUG(dbgs() << "Found all extractions of a vector load: " << *LI
17391764
<< "\n LoadExtractCost: " << OriginalCost
17401765
<< " vs ScalarizedCost: " << ScalarizedCost << "\n");
17411766

@@ -1773,6 +1798,72 @@ bool VectorCombine::scalarizeLoadExtract(Instruction &I) {
17731798
return true;
17741799
}
17751800

1801+
/// Try to scalarize vector loads feeding bitcast instructions.
1802+
bool VectorCombine::scalarizeLoadBitcast(LoadInst *LI, VectorType *VecTy,
1803+
Value *Ptr) {
1804+
InstructionCost OriginalCost =
1805+
TTI.getMemoryOpCost(Instruction::Load, VecTy, LI->getAlign(),
1806+
LI->getPointerAddressSpace(), CostKind);
1807+
1808+
Type *TargetScalarType = nullptr;
1809+
unsigned VecBitWidth = DL->getTypeSizeInBits(VecTy);
1810+
1811+
for (User *U : LI->users()) {
1812+
auto *BC = cast<BitCastInst>(U);
1813+
1814+
Type *DestTy = BC->getDestTy();
1815+
if (!DestTy->isIntegerTy() && !DestTy->isFloatingPointTy())
1816+
return false;
1817+
1818+
unsigned DestBitWidth = DL->getTypeSizeInBits(DestTy);
1819+
if (DestBitWidth != VecBitWidth)
1820+
return false;
1821+
1822+
// All bitcasts must target the same scalar type.
1823+
if (!TargetScalarType)
1824+
TargetScalarType = DestTy;
1825+
else if (TargetScalarType != DestTy)
1826+
return false;
1827+
1828+
OriginalCost +=
1829+
TTI.getCastInstrCost(Instruction::BitCast, TargetScalarType, VecTy,
1830+
TTI.getCastContextHint(BC), CostKind, BC);
1831+
}
1832+
1833+
if (!TargetScalarType)
1834+
return false;
1835+
1836+
assert(!LI->user_empty() && "Unexpected load without bitcast users");
1837+
InstructionCost ScalarizedCost =
1838+
TTI.getMemoryOpCost(Instruction::Load, TargetScalarType, LI->getAlign(),
1839+
LI->getPointerAddressSpace(), CostKind);
1840+
1841+
LLVM_DEBUG(dbgs() << "Found vector load feeding only bitcasts: " << *LI
1842+
<< "\n OriginalCost: " << OriginalCost
1843+
<< " vs ScalarizedCost: " << ScalarizedCost << "\n");
1844+
1845+
if (ScalarizedCost >= OriginalCost)
1846+
return false;
1847+
1848+
// Ensure we add the load back to the worklist BEFORE its users so they can
1849+
// erased in the correct order.
1850+
Worklist.push(LI);
1851+
1852+
Builder.SetInsertPoint(LI);
1853+
auto *ScalarLoad =
1854+
Builder.CreateLoad(TargetScalarType, Ptr, LI->getName() + ".scalar");
1855+
ScalarLoad->setAlignment(LI->getAlign());
1856+
ScalarLoad->copyMetadata(*LI);
1857+
1858+
// Replace all bitcast users with the scalar load.
1859+
for (User *U : LI->users()) {
1860+
auto *BC = cast<BitCastInst>(U);
1861+
replaceValue(*BC, *ScalarLoad);
1862+
}
1863+
1864+
return true;
1865+
}
1866+
17761867
bool VectorCombine::scalarizeExtExtract(Instruction &I) {
17771868
auto *Ext = dyn_cast<ZExtInst>(&I);
17781869
if (!Ext)
@@ -1822,8 +1913,31 @@ bool VectorCombine::scalarizeExtExtract(Instruction &I) {
18221913

18231914
Value *ScalarV = Ext->getOperand(0);
18241915
if (!isGuaranteedNotToBePoison(ScalarV, &AC, dyn_cast<Instruction>(ScalarV),
1825-
&DT))
1826-
ScalarV = Builder.CreateFreeze(ScalarV);
1916+
&DT)) {
1917+
// Check wether all lanes are extracted, all extracts trigger UB
1918+
// on poison, and the last extract (and hence all previous ones)
1919+
// are guaranteed to execute if Ext executes. If so, we do not
1920+
// need to insert a freeze.
1921+
SmallDenseSet<ConstantInt *, 8> ExtractedLanes;
1922+
bool AllExtractsTriggerUB = true;
1923+
ExtractElementInst *LastExtract = nullptr;
1924+
BasicBlock *ExtBB = Ext->getParent();
1925+
for (User *U : Ext->users()) {
1926+
auto *Extract = cast<ExtractElementInst>(U);
1927+
if (Extract->getParent() != ExtBB || !programUndefinedIfPoison(Extract)) {
1928+
AllExtractsTriggerUB = false;
1929+
break;
1930+
}
1931+
ExtractedLanes.insert(cast<ConstantInt>(Extract->getIndexOperand()));
1932+
if (!LastExtract || LastExtract->comesBefore(Extract))
1933+
LastExtract = Extract;
1934+
}
1935+
if (ExtractedLanes.size() != DstTy->getNumElements() ||
1936+
!AllExtractsTriggerUB ||
1937+
!isGuaranteedToTransferExecutionToSuccessor(Ext->getIterator(),
1938+
LastExtract->getIterator()))
1939+
ScalarV = Builder.CreateFreeze(ScalarV);
1940+
}
18271941
ScalarV = Builder.CreateBitCast(
18281942
ScalarV,
18291943
IntegerType::get(SrcTy->getContext(), DL->getTypeSizeInBits(SrcTy)));
@@ -3734,7 +3848,7 @@ bool VectorCombine::run() {
37343848
// TODO: Identify and allow other scalable transforms
37353849
if (IsVectorType) {
37363850
MadeChange |= scalarizeOpOrCmp(I);
3737-
MadeChange |= scalarizeLoadExtract(I);
3851+
MadeChange |= scalarizeLoad(I);
37383852
MadeChange |= scalarizeExtExtract(I);
37393853
MadeChange |= scalarizeVPIntrinsic(I);
37403854
MadeChange |= foldInterleaveIntrinsics(I);
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 6
2+
; RUN: opt -O3 -mtriple=arm64-apple-darwinos -S %s | FileCheck %s
3+
4+
define noundef i32 @load_ext_extract(ptr %src) {
5+
; CHECK-LABEL: define noundef range(i32 0, 1021) i32 @load_ext_extract(
6+
; CHECK-SAME: ptr readonly captures(none) [[SRC:%.*]]) local_unnamed_addr #[[ATTR0:[0-9]+]] {
7+
; CHECK-NEXT: [[ENTRY:.*:]]
8+
; CHECK-NEXT: [[TMP14:%.*]] = load i32, ptr [[SRC]], align 4
9+
; CHECK-NEXT: [[TMP15:%.*]] = lshr i32 [[TMP14]], 24
10+
; CHECK-NEXT: [[TMP16:%.*]] = lshr i32 [[TMP14]], 16
11+
; CHECK-NEXT: [[TMP17:%.*]] = and i32 [[TMP16]], 255
12+
; CHECK-NEXT: [[TMP18:%.*]] = lshr i32 [[TMP14]], 8
13+
; CHECK-NEXT: [[TMP19:%.*]] = and i32 [[TMP18]], 255
14+
; CHECK-NEXT: [[TMP20:%.*]] = and i32 [[TMP14]], 255
15+
; CHECK-NEXT: [[ADD1:%.*]] = add nuw nsw i32 [[TMP20]], [[TMP19]]
16+
; CHECK-NEXT: [[ADD2:%.*]] = add nuw nsw i32 [[ADD1]], [[TMP17]]
17+
; CHECK-NEXT: [[ADD3:%.*]] = add nuw nsw i32 [[ADD2]], [[TMP15]]
18+
; CHECK-NEXT: ret i32 [[ADD3]]
19+
;
20+
entry:
21+
%x = load <4 x i8>, ptr %src, align 4
22+
%ext = zext nneg <4 x i8> %x to <4 x i32>
23+
%ext.0 = extractelement <4 x i32> %ext, i64 0
24+
%ext.1 = extractelement <4 x i32> %ext, i64 1
25+
%ext.2 = extractelement <4 x i32> %ext, i64 2
26+
%ext.3 = extractelement <4 x i32> %ext, i64 3
27+
28+
%add1 = add i32 %ext.0, %ext.1
29+
%add2 = add i32 %add1, %ext.2
30+
%add3 = add i32 %add2, %ext.3
31+
ret i32 %add3
32+
}

0 commit comments

Comments
 (0)