@@ -123,7 +123,9 @@ class VectorCombine {
123123 bool foldExtractedCmps (Instruction &I);
124124 bool foldBinopOfReductions (Instruction &I);
125125 bool foldSingleElementStore (Instruction &I);
126- bool scalarizeLoadExtract (Instruction &I);
126+ bool scalarizeLoad (Instruction &I);
127+ bool scalarizeLoadExtract (LoadInst *LI, VectorType *VecTy, Value *Ptr);
128+ bool scalarizeLoadBitcast (LoadInst *LI, VectorType *VecTy, Value *Ptr);
127129 bool scalarizeExtExtract (Instruction &I);
128130 bool foldConcatOfBoolMasks (Instruction &I);
129131 bool foldPermuteOfBinops (Instruction &I);
@@ -1664,8 +1666,9 @@ bool VectorCombine::foldSingleElementStore(Instruction &I) {
16641666 return false ;
16651667}
16661668
1667- // / Try to scalarize vector loads feeding extractelement instructions.
1668- bool VectorCombine::scalarizeLoadExtract (Instruction &I) {
1669+ // / Try to scalarize vector loads feeding extractelement or bitcast
1670+ // / instructions.
1671+ bool VectorCombine::scalarizeLoad (Instruction &I) {
16691672 Value *Ptr;
16701673 if (!match (&I, m_Load (m_Value (Ptr))))
16711674 return false ;
@@ -1675,35 +1678,30 @@ bool VectorCombine::scalarizeLoadExtract(Instruction &I) {
16751678 if (LI->isVolatile () || !DL->typeSizeEqualsStoreSize (VecTy->getScalarType ()))
16761679 return false ;
16771680
1678- InstructionCost OriginalCost =
1679- TTI.getMemoryOpCost (Instruction::Load, VecTy, LI->getAlign (),
1680- LI->getPointerAddressSpace (), CostKind);
1681- InstructionCost ScalarizedCost = 0 ;
1682-
1681+ bool AllExtracts = true ;
1682+ bool AllBitcasts = true ;
16831683 Instruction *LastCheckedInst = LI;
16841684 unsigned NumInstChecked = 0 ;
1685- DenseMap<ExtractElementInst *, ScalarizationResult> NeedFreeze;
1686- auto FailureGuard = make_scope_exit ([&]() {
1687- // If the transform is aborted, discard the ScalarizationResults.
1688- for (auto &Pair : NeedFreeze)
1689- Pair.second .discard ();
1690- });
16911685
1692- // Check if all users of the load are extracts with no memory modifications
1693- // between the load and the extract. Compute the cost of both the original
1694- // code and the scalarized version .
1686+ // Check what type of users we have (must either all be extracts or
1687+ // bitcasts) and ensure no memory modifications between the load and
1688+ // its users .
16951689 for (User *U : LI->users ()) {
1696- auto *UI = dyn_cast<ExtractElementInst >(U);
1690+ auto *UI = dyn_cast<Instruction >(U);
16971691 if (!UI || UI->getParent () != LI->getParent ())
16981692 return false ;
16991693
1700- // If any extract is waiting to be erased, then bail out as this will
1694+ // If any user is waiting to be erased, then bail out as this will
17011695 // distort the cost calculation and possibly lead to infinite loops.
17021696 if (UI->use_empty ())
17031697 return false ;
17041698
1705- // Check if any instruction between the load and the extract may modify
1706- // memory.
1699+ if (!isa<ExtractElementInst>(UI))
1700+ AllExtracts = false ;
1701+ if (!isa<BitCastInst>(UI))
1702+ AllBitcasts = false ;
1703+
1704+ // Check if any instruction between the load and the user may modify memory.
17071705 if (LastCheckedInst->comesBefore (UI)) {
17081706 for (Instruction &I :
17091707 make_range (std::next (LI->getIterator ()), UI->getIterator ())) {
@@ -1715,6 +1713,33 @@ bool VectorCombine::scalarizeLoadExtract(Instruction &I) {
17151713 }
17161714 LastCheckedInst = UI;
17171715 }
1716+ }
1717+
1718+ if (AllExtracts)
1719+ return scalarizeLoadExtract (LI, VecTy, Ptr);
1720+ if (AllBitcasts)
1721+ return scalarizeLoadBitcast (LI, VecTy, Ptr);
1722+ return false ;
1723+ }
1724+
1725+ // / Try to scalarize vector loads feeding extractelement instructions.
1726+ bool VectorCombine::scalarizeLoadExtract (LoadInst *LI, VectorType *VecTy,
1727+ Value *Ptr) {
1728+
1729+ DenseMap<ExtractElementInst *, ScalarizationResult> NeedFreeze;
1730+ auto FailureGuard = make_scope_exit ([&]() {
1731+ // If the transform is aborted, discard the ScalarizationResults.
1732+ for (auto &Pair : NeedFreeze)
1733+ Pair.second .discard ();
1734+ });
1735+
1736+ InstructionCost OriginalCost =
1737+ TTI.getMemoryOpCost (Instruction::Load, VecTy, LI->getAlign (),
1738+ LI->getPointerAddressSpace (), CostKind);
1739+ InstructionCost ScalarizedCost = 0 ;
1740+
1741+ for (User *U : LI->users ()) {
1742+ auto *UI = cast<ExtractElementInst>(U);
17181743
17191744 auto ScalarIdx =
17201745 canScalarizeAccess (VecTy, UI->getIndexOperand (), LI, AC, DT);
@@ -1735,7 +1760,7 @@ bool VectorCombine::scalarizeLoadExtract(Instruction &I) {
17351760 ScalarizedCost += TTI.getAddressComputationCost (VecTy->getElementType ());
17361761 }
17371762
1738- LLVM_DEBUG (dbgs () << " Found all extractions of a vector load: " << I
1763+ LLVM_DEBUG (dbgs () << " Found all extractions of a vector load: " << *LI
17391764 << " \n LoadExtractCost: " << OriginalCost
17401765 << " vs ScalarizedCost: " << ScalarizedCost << " \n " );
17411766
@@ -1773,6 +1798,72 @@ bool VectorCombine::scalarizeLoadExtract(Instruction &I) {
17731798 return true ;
17741799}
17751800
1801+ // / Try to scalarize vector loads feeding bitcast instructions.
1802+ bool VectorCombine::scalarizeLoadBitcast (LoadInst *LI, VectorType *VecTy,
1803+ Value *Ptr) {
1804+ InstructionCost OriginalCost =
1805+ TTI.getMemoryOpCost (Instruction::Load, VecTy, LI->getAlign (),
1806+ LI->getPointerAddressSpace (), CostKind);
1807+
1808+ Type *TargetScalarType = nullptr ;
1809+ unsigned VecBitWidth = DL->getTypeSizeInBits (VecTy);
1810+
1811+ for (User *U : LI->users ()) {
1812+ auto *BC = cast<BitCastInst>(U);
1813+
1814+ Type *DestTy = BC->getDestTy ();
1815+ if (!DestTy->isIntegerTy () && !DestTy->isFloatingPointTy ())
1816+ return false ;
1817+
1818+ unsigned DestBitWidth = DL->getTypeSizeInBits (DestTy);
1819+ if (DestBitWidth != VecBitWidth)
1820+ return false ;
1821+
1822+ // All bitcasts must target the same scalar type.
1823+ if (!TargetScalarType)
1824+ TargetScalarType = DestTy;
1825+ else if (TargetScalarType != DestTy)
1826+ return false ;
1827+
1828+ OriginalCost +=
1829+ TTI.getCastInstrCost (Instruction::BitCast, TargetScalarType, VecTy,
1830+ TTI.getCastContextHint (BC), CostKind, BC);
1831+ }
1832+
1833+ if (!TargetScalarType)
1834+ return false ;
1835+
1836+ assert (!LI->user_empty () && " Unexpected load without bitcast users" );
1837+ InstructionCost ScalarizedCost =
1838+ TTI.getMemoryOpCost (Instruction::Load, TargetScalarType, LI->getAlign (),
1839+ LI->getPointerAddressSpace (), CostKind);
1840+
1841+ LLVM_DEBUG (dbgs () << " Found vector load feeding only bitcasts: " << *LI
1842+ << " \n OriginalCost: " << OriginalCost
1843+ << " vs ScalarizedCost: " << ScalarizedCost << " \n " );
1844+
1845+ if (ScalarizedCost >= OriginalCost)
1846+ return false ;
1847+
1848+ // Ensure we add the load back to the worklist BEFORE its users so they can
1849+ // erased in the correct order.
1850+ Worklist.push (LI);
1851+
1852+ Builder.SetInsertPoint (LI);
1853+ auto *ScalarLoad =
1854+ Builder.CreateLoad (TargetScalarType, Ptr, LI->getName () + " .scalar" );
1855+ ScalarLoad->setAlignment (LI->getAlign ());
1856+ ScalarLoad->copyMetadata (*LI);
1857+
1858+ // Replace all bitcast users with the scalar load.
1859+ for (User *U : LI->users ()) {
1860+ auto *BC = cast<BitCastInst>(U);
1861+ replaceValue (*BC, *ScalarLoad);
1862+ }
1863+
1864+ return true ;
1865+ }
1866+
17761867bool VectorCombine::scalarizeExtExtract (Instruction &I) {
17771868 auto *Ext = dyn_cast<ZExtInst>(&I);
17781869 if (!Ext)
@@ -1822,8 +1913,31 @@ bool VectorCombine::scalarizeExtExtract(Instruction &I) {
18221913
18231914 Value *ScalarV = Ext->getOperand (0 );
18241915 if (!isGuaranteedNotToBePoison (ScalarV, &AC, dyn_cast<Instruction>(ScalarV),
1825- &DT))
1826- ScalarV = Builder.CreateFreeze (ScalarV);
1916+ &DT)) {
1917+ // Check wether all lanes are extracted, all extracts trigger UB
1918+ // on poison, and the last extract (and hence all previous ones)
1919+ // are guaranteed to execute if Ext executes. If so, we do not
1920+ // need to insert a freeze.
1921+ SmallDenseSet<ConstantInt *, 8 > ExtractedLanes;
1922+ bool AllExtractsTriggerUB = true ;
1923+ ExtractElementInst *LastExtract = nullptr ;
1924+ BasicBlock *ExtBB = Ext->getParent ();
1925+ for (User *U : Ext->users ()) {
1926+ auto *Extract = cast<ExtractElementInst>(U);
1927+ if (Extract->getParent () != ExtBB || !programUndefinedIfPoison (Extract)) {
1928+ AllExtractsTriggerUB = false ;
1929+ break ;
1930+ }
1931+ ExtractedLanes.insert (cast<ConstantInt>(Extract->getIndexOperand ()));
1932+ if (!LastExtract || LastExtract->comesBefore (Extract))
1933+ LastExtract = Extract;
1934+ }
1935+ if (ExtractedLanes.size () != DstTy->getNumElements () ||
1936+ !AllExtractsTriggerUB ||
1937+ !isGuaranteedToTransferExecutionToSuccessor (Ext->getIterator (),
1938+ LastExtract->getIterator ()))
1939+ ScalarV = Builder.CreateFreeze (ScalarV);
1940+ }
18271941 ScalarV = Builder.CreateBitCast (
18281942 ScalarV,
18291943 IntegerType::get (SrcTy->getContext (), DL->getTypeSizeInBits (SrcTy)));
@@ -3734,7 +3848,7 @@ bool VectorCombine::run() {
37343848 // TODO: Identify and allow other scalable transforms
37353849 if (IsVectorType) {
37363850 MadeChange |= scalarizeOpOrCmp (I);
3737- MadeChange |= scalarizeLoadExtract (I);
3851+ MadeChange |= scalarizeLoad (I);
37383852 MadeChange |= scalarizeExtExtract (I);
37393853 MadeChange |= scalarizeVPIntrinsic (I);
37403854 MadeChange |= foldInterleaveIntrinsics (I);
0 commit comments