Skip to content
This repository was archived by the owner on Jan 3, 2023. It is now read-only.

Commit 9b3d573

Browse files
author
Scott Cyphers
committed
Fix top sort
1 parent de37f9d commit 9b3d573

File tree

9 files changed

+103
-71
lines changed

9 files changed

+103
-71
lines changed

src/ngraph/descriptor/output.cpp

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,10 @@
1414
// limitations under the License.
1515
//*****************************************************************************
1616

17-
#include "ngraph/descriptor/output.hpp"
17+
#include <algorithm>
18+
1819
#include "ngraph/descriptor/input.hpp"
20+
#include "ngraph/descriptor/output.hpp"
1921
#include "ngraph/node.hpp"
2022

2123
using namespace std;
@@ -31,12 +33,20 @@ descriptor::Output::Output(Node* node, size_t index, const shared_ptr<Tensor>& t
3133
// Add an input to the vector of inputs that use this output.
3234
void descriptor::Output::add_input(Input* input)
3335
{
34-
m_inputs.insert(input);
36+
// Keep the inputs in insertion order to keep sorts deterministic
37+
if (find(m_inputs.begin(), m_inputs.end(), input) == m_inputs.end())
38+
{
39+
m_inputs.push_back(input);
40+
}
3541
}
3642

3743
void descriptor::Output::remove_input(Input* input)
3844
{
39-
m_inputs.erase(input);
45+
auto it = find(m_inputs.begin(), m_inputs.end(), input);
46+
if (it != m_inputs.end())
47+
{
48+
m_inputs.erase(it);
49+
}
4050
}
4151

4252
shared_ptr<Node> descriptor::Output::get_node() const

src/ngraph/descriptor/output.hpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
#pragma once
1818

1919
#include <memory>
20-
#include <set>
20+
#include <vector>
2121

2222
#include "ngraph/descriptor/input.hpp"
2323
#include "ngraph/descriptor/tensor.hpp"
@@ -48,7 +48,7 @@ namespace ngraph
4848
void set_tensor_ptr(const std::shared_ptr<Tensor>& tensor) { m_tensor = tensor; }
4949
void add_input(Input* input);
5050
void remove_input(Input* input);
51-
const std::set<Input*>& get_inputs() const { return m_inputs; }
51+
const std::vector<Input*>& get_inputs() const { return m_inputs; }
5252
Tensor& get_tensor() const;
5353

5454
/// \return the shape of the output
@@ -64,7 +64,7 @@ namespace ngraph
6464
Node* m_node;
6565
size_t m_index;
6666
std::shared_ptr<Tensor> m_tensor;
67-
std::set<Input*> m_inputs;
67+
std::vector<Input*> m_inputs;
6868

6969
private:
7070
Output(const Output&) = delete;

src/ngraph/graph_util.cpp

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -81,27 +81,27 @@ void ngraph::traverse_nodes(const NodeVector& subgraph_results,
8181
while (stack.size() > 0)
8282
{
8383
std::shared_ptr<Node> n = stack.front();
84+
stack.pop_front();
8485
if (instances_seen.count(n) == 0)
8586
{
8687
instances_seen.insert(n);
8788
f(n);
88-
}
89-
stack.pop_front();
90-
for (auto arg : n->get_arguments())
91-
{
92-
if (instances_seen.count(arg) == 0)
89+
for (auto arg : n->get_arguments())
9390
{
94-
stack.push_front(arg);
91+
if (instances_seen.count(arg) == 0)
92+
{
93+
stack.push_front(arg);
94+
}
9595
}
96-
}
9796

98-
if (include_control_deps)
99-
{
100-
for (auto cdep : n->get_control_dependencies())
97+
if (include_control_deps)
10198
{
102-
if (instances_seen.count(cdep) == 0)
99+
for (auto cdep : n->get_control_dependencies())
103100
{
104-
stack.push_front(cdep);
101+
if (instances_seen.count(cdep) == 0)
102+
{
103+
stack.push_front(cdep);
104+
}
105105
}
106106
}
107107
}

src/ngraph/graph_util.hpp

Lines changed: 33 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#include <functional>
2121
#include <list>
2222
#include <memory>
23+
#include <stack>
2324
#include <string>
2425
#include <unordered_map>
2526
#include <unordered_set>
@@ -81,66 +82,53 @@ namespace ngraph
8182
std::list<std::shared_ptr<Node>> topological_sort(const T& nodes,
8283
bool include_control_deps = false)
8384
{
84-
std::deque<ngraph::Node*> independent_nodes;
85-
std::unordered_map<const ngraph::Node*, size_t> node_dependency_count;
86-
std::unordered_map<ngraph::Node*, std::shared_ptr<ngraph::Node>> node_map;
87-
std::unordered_map<ngraph::Node*, std::set<Node*>> control_deps_users;
85+
std::stack<ngraph::Node*> nodes_to_do;
86+
std::set<Node*> nodes_done;
87+
std::list<std::shared_ptr<Node>> result;
8888

8989
for (auto node : nodes)
9090
{
91-
//build an equivalent of node->get_users() but for control dependencies
92-
size_t control_deps_count = 0;
93-
if (include_control_deps)
94-
{
95-
for (auto cd : node->get_control_dependencies())
96-
{
97-
control_deps_count++;
98-
control_deps_users[cd.get()].insert(node.get());
99-
}
100-
}
101-
102-
node_map[node.get()] = node;
103-
size_t deps_count = node->get_input_size() + control_deps_count;
104-
node_dependency_count[node.get()] = deps_count;
105-
if (deps_count == 0)
106-
{
107-
independent_nodes.push_back(node.get());
108-
}
91+
nodes_to_do.push(node.get());
10992
}
110-
111-
std::list<std::shared_ptr<ngraph::Node>> result_list;
112-
while (independent_nodes.size() > 0)
93+
while (nodes_to_do.size() > 0)
11394
{
114-
auto independent_node = independent_nodes.front();
115-
result_list.push_back(node_map[independent_node]);
116-
independent_nodes.pop_front();
117-
118-
for (const std::shared_ptr<Node>& user : independent_node->get_users())
95+
Node* node = nodes_to_do.top();
96+
if (nodes_done.count(node) != 0)
11997
{
120-
if (--node_dependency_count[user.get()] == 0)
98+
nodes_to_do.pop();
99+
continue;
100+
}
101+
bool can_add = true;
102+
size_t arg_count = node->get_input_size();
103+
for (size_t i = 0; i < arg_count; ++i)
104+
{
105+
Node* dep = node->input(arg_count - i - 1).get_source_output().get_node();
106+
if (nodes_done.count(dep) == 0)
121107
{
122-
independent_nodes.push_back(user.get());
108+
can_add = false;
109+
nodes_to_do.push(dep);
123110
}
124111
}
125-
126112
if (include_control_deps)
127113
{
128-
auto cdit = control_deps_users.find(independent_node);
129-
if (cdit != control_deps_users.end())
130-
for (auto cd_user : cdit->second)
114+
for (auto depptr : node->get_control_dependencies())
115+
{
116+
Node* dep = depptr.get();
117+
if (nodes_done.count(dep) == 0)
131118
{
132-
node_dependency_count[cd_user] -= 1;
133-
size_t count = node_dependency_count[cd_user];
134-
if (count == 0)
135-
{
136-
independent_nodes.push_back(cd_user);
137-
}
119+
can_add = false;
120+
nodes_to_do.push(dep);
138121
}
122+
}
123+
}
124+
if (can_add)
125+
{
126+
result.push_back(node->shared_from_this());
127+
nodes_to_do.pop();
128+
nodes_done.insert(node);
139129
}
140130
}
141-
142-
NGRAPH_CHECK(nodes.size() == result_list.size());
143-
return result_list;
131+
return result;
144132
}
145133

146134
// For cases, where `nodes` is a subset of the entire graph

src/ngraph/node.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -344,7 +344,7 @@ shared_ptr<descriptor::Tensor> Node::get_output_tensor_ptr() const
344344
return m_outputs.at(0).get_tensor_ptr();
345345
}
346346

347-
const std::set<descriptor::Input*>& Node::get_output_inputs(size_t i) const
347+
const std::vector<descriptor::Input*>& Node::get_output_inputs(size_t i) const
348348
{
349349
return m_outputs.at(i).get_inputs();
350350
}

src/ngraph/node.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,7 @@ namespace ngraph
257257
"output, or update calling code not to assume only one output");
258258

259259
/// Returns the set of inputs using output i
260-
const std::set<descriptor::Input*>& get_output_inputs(size_t i) const
260+
const std::vector<descriptor::Input*>& get_output_inputs(size_t i) const
261261
NGRAPH_DEPRECATED("use node->output(i).get_target_inputs() instead");
262262

263263
/// Returns the number of inputs for the op

test/control_dependencies.cpp

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -87,8 +87,7 @@ TEST(control_dependencies, cdep_ops)
8787
make_shared<ControlDependencyOp>(NodeVector{A}, std::set<std::shared_ptr<Node>>{absn});
8888

8989
auto f = make_shared<Function>(cdop, ParameterVector{A, B});
90-
auto nodes = f->get_ordered_ops(true);
91-
ASSERT_EQ(nodes.back()->get_argument(0), cdop);
90+
test_ordered_ops(f);
9291
}
9392

