-
Notifications
You must be signed in to change notification settings - Fork 322
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Transform SequenceAt to split for special cases #3018
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: chentong319 <[email protected]>
Signed-off-by: chentong319 <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM.
ONNXSplitToSequenceOp splitToSequence; | ||
if (!(splitToSequence = mlir::dyn_cast<ONNXSplitToSequenceOp>( | ||
inputSequence.getDefiningOp()))) | ||
return false; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can be shorten with
ONNXSplitToSequenceOp splitToSequence = inputSequence.getDefiningOp<ONNXSplitToSequenceOp>();
if (!splitToSequence)
return false;
Value replaceSequenceAt( | ||
PatternRewriter &rewriter, Location loc, Value sequenceAtResult) { | ||
ONNXSequenceAtOp op = | ||
mlir::cast<ONNXSequenceAtOp>(sequenceAtResult.getDefiningOp()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here, we can use ONNXSequenceAtOp op = sequenceAtResult.getDefiningOp<ONNXSequenceAtOp>()
@@ -444,7 +444,7 @@ TensorType OnnxBuilder::toTensor(Type input) const { | |||
} | |||
|
|||
TypeRange OnnxBuilder::toTensors(TypeRange inputs) const { | |||
assert(inputs.size() >= 2 && "Expect at least two inputs"); | |||
//assert(inputs.size() >= 2 && "Expect at least two inputs"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Look like we can remove this completely.
Certain PyTorch.onnx.export will break the LSTM op into lower level operations and generate SplitToSequence and SequenceAt operation pattern. For example:
ONNX-MLIR currently does not have the lowering for SplitToSequence. In general, sequence related ops are not well optimized in ONNX-MLIR. I found that such code pattern can be converted into tensor operations to avoid the sequence ops.
The two onnx.Split is supposed to be able merged into one. But current onnx-mlir didn't. Need further investigation.
But the exported model has another type of SequenceAt, in which the output type of SequenceAt is different from the element type of the input sequence type. I think it is an error in the exporter. However, I tried to fix this issue in the transformation of this PR.
Output from this PR:
Other small fixes in this PR:
--EmitONNXBasic
will fail for this model.