Skip to content

Commit

Permalink
[STABLEHLO] Added Softmax Op conversion support to StableHLO dialect (#…
Browse files Browse the repository at this point in the history
…2886)

Signed-off-by: Abhishek-TyRnT <[email protected]>
Co-authored-by: Alexandre Eichenberger <[email protected]>
Co-authored-by: Tung D. Le <[email protected]>
  • Loading branch information
3 people authored Jul 26, 2024
1 parent 2044d52 commit 226a0d6
Show file tree
Hide file tree
Showing 5 changed files with 230 additions and 42 deletions.
1 change: 1 addition & 0 deletions src/Conversion/ONNXToStablehlo/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ add_onnx_mlir_library(OMONNXToStablehlo
Math/Gemm.cpp
Math/MatMul.cpp
Math/Reduction.cpp
Math/Softmax.cpp
NN/Conv.cpp
NN/ConvTranspose.cpp
NN/Normalization.cpp
Expand Down
4 changes: 1 addition & 3 deletions src/Conversion/ONNXToStablehlo/ConvertONNXToStablehlo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ void populateONNXToStablehloConversionPattern(
populateLoweringONNXGemmOpToStablehloPattern(patterns, ctx);
populateLoweringONNXMatMulOpToStablehloPattern(patterns, ctx);
populateLoweringONNXReductionOpToStablehloPattern(patterns, ctx);
populateLoweringONNXSoftmaxOpToStablehloPattern(patterns, ctx);
// Neural network
populateLoweringONNXConvOpToStablehloPattern(patterns, ctx);
populateLoweringONNXConvTransposeOpToStablehloPattern(patterns, ctx);
Expand Down Expand Up @@ -126,9 +127,6 @@ void FrontendToStablehloLoweringPass::runOnOperation() {
populateONNXToStablehloConversionPattern(
patterns, &getContext(), enableUnroll);

// add illegal op
target.addIllegalOp<ONNXSoftmaxOp>();

// With the target and rewrite patterns defined, we can now attempt the
// conversion. The conversion will signal failure if any of our `illegal`
// operations were not converted successfully.
Expand Down
186 changes: 186 additions & 0 deletions src/Conversion/ONNXToStablehlo/Math/Softmax.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
/*
* SPDX-License-Identifier: Apache-2.0
*/

//===---------------- Softmax.cpp - Softmax Ops -------------------===//
//
// Copyright 2022-2024
//
// =============================================================================
//
// This file lowers ONNX softmax operators to Stablehlo dialect.
//

#include "src/Conversion/ONNXToStablehlo/DialectBuilder.hpp"
#include "src/Conversion/ONNXToStablehlo/ONNXToStablehloCommon.hpp"
#include "src/Dialect/ONNX/ONNXOps/ShapeHelper.hpp"
#include "src/Support/TypeUtilities.hpp"
#include "stablehlo/dialect/BroadcastUtils.h"

using namespace mlir;

namespace onnx_mlir {

namespace {

Value getReductionShapeValue(Location loc, PatternRewriter &rewriter,
Value operand, llvm::SmallVector<int64_t, 4> axes, bool keepDims) {
int64_t rank = mlir::cast<RankedTensorType>(operand.getType()).getRank();

Value inputShape = rewriter.create<shape::ShapeOfOp>(loc, operand);
SmallVector<Value> dims;
for (int64_t i = 0; i < rank; i++) {
if (!(std::find(axes.begin(), axes.end(), i) != axes.end())) {
Value dim = rewriter.create<shape::GetExtentOp>(loc, inputShape, i);
dims.push_back(dim);
} else if (keepDims) {
Value dim = rewriter.create<arith::ConstantIndexOp>(loc, 1);
dims.push_back(dim);
}
}
Value reduceShapeValue = rewriter.create<shape::FromExtentsOp>(loc, dims);
reduceShapeValue = rewriter.create<shape::ToExtentTensorOp>(loc,
RankedTensorType::get({rank}, rewriter.getIndexType()), reduceShapeValue);
return reduceShapeValue;
}

// Calutes Broadcast dimensions
SmallVector<int64_t> getBroadcastDims(
Value operand, llvm::SmallVector<int64_t, 4> axes) {
int64_t rank = mlir::cast<RankedTensorType>(operand.getType()).getRank();
SmallVector<int64_t> dims;
for (int64_t i = 0; i < rank; i++) {
if (!(std::find(axes.begin(), axes.end(), i) != axes.end())) {
dims.push_back(i);
}
}

return dims;
}

Value computeReduceSum(Location loc, Value operand, Value identity,
SmallVector<int64_t> &reduceShape, llvm::SmallVector<int64_t, 4> axes,
PatternRewriter &rewriter, bool keepDims, ShapedType outputType) {

RankedTensorType operandType =
mlir::cast<RankedTensorType>(operand.getType());
Type reduceResultType =
RankedTensorType::get(reduceShape, operandType.getElementType());
stablehlo::ReduceOp reduce = rewriter.create<stablehlo::ReduceOp>(loc,
reduceResultType, operand, identity, rewriter.getDenseI64ArrayAttr(axes));

Region &region = reduce.getBody();
Block &block = region.emplaceBlock();
RankedTensorType blockArgumentType =
RankedTensorType::get({}, operandType.getElementType());
block.addArgument(blockArgumentType, loc);
block.addArgument(blockArgumentType, loc);

BlockArgument firstArgument = *block.args_begin();
BlockArgument secondArgument = *block.args_rbegin();
{
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(&block);
Value reduceResult =
rewriter.create<stablehlo::AddOp>(loc, firstArgument, secondArgument);
rewriter.create<stablehlo::ReturnOp>(loc, reduceResult);
}
Value result = reduce.getResult(0);

if (keepDims) {
Value reduceShapeValue =
getReductionShapeValue(loc, rewriter, operand, axes, true);
result = rewriter.create<stablehlo::DynamicReshapeOp>(
loc, outputType, result, reduceShapeValue);
}
return result;
}

SmallVector<int64_t> getReductionShape(ShapedType inputType,
const llvm::SmallVector<int64_t, 4> &axes, bool isKeepdims) {
SmallVector<int64_t> reduceShape;
llvm::ArrayRef<int64_t> inputShape = inputType.getShape();
int64_t rank = inputType.getRank();

// Mark reduction axes.
for (int64_t i = 0; i < rank; ++i) {
if (!(std::find(axes.begin(), axes.end(), i) != axes.end()))
reduceShape.push_back(inputShape[i]);
else if (isKeepdims)
reduceShape.push_back(1);
}

return reduceShape;
}

struct ONNXSoftmaxOpLoweringToStablehlo : public ConversionPattern {
ONNXSoftmaxOpLoweringToStablehlo(MLIRContext *ctx)
: ConversionPattern(ONNXSoftmaxOp::getOperationName(), 1, ctx) {}

LogicalResult matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {

Value operand = operands[0];
assert(
hasStaticShape(operand.getType()) && "Only Static shapes are accepted");

Location loc = op->getLoc();
Type outputType = *op->result_type_begin();
assert(isRankedShapedType(outputType) && "Expected Ranked ShapedType");
assert(mlir::cast<RankedTensorType>(operand.getType())
.getElementType()
.isF32() &&
"Currently Only float32 is supported for input");

// Exponential operation
Value ElementwiseExpStableHLO = rewriter.create<stablehlo::ExpOp>(
loc, op->getResultTypes(), op->getOperands());

if (ElementwiseExpStableHLO == nullptr)
return failure();

RankedTensorType ExpOutputType =
mlir::cast<RankedTensorType>(ElementwiseExpStableHLO.getType());

// Converting negative indices to Postive indices
int64_t axis = mlir::cast<ONNXSoftmaxOp>(*op).getAxis();
if (axis < 0)
axis = ExpOutputType.getRank() + axis;

SmallVector<int64_t, 4> axes = {axis};
// Sum of the all the exponents for the denominator
SmallVector<int64_t> reducedShape =
getReductionShape(ExpOutputType, axes, false);
ShapedType ReducedShapeType = mlir::cast<ShapedType>(
RankedTensorType::get(reducedShape, ExpOutputType.getElementType()));
Value identity = rewriter.create<stablehlo::ConstantOp>(
loc, rewriter.getZeroAttr(ExpOutputType.getElementType()));
Value ReduceSum = computeReduceSum(loc, ElementwiseExpStableHLO, identity,
reducedShape, axes, rewriter, false, ReducedShapeType);
if (ReduceSum == nullptr)
return failure();

SmallVector<int64_t> broadcast_dims =
getBroadcastDims(ElementwiseExpStableHLO, axes);
Value BroadCastOp =
rewriter.create<stablehlo::BroadcastInDimOp>(loc, ExpOutputType,
ReduceSum, rewriter.getDenseI64ArrayAttr(broadcast_dims));
if (BroadCastOp == nullptr)
return failure();

Value Softmax_output = rewriter.create<stablehlo::DivOp>(
loc, ElementwiseExpStableHLO, BroadCastOp);
if (Softmax_output == nullptr)
return failure();

rewriter.replaceOp(op, Softmax_output);
return success();
}
};
} // namespace

void populateLoweringONNXSoftmaxOpToStablehloPattern(
RewritePatternSet &patterns, MLIRContext *ctx) {
patterns.insert<ONNXSoftmaxOpLoweringToStablehlo>(ctx);
}
} // namespace onnx_mlir
2 changes: 2 additions & 0 deletions src/Conversion/ONNXToStablehlo/ONNXToStablehloCommon.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -215,4 +215,6 @@ void populateLoweringONNXTransposeOpToStablehloPattern(
RewritePatternSet &, MLIRContext *);
void populateLoweringONNXUnsqueezeOpToStablehloPattern(
RewritePatternSet &, MLIRContext *);
void populateLoweringONNXSoftmaxOpToStablehloPattern(
RewritePatternSet &patterns, MLIRContext *ctx);
} // namespace onnx_mlir
79 changes: 40 additions & 39 deletions test/mlir/conversion/onnx_to_stablehlo/Math/Softmax.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -32,45 +32,46 @@ func.func @test_softmax_dynamic(%arg0 : tensor<?x20x30xf32>) -> tensor<?x20x30xf
"func.return"(%0) : (tensor<?x20x30xf32>) -> ()
}

// CHECK-LABEL: func.func @test_softmax_dynamic
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<?x20x30xf32>) -> tensor<?x20x30xf32> {
// CHECK-DAG: [[VAR_0_:%.+]] = stablehlo.constant dense<0.000000e+00> : tensor<f32>
// CHECK-DAG: [[CST_2_:%.+]] = arith.constant 2 : index
// CHECK-DAG: [[CST_1_:%.+]] = arith.constant 1 : index
// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : index
// CHECK-DAG: [[VAR_1_:%.+]] = stablehlo.constant dense<0xFF800000> : tensor<f32>
// CHECK-NOT: separator of consecutive DAGs
// CHECK-DAG: [[VAR_2_:%.+]] = stablehlo.reduce([[PARAM_0_]] init: [[VAR_1_]]) applies stablehlo.maximum across dimensions = [1] : (tensor<?x20x30xf32>, tensor<f32>) -> tensor<?x30xf32>
// CHECK-DAG: [[VAR_3_:%.+]] = shape.shape_of [[PARAM_0_]] : tensor<?x20x30xf32> -> tensor<3xindex>
// CHECK-NOT: separator of consecutive DAGs
// CHECK-DAG: [[VAR_4_:%.+]] = shape.get_extent [[VAR_3_]], [[CST_0_]] : tensor<3xindex>, index -> index
// CHECK-DAG: [[VAR_5_:%.+]] = shape.get_extent [[VAR_3_]], [[CST_2_]] : tensor<3xindex>, index -> index
// CHECK: [[VAR_6_:%.+]] = shape.from_extents [[VAR_4_]], [[CST_1_]], [[VAR_5_]] : index, index, index
// CHECK: [[VAR_7_:%.+]] = shape.to_extent_tensor [[VAR_6_]] : !shape.shape -> tensor<3xindex>
// CHECK-DAG: [[VAR_8_:%.+]] = stablehlo.dynamic_reshape [[VAR_2_]], [[VAR_7_]] : (tensor<?x30xf32>, tensor<3xindex>) -> tensor<?x1x30xf32>
// CHECK-DAG: [[VAR_9_:%.+]] = shape.shape_of [[PARAM_0_]] : tensor<?x20x30xf32> -> tensor<3xindex>
// CHECK: [[VAR_10_:%.+]] = shape.shape_of [[VAR_8_]] : tensor<?x1x30xf32> -> tensor<3xindex>
// CHECK: [[VAR_11_:%.+]] = shape.broadcast [[VAR_9_]], [[VAR_10_]] : tensor<3xindex>, tensor<3xindex> -> tensor<3xindex>
// CHECK-DAG: [[VAR_12_:%.+]] = stablehlo.dynamic_broadcast_in_dim [[PARAM_0_]], [[VAR_11_]], dims = [0, 1, 2] : (tensor<?x20x30xf32>, tensor<3xindex>) -> tensor<?x20x30xf32>
// CHECK-DAG: [[VAR_13_:%.+]] = stablehlo.dynamic_broadcast_in_dim [[VAR_8_]], [[VAR_11_]], dims = [0, 1, 2] : (tensor<?x1x30xf32>, tensor<3xindex>) -> tensor<?x20x30xf32>
// CHECK: [[VAR_14_:%.+]] = stablehlo.subtract [[VAR_12_]], [[VAR_13_]] : tensor<?x20x30xf32>
// CHECK: [[VAR_15_:%.+]] = stablehlo.exponential [[VAR_14_]] : tensor<?x20x30xf32>
// CHECK-DAG: [[VAR_16_:%.+]] = stablehlo.reduce([[VAR_15_]] init: [[VAR_0_]]) applies stablehlo.add across dimensions = [1] : (tensor<?x20x30xf32>, tensor<f32>) -> tensor<?x30xf32>
// CHECK-DAG: [[VAR_17_:%.+]] = shape.shape_of [[VAR_15_]] : tensor<?x20x30xf32> -> tensor<3xindex>
// CHECK-NOT: separator of consecutive DAGs
// CHECK-DAG: [[VAR_18_:%.+]] = shape.get_extent [[VAR_17_]], [[CST_0_]] : tensor<3xindex>, index -> index
// CHECK-DAG: [[VAR_19_:%.+]] = shape.get_extent [[VAR_17_]], [[CST_2_]] : tensor<3xindex>, index -> index
// CHECK: [[VAR_20_:%.+]] = shape.from_extents [[VAR_18_]], [[CST_1_]], [[VAR_19_]] : index, index, index
// CHECK: [[VAR_21_:%.+]] = shape.to_extent_tensor [[VAR_20_]] : !shape.shape -> tensor<3xindex>
// CHECK-DAG: [[VAR_22_:%.+]] = stablehlo.dynamic_reshape [[VAR_16_]], [[VAR_21_]] : (tensor<?x30xf32>, tensor<3xindex>) -> tensor<?x1x30xf32>
// CHECK-DAG: [[VAR_23_:%.+]] = shape.shape_of [[VAR_15_]] : tensor<?x20x30xf32> -> tensor<3xindex>
// CHECK: [[VAR_24_:%.+]] = shape.shape_of [[VAR_22_]] : tensor<?x1x30xf32> -> tensor<3xindex>
// CHECK: [[VAR_25_:%.+]] = shape.broadcast [[VAR_23_]], [[VAR_24_]] : tensor<3xindex>, tensor<3xindex> -> tensor<3xindex>
// CHECK-DAG: [[VAR_26_:%.+]] = stablehlo.dynamic_broadcast_in_dim [[VAR_15_]], [[VAR_25_]], dims = [0, 1, 2] : (tensor<?x20x30xf32>, tensor<3xindex>) -> tensor<?x20x30xf32>
// CHECK-DAG: [[VAR_27_:%.+]] = stablehlo.dynamic_broadcast_in_dim [[VAR_22_]], [[VAR_25_]], dims = [0, 1, 2] : (tensor<?x1x30xf32>, tensor<3xindex>) -> tensor<?x20x30xf32>
// CHECK: [[VAR_28_:%.+]] = stablehlo.divide [[VAR_26_]], [[VAR_27_]] : tensor<?x20x30xf32>
// CHECK: return [[VAR_28_]] : tensor<?x20x30xf32>
// CHECK: }
//TODO: Renable dynamic shape test
// func.func @test_softmax_dynamic
// ([[PARAM_0_:%.+]]: tensor<?x20x30xf32>) -> tensor<?x20x30xf32> {
// [[VAR_0_:%.+]] = stablehlo.constant dense<0.000000e+00> : tensor<f32>
// [[CST_2_:%.+]] = arith.constant 2 : index
// [[CST_1_:%.+]] = arith.constant 1 : index
// [[CST_0_:%.+]] = arith.constant 0 : index
// [[VAR_1_:%.+]] = stablehlo.constant dense<0xFF800000> : tensor<f32>
// separator of consecutive DAGs
// [[VAR_2_:%.+]] = stablehlo.reduce([[PARAM_0_]] init: [[VAR_1_]]) applies stablehlo.maximum across dimensions = [1] : (tensor<?x20x30xf32>, tensor<f32>) -> tensor<?x30xf32>
// [[VAR_3_:%.+]] = shape.shape_of [[PARAM_0_]] : tensor<?x20x30xf32> -> tensor<3xindex>
// separator of consecutive DAGs
// [[VAR_4_:%.+]] = shape.get_extent [[VAR_3_]], [[CST_0_]] : tensor<3xindex>, index -> index
// [[VAR_5_:%.+]] = shape.get_extent [[VAR_3_]], [[CST_2_]] : tensor<3xindex>, index -> index
// [[VAR_6_:%.+]] = shape.from_extents [[VAR_4_]], [[CST_1_]], [[VAR_5_]] : index, index, index
// [[VAR_7_:%.+]] = shape.to_extent_tensor [[VAR_6_]] : !shape.shape -> tensor<3xindex>
// [[VAR_8_:%.+]] = stablehlo.dynamic_reshape [[VAR_2_]], [[VAR_7_]] : (tensor<?x30xf32>, tensor<3xindex>) -> tensor<?x1x30xf32>
// [[VAR_9_:%.+]] = shape.shape_of [[PARAM_0_]] : tensor<?x20x30xf32> -> tensor<3xindex>
// [[VAR_10_:%.+]] = shape.shape_of [[VAR_8_]] : tensor<?x1x30xf32> -> tensor<3xindex>
// [[VAR_11_:%.+]] = shape.broadcast [[VAR_9_]], [[VAR_10_]] : tensor<3xindex>, tensor<3xindex> -> tensor<3xindex>
// [[VAR_12_:%.+]] = stablehlo.dynamic_broadcast_in_dim [[PARAM_0_]], [[VAR_11_]], dims = [0, 1, 2] : (tensor<?x20x30xf32>, tensor<3xindex>) -> tensor<?x20x30xf32>
// [[VAR_13_:%.+]] = stablehlo.dynamic_broadcast_in_dim [[VAR_8_]], [[VAR_11_]], dims = [0, 1, 2] : (tensor<?x1x30xf32>, tensor<3xindex>) -> tensor<?x20x30xf32>
// [[VAR_14_:%.+]] = stablehlo.subtract [[VAR_12_]], [[VAR_13_]] : tensor<?x20x30xf32>
// [[VAR_15_:%.+]] = stablehlo.exponential [[VAR_14_]] : tensor<?x20x30xf32>
// [[VAR_16_:%.+]] = stablehlo.reduce([[VAR_15_]] init: [[VAR_0_]]) applies stablehlo.add across dimensions = [1] : (tensor<?x20x30xf32>, tensor<f32>) -> tensor<?x30xf32>
// [[VAR_17_:%.+]] = shape.shape_of [[VAR_15_]] : tensor<?x20x30xf32> -> tensor<3xindex>
// separator of consecutive DAGs
// [[VAR_18_:%.+]] = shape.get_extent [[VAR_17_]], [[CST_0_]] : tensor<3xindex>, index -> index
// [[VAR_19_:%.+]] = shape.get_extent [[VAR_17_]], [[CST_2_]] : tensor<3xindex>, index -> index
// [[VAR_20_:%.+]] = shape.from_extents [[VAR_18_]], [[CST_1_]], [[VAR_19_]] : index, index, index
// [[VAR_21_:%.+]] = shape.to_extent_tensor [[VAR_20_]] : !shape.shape -> tensor<3xindex>
// [[VAR_22_:%.+]] = stablehlo.dynamic_reshape [[VAR_16_]], [[VAR_21_]] : (tensor<?x30xf32>, tensor<3xindex>) -> tensor<?x1x30xf32>
// [[VAR_23_:%.+]] = shape.shape_of [[VAR_15_]] : tensor<?x20x30xf32> -> tensor<3xindex>
// [[VAR_24_:%.+]] = shape.shape_of [[VAR_22_]] : tensor<?x1x30xf32> -> tensor<3xindex>
// [[VAR_25_:%.+]] = shape.broadcast [[VAR_23_]], [[VAR_24_]] : tensor<3xindex>, tensor<3xindex> -> tensor<3xindex>
// [[VAR_26_:%.+]] = stablehlo.dynamic_broadcast_in_dim [[VAR_15_]], [[VAR_25_]], dims = [0, 1, 2] : (tensor<?x20x30xf32>, tensor<3xindex>) -> tensor<?x20x30xf32>
// [[VAR_27_:%.+]] = stablehlo.dynamic_broadcast_in_dim [[VAR_22_]], [[VAR_25_]], dims = [0, 1, 2] : (tensor<?x1x30xf32>, tensor<3xindex>) -> tensor<?x20x30xf32>
// [[VAR_28_:%.+]] = stablehlo.divide [[VAR_26_]], [[VAR_27_]] : tensor<?x20x30xf32>
// return [[VAR_28_]] : tensor<?x20x30xf32>
// }

// -----

Expand Down

0 comments on commit 226a0d6

Please sign in to comment.