|
22 | 22 | #include "mlir/IR/IRMapping.h" |
23 | 23 | #include "mlir/IR/Matchers.h" |
24 | 24 | #include "mlir/IR/PatternMatch.h" |
| 25 | +#include "mlir/IR/Visitors.h" |
25 | 26 | #include "mlir/Pass/PassManager.h" |
26 | 27 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
27 | 28 | #include "shardy/dialect/sdy/ir/utils.h" |
28 | 29 | #include "src/enzyme_ad/jax/Dialect/Dialect.h" |
29 | 30 | #include "src/enzyme_ad/jax/Dialect/Ops.h" |
| 31 | +#include "src/enzyme_ad/jax/Implementations/WhileLoopInfo.h" |
30 | 32 | #include "src/enzyme_ad/jax/Passes/EnzymeHLOPatterns.h" |
31 | 33 | #include "src/enzyme_ad/jax/Passes/Passes.h" |
32 | 34 | #include "src/enzyme_ad/jax/Passes/StructuredTensors.h" |
@@ -23215,6 +23217,55 @@ struct CaseToIf : public CheckedOpRewritePattern<stablehlo::CaseOp, CaseToIf> { |
23215 | 23217 | } |
23216 | 23218 | }; |
23217 | 23219 |
|
| 23220 | +struct RemoveNoOpsFromWhileLoop |
| 23221 | + : public CheckedOpRewritePattern<stablehlo::WhileOp, |
| 23222 | + RemoveNoOpsFromWhileLoop> { |
| 23223 | + using CheckedOpRewritePattern< |
| 23224 | + stablehlo::WhileOp, RemoveNoOpsFromWhileLoop>::CheckedOpRewritePattern; |
| 23225 | + |
| 23226 | + LogicalResult matchAndRewriteImpl(stablehlo::WhileOp whileOp, |
| 23227 | + PatternRewriter &rewriter) const { |
| 23228 | + auto info = WhileLoopInfo(whileOp); |
| 23229 | + auto computeInfoSuccess = info.computeInfo(); |
| 23230 | + if (computeInfoSuccess.failed()) |
| 23231 | + return computeInfoSuccess; |
| 23232 | + |
| 23233 | + if (!info.isValid() || !info.isConstant()) |
| 23234 | + return failure(); |
| 23235 | + |
| 23236 | + auto &whileBody = whileOp.getBody().front(); |
| 23237 | + auto inductionVar = whileBody.getArgument(0); |
| 23238 | + |
| 23239 | + bool anyNoOpRemoved = false; |
| 23240 | + |
| 23241 | + auto limit = info.getConstantLimit().value(); |
| 23242 | + auto start = info.getConstantStart().value(); |
| 23243 | + auto step = info.getConstantStep().value(); |
| 23244 | + |
| 23245 | + // currently only removes remainder of induction variable |
| 23246 | + whileBody.walk([&](stablehlo::RemOp remOp) -> WalkResult { |
| 23247 | + if (remOp.getLhs() == inductionVar) { |
| 23248 | + SplatElementsAttr remRhsAttr; |
| 23249 | + if (matchPattern(remOp.getRhs(), m_Constant(&remRhsAttr))) { |
| 23250 | + auto attr = remRhsAttr.getSplatValue<Attribute>(); |
| 23251 | + if (auto intAttr = dyn_cast<IntegerAttr>(attr)) { |
| 23252 | + auto rhsRemValue = intAttr.getValue(); |
| 23253 | + APInt startAPInt(rhsRemValue.getBitWidth(), start, true); |
| 23254 | + APInt limitAPInt(rhsRemValue.getBitWidth(), limit, true); |
| 23255 | + if (rhsRemValue.sge(limitAPInt) && startAPInt.slt(rhsRemValue)) { |
| 23256 | + rewriter.replaceOp(remOp, remOp.getLhs()); |
| 23257 | + anyNoOpRemoved = true; |
| 23258 | + } |
| 23259 | + } |
| 23260 | + } |
| 23261 | + } |
| 23262 | + return WalkResult::advance(); |
| 23263 | + }); |
| 23264 | + |
| 23265 | + return anyNoOpRemoved ? success() : failure(); |
| 23266 | + } |
| 23267 | +}; |
| 23268 | + |
23218 | 23269 | /////////////// End Imported from stablehlo |
23219 | 23270 |
|
23220 | 23271 | // clang-format off |
@@ -23826,7 +23877,8 @@ struct EnzymeHLOOptPass |
23826 | 23877 | MulReduceSliceFusion, |
23827 | 23878 | MinReduceSliceFusion, |
23828 | 23879 | MaxReduceSliceFusion, |
23829 | | - CaseToIf |
| 23880 | + CaseToIf, |
| 23881 | + RemoveNoOpsFromWhileLoop |
23830 | 23882 | >(context); |
23831 | 23883 |
|
23832 | 23884 | patterns.add< |
|
0 commit comments