Skip to content

Commit

Permalink
Transform SequenceAt to split for special cases (#3018)
Browse files Browse the repository at this point in the history
* implement

Signed-off-by: chentong319 <[email protected]>

* test case

Signed-off-by: chentong319 <[email protected]>

* format

Signed-off-by: chentong319 <[email protected]>

* fix

Signed-off-by: chentong319 <[email protected]>

---------

Signed-off-by: chentong319 <[email protected]>
Co-authored-by: Alexandre Eichenberger <[email protected]>
  • Loading branch information
chentong319 and AlexandreEichenberger authored Dec 4, 2024
1 parent e801b36 commit 45f07d5
Show file tree
Hide file tree
Showing 5 changed files with 181 additions and 2 deletions.
1 change: 0 additions & 1 deletion src/Dialect/ONNX/DialectBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -444,7 +444,6 @@ TensorType OnnxBuilder::toTensor(Type input) const {
}

TypeRange OnnxBuilder::toTensors(TypeRange inputs) const {
assert(inputs.size() >= 2 && "Expect at least two inputs");
if (llvm::all_of(inputs, [](Type t) { return (mlir::isa<TensorType>(t)); }))
return inputs;
assert(llvm::all_of(inputs, [](Type t) {
Expand Down
3 changes: 2 additions & 1 deletion src/Dialect/ONNX/ONNXOps/OpHelper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,8 @@ void ArrayAttrIntVals(ArrayAttr a, mlir::SmallVectorImpl<int64_t> &i) {

ElementsAttr getElementAttributeFromONNXValue(Value value) {
ONNXConstantOp constantOp = getONNXConstantOp(value);
if (constantOp)
// In case the ConstantOp has not been normalized yet
if (constantOp && constantOp.getValueAttr())
return mlir::dyn_cast<ElementsAttr>(constantOp.getValueAttr());
return nullptr;
}
Expand Down
118 changes: 118 additions & 0 deletions src/Dialect/ONNX/Transforms/Decompose.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,120 @@ bool hasStaticSpatialDims(Value v) {
return llvm::none_of(Ds, ShapedType::isDynamic);
}

// In the following pattern, a SequenceAt can be replaced with Split
// %seq = onnx.SplitToSequence(%input, %split) {%axis : }
// %res = onnx.SequenceAt(%seq, %position)
// We just try to avoid using the sequence related ops, which are less
// optimized, or even not implemented in onnx-mlir.
// In the targeted use case, %split and %position are constant scalar and the
// tensor of %input and %res have static shape.
// This condition greatly reduces the complexity of code generation to replace
// SequenceAt with split op
// %res = onnx.Split(%input, onnx.expand(%split, %input.shape()[%axis]))
// {%axis : } : %position
// onnx.expand(%split, %input.shape()[%axis]) can be a constant under the
// assumed condition.
// Here %position has to be compiler time constant.
// For multiple SequenceAt from the same SplitToSequence result, the onnx.split
// for different SequenceAt are expected to be merged by optimization.
// Alternatively, Slice can be used
// %res = onnx.Slice(%input, %start, %end, %step)
// The start, and end for slice will be onnx.constant:
// start: %position*%split for %axis, 0 for other dimensionis
// end: (%positiion+1)*%split for %axis, upper bound for other dimension
// step: 1 for all dimensions
// The split approach may have better performance than the alternative slice
// approach, because the slicing is done separately.

bool canSequenceAtBeReplaced(Value sequenceAtResult) {
if (!hasStaticShape(sequenceAtResult.getType()))
return false;

ONNXSequenceAtOp op = sequenceAtResult.getDefiningOp<ONNXSequenceAtOp>();

Value inputSequence = op.getInputSequence();
Value position = op.getPosition();

if (!isDenseONNXConstant(position))
return false;

// Input sequence should be defined with SplitToSequence
ONNXSplitToSequenceOp splitToSequence =
inputSequence.getDefiningOp<ONNXSplitToSequenceOp>();
if (!splitToSequence)
return false;

// Check the pattern of the SplitToSequence op
Value input = splitToSequence.getInput();
if (!hasStaticShape(input.getType()))
return false;
Value split = splitToSequence.getSplit();
if (!isScalarConstantTensor(split))
return false;

return true;
}

Value replaceSequenceAt(
PatternRewriter &rewriter, Location loc, Value sequenceAtResult) {
ONNXSequenceAtOp op = sequenceAtResult.getDefiningOp<ONNXSequenceAtOp>();

Value inputSequence = op.getInputSequence();
Value position = op.getPosition();

ONNXConstantOp positionConstant =
mlir::cast<ONNXConstantOp>(position.getDefiningOp());
int64_t positionInt = getScalarValue<int64_t>(positionConstant);

ONNXSplitToSequenceOp splitToSequence =
mlir::cast<ONNXSplitToSequenceOp>(inputSequence.getDefiningOp());

Value input = splitToSequence.getInput();
Value split = splitToSequence.getSplit();

ONNXConstantOp splitConstant =
mlir::cast<ONNXConstantOp>(split.getDefiningOp());
int64_t splitInt = getScalarValue<int64_t>(splitConstant);
int64_t axisInt = splitToSequence.getAxis();

auto shape = getShape(input.getType());

OnnxBuilder create(rewriter, loc);

Type sequenceElementType =
mlir::cast<SeqType>(inputSequence.getType()).getElementType();
mlir::SmallVector<mlir::Type, 4> outputTypes(
shape[axisInt] / splitInt, sequenceElementType);
auto numSplit = create.constantInt64(
mlir::SmallVector<int64_t, 4>(shape[axisInt] / splitInt, splitInt));
auto resultRange = create.split(outputTypes, input, numSplit, axisInt);
auto rawResult = resultRange[positionInt];

if (rawResult.getType() == sequenceAtResult.getType())
return rawResult;

// Temporary code for the error in the model generated by torch.onnx.export
// The the dim of the reuslt of SequenceAt op is different from the element
// type of the sequence..
// My assumption is that the exporter is confused with squeeze and unsqueeze
// followed by the SequenceAt. There are two cases in the model:
// clang-format off
// Case #1:
// %16 = "onnx.SequenceAt"(%14, %15) {onnx_node_name = "n0"} :
// (!onnx.Seq<tensor<1x1x100xf32>>, tensor<i64>) -> tensor<1x100xf32>
// %23 = "onnx.Unsqueeze"(%16, %22) {onnx_node_name = "n2"} :
// (tensor<1x100xf32>, tensor<i64>) -> tensor<1x1x100xf32>
// Case#2:
// %67 = "onnx.SequenceAt"(%66, %15) {onnx_node_name = "n0"} :
// (!onnx.Seq<tensor<1x1x100xf32>>, tensor<i64>) -> tensor<1x1x100xf32>
// %71 = "onnx.Sigmoid"(%67) {onnx_node_name = "node_Sigmoid_60"} :
// (tensor<1x1x100xf32>) -> tensor<1x1x100xf32>
// clang-format on
// Thus, the compiler squeeze the tensor if needed.
return create.squeeze(
sequenceAtResult.getType(), rawResult, create.constantInt64(axisInt));
}

bool shouldDecomposeConvTransposeOp(Value convTransposeResult) {
ONNXConvTransposeOp op =
mlir::cast<ONNXConvTransposeOp>(convTransposeResult.getDefiningOp());
Expand Down Expand Up @@ -1246,6 +1360,10 @@ void DecomposeONNXToONNXPass::runOnOperation() {
return !isConcatFuseMatched(op, shapeOp, transposeOp);
});

target.addDynamicallyLegalOp<ONNXSequenceAtOp>([](ONNXSequenceAtOp op) {
return !onnx_mlir::canSequenceAtBeReplaced(op.getResult());
});

// Rewrite ONNXConstantOp with scalar values into the one using ElementAttrs.
target.addDynamicallyLegalOp<ONNXConstantOp>([](ONNXConstantOp op) {
return !(op.getValueFloatAttr() || op.getValueFloatsAttr() ||
Expand Down
18 changes: 18 additions & 0 deletions src/Dialect/ONNX/Transforms/Decompose.td
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,12 @@ def createScalarDenseAttrRank0
def ReshapeElementsAttrToRank0 : NativeCodeCall<
"onnx_mlir::OnnxElementsAttrBuilder($0.getContext()).reshape(cast<ElementsAttr>($0), {})">;

def ReplaceSequenceAt : NativeCodeCall<
"onnx_mlir::replaceSequenceAt($_builder, $_loc, $0)">;

def CanSequenceAtBeReplaced :
Constraint<CPred<"::onnx_mlir::canSequenceAtBeReplaced($_self)">, "check whether the SequenceAt can be replaced with split">;

// Create a DenseElementsAttr from a single attribute.
def createDenseArrayAttrFromSingleAttr
: NativeCodeCall<"::onnx_mlir::createDenseArrayAttr($_builder, $_builder.getArrayAttr($0))">;
Expand Down Expand Up @@ -620,4 +626,16 @@ def ConstantOpNormalizationPattern6: Pat<
[(AttributeIsNotNull:$stringsAttr)]
>;

// Optimize for the pattern coming from torch.nn.LSTM exported from pytorch
// %32 = "onnx.SplitToSequence"(%30, %27) {axis = 0 : si64, keepdims = 0 : si64, onnx_node_name = "n1"} : (tensor<1x1x100xf32>, tensor<i64>) -> !onnx.Seq<tensor<1x1x100xf32>>
// %33 = "onnx.SequenceAt"(%32, %26) {onnx_node_name = "n0"} : (!onnx.Seq<tensor<1x1x100xf32>>, tensor<i64>) -> tensor<1x100xf32>
// When shape and size/axis related value are constant, this sequence of code
// can be translated into onnx.slice

def ReplaceSequenceAtPattern: Pat<
(ONNXSequenceAtOp:$res $seq, $position),
(ReplaceSequenceAt $res),
[(CanSequenceAtBeReplaced:$res)]
>;

#endif // ONNX_DECOMPOSE
43 changes: 43 additions & 0 deletions test/mlir/onnx/onnx_decompose_canonicalize.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@

// RUN: onnx-mlir-opt --decompose-onnx --canonicalize %s -split-input-file | FileCheck %s

// -----

// Test one pattern in lstm_no_data.onnx.
// The type of output of SequenceAt is not the same as the element type
// of the input sequence
func.func @sequence_at_squeezed(%arg0 : tensor<1x1x100xf32>) -> tensor<1x100xf32> {
%26 = onnx.Constant dense<0> : tensor<i64>
%27 = onnx.Constant dense<1> : tensor<i64>
%32 = "onnx.SplitToSequence"(%arg0, %27) {axis = 0 : si64, keepdims = 0 : si64} : (tensor<1x1x100xf32>, tensor<i64>) -> !onnx.Seq<tensor<1x1x100xf32>>
%33 = "onnx.SequenceAt"(%32, %26) : (!onnx.Seq<tensor<1x1x100xf32>>, tensor<i64>) -> tensor<1x100xf32>
return %33: tensor<1x100xf32>
// CHECK-LABEL: func.func @sequence_at_squeezed
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x1x100xf32>) -> tensor<1x100xf32> {
// CHECK-DAG: [[VAR_0_:%.+]] = onnx.Constant dense<0> : tensor<1xi64>
// CHECK-DAG: [[VAR_1_:%.+]] = onnx.Constant dense<1> : tensor<1xi64>
// CHECK: [[VAR_2_:%.+]] = "onnx.Split"([[PARAM_0_]], [[VAR_1_]]) {axis = 0 : si64} : (tensor<1x1x100xf32>, tensor<1xi64>) -> tensor<1x1x100xf32>
// CHECK: [[VAR_3_:%.+]] = "onnx.Squeeze"([[VAR_2_]], [[VAR_0_]]) : (tensor<1x1x100xf32>, tensor<1xi64>) -> tensor<1x100xf32>
// CHECK: return [[VAR_3_]] : tensor<1x100xf32>
// CHECK: }
}

func.func @sequence_at_multi(%arg0 : tensor<1x1x400xf32>) -> tensor<1x1x100xf32> {
%15 = onnx.Constant dense<0> : tensor<i64>
%38 = onnx.Constant dense<1> : tensor<i64>
%65 = onnx.Constant dense<100> : tensor<i64>
%66 = "onnx.SplitToSequence"(%arg0, %65) {axis = 2 : si64, keepdims = 1 : si64} : (tensor<1x1x400xf32>, tensor<i64>) -> !onnx.Seq<tensor<1x1x100xf32>>
%67 = "onnx.SequenceAt"(%66, %15) : (!onnx.Seq<tensor<1x1x100xf32>>, tensor<i64>) -> tensor<1x1x100xf32>
%68 = "onnx.SequenceAt"(%66, %38) : (!onnx.Seq<tensor<1x1x100xf32>>, tensor<i64>) -> tensor<1x1x100xf32>
%40 = "onnx.Add"(%67, %68) : (tensor<1x1x100xf32>, tensor<1x1x100xf32>) -> tensor<1x1x100xf32>
return %40: tensor<1x1x100xf32>
// CHECK-LABEL: func.func @sequence_at_multi
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x1x400xf32>) -> tensor<1x1x100xf32> {
// CHECK: [[VAR_0_:%.+]] = onnx.Constant dense<100> : tensor<4xi64>
// CHECK-DAG: [[VAR_1_:%.+]]:4 = "onnx.Split"([[PARAM_0_]], [[VAR_0_]]) {axis = 2 : si64} : (tensor<1x1x400xf32>, tensor<4xi64>) -> (tensor<1x1x100xf32>, tensor<1x1x100xf32>, tensor<1x1x100xf32>, tensor<1x1x100xf32>)
// CHECK-DAG: [[VAR_2_:%.+]]:4 = "onnx.Split"([[PARAM_0_]], [[VAR_0_]]) {axis = 2 : si64} : (tensor<1x1x400xf32>, tensor<4xi64>) -> (tensor<1x1x100xf32>, tensor<1x1x100xf32>, tensor<1x1x100xf32>, tensor<1x1x100xf32>)
// CHECK: [[VAR_3_:%.+]] = "onnx.Add"([[VAR_1_]]#0, [[VAR_2_]]#1) : (tensor<1x1x100xf32>, tensor<1x1x100xf32>) -> tensor<1x1x100xf32>
// CHECK: return [[VAR_3_]] : tensor<1x1x100xf32>
// CHECK: }
}

0 comments on commit 45f07d5

Please sign in to comment.