From a5b801e0a397b1ad448c883bb00c637c09710c10 Mon Sep 17 00:00:00 2001 From: Philipp Braun Date: Fri, 15 Jul 2022 08:21:13 +0100 Subject: [PATCH 01/10] Update type converter for ONNXToTOSA conversion passes --- .../ONNXToTOSA/ConvertONNXToTOSA.cpp | 85 ++++++++----------- .../ONNXToTOSA/ONNXToTOSACommon.hpp | 47 ++++++++++ 2 files changed, 82 insertions(+), 50 deletions(-) create mode 100644 src/Conversion/ONNXToTOSA/ONNXToTOSACommon.hpp diff --git a/src/Conversion/ONNXToTOSA/ConvertONNXToTOSA.cpp b/src/Conversion/ONNXToTOSA/ConvertONNXToTOSA.cpp index 06067c4b6a..8de2e06991 100644 --- a/src/Conversion/ONNXToTOSA/ConvertONNXToTOSA.cpp +++ b/src/Conversion/ONNXToTOSA/ConvertONNXToTOSA.cpp @@ -12,54 +12,26 @@ // //===----------------------------------------------------------------------===// -#include "mlir/Dialect/Tosa/IR/TosaOps.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" - -#include "src/Dialect/ONNX/DialectBuilder.hpp" -#include "src/Dialect/ONNX/ONNXOps.hpp" -#include "src/Dialect/ONNX/ONNXOpsHelper.hpp" -#include "src/Pass/Passes.hpp" +#include "src/Conversion/ONNXToTOSA/ONNXToTOSACommon.hpp" using namespace mlir; namespace onnx_mlir { -// This defines a template to construct ops whose legalizations are -// specialized. -template -class ConvertOnnxOp : public OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - using OpAdaptor = typename OnnxOpT::Adaptor; - LogicalResult matchAndRewrite(OnnxOpT op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override; -}; - -template <> -LogicalResult ConvertOnnxOp::matchAndRewrite(ONNXReluOp op, - OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { - Value input = adaptor.X(); - auto inputTy = input.getType().dyn_cast(); - - if (!inputTy) - return op.emitError("Only Tensor types supported in TOSA"); - - if (!inputTy.getElementType().isa()) { - return op.emitError( - "Only floating-point datatype legalization currently supported"); - } +inline bool isa_tosa_signed_int(Type type) { + IntegerType intType = type.dyn_cast(); + std::set intWidth{8, 16, 32, 48, 64}; + return intType && (intWidth.find(intType.getWidth()) != intWidth.end()); +} - // Rescale the clampIn for quantized types. TBD - // Maps to tosa.clamp which has both int and fp limits. - Value clampIn = input; +inline bool isa_tosa_float(Type type) { + return type.isa(); +} - rewriter.replaceOpWithNewOp(op, op.getType(), clampIn, - rewriter.getI64IntegerAttr(0), - rewriter.getI64IntegerAttr(std::numeric_limits::max()), - rewriter.getF32FloatAttr(0.0f), - rewriter.getF32FloatAttr(std::numeric_limits::max())); - return success(); +void populateONNXToTOSAConversionPattern(RewritePatternSet &patterns, + TypeConverter &typeConverter, MLIRContext *ctx) { + // Math + populateLoweringONNXElementwiseOpToTOSAPattern(patterns, typeConverter, ctx); } // Performs lowering to TOSA dialect @@ -79,24 +51,37 @@ struct FrontendToTosaLoweringPass }; void FrontendToTosaLoweringPass::runOnOperation() { + ModuleOp module = getOperation(); + // Define final conversion target MLIRContext *context = &getContext(); RewritePatternSet patterns(context); ConversionTarget target(*context); + // We use the type converter to legalize types before any conversion patterns + // are executed. This ensures that we do not need to trigger separate + // conversion failures. TypeConverter typeConverter; - typeConverter.addConversion([](Type type) { return type; }); - + typeConverter.addConversion([](Type type) -> Optional { + if (isa_tosa_signed_int(type) || isa_tosa_float(type)) + return type; + return llvm::None; + }); + typeConverter.addConversion([](TensorType type) -> Optional { + if (isa_tosa_signed_int(type.getElementType()) || + isa_tosa_float(type.getElementType())) + return type; + return llvm::None; + }); + + // Define legal dialects and operations target.addLegalDialect(); -#define INSERT_ONNXOP_PATTERN(OnnxOp) \ - target.addIllegalOp(); \ - patterns.add>(typeConverter, context); - INSERT_ONNXOP_PATTERN(ONNXReluOp); -#undef INSERT_ONNXOP_PATTERN + // Define patterns + populateONNXToTOSAConversionPattern(patterns, typeConverter, context); - if (failed( - applyPartialConversion(getOperation(), target, std::move(patterns)))) + if (failed(applyPartialConversion(module, target, std::move(patterns)))) { signalPassFailure(); + } } std::unique_ptr createConvertONNXToTOSAPass() { diff --git a/src/Conversion/ONNXToTOSA/ONNXToTOSACommon.hpp b/src/Conversion/ONNXToTOSA/ONNXToTOSACommon.hpp new file mode 100644 index 0000000000..0b650fdf15 --- /dev/null +++ b/src/Conversion/ONNXToTOSA/ONNXToTOSACommon.hpp @@ -0,0 +1,47 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + */ + +//====------ ONNXToTOSACommon.hpp - ONNX dialects to TOSA lowering --------===// +// +// Copyright (c) 2022 Advanced Micro Devices, Inc. +// +// ============================================================================= +// +// This file contains common code shared by the functions performing the +// lowering to the TOSA dialect. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Tosa/IR/TosaOps.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#include "src/Dialect/ONNX/DialectBuilder.hpp" +#include "src/Dialect/ONNX/ONNXOps.hpp" +#include "src/Dialect/ONNX/ONNXOpsHelper.hpp" +#include "src/Pass/Passes.hpp" +#include "src/Transform/ONNX/ConstPropHelper.hpp" + +//===----------------------------------------------------------------------===// +// Functions to add lowering patterns for frontend operations. +//===----------------------------------------------------------------------===// + +namespace onnx_mlir { + +//===----------------------------------------------------------------------===// +// This is to get a TOSA operation of a given type for a specific operation. +//===----------------------------------------------------------------------===// +template +struct TOSADialectOp { + using Op = void; +}; + +template +using TOSAOp = typename TOSADialectOp::Op; + +// `Math` directory methods: +void populateLoweringONNXElementwiseOpToTOSAPattern( + RewritePatternSet &, TypeConverter &, MLIRContext *); +} // namespace onnx_mlir From 733cf9c04b36b1c2be3ca77f8f0c89729bbbf703 Mon Sep 17 00:00:00 2001 From: Philipp Braun Date: Fri, 15 Jul 2022 08:21:43 +0100 Subject: [PATCH 02/10] Add unary op conversion passes --- src/Conversion/ONNXToTOSA/CMakeLists.txt | 3 + .../ONNXToTOSA/Math/Elementwise.cpp | 68 +++++++++++++++++++ 2 files changed, 71 insertions(+) create mode 100644 src/Conversion/ONNXToTOSA/Math/Elementwise.cpp diff --git a/src/Conversion/ONNXToTOSA/CMakeLists.txt b/src/Conversion/ONNXToTOSA/CMakeLists.txt index 392a1dd531..52309020bd 100644 --- a/src/Conversion/ONNXToTOSA/CMakeLists.txt +++ b/src/Conversion/ONNXToTOSA/CMakeLists.txt @@ -2,6 +2,9 @@ add_onnx_mlir_library(OMONNXToTOSA ConvertONNXToTOSA.cpp + + Math/Elementwise.cpp + LINK_LIBS PUBLIC OMONNXOps MLIRTosa diff --git a/src/Conversion/ONNXToTOSA/Math/Elementwise.cpp b/src/Conversion/ONNXToTOSA/Math/Elementwise.cpp new file mode 100644 index 0000000000..b96c9c0dfb --- /dev/null +++ b/src/Conversion/ONNXToTOSA/Math/Elementwise.cpp @@ -0,0 +1,68 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + */ + +//===---------------- Elementwise.cpp - Elementwise Op --------------------===// +// +// Copyright (c) 2022 Advanced Micro Devices, Inc. +// +// ============================================================================= +// +// This file lowers ONNX element-wise operators to TOSA dialect. +// +//===----------------------------------------------------------------------===// + +#include "src/Conversion/ONNXToTOSA/ONNXToTOSACommon.hpp" + +using namespace mlir; + +namespace onnx_mlir { + +namespace { + +template +class ONNXUnaryOpLoweringToTOSA : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + using OpAdaptor = typename ONNXOpT::Adaptor; + LogicalResult matchAndRewrite(ONNXOpT op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + rewriter.replaceOpWithNewOp(op, op.getType(), adaptor.X()); + return success(); + } +}; + +class ONNXReluOpLoweringToTOSA : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite(ONNXReluOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + // Rescale the input for quantized types. TBD + // Maps to tosa.clamp which has both int and fp limits. + Value input = adaptor.X(); + + rewriter.replaceOpWithNewOp(op, op.getType(), input, + rewriter.getI64IntegerAttr(0), + rewriter.getI64IntegerAttr(std::numeric_limits::max()), + rewriter.getF32FloatAttr(0.0f), + rewriter.getF32FloatAttr(std::numeric_limits::max())); + return success(); + } +}; + +} // namespace + +void populateLoweringONNXElementwiseOpToTOSAPattern(RewritePatternSet &patterns, + TypeConverter &typeConverter, MLIRContext *ctx) { + patterns.insert(typeConverter, ctx); + +#define INSERT_UNARY_PATTERN(ONNXOp, TOSAOp) \ + patterns.insert>( \ + typeConverter, ctx); + INSERT_UNARY_PATTERN(ONNXNegOp, tosa::NegateOp) + INSERT_UNARY_PATTERN(ONNXFloorOp, tosa::FloorOp) +#undef INSERT_UNARY_PATTERN +} + +} // namespace onnx_mlir From f9b4e480bd2105e7fded67b55c854996766ae742 Mon Sep 17 00:00:00 2001 From: Philipp Braun Date: Fri, 15 Jul 2022 08:22:14 +0100 Subject: [PATCH 03/10] Add ONNXToTOSA unary lit tests --- .../onnx_to_tosa/Math/Elementwise.mlir | 37 +++++++++++++++++++ 1 file changed, 37 insertions(+) create mode 100644 test/mlir/conversion/onnx_to_tosa/Math/Elementwise.mlir diff --git a/test/mlir/conversion/onnx_to_tosa/Math/Elementwise.mlir b/test/mlir/conversion/onnx_to_tosa/Math/Elementwise.mlir new file mode 100644 index 0000000000..e454dbb0de --- /dev/null +++ b/test/mlir/conversion/onnx_to_tosa/Math/Elementwise.mlir @@ -0,0 +1,37 @@ +// RUN: onnx-mlir-opt --convert-onnx-to-tosa %s -split-input-file | FileCheck %s + +func.func @test_relu(%arg0 : tensor<10x10xf32>) -> tensor<10x10xf32> { + %0 = "onnx.Relu"(%arg0) : (tensor<10x10xf32>) -> tensor<10x10xf32> + "func.return"(%0) : (tensor<10x10xf32>) -> () +// CHECK-LABEL: func @test_relu +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<10x10xf32>) -> tensor<10x10xf32> { +// CHECK-NEXT: [[VAR_0_:%.+]] = "tosa.clamp"(%arg0) {max_fp = 3.40282347E+38 : f32, max_int = 2147483647 : i64, min_fp = 0.000000e+00 : f32, min_int = 0 : i64} : (tensor<10x10xf32>) -> tensor<10x10xf32> +// CHECK-NEXT: return [[VAR_0_]] : tensor<10x10xf32> +// CHECK-NEXT: } +} + +func.func @test_relu_dynamic(%arg0 : tensor) -> tensor { + %0 = "onnx.Relu"(%arg0) : (tensor) -> tensor + "func.return"(%0) : (tensor) -> () +// CHECK-LABEL: func @test_relu_dynamic +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor) -> tensor { +// CHECK-NEXT: [[VAR_0_:%.+]] = "tosa.clamp"(%arg0) {max_fp = 3.40282347E+38 : f32, max_int = 2147483647 : i64, min_fp = 0.000000e+00 : f32, min_int = 0 : i64} : (tensor) -> tensor +// CHECK-NEXT: return [[VAR_0_]] : tensor +// CHECK-NEXT: } +} + +func.func @test_neg(%arg0: tensor<10x10xf32>) -> tensor<10x10xf32> { + %0 = "onnx.Neg"(%arg0) : (tensor<10x10xf32>) -> tensor<10x10xf32> + "func.return"(%0) : (tensor<10x10xf32>) -> () +// CHECK-LABEL: func @test_neg +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<10x10xf32>) -> tensor<10x10xf32> { +// CHECK-NEXT: [[VAR_0_:%.+]] = "tosa.negate"([[PARAM_0_]]) : (tensor<10x10xf32>) -> tensor<10x10xf32> +} + +func.func @test_floor(%arg0: tensor<10x10xf32>) -> tensor<10x10xf32> { + %0 = "onnx.Floor"(%arg0) : (tensor<10x10xf32>) -> tensor<10x10xf32> + "func.return"(%0) : (tensor<10x10xf32>) -> () +// CHECK-LABEL: func @test_floor +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<10x10xf32>) -> tensor<10x10xf32> { +// CHECK-NEXT: [[VAR_0_:%.+]] = "tosa.floor"([[PARAM_0_]]) : (tensor<10x10xf32>) -> tensor<10x10xf32> +} From fe8b4dcecdf680d83e4469e2a087562cf9015368 Mon Sep 17 00:00:00 2001 From: Philipp Braun Date: Fri, 15 Jul 2022 09:14:04 +0100 Subject: [PATCH 04/10] Add ONNX to TOSA ArgMax conversion pass --- src/Conversion/ONNXToTOSA/CMakeLists.txt | 1 + .../ONNXToTOSA/ConvertONNXToTOSA.cpp | 2 + .../ONNXToTOSA/ONNXToTOSACommon.hpp | 4 ++ src/Conversion/ONNXToTOSA/Tensor/ArgMax.cpp | 50 +++++++++++++++++++ .../onnx_to_tosa/Tensor/ArgMax.mlir | 11 ++++ 5 files changed, 68 insertions(+) create mode 100644 src/Conversion/ONNXToTOSA/Tensor/ArgMax.cpp create mode 100644 test/mlir/conversion/onnx_to_tosa/Tensor/ArgMax.mlir diff --git a/src/Conversion/ONNXToTOSA/CMakeLists.txt b/src/Conversion/ONNXToTOSA/CMakeLists.txt index 52309020bd..42fc319e6a 100644 --- a/src/Conversion/ONNXToTOSA/CMakeLists.txt +++ b/src/Conversion/ONNXToTOSA/CMakeLists.txt @@ -4,6 +4,7 @@ add_onnx_mlir_library(OMONNXToTOSA ConvertONNXToTOSA.cpp Math/Elementwise.cpp + Tensor/ArgMax.cpp LINK_LIBS PUBLIC OMONNXOps diff --git a/src/Conversion/ONNXToTOSA/ConvertONNXToTOSA.cpp b/src/Conversion/ONNXToTOSA/ConvertONNXToTOSA.cpp index 8de2e06991..d78328ee00 100644 --- a/src/Conversion/ONNXToTOSA/ConvertONNXToTOSA.cpp +++ b/src/Conversion/ONNXToTOSA/ConvertONNXToTOSA.cpp @@ -32,6 +32,8 @@ void populateONNXToTOSAConversionPattern(RewritePatternSet &patterns, TypeConverter &typeConverter, MLIRContext *ctx) { // Math populateLoweringONNXElementwiseOpToTOSAPattern(patterns, typeConverter, ctx); + // Tensor + populateLoweringONNXArgMaxOpToTOSAPattern(patterns, typeConverter, ctx); } // Performs lowering to TOSA dialect diff --git a/src/Conversion/ONNXToTOSA/ONNXToTOSACommon.hpp b/src/Conversion/ONNXToTOSA/ONNXToTOSACommon.hpp index 0b650fdf15..0eaa02051f 100644 --- a/src/Conversion/ONNXToTOSA/ONNXToTOSACommon.hpp +++ b/src/Conversion/ONNXToTOSA/ONNXToTOSACommon.hpp @@ -44,4 +44,8 @@ using TOSAOp = typename TOSADialectOp::Op; // `Math` directory methods: void populateLoweringONNXElementwiseOpToTOSAPattern( RewritePatternSet &, TypeConverter &, MLIRContext *); +// `Tensor` directory methods: +void populateLoweringONNXArgMaxOpToTOSAPattern( + RewritePatternSet &, TypeConverter &, MLIRContext *); + } // namespace onnx_mlir diff --git a/src/Conversion/ONNXToTOSA/Tensor/ArgMax.cpp b/src/Conversion/ONNXToTOSA/Tensor/ArgMax.cpp new file mode 100644 index 0000000000..da2e389aa6 --- /dev/null +++ b/src/Conversion/ONNXToTOSA/Tensor/ArgMax.cpp @@ -0,0 +1,50 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + */ + +//===----------------- Softmax.cpp - Softmax Op ---------------------------===// +// +// Copyright (c) 2022 Advanced Micro Devices, Inc. +// +// ============================================================================= +// +// This file lowers ONNX ArgMax operator to TOSA dialect. +// +//===----------------------------------------------------------------------===// + +#include "src/Conversion/ONNXToTOSA/ONNXToTOSACommon.hpp" + +using namespace mlir; + +namespace onnx_mlir { + +namespace { + +class ONNXArgMaxOpLoweringToTOSA : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite(ONNXArgMaxOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + if (adaptor.keepdims() != 1) + return rewriter.notifyMatchFailure(op, "keepdims != 1 is not supported"); + + if (adaptor.select_last_index() != 0) + return rewriter.notifyMatchFailure( + op, "select_last_index != 0 is not supported"); + + IntegerAttr axis = rewriter.getI64IntegerAttr(adaptor.axis()); + rewriter.replaceOpWithNewOp( + op, op.getType(), adaptor.data(), axis); + return success(); + } +}; + +} // namespace + +void populateLoweringONNXArgMaxOpToTOSAPattern(RewritePatternSet &patterns, + TypeConverter &typeConverter, MLIRContext *ctx) { + patterns.insert(typeConverter, ctx); +} + +} // namespace onnx_mlir diff --git a/test/mlir/conversion/onnx_to_tosa/Tensor/ArgMax.mlir b/test/mlir/conversion/onnx_to_tosa/Tensor/ArgMax.mlir new file mode 100644 index 0000000000..d55c571e94 --- /dev/null +++ b/test/mlir/conversion/onnx_to_tosa/Tensor/ArgMax.mlir @@ -0,0 +1,11 @@ +// RUN: onnx-mlir-opt --convert-onnx-to-tosa %s -split-input-file | FileCheck %s + +func.func @test_argmax(%arg0: tensor<8x16x32xf32>) -> tensor<8x16x32xi64> { + %0 = "onnx.ArgMax"(%arg0) {axis = 2 : si64, keepdims = 1 : si64, onnx_node_name = "ArgMax_0"} : (tensor<8x16x32xf32>) -> tensor<8x16x32xi64> + return %0 : tensor<8x16x32xi64> +// CHECK-LABEL: func @test_argmax +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<8x16x32xf32>) -> tensor<8x16x32xi64> { +// CHECK-NEXT: [[VAR_0_:%.+]] = "tosa.argmax"([[PARAM_0_]]) {axis = 2 : i64} : (tensor<8x16x32xf32>) -> tensor<8x16x32xi64> +// CHECK-NEXT: return [[VAR_0_]] : tensor<8x16x32xi64> +// CHECK-NEXT: } +} From c070a5e35b261c1faf34c8a7aff25374972f8896 Mon Sep 17 00:00:00 2001 From: Philipp Braun Date: Fri, 15 Jul 2022 09:16:36 +0100 Subject: [PATCH 05/10] Update ONNXToTOSA unary ops lit test --- test/mlir/conversion/onnx_to_tosa/Math/Elementwise.mlir | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/mlir/conversion/onnx_to_tosa/Math/Elementwise.mlir b/test/mlir/conversion/onnx_to_tosa/Math/Elementwise.mlir index e454dbb0de..0517cc0d9f 100644 --- a/test/mlir/conversion/onnx_to_tosa/Math/Elementwise.mlir +++ b/test/mlir/conversion/onnx_to_tosa/Math/Elementwise.mlir @@ -5,7 +5,7 @@ func.func @test_relu(%arg0 : tensor<10x10xf32>) -> tensor<10x10xf32> { "func.return"(%0) : (tensor<10x10xf32>) -> () // CHECK-LABEL: func @test_relu // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<10x10xf32>) -> tensor<10x10xf32> { -// CHECK-NEXT: [[VAR_0_:%.+]] = "tosa.clamp"(%arg0) {max_fp = 3.40282347E+38 : f32, max_int = 2147483647 : i64, min_fp = 0.000000e+00 : f32, min_int = 0 : i64} : (tensor<10x10xf32>) -> tensor<10x10xf32> +// CHECK-NEXT: [[VAR_0_:%.+]] = "tosa.clamp"([[PARAM_0_]]) {max_fp = 3.40282347E+38 : f32, max_int = 2147483647 : i64, min_fp = 0.000000e+00 : f32, min_int = 0 : i64} : (tensor<10x10xf32>) -> tensor<10x10xf32> // CHECK-NEXT: return [[VAR_0_]] : tensor<10x10xf32> // CHECK-NEXT: } } @@ -15,7 +15,7 @@ func.func @test_relu_dynamic(%arg0 : tensor) -> tensor { "func.return"(%0) : (tensor) -> () // CHECK-LABEL: func @test_relu_dynamic // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor) -> tensor { -// CHECK-NEXT: [[VAR_0_:%.+]] = "tosa.clamp"(%arg0) {max_fp = 3.40282347E+38 : f32, max_int = 2147483647 : i64, min_fp = 0.000000e+00 : f32, min_int = 0 : i64} : (tensor) -> tensor +// CHECK-NEXT: [[VAR_0_:%.+]] = "tosa.clamp"([[PARAM_0_]]) {max_fp = 3.40282347E+38 : f32, max_int = 2147483647 : i64, min_fp = 0.000000e+00 : f32, min_int = 0 : i64} : (tensor) -> tensor // CHECK-NEXT: return [[VAR_0_]] : tensor // CHECK-NEXT: } } From c10d29b0e773ddafd878395d5ef14d332bd466bd Mon Sep 17 00:00:00 2001 From: Philipp Braun Date: Fri, 15 Jul 2022 09:42:39 +0100 Subject: [PATCH 06/10] Update ArgMax title --- src/Conversion/ONNXToTOSA/Tensor/ArgMax.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Conversion/ONNXToTOSA/Tensor/ArgMax.cpp b/src/Conversion/ONNXToTOSA/Tensor/ArgMax.cpp index da2e389aa6..fedc8c3dc9 100644 --- a/src/Conversion/ONNXToTOSA/Tensor/ArgMax.cpp +++ b/src/Conversion/ONNXToTOSA/Tensor/ArgMax.cpp @@ -2,7 +2,7 @@ * SPDX-License-Identifier: Apache-2.0 */ -//===----------------- Softmax.cpp - Softmax Op ---------------------------===// +//===------------------- ArgMax.cpp - ArgMax Op ---------------------------===// // // Copyright (c) 2022 Advanced Micro Devices, Inc. // From e86ace9a866154b5e3dd79cc16f868613cd00aef Mon Sep 17 00:00:00 2001 From: Philipp Braun Date: Fri, 15 Jul 2022 11:29:14 +0100 Subject: [PATCH 07/10] Mark unary ONNX ops illegal & fail on quantized types --- .../ONNXToTOSA/ConvertONNXToTOSA.cpp | 24 ++++++++++--------- .../ONNXToTOSA/Math/Elementwise.cpp | 14 +++++++---- .../ONNXToTOSA/ONNXToTOSACommon.hpp | 4 +++- 3 files changed, 26 insertions(+), 16 deletions(-) diff --git a/src/Conversion/ONNXToTOSA/ConvertONNXToTOSA.cpp b/src/Conversion/ONNXToTOSA/ConvertONNXToTOSA.cpp index 8de2e06991..069ec82712 100644 --- a/src/Conversion/ONNXToTOSA/ConvertONNXToTOSA.cpp +++ b/src/Conversion/ONNXToTOSA/ConvertONNXToTOSA.cpp @@ -18,20 +18,23 @@ using namespace mlir; namespace onnx_mlir { -inline bool isa_tosa_signed_int(Type type) { +static bool isSignedInt(Type type) { IntegerType intType = type.dyn_cast(); std::set intWidth{8, 16, 32, 48, 64}; - return intType && (intWidth.find(intType.getWidth()) != intWidth.end()); + return intType && intType.isSigned() && + (intWidth.find(intType.getWidth()) != intWidth.end()); } -inline bool isa_tosa_float(Type type) { +static bool isFloat(Type type) { return type.isa(); } -void populateONNXToTOSAConversionPattern(RewritePatternSet &patterns, - TypeConverter &typeConverter, MLIRContext *ctx) { +void populateONNXToTOSAConversionPattern(ConversionTarget &target, + RewritePatternSet &patterns, TypeConverter &typeConverter, + MLIRContext *ctx) { // Math - populateLoweringONNXElementwiseOpToTOSAPattern(patterns, typeConverter, ctx); + populateLoweringONNXElementwiseOpToTOSAPattern( + target, patterns, typeConverter, ctx); } // Performs lowering to TOSA dialect @@ -62,13 +65,12 @@ void FrontendToTosaLoweringPass::runOnOperation() { // conversion failures. TypeConverter typeConverter; typeConverter.addConversion([](Type type) -> Optional { - if (isa_tosa_signed_int(type) || isa_tosa_float(type)) + if (isSignedInt(type) || isFloat(type)) return type; return llvm::None; }); - typeConverter.addConversion([](TensorType type) -> Optional { - if (isa_tosa_signed_int(type.getElementType()) || - isa_tosa_float(type.getElementType())) + typeConverter.addConversion([&](TensorType type) -> Optional { + if (typeConverter.isLegal(type.getElementType())) return type; return llvm::None; }); @@ -77,7 +79,7 @@ void FrontendToTosaLoweringPass::runOnOperation() { target.addLegalDialect(); // Define patterns - populateONNXToTOSAConversionPattern(patterns, typeConverter, context); + populateONNXToTOSAConversionPattern(target, patterns, typeConverter, context); if (failed(applyPartialConversion(module, target, std::move(patterns)))) { signalPassFailure(); diff --git a/src/Conversion/ONNXToTOSA/Math/Elementwise.cpp b/src/Conversion/ONNXToTOSA/Math/Elementwise.cpp index b96c9c0dfb..0a53c98a06 100644 --- a/src/Conversion/ONNXToTOSA/Math/Elementwise.cpp +++ b/src/Conversion/ONNXToTOSA/Math/Elementwise.cpp @@ -38,10 +38,14 @@ class ONNXReluOpLoweringToTOSA : public OpConversionPattern { LogicalResult matchAndRewrite(ONNXReluOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - // Rescale the input for quantized types. TBD - // Maps to tosa.clamp which has both int and fp limits. Value input = adaptor.X(); + // Rescale the input for quantized types. TBD + if (input.getType().isa()) + return rewriter.notifyMatchFailure( + op, "quantized types are not supported"); + + // Maps to `tosa.clamp` which has both int and fp limits. rewriter.replaceOpWithNewOp(op, op.getType(), input, rewriter.getI64IntegerAttr(0), rewriter.getI64IntegerAttr(std::numeric_limits::max()), @@ -53,11 +57,13 @@ class ONNXReluOpLoweringToTOSA : public OpConversionPattern { } // namespace -void populateLoweringONNXElementwiseOpToTOSAPattern(RewritePatternSet &patterns, - TypeConverter &typeConverter, MLIRContext *ctx) { +void populateLoweringONNXElementwiseOpToTOSAPattern(ConversionTarget &target, + RewritePatternSet &patterns, TypeConverter &typeConverter, + MLIRContext *ctx) { patterns.insert(typeConverter, ctx); #define INSERT_UNARY_PATTERN(ONNXOp, TOSAOp) \ + target.addIllegalOp(); \ patterns.insert>( \ typeConverter, ctx); INSERT_UNARY_PATTERN(ONNXNegOp, tosa::NegateOp) diff --git a/src/Conversion/ONNXToTOSA/ONNXToTOSACommon.hpp b/src/Conversion/ONNXToTOSA/ONNXToTOSACommon.hpp index 0b650fdf15..ba5dd71756 100644 --- a/src/Conversion/ONNXToTOSA/ONNXToTOSACommon.hpp +++ b/src/Conversion/ONNXToTOSA/ONNXToTOSACommon.hpp @@ -13,7 +13,9 @@ // //===----------------------------------------------------------------------===// +#include "mlir/Dialect/Quant/QuantTypes.h" #include "mlir/Dialect/Tosa/IR/TosaOps.h" + #include "mlir/IR/MLIRContext.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" @@ -43,5 +45,5 @@ using TOSAOp = typename TOSADialectOp::Op; // `Math` directory methods: void populateLoweringONNXElementwiseOpToTOSAPattern( - RewritePatternSet &, TypeConverter &, MLIRContext *); + ConversionTarget &, RewritePatternSet &, TypeConverter &, MLIRContext *); } // namespace onnx_mlir From dd933f506b2b3a80059f367d8ca32c449d213641 Mon Sep 17 00:00:00 2001 From: Philipp Braun Date: Fri, 15 Jul 2022 11:50:41 +0100 Subject: [PATCH 08/10] Align TOSA lit tests --- .../conversion/onnx_to_tosa/Math/Elementwise.mlir | 8 ++++---- test/mlir/tosa/onnx_lowering.mlir | 11 ----------- 2 files changed, 4 insertions(+), 15 deletions(-) delete mode 100644 test/mlir/tosa/onnx_lowering.mlir diff --git a/test/mlir/conversion/onnx_to_tosa/Math/Elementwise.mlir b/test/mlir/conversion/onnx_to_tosa/Math/Elementwise.mlir index 0517cc0d9f..b63b0cac7d 100644 --- a/test/mlir/conversion/onnx_to_tosa/Math/Elementwise.mlir +++ b/test/mlir/conversion/onnx_to_tosa/Math/Elementwise.mlir @@ -1,4 +1,4 @@ -// RUN: onnx-mlir-opt --convert-onnx-to-tosa %s -split-input-file | FileCheck %s +// RUN: onnx-mlir-opt --shape-inference --convert-onnx-to-tosa %s -split-input-file | FileCheck %s func.func @test_relu(%arg0 : tensor<10x10xf32>) -> tensor<10x10xf32> { %0 = "onnx.Relu"(%arg0) : (tensor<10x10xf32>) -> tensor<10x10xf32> @@ -10,9 +10,9 @@ func.func @test_relu(%arg0 : tensor<10x10xf32>) -> tensor<10x10xf32> { // CHECK-NEXT: } } -func.func @test_relu_dynamic(%arg0 : tensor) -> tensor { - %0 = "onnx.Relu"(%arg0) : (tensor) -> tensor - "func.return"(%0) : (tensor) -> () +func.func @test_relu_dynamic(%arg0 : tensor) -> tensor<*xf32> { + %0 = "onnx.Relu"(%arg0) : (tensor) -> tensor<*xf32> + "func.return"(%0) : (tensor<*xf32>) -> () // CHECK-LABEL: func @test_relu_dynamic // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor) -> tensor { // CHECK-NEXT: [[VAR_0_:%.+]] = "tosa.clamp"([[PARAM_0_]]) {max_fp = 3.40282347E+38 : f32, max_int = 2147483647 : i64, min_fp = 0.000000e+00 : f32, min_int = 0 : i64} : (tensor) -> tensor diff --git a/test/mlir/tosa/onnx_lowering.mlir b/test/mlir/tosa/onnx_lowering.mlir deleted file mode 100644 index 8ac357ab21..0000000000 --- a/test/mlir/tosa/onnx_lowering.mlir +++ /dev/null @@ -1,11 +0,0 @@ -// RUN: onnx-mlir-opt -O3 --shape-inference --convert-onnx-to-tosa %s -split-input-file | FileCheck %s - -func private @test_relu(%arg0 : tensor) -> tensor<*xf32> { - %0 = "onnx.Relu"(%arg0) : (tensor) -> tensor<*xf32> - "func.return"(%0) : (tensor<*xf32>) -> () - -// CHECK-LABEL: func private @test_relu( -// CHECK-SAME: [[INPUT:%.+]]: tensor) -> tensor { -// CHECK: [[OUTPUT:%.+]] = "tosa.clamp"([[INPUT]]) {max_fp = 3.40282347E+38 : f32, max_int = 2147483647 : i64, min_fp = 0.000000e+00 : f32, min_int = 0 : i64} : (tensor) -> tensor -// CHECK: return [[OUTPUT]] : tensor -} From 73bbeac30160178114c91af4a63776a97853e358 Mon Sep 17 00:00:00 2001 From: Philipp Braun Date: Mon, 18 Jul 2022 08:36:36 +0100 Subject: [PATCH 09/10] Remove quantization check --- src/Conversion/ONNXToTOSA/Math/Elementwise.cpp | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/src/Conversion/ONNXToTOSA/Math/Elementwise.cpp b/src/Conversion/ONNXToTOSA/Math/Elementwise.cpp index 0a53c98a06..9ba967bcac 100644 --- a/src/Conversion/ONNXToTOSA/Math/Elementwise.cpp +++ b/src/Conversion/ONNXToTOSA/Math/Elementwise.cpp @@ -40,11 +40,8 @@ class ONNXReluOpLoweringToTOSA : public OpConversionPattern { Value input = adaptor.X(); - // Rescale the input for quantized types. TBD - if (input.getType().isa()) - return rewriter.notifyMatchFailure( - op, "quantized types are not supported"); - + // Quantized types are not supported right now. Rescale the input for + // quantized types. (TBD) // Maps to `tosa.clamp` which has both int and fp limits. rewriter.replaceOpWithNewOp(op, op.getType(), input, rewriter.getI64IntegerAttr(0), From 105cdc29f59f821a7f06e17f179e762fc02e3a82 Mon Sep 17 00:00:00 2001 From: Philipp Braun Date: Mon, 18 Jul 2022 14:48:53 +0100 Subject: [PATCH 10/10] Update comments --- src/Conversion/ONNXToTOSA/ConvertONNXToTOSA.cpp | 2 +- src/Conversion/ONNXToTOSA/Math/Elementwise.cpp | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/Conversion/ONNXToTOSA/ConvertONNXToTOSA.cpp b/src/Conversion/ONNXToTOSA/ConvertONNXToTOSA.cpp index ffc1fc6897..ed7eec5424 100644 --- a/src/Conversion/ONNXToTOSA/ConvertONNXToTOSA.cpp +++ b/src/Conversion/ONNXToTOSA/ConvertONNXToTOSA.cpp @@ -64,7 +64,7 @@ void FrontendToTosaLoweringPass::runOnOperation() { // We use the type converter to legalize types before any conversion patterns // are executed. This ensures that we do not need to trigger separate - // conversion failures. + // conversion failures. Quantized types are not supported right now. TypeConverter typeConverter; typeConverter.addConversion([](Type type) -> Optional { if (isSignedInt(type) || isFloat(type)) diff --git a/src/Conversion/ONNXToTOSA/Math/Elementwise.cpp b/src/Conversion/ONNXToTOSA/Math/Elementwise.cpp index 9ba967bcac..cf67072c09 100644 --- a/src/Conversion/ONNXToTOSA/Math/Elementwise.cpp +++ b/src/Conversion/ONNXToTOSA/Math/Elementwise.cpp @@ -40,8 +40,8 @@ class ONNXReluOpLoweringToTOSA : public OpConversionPattern { Value input = adaptor.X(); - // Quantized types are not supported right now. Rescale the input for - // quantized types. (TBD) + // Quantized types are not supported right now (in type conversion). + // Once they are, the input should be rescaled for quantized types. (TBD) // Maps to `tosa.clamp` which has both int and fp limits. rewriter.replaceOpWithNewOp(op, op.getType(), input, rewriter.getI64IntegerAttr(0),