Skip to content

Commit 8280070

Browse files
authored
[VectorCombine] Try to scalarize vector loads feeding bitcast instructions. (#164682)
This change aims to convert vector loads to scalar loads, if they are only converted to scalars after anyway. alive2 proof: https://alive2.llvm.org/ce/z/U_rvht
1 parent 5932477 commit 8280070

File tree

3 files changed

+284
-26
lines changed

3 files changed

+284
-26
lines changed

llvm/lib/Transforms/Vectorize/VectorCombine.cpp

Lines changed: 116 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,9 @@ class VectorCombine {
129129
bool foldExtractedCmps(Instruction &I);
130130
bool foldBinopOfReductions(Instruction &I);
131131
bool foldSingleElementStore(Instruction &I);
132-
bool scalarizeLoadExtract(Instruction &I);
132+
bool scalarizeLoad(Instruction &I);
133+
bool scalarizeLoadExtract(LoadInst *LI, VectorType *VecTy, Value *Ptr);
134+
bool scalarizeLoadBitcast(LoadInst *LI, VectorType *VecTy, Value *Ptr);
133135
bool scalarizeExtExtract(Instruction &I);
134136
bool foldConcatOfBoolMasks(Instruction &I);
135137
bool foldPermuteOfBinops(Instruction &I);
@@ -1852,11 +1854,9 @@ bool VectorCombine::foldSingleElementStore(Instruction &I) {
18521854
return false;
18531855
}
18541856

1855-
/// Try to scalarize vector loads feeding extractelement instructions.
1856-
bool VectorCombine::scalarizeLoadExtract(Instruction &I) {
1857-
if (!TTI.allowVectorElementIndexingUsingGEP())
1858-
return false;
1859-
1857+
/// Try to scalarize vector loads feeding extractelement or bitcast
1858+
/// instructions.
1859+
bool VectorCombine::scalarizeLoad(Instruction &I) {
18601860
Value *Ptr;
18611861
if (!match(&I, m_Load(m_Value(Ptr))))
18621862
return false;
@@ -1866,35 +1866,30 @@ bool VectorCombine::scalarizeLoadExtract(Instruction &I) {
18661866
if (LI->isVolatile() || !DL->typeSizeEqualsStoreSize(VecTy->getScalarType()))
18671867
return false;
18681868

1869-
InstructionCost OriginalCost =
1870-
TTI.getMemoryOpCost(Instruction::Load, VecTy, LI->getAlign(),
1871-
LI->getPointerAddressSpace(), CostKind);
1872-
InstructionCost ScalarizedCost = 0;
1873-
1869+
bool AllExtracts = true;
1870+
bool AllBitcasts = true;
18741871
Instruction *LastCheckedInst = LI;
18751872
unsigned NumInstChecked = 0;
1876-
DenseMap<ExtractElementInst *, ScalarizationResult> NeedFreeze;
1877-
auto FailureGuard = make_scope_exit([&]() {
1878-
// If the transform is aborted, discard the ScalarizationResults.
1879-
for (auto &Pair : NeedFreeze)
1880-
Pair.second.discard();
1881-
});
18821873

1883-
// Check if all users of the load are extracts with no memory modifications
1884-
// between the load and the extract. Compute the cost of both the original
1885-
// code and the scalarized version.
1874+
// Check what type of users we have (must either all be extracts or
1875+
// bitcasts) and ensure no memory modifications between the load and
1876+
// its users.
18861877
for (User *U : LI->users()) {
1887-
auto *UI = dyn_cast<ExtractElementInst>(U);
1878+
auto *UI = dyn_cast<Instruction>(U);
18881879
if (!UI || UI->getParent() != LI->getParent())
18891880
return false;
18901881

1891-
// If any extract is waiting to be erased, then bail out as this will
1882+
// If any user is waiting to be erased, then bail out as this will
18921883
// distort the cost calculation and possibly lead to infinite loops.
18931884
if (UI->use_empty())
18941885
return false;
18951886

1896-
// Check if any instruction between the load and the extract may modify
1897-
// memory.
1887+
if (!isa<ExtractElementInst>(UI))
1888+
AllExtracts = false;
1889+
if (!isa<BitCastInst>(UI))
1890+
AllBitcasts = false;
1891+
1892+
// Check if any instruction between the load and the user may modify memory.
18981893
if (LastCheckedInst->comesBefore(UI)) {
18991894
for (Instruction &I :
19001895
make_range(std::next(LI->getIterator()), UI->getIterator())) {
@@ -1906,6 +1901,35 @@ bool VectorCombine::scalarizeLoadExtract(Instruction &I) {
19061901
}
19071902
LastCheckedInst = UI;
19081903
}
1904+
}
1905+
1906+
if (AllExtracts)
1907+
return scalarizeLoadExtract(LI, VecTy, Ptr);
1908+
if (AllBitcasts)
1909+
return scalarizeLoadBitcast(LI, VecTy, Ptr);
1910+
return false;
1911+
}
1912+
1913+
/// Try to scalarize vector loads feeding extractelement instructions.
1914+
bool VectorCombine::scalarizeLoadExtract(LoadInst *LI, VectorType *VecTy,
1915+
Value *Ptr) {
1916+
if (!TTI.allowVectorElementIndexingUsingGEP())
1917+
return false;
1918+
1919+
DenseMap<ExtractElementInst *, ScalarizationResult> NeedFreeze;
1920+
auto FailureGuard = make_scope_exit([&]() {
1921+
// If the transform is aborted, discard the ScalarizationResults.
1922+
for (auto &Pair : NeedFreeze)
1923+
Pair.second.discard();
1924+
});
1925+
1926+
InstructionCost OriginalCost =
1927+
TTI.getMemoryOpCost(Instruction::Load, VecTy, LI->getAlign(),
1928+
LI->getPointerAddressSpace(), CostKind);
1929+
InstructionCost ScalarizedCost = 0;
1930+
1931+
for (User *U : LI->users()) {
1932+
auto *UI = cast<ExtractElementInst>(U);
19091933

19101934
auto ScalarIdx =
19111935
canScalarizeAccess(VecTy, UI->getIndexOperand(), LI, AC, DT);
@@ -1927,7 +1951,7 @@ bool VectorCombine::scalarizeLoadExtract(Instruction &I) {
19271951
nullptr, nullptr, CostKind);
19281952
}
19291953

1930-
LLVM_DEBUG(dbgs() << "Found all extractions of a vector load: " << I
1954+
LLVM_DEBUG(dbgs() << "Found all extractions of a vector load: " << *LI
19311955
<< "\n LoadExtractCost: " << OriginalCost
19321956
<< " vs ScalarizedCost: " << ScalarizedCost << "\n");
19331957

@@ -1973,6 +1997,72 @@ bool VectorCombine::scalarizeLoadExtract(Instruction &I) {
19731997
return true;
19741998
}
19751999

2000+
/// Try to scalarize vector loads feeding bitcast instructions.
2001+
bool VectorCombine::scalarizeLoadBitcast(LoadInst *LI, VectorType *VecTy,
2002+
Value *Ptr) {
2003+
InstructionCost OriginalCost =
2004+
TTI.getMemoryOpCost(Instruction::Load, VecTy, LI->getAlign(),
2005+
LI->getPointerAddressSpace(), CostKind);
2006+
2007+
Type *TargetScalarType = nullptr;
2008+
unsigned VecBitWidth = DL->getTypeSizeInBits(VecTy);
2009+
2010+
for (User *U : LI->users()) {
2011+
auto *BC = cast<BitCastInst>(U);
2012+
2013+
Type *DestTy = BC->getDestTy();
2014+
if (!DestTy->isIntegerTy() && !DestTy->isFloatingPointTy())
2015+
return false;
2016+
2017+
unsigned DestBitWidth = DL->getTypeSizeInBits(DestTy);
2018+
if (DestBitWidth != VecBitWidth)
2019+
return false;
2020+
2021+
// All bitcasts must target the same scalar type.
2022+
if (!TargetScalarType)
2023+
TargetScalarType = DestTy;
2024+
else if (TargetScalarType != DestTy)
2025+
return false;
2026+
2027+
OriginalCost +=
2028+
TTI.getCastInstrCost(Instruction::BitCast, TargetScalarType, VecTy,
2029+
TTI.getCastContextHint(BC), CostKind, BC);
2030+
}
2031+
2032+
if (!TargetScalarType)
2033+
return false;
2034+
2035+
assert(!LI->user_empty() && "Unexpected load without bitcast users");
2036+
InstructionCost ScalarizedCost =
2037+
TTI.getMemoryOpCost(Instruction::Load, TargetScalarType, LI->getAlign(),
2038+
LI->getPointerAddressSpace(), CostKind);
2039+
2040+
LLVM_DEBUG(dbgs() << "Found vector load feeding only bitcasts: " << *LI
2041+
<< "\n OriginalCost: " << OriginalCost
2042+
<< " vs ScalarizedCost: " << ScalarizedCost << "\n");
2043+
2044+
if (ScalarizedCost >= OriginalCost)
2045+
return false;
2046+
2047+
// Ensure we add the load back to the worklist BEFORE its users so they can
2048+
// erased in the correct order.
2049+
Worklist.push(LI);
2050+
2051+
Builder.SetInsertPoint(LI);
2052+
auto *ScalarLoad =
2053+
Builder.CreateLoad(TargetScalarType, Ptr, LI->getName() + ".scalar");
2054+
ScalarLoad->setAlignment(LI->getAlign());
2055+
ScalarLoad->copyMetadata(*LI);
2056+
2057+
// Replace all bitcast users with the scalar load.
2058+
for (User *U : LI->users()) {
2059+
auto *BC = cast<BitCastInst>(U);
2060+
replaceValue(*BC, *ScalarLoad, false);
2061+
}
2062+
2063+
return true;
2064+
}
2065+
19762066
bool VectorCombine::scalarizeExtExtract(Instruction &I) {
19772067
if (!TTI.allowVectorElementIndexingUsingGEP())
19782068
return false;
@@ -4585,7 +4675,7 @@ bool VectorCombine::run() {
45854675
if (IsVectorType) {
45864676
if (scalarizeOpOrCmp(I))
45874677
return true;
4588-
if (scalarizeLoadExtract(I))
4678+
if (scalarizeLoad(I))
45894679
return true;
45904680
if (scalarizeExtExtract(I))
45914681
return true;
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 6
2+
; RUN: opt -O3 -mtriple=arm64-apple-darwinos -S %s | FileCheck %s
3+
4+
define noundef i32 @load_ext_extract(ptr %src) {
5+
; CHECK-LABEL: define noundef range(i32 0, 1021) i32 @load_ext_extract(
6+
; CHECK-SAME: ptr readonly captures(none) [[SRC:%.*]]) local_unnamed_addr #[[ATTR0:[0-9]+]] {
7+
; CHECK-NEXT: [[ENTRY:.*:]]
8+
; CHECK-NEXT: [[TMP14:%.*]] = load i32, ptr [[SRC]], align 4
9+
; CHECK-NEXT: [[TMP15:%.*]] = lshr i32 [[TMP14]], 24
10+
; CHECK-NEXT: [[TMP16:%.*]] = lshr i32 [[TMP14]], 16
11+
; CHECK-NEXT: [[TMP17:%.*]] = and i32 [[TMP16]], 255
12+
; CHECK-NEXT: [[TMP18:%.*]] = lshr i32 [[TMP14]], 8
13+
; CHECK-NEXT: [[TMP19:%.*]] = and i32 [[TMP18]], 255
14+
; CHECK-NEXT: [[TMP20:%.*]] = and i32 [[TMP14]], 255
15+
; CHECK-NEXT: [[ADD1:%.*]] = add nuw nsw i32 [[TMP20]], [[TMP19]]
16+
; CHECK-NEXT: [[ADD2:%.*]] = add nuw nsw i32 [[ADD1]], [[TMP17]]
17+
; CHECK-NEXT: [[ADD3:%.*]] = add nuw nsw i32 [[ADD2]], [[TMP15]]
18+
; CHECK-NEXT: ret i32 [[ADD3]]
19+
;
20+
entry:
21+
%x = load <4 x i8>, ptr %src, align 4
22+
%ext = zext nneg <4 x i8> %x to <4 x i32>
23+
%ext.0 = extractelement <4 x i32> %ext, i64 0
24+
%ext.1 = extractelement <4 x i32> %ext, i64 1
25+
%ext.2 = extractelement <4 x i32> %ext, i64 2
26+
%ext.3 = extractelement <4 x i32> %ext, i64 3
27+
28+
%add1 = add i32 %ext.0, %ext.1
29+
%add2 = add i32 %add1, %ext.2
30+
%add3 = add i32 %add2, %ext.3
31+
ret i32 %add3
32+
}
Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 6
2+
; RUN: opt -passes=vector-combine -mtriple=arm64-apple-darwinos -S %s | FileCheck %s
3+
4+
define i32 @load_v4i8_bitcast_to_i32(ptr %x) {
5+
; CHECK-LABEL: define i32 @load_v4i8_bitcast_to_i32(
6+
; CHECK-SAME: ptr [[X:%.*]]) {
7+
; CHECK-NEXT: [[R_SCALAR:%.*]] = load i32, ptr [[X]], align 4
8+
; CHECK-NEXT: ret i32 [[R_SCALAR]]
9+
;
10+
%lv = load <4 x i8>, ptr %x
11+
%r = bitcast <4 x i8> %lv to i32
12+
ret i32 %r
13+
}
14+
15+
define i64 @load_v2i32_bitcast_to_i64(ptr %x) {
16+
; CHECK-LABEL: define i64 @load_v2i32_bitcast_to_i64(
17+
; CHECK-SAME: ptr [[X:%.*]]) {
18+
; CHECK-NEXT: [[R_SCALAR:%.*]] = load i64, ptr [[X]], align 8
19+
; CHECK-NEXT: ret i64 [[R_SCALAR]]
20+
;
21+
%lv = load <2 x i32>, ptr %x
22+
%r = bitcast <2 x i32> %lv to i64
23+
ret i64 %r
24+
}
25+
26+
define float @load_v4i8_bitcast_to_float(ptr %x) {
27+
; CHECK-LABEL: define float @load_v4i8_bitcast_to_float(
28+
; CHECK-SAME: ptr [[X:%.*]]) {
29+
; CHECK-NEXT: [[R_SCALAR:%.*]] = load float, ptr [[X]], align 4
30+
; CHECK-NEXT: ret float [[R_SCALAR]]
31+
;
32+
%lv = load <4 x i8>, ptr %x
33+
%r = bitcast <4 x i8> %lv to float
34+
ret float %r
35+
}
36+
37+
define float @load_v2i16_bitcast_to_float(ptr %x) {
38+
; CHECK-LABEL: define float @load_v2i16_bitcast_to_float(
39+
; CHECK-SAME: ptr [[X:%.*]]) {
40+
; CHECK-NEXT: [[R_SCALAR:%.*]] = load float, ptr [[X]], align 4
41+
; CHECK-NEXT: ret float [[R_SCALAR]]
42+
;
43+
%lv = load <2 x i16>, ptr %x
44+
%r = bitcast <2 x i16> %lv to float
45+
ret float %r
46+
}
47+
48+
define double @load_v4i16_bitcast_to_double(ptr %x) {
49+
; CHECK-LABEL: define double @load_v4i16_bitcast_to_double(
50+
; CHECK-SAME: ptr [[X:%.*]]) {
51+
; CHECK-NEXT: [[LV:%.*]] = load <4 x i16>, ptr [[X]], align 8
52+
; CHECK-NEXT: [[R_SCALAR:%.*]] = bitcast <4 x i16> [[LV]] to double
53+
; CHECK-NEXT: ret double [[R_SCALAR]]
54+
;
55+
%lv = load <4 x i16>, ptr %x
56+
%r = bitcast <4 x i16> %lv to double
57+
ret double %r
58+
}
59+
60+
define double @load_v2i32_bitcast_to_double(ptr %x) {
61+
; CHECK-LABEL: define double @load_v2i32_bitcast_to_double(
62+
; CHECK-SAME: ptr [[X:%.*]]) {
63+
; CHECK-NEXT: [[LV:%.*]] = load <2 x i32>, ptr [[X]], align 8
64+
; CHECK-NEXT: [[R_SCALAR:%.*]] = bitcast <2 x i32> [[LV]] to double
65+
; CHECK-NEXT: ret double [[R_SCALAR]]
66+
;
67+
%lv = load <2 x i32>, ptr %x
68+
%r = bitcast <2 x i32> %lv to double
69+
ret double %r
70+
}
71+
72+
; Multiple users with the same bitcast type should be scalarized.
73+
define i32 @load_v4i8_bitcast_multiple_users_same_type(ptr %x) {
74+
; CHECK-LABEL: define i32 @load_v4i8_bitcast_multiple_users_same_type(
75+
; CHECK-SAME: ptr [[X:%.*]]) {
76+
; CHECK-NEXT: [[LV_SCALAR:%.*]] = load i32, ptr [[X]], align 4
77+
; CHECK-NEXT: [[ADD:%.*]] = add i32 [[LV_SCALAR]], [[LV_SCALAR]]
78+
; CHECK-NEXT: ret i32 [[ADD]]
79+
;
80+
%lv = load <4 x i8>, ptr %x
81+
%r1 = bitcast <4 x i8> %lv to i32
82+
%r2 = bitcast <4 x i8> %lv to i32
83+
%add = add i32 %r1, %r2
84+
ret i32 %add
85+
}
86+
87+
; Different bitcast types should not be scalarized.
88+
define i32 @load_v4i8_bitcast_multiple_users_different_types(ptr %x) {
89+
; CHECK-LABEL: define i32 @load_v4i8_bitcast_multiple_users_different_types(
90+
; CHECK-SAME: ptr [[X:%.*]]) {
91+
; CHECK-NEXT: [[LV:%.*]] = load <4 x i8>, ptr [[X]], align 4
92+
; CHECK-NEXT: [[R1:%.*]] = bitcast <4 x i8> [[LV]] to i32
93+
; CHECK-NEXT: [[R2:%.*]] = bitcast <4 x i8> [[LV]] to float
94+
; CHECK-NEXT: [[R2_INT:%.*]] = bitcast float [[R2]] to i32
95+
; CHECK-NEXT: [[ADD:%.*]] = add i32 [[R1]], [[R2_INT]]
96+
; CHECK-NEXT: ret i32 [[ADD]]
97+
;
98+
%lv = load <4 x i8>, ptr %x
99+
%r1 = bitcast <4 x i8> %lv to i32
100+
%r2 = bitcast <4 x i8> %lv to float
101+
%r2.int = bitcast float %r2 to i32
102+
%add = add i32 %r1, %r2.int
103+
ret i32 %add
104+
}
105+
106+
; Bitcast to vector should not be scalarized.
107+
define <2 x i16> @load_v4i8_bitcast_to_vector(ptr %x) {
108+
; CHECK-LABEL: define <2 x i16> @load_v4i8_bitcast_to_vector(
109+
; CHECK-SAME: ptr [[X:%.*]]) {
110+
; CHECK-NEXT: [[LV:%.*]] = load <4 x i8>, ptr [[X]], align 4
111+
; CHECK-NEXT: [[R:%.*]] = bitcast <4 x i8> [[LV]] to <2 x i16>
112+
; CHECK-NEXT: ret <2 x i16> [[R]]
113+
;
114+
%lv = load <4 x i8>, ptr %x
115+
%r = bitcast <4 x i8> %lv to <2 x i16>
116+
ret <2 x i16> %r
117+
}
118+
119+
; Load with both bitcast users and other users should not be scalarized.
120+
define i32 @load_v4i8_mixed_users(ptr %x) {
121+
; CHECK-LABEL: define i32 @load_v4i8_mixed_users(
122+
; CHECK-SAME: ptr [[X:%.*]]) {
123+
; CHECK-NEXT: [[LV:%.*]] = load <4 x i8>, ptr [[X]], align 4
124+
; CHECK-NEXT: [[R1:%.*]] = bitcast <4 x i8> [[LV]] to i32
125+
; CHECK-NEXT: [[R2:%.*]] = extractelement <4 x i8> [[LV]], i32 0
126+
; CHECK-NEXT: [[R2_EXT:%.*]] = zext i8 [[R2]] to i32
127+
; CHECK-NEXT: [[ADD:%.*]] = add i32 [[R1]], [[R2_EXT]]
128+
; CHECK-NEXT: ret i32 [[ADD]]
129+
;
130+
%lv = load <4 x i8>, ptr %x
131+
%r1 = bitcast <4 x i8> %lv to i32
132+
%r2 = extractelement <4 x i8> %lv, i32 0
133+
%r2.ext = zext i8 %r2 to i32
134+
%add = add i32 %r1, %r2.ext
135+
ret i32 %add
136+
}

0 commit comments

Comments
 (0)