Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
134 changes: 134 additions & 0 deletions src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20378,6 +20378,139 @@ struct SplitConvolutionIntoReverseConvolution final
}
};

// partial workaround for: https://github.com/openxla/xla/issues/29362
struct SplitMultiResultScatter
: public CheckedOpRewritePattern<stablehlo::ScatterOp,
SplitMultiResultScatter> {
using CheckedOpRewritePattern::CheckedOpRewritePattern;

LogicalResult matchAndRewriteImpl(stablehlo::ScatterOp op,
PatternRewriter &rewriter) const {
// This pattern is for scatters with multiple independent results.
if (op.getNumResults() <= 1) {
return rewriter.notifyMatchFailure(
op, "scatter does not have multiple results");
}

// The body of the scatter should perform an independent update.
// This means the returned values are just the update arguments.
Block &body = op.getUpdateComputation().front();
auto returnOp = dyn_cast<stablehlo::ReturnOp>(body.getTerminator());
if (!returnOp || returnOp.getNumOperands() != op.getNumResults()) {
return rewriter.notifyMatchFailure(
op, "scatter body does not have a suitable return op");
}

// The arguments to the body are first the operands, then the updates.
// We expect the return op to return the update arguments.
size_t numInputs = op.getInputs().size();
// Map from a value to the set of block argument indices it depends on.
llvm::DenseMap<Value, llvm::SmallSet<unsigned, 4>> dependencyMap;

// Initialize dependencies for block arguments.
for (unsigned i = 0; i < body.getNumArguments(); ++i) {
dependencyMap[body.getArgument(i)].insert(i);
}

// Propagate dependencies through the operations in the body.
for (Operation &opInBody : body.without_terminator()) {
llvm::SmallSet<unsigned, 4> opDependencies;
for (Value operand : opInBody.getOperands()) {
if (auto blockArg = dyn_cast<BlockArgument>(operand)) {
if (blockArg.getOwner() == &body) {
opDependencies.insert(blockArg.getArgNumber());
}
} else if (dependencyMap.count(operand)) {
opDependencies.insert(dependencyMap[operand].begin(), dependencyMap[operand].end());
}
}
for (Value result : opInBody.getResults()) {
dependencyMap[result] = opDependencies;
}
}

// Check dependencies for each returned value.
for (unsigned i = 0; i < returnOp.getNumOperands(); ++i) {
Value returnedVal = returnOp.getOperand(i);
llvm::SmallSet<unsigned, 4> dependencies;

if (auto blockArg = dyn_cast<BlockArgument>(returnedVal)) {
if (blockArg.getOwner() == &body) {
dependencies.insert(blockArg.getArgNumber());
}
} else if (dependencyMap.count(returnedVal)) {
dependencies = dependencyMap[returnedVal];
}

for (unsigned depIndex : dependencies) {
bool isCurrentInput = depIndex == i;
bool isCurrentUpdate = depIndex == (numInputs + i);
if (!isCurrentInput && !isCurrentUpdate) {
return rewriter.notifyMatchFailure(
op, "computation for result " + std::to_string(i) +
" has an invalid dependency on argument " + std::to_string(depIndex));
}
}
}

// If we reached here, the pattern matches. Now, rewrite.
SmallVector<Value> newResults;
for (unsigned i = 0; i < op.getNumResults(); ++i) {
auto newScatterOp = rewriter.create<stablehlo::ScatterOp>(
op.getLoc(), op.getResult(i).getType(), op.getInputs()[i],
op.getScatterIndices(), op.getUpdates()[i],
op.getScatterDimensionNumbersAttr(), op.getIndicesAreSortedAttr(),
op.getUniqueIndicesAttr());

{
OpBuilder::InsertionGuard guard(rewriter);
IRMapping mapper;
Block *newBlock = rewriter.createBlock(&newScatterOp.getUpdateComputation());

Type operandArgType = body.getArgument(i).getType();
Type updateArgType = body.getArgument(numInputs + i).getType();
newBlock->addArgument(operandArgType, op.getLoc());
newBlock->addArgument(updateArgType, op.getLoc());

mapper.map(body.getArgument(i), newBlock->getArgument(0));
mapper.map(body.getArgument(numInputs + i), newBlock->getArgument(1));

for(Operation &opInBody : body.without_terminator()) {
// Only clone operations that are relevant for the current result.
// We check if any of the op's results are in the dependency map.
bool isRelevant = false;
for (Value result : opInBody.getResults()) {
if (dependencyMap.count(result)) {
const auto& deps = dependencyMap[result];
for(unsigned dep : deps) {
if (dep == i || dep == numInputs + i) {
isRelevant = true;
break;
}
}
}
if (isRelevant) break;
}

if (isRelevant) {
rewriter.clone(opInBody, mapper);
}
}

Value returnedVal = returnOp.getOperand(i);
rewriter.setInsertionPointToEnd(newBlock);
rewriter.create<stablehlo::ReturnOp>(op.getLoc(), mapper.lookup(returnedVal));
}

newResults.push_back(newScatterOp.getResult(0));
}

rewriter.replaceOp(op, newResults);

return success();
}
};

