From afa19b08eb7d564842d2ae7ba8a1de486156138c Mon Sep 17 00:00:00 2001 From: Soren Lassen Date: Thu, 18 May 2023 07:33:52 +0200 Subject: [PATCH] combine op_dialect_version_map_, import_handler_map_ into onnx_ops_map_ Signed-off-by: Soren Lassen --- src/Builder/FrontendDialectTransformer.cpp | 136 +++++++++++---------- 1 file changed, 72 insertions(+), 64 deletions(-) diff --git a/src/Builder/FrontendDialectTransformer.cpp b/src/Builder/FrontendDialectTransformer.cpp index 7acaf6eb76..d1ed97122a 100644 --- a/src/Builder/FrontendDialectTransformer.cpp +++ b/src/Builder/FrontendDialectTransformer.cpp @@ -102,18 +102,25 @@ class FrontendGenImpl { ModuleOp module_; OpBuilder builder_; - // onnxop: list of versions for dialect - std::unordered_map> op_dialect_version_map_; + using ImportHandlerType = void (FrontendGenImpl::*)(const onnx::NodeProto &); + + struct VersionedHandler { + int version; + ImportHandlerType handler; + }; + + using ONNXOpVersions = SmallVector; + + // Maps NodeProto::op_type() to sorted vector of (version, handler) pairs. + // TODO: Key by (domain, op_type) pair so we don't rely on names being unique + // across all domains. + std::unordered_map onnx_ops_map_; // mapping between string name and symbol ValueSymbolMapping frontend_symbols_; ModelInputShaper modelInputShaper_; - using ImportHandlerType = void (FrontendGenImpl::*)(const onnx::NodeProto &); - - std::unordered_map import_handler_map_; - // The total number of elements in all initializers. This value is a rough // counter of the number of parameters in a model. int64_t num_of_parameters_ = 0; @@ -682,45 +689,6 @@ class FrontendGenImpl { node.op_type(), version, node.domain()); } - std::string GetImportVersionOfNode(const onnx::NodeProto &node) { - int64_t version = GetDomainVersion(node.domain()); - if (version == 0) - return ""; - - LLVM_DEBUG(llvm::dbgs() << DEBUG_TYPE << ": Importing ONNX" - << node.op_type() << " (" << node.name() << ")" - << ", Opset: " << version << "\n"); - - auto opset_list_it = op_dialect_version_map_.find(node.op_type()); - - // Custom ops may not be present in op_dialect_version_map_. If no version - // info is found, treat as unversioned (no renaming). - if (opset_list_it == op_dialect_version_map_.end()) - return ""; - - auto opset_list = opset_list_it->second; - - // A new opset is added to onnx-mlir when it becomes imcompactible. - // But the lowest opset in op_dialect_version_map_ is an exception. - // It is the current opset when onnx-mlir project is started. - // All opset lower than the last opset should use the last opset(version) - if (node.domain().compare("ai.onnx.ml") != 0 && - version < opset_list.back() && version < MINIMUM_SUPPORTED_OPSET) - llvm::outs() << "Warning: ONNX " << node.op_type() - << " in your model is using Opset " << version - << ", which is quite old. Please consider regenerating your " - "model with a newer Opset.\n"; - - for (int i = opset_list.size() - 1; i > 0; i--) { - if (version < opset_list[i - 1]) { - LLVM_DEBUG(llvm::dbgs() << DEBUG_TYPE << ": - use Opset " - << opset_list[i] << "\n"); - return "V" + std::to_string(opset_list[i]); - } - } - return ""; - } - func::FuncOp CreateFuncOp( std::string namePrefix, TypeRange operandTypes, TypeRange resultTypes) { auto funcType = builder_.getFunctionType(operandTypes, resultTypes); @@ -912,16 +880,58 @@ class FrontendGenImpl { } } + bool TryImportONNXNode(const onnx::NodeProto &node) { + int64_t version = GetDomainVersion(node.domain()); + if (version == 0) { + // Unknown domain. + return false; + } + + LLVM_DEBUG(llvm::dbgs() << DEBUG_TYPE << ": Importing ONNX" + << node.op_type() << " (" << node.name() << ")" + << ", Opset: " << version << "\n"); + + auto versions_it = onnx_ops_map_.find(node.op_type()); + if (versions_it == onnx_ops_map_.end()) { + // Unknown op_type. + llvm::outs() << "Warning: ONNX " << node.op_type() << " from domain '" + << node.domain() << "," + << " in your model is unsupported.\n"; + return false; + } + + const ONNXOpVersions &opVersions = versions_it->second; + + // A new opset is added to onnx-mlir when it becomes imcompatible. + // But the lowest opset in op_dialect_version_map_ is an exception. + // It is the current opset when onnx-mlir project is started. + // All opset lower than the last opset should use the last opset(version) + if (node.domain().compare("ai.onnx.ml") != 0 && + version < opVersions.back().version && + version < MINIMUM_SUPPORTED_OPSET) + llvm::outs() << "Warning: ONNX " << node.op_type() + << " in your model is using Opset " << version + << ", which is quite old. Please consider regenerating your " + "model with a newer Opset.\n"; + + ImportHandlerType handler = opVersions.front().handler; + for (int i = opVersions.size() - 1; i > 0; --i) { + if (version < opVersions[i - 1].version) { + LLVM_DEBUG(llvm::dbgs() << DEBUG_TYPE << ": - use Opset " + << opVersions[i].version << "\n"); + handler = opVersions[i].handler; + } + } + (this->*handler)(node); + return true; + } + void ImportNode(const onnx::NodeProto &node) { - std::string versionStr = GetImportVersionOfNode(node); - - // look up handler for the opName. If not found, create a node - // for a custom op, and issue a warning. - std::string versionedName = node.op_type() + versionStr; - auto handler = import_handler_map_.find(versionedName); - if (handler != import_handler_map_.end()) { - (this->*(handler->second))(node); - } else { + bool imported = TryImportONNXNode(node); + if (!imported) { + LLVM_DEBUG(llvm::dbgs() << DEBUG_TYPE << ": Importing Custom op " + << node.op_type() << " (" << node.name() << ")" + << ", domain: '" << node.domain() << "'\n"); ImportCustomNode(node); } } @@ -932,18 +942,16 @@ class FrontendGenImpl { if constexpr (std::is_base_of_v, T>) { StringRef name = T::getONNXName(); int version = T::getONNXSinceVersion(); - op_dialect_version_map_[name.str()].push_back(version); - - StringRef versionedName = T::getOperationName(); - bool hadOnnxPrefix = versionedName.consume_front("onnx."); - assert(hadOnnxPrefix); - import_handler_map_[versionedName.str()] = - &FrontendGenImpl::buildOperation; + ImportHandlerType handler = &FrontendGenImpl::buildOperation; + ONNXOpVersions &opVersions = onnx_ops_map_[name.str()]; + // Insert in descending version order: + auto it = opVersions.begin(); + while (it != opVersions.end() && it->version > version) { + ++it; // Skip past larger versions. + } + opVersions.insert(it, {version, handler}); } }); - for (auto &[name, versions] : op_dialect_version_map_) { - std::sort(versions.begin(), versions.end(), std::greater()); - } } /*!