Skip to content

Commit

Permalink
Shape inference for Layer Normalization Op (#2559)
Browse files Browse the repository at this point in the history
Signed-off-by: Alexandre Eichenberger <[email protected]>
Co-authored-by: Tung D. Le <[email protected]>
  • Loading branch information
AlexandreEichenberger and tungld authored Oct 17, 2023
1 parent 7f8f1d9 commit e678bc0
Show file tree
Hide file tree
Showing 7 changed files with 173 additions and 3 deletions.
1 change: 1 addition & 0 deletions src/Dialect/ONNX/ONNXOps.td.inc
Original file line number Diff line number Diff line change
Expand Up @@ -3615,6 +3615,7 @@ def ONNXLayerNormalizationOp:ONNX_Op<"LayerNormalization",
return sh;
}
}];
let hasVerifier = 1;
}

def ONNXLeakyReluOp:ONNX_Op<"LeakyRelu",
Expand Down
122 changes: 122 additions & 0 deletions src/Dialect/ONNX/ONNXOps/NN/Normalization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -149,3 +149,125 @@ LogicalResult ONNXInstanceNormalizationOp::verify() {
}

// TODO: should there be a shape inference for this one?

//===----------------------------------------------------------------------===//
// LayerNormalization
//===----------------------------------------------------------------------===//

LogicalResult ONNXLayerNormalizationOp::verify() {
auto operandAdaptor = ONNXLayerNormalizationOpAdaptor(*this);

// Get attributes.
int64_t axis = getAxis();

// Get operands.
Value X = operandAdaptor.getX();
Value scale = operandAdaptor.getScale();
Value B = operandAdaptor.getB();

// Check X.
if (!hasShapeAndRank(X)) {
// Won't be able to do any checking at this stage.
return success();
}
ShapedType XType = X.getType().cast<ShapedType>();
ArrayRef<int64_t> XShape = XType.getShape();
int64_t XRank = XShape.size();
Type XElementType = XType.getElementType();

// Axis attribute (if specified) must be in the range [-r,r), where r =
// rank(input).
if (axis < -XRank || axis >= XRank)
return emitOpError("axis must be in [-r, r) range]");
if (axis < 0)
axis += XRank;

// Check bias B.
if (hasShapeAndRank(B)) {
// Can check at this stage.
ShapedType bType = B.getType().cast<ShapedType>();
ArrayRef<int64_t> bShape = bType.getShape();
SmallVector<int64_t> BBroadcastShape;
if (!OpTrait::util::getBroadcastedShape(XShape, bShape, BBroadcastShape))
emitOpError(
"LayerNormalization op with incompatible B shapes (broadcast)");
if ((int64_t)BBroadcastShape.size() != XRank)
emitOpError("LayerNormalization op with incompatible B shapes "
"(unidirectional broadcast)");
if (bType.getElementType() != XElementType)
emitOpError("LayerNormalization op with incompatible B type");
}

// Check scale.
if (hasShapeAndRank(scale)) {
// Can check at this stage.
ShapedType scaleType = scale.getType().cast<ShapedType>();
ArrayRef<int64_t> scaleShape = scaleType.getShape();
SmallVector<int64_t> scaleBroadcastShape;
if (!OpTrait::util::getBroadcastedShape(
XShape, scaleShape, scaleBroadcastShape))
emitOpError(
"LayerNormalization op with incompatible scale shapes (broadcast)");
if ((int64_t)scaleBroadcastShape.size() != XRank)
emitOpError("LayerNormalization op with incompatible scale shapes "
"(unidirectional broadcast)");
if (scaleType.getElementType() != XElementType)
emitOpError("LayerNormalization op with incompatible scale type");
}

return success();
}

