Skip to content

Commit

Permalink
fixes
Browse files Browse the repository at this point in the history
Signed-off-by: Soren Lassen <[email protected]>
  • Loading branch information
sorenlassen committed May 16, 2023
1 parent a07f6b7 commit 6fb1231
Showing 1 changed file with 15 additions and 9 deletions.
24 changes: 15 additions & 9 deletions src/Builder/FrontendDialectTransformer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/Support/FileUtilities.h"
#include "llvm/ADT/StringMap.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/LineIterator.h"
#include "llvm/Support/MemoryBuffer.h"
Expand Down Expand Up @@ -52,8 +51,8 @@ SUPPRESS_WARNINGS_POP
#include <array>
#include <fstream>
#include <functional>
#include <map>
#include <unordered_map>
#include <unordered_set>
#include <vector>

#define DEBUG_TYPE "frontend_dialect_transformer"
Expand Down Expand Up @@ -104,7 +103,7 @@ class FrontendGenImpl {
OpBuilder builder_;

// onnxop: list of versions for dialect
llvm::StringMap<std::vector<int>> op_dialect_version_map_;
std::unordered_map<std::string, std::vector<int>> op_dialect_version_map_;

// mapping between string name and symbol
ValueSymbolMapping frontend_symbols_;
Expand All @@ -113,7 +112,7 @@ class FrontendGenImpl {

using ImportHandlerType = void (FrontendGenImpl::*)(const onnx::NodeProto &);

llvm::StringMap<ImportHandlerType> import_handler_map_;
std::unordered_map<std::string, ImportHandlerType> import_handler_map_;

// The total number of elements in all initializers. This value is a rough
// counter of the number of parameters in a model.
Expand Down Expand Up @@ -146,7 +145,7 @@ class FrontendGenImpl {
// opset_map_ is the internal (map) representation of ModelProto::opset_import
// It maps each domain (e.g., "ai.onnx") to the specific version of that opset
// used by this model.
std::map<std::string, int64_t> opset_map_;
std::unordered_map<std::string, int64_t> opset_map_;
void SetOpSetImport(const onnx::ModelProto &model) {
opset_map_.clear();
for (auto &binding : model.opset_import()) {
Expand Down Expand Up @@ -732,6 +731,8 @@ class FrontendGenImpl {
"model with a newer Opset.\n";

for (int i = opset_list.size() - 1; i > 0; i--) {
LLVM_DEBUG(llvm::dbgs() << DEBUG_TYPE << ": - testing Opset "
<< opset_list[i - 1] << "\n");
if (current_opset < opset_list[i - 1]) {
LLVM_DEBUG(llvm::dbgs() << DEBUG_TYPE << ": - use Opset "
<< opset_list[i] << "\n");
Expand Down Expand Up @@ -937,8 +938,8 @@ class FrontendGenImpl {

// look up handler for the opName. If not found, create a node
// for a custom op, and issue a warning.
auto handler =
import_handler_map_.find(node.op_type() + versionStr.c_str());
std::string versionedName = node.op_type() + versionStr;
auto handler = import_handler_map_.find(versionedName);
if (handler != import_handler_map_.end()) {
(this->*(handler->second))(node);
} else {
Expand All @@ -952,8 +953,13 @@ class FrontendGenImpl {
if constexpr (std::is_base_of_v<ONNXOperationTrait<T>, T>) {
StringRef name = T::getONNXName();
int version = T::getONNXSinceVersion();
op_dialect_version_map_[name].push_back(version);
import_handler_map_[name] = &FrontendGenImpl::buildOperation<T>;
op_dialect_version_map_[name.str()].push_back(version);

StringRef versionedName = T::getOperationName();
bool hadOnnxPrefix = versionedName.consume_front("onnx.");
assert(hadOnnxPrefix);
import_handler_map_[versionedName.str()] =
&FrontendGenImpl::buildOperation<T>;
}
});
for (auto &[name, versions] : op_dialect_version_map_) {
Expand Down

0 comments on commit 6fb1231

Please sign in to comment.