Skip to content

Commit 740aff6

Browse files
committed
feat: elementwise ops are completely supported
1 parent c4cebb9 commit 740aff6

File tree

2 files changed

+138
-45
lines changed

2 files changed

+138
-45
lines changed

src/enzyme_ad/jax/Passes/AutoBatching.cpp

Lines changed: 119 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include "src/enzyme_ad/jax/Passes/Passes.h"
1313
#include "src/enzyme_ad/jax/Utils.h"
1414
#include "stablehlo/dialect/StablehloOps.h"
15+
#include "llvm/ADT/DenseMap.h"
1516
#include "llvm/ADT/SetVector.h"
1617
#include "llvm/ADT/SmallVector.h"
1718

@@ -105,7 +106,7 @@ func::FuncOp createWrapperUnbatchedFunction(PatternRewriter &rewriter,
105106
return nullptr;
106107
rewriter.setInsertionPointToStart(modOp.getBody());
107108

108-
SmallVector<Type> argTypes;
109+
SmallVector<mlir::Type> argTypes;
109110
for (auto v : operands) {
110111
auto vType = cast<RankedTensorType>(v.getType());
111112
auto shape = vType.getShape();
@@ -770,6 +771,8 @@ GreedyWhileLoopBatchFission::matchAndRewrite(stablehlo::WhileOp whileOp,
770771
int64_t limit = info.getConstantLimit().value();
771772
int64_t step = info.getConstantStep().value();
772773

774+
// TODO: we can support start != 0 for applying a function to part of the
775+
// tensor
773776
if (start != 0 || step != 1)
774777
return failure();
775778

@@ -790,13 +793,14 @@ GreedyWhileLoopBatchFission::matchAndRewrite(stablehlo::WhileOp whileOp,
790793
// 2. Only one variable in the body is a direct descendant of the induction
791794
// variable
792795
// 3. The size of that dimension equals the limit
793-
SmallVector<stablehlo::DynamicSliceOp> candidateSlices;
796+
SmallVector<DynamicSliceInfo> candidateSlices;
794797

795798
whileBody.walk([&](stablehlo::DynamicSliceOp sliceOp) {
796799
// Check if this dynamic slice meets our criteria
797-
if (isDynamicSliceValidForBatching(sliceOp, inductionVar, limit, whileBody,
798-
parentBlock)) {
799-
candidateSlices.push_back(sliceOp);
800+
auto dim = isDynamicSliceValidForBatching(sliceOp, inductionVar, limit,
801+
whileBody, parentBlock);
802+
if (dim != -1) {
803+
candidateSlices.push_back(DynamicSliceInfo{sliceOp, dim});
800804
llvm::errs() << "Found candidate dynamic slice: " << sliceOp << "\n";
801805
}
802806
});
@@ -809,18 +813,29 @@ GreedyWhileLoopBatchFission::matchAndRewrite(stablehlo::WhileOp whileOp,
809813
llvm::errs() << "Found " << candidateSlices.size()
810814
<< " candidate dynamic slices\n";
811815

812-
bool wasLifted = false;
816+
// Create a map of user operations to their corresponding dynamic slices
817+
DenseMap<Operation *, SmallVector<DynamicSliceInfo>> userOpToSlicesMap;
813818
for (auto ds : candidateSlices) {
814-
for (auto op : ds->getUsers()) {
815-
// TODO: handle intermediate reshapes
819+
for (auto op : ds.sliceOp->getUsers()) {
820+
llvm::errs() << "mapping user op: " << *op << " to slice: " << ds.sliceOp
821+
<< "\n";
822+
userOpToSlicesMap[op].push_back(ds);
823+
}
824+
}
816825

817-
llvm::errs() << "trying to lift op: " << *op << "\n";
826+
if (userOpToSlicesMap.empty()) {
827+
return failure();
828+
}
818829

819-
if (op->hasTrait<OpTrait::Elementwise>() &&
820-
liftElementwiseOp(rewriter, whileOp, ds, op)) {
821-
wasLifted = true;
822-
continue;
823-
}
830+
llvm::errs() << "Created map with " << userOpToSlicesMap.size()
831+
<< " user operations\n";
832+
833+
// Log statistics about multiple slices per operation
834+
bool wasLifted = false;
835+
for (auto &[op, slices] : userOpToSlicesMap) {
836+
if (op->hasTrait<OpTrait::Elementwise>() &&
837+
liftElementwiseOp(rewriter, whileOp, slices, op, info)) {
838+
wasLifted = true;
824839
}
825840
}
826841

@@ -829,35 +844,96 @@ GreedyWhileLoopBatchFission::matchAndRewrite(stablehlo::WhileOp whileOp,
829844

830845
bool GreedyWhileLoopBatchFission::liftElementwiseOp(
831846
PatternRewriter &rewriter, stablehlo::WhileOp whileOp,
832-
stablehlo::DynamicSliceOp sliceOp, Operation *elem) const {
847+
ArrayRef<DynamicSliceInfo> sliceOps, Operation *elem,
848+
WhileLoopInfo info) const {
833849
IRRewriter::InsertionGuard guard(rewriter);
850+
rewriter.setInsertionPoint(whileOp);
834851

835-
// TODO: support non-unary elementwise ops
836-
if (elem->getNumOperands() != 1) {
837-
// For non-unary elementwise ops, we can
838-
// 1. lift operands if they are produced by DSs in the list
839-
// 2. defined outside the loop body, in which case we simply do a
840-
// broadcast_in_dim
841-
return false;
852+
// For non-unary elementwise ops, we can
853+
// 1. lift operands if they are produced by DSs in the list
854+
// 2. defined outside the loop body, in which case we simply do a
855+
// broadcast_in_dim
856+
SmallVector<int64_t> constructionType(elem->getNumOperands(), -1);
857+
SmallVector<Value> baseOperands(elem->getNumOperands());
858+
stablehlo::DynamicSliceOp sliceOp;
859+
int64_t sliceDim = -1;
860+
for (auto [i, operand] : llvm::enumerate(elem->getOperands())) {
861+
auto defOp = operand.getDefiningOp();
862+
if (!defOp)
863+
return false;
864+
865+
if (defOp->getBlock() == whileOp->getBlock()) {
866+
constructionType[i] = 1;
867+
baseOperands[i] = operand;
868+
}
869+
870+
if (auto ds = dyn_cast<stablehlo::DynamicSliceOp>(defOp)) {
871+
auto itr = llvm::find_if(sliceOps, [&](const DynamicSliceInfo &info) {
872+
return info.sliceOp == ds;
873+
});
874+
if (itr != sliceOps.end()) {
875+
constructionType[i] = 0;
876+
if (sliceDim == -1) {
877+
sliceOp = ds;
878+
sliceDim = itr->inductionVarDimension;
879+
} else if (sliceDim != itr->inductionVarDimension) {
880+
return false;
881+
}
882+
baseOperands[i] = itr->sliceOp->getOperand(0);
883+
} else {
884+
return false;
885+
}
886+
}
842887
}
843888

844-
// unary ops are simpler, so we special case them
845-
auto sliceOperand = sliceOp.getOperand();
889+
assert(sliceDim != -1);
890+
891+
SmallVector<Value> newOperands;
892+
for (auto [consType, baseOp] : llvm::zip(constructionType, baseOperands)) {
893+
if (consType == 0) { // originating from a dynamic slice
894+
newOperands.push_back(baseOp);
895+
} else if (consType == 1) { // broadcast_in_dim
896+
auto opShape =
897+
llvm::to_vector(cast<RankedTensorType>(baseOp.getType()).getShape());
898+
if (opShape[sliceDim] != 1) {
899+
whileOp->emitError("broadcast_in_dim operand not broadcastable");
900+
return false;
901+
}
902+
opShape[sliceDim] = info.getConstantLimit().value();
903+
SmallVector<int64_t> mapping(opShape.size(), -1);
904+
std::iota(mapping.begin(), mapping.end(), 0);
905+
auto newBcastOp = rewriter.create<stablehlo::BroadcastInDimOp>(
906+
whileOp->getLoc(),
907+
RankedTensorType::get(
908+
opShape,
909+
cast<RankedTensorType>(baseOp.getType()).getElementType()),
910+
baseOp, rewriter.getDenseI64ArrayAttr(mapping));
911+
newOperands.push_back(newBcastOp->getResult(0));
912+
} else {
913+
whileOp->emitError("unhandled construction type");
914+
return false;
915+
}
916+
}
846917

847-
rewriter.setInsertionPoint(whileOp);
848-
auto newOp = rewriter.create(elem->getLoc(), elem->getName().getIdentifier(),
849-
ValueRange({sliceOperand}),
850-
TypeRange({sliceOperand.getType()}),
851-
elem->getAttrs(), {}, {});
918+
auto newElemShape = llvm::to_vector(
919+
cast<RankedTensorType>(elem->getResult(0).getType()).getShape());
920+
newElemShape[sliceDim] = info.getConstantLimit().value();
921+
auto newElemType = RankedTensorType::get(
922+
newElemShape,
923+
cast<RankedTensorType>(elem->getResult(0).getType()).getElementType());
924+
925+
auto newOp = rewriter.create(
926+
elem->getLoc(), elem->getName().getIdentifier(), ValueRange(newOperands),
927+
TypeRange({newElemType}), elem->getAttrs(), {}, {});
852928

853-
rewriter.setInsertionPoint(sliceOp);
929+
rewriter.setInsertionPoint(elem);
854930
rewriter.replaceOpWithNewOp<stablehlo::DynamicSliceOp>(
855931
elem, newOp->getResult(0), sliceOp.getStartIndices(),
856932
sliceOp.getSliceSizes());
857933
return true;
858934
}
859935

860-
bool GreedyWhileLoopBatchFission::isDynamicSliceValidForBatching(
936+
int64_t GreedyWhileLoopBatchFission::isDynamicSliceValidForBatching(
861937
stablehlo::DynamicSliceOp sliceOp, Value inductionVar, int64_t limit,
862938
Block &whileBody, Block *parentBlock) const {
863939
// Get the start indices of the dynamic slice
@@ -867,7 +943,7 @@ bool GreedyWhileLoopBatchFission::isDynamicSliceValidForBatching(
867943
auto operandShape = cast<RankedTensorType>(operand.getType()).getShape();
868944

869945
if (operand.getParentBlock() != parentBlock)
870-
return false;
946+
return -1;
871947

872948
// Track which start index corresponds to the induction variable descendant
873949
int32_t inductionVarDimension = -1, indicesFromBody = 0;
@@ -876,7 +952,7 @@ bool GreedyWhileLoopBatchFission::isDynamicSliceValidForBatching(
876952
Value startIndex = startIndices[i];
877953
Operation *definingOp = startIndex.getDefiningOp();
878954
if (!definingOp)
879-
return false;
955+
return -1;
880956

881957
// Check if this start index is defined within the loop body
882958
if (definingOp->getBlock() == &whileBody) {
@@ -886,26 +962,30 @@ bool GreedyWhileLoopBatchFission::isDynamicSliceValidForBatching(
886962
if (isDirectDescendantOfInductionVar(startIndex, inductionVar)) {
887963
// Multiple dimensions are descendants of induction var - invalid
888964
if (inductionVarDimension != -1)
889-
return false;
965+
return -1;
890966
inductionVarDimension = i;
891967

892968
// Check if the slice size in this dimension equals the limit
969+
// TODO: relax the limit check at some point
893970
if (operandShape[i] != limit || sliceSizes[i] != 1)
894-
return false;
971+
return -1;
895972
}
896973
} else {
897-
// TODO: ensure we are doing a full slice for now
974+
// TODO: we are only considering the full slice case for now. we can
975+
// generalize this
898976
if (!matchPattern(definingOp, m_Zero()))
899-
return false;
977+
return -1;
900978

901979
if (sliceSizes[i] != operandShape[i])
902-
return false;
980+
return -1;
903981
}
904982
}
905983

906984
// We should have exactly one index from the body, and it should be
907985
// a descendant of the induction variable
908-
return (indicesFromBody == 1 && inductionVarDimension != -1);
986+
if (indicesFromBody == 1)
987+
return inductionVarDimension;
988+
return -1;
909989
}
910990

911991
bool GreedyWhileLoopBatchFission::isDirectDescendantOfInductionVar(

src/enzyme_ad/jax/Passes/AutoBatching.h

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,14 @@
88
#include <tuple>
99
#include <vector>
1010

11+
// Loading the header causes a bunch of ambiguous errors
12+
// #include "src/enzyme_ad/jax/Implementations/WhileLoopInfo.h"
13+
namespace mlir {
14+
namespace enzyme {
15+
struct WhileLoopInfo;
16+
}; // namespace enzyme
17+
}; // namespace mlir
18+
1119
struct BatchOperandConstructionInfo {
1220
mlir::stablehlo::SliceOp sliceOp;
1321
int32_t sliceOperandIndex;
@@ -182,6 +190,11 @@ struct GreedyWhileLoopBatchFission
182190
mlir::PatternRewriter &rewriter) const override;
183191

184192
private:
193+
struct DynamicSliceInfo {
194+
mlir::stablehlo::DynamicSliceOp sliceOp;
195+
int64_t inductionVarDimension;
196+
};
197+
185198
bool isDirectDescendantOfInductionVar(mlir::Value value,
186199
mlir::Value inductionVar) const;
187200

@@ -190,13 +203,13 @@ struct GreedyWhileLoopBatchFission
190203
bool isChainOfAddSubtractConverts(mlir::Value value,
191204
mlir::Value inductionVar) const;
192205

193-
bool isDynamicSliceValidForBatching(mlir::stablehlo::DynamicSliceOp sliceOp,
194-
mlir::Value inductionVar, int64_t limit,
195-
mlir::Block &whileBody,
196-
mlir::Block *parentBlock) const;
206+
int64_t isDynamicSliceValidForBatching(
207+
mlir::stablehlo::DynamicSliceOp sliceOp, mlir::Value inductionVar,
208+
int64_t limit, mlir::Block &whileBody, mlir::Block *parentBlock) const;
197209

198210
bool liftElementwiseOp(mlir::PatternRewriter &rewriter,
199211
mlir::stablehlo::WhileOp whileOp,
200-
mlir::stablehlo::DynamicSliceOp sliceOp,
201-
mlir::Operation *op) const;
212+
llvm::ArrayRef<DynamicSliceInfo> sliceOps,
213+
mlir::Operation *op,
214+
mlir::enzyme::WhileLoopInfo info) const;
202215
};

0 commit comments

Comments
 (0)