namespace onnx_mlir {

mlir::LogicalResult ONNXLayerNormalizationOpShapeHelper::computeShape() {

ONNXLayerNormalizationOpAdaptor operandAdaptor(operands);
ONNXLayerNormalizationOp lnOp = llvm::cast<ONNXLayerNormalizationOp>(op);

// Get rank and axis attribute.
int64_t axis = lnOp.getAxis();
Value X = operandAdaptor.getX();
int64_t XRank = X.getType().cast<ShapedType>().getRank();
if (axis < 0)
axis += XRank;

// Compute the shape of the first output and all the inputs.
llvm::SmallVector<Value, 3> operandsForBroadcast;
operandsForBroadcast.emplace_back(X);
operandsForBroadcast.emplace_back(operandAdaptor.getScale());
if (!isNoneValue(operandAdaptor.getB()))
operandsForBroadcast.emplace_back(operandAdaptor.getB());
if (failed(ONNXBroadcastOpShapeHelper::customComputeShape(
operandsForBroadcast, nullptr)))
return failure();

// Compute mean output shape if requested.
if (!isNoneValue(lnOp.getMean())) {
DimsExpr meanShape(getOutputDims(0));
for (int64_t r = axis; r < XRank; ++r)
meanShape[r] = LiteralIndexExpr(1);
setOutputDims(meanShape, 1, false);
}
// Compute invStdDev output shape if requested.
if (!isNoneValue(lnOp.getInvStdDev())) {
DimsExpr invStdDevShape(getOutputDims(0));
for (int64_t r = axis; r < XRank; ++r)
invStdDevShape[r] = LiteralIndexExpr(1);
setOutputDims(invStdDevShape, 2, false);
}
return success();
}
} // namespace onnx_mlir

LogicalResult ONNXLayerNormalizationOp::inferShapes(
std::function<void(Region &)> doShapeInference) {
// If any input is not ranked tensor, do nothing. Account for possibly null
// inputs (B).
if (!hasShapeAndRank(getX()) || !hasShapeAndRank(getScale()) ||
(!isNoneValue(getB()) && !hasShapeAndRank(getB())))
return success();
Type commonType = getX().getType().cast<RankedTensorType>().getElementType();
ONNXLayerNormalizationOpShapeHelper shapeHelper(getOperation(), {});
return shapeHelper.computeShapeAndUpdateType(commonType);
}
2 changes: 2 additions & 0 deletions src/Dialect/ONNX/ONNXOps/ShapeHelper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,8 @@ LogicalResult ONNXOpShapeHelper::computeShapeAndUpdateType(
// Invoke virtual compute shape.
if (failed(computeShape()))
return op->emitError("Failed to scan parameters successfully");
assert((elementType.isa<VectorType>() || !elementType.isa<ShapedType>()) &&
"element type cannot be a shaped type other than vector type");
uint64_t resNum = op->getNumResults();
for (uint64_t i = 0; i < resNum; ++i) {
// If we have an optional type, leave it as is.
Expand Down
11 changes: 11 additions & 0 deletions src/Dialect/ONNX/ONNXOps/ShapeHelper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,17 @@ struct ONNXPReluOpShapeHelper : public ONNXBroadcastOpShapeHelper {
}
};

// Helper for ONNXLayerNormalizationOp (B and Scales broadcast to input X)
struct ONNXLayerNormalizationOpShapeHelper : public ONNXBroadcastOpShapeHelper {
ONNXLayerNormalizationOpShapeHelper(mlir::Operation *op,
mlir::ValueRange operands, IndexExprBuilder *ieBuilder = nullptr,
IndexExprScope *scope = nullptr)
: ONNXBroadcastOpShapeHelper(op, operands, ieBuilder, scope,
/*hasUniBroadcasting*/ true) {}
virtual ~ONNXLayerNormalizationOpShapeHelper() {}
mlir::LogicalResult computeShape() final;
};

//===----------------------------------------------------------------------===//
// Unary Ops
//===----------------------------------------------------------------------===//
Expand Down
1 change: 0 additions & 1 deletion src/Dialect/ONNX/ONNXUnsupportedOps.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ UNSUPPORTED_OPS(ONNXHammingWindowOp)
UNSUPPORTED_OPS(ONNXHannWindowOp)
UNSUPPORTED_OPS(ONNXImputerOp)
UNSUPPORTED_OPS(ONNXLabelEncoderOp)
UNSUPPORTED_OPS(ONNXLayerNormalizationOp)
UNSUPPORTED_OPS(ONNXLinearClassifierOp)
UNSUPPORTED_OPS(ONNXLinearRegressorOp)
UNSUPPORTED_OPS(ONNXLpPoolOp)
Expand Down
38 changes: 36 additions & 2 deletions test/mlir/onnx/onnx_shape_inference.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -3680,8 +3680,6 @@ module {
}
}

// -----

