1212#include " src/enzyme_ad/jax/Passes/Passes.h"
1313#include " src/enzyme_ad/jax/Utils.h"
1414#include " stablehlo/dialect/StablehloOps.h"
15+ #include " llvm/ADT/DenseMap.h"
1516#include " llvm/ADT/SetVector.h"
1617#include " llvm/ADT/SmallVector.h"
1718
@@ -105,7 +106,7 @@ func::FuncOp createWrapperUnbatchedFunction(PatternRewriter &rewriter,
105106 return nullptr ;
106107 rewriter.setInsertionPointToStart (modOp.getBody ());
107108
108- SmallVector<Type> argTypes;
109+ SmallVector<mlir:: Type> argTypes;
109110 for (auto v : operands) {
110111 auto vType = cast<RankedTensorType>(v.getType ());
111112 auto shape = vType.getShape ();
@@ -770,6 +771,8 @@ GreedyWhileLoopBatchFission::matchAndRewrite(stablehlo::WhileOp whileOp,
770771 int64_t limit = info.getConstantLimit ().value ();
771772 int64_t step = info.getConstantStep ().value ();
772773
774+ // TODO: we can support start != 0 for applying a function to part of the
775+ // tensor
773776 if (start != 0 || step != 1 )
774777 return failure ();
775778
@@ -790,13 +793,14 @@ GreedyWhileLoopBatchFission::matchAndRewrite(stablehlo::WhileOp whileOp,
790793 // 2. Only one variable in the body is a direct descendant of the induction
791794 // variable
792795 // 3. The size of that dimension equals the limit
793- SmallVector<stablehlo::DynamicSliceOp > candidateSlices;
796+ SmallVector<DynamicSliceInfo > candidateSlices;
794797
795798 whileBody.walk ([&](stablehlo::DynamicSliceOp sliceOp) {
796799 // Check if this dynamic slice meets our criteria
797- if (isDynamicSliceValidForBatching (sliceOp, inductionVar, limit, whileBody,
798- parentBlock)) {
799- candidateSlices.push_back (sliceOp);
800+ auto dim = isDynamicSliceValidForBatching (sliceOp, inductionVar, limit,
801+ whileBody, parentBlock);
802+ if (dim != -1 ) {
803+ candidateSlices.push_back (DynamicSliceInfo{sliceOp, dim});
800804 llvm::errs () << " Found candidate dynamic slice: " << sliceOp << " \n " ;
801805 }
802806 });
@@ -809,18 +813,29 @@ GreedyWhileLoopBatchFission::matchAndRewrite(stablehlo::WhileOp whileOp,
809813 llvm::errs () << " Found " << candidateSlices.size ()
810814 << " candidate dynamic slices\n " ;
811815
812- bool wasLifted = false ;
816+ // Create a map of user operations to their corresponding dynamic slices
817+ DenseMap<Operation *, SmallVector<DynamicSliceInfo>> userOpToSlicesMap;
813818 for (auto ds : candidateSlices) {
814- for (auto op : ds->getUsers ()) {
815- // TODO: handle intermediate reshapes
819+ for (auto op : ds.sliceOp ->getUsers ()) {
820+ llvm::errs () << " mapping user op: " << *op << " to slice: " << ds.sliceOp
821+ << " \n " ;
822+ userOpToSlicesMap[op].push_back (ds);
823+ }
824+ }
816825
817- llvm::errs () << " trying to lift op: " << *op << " \n " ;
826+ if (userOpToSlicesMap.empty ()) {
827+ return failure ();
828+ }
818829
819- if (op->hasTrait <OpTrait::Elementwise>() &&
820- liftElementwiseOp (rewriter, whileOp, ds, op)) {
821- wasLifted = true ;
822- continue ;
823- }
830+ llvm::errs () << " Created map with " << userOpToSlicesMap.size ()
831+ << " user operations\n " ;
832+
833+ // Log statistics about multiple slices per operation
834+ bool wasLifted = false ;
835+ for (auto &[op, slices] : userOpToSlicesMap) {
836+ if (op->hasTrait <OpTrait::Elementwise>() &&
837+ liftElementwiseOp (rewriter, whileOp, slices, op, info)) {
838+ wasLifted = true ;
824839 }
825840 }
826841
@@ -829,35 +844,96 @@ GreedyWhileLoopBatchFission::matchAndRewrite(stablehlo::WhileOp whileOp,
829844
830845bool GreedyWhileLoopBatchFission::liftElementwiseOp (
831846 PatternRewriter &rewriter, stablehlo::WhileOp whileOp,
832- stablehlo::DynamicSliceOp sliceOp, Operation *elem) const {
847+ ArrayRef<DynamicSliceInfo> sliceOps, Operation *elem,
848+ WhileLoopInfo info) const {
833849 IRRewriter::InsertionGuard guard (rewriter);
850+ rewriter.setInsertionPoint (whileOp);
834851
835- // TODO: support non-unary elementwise ops
836- if (elem->getNumOperands () != 1 ) {
837- // For non-unary elementwise ops, we can
838- // 1. lift operands if they are produced by DSs in the list
839- // 2. defined outside the loop body, in which case we simply do a
840- // broadcast_in_dim
841- return false ;
852+ // For non-unary elementwise ops, we can
853+ // 1. lift operands if they are produced by DSs in the list
854+ // 2. defined outside the loop body, in which case we simply do a
855+ // broadcast_in_dim
856+ SmallVector<int64_t > constructionType (elem->getNumOperands (), -1 );
857+ SmallVector<Value> baseOperands (elem->getNumOperands ());
858+ stablehlo::DynamicSliceOp sliceOp;
859+ int64_t sliceDim = -1 ;
860+ for (auto [i, operand] : llvm::enumerate (elem->getOperands ())) {
861+ auto defOp = operand.getDefiningOp ();
862+ if (!defOp)
863+ return false ;
864+
865+ if (defOp->getBlock () == whileOp->getBlock ()) {
866+ constructionType[i] = 1 ;
867+ baseOperands[i] = operand;
868+ }
869+
870+ if (auto ds = dyn_cast<stablehlo::DynamicSliceOp>(defOp)) {
871+ auto itr = llvm::find_if (sliceOps, [&](const DynamicSliceInfo &info) {
872+ return info.sliceOp == ds;
873+ });
874+ if (itr != sliceOps.end ()) {
875+ constructionType[i] = 0 ;
876+ if (sliceDim == -1 ) {
877+ sliceOp = ds;
878+ sliceDim = itr->inductionVarDimension ;
879+ } else if (sliceDim != itr->inductionVarDimension ) {
880+ return false ;
881+ }
882+ baseOperands[i] = itr->sliceOp ->getOperand (0 );
883+ } else {
884+ return false ;
885+ }
886+ }
842887 }
843888
844- // unary ops are simpler, so we special case them
845- auto sliceOperand = sliceOp.getOperand ();
889+ assert (sliceDim != -1 );
890+
891+ SmallVector<Value> newOperands;
892+ for (auto [consType, baseOp] : llvm::zip (constructionType, baseOperands)) {
893+ if (consType == 0 ) { // originating from a dynamic slice
894+ newOperands.push_back (baseOp);
895+ } else if (consType == 1 ) { // broadcast_in_dim
896+ auto opShape =
897+ llvm::to_vector (cast<RankedTensorType>(baseOp.getType ()).getShape ());
898+ if (opShape[sliceDim] != 1 ) {
899+ whileOp->emitError (" broadcast_in_dim operand not broadcastable" );
900+ return false ;
901+ }
902+ opShape[sliceDim] = info.getConstantLimit ().value ();
903+ SmallVector<int64_t > mapping (opShape.size (), -1 );
904+ std::iota (mapping.begin (), mapping.end (), 0 );
905+ auto newBcastOp = rewriter.create <stablehlo::BroadcastInDimOp>(
906+ whileOp->getLoc (),
907+ RankedTensorType::get (
908+ opShape,
909+ cast<RankedTensorType>(baseOp.getType ()).getElementType ()),
910+ baseOp, rewriter.getDenseI64ArrayAttr (mapping));
911+ newOperands.push_back (newBcastOp->getResult (0 ));
912+ } else {
913+ whileOp->emitError (" unhandled construction type" );
914+ return false ;
915+ }
916+ }
846917
847- rewriter.setInsertionPoint (whileOp);
848- auto newOp = rewriter.create (elem->getLoc (), elem->getName ().getIdentifier (),
849- ValueRange ({sliceOperand}),
850- TypeRange ({sliceOperand.getType ()}),
851- elem->getAttrs (), {}, {});
918+ auto newElemShape = llvm::to_vector (
919+ cast<RankedTensorType>(elem->getResult (0 ).getType ()).getShape ());
920+ newElemShape[sliceDim] = info.getConstantLimit ().value ();
921+ auto newElemType = RankedTensorType::get (
922+ newElemShape,
923+ cast<RankedTensorType>(elem->getResult (0 ).getType ()).getElementType ());
924+
925+ auto newOp = rewriter.create (
926+ elem->getLoc (), elem->getName ().getIdentifier (), ValueRange (newOperands),
927+ TypeRange ({newElemType}), elem->getAttrs (), {}, {});
852928
853- rewriter.setInsertionPoint (sliceOp );
929+ rewriter.setInsertionPoint (elem );
854930 rewriter.replaceOpWithNewOp <stablehlo::DynamicSliceOp>(
855931 elem, newOp->getResult (0 ), sliceOp.getStartIndices (),
856932 sliceOp.getSliceSizes ());
857933 return true ;
858934}
859935
860- bool GreedyWhileLoopBatchFission::isDynamicSliceValidForBatching (
936+ int64_t GreedyWhileLoopBatchFission::isDynamicSliceValidForBatching (
861937 stablehlo::DynamicSliceOp sliceOp, Value inductionVar, int64_t limit,
862938 Block &whileBody, Block *parentBlock) const {
863939 // Get the start indices of the dynamic slice
@@ -867,7 +943,7 @@ bool GreedyWhileLoopBatchFission::isDynamicSliceValidForBatching(
867943 auto operandShape = cast<RankedTensorType>(operand.getType ()).getShape ();
868944
869945 if (operand.getParentBlock () != parentBlock)
870- return false ;
946+ return - 1 ;
871947
872948 // Track which start index corresponds to the induction variable descendant
873949 int32_t inductionVarDimension = -1 , indicesFromBody = 0 ;
@@ -876,7 +952,7 @@ bool GreedyWhileLoopBatchFission::isDynamicSliceValidForBatching(
876952 Value startIndex = startIndices[i];
877953 Operation *definingOp = startIndex.getDefiningOp ();
878954 if (!definingOp)
879- return false ;
955+ return - 1 ;
880956
881957 // Check if this start index is defined within the loop body
882958 if (definingOp->getBlock () == &whileBody) {
@@ -886,26 +962,30 @@ bool GreedyWhileLoopBatchFission::isDynamicSliceValidForBatching(
886962 if (isDirectDescendantOfInductionVar (startIndex, inductionVar)) {
887963 // Multiple dimensions are descendants of induction var - invalid
888964 if (inductionVarDimension != -1 )
889- return false ;
965+ return - 1 ;
890966 inductionVarDimension = i;
891967
892968 // Check if the slice size in this dimension equals the limit
969+ // TODO: relax the limit check at some point
893970 if (operandShape[i] != limit || sliceSizes[i] != 1 )
894- return false ;
971+ return - 1 ;
895972 }
896973 } else {
897- // TODO: ensure we are doing a full slice for now
974+ // TODO: we are only considering the full slice case for now. we can
975+ // generalize this
898976 if (!matchPattern (definingOp, m_Zero ()))
899- return false ;
977+ return - 1 ;
900978
901979 if (sliceSizes[i] != operandShape[i])
902- return false ;
980+ return - 1 ;
903981 }
904982 }
905983
906984 // We should have exactly one index from the body, and it should be
907985 // a descendant of the induction variable
908- return (indicesFromBody == 1 && inductionVarDimension != -1 );
986+ if (indicesFromBody == 1 )
987+ return inductionVarDimension;
988+ return -1 ;
909989}
910990
911991bool GreedyWhileLoopBatchFission::isDirectDescendantOfInductionVar (
0 commit comments