Skip to content

[CIR][ThroughMLIR] Lower WhileOp with break #1735

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 69 additions & 1 deletion clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRLoopToSCF.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -528,7 +528,7 @@ class CIRWhileOpLowering : public mlir::OpConversionPattern<cir::WhileOp> {
for (auto continueOp : continues) {
bool nested = false;
// When there is another loop between this WhileOp and the ContinueOp,
// we shouldn't change that loop instead.
// we should change that loop instead.
for (mlir::Operation *parent = continueOp->getParentOp();
parent != whileOp; parent = parent->getParentOp()) {
if (isa<WhileOp>(parent)) {
Expand Down Expand Up @@ -570,6 +570,73 @@ class CIRWhileOpLowering : public mlir::OpConversionPattern<cir::WhileOp> {
}
}

void rewriteBreak(mlir::scf::WhileOp whileOp,
mlir::ConversionPatternRewriter &rewriter) const {
// Collect all BreakOp inside this while.
llvm::SmallVector<cir::BreakOp> breaks;
whileOp->walk([&](mlir::Operation *op) {
if (auto breakOp = dyn_cast<BreakOp>(op))
breaks.push_back(breakOp);
});

if (breaks.empty())
return;

for (auto breakOp : breaks) {
// When there is another loop between this WhileOp and the BreakOp,
// we should change that loop instead.
if (breakOp->getParentOfType<mlir::scf::WhileOp>() != whileOp)
continue;

// Similar to the case of ContinueOp, when there is an `IfOp`,
// we need to take special care.
for (mlir::Operation *parent = breakOp->getParentOp(); parent != whileOp;
parent = parent->getParentOp()) {
if (auto ifOp = dyn_cast<cir::IfOp>(parent))
llvm_unreachable("NYI");
}

// Operations after this BreakOp has to be removed.
for (mlir::Operation *runner = breakOp->getNextNode(); runner;) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why you need to remove this? Looks like we should just split the block at this point and create an unrecheable block (which should get later DCE'd by canonicalizer)?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know whether it's allowed to delete a non-empty loop (as I'm deleting the loop below). I'll change it to splitting the block if it doesn't matter or we can find a better way below.

mlir::Operation *next = runner->getNextNode();
runner->erase();
runner = next;
}

// Blocks after this BreakOp also has to be removed.
for (mlir::Block *block = breakOp->getBlock()->getNextNode(); block;) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This ties back to the other comment

mlir::Block *next = block->getNextNode();
block->erase();
block = next;
}

// We know this BreakOp isn't nested in any IfOp.
// Therefore, the loop is executed only once.
// We pull everything out of the loop.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like your are optimizing while you are lowering, why isn't this a separate pass?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we need to erase the BreakOp because scf doesn't support it, so we have to rewrite the loop in some way to preserve semantics. Deleting the loop is the most straightforward way I can think of. Could you suggest better ways of rewriting?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see your point! This is good for now.

Thinking more about this it seems like these things should actually be handled in a pass before the actual SCF lowering (something like a CoreDialectPrepare kinda thing) - just like we have a CFGFlatten pre LLVM we could have something that massage CIR loops into CIR loops suitable for more direct SCF translation. This is more food for thought for the future, where you will probably have gathered more examples (cir.continue is slightly different but suffers from a similar issue?).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a good idea. I think I'll do the cir.switch handling (which I'm working on now) in a new pass, and move these break/continue lowering to there afterwards.


auto &beforeOps = whileOp.getBeforeBody()->getOperations();
for (mlir::Operation *op = &*beforeOps.begin(); op;) {
if (isa<ConditionOp>(op))
break;
auto *next = op->getNextNode();
op->moveBefore(whileOp);
op = next;
}

auto &afterOps = whileOp.getAfterBody()->getOperations();
for (mlir::Operation *op = &*afterOps.begin(); op;) {
if (isa<YieldOp>(op))
break;
auto *next = op->getNextNode();
op->moveBefore(whileOp);
op = next;
}

// The loop itself should now be removed.
rewriter.eraseOp(whileOp);
}
}

public:
using OpConversionPattern<cir::WhileOp>::OpConversionPattern;

Expand All @@ -579,6 +646,7 @@ class CIRWhileOpLowering : public mlir::OpConversionPattern<cir::WhileOp> {
SCFWhileLoop loop(op, adaptor, &rewriter);
auto whileOp = loop.transferToSCFWhileOp();
rewriteContinue(whileOp, rewriter);
rewriteBreak(whileOp, rewriter);
rewriter.eraseOp(op);
return mlir::success();
}
Expand Down
25 changes: 25 additions & 0 deletions clang/test/CIR/Lowering/ThroughMLIR/while-with-break.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -fclangir -fno-clangir-direct-lowering -emit-mlir=core %s -o %t.mlir
// RUN: FileCheck --input-file=%t.mlir %s

void while_break() {
int i = 0;
while (i < 100) {
i++;
break;
i++;
}
// This should be compiled into the condition `i < 100` and a single `i++`,
// without the while-loop.

// CHECK: memref.alloca_scope {
// CHECK: %[[IV:.+]] = memref.load %alloca[]
// CHECK: %[[HUNDRED:.+]] = arith.constant 100
// CHECK: %[[_:.+]] = arith.cmpi slt, %[[IV]], %[[HUNDRED]]
// CHECK: memref.alloca_scope {
// CHECK: %[[IV2:.+]] = memref.load %alloca[]
// CHECK: %[[ONE:.+]] = arith.constant 1
// CHECK: %[[INCR:.+]] = arith.addi %[[IV2]], %[[ONE]]
// CHECK: memref.store %[[INCR]], %alloca[]
// CHECK: }
// CHECK: }
}
Loading