diff --git a/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp b/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp index a126026678..2fc817cc26 100644 --- a/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp +++ b/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp @@ -20378,6 +20378,139 @@ struct SplitConvolutionIntoReverseConvolution final } }; +// partial workaround for: https://github.com/openxla/xla/issues/29362 +struct SplitMultiResultScatter + : public CheckedOpRewritePattern { + using CheckedOpRewritePattern::CheckedOpRewritePattern; + + LogicalResult matchAndRewriteImpl(stablehlo::ScatterOp op, + PatternRewriter &rewriter) const { + // This pattern is for scatters with multiple independent results. + if (op.getNumResults() <= 1) { + return rewriter.notifyMatchFailure( + op, "scatter does not have multiple results"); + } + + // The body of the scatter should perform an independent update. + // This means the returned values are just the update arguments. + Block &body = op.getUpdateComputation().front(); + auto returnOp = dyn_cast(body.getTerminator()); + if (!returnOp || returnOp.getNumOperands() != op.getNumResults()) { + return rewriter.notifyMatchFailure( + op, "scatter body does not have a suitable return op"); + } + + // The arguments to the body are first the operands, then the updates. + // We expect the return op to return the update arguments. + size_t numInputs = op.getInputs().size(); + // Map from a value to the set of block argument indices it depends on. + llvm::DenseMap> dependencyMap; + + // Initialize dependencies for block arguments. + for (unsigned i = 0; i < body.getNumArguments(); ++i) { + dependencyMap[body.getArgument(i)].insert(i); + } + + // Propagate dependencies through the operations in the body. + for (Operation &opInBody : body.without_terminator()) { + llvm::SmallSet opDependencies; + for (Value operand : opInBody.getOperands()) { + if (auto blockArg = dyn_cast(operand)) { + if (blockArg.getOwner() == &body) { + opDependencies.insert(blockArg.getArgNumber()); + } + } else if (dependencyMap.count(operand)) { + opDependencies.insert(dependencyMap[operand].begin(), dependencyMap[operand].end()); + } + } + for (Value result : opInBody.getResults()) { + dependencyMap[result] = opDependencies; + } + } + + // Check dependencies for each returned value. + for (unsigned i = 0; i < returnOp.getNumOperands(); ++i) { + Value returnedVal = returnOp.getOperand(i); + llvm::SmallSet dependencies; + + if (auto blockArg = dyn_cast(returnedVal)) { + if (blockArg.getOwner() == &body) { + dependencies.insert(blockArg.getArgNumber()); + } + } else if (dependencyMap.count(returnedVal)) { + dependencies = dependencyMap[returnedVal]; + } + + for (unsigned depIndex : dependencies) { + bool isCurrentInput = depIndex == i; + bool isCurrentUpdate = depIndex == (numInputs + i); + if (!isCurrentInput && !isCurrentUpdate) { + return rewriter.notifyMatchFailure( + op, "computation for result " + std::to_string(i) + + " has an invalid dependency on argument " + std::to_string(depIndex)); + } + } + } + + // If we reached here, the pattern matches. Now, rewrite. + SmallVector newResults; + for (unsigned i = 0; i < op.getNumResults(); ++i) { + auto newScatterOp = rewriter.create( + op.getLoc(), op.getResult(i).getType(), op.getInputs()[i], + op.getScatterIndices(), op.getUpdates()[i], + op.getScatterDimensionNumbersAttr(), op.getIndicesAreSortedAttr(), + op.getUniqueIndicesAttr()); + + { + OpBuilder::InsertionGuard guard(rewriter); + IRMapping mapper; + Block *newBlock = rewriter.createBlock(&newScatterOp.getUpdateComputation()); + + Type operandArgType = body.getArgument(i).getType(); + Type updateArgType = body.getArgument(numInputs + i).getType(); + newBlock->addArgument(operandArgType, op.getLoc()); + newBlock->addArgument(updateArgType, op.getLoc()); + + mapper.map(body.getArgument(i), newBlock->getArgument(0)); + mapper.map(body.getArgument(numInputs + i), newBlock->getArgument(1)); + + for(Operation &opInBody : body.without_terminator()) { + // Only clone operations that are relevant for the current result. + // We check if any of the op's results are in the dependency map. + bool isRelevant = false; + for (Value result : opInBody.getResults()) { + if (dependencyMap.count(result)) { + const auto& deps = dependencyMap[result]; + for(unsigned dep : deps) { + if (dep == i || dep == numInputs + i) { + isRelevant = true; + break; + } + } + } + if (isRelevant) break; + } + + if (isRelevant) { + rewriter.clone(opInBody, mapper); + } + } + + Value returnedVal = returnOp.getOperand(i); + rewriter.setInsertionPointToEnd(newBlock); + rewriter.create(op.getLoc(), mapper.lookup(returnedVal)); + } + + newResults.push_back(newScatterOp.getResult(0)); + } + + rewriter.replaceOp(op, newResults); + + return success(); + } +}; + struct ScatterMultiplySimplify final : public CheckedOpRewritePattern { @@ -22201,6 +22334,7 @@ struct EnzymeHLOOptPass patterns.add(context); patterns.add(context); patterns.add(context); + patterns.add(context); patterns.add(context); patterns.add< diff --git a/src/enzyme_ad/jax/TransformOps/TransformOps.td b/src/enzyme_ad/jax/TransformOps/TransformOps.td index a25438f123..ac68def810 100644 --- a/src/enzyme_ad/jax/TransformOps/TransformOps.td +++ b/src/enzyme_ad/jax/TransformOps/TransformOps.td @@ -2129,6 +2129,13 @@ def ApplySplitConvolutionIntoReverseConvolution : EnzymeHLOPatternOp< ]; } +def ApplySplitMultiResultScatter : EnzymeHLOPatternOp< + "split_multi_result_scatter"> { + let patterns = [ + "SplitMultiResultScatter" + ]; +} + def ApplyScatterMultiplySimplify : EnzymeHLOPatternOp<"scatter_multiply_simplify"> { let patterns = [ "ScatterMultiplySimplify" diff --git a/test/lit_tests/split_multi_result_scatter.mlir b/test/lit_tests/split_multi_result_scatter.mlir new file mode 100644 index 0000000000..2ee8f250c7 --- /dev/null +++ b/test/lit_tests/split_multi_result_scatter.mlir @@ -0,0 +1,47 @@ +// RUN: enzymexlamlir-opt --enzyme-hlo-generate-td="patterns=split_multi_result_scatter" --transform-interpreter --enzyme-hlo-remove-transform %s | FileCheck %s + +// CHECK: func.func private @can_rewrite(%arg0: tensor<2xf32>, %arg1: tensor<2xf32>) -> (tensor<3xf32>, tensor<3xf32>) { +// CHECK-NEXT: %cst = stablehlo.constant dense<0.000000e+00> : tensor<3xf32> +// CHECK-NEXT: %c = stablehlo.constant dense<{{\[}}{{\[}}0], [1]]> : tensor<2x1xi32> +// CHECK-NEXT: %0 = "stablehlo.scatter"(%cst, %c, %arg0) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = true}> ({ +// CHECK-NEXT: ^bb0(%arg2: tensor, %arg3: tensor): +// CHECK-NEXT: %2 = stablehlo.abs %arg3 : tensor +// CHECK-NEXT: stablehlo.return %2 : tensor +// CHECK-NEXT: }) : (tensor<3xf32>, tensor<2x1xi32>, tensor<2xf32>) -> tensor<3xf32> +// CHECK-NEXT: %1 = "stablehlo.scatter"(%cst, %c, %arg1) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = true}> ({ +// CHECK-NEXT: ^bb0(%arg2: tensor, %arg3: tensor): +// CHECK-NEXT: stablehlo.return %arg3 : tensor +// CHECK-NEXT: }) : (tensor<3xf32>, tensor<2x1xi32>, tensor<2xf32>) -> tensor<3xf32> +// CHECK-NEXT: return %0, %1 : tensor<3xf32>, tensor<3xf32> +// CHECK-NEXT: } +func.func private @can_rewrite(%arg0: tensor<2xf32>, %arg1: tensor<2xf32>) -> (tensor<3xf32>, tensor<3xf32>) { + %cst = stablehlo.constant dense<0.000000e+00> : tensor<3xf32> + %c = stablehlo.constant dense<[[0], [1]]> : tensor<2x1xi32> + %0:2 = "stablehlo.scatter"(%cst, %cst, %c, %arg0, %arg1) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = true}> ({ + ^bb0(%arg2: tensor, %arg3: tensor, %arg4: tensor, %arg5: tensor): + %a = "stablehlo.abs"(%arg4) : (tensor) -> tensor + stablehlo.return %a, %arg5 : tensor, tensor + }) : (tensor<3xf32>, tensor<3xf32>, tensor<2x1xi32>, tensor<2xf32>, tensor<2xf32>) -> (tensor<3xf32>, tensor<3xf32>) + return %0#0, %0#1 : tensor<3xf32>, tensor<3xf32> +} + +// CHECK: func.func private @cannot_rewrite(%arg0: tensor<2xf32>, %arg1: tensor<2xf32>) -> (tensor<3xf32>, tensor<3xf32>) { +// CHECK-NEXT: %cst = stablehlo.constant dense<0.000000e+00> : tensor<3xf32> +// CHECK-NEXT: %c = stablehlo.constant dense<{{\[}}{{\[}}0], [1]]> : tensor<2x1xi32> +// CHECK-NEXT: %0:2 = "stablehlo.scatter"(%cst, %cst, %c, %arg0, %arg1) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = true}> ({ +// CHECK-NEXT: ^bb0(%arg2: tensor, %arg3: tensor, %arg4: tensor, %arg5: tensor): +// CHECK-NEXT: %1 = stablehlo.add %arg4, %arg3 : tensor +// CHECK-NEXT: stablehlo.return %1, %arg5 : tensor, tensor +// CHECK-NEXT: }) : (tensor<3xf32>, tensor<3xf32>, tensor<2x1xi32>, tensor<2xf32>, tensor<2xf32>) -> (tensor<3xf32>, tensor<3xf32>) +// CHECK-NEXT: return %0#0, %0#1 : tensor<3xf32>, tensor<3xf32> +// CHECK-NEXT: } +func.func private @cannot_rewrite(%arg0: tensor<2xf32>, %arg1: tensor<2xf32>) -> (tensor<3xf32>, tensor<3xf32>) { + %cst = stablehlo.constant dense<0.000000e+00> : tensor<3xf32> + %c = stablehlo.constant dense<[[0], [1]]> : tensor<2x1xi32> + %0:2 = "stablehlo.scatter"(%cst, %cst, %c, %arg0, %arg1) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = true}> ({ + ^bb0(%arg2: tensor, %arg3: tensor, %arg4: tensor, %arg5: tensor): + %a = "stablehlo.add"(%arg4, %arg3) : (tensor, tensor) -> tensor // result depends on multiple operands: can't be transformed! + stablehlo.return %a, %arg5 : tensor, tensor + }) : (tensor<3xf32>, tensor<3xf32>, tensor<2x1xi32>, tensor<2xf32>, tensor<2xf32>) -> (tensor<3xf32>, tensor<3xf32>) + return %0#0, %0#1 : tensor<3xf32>, tensor<3xf32> +}