@@ -129,7 +129,9 @@ class VectorCombine {
129129 bool foldExtractedCmps (Instruction &I);
130130 bool foldBinopOfReductions (Instruction &I);
131131 bool foldSingleElementStore (Instruction &I);
132- bool scalarizeLoadExtract (Instruction &I);
132+ bool scalarizeLoad (Instruction &I);
133+ bool scalarizeLoadExtract (LoadInst *LI, VectorType *VecTy, Value *Ptr);
134+ bool scalarizeLoadBitcast (LoadInst *LI, VectorType *VecTy, Value *Ptr);
133135 bool scalarizeExtExtract (Instruction &I);
134136 bool foldConcatOfBoolMasks (Instruction &I);
135137 bool foldPermuteOfBinops (Instruction &I);
@@ -1852,11 +1854,9 @@ bool VectorCombine::foldSingleElementStore(Instruction &I) {
18521854 return false ;
18531855}
18541856
1855- // / Try to scalarize vector loads feeding extractelement instructions.
1856- bool VectorCombine::scalarizeLoadExtract (Instruction &I) {
1857- if (!TTI.allowVectorElementIndexingUsingGEP ())
1858- return false ;
1859-
1857+ // / Try to scalarize vector loads feeding extractelement or bitcast
1858+ // / instructions.
1859+ bool VectorCombine::scalarizeLoad (Instruction &I) {
18601860 Value *Ptr;
18611861 if (!match (&I, m_Load (m_Value (Ptr))))
18621862 return false ;
@@ -1866,35 +1866,30 @@ bool VectorCombine::scalarizeLoadExtract(Instruction &I) {
18661866 if (LI->isVolatile () || !DL->typeSizeEqualsStoreSize (VecTy->getScalarType ()))
18671867 return false ;
18681868
1869- InstructionCost OriginalCost =
1870- TTI.getMemoryOpCost (Instruction::Load, VecTy, LI->getAlign (),
1871- LI->getPointerAddressSpace (), CostKind);
1872- InstructionCost ScalarizedCost = 0 ;
1873-
1869+ bool AllExtracts = true ;
1870+ bool AllBitcasts = true ;
18741871 Instruction *LastCheckedInst = LI;
18751872 unsigned NumInstChecked = 0 ;
1876- DenseMap<ExtractElementInst *, ScalarizationResult> NeedFreeze;
1877- auto FailureGuard = make_scope_exit ([&]() {
1878- // If the transform is aborted, discard the ScalarizationResults.
1879- for (auto &Pair : NeedFreeze)
1880- Pair.second .discard ();
1881- });
18821873
1883- // Check if all users of the load are extracts with no memory modifications
1884- // between the load and the extract. Compute the cost of both the original
1885- // code and the scalarized version .
1874+ // Check what type of users we have (must either all be extracts or
1875+ // bitcasts) and ensure no memory modifications between the load and
1876+ // its users .
18861877 for (User *U : LI->users ()) {
1887- auto *UI = dyn_cast<ExtractElementInst >(U);
1878+ auto *UI = dyn_cast<Instruction >(U);
18881879 if (!UI || UI->getParent () != LI->getParent ())
18891880 return false ;
18901881
1891- // If any extract is waiting to be erased, then bail out as this will
1882+ // If any user is waiting to be erased, then bail out as this will
18921883 // distort the cost calculation and possibly lead to infinite loops.
18931884 if (UI->use_empty ())
18941885 return false ;
18951886
1896- // Check if any instruction between the load and the extract may modify
1897- // memory.
1887+ if (!isa<ExtractElementInst>(UI))
1888+ AllExtracts = false ;
1889+ if (!isa<BitCastInst>(UI))
1890+ AllBitcasts = false ;
1891+
1892+ // Check if any instruction between the load and the user may modify memory.
18981893 if (LastCheckedInst->comesBefore (UI)) {
18991894 for (Instruction &I :
19001895 make_range (std::next (LI->getIterator ()), UI->getIterator ())) {
@@ -1906,6 +1901,35 @@ bool VectorCombine::scalarizeLoadExtract(Instruction &I) {
19061901 }
19071902 LastCheckedInst = UI;
19081903 }
1904+ }
1905+
1906+ if (AllExtracts)
1907+ return scalarizeLoadExtract (LI, VecTy, Ptr);
1908+ if (AllBitcasts)
1909+ return scalarizeLoadBitcast (LI, VecTy, Ptr);
1910+ return false ;
1911+ }
1912+
1913+ // / Try to scalarize vector loads feeding extractelement instructions.
1914+ bool VectorCombine::scalarizeLoadExtract (LoadInst *LI, VectorType *VecTy,
1915+ Value *Ptr) {
1916+ if (!TTI.allowVectorElementIndexingUsingGEP ())
1917+ return false ;
1918+
1919+ DenseMap<ExtractElementInst *, ScalarizationResult> NeedFreeze;
1920+ auto FailureGuard = make_scope_exit ([&]() {
1921+ // If the transform is aborted, discard the ScalarizationResults.
1922+ for (auto &Pair : NeedFreeze)
1923+ Pair.second .discard ();
1924+ });
1925+
1926+ InstructionCost OriginalCost =
1927+ TTI.getMemoryOpCost (Instruction::Load, VecTy, LI->getAlign (),
1928+ LI->getPointerAddressSpace (), CostKind);
1929+ InstructionCost ScalarizedCost = 0 ;
1930+
1931+ for (User *U : LI->users ()) {
1932+ auto *UI = cast<ExtractElementInst>(U);
19091933
19101934 auto ScalarIdx =
19111935 canScalarizeAccess (VecTy, UI->getIndexOperand (), LI, AC, DT);
@@ -1927,7 +1951,7 @@ bool VectorCombine::scalarizeLoadExtract(Instruction &I) {
19271951 nullptr , nullptr , CostKind);
19281952 }
19291953
1930- LLVM_DEBUG (dbgs () << " Found all extractions of a vector load: " << I
1954+ LLVM_DEBUG (dbgs () << " Found all extractions of a vector load: " << *LI
19311955 << " \n LoadExtractCost: " << OriginalCost
19321956 << " vs ScalarizedCost: " << ScalarizedCost << " \n " );
19331957
@@ -1973,6 +1997,72 @@ bool VectorCombine::scalarizeLoadExtract(Instruction &I) {
19731997 return true ;
19741998}
19751999
2000+ // / Try to scalarize vector loads feeding bitcast instructions.
2001+ bool VectorCombine::scalarizeLoadBitcast (LoadInst *LI, VectorType *VecTy,
2002+ Value *Ptr) {
2003+ InstructionCost OriginalCost =
2004+ TTI.getMemoryOpCost (Instruction::Load, VecTy, LI->getAlign (),
2005+ LI->getPointerAddressSpace (), CostKind);
2006+
2007+ Type *TargetScalarType = nullptr ;
2008+ unsigned VecBitWidth = DL->getTypeSizeInBits (VecTy);
2009+
2010+ for (User *U : LI->users ()) {
2011+ auto *BC = cast<BitCastInst>(U);
2012+
2013+ Type *DestTy = BC->getDestTy ();
2014+ if (!DestTy->isIntegerTy () && !DestTy->isFloatingPointTy ())
2015+ return false ;
2016+
2017+ unsigned DestBitWidth = DL->getTypeSizeInBits (DestTy);
2018+ if (DestBitWidth != VecBitWidth)
2019+ return false ;
2020+
2021+ // All bitcasts must target the same scalar type.
2022+ if (!TargetScalarType)
2023+ TargetScalarType = DestTy;
2024+ else if (TargetScalarType != DestTy)
2025+ return false ;
2026+
2027+ OriginalCost +=
2028+ TTI.getCastInstrCost (Instruction::BitCast, TargetScalarType, VecTy,
2029+ TTI.getCastContextHint (BC), CostKind, BC);
2030+ }
2031+
2032+ if (!TargetScalarType)
2033+ return false ;
2034+
2035+ assert (!LI->user_empty () && " Unexpected load without bitcast users" );
2036+ InstructionCost ScalarizedCost =
2037+ TTI.getMemoryOpCost (Instruction::Load, TargetScalarType, LI->getAlign (),
2038+ LI->getPointerAddressSpace (), CostKind);
2039+
2040+ LLVM_DEBUG (dbgs () << " Found vector load feeding only bitcasts: " << *LI
2041+ << " \n OriginalCost: " << OriginalCost
2042+ << " vs ScalarizedCost: " << ScalarizedCost << " \n " );
2043+
2044+ if (ScalarizedCost >= OriginalCost)
2045+ return false ;
2046+
2047+ // Ensure we add the load back to the worklist BEFORE its users so they can
2048+ // erased in the correct order.
2049+ Worklist.push (LI);
2050+
2051+ Builder.SetInsertPoint (LI);
2052+ auto *ScalarLoad =
2053+ Builder.CreateLoad (TargetScalarType, Ptr, LI->getName () + " .scalar" );
2054+ ScalarLoad->setAlignment (LI->getAlign ());
2055+ ScalarLoad->copyMetadata (*LI);
2056+
2057+ // Replace all bitcast users with the scalar load.
2058+ for (User *U : LI->users ()) {
2059+ auto *BC = cast<BitCastInst>(U);
2060+ replaceValue (*BC, *ScalarLoad, false );
2061+ }
2062+
2063+ return true ;
2064+ }
2065+
19762066bool VectorCombine::scalarizeExtExtract (Instruction &I) {
19772067 if (!TTI.allowVectorElementIndexingUsingGEP ())
19782068 return false ;
@@ -4585,7 +4675,7 @@ bool VectorCombine::run() {
45854675 if (IsVectorType) {
45864676 if (scalarizeOpOrCmp (I))
45874677 return true ;
4588- if (scalarizeLoadExtract (I))
4678+ if (scalarizeLoad (I))
45894679 return true ;
45904680 if (scalarizeExtExtract (I))
45914681 return true ;
0 commit comments