diff --git a/src/Conversion/ONNXToKrnl/Tensor/Constant.cpp b/src/Conversion/ONNXToKrnl/Tensor/Constant.cpp index 19ef38abbd..dfdc35d910 100644 --- a/src/Conversion/ONNXToKrnl/Tensor/Constant.cpp +++ b/src/Conversion/ONNXToKrnl/Tensor/Constant.cpp @@ -39,8 +39,20 @@ struct ONNXConstantOpLowering : public OpConversionPattern { // Emit the constant global in Krnl dialect. MultiDialectBuilder create(rewriter, loc); - Value constantGlobal = create.krnl.constant( - memRefType, "constant_", constantOp.getValue().value()); + mlir::Attribute constValAttr = constantOp.getValue().value(); + if (memRefType.getElementType().isa()) { + // If the onnx.ConstantOp has string type value attribute, + // The element type of the value attribute of krnl.global op should be + // "!krnl.string" instead of "!onnx.String". + ShapedType constStrType = RankedTensorType::get( + memRefType.getShape(), krnl::StringType::get(rewriter.getContext())); + SmallVector constStrVector( + constValAttr.dyn_cast().getValues()); + ArrayRef constStrValues(constStrVector); + constValAttr = mlir::DenseElementsAttr::get(constStrType, constStrValues); + } + Value constantGlobal = + create.krnl.constant(memRefType, "constant_", constValAttr); // Replace this operation with the generated krnl.global. rewriter.replaceOp(op, constantGlobal); diff --git a/test/mlir/conversion/onnx_to_krnl/Tensor/Constant.mlir b/test/mlir/conversion/onnx_to_krnl/Tensor/Constant.mlir index 7585ebe59a..772143937f 100644 --- a/test/mlir/conversion/onnx_to_krnl/Tensor/Constant.mlir +++ b/test/mlir/conversion/onnx_to_krnl/Tensor/Constant.mlir @@ -8,3 +8,39 @@ func.func private @test_constant_dense_2d_value(%arg0: tensor<1xf32>) -> tensor< // CHECK: return [[GLOBAL]] : memref<3x2xf32> } +// ----- + +func.func @test_constant_string() -> tensor { + %0 = onnx.Constant dense<"1"> : tensor + "func.return"(%0) : (tensor) -> () + // mlir2FileCheck.py + // CHECK-LABEL: func.func @test_constant_string + // CHECK-SAME: () -> memref { + // CHECK: [[VAR_0_:%.+]] = "krnl.global"() {name = "constant_{{[0-9]+}}", shape = [], value = dense<"1"> : tensor} : () -> memref + // CHECK: return [[VAR_0_]] : memref +} + +// ----- + +func.func @test_constant_string_3elem() -> tensor<3x!onnx.String> { + %0 = onnx.Constant dense<["1", "2", "3"]> : tensor<3x!onnx.String> + "func.return"(%0) : (tensor<3x!onnx.String>) -> () + // mlir2FileCheck.py + // CHECK-LABEL: func.func @test_constant_string_3elem + // CHECK-SAME: () -> memref<3x!krnl.string> { + // CHECK: [[VAR_0_:%.+]] = "krnl.global"() {name = "constant_{{[0-9]+}}", shape = [3], value = dense<["1", "2", "3"]> : tensor<3x!krnl.string>} : () -> memref<3x!krnl.string> + // CHECK: return [[VAR_0_]] : memref<3x!krnl.string> +} + +// ----- + +func.func @test_constant_string_3elem2() -> tensor<3x!onnx.String> { + %0 = onnx.Constant dense<"1"> : tensor<3x!onnx.String> + "func.return"(%0) : (tensor<3x!onnx.String>) -> () + // mlir2FileCheck.py + // CHECK-LABEL: func.func @test_constant_string_3elem2 + // CHECK-SAME: () -> memref<3x!krnl.string> { + // CHECK: [[VAR_0_:%.+]] = "krnl.global"() {name = "constant_{{[0-9]+}}", shape = [3], value = dense<"1"> : tensor<3x!krnl.string>} : () -> memref<3x!krnl.string> + // CHECK: return [[VAR_0_]] : memref<3x!krnl.string> + // CHECK: } +}