Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TOSA] Add ONNX to TOSA ArgMax conversion pass #45

Open
wants to merge 14 commits into
base: feature/onnx_to_torch
Choose a base branch
from
Next Next commit
Update type converter for ONNXToTOSA conversion passes
Philipp Braun committed Jul 15, 2022
commit a5b801e0a397b1ad448c883bb00c637c09710c10
85 changes: 35 additions & 50 deletions src/Conversion/ONNXToTOSA/ConvertONNXToTOSA.cpp
Original file line number Diff line number Diff line change
@@ -12,54 +12,26 @@
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

#include "src/Dialect/ONNX/DialectBuilder.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
#include "src/Dialect/ONNX/ONNXOpsHelper.hpp"
#include "src/Pass/Passes.hpp"
#include "src/Conversion/ONNXToTOSA/ONNXToTOSACommon.hpp"

using namespace mlir;

namespace onnx_mlir {

// This defines a template to construct ops whose legalizations are
// specialized.
template <typename OnnxOpT>
class ConvertOnnxOp : public OpConversionPattern<OnnxOpT> {
public:
using OpConversionPattern<OnnxOpT>::OpConversionPattern;
using OpAdaptor = typename OnnxOpT::Adaptor;
LogicalResult matchAndRewrite(OnnxOpT op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override;
};

template <>
LogicalResult ConvertOnnxOp<ONNXReluOp>::matchAndRewrite(ONNXReluOp op,
OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const {
Value input = adaptor.X();
auto inputTy = input.getType().dyn_cast<TensorType>();

if (!inputTy)
return op.emitError("Only Tensor types supported in TOSA");

if (!inputTy.getElementType().isa<FloatType>()) {
return op.emitError(
"Only floating-point datatype legalization currently supported");
}
inline bool isa_tosa_signed_int(Type type) {
IntegerType intType = type.dyn_cast<IntegerType>();
std::set<unsigned> intWidth{8, 16, 32, 48, 64};
return intType && (intWidth.find(intType.getWidth()) != intWidth.end());
}

// Rescale the clampIn for quantized types. TBD
// Maps to tosa.clamp which has both int and fp limits.
Value clampIn = input;
inline bool isa_tosa_float(Type type) {
return type.isa<BFloat16Type, Float16Type, Float32Type>();
}

rewriter.replaceOpWithNewOp<tosa::ClampOp>(op, op.getType(), clampIn,
rewriter.getI64IntegerAttr(0),
rewriter.getI64IntegerAttr(std::numeric_limits<int32_t>::max()),
rewriter.getF32FloatAttr(0.0f),
rewriter.getF32FloatAttr(std::numeric_limits<float>::max()));
return success();
void populateONNXToTOSAConversionPattern(RewritePatternSet &patterns,
TypeConverter &typeConverter, MLIRContext *ctx) {
// Math
populateLoweringONNXElementwiseOpToTOSAPattern(patterns, typeConverter, ctx);
}

// Performs lowering to TOSA dialect
@@ -79,24 +51,37 @@ struct FrontendToTosaLoweringPass
};

void FrontendToTosaLoweringPass::runOnOperation() {
ModuleOp module = getOperation();
// Define final conversion target
MLIRContext *context = &getContext();
RewritePatternSet patterns(context);
ConversionTarget target(*context);

// We use the type converter to legalize types before any conversion patterns
// are executed. This ensures that we do not need to trigger separate
// conversion failures.
TypeConverter typeConverter;
typeConverter.addConversion([](Type type) { return type; });

typeConverter.addConversion([](Type type) -> Optional<Type> {
if (isa_tosa_signed_int(type) || isa_tosa_float(type))
return type;
return llvm::None;
});
typeConverter.addConversion([](TensorType type) -> Optional<Type> {
if (isa_tosa_signed_int(type.getElementType()) ||
isa_tosa_float(type.getElementType()))
return type;
return llvm::None;
});

// Define legal dialects and operations
target.addLegalDialect<tosa::TosaDialect, func::FuncDialect>();

#define INSERT_ONNXOP_PATTERN(OnnxOp) \
target.addIllegalOp<OnnxOp>(); \
patterns.add<ConvertOnnxOp<OnnxOp>>(typeConverter, context);
INSERT_ONNXOP_PATTERN(ONNXReluOp);
#undef INSERT_ONNXOP_PATTERN
// Define patterns
populateONNXToTOSAConversionPattern(patterns, typeConverter, context);

if (failed(
applyPartialConversion(getOperation(), target, std::move(patterns))))
if (failed(applyPartialConversion(module, target, std::move(patterns)))) {
signalPassFailure();
}
}

std::unique_ptr<Pass> createConvertONNXToTOSAPass() {
47 changes: 47 additions & 0 deletions src/Conversion/ONNXToTOSA/ONNXToTOSACommon.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
/*
* SPDX-License-Identifier: Apache-2.0
*/

//====------ ONNXToTOSACommon.hpp - ONNX dialects to TOSA lowering --------===//
//
// Copyright (c) 2022 Advanced Micro Devices, Inc.
//
// =============================================================================
//
// This file contains common code shared by the functions performing the
// lowering to the TOSA dialect.
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

#include "src/Dialect/ONNX/DialectBuilder.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
#include "src/Dialect/ONNX/ONNXOpsHelper.hpp"
#include "src/Pass/Passes.hpp"
#include "src/Transform/ONNX/ConstPropHelper.hpp"

//===----------------------------------------------------------------------===//
// Functions to add lowering patterns for frontend operations.
//===----------------------------------------------------------------------===//

namespace onnx_mlir {

//===----------------------------------------------------------------------===//
// This is to get a TOSA operation of a given type for a specific operation.
//===----------------------------------------------------------------------===//
template <typename ONNXOp>
struct TOSADialectOp {
using Op = void;
};

template <typename Op>
using TOSAOp = typename TOSADialectOp<Op>::Op;

// `Math` directory methods:
void populateLoweringONNXElementwiseOpToTOSAPattern(
RewritePatternSet &, TypeConverter &, MLIRContext *);
} // namespace onnx_mlir