Skip to content

Commit

Permalink
Fix onnx-to-krnl lowering of onnx.ConstantOp with string type values. (
Browse files Browse the repository at this point in the history
…#2574)

Signed-off-by: Yasushi Negishi <[email protected]>
  • Loading branch information
negiyas authored Oct 25, 2023
1 parent 5db6e46 commit 968bf5f
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 2 deletions.
16 changes: 14 additions & 2 deletions src/Conversion/ONNXToKrnl/Tensor/Constant.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,20 @@ struct ONNXConstantOpLowering : public OpConversionPattern<ONNXConstantOp> {

// Emit the constant global in Krnl dialect.
MultiDialectBuilder<KrnlBuilder> create(rewriter, loc);
Value constantGlobal = create.krnl.constant(
memRefType, "constant_", constantOp.getValue().value());
mlir::Attribute constValAttr = constantOp.getValue().value();
if (memRefType.getElementType().isa<krnl::StringType>()) {
// If the onnx.ConstantOp has string type value attribute,
// The element type of the value attribute of krnl.global op should be
// "!krnl.string" instead of "!onnx.String".
ShapedType constStrType = RankedTensorType::get(
memRefType.getShape(), krnl::StringType::get(rewriter.getContext()));
SmallVector<StringRef> constStrVector(
constValAttr.dyn_cast<DenseElementsAttr>().getValues<StringAttr>());
ArrayRef<StringRef> constStrValues(constStrVector);
constValAttr = mlir::DenseElementsAttr::get(constStrType, constStrValues);
}
Value constantGlobal =
create.krnl.constant(memRefType, "constant_", constValAttr);

// Replace this operation with the generated krnl.global.
rewriter.replaceOp(op, constantGlobal);
Expand Down
36 changes: 36 additions & 0 deletions test/mlir/conversion/onnx_to_krnl/Tensor/Constant.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,39 @@ func.func private @test_constant_dense_2d_value(%arg0: tensor<1xf32>) -> tensor<
// CHECK: return [[GLOBAL]] : memref<3x2xf32>
}

// -----

func.func @test_constant_string() -> tensor<!onnx.String> {
%0 = onnx.Constant dense<"1"> : tensor<!onnx.String>
"func.return"(%0) : (tensor<!onnx.String>) -> ()
// mlir2FileCheck.py
// CHECK-LABEL: func.func @test_constant_string
// CHECK-SAME: () -> memref<!krnl.string> {
// CHECK: [[VAR_0_:%.+]] = "krnl.global"() {name = "constant_{{[0-9]+}}", shape = [], value = dense<"1"> : tensor<!krnl.string>} : () -> memref<!krnl.string>
// CHECK: return [[VAR_0_]] : memref<!krnl.string>
}

// -----

func.func @test_constant_string_3elem() -> tensor<3x!onnx.String> {
%0 = onnx.Constant dense<["1", "2", "3"]> : tensor<3x!onnx.String>
"func.return"(%0) : (tensor<3x!onnx.String>) -> ()
// mlir2FileCheck.py
// CHECK-LABEL: func.func @test_constant_string_3elem
// CHECK-SAME: () -> memref<3x!krnl.string> {
// CHECK: [[VAR_0_:%.+]] = "krnl.global"() {name = "constant_{{[0-9]+}}", shape = [3], value = dense<["1", "2", "3"]> : tensor<3x!krnl.string>} : () -> memref<3x!krnl.string>
// CHECK: return [[VAR_0_]] : memref<3x!krnl.string>
}

// -----

func.func @test_constant_string_3elem2() -> tensor<3x!onnx.String> {
%0 = onnx.Constant dense<"1"> : tensor<3x!onnx.String>
"func.return"(%0) : (tensor<3x!onnx.String>) -> ()
// mlir2FileCheck.py
// CHECK-LABEL: func.func @test_constant_string_3elem2
// CHECK-SAME: () -> memref<3x!krnl.string> {
// CHECK: [[VAR_0_:%.+]] = "krnl.global"() {name = "constant_{{[0-9]+}}", shape = [3], value = dense<"1"> : tensor<3x!krnl.string>} : () -> memref<3x!krnl.string>
// CHECK: return [[VAR_0_]] : memref<3x!krnl.string>
// CHECK: }
}

0 comments on commit 968bf5f

Please sign in to comment.