|
20 | 20 | #include <functional>
|
21 | 21 | #include <list>
|
22 | 22 | #include <memory>
|
| 23 | +#include <stack> |
23 | 24 | #include <string>
|
24 | 25 | #include <unordered_map>
|
25 | 26 | #include <unordered_set>
|
@@ -81,66 +82,53 @@ namespace ngraph
|
81 | 82 | std::list<std::shared_ptr<Node>> topological_sort(const T& nodes,
|
82 | 83 | bool include_control_deps = false)
|
83 | 84 | {
|
84 |
| - std::deque<ngraph::Node*> independent_nodes; |
85 |
| - std::unordered_map<const ngraph::Node*, size_t> node_dependency_count; |
86 |
| - std::unordered_map<ngraph::Node*, std::shared_ptr<ngraph::Node>> node_map; |
87 |
| - std::unordered_map<ngraph::Node*, std::set<Node*>> control_deps_users; |
| 85 | + std::stack<ngraph::Node*> nodes_to_do; |
| 86 | + std::set<Node*> nodes_done; |
| 87 | + std::list<std::shared_ptr<Node>> result; |
88 | 88 |
|
89 | 89 | for (auto node : nodes)
|
90 | 90 | {
|
91 |
| - //build an equivalent of node->get_users() but for control dependencies |
92 |
| - size_t control_deps_count = 0; |
93 |
| - if (include_control_deps) |
94 |
| - { |
95 |
| - for (auto cd : node->get_control_dependencies()) |
96 |
| - { |
97 |
| - control_deps_count++; |
98 |
| - control_deps_users[cd.get()].insert(node.get()); |
99 |
| - } |
100 |
| - } |
101 |
| - |
102 |
| - node_map[node.get()] = node; |
103 |
| - size_t deps_count = node->get_input_size() + control_deps_count; |
104 |
| - node_dependency_count[node.get()] = deps_count; |
105 |
| - if (deps_count == 0) |
106 |
| - { |
107 |
| - independent_nodes.push_back(node.get()); |
108 |
| - } |
| 91 | + nodes_to_do.push(node.get()); |
109 | 92 | }
|
110 |
| - |
111 |
| - std::list<std::shared_ptr<ngraph::Node>> result_list; |
112 |
| - while (independent_nodes.size() > 0) |
| 93 | + while (nodes_to_do.size() > 0) |
113 | 94 | {
|
114 |
| - auto independent_node = independent_nodes.front(); |
115 |
| - result_list.push_back(node_map[independent_node]); |
116 |
| - independent_nodes.pop_front(); |
117 |
| - |
118 |
| - for (const std::shared_ptr<Node>& user : independent_node->get_users()) |
| 95 | + Node* node = nodes_to_do.top(); |
| 96 | + if (nodes_done.count(node) != 0) |
119 | 97 | {
|
120 |
| - if (--node_dependency_count[user.get()] == 0) |
| 98 | + nodes_to_do.pop(); |
| 99 | + continue; |
| 100 | + } |
| 101 | + bool can_add = true; |
| 102 | + size_t arg_count = node->get_input_size(); |
| 103 | + for (size_t i = 0; i < arg_count; ++i) |
| 104 | + { |
| 105 | + Node* dep = node->input(arg_count - i - 1).get_source_output().get_node(); |
| 106 | + if (nodes_done.count(dep) == 0) |
121 | 107 | {
|
122 |
| - independent_nodes.push_back(user.get()); |
| 108 | + can_add = false; |
| 109 | + nodes_to_do.push(dep); |
123 | 110 | }
|
124 | 111 | }
|
125 |
| - |
126 | 112 | if (include_control_deps)
|
127 | 113 | {
|
128 |
| - auto cdit = control_deps_users.find(independent_node); |
129 |
| - if (cdit != control_deps_users.end()) |
130 |
| - for (auto cd_user : cdit->second) |
| 114 | + for (auto depptr : node->get_control_dependencies()) |
| 115 | + { |
| 116 | + Node* dep = depptr.get(); |
| 117 | + if (nodes_done.count(dep) == 0) |
131 | 118 | {
|
132 |
| - node_dependency_count[cd_user] -= 1; |
133 |
| - size_t count = node_dependency_count[cd_user]; |
134 |
| - if (count == 0) |
135 |
| - { |
136 |
| - independent_nodes.push_back(cd_user); |
137 |
| - } |
| 119 | + can_add = false; |
| 120 | + nodes_to_do.push(dep); |
138 | 121 | }
|
| 122 | + } |
| 123 | + } |
| 124 | + if (can_add) |
| 125 | + { |
| 126 | + result.push_back(node->shared_from_this()); |
| 127 | + nodes_to_do.pop(); |
| 128 | + nodes_done.insert(node); |
139 | 129 | }
|
140 | 130 | }
|
141 |
| - |
142 |
| - NGRAPH_CHECK(nodes.size() == result_list.size()); |
143 |
| - return result_list; |
| 131 | + return result; |
144 | 132 | }
|
145 | 133 |
|
146 | 134 | // For cases, where `nodes` is a subset of the entire graph
|
|
0 commit comments