struct ScatterMultiplySimplify final
: public CheckedOpRewritePattern<stablehlo::MulOp,
ScatterMultiplySimplify> {
Expand Down Expand Up @@ -22201,6 +22334,7 @@ struct EnzymeHLOOptPass
patterns.add<TransposeWrap>(context);
patterns.add<TransposeExtend>(context);
patterns.add<TransposeRotate>(context);
patterns.add<SplitMultiResultScatter>(context);
patterns.add<SelectPad>(context);

patterns.add<
Expand Down
7 changes: 7 additions & 0 deletions src/enzyme_ad/jax/TransformOps/TransformOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -2129,6 +2129,13 @@ def ApplySplitConvolutionIntoReverseConvolution : EnzymeHLOPatternOp<
];
}

def ApplySplitMultiResultScatter : EnzymeHLOPatternOp<
"split_multi_result_scatter"> {
let patterns = [
"SplitMultiResultScatter"
];
}

def ApplyScatterMultiplySimplify : EnzymeHLOPatternOp<"scatter_multiply_simplify"> {
let patterns = [
"ScatterMultiplySimplify"
Expand Down
47 changes: 47 additions & 0 deletions test/lit_tests/split_multi_result_scatter.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
// RUN: enzymexlamlir-opt --enzyme-hlo-generate-td="patterns=split_multi_result_scatter" --transform-interpreter --enzyme-hlo-remove-transform %s | FileCheck %s

// CHECK: func.func private @can_rewrite(%arg0: tensor<2xf32>, %arg1: tensor<2xf32>) -> (tensor<3xf32>, tensor<3xf32>) {
// CHECK-NEXT: %cst = stablehlo.constant dense<0.000000e+00> : tensor<3xf32>
// CHECK-NEXT: %c = stablehlo.constant dense<{{\[}}{{\[}}0], [1]]> : tensor<2x1xi32>
// CHECK-NEXT: %0 = "stablehlo.scatter"(%cst, %c, %arg0) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<inserted_window_dims = [0], scatter_dims_to_operand_dims = [0], index_vector_dim = 1>, unique_indices = true}> ({
// CHECK-NEXT: ^bb0(%arg2: tensor<f32>, %arg3: tensor<f32>):
// CHECK-NEXT: %2 = stablehlo.abs %arg3 : tensor<f32>
// CHECK-NEXT: stablehlo.return %2 : tensor<f32>
// CHECK-NEXT: }) : (tensor<3xf32>, tensor<2x1xi32>, tensor<2xf32>) -> tensor<3xf32>
// CHECK-NEXT: %1 = "stablehlo.scatter"(%cst, %c, %arg1) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<inserted_window_dims = [0], scatter_dims_to_operand_dims = [0], index_vector_dim = 1>, unique_indices = true}> ({
// CHECK-NEXT: ^bb0(%arg2: tensor<f32>, %arg3: tensor<f32>):
// CHECK-NEXT: stablehlo.return %arg3 : tensor<f32>
// CHECK-NEXT: }) : (tensor<3xf32>, tensor<2x1xi32>, tensor<2xf32>) -> tensor<3xf32>
// CHECK-NEXT: return %0, %1 : tensor<3xf32>, tensor<3xf32>
// CHECK-NEXT: }
func.func private @can_rewrite(%arg0: tensor<2xf32>, %arg1: tensor<2xf32>) -> (tensor<3xf32>, tensor<3xf32>) {
%cst = stablehlo.constant dense<0.000000e+00> : tensor<3xf32>
%c = stablehlo.constant dense<[[0], [1]]> : tensor<2x1xi32>
%0:2 = "stablehlo.scatter"(%cst, %cst, %c, %arg0, %arg1) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<inserted_window_dims = [0], scatter_dims_to_operand_dims = [0], index_vector_dim = 1>, unique_indices = true}> ({
^bb0(%arg2: tensor<f32>, %arg3: tensor<f32>, %arg4: tensor<f32>, %arg5: tensor<f32>):
%a = "stablehlo.abs"(%arg4) : (tensor<f32>) -> tensor<f32>
stablehlo.return %a, %arg5 : tensor<f32>, tensor<f32>
}) : (tensor<3xf32>, tensor<3xf32>, tensor<2x1xi32>, tensor<2xf32>, tensor<2xf32>) -> (tensor<3xf32>, tensor<3xf32>)
return %0#0, %0#1 : tensor<3xf32>, tensor<3xf32>
}

// CHECK: func.func private @cannot_rewrite(%arg0: tensor<2xf32>, %arg1: tensor<2xf32>) -> (tensor<3xf32>, tensor<3xf32>) {
// CHECK-NEXT: %cst = stablehlo.constant dense<0.000000e+00> : tensor<3xf32>
// CHECK-NEXT: %c = stablehlo.constant dense<{{\[}}{{\[}}0], [1]]> : tensor<2x1xi32>
// CHECK-NEXT: %0:2 = "stablehlo.scatter"(%cst, %cst, %c, %arg0, %arg1) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<inserted_window_dims = [0], scatter_dims_to_operand_dims = [0], index_vector_dim = 1>, unique_indices = true}> ({
// CHECK-NEXT: ^bb0(%arg2: tensor<f32>, %arg3: tensor<f32>, %arg4: tensor<f32>, %arg5: tensor<f32>):
// CHECK-NEXT: %1 = stablehlo.add %arg4, %arg3 : tensor<f32>
// CHECK-NEXT: stablehlo.return %1, %arg5 : tensor<f32>, tensor<f32>
// CHECK-NEXT: }) : (tensor<3xf32>, tensor<3xf32>, tensor<2x1xi32>, tensor<2xf32>, tensor<2xf32>) -> (tensor<3xf32>, tensor<3xf32>)
// CHECK-NEXT: return %0#0, %0#1 : tensor<3xf32>, tensor<3xf32>
// CHECK-NEXT: }
func.func private @cannot_rewrite(%arg0: tensor<2xf32>, %arg1: tensor<2xf32>) -> (tensor<3xf32>, tensor<3xf32>) {
%cst = stablehlo.constant dense<0.000000e+00> : tensor<3xf32>
%c = stablehlo.constant dense<[[0], [1]]> : tensor<2x1xi32>
%0:2 = "stablehlo.scatter"(%cst, %cst, %c, %arg0, %arg1) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<inserted_window_dims = [0], scatter_dims_to_operand_dims = [0], index_vector_dim = 1>, unique_indices = true}> ({
^bb0(%arg2: tensor<f32>, %arg3: tensor<f32>, %arg4: tensor<f32>, %arg5: tensor<f32>):
%a = "stablehlo.add"(%arg4, %arg3) : (tensor<f32>, tensor<f32>) -> tensor<f32> // result depends on multiple operands: can't be transformed!
stablehlo.return %a, %arg5 : tensor<f32>, tensor<f32>
}) : (tensor<3xf32>, tensor<3xf32>, tensor<2x1xi32>, tensor<2xf32>, tensor<2xf32>) -> (tensor<3xf32>, tensor<3xf32>)
return %0#0, %0#1 : tensor<3xf32>, tensor<3xf32>
}
Loading