@@ -17,6 +17,9 @@ struct WhileLoopInfo;
1717}; // namespace enzyme
1818}; // namespace mlir
1919
20+ std::tuple<bool , bool > allSameBool (const llvm::SmallVector<bool > &bools);
21+ bool allOpsAreUnique (const llvm::SmallVector<mlir::Operation *> &ops);
22+
2023struct BatchOperandConstructionInfo {
2124 mlir::stablehlo::SliceOp sliceOp;
2225 int32_t sliceOperandIndex;
@@ -25,6 +28,27 @@ struct BatchOperandConstructionInfo {
2528 bool intermediateReshape;
2629};
2730
31+ // TODO: update old code to use this new slice info
32+ template <typename OpTy> struct NewSliceInfo {
33+ OpTy sliceOp;
34+ llvm::SmallVector<mlir::Value> dynamicStartIndices;
35+ llvm::SmallVector<int64_t > startIndices;
36+ llvm::SmallVector<int64_t > sliceSizes;
37+ int64_t sliceDim;
38+ int64_t sliceStart;
39+ bool supported;
40+ };
41+
42+ NewSliceInfo<mlir::stablehlo::SliceOp>
43+ constructNewSliceInfo (mlir::stablehlo::SliceOp sliceOp);
44+ NewSliceInfo<mlir::stablehlo::DynamicSliceOp>
45+ constructNewSliceInfo (mlir::stablehlo::DynamicSliceOp sliceOp);
46+
47+ bool areSlicesContiguous (
48+ llvm::SmallVector<NewSliceInfo<mlir::stablehlo::SliceOp>> &slices);
49+ bool areSlicesContiguous (
50+ llvm::SmallVector<NewSliceInfo<mlir::stablehlo::DynamicSliceOp>> &slices);
51+
2852struct ConcatInsertDimToBatchBase
2953 : public mlir::enzyme::CheckedOpRewritePattern<
3054 mlir::stablehlo::ConcatenateOp, ConcatInsertDimToBatchBase> {
@@ -108,11 +132,6 @@ struct SliceToBatchBase
108132
109133 SliceInfo extractSliceInfo (mlir::stablehlo::SliceOp slice) const ;
110134 bool areSlicesContiguous (llvm::SmallVector<SliceInfo> &slices) const ;
111- std::tuple<bool , bool >
112- allSameBool (const llvm::SmallVector<bool > &bools) const ;
113- std::tuple<bool , bool >
114- allSameBool (const llvm::SmallVector<mlir::Operation *> &ops) const ;
115- bool allOpsAreUnique (const llvm::SmallVector<mlir::Operation *> &ops) const ;
116135
117136protected:
118137 std::function<mlir::Operation *(mlir::Operation *)> isValidTargetOp;
@@ -182,13 +201,16 @@ struct SliceToBatchElementwise : public SliceToBatchBase {
182201};
183202
184203struct GreedyWhileLoopBatchFission
185- : public mlir::OpRewritePattern<mlir::stablehlo::WhileOp> {
186- using Base = mlir::OpRewritePattern<mlir::stablehlo::WhileOp>;
204+ : public mlir::enzyme::CheckedOpRewritePattern<
205+ mlir::stablehlo::WhileOp, GreedyWhileLoopBatchFission> {
206+ using Base =
207+ mlir::enzyme::CheckedOpRewritePattern<mlir::stablehlo::WhileOp,
208+ GreedyWhileLoopBatchFission>;
187209 using Base::Base;
188210
189211 mlir::LogicalResult
190- matchAndRewrite (mlir::stablehlo::WhileOp whileOp,
191- mlir::PatternRewriter &rewriter) const override ;
212+ matchAndRewriteImpl (mlir::stablehlo::WhileOp whileOp,
213+ mlir::PatternRewriter &rewriter) const ;
192214
193215private:
194216 struct DynamicSliceInfo {
0 commit comments