Skip to content

Commit

Permalink
more fixes
Browse files Browse the repository at this point in the history
Signed-off-by: Soren Lassen <[email protected]>
  • Loading branch information
sorenlassen committed May 16, 2023
1 parent 6fb1231 commit 13f625a
Showing 1 changed file with 19 additions and 14 deletions.
33 changes: 19 additions & 14 deletions src/Builder/FrontendDialectTransformer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,13 @@ class FrontendGenImpl {
}
}

int64_t GetDomainVersion(const std::string &domain) {
auto it = opset_map_.find(domain);
if (it == opset_map_.end())
return 0;
return it->second;
}

void BindOnnxName(const std::string &onnx_name, Value symbol) {
frontend_symbols_.AddMapping(onnx_name, symbol);
}
Expand Down Expand Up @@ -694,20 +701,21 @@ class FrontendGenImpl {
}

const onnx::OpSchema *GetOpSchema(const onnx::NodeProto &node) {
auto &domain = node.domain();
auto version_it = opset_map_.find(domain);
if (version_it == opset_map_.end())
int64_t version = GetDomainVersion(node.domain());
if (version == 0)
return nullptr;
auto version = version_it->second;
return onnx::OpSchemaRegistry::Schema(node.op_type(), version, domain);
return onnx::OpSchemaRegistry::Schema(
node.op_type(), version, node.domain());
}

std::string GetImportVersionOfNode(const onnx::NodeProto &node) {
auto current_opset = opset_map_.find(node.domain())->second;
int64_t version = GetDomainVersion(node.domain());
if (version == 0)
return "";

LLVM_DEBUG(llvm::dbgs() << DEBUG_TYPE << ": Importing ONNX"
<< node.op_type() << " (" << node.name() << ")"
<< ", Opset: " << current_opset << "\n");
<< ", Opset: " << version << "\n");

auto opset_list_it = op_dialect_version_map_.find(node.op_type());

Expand All @@ -723,23 +731,20 @@ class FrontendGenImpl {
// It is the current opset when onnx-mlir project is started.
// All opset lower than the last opset should use the last opset(version)
if (node.domain().compare("ai.onnx.ml") != 0 &&
current_opset < opset_list.back() &&
current_opset < MINIMUM_SUPPORTED_OPSET)
version < opset_list.back() && version < MINIMUM_SUPPORTED_OPSET)
llvm::outs() << "Warning: ONNX " << node.op_type()
<< " in your model is using Opset " << current_opset
<< " in your model is using Opset " << version
<< ", which is quite old. Please consider regenerating your "
"model with a newer Opset.\n";

for (int i = opset_list.size() - 1; i > 0; i--) {
LLVM_DEBUG(llvm::dbgs() << DEBUG_TYPE << ": - testing Opset "
<< opset_list[i - 1] << "\n");
if (current_opset < opset_list[i - 1]) {
if (version < opset_list[i - 1]) {
LLVM_DEBUG(llvm::dbgs() << DEBUG_TYPE << ": - use Opset "
<< opset_list[i] << "\n");
return "V" + std::to_string(opset_list[i]);
}
}
return std::string("");
return "";
}

func::FuncOp CreateFuncOp(
Expand Down

0 comments on commit 13f625a

Please sign in to comment.