Skip to content

Commit 407ce70

Browse files
committed
Fix invalid IR from scalarrepl-param-hlsl in ReplaceConstantWithInst
ReplaceConstantWithInst(C, V) replaces uses of C in the current function with V. If such a use C is an instruction I, the it replaces uses of C in I with V. However, this function did not make sure to only perform this replacement if V dominates I. As a result, it may end up replacing uses of C in instructions before the definition of V. The fix is to lazily compute the dominator tree in ReplaceConstantWithInst so that we can guard the replacement with that dominance check.
1 parent 7cf175f commit 407ce70

5 files changed

+328
-7
lines changed

lib/Transforms/Scalar/ScalarReplAggregatesHLSL.cpp

Lines changed: 30 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3271,15 +3271,34 @@ bool SROA_Helper::DoScalarReplacement(GlobalVariable *GV,
32713271
return true;
32723272
}
32733273

3274-
static void ReplaceConstantWithInst(Constant *C, Value *V,
3274+
// Replaces uses of constant C in the current function
3275+
// with V, when those uses are dominated by V.
3276+
// Returns true if it was completely replaced.
3277+
static bool ReplaceConstantWithInst(Constant *C, Value *V,
32753278
IRBuilder<> &Builder) {
3279+
bool bReplacedAll = true;
32763280
Function *F = Builder.GetInsertBlock()->getParent();
3281+
Instruction *VInst = dyn_cast<Instruction>(V);
3282+
// Lazily calculate dominance
3283+
DominatorTree DT;
3284+
bool Calculated = false;
3285+
auto Dominates = [&](llvm::Instruction *Def, llvm::Instruction *User) {
3286+
if (!Calculated) {
3287+
DT.recalculate(*F);
3288+
Calculated = true;
3289+
}
3290+
return DT.dominates(Def, User);
3291+
};
3292+
32773293
for (auto it = C->user_begin(); it != C->user_end();) {
32783294
User *U = *(it++);
32793295
if (Instruction *I = dyn_cast<Instruction>(U)) {
32803296
if (I->getParent()->getParent() != F)
32813297
continue;
3282-
I->replaceUsesOfWith(C, V);
3298+
if (VInst && Dominates(VInst, I))
3299+
I->replaceUsesOfWith(C, V);
3300+
else
3301+
bReplacedAll = false;
32833302
} else {
32843303
// Skip unused ConstantExpr.
32853304
if (U->user_empty())
@@ -3288,10 +3307,12 @@ static void ReplaceConstantWithInst(Constant *C, Value *V,
32883307
Instruction *Inst = CE->getAsInstruction();
32893308
Builder.Insert(Inst);
32903309
Inst->replaceUsesOfWith(C, V);
3291-
ReplaceConstantWithInst(CE, Inst, Builder);
3310+
if (!ReplaceConstantWithInst(CE, Inst, Builder))
3311+
bReplacedAll = false;
32923312
}
32933313
}
32943314
C->removeDeadConstantUsers();
3315+
return bReplacedAll;
32953316
}
32963317

32973318
static void ReplaceUnboundedArrayUses(Value *V, Value *Src) {
@@ -3531,15 +3552,17 @@ static bool ReplaceMemcpy(Value *V, Value *Src, MemCpyInst *MC,
35313552
} else {
35323553
// Replace Constant with a non-Constant.
35333554
IRBuilder<> Builder(MC);
3534-
ReplaceConstantWithInst(C, Src, Builder);
3555+
if (!ReplaceConstantWithInst(C, Src, Builder))
3556+
return false;
35353557
}
35363558
} else {
35373559
// Try convert special pattern for cbuffer which copy array of float4 to
35383560
// array of float.
35393561
if (!tryToReplaceCBVec4ArrayToScalarArray(V, TyV, Src, TySrc, MC, DL)) {
35403562
IRBuilder<> Builder(MC);
35413563
Src = Builder.CreateBitCast(Src, V->getType());
3542-
ReplaceConstantWithInst(C, Src, Builder);
3564+
if (!ReplaceConstantWithInst(C, Src, Builder))
3565+
return false;
35433566
}
35443567
}
35453568
} else {
@@ -5449,9 +5472,9 @@ void SROA_Parameter_HLSL::flattenArgument(
54495472
if (Ty->isPointerTy())
54505473
Ty = Ty->getPointerElementType();
54515474
unsigned size = DL.getTypeAllocSize(Ty);
5452-
#if 0 // HLSL Change
5475+
#if 0 // HLSL Change
54535476
DIExpression *DDIExp = DIB.createBitPieceExpression(debugOffset, size);
5454-
#else // HLSL Change
5477+
#else // HLSL Change
54555478
Type *argTy = Arg->getType();
54565479
if (argTy->isPointerTy())
54575480
argTy = argTy->getPointerElementType();
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
// RUN: not %dxc -T ps_6_2 %s 2>&1 | FileCheck %s
2+
3+
// Validate that copying from static array to local, then back to static
4+
// array does not crash the compiler. This was resulting in an invalid
5+
// ReplaceConstantWithInst from ScalarReplAggregatesHLSL, which would
6+
// result in referenced deleted instruction in a later pass.
7+
8+
// CHECK: error: Loop must have break.
9+
10+
static int arr1[10] = (int[10])0;
11+
static int arr2[10] = (int[10])0;
12+
static float result = 0;
13+
ByteAddressBuffer buff : register(t0);
14+
15+
void foo() {
16+
int i = 0;
17+
if (buff.Load(0u)) {
18+
return;
19+
}
20+
arr2[i] = arr1[i];
21+
result = float(arr1[0]);
22+
}
23+
24+
struct tint_symbol {
25+
float4 value : SV_Target0;
26+
};
27+
28+
float main_inner() {
29+
foo();
30+
bool cond = false;
31+
while (true) {
32+
if (cond) { break; }
33+
}
34+
int arr1_copy[10] = arr1; // constant to local
35+
arr1 = arr1_copy; // local to constant
36+
foo();
37+
return result;
38+
}
39+
40+
tint_symbol main() {
41+
float inner_result = main_inner();
42+
tint_symbol wrapper_result = (tint_symbol)0;
43+
wrapper_result.value.x = inner_result;
44+
return wrapper_result;
45+
}

0 commit comments

Comments
 (0)