9493
TEST(control_dependencies, two_cdep_ops)
@@ -102,8 +101,7 @@ TEST(control_dependencies, two_cdep_ops)
102101
std::set<std::shared_ptr<Node>>{absn, absn_c});
103102

104103
auto f = make_shared<Function>(cdop, ParameterVector{A, B, C});
105-
auto nodes = f->get_ordered_ops(true);
106-
ASSERT_EQ(nodes.back()->get_argument(0), cdop);
104+
test_ordered_ops(f);
107105
}
108106

109107
TEST(control_dependencies, two_cdep_ops_op_on_top)
@@ -117,8 +115,7 @@ TEST(control_dependencies, two_cdep_ops_op_on_top)
117115
auto absn_cdop = make_shared<op::Abs>(cdop);
118116

119117
auto f = make_shared<Function>(absn_cdop, ParameterVector{A, B});
120-
auto nodes = f->get_ordered_ops(true);
121-
ASSERT_EQ(nodes.back()->get_argument(0), absn_cdop);
118+
test_ordered_ops(f);
122119
}
123120

124121
TEST(control_dependencies, clone_function_cdop)
@@ -129,6 +126,7 @@ TEST(control_dependencies, clone_function_cdop)
129126
make_shared<ControlDependencyOp>(NodeVector{A}, std::set<std::shared_ptr<Node>>{absn});
130127

