diff --git a/onnxruntime/core/providers/vitisai/imp/global_api.cc b/onnxruntime/core/providers/vitisai/imp/global_api.cc index 8111ee3c1fe61..e2381c8d178e7 100644 --- a/onnxruntime/core/providers/vitisai/imp/global_api.cc +++ b/onnxruntime/core/providers/vitisai/imp/global_api.cc @@ -503,6 +503,7 @@ vaip_core::OrtApiForVaip* create_org_api_hook() { the_global_api.graph_remove_initialized_tensor = [](Graph& graph, const std::string& tensor_name) { graph.RemoveInitializedTensor(tensor_name); }; + the_global_api.graph_reverse_dfs_from_preemp = vaip::graph_reverse_dfs_from; if (!s_library_vitisaiep.vaip_get_version) { return reinterpret_cast(&(the_global_api.host_)); } else { diff --git a/onnxruntime/core/providers/vitisai/imp/graph.cc b/onnxruntime/core/providers/vitisai/imp/graph.cc index e7b39546fda6a..e9db58143577b 100644 --- a/onnxruntime/core/providers/vitisai/imp/graph.cc +++ b/onnxruntime/core/providers/vitisai/imp/graph.cc @@ -89,6 +89,75 @@ Node& graph_add_node(Graph& graph, const std::string& name, const std::string& o return ret; } +// copied from graph.cc, trying to exit the function early as leave function may change the validity of the graph +void graph_reverse_dfs_from( + const Graph& graph, gsl::span from, + const std::function& enter, + const std::function& leave, + const std::function& comp, + const std::function& + stop) { + using WorkEntry = std::pair; // bool represents leave or not + InlinedVector stack; + stack.reserve(from.size()); + for (auto node : from) { + stack.emplace_back(node, false); + } + + InlinedVector visited(graph.MaxNodeIndex(), false); + while (!stack.empty()) { + const WorkEntry last_entry = stack.back(); + stack.pop_back(); + + if (last_entry.first == nullptr) { + continue; + } + const Node& n = *last_entry.first; + + if (last_entry.second) { + // leave node + if (leave(&n)) { + return; + } + continue; + } + + if (visited[n.Index()]) continue; + + visited[n.Index()] = true; + + if (enter) { + if (enter(&n)) { + return; + } + } + if (leave) stack.emplace_back(&n, true); + + if (comp) { + InlinedVector sorted_nodes; + for (auto iter = n.InputNodesBegin(); iter != n.InputNodesEnd(); ++iter) { + if (stop && stop(&n, &(*iter))) continue; + sorted_nodes.push_back(&(*iter)); + } + std::sort(sorted_nodes.begin(), sorted_nodes.end(), comp); + for (const auto* in : sorted_nodes) { + const NodeIndex idx = in->Index(); + if (!visited[idx]) { + stack.emplace_back(in, false); + } + } + } else { + for (auto iter = n.InputNodesBegin(); iter != n.InputNodesEnd(); ++iter) { + if (stop && stop(&n, &(*iter))) continue; + const NodeIndex idx = (*iter).Index(); + if (!visited[idx]) { + stack.emplace_back(graph.GetNode(idx), false); + } + } + } + } +} + void graph_remove_node(Graph& graph, const NodeInput& node_input) { if (node_input.node == nullptr && node_input.node_arg != nullptr) { assert(node_input.node_arg->Exists()); diff --git a/onnxruntime/core/providers/vitisai/include/vaip/graph.h b/onnxruntime/core/providers/vitisai/include/vaip/graph.h index 561278c73a6c1..bd8d0229d627c 100644 --- a/onnxruntime/core/providers/vitisai/include/vaip/graph.h +++ b/onnxruntime/core/providers/vitisai/include/vaip/graph.h @@ -3,6 +3,8 @@ #pragma once #include "./node.h" #include "vaip/my_ort.h" +#include +#include namespace vaip { using namespace onnxruntime; @@ -16,4 +18,12 @@ Node& graph_fuse(Graph& graph, const std::string& name, const std::string& op_ty const std::vector& inputs, const std::vector& outputs, const std::vector& constant_initializers); Model* model_clone(const Model& original_model, int64_t external_data_threshold); + +void graph_reverse_dfs_from( + const Graph& graph, gsl::span from, + const std::function& enter, + const std::function& leave, + const std::function& comp, + const std::function& + stop); } // namespace vaip diff --git a/onnxruntime/core/providers/vitisai/include/vaip/vaip_ort_api.h b/onnxruntime/core/providers/vitisai/include/vaip/vaip_ort_api.h index 6a51ef862280b..0becc41d861f7 100644 --- a/onnxruntime/core/providers/vitisai/include/vaip/vaip_ort_api.h +++ b/onnxruntime/core/providers/vitisai/include/vaip/vaip_ort_api.h @@ -13,7 +13,7 @@ struct OrtApi; namespace vaip_core { -#define VAIP_ORT_API_MAJOR (13u) +#define VAIP_ORT_API_MAJOR (14u) #define VAIP_ORT_API_MINOR (0u) #define VAIP_ORT_API_PATCH (0u) struct OrtApiForVaip { @@ -243,6 +243,13 @@ struct OrtApiForVaip { const std::vector& shape, const std::vector& data); // [101] void (*graph_remove_initialized_tensor)(Graph& graph, const std::string& tensor_name); // [102] + void (*graph_reverse_dfs_from_preemp)( + const Graph& graph, gsl::span from, + const std::function& enter, + const std::function& leave, + const std::function& comp, + const std::function& + stop); // [103] }; #ifndef USE_VITISAI