// Check that ClipV6 operation shape inference goes through shape inference smoothly.
// ClipV6 has no shape inference as it is supposed to be first updated to the latest ClipOp.
// Using the latest shape inference, the default is to let unimplemented ops go through shape
Expand Down Expand Up @@ -3738,3 +3736,39 @@ func.func @test_custom3(%arg0: tensor<1024xi32>, %arg1: tensor<4xf32>) -> tensor
// CHECK: [[VAR_0_:%.+]] = "onnx.Custom"([[PARAM_0_]], [[PARAM_1_]]) {function_name = "testcall", inputs_for_infer = [1], shape_infer_pattern = "SameAs"} : (tensor<1024xi32>, tensor<4xf32>) -> tensor<4xf32>
// CHECK: return [[VAR_0_]] : tensor<4xf32>
// CHECK: }


// -----

// Test layer norm when not decomposed

func.func @test_layer_norm_3inputs(%arg0: tensor<12x3x5xf32>, %arg1: tensor<5xf32>, %arg2: tensor<5xf32>) -> tensor<*xf32> {
%Y, %Mean, %InvStdDev = "onnx.LayerNormalization"(%arg0, %arg1, %arg2) {axis = -1 : si64, epsilon = 9.99999974E-6 : f32, stash_type = 1 : si64} : (tensor<12x3x5xf32>, tensor<5xf32>, tensor<5xf32>) -> (tensor<*xf32>, tensor<*xf32>, tensor<*xf32>)
return %Y : tensor<*xf32>

// mlir2FileCheck.py
// CHECK-LABEL: func.func @test_layer_norm_3inputs
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<12x3x5xf32>, [[PARAM_1_:%.+]]: tensor<5xf32>, [[PARAM_2_:%.+]]: tensor<5xf32>) -> tensor<12x3x5xf32> {
// CHECK: [[Y_:%.+]], [[Mean_:%.+]], [[VAR_InvStdDev_:%.+]] = "onnx.LayerNormalization"([[PARAM_0_]], [[PARAM_1_]], [[PARAM_2_]]) {axis = -1 : si64, epsilon = 9.99999974E-6 : f32, stash_type = 1 : si64} : (tensor<12x3x5xf32>, tensor<5xf32>, tensor<5xf32>) -> (tensor<12x3x5xf32>, tensor<12x3x1xf32>, tensor<12x3x1xf32>)
// CHECK: return [[Y_]] : tensor<12x3x5xf32>
// CHECK: }
}

// -----

// Test layer norm when not decomposed

func.func @test_layer_norm_2inputs(%arg0: tensor<12x3x5xf32>, %arg1: tensor<5xf32>) -> tensor<*xf32> {
%0 = "onnx.NoValue"() {value} : () -> none
%Y, %Mean, %InvStdDev = "onnx.LayerNormalization"(%arg0, %arg1, %0) {axis = -1 : si64, epsilon = 9.99999974E-6 : f32, stash_type = 1 : si64} : (tensor<12x3x5xf32>, tensor<5xf32>, none) -> (tensor<*xf32>, tensor<*xf32>, tensor<*xf32>)
return %Y : tensor<*xf32>

// mlir2FileCheck.py
// CHECK-LABEL: func.func @test_layer_norm_2inputs
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<12x3x5xf32>, [[PARAM_1_:%.+]]: tensor<5xf32>) -> tensor<12x3x5xf32> {
// CHECK: [[VAR_0_:%.+]] = "onnx.NoValue"() {value} : () -> none
// CHECK: [[Y_:%.+]], [[Mean_:%.+]], [[VAR_InvStdDev_:%.+]] = "onnx.LayerNormalization"([[PARAM_0_]], [[PARAM_1_]], [[VAR_0_]]) {axis = -1 : si64, epsilon = 9.99999974E-6 : f32, stash_type = 1 : si64} : (tensor<12x3x5xf32>, tensor<5xf32>, none) -> (tensor<12x3x5xf32>, tensor<12x3x1xf32>, tensor<12x3x1xf32>)
// CHECK: return [[Y_]] : tensor<12x3x5xf32>
// CHECK: }
}

1 change: 1 addition & 0 deletions utils/gen_onnx_mlir.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,7 @@
"If",
"IsInf",
"InstanceNormalization",
"LayerNormalization",
"Less",
"LessOrEqual",
"LogSoftmax",
Expand Down

0 comments on commit e678bc0

Please sign in to comment.