Skip to content

Commit

Permalink
Topological sort implementation (#1725)
Browse files Browse the repository at this point in the history
* Topological sort implementation

Co-authorded-by: jooybar <[email protected]>
Signed-off-by: Philip Lassen <[email protected]>

* Address comments

Signed-off-by: Philip Lassen <[email protected]>

* Improve name of constants

Signed-off-by: Philip Lassen <[email protected]>

* clang format

Signed-off-by: Philip Lassen <[email protected]>

Signed-off-by: Philip Lassen <[email protected]>
Signed-off-by: Philip Lassen <[email protected]>
  • Loading branch information
philass authored Oct 7, 2022
1 parent f22d379 commit 79eb701
Show file tree
Hide file tree
Showing 8 changed files with 203 additions and 4 deletions.
1 change: 1 addition & 0 deletions src/Builder/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
add_onnx_mlir_library(OMBuilder
FrontendDialectHelper.cpp
FrontendDialectTransformer.cpp
ImportONNXUtils.cpp
ModelInputShaper.cpp

LINK_LIBS PUBLIC
Expand Down
20 changes: 17 additions & 3 deletions src/Builder/FrontendDialectTransformer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

#include "include/onnx-mlir/Compiler/OMCompilerTypes.h"
#include "src/Builder/FrontendDialectTransformer.hpp"
#include "src/Builder/ImportONNXUtils.hpp"
#include "src/Builder/ModelInputShaper.hpp"
#include "src/Builder/SymbolTable.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
Expand Down Expand Up @@ -1208,7 +1209,7 @@ class FrontendGenImpl {
} // namespace onnx_mlir
namespace onnx_mlir {

void ImportFrontendModelInternal(onnx::ModelProto &model, MLIRContext &context,
bool ImportFrontendModelInternal(onnx::ModelProto &model, MLIRContext &context,
OwningOpRef<ModuleOp> &module, ImportOptions options) {
int originVersion = CURRENT_ONNX_OPSET;
// Get the version of the model
Expand All @@ -1221,7 +1222,14 @@ void ImportFrontendModelInternal(onnx::ModelProto &model, MLIRContext &context,
}
}

// Didnot do downward convert because support for BatchNorm is missing
if (options.allowSorting && !IsTopologicallySorted(model.graph())) {
if (!SortGraph(model.mutable_graph())) {
llvm::outs() << "The graph is not topologically sortable.\n";
return false;
}
}

// Did not do downward convert because support for BatchNorm is missing
if (options.invokeOnnxVersionConverter &&
originVersion < CURRENT_ONNX_OPSET) {
onnx::ModelProto convertModel =
Expand All @@ -1234,6 +1242,7 @@ void ImportFrontendModelInternal(onnx::ModelProto &model, MLIRContext &context,
onnx::shape_inference::InferShapes(model);
ImportFrontendModel(model, context, module, options);
}
return true;
}

// Return 0 on success, error otherwise.
Expand Down Expand Up @@ -1297,7 +1306,12 @@ int ImportFrontendModelFile(StringRef model_fname, MLIRContext &context,
return InvalidOnnxFormat;
}
}
ImportFrontendModelInternal(model, context, module, options);

if (!ImportFrontendModelInternal(model, context, module, options)) {
*errorMessage = "Onnx Model Import Failed on " + model_fname.str();
return CompilerFailure;
}

return CompilerSuccess;
}

Expand Down
1 change: 1 addition & 0 deletions src/Builder/FrontendDialectTransformer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ struct ImportOptions {
// variables)
bool useOnnxModelTypes = false;
bool invokeOnnxVersionConverter = false;
bool allowSorting = true;
// Custom shape information for the graph inputs.
// Its format is 'input_id:dim,dim,dim|input_id:dim,dim,dim'
// E.g. An ONNX model has two dynamic inputs
Expand Down
155 changes: 155 additions & 0 deletions src/Builder/ImportONNXUtils.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
/*
* SPDX-License-Identifier: Apache-2.0
*/

//===--------------------- ImportONNXUtils.hpp ----------------------===//
//
// Copyright 2022 The IBM Research Authors.
//
// =============================================================================
//
// Helper methods for importing and cleaning of onnx models.
//
//===----------------------------------------------------------------------===//

#include <map>
#include <set>
#include <vector>

#include "src/Builder/ImportONNXUtils.hpp"

bool IsTopologicallySorted(const onnx::GraphProto &graph) {
std::set<std::string> visited;
for (const auto &initializer : graph.initializer()) {
const auto &initializerName = initializer.name();
visited.insert(initializerName);
}
for (const auto &input : graph.input()) {
visited.insert(input.name());
}
for (const auto &node : graph.node()) {
for (const auto &input : node.input()) {
if (!visited.count(input))
return false;
}
for (const auto &output : node.output()) {
visited.insert(output);
}
}
return true;
}

// Sort graph into lexicographically smallest topological ordering.
// Returns true if sorted succesfully and false otherwise.
bool SortGraph(onnx::GraphProto *graph) {
int nNodes = graph->node().size();
// Map of edges / node-outputs to their parent ops
std::map<std::string, int> origIndex;
int index = 0;
for (const auto &node : graph->node()) {
for (const auto &output : node.output()) {
origIndex[output] = index;
}
index++;
}
assert(index == nNodes);

// graph inputs and initializers should not be counted as dependencies.
std::set<std::string> graphInputsAndInitializers;
for (const auto &initializer : graph->initializer()) {
const auto &initializerName = initializer.name();
graphInputsAndInitializers.insert(initializerName);
}
for (const auto &input : graph->input()) {
graphInputsAndInitializers.insert(input.name());
}
// Empty input names should be ignored.
graphInputsAndInitializers.insert("");

// Users tracks idx of the ops which consumes a given ops outputs.
std::vector<std::vector<int>> users(nNodes);
index = 0;
for (const auto &node : graph->node()) {
for (const auto &input : node.input()) {
// Input edges to node are graph inputs or initializers.
if (graphInputsAndInitializers.count(input))
continue;
// Check if input edges to node aren't graph inputs or initializers and
// don't have a parent op, in which case its not possible to topologically
// sort the graph.
if (!origIndex.count(input)) {
return false;
}
// Add current node as a user of the op that produces input.
users[origIndex[input]].push_back(index);
}
index++;
}

// inDegrees stores the number of inputs to a given node not counting inputs
// which are graph inputs or initializers.
std::vector<int> inDegrees(nNodes, 0);
index = 0;
for (const auto &node : graph->node()) {
for (const auto &input : node.input()) {
if (!graphInputsAndInitializers.count(input)) {
inDegrees[index]++;
}
}
index++;
}
assert(index == nNodes);

// Create a set and inserting all nodes with indegree 0.
std::multiset<int> nodeList;
for (int i = 0; i < nNodes; i++) {
if (inDegrees[i] == 0) {
nodeList.insert(i);
}
}

// The number of visited nodes.
int nVisited = 0;
// The final topological order.
std::vector<int> topOrder;

// Now we follow Kahn's algorithm for topological sorting
while (!nodeList.empty()) {
// Extract node with minimum number from multiset
// and add it to topological order.
int u = *nodeList.begin();
nodeList.erase(nodeList.begin());
topOrder.push_back(u);

// Iterate through all its users
// and decreament inDegrees by 1.
for (auto v : users[u]) {
// If inDegree becomes zero, add it to queue.
if (--inDegrees[v] == 0) {
nodeList.insert(v);
}
}
nVisited++;
}
// No possible topological order.
if (nVisited != nNodes) {
return false;
}

// Generate SwapElements to reach desired order.
std::vector<int> curOrder(nNodes);
for (int i = 0; i < nNodes; i++)
curOrder[i] = i;
for (int resIndex = 0; resIndex < nNodes; resIndex++) {
if (topOrder[resIndex] == curOrder[resIndex])
continue;
for (int search = resIndex + 1; search < nNodes; search++) {
if (topOrder[resIndex] == curOrder[search]) {
graph->mutable_node()->SwapElements(resIndex, search);
std::swap(curOrder[search], curOrder[resIndex]);
break;
}
}
}
return true; // Succesfully sorted graph.
}
21 changes: 21 additions & 0 deletions src/Builder/ImportONNXUtils.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
/*
* SPDX-License-Identifier: Apache-2.0
*/

//===--------------------- ImportONNXUtils.hpp ----------------------===//
//
// Copyright 2022 The IBM Research Authors.
//
// =============================================================================
//
// Helper methods for importing and cleaning of onnx models.
//
//===----------------------------------------------------------------------===//

#pragma once

#include "onnx/onnx_pb.h"

bool IsTopologicallySorted(const onnx::GraphProto &graph);

bool SortGraph(onnx::GraphProto *graph);
4 changes: 4 additions & 0 deletions src/Compiler/CompilerOptions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,10 @@ llvm::cl::opt<bool> verifyInputTensors("verifyInputTensors",
"at runtime."),
llvm::cl::init(false), llvm::cl::cat(OnnxMlirOptions));

llvm::cl::opt<bool> allowSorting("allowSorting",
llvm::cl::desc("Allow onnx-mlir to perform topological sort on onnx graph"),
llvm::cl::init(true), llvm::cl::cat(OnnxMlirOptions));

// Configuration states associated with certain options.
// For example, when maccel is specified, NNPA can register
// dependent libdnn.
Expand Down
3 changes: 2 additions & 1 deletion src/Compiler/CompilerOptions.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ extern llvm::cl::list<std::string> Xopt;
extern llvm::cl::list<std::string> Xllc;
extern llvm::cl::opt<std::string> mllvm;
extern llvm::cl::opt<bool> verifyInputTensors;
extern llvm::cl::opt<bool> allowSorting;

extern llvm::cl::opt<std::string> instrumentONNXOps;
extern llvm::cl::bits<InstrumentActions> instrumentControlBits;
Expand Down Expand Up @@ -119,4 +120,4 @@ std::vector<std::string> getCompilerConfig(std::string k);
void addCompilerConfig(std::string k, std::vector<std::string> v);
void delCompilerConfig(std::string k, std::vector<std::string> v);

} // namespace onnx_mlir
} // namespace onnx_mlir
2 changes: 2 additions & 0 deletions src/Compiler/CompilerUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -642,6 +642,7 @@ int processInputFile(StringRef inputFilename, mlir::MLIRContext &context,
options.useOnnxModelTypes = useOnnxModelTypes;
options.invokeOnnxVersionConverter = invokeOnnxVersionConverter;
options.shapeInformation = shapeInformation;
options.allowSorting = allowSorting;
options.externalDataDir = dirName(inputFilename);
return ImportFrontendModelFile(
inputFilename, context, module, errorMessage, options);
Expand All @@ -657,6 +658,7 @@ int processInputArray(const void *onnxBuffer, int bufferSize,
ImportOptions options;
options.useOnnxModelTypes = useOnnxModelTypes;
options.invokeOnnxVersionConverter = invokeOnnxVersionConverter;
options.allowSorting = allowSorting;
options.shapeInformation = shapeInformation;
return ImportFrontendModelArray(
onnxBuffer, bufferSize, context, module, errorMessage, options);
Expand Down

0 comments on commit 79eb701

Please sign in to comment.