Skip to content

Commit

Permalink
Fix invalid IR from scalarrepl-param-hlsl in ReplaceConstantWithInst
Browse files Browse the repository at this point in the history
ReplaceConstantWithInst(C, V) replaces uses of C in the current function with V.
If such a use C is an instruction I, the it replaces uses of C in I with V.
However, this function did not make sure to only perform this replacement if V
dominates I. As a result, it may end up replacing uses of C in instructions
before the definition of V.

The fix is to lazily compute the dominator tree in ReplaceConstantWithInst so
that we can guard the replacement with that dominance check.
  • Loading branch information
amaiorano committed Apr 22, 2024
1 parent 7cf175f commit dd202d2
Show file tree
Hide file tree
Showing 3 changed files with 337 additions and 9 deletions.
47 changes: 38 additions & 9 deletions lib/Transforms/Scalar/ScalarReplAggregatesHLSL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3271,15 +3271,35 @@ bool SROA_Helper::DoScalarReplacement(GlobalVariable *GV,
return true;
}

static void ReplaceConstantWithInst(Constant *C, Value *V,
// Replaces uses of constant C in the current function
// with V, when those uses are dominated by V.
// Returns true if it was completely replaced.
static bool ReplaceConstantWithInst(Constant *C, Value *V,
IRBuilder<> &Builder) {
bool bReplacedAll = true;
Function *F = Builder.GetInsertBlock()->getParent();
Instruction *VInst = dyn_cast<Instruction>(V);
// Lazily calculate dominance
DominatorTree DT;
bool Calculated = false;
auto Dominates = [&](llvm::Instruction *Def, llvm::Instruction *User) {
if (!Calculated) {
DT.recalculate(*F);
Calculated = true;
}
return DT.dominates(Def, User);
};

for (auto it = C->user_begin(); it != C->user_end();) {
User *U = *(it++);
if (Instruction *I = dyn_cast<Instruction>(U)) {
if (I->getParent()->getParent() != F)
continue;
I->replaceUsesOfWith(C, V);
if (VInst && Dominates(VInst, I)) {
I->replaceUsesOfWith(C, V);
} else {
bReplacedAll = false;
}
} else {
// Skip unused ConstantExpr.
if (U->user_empty())
Expand All @@ -3288,10 +3308,13 @@ static void ReplaceConstantWithInst(Constant *C, Value *V,
Instruction *Inst = CE->getAsInstruction();
Builder.Insert(Inst);
Inst->replaceUsesOfWith(C, V);
ReplaceConstantWithInst(CE, Inst, Builder);
if (!ReplaceConstantWithInst(CE, Inst, Builder)) {
bReplacedAll = false;
}
}
}
C->removeDeadConstantUsers();
return bReplacedAll;
}

static void ReplaceUnboundedArrayUses(Value *V, Value *Src) {
Expand Down Expand Up @@ -3531,15 +3554,19 @@ static bool ReplaceMemcpy(Value *V, Value *Src, MemCpyInst *MC,
} else {
// Replace Constant with a non-Constant.
IRBuilder<> Builder(MC);
ReplaceConstantWithInst(C, Src, Builder);
if (!ReplaceConstantWithInst(C, Src, Builder)) {
return false;
}
}
} else {
// Try convert special pattern for cbuffer which copy array of float4 to
// array of float.
if (!tryToReplaceCBVec4ArrayToScalarArray(V, TyV, Src, TySrc, MC, DL)) {
IRBuilder<> Builder(MC);
Src = Builder.CreateBitCast(Src, V->getType());
ReplaceConstantWithInst(C, Src, Builder);
if (!ReplaceConstantWithInst(C, Src, Builder)) {
return false;
}
}
}
} else {
Expand Down Expand Up @@ -3678,13 +3705,15 @@ static bool ReplaceUseOfZeroInit(Instruction *I, Value *V, DominatorTree &DT,
continue;

// Skip properly dominated users
if (DT.properlyDominates(BB, UI->getParent()))
if (DT.properlyDominates(BB, UI->getParent())) {
continue;
}

// If user is found in memcpy successor list
// then the user is not safe to replace with zeroinitializer.
if (Reachable.count(UI->getParent()))
if (Reachable.count(UI->getParent())) {
return false;
}

// Remaining cases are where I:
// - is at the end of the same block
Expand Down Expand Up @@ -5449,9 +5478,9 @@ void SROA_Parameter_HLSL::flattenArgument(
if (Ty->isPointerTy())
Ty = Ty->getPointerElementType();
unsigned size = DL.getTypeAllocSize(Ty);
#if 0 // HLSL Change
#if 0 // HLSL Change
DIExpression *DDIExp = DIB.createBitPieceExpression(debugOffset, size);
#else // HLSL Change
#else // HLSL Change
Type *argTy = Arg->getType();
if (argTy->isPointerTy())
argTy = argTy->getPointerElementType();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
// RUN: %dxc -T ps_6_2 %s | FileCheck %s

// Validate that copying from static array to local, then back to static
// array does not crash the compiler. This was resulting in an invalid
// ReplaceConstantWithInst from ScalarReplAggregatesHLSL, which would
// result in referenced deleted instruction in a later pass.

// This test is expected to fail with "error: Loop must have a break."
// XFAIL: *

static int arr1[10] = (int[10])0;
static int arr2[10] = (int[10])0;
static float result = 0;
ByteAddressBuffer buff : register(t0);

void foo() {
int i = 0;
if (buff.Load(0u)) {
return;
}
arr2[i] = arr1[i];
result = float(arr1[0]);
}

struct tint_symbol {
float4 value : SV_Target0;
};

float main_inner() {
foo();
bool cond = false;
while (true) {
if (cond) { break; }
}
int arr1_copy[10] = arr1; // constant to local
arr1 = arr1_copy; // local to constant
foo();
return result;
}

tint_symbol main() {
float inner_result = main_inner();
tint_symbol wrapper_result = (tint_symbol)0;
wrapper_result.value.x = inner_result;
return wrapper_result;
}
Loading

0 comments on commit dd202d2

Please sign in to comment.