Skip to content

Commit fe7ec4a

Browse files
committed
feat: clear out unwanted no-ops from loop body
1 parent e28316e commit fe7ec4a

File tree

3 files changed

+60
-2
lines changed

3 files changed

+60
-2
lines changed

src/enzyme_ad/jax/Passes/AutoBatching.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -831,7 +831,8 @@ bool GreedyWhileLoopBatchFission::liftElementwiseOp(
831831
if (elem->getNumOperands() != 1) {
832832
// For non-unary elementwise ops, we can
833833
// 1. lift operands if they are produced by DSs in the list
834-
// 2. defined outside the loop body, in which case we simply do a broadcast_in_dim
834+
// 2. defined outside the loop body, in which case we simply do a
835+
// broadcast_in_dim
835836
return false;
836837
}
837838

src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp

Lines changed: 53 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,13 @@
2222
#include "mlir/IR/IRMapping.h"
2323
#include "mlir/IR/Matchers.h"
2424
#include "mlir/IR/PatternMatch.h"
25+
#include "mlir/IR/Visitors.h"
2526
#include "mlir/Pass/PassManager.h"
2627
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
2728
#include "shardy/dialect/sdy/ir/utils.h"
2829
#include "src/enzyme_ad/jax/Dialect/Dialect.h"
2930
#include "src/enzyme_ad/jax/Dialect/Ops.h"
31+
#include "src/enzyme_ad/jax/Implementations/WhileLoopInfo.h"
3032
#include "src/enzyme_ad/jax/Passes/EnzymeHLOPatterns.h"
3133
#include "src/enzyme_ad/jax/Passes/Passes.h"
3234
#include "src/enzyme_ad/jax/Passes/StructuredTensors.h"
@@ -23215,6 +23217,55 @@ struct CaseToIf : public CheckedOpRewritePattern<stablehlo::CaseOp, CaseToIf> {
2321523217
}
2321623218
};
2321723219

23220+
struct RemoveNoOpsFromWhileLoop
23221+
: public CheckedOpRewritePattern<stablehlo::WhileOp,
23222+
RemoveNoOpsFromWhileLoop> {
23223+
using CheckedOpRewritePattern<
23224+
stablehlo::WhileOp, RemoveNoOpsFromWhileLoop>::CheckedOpRewritePattern;
23225+
23226+
LogicalResult matchAndRewriteImpl(stablehlo::WhileOp whileOp,
23227+
PatternRewriter &rewriter) const {
23228+
auto info = WhileLoopInfo(whileOp);
23229+
auto computeInfoSuccess = info.computeInfo();
23230+
if (computeInfoSuccess.failed())
23231+
return computeInfoSuccess;
23232+
23233+
if (!info.isValid() || !info.isConstant())
23234+
return failure();
23235+
23236+
auto &whileBody = whileOp.getBody().front();
23237+
auto inductionVar = whileBody.getArgument(0);
23238+
23239+
bool anyNoOpRemoved = false;
23240+
23241+
auto limit = info.getConstantLimit().value();
23242+
auto start = info.getConstantStart().value();
23243+
auto step = info.getConstantStep().value();
23244+
23245+
// currently only removes remainder of induction variable
23246+
whileBody.walk([&](stablehlo::RemOp remOp) -> WalkResult {
23247+
if (remOp.getLhs() == inductionVar) {
23248+
SplatElementsAttr remRhsAttr;
23249+
if (matchPattern(remOp.getRhs(), m_Constant(&remRhsAttr))) {
23250+
auto attr = remRhsAttr.getSplatValue<Attribute>();
23251+
if (auto intAttr = dyn_cast<IntegerAttr>(attr)) {
23252+
auto rhsRemValue = intAttr.getValue();
23253+
APInt startAPInt(rhsRemValue.getBitWidth(), start, true);
23254+
APInt limitAPInt(rhsRemValue.getBitWidth(), limit, true);
23255+
if (rhsRemValue.sge(limitAPInt) && startAPInt.slt(rhsRemValue)) {
23256+
rewriter.replaceOp(remOp, remOp.getLhs());
23257+
anyNoOpRemoved = true;
23258+
}
23259+
}
23260+
}
23261+
}
23262+
return WalkResult::advance();
23263+
});
23264+
23265+
return anyNoOpRemoved ? success() : failure();
23266+
}
23267+
};
23268+
2321823269
/////////////// End Imported from stablehlo
2321923270

2322023271
// clang-format off
@@ -23826,7 +23877,8 @@ struct EnzymeHLOOptPass
2382623877
MulReduceSliceFusion,
2382723878
MinReduceSliceFusion,
2382823879
MaxReduceSliceFusion,
23829-
CaseToIf
23880+
CaseToIf,
23881+
RemoveNoOpsFromWhileLoop
2383023882
>(context);
2383123883

2383223884
patterns.add<

src/enzyme_ad/jax/TransformOps/TransformOps.td

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2378,3 +2378,8 @@ def ApplyElementwiseSliceToBatchPatterns : EnzymeHLOPatternOp<
23782378
def ApplyCaseToIfPatterns : EnzymeHLOPatternOp<"case_to_if"> {
23792379
let patterns = ["CaseToIf"];
23802380
}
2381+
2382+
def ApplyRemoveNoOpsFromWhileLoopPatterns : EnzymeHLOPatternOp<
2383+
"remove_no_ops_from_while_loop"> {
2384+
let patterns = ["RemoveNoOpsFromWhileLoop"];
2385+
}

0 commit comments

Comments
 (0)