Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Transform SequenceAt to split for special cases #3018

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

chentong319
Copy link
Collaborator

Certain PyTorch.onnx.export will break the LSTM op into lower level operations and generate SplitToSequence and SequenceAt operation pattern. For example:

  %15 = onnx.Constant dense<0> : tensor<i64>
  %38 = onnx.Constant dense<1> : tensor<i64>
  %65 = onnx.Constant dense<100> : tensor<i64>
  %66 = "onnx.SplitToSequence"(%arg0, %65) {axis = 2 : si64, keepdims = 1 : si64} : (tensor<1x1x400xf32>, tensor<i64>) -> !onnx.Seq<tensor<1x1x100xf32>>
  %67 = "onnx.SequenceAt"(%66, %15) : (!onnx.Seq<tensor<1x1x100xf32>>, tensor<i64>) -> tensor<1x1x100xf32>
  %68 = "onnx.SequenceAt"(%66, %38) : (!onnx.Seq<tensor<1x1x100xf32>>, tensor<i64>) -> tensor<1x1x100xf32>
  %40 = "onnx.Add"(%67, %68) : (tensor<1x1x100xf32>, tensor<1x1x100xf32>) -> tensor<1x1x100xf32>

ONNX-MLIR currently does not have the lowering for SplitToSequence. In general, sequence related ops are not well optimized in ONNX-MLIR. I found that such code pattern can be converted into tensor operations to avoid the sequence ops.

    %0 = onnx.Constant dense<100> : tensor<4xi64>
    %1:4 = "onnx.Split"(%arg0, %0) {axis = 2 : si64} : (tensor<1x1x400xf32>, tensor<4xi64>) -> (tensor<1x1x100xf32>, tensor<1x1x100xf32>, tensor<1x1x100xf32>, tensor<1x1x100xf32>)
    %2:4 = "onnx.Split"(%arg0, %0) {axis = 2 : si64} : (tensor<1x1x400xf32>, tensor<4xi64>) -> (tensor<1x1x100xf32>, tensor<1x1x100xf32>, tensor<1x1x100xf32>, tensor<1x1x100xf32>)
    %3 = "onnx.Add"(%1#0, %2#1) : (tensor<1x1x100xf32>, tensor<1x1x100xf32>) -> tensor<1x1x100xf32>

The two onnx.Split is supposed to be able merged into one. But current onnx-mlir didn't. Need further investigation.

But the exported model has another type of SequenceAt, in which the output type of SequenceAt is different from the element type of the input sequence type. I think it is an error in the exporter. However, I tried to fix this issue in the transformation of this PR.

    %26 = onnx.Constant dense<0> : tensor<i64>
    %27 = onnx.Constant dense<1> : tensor<i64>
    %32 = "onnx.SplitToSequence"(%arg0, %27) {axis = 0 : si64, keepdims = 0 : si64} : (tensor<1x1x100xf32>, tensor<i64>) -> !onnx.Seq<tensor<1x1x100xf32>>
    %33 = "onnx.SequenceAt"(%32, %26) : (!onnx.Seq<tensor<1x1x100xf32>>, tensor<i64>) -> tensor<1x100xf32>

Output from this PR:

    %0 = onnx.Constant dense<0> : tensor<1xi64>
    %1 = onnx.Constant dense<1> : tensor<1xi64>
    %2 = "onnx.Split"(%arg0, %1) {axis = 0 : si64} : (tensor<1x1x100xf32>, tensor<1xi64>) -> tensor<1x1x100xf32>
    %3 = "onnx.Squeeze"(%2, %0) : (tensor<1x1x100xf32>, tensor<1xi64>) -> tensor<1x100xf32>

Other small fixes in this PR:

  1. Handle the case that onnx.ConstantOp has not been normalized yet. Otherwise, --EmitONNXBasic will fail for this model.
  2. It is not needed to check the tensor range size to be larger than one. In the model, the number of output from Split happened to be one.

Copy link
Collaborator

@tungld tungld left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

ONNXSplitToSequenceOp splitToSequence;
if (!(splitToSequence = mlir::dyn_cast<ONNXSplitToSequenceOp>(
inputSequence.getDefiningOp())))
return false;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can be shorten with

ONNXSplitToSequenceOp splitToSequence = inputSequence.getDefiningOp<ONNXSplitToSequenceOp>();
if (!splitToSequence)
  return false;

Value replaceSequenceAt(
PatternRewriter &rewriter, Location loc, Value sequenceAtResult) {
ONNXSequenceAtOp op =
mlir::cast<ONNXSequenceAtOp>(sequenceAtResult.getDefiningOp());
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here, we can use ONNXSequenceAtOp op = sequenceAtResult.getDefiningOp<ONNXSequenceAtOp>()

@@ -444,7 +444,7 @@ TensorType OnnxBuilder::toTensor(Type input) const {
}

TypeRange OnnxBuilder::toTensors(TypeRange inputs) const {
assert(inputs.size() >= 2 && "Expect at least two inputs");
//assert(inputs.size() >= 2 && "Expect at least two inputs");
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Look like we can remove this completely.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants