Skip to content

Commit af7bc5f

Browse files
committed
refactor: support dynamicsliceop
1 parent b317691 commit af7bc5f

File tree

2 files changed

+46
-32
lines changed

2 files changed

+46
-32
lines changed

src/enzyme_ad/jax/Passes/AutoBatching.cpp

Lines changed: 15 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -265,6 +265,19 @@ constructAndExtractBatchOperands(PatternRewriter &rewriter,
265265
return std::make_tuple(operands, operandIndexMap);
266266
}
267267

268+
std::tuple<bool, bool> allSameBool(const SmallVector<bool> &bools) {
269+
return {
270+
llvm::all_of(bools, [&](bool b) { return b == bools.front(); }),
271+
bools.front(),
272+
};
273+
}
274+
275+
bool allOpsAreUnique(const SmallVector<Operation *> &ops) {
276+
SmallPtrSet<Operation *, 8> seen;
277+
return llvm::all_of(ops,
278+
[&](Operation *op) { return seen.insert(op).second; });
279+
}
280+
268281
LogicalResult ConcatInsertDimToBatchBase::matchAndRewriteImpl(
269282
stablehlo::ConcatenateOp concatOp, PatternRewriter &rewriter) const {
270283
if (concatOp.getNumOperands() <= 1)
@@ -732,29 +745,8 @@ bool SliceToBatchBase::areSlicesContiguous(
732745
return true;
733746
}
734747

735-
std::tuple<bool, bool>
736-
SliceToBatchBase::allSameBool(const SmallVector<bool> &bools) const {
737-
bool first = bools[0];
738-
for (auto b : bools) {
739-
if (b != first)
740-
return {false, first};
741-
}
742-
return {true, first};
743-
}
744-
745-
bool SliceToBatchBase::allOpsAreUnique(
746-
const SmallVector<Operation *> &ops) const {
747-
SmallPtrSet<Operation *, 8> seen;
748-
for (auto op : ops) {
749-
if (!seen.insert(op).second)
750-
return false;
751-
}
752-
return true;
753-
}
754-
755-
LogicalResult
756-
GreedyWhileLoopBatchFission::matchAndRewrite(stablehlo::WhileOp whileOp,
757-
PatternRewriter &rewriter) const {
748+
LogicalResult GreedyWhileLoopBatchFission::matchAndRewriteImpl(
749+
stablehlo::WhileOp whileOp, PatternRewriter &rewriter) const {
758750
auto info = WhileLoopInfo(whileOp);
759751
auto computeInfoSuccess = info.computeInfo();
760752
if (computeInfoSuccess.failed())

src/enzyme_ad/jax/Passes/AutoBatching.h

Lines changed: 31 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@ struct WhileLoopInfo;
1717
}; // namespace enzyme
1818
}; // namespace mlir
1919

20+
std::tuple<bool, bool> allSameBool(const llvm::SmallVector<bool> &bools);
21+
bool allOpsAreUnique(const llvm::SmallVector<mlir::Operation *> &ops);
22+
2023
struct BatchOperandConstructionInfo {
2124
mlir::stablehlo::SliceOp sliceOp;
2225
int32_t sliceOperandIndex;
@@ -25,6 +28,27 @@ struct BatchOperandConstructionInfo {
2528
bool intermediateReshape;
2629
};
2730

31+
// TODO: update old code to use this new slice info
32+
template <typename OpTy> struct NewSliceInfo {
33+
OpTy sliceOp;
34+
llvm::SmallVector<mlir::Value> dynamicStartIndices;
35+
llvm::SmallVector<int64_t> startIndices;
36+
llvm::SmallVector<int64_t> sliceSizes;
37+
int64_t sliceDim;
38+
int64_t sliceStart;
39+
bool supported;
40+
};
41+
42+
NewSliceInfo<mlir::stablehlo::SliceOp>
43+
constructNewSliceInfo(mlir::stablehlo::SliceOp sliceOp);
44+
NewSliceInfo<mlir::stablehlo::DynamicSliceOp>
45+
constructNewSliceInfo(mlir::stablehlo::DynamicSliceOp sliceOp);
46+
47+
bool areSlicesContiguous(
48+
llvm::SmallVector<NewSliceInfo<mlir::stablehlo::SliceOp>> &slices);
49+
bool areSlicesContiguous(
50+
llvm::SmallVector<NewSliceInfo<mlir::stablehlo::DynamicSliceOp>> &slices);
51+
2852
struct ConcatInsertDimToBatchBase
2953
: public mlir::enzyme::CheckedOpRewritePattern<
3054
mlir::stablehlo::ConcatenateOp, ConcatInsertDimToBatchBase> {
@@ -108,11 +132,6 @@ struct SliceToBatchBase
108132

109133
SliceInfo extractSliceInfo(mlir::stablehlo::SliceOp slice) const;
110134
bool areSlicesContiguous(llvm::SmallVector<SliceInfo> &slices) const;
111-
std::tuple<bool, bool>
112-
allSameBool(const llvm::SmallVector<bool> &bools) const;
113-
std::tuple<bool, bool>
114-
allSameBool(const llvm::SmallVector<mlir::Operation *> &ops) const;
115-
bool allOpsAreUnique(const llvm::SmallVector<mlir::Operation *> &ops) const;
116135

117136
protected:
118137
std::function<mlir::Operation *(mlir::Operation *)> isValidTargetOp;
@@ -182,13 +201,16 @@ struct SliceToBatchElementwise : public SliceToBatchBase {
182201
};
183202

184203
struct GreedyWhileLoopBatchFission
185-
: public mlir::OpRewritePattern<mlir::stablehlo::WhileOp> {
186-
using Base = mlir::OpRewritePattern<mlir::stablehlo::WhileOp>;
204+
: public mlir::enzyme::CheckedOpRewritePattern<
205+
mlir::stablehlo::WhileOp, GreedyWhileLoopBatchFission> {
206+
using Base =
207+
mlir::enzyme::CheckedOpRewritePattern<mlir::stablehlo::WhileOp,
208+
GreedyWhileLoopBatchFission>;
187209
using Base::Base;
188210

189211
mlir::LogicalResult
190-
matchAndRewrite(mlir::stablehlo::WhileOp whileOp,
191-
mlir::PatternRewriter &rewriter) const override;
212+
matchAndRewriteImpl(mlir::stablehlo::WhileOp whileOp,
213+
mlir::PatternRewriter &rewriter) const;
192214

193215
private:
194216
struct DynamicSliceInfo {

0 commit comments

Comments
 (0)