131128
auto f = make_shared<Function>(cdop, ParameterVector{A});
129+
test_ordered_ops(f);
132130
auto clone = ngraph::clone_function(*f.get());
133131
auto matcher = std::make_shared<pattern::Matcher>(cdop);
134132
auto cdop_clone = clone->get_results().at(0)->get_argument(0);

test/util/test_tools.cpp

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -313,3 +313,35 @@ std::shared_ptr<Function> make_function_from_file(const std::string& file_name)
313313
return func;
314314
}
315315
#endif
316+
317+
::testing::AssertionResult test_ordered_ops(shared_ptr<Function> f)
318+
{
319+
set<shared_ptr<Node>> seen;
320+
for (auto node : f->get_ordered_ops())
321+
{
322+
if (seen.count(node) > 0)
323+
{
324+
return ::testing::AssertionFailure() << "Duplication in ordered ops";
325+
}
326+
size_t arg_count = node->get_input_size();
327+
for (size_t i = 0; i < arg_count; ++i)
328+
{
329+
shared_ptr<Node> dep = node->input(i).get_source_output().get_node_shared_ptr();
330+
if (seen.count(dep) == 0)
331+
{
332+
return ::testing::AssertionFailure() << "Argument " << dep
333+
<< " does not occur before op" << node;
334+
}
335+
}
336+
for (shared_ptr<Node> dep : node->get_control_dependencies())
337+
{
338+
if (seen.count(dep) == 0)
339+
{
340+
return ::testing::AssertionFailure() << "Control dependency " << dep
341+
<< " does not occur before op" << node;
342+
}
343+
}
344+
seen.insert(node);
345+
}
346+
return ::testing::AssertionSuccess();
347+
}

test/util/test_tools.hpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525
#include <random>
2626
#include <vector>
2727

28+
#include "gtest/gtest.h"
29+
2830
#include "ngraph/descriptor/layout/tensor_layout.hpp"
2931
#include "ngraph/file_util.hpp"
3032
#include "ngraph/log.hpp"
@@ -276,3 +278,5 @@ std::vector<T> read_binary_file(const std::string& path)
276278
inputs_fs.read(reinterpret_cast<char*>(file_content.data()), size);
277279
return file_content;
278280
}
281+
282+
testing::AssertionResult test_ordered_ops(std::shared_ptr<ngraph::Function> f);

0 commit comments

Comments
 (0)