Skip to content

Commit 51dcbf9

Browse files
authored
Add Shape op verifier (#1711)
* Add Shape op verifier Signed-off-by: Philip Lassen <[email protected]>
1 parent b6a17f6 commit 51dcbf9

File tree

7 files changed

+26
-1
lines changed

7 files changed

+26
-1
lines changed

docs/SupportedONNXOps-cpu.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ Onnx-mlir currently supports ONNX operations targeting up to opset 16. Limitatio
166166
| **SequenceErase** | |unsupported | |
167167
| **SequenceInsert** |11 |Does not support unranked sequence element. | |
168168
| **SequenceLength** | |unsupported | |
169-
| **Shape** |13 | | |
169+
| **Shape** |15 |Does not support start and end attributes. | |
170170
| **Shrink** | |unsupported | |
171171
| **Sigmoid** |13 | | |
172172
| **Sign** |13 | | |

src/Dialect/ONNX/ONNXOps.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3436,6 +3436,18 @@ LogicalResult ONNXShapeOp::inferShapes(
34363436
ONNXShapeOpAdaptor>(*this, elementType);
34373437
}
34383438

3439+
LogicalResult ONNXShapeOp::verify() {
3440+
if (!data().getType().isa<RankedTensorType>())
3441+
return success();
3442+
ONNXShapeOpAdaptor operandAdaptor(*this);
3443+
int64_t start;
3444+
int64_t end;
3445+
std::tie(start, end) = getDataShapeBounds(operandAdaptor);
3446+
if (start > end)
3447+
return emitOpError() << "Start: " << start << " is after End: " << end;
3448+
return success();
3449+
}
3450+
34393451
//===----------------------------------------------------------------------===//
34403452
// Size
34413453
//===----------------------------------------------------------------------===//

src/Dialect/ONNX/ONNXOps.td.inc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5609,6 +5609,7 @@ def ONNXShapeOp:ONNX_Op<"Shape",
56095609
return {4};
56105610
}
56115611
}];
5612+
let hasVerifier = 1;
56125613
}
56135614

56145615
def ONNXShrinkOp:ONNX_Op<"Shrink",

src/Dialect/ONNX/ShapeInference/Shape.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,8 @@ LogicalResult ONNXShapeOpShapeHelper::computeShape(
7070
int64_t end;
7171
std::tie(start, end) = getDataShapeBounds(operandAdaptor);
7272

73+
assert(start <= end && "Start must not be greater than end");
74+
7375
// Output is the actual number of values (1D)
7476
dimsForOutput().emplace_back(LiteralIndexExpr(end - start));
7577

test/backend/inference_backend.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -855,6 +855,7 @@ def get_test_models():
855855
#"test_sequence_insert_at_back_cpu": {STATIC_SHAPE:{}, DYNAMIC_SHAPE:{-1:{-1}}, CONSTANT_INPUT:{-1}},
856856

857857
# ==OP== Shape
858+
# ==LIM== Does not support start and end attributes.
858859
"test_shape_cpu": {STATIC_SHAPE:{}, DYNAMIC_SHAPE:{-1:{-1}}, CONSTANT_INPUT:{-1}},
859860
"test_shape_example_cpu": {STATIC_SHAPE:{}, DYNAMIC_SHAPE:{-1:{-1}}, CONSTANT_INPUT:{-1}},
860861

test/mlir/onnx/invalid.mlir

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -343,6 +343,14 @@ func.func @test_scatterelements_verifier_2(%arg0: tensor<2x2xf32>, %arg1: tensor
343343

344344
// -----
345345

346+
func.func @test_shape_to_dim_positive_axis_verifier(%arg0: tensor<?x256x?xi64>) -> tensor<2xi64> {
347+
// expected-error @+1 {{'onnx.Shape' op Start: 2 is after End: 0}}
348+
%0 = "onnx.Shape"(%arg0) {end = 0 : si64, start = -1 : si64} : (tensor<?x256x?xi64>) -> tensor<2xi64>
349+
return %0 : tensor<2xi64>
350+
}
351+
352+
// -----
353+
346354
func.func @test_logsoftmax_verifier_1(%arg0: tensor<2x2xf32>) -> tensor<*xf32> {
347355
// expected-error @+1 {{onnx.LogSoftmax: 'axis' value is 3, accepted range is [-2, 1]}}
348356
%1 = "onnx.LogSoftmax"(%arg0) {axis = 3 : si64} : (tensor<2x2xf32>) -> tensor<*xf32>

utils/gen_onnx_mlir.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -349,6 +349,7 @@
349349
'ScatterND',
350350
'SequenceEmpty',
351351
'SequenceInsert',
352+
'Shape',
352353
'SpaceToDepth',
353354
'Split',
354355
'SplitToSequence',

0 commit comments

Comments
 (0)