Skip to content

Commit

Permalink
[Coroutines] Enhance symmetric transfer for constant CmpInst
Browse files Browse the repository at this point in the history
This fixes bug52896.

Simply, some symmetric transfer optimization chances get invalided due
to we delete some inlined optimization passes in 822b92a. This would
cause stack-overflow in some situations which should be avoided by the
design of coroutine. This patch tries to fix this by transforming the
constant CmpInst instruction which was done in the deleted passes.

Reviewed By: rjmccall, junparser

Differential Revision: https://reviews.llvm.org/D116327

(cherry picked from commit 403772f)
  • Loading branch information
ChuanqiXu9 authored and tstellar committed Jan 12, 2022
1 parent 9d9efb1 commit b9a243d
Show file tree
Hide file tree
Showing 2 changed files with 128 additions and 35 deletions.
98 changes: 63 additions & 35 deletions llvm/lib/Transforms/Coroutines/CoroSplit.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#include "llvm/Analysis/CFG.h"
#include "llvm/Analysis/CallGraph.h"
#include "llvm/Analysis/CallGraphSCCPass.h"
#include "llvm/Analysis/ConstantFolding.h"
#include "llvm/Analysis/LazyCallGraph.h"
#include "llvm/IR/Argument.h"
#include "llvm/IR/Attributes.h"
Expand Down Expand Up @@ -1174,6 +1175,15 @@ scanPHIsAndUpdateValueMap(Instruction *Prev, BasicBlock *NewBlock,
static bool simplifyTerminatorLeadingToRet(Instruction *InitialInst) {
DenseMap<Value *, Value *> ResolvedValues;
BasicBlock *UnconditionalSucc = nullptr;
assert(InitialInst->getModule());
const DataLayout &DL = InitialInst->getModule()->getDataLayout();

auto TryResolveConstant = [&ResolvedValues](Value *V) {
auto It = ResolvedValues.find(V);
if (It != ResolvedValues.end())
V = It->second;
return dyn_cast<ConstantInt>(V);
};

Instruction *I = InitialInst;
while (I->isTerminator() ||
Expand All @@ -1190,47 +1200,65 @@ static bool simplifyTerminatorLeadingToRet(Instruction *InitialInst) {
}
if (auto *BR = dyn_cast<BranchInst>(I)) {
if (BR->isUnconditional()) {
BasicBlock *BB = BR->getSuccessor(0);
BasicBlock *Succ = BR->getSuccessor(0);
if (I == InitialInst)
UnconditionalSucc = BB;
scanPHIsAndUpdateValueMap(I, BB, ResolvedValues);
I = BB->getFirstNonPHIOrDbgOrLifetime();
UnconditionalSucc = Succ;
scanPHIsAndUpdateValueMap(I, Succ, ResolvedValues);
I = Succ->getFirstNonPHIOrDbgOrLifetime();
continue;
}

BasicBlock *BB = BR->getParent();
// Handle the case the condition of the conditional branch is constant.
// e.g.,
//
// br i1 false, label %cleanup, label %CoroEnd
//
// It is possible during the transformation. We could continue the
// simplifying in this case.
if (ConstantFoldTerminator(BB, /*DeleteDeadConditions=*/true)) {
// Handle this branch in next iteration.
I = BB->getTerminator();
continue;
}
} else if (auto *CondCmp = dyn_cast<CmpInst>(I)) {
// If the case number of suspended switch instruction is reduced to
// 1, then it is simplified to CmpInst in llvm::ConstantFoldTerminator.
auto *BR = dyn_cast<BranchInst>(I->getNextNode());
if (BR && BR->isConditional() && CondCmp == BR->getCondition()) {
// If the case number of suspended switch instruction is reduced to
// 1, then it is simplified to CmpInst in llvm::ConstantFoldTerminator.
// And the comparsion looks like : %cond = icmp eq i8 %V, constant.
ConstantInt *CondConst = dyn_cast<ConstantInt>(CondCmp->getOperand(1));
if (CondConst && CondCmp->getPredicate() == CmpInst::ICMP_EQ) {
Value *V = CondCmp->getOperand(0);
auto it = ResolvedValues.find(V);
if (it != ResolvedValues.end())
V = it->second;

if (ConstantInt *Cond0 = dyn_cast<ConstantInt>(V)) {
BasicBlock *BB = Cond0->equalsInt(CondConst->getZExtValue())
? BR->getSuccessor(0)
: BR->getSuccessor(1);
scanPHIsAndUpdateValueMap(I, BB, ResolvedValues);
I = BB->getFirstNonPHIOrDbgOrLifetime();
continue;
}
}
}
if (!BR || !BR->isConditional() || CondCmp != BR->getCondition())
return false;

// And the comparsion looks like : %cond = icmp eq i8 %V, constant.
// So we try to resolve constant for the first operand only since the
// second operand should be literal constant by design.
ConstantInt *Cond0 = TryResolveConstant(CondCmp->getOperand(0));
auto *Cond1 = dyn_cast<ConstantInt>(CondCmp->getOperand(1));
if (!Cond0 || !Cond1)
return false;

// Both operands of the CmpInst are Constant. So that we could evaluate
// it immediately to get the destination.
auto *ConstResult =
dyn_cast_or_null<ConstantInt>(ConstantFoldCompareInstOperands(
CondCmp->getPredicate(), Cond0, Cond1, DL));
if (!ConstResult)
return false;

CondCmp->replaceAllUsesWith(ConstResult);
CondCmp->eraseFromParent();

// Handle this branch in next iteration.
I = BR;
continue;
} else if (auto *SI = dyn_cast<SwitchInst>(I)) {
Value *V = SI->getCondition();
auto it = ResolvedValues.find(V);
if (it != ResolvedValues.end())
V = it->second;
if (ConstantInt *Cond = dyn_cast<ConstantInt>(V)) {
BasicBlock *BB = SI->findCaseValue(Cond)->getCaseSuccessor();
scanPHIsAndUpdateValueMap(I, BB, ResolvedValues);
I = BB->getFirstNonPHIOrDbgOrLifetime();
continue;
}
ConstantInt *Cond = TryResolveConstant(SI->getCondition());
if (!Cond)
return false;

BasicBlock *BB = SI->findCaseValue(Cond)->getCaseSuccessor();
scanPHIsAndUpdateValueMap(I, BB, ResolvedValues);
I = BB->getFirstNonPHIOrDbgOrLifetime();
continue;
}
return false;
}
Expand Down
65 changes: 65 additions & 0 deletions llvm/test/Transforms/Coroutines/coro-split-musttail4.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
; Tests that coro-split will convert a call before coro.suspend to a musttail call
; while the user of the coro.suspend is a icmpinst.
; RUN: opt < %s -passes='cgscc(coro-split),simplifycfg,early-cse' -S | FileCheck %s

define void @fakeresume1(i8*) {
entry:
ret void;
}

define void @f() #0 {
entry:
%id = call token @llvm.coro.id(i32 0, i8* null, i8* null, i8* null)
%alloc = call i8* @malloc(i64 16) #3
%vFrame = call noalias nonnull i8* @llvm.coro.begin(token %id, i8* %alloc)

%save = call token @llvm.coro.save(i8* null)

%init_suspend = call i8 @llvm.coro.suspend(token %save, i1 false)
switch i8 %init_suspend, label %coro.end [
i8 0, label %await.ready
i8 1, label %coro.end
]
await.ready:
%save2 = call token @llvm.coro.save(i8* null)

call fastcc void @fakeresume1(i8* align 8 null)
%suspend = call i8 @llvm.coro.suspend(token %save2, i1 true)
%switch = icmp ult i8 %suspend, 2
br i1 %switch, label %cleanup, label %coro.end

cleanup:
%free.handle = call i8* @llvm.coro.free(token %id, i8* %vFrame)
%.not = icmp eq i8* %free.handle, null
br i1 %.not, label %coro.end, label %coro.free

coro.free:
call void @delete(i8* nonnull %free.handle) #2
br label %coro.end

coro.end:
call i1 @llvm.coro.end(i8* null, i1 false)
ret void
}

; CHECK-LABEL: @f.resume(
; CHECK: musttail call fastcc void @fakeresume1(
; CHECK-NEXT: ret void

declare token @llvm.coro.id(i32, i8* readnone, i8* nocapture readonly, i8*) #1
declare i1 @llvm.coro.alloc(token) #2
declare i64 @llvm.coro.size.i64() #3
declare i8* @llvm.coro.begin(token, i8* writeonly) #2
declare token @llvm.coro.save(i8*) #2
declare i8* @llvm.coro.frame() #3
declare i8 @llvm.coro.suspend(token, i1) #2
declare i8* @llvm.coro.free(token, i8* nocapture readonly) #1
declare i1 @llvm.coro.end(i8*, i1) #2
declare i8* @llvm.coro.subfn.addr(i8* nocapture readonly, i8) #1
declare i8* @malloc(i64)
declare void @delete(i8* nonnull) #2

attributes #0 = { "coroutine.presplit"="1" }
attributes #1 = { argmemonly nounwind readonly }
attributes #2 = { nounwind }
attributes #3 = { nounwind readnone }

0 comments on commit b9a243d

Please sign in to comment.