Skip to content

Commit

Permalink
Fix issues with onnx.Unique op with dynamic inputs. (#2555)
Browse files Browse the repository at this point in the history
Signed-off-by: Yasushi Negishi <[email protected]>
  • Loading branch information
negiyas authored Oct 17, 2023
1 parent 6989e96 commit cfe1b18
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 5 deletions.
12 changes: 7 additions & 5 deletions src/Conversion/ONNXToKrnl/Tensor/Unique.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,9 @@ struct ONNXUniqueOpLowering : public ConversionPattern {
ONNXUniqueOpShapeHelper shapeHelper(op, operands, &create.krnlIE);
shapeHelper.computeShapeAndAssertOnFailure();
Value X = operandAdaptor.getX();
ArrayRef<int64_t> xShape = getShape(X.getType());
SmallVector<IndexExpr> XDims;
create.krnlIE.getShapeAsDims(X, XDims);

Type elementType = X.getType().cast<MemRefType>().getElementType();
int64_t rank = create.krnlIE.getShapedTypeRank(X);
int64_t sorted = operandAdaptor.getSorted();
Expand Down Expand Up @@ -139,20 +141,20 @@ struct ONNXUniqueOpLowering : public ConversionPattern {
if (axis < 0) {
outputYDims.emplace_back(totalDimExpr);
outputIndexDims.emplace_back(totalDimExpr);
DimIndexExpr inputDimExpr = LiteralIndexExpr(xShape[0]);
DimIndexExpr inputDimExpr = XDims[0];
for (int64_t i = 1; i < rank; i++) {
inputDimExpr = inputDimExpr * LiteralIndexExpr(xShape[i]);
inputDimExpr = inputDimExpr * XDims[i];
}
outputInverseIndexDims.emplace_back(inputDimExpr);
} else {
for (int64_t i = 0; i < rank; i++) {
DimIndexExpr tDimExpr = LiteralIndexExpr(xShape[i]);
DimIndexExpr tDimExpr = XDims[i];
if (i == axis)
tDimExpr = totalDimExpr;
outputYDims.emplace_back(tDimExpr);
}
outputIndexDims.emplace_back(totalDimExpr);
outputInverseIndexDims.emplace_back(LiteralIndexExpr(xShape[axis]));
outputInverseIndexDims.emplace_back(XDims[axis]);
}
//
// Insert an allocation and deallocation for the outputs.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -192,3 +192,31 @@ func.func @unique_with_counts(%arg0: tensor<2x2xi64>) -> tensor<*xi64> {
// CHECK: "krnl.call"([[RES_]], [[RES_1_]], [[RES_1_]]_0, [[RES_1_]]_0, [[RES_1_]]_1, [[X_]], [[CST_1_]], [[CST_1_]]) {funcName = "omTensorUnique", numOfOutput = 5 : si64} : (memref<index>, memref<2x?xi64>, memref<0xi64>, memref<0xi64>, memref<?xi64>, memref<2x2xi64>, i64, i64) -> ()
// CHECK: return [[RES_1_]] : memref<2x?xi64>
// CHECK: }

// -----

func.func @unique_with_dynamic_inputs(%arg0: tensor<?xi64>) -> (tensor<?xi64>, tensor<?xi64>, tensor<?xi64>) {
%Y, %indices, %inverse_indices, %counts = "onnx.Unique"(%arg0) {axis = 0 : si64, sorted = 1 : si64} : (tensor<?xi64>) -> (tensor<?xi64>, none, tensor<?xi64>, tensor<?xi64>)
return %Y, %inverse_indices, %counts : tensor<?xi64>, tensor<?xi64>, tensor<?xi64>

// mlir2FileCheck.py
// CHECK-LABEL: func.func @unique_with_dynamic_inputs
// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<?xi64>) -> (memref<?xi64>, memref<?xi64>, memref<?xi64>) {
// CHECK-DAG: [[CST_1_:%.+]] = arith.constant 1 : i64
// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : i64
// CHECK-DAG: [[CST_0_1_:%.+]] = arith.constant 0 : index
// CHECK-NOT: separator of consecutive DAGs
// CHECK-DAG: [[VAR_dim_:%.+]] = memref.dim [[PARAM_0_]], [[CST_0_1_]] : memref<?xi64>
// CHECK-DAG: [[RES_:%.+]] = memref.alloca() : memref<index>
// CHECK: krnl.store [[CST_0_1_]], [[RES_]][] : memref<index>
// CHECK: "krnl.call"([[RES_]], [[PARAM_0_]], [[CST_0_]], [[CST_1_]]) {funcName = "omTensorUniqueCount", numOfOutput = 1 : si64} : (memref<index>, memref<?xi64>, i64, i64) -> ()
// CHECK: [[LOAD_RES_MEM_:%.+]] = krnl.load [[RES_]][] : memref<index>
// CHECK-DAG: [[RES_1_:%.+]] = memref.alloc([[LOAD_RES_MEM_]]) {{.*}}: memref<?xi64>
// CHECK-DAG: [[RES_2_:%.+]] = memref.alloc() {{.*}}: memref<0xi64>
// CHECK-DAG: [[RES_3_:%.+]] = memref.alloc([[VAR_dim_]]) {{.*}}: memref<?xi64>
// CHECK-DAG: [[RES_4_:%.+]] = memref.alloc([[LOAD_RES_MEM_]]) {{.*}}: memref<?xi64>
// CHECK: krnl.store [[CST_0_1_]], [[RES_]][] : memref<index>
// CHECK: "krnl.call"([[RES_]], [[RES_1_]], [[RES_1_]]_0, [[RES_1_]]_1, [[RES_1_]]_2, [[PARAM_0_]], [[CST_0_]], [[CST_1_]]) {funcName = "omTensorUnique", numOfOutput = 5 : si64} : (memref<index>, memref<?xi64>, memref<0xi64>, memref<?xi64>, memref<?xi64>, memref<?xi64>, i64, i64) -> ()
// CHECK: return [[RES_1_]], [[RES_1_]]_1, [[RES_1_]]_2 : memref<?xi64>, memref<?xi64>, memref<?xi64>
}

0 comments on commit cfe1b18

Please sign in to comment.