Skip to content

Commit 68baf18

Browse files
WIP: Debug port binding
1 parent 6105526 commit 68baf18

File tree

11 files changed

+160
-109
lines changed

11 files changed

+160
-109
lines changed

include/ion/builder.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,26 +101,41 @@ class Builder {
101101
Halide::Pipeline build(ion::PortMap& ports);
102102

103103
std::vector<Halide::Argument> get_arguments_stub() const {
104+
std::set<Port::Channel> added_ports;
104105
std::vector<Halide::Argument> args;
105106
for (const auto& node : nodes_) {
106107
for (const auto& port : node.iports()) {
107108
if (port.has_pred()) {
108109
continue;
109110
}
111+
112+
if (added_ports.count(port.impl_->pred_chan)) {
113+
continue;
114+
}
115+
added_ports.insert(port.impl_->pred_chan);
116+
110117
const auto& port_args(port.as_argument());
111118
args.insert(args.end(), port_args.begin(), port_args.end());
119+
112120
}
113121
}
114122
return args;
115123
}
116124

117125
std::vector<const void*> get_arguments_instance() const {
126+
std::set<Port::Channel> added_ports;
118127
std::vector<const void*> instances;
119128
for (const auto& node : nodes_) {
120129
for (const auto& port : node.iports()) {
121130
if (port.has_pred()) {
122131
continue;
123132
}
133+
134+
if (added_ports.count(port.impl_->pred_chan)) {
135+
continue;
136+
}
137+
added_ports.insert(port.impl_->pred_chan);
138+
124139
const auto& port_instances(port.as_instance());
125140
instances.insert(instances.end(), port_instances.begin(), port_instances.end());
126141
}

include/ion/node.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -88,13 +88,13 @@ class Node {
8888
*/
8989
Port operator[](const std::string& name) {
9090
auto it = std::find_if(impl_->ports.begin(), impl_->ports.end(),
91-
[&](const Port& p){ return (p.pred_name() == name && p.pred_id() == impl_->id) || (p.succ_name() == name && p.succ_id() == impl_->id); });
91+
[&](const Port& p){ return (p.pred_name() == name && p.pred_id() == impl_->id) || p.has_succ({.node_id=impl_->id, .name=name}); });
9292
if (it == impl_->ports.end()) {
9393
// This is output port which is never referenced.
9494
// Bind myself as a predecessor and register
9595

9696
// TODO: Validate with arginfo
97-
Port port(impl_->id, name, "", "");
97+
Port port(impl_->id, name);
9898
impl_->ports.push_back(port);
9999
return port;
100100
} else {
@@ -138,7 +138,8 @@ class Node {
138138
std::vector<Port> iports() const {
139139
std::vector<Port> iports;
140140
for (const auto& p: impl_->ports) {
141-
if (id() == p.succ_id()) {
141+
if (std::count_if(p.impl_->succ_chans.begin(), p.impl_->succ_chans.end(),
142+
[&](const Port::Channel& c) { return std::get<0>(c) == impl_->id; })) {
142143
iports.push_back(p);
143144
}
144145
}

include/ion/port.h

Lines changed: 32 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,19 @@ namespace ion {
2020
* Port class is used to create dynamic i/o for each node.
2121
*/
2222
class Port {
23+
public:
24+
using Channel = std::tuple<std::string, std::string>;
2325

26+
private:
2427
struct Impl {
25-
std::string pred_id;
26-
std::string pred_name;
28+
// std::string pred_id;
29+
// std::string pred_name;
30+
31+
// std::string succ_id;
32+
// std::string succ_name;
2733

28-
std::string succ_id;
29-
std::string succ_name;
34+
Channel pred_chan;
35+
std::set<Channel> succ_chans;
3036

3137
Halide::Type type;
3238
int32_t dimensions;
@@ -36,10 +42,10 @@ class Port {
3642

3743
Impl() {}
3844

39-
Impl(const std::string& pid, const std::string& pn, const std::string& sid, const std::string& sn, const Halide::Type& t, int32_t d)
40-
: pred_id(pid), pred_name(pn), succ_id(sid), succ_name(sn), type(t), dimensions(d)
45+
Impl(const std::string& pid, const std::string& pn, const Halide::Type& t, int32_t d)
46+
: pred_chan{.node_id=pid, .name=pn}, succ_chans{}, type(t), dimensions(d)
4147
{
42-
params[0] = Halide::Internal::Parameter(type, dimensions != 0, dimensions, argument_name(pid, pn, sid, sn, 0));
48+
params[0] = Halide::Internal::Parameter(type, dimensions != 0, dimensions, argument_name(pid, pn, 0));
4349
}
4450
};
4551

@@ -48,42 +54,41 @@ class Port {
4854
friend class Node;
4955
friend class nlohmann::adl_serializer<Port>;
5056

51-
Port() : impl_(new Impl("", "", "", "", Halide::Type(), 0)), index_(-1) {}
57+
Port() : impl_(new Impl("", "", Halide::Type(), 0)), index_(-1) {}
5258
Port(const std::shared_ptr<Impl>& impl) : impl_(impl), index_(-1) {}
5359

5460
/**
5561
* Construct new port for scalar value.
5662
* @arg k: The key of the port which should be matched with BuildingBlock Input/Output name.
5763
* @arg t: The type of the value.
5864
*/
59-
Port(const std::string& n, Halide::Type t) : impl_(new Impl("", "", "", n, t, 0)), index_(-1) {}
65+
Port(const std::string& n, Halide::Type t) : impl_(new Impl("", n, t, 0)), index_(-1) {}
6066

6167
/**
6268
* Construct new port for vector value.
6369
* @arg k: The key of the port which should be matched with BuildingBlock Input/Output name.
6470
* @arg t: The type of the element value.
6571
* @arg d: The dimension of the port. The range is 1 to 4.
6672
*/
67-
Port(const std::string& n, Halide::Type t, int32_t d) : impl_(new Impl("", "", "", n, t, d)), index_(-1) {}
73+
Port(const std::string& n, Halide::Type t, int32_t d) : impl_(new Impl("", n, t, d)), index_(-1) {}
6874

69-
const std::string& pred_name() const { return impl_->pred_name; }
70-
const std::string& succ_name() const { return impl_->succ_name; }
75+
const std::string& pred_id() const { return std::get<0>(impl_->pred_chan); }
76+
const std::string& pred_name() const { return std::get<1>(impl_->pred_chan); }
7177

7278
const Halide::Type& type() const { return impl_->type; }
7379

7480
int32_t dimensions() const { return impl_->dimensions; }
7581

76-
const std::string& pred_id() const { return impl_->pred_id; }
77-
78-
const std::string& succ_id() const { return impl_->succ_id; }
82+
// const std::string& succ_id() const { return impl_->succ_id; }
7983

8084
int32_t size() const { return impl_->params.size(); }
8185

8286
int32_t index() const { return index_; }
8387

84-
bool has_pred() const { return !pred_id().empty(); }
88+
bool has_pred() const { return !std::get<0>(impl_->pred_chan).empty(); }
8589

86-
bool has_succ() const { return !succ_id().empty(); }
90+
bool has_succ() const { return !impl_->succ_chans.empty(); }
91+
bool has_succ(const Channel& c) const { return impl_->succ_chans.count(c); }
8792

8893
void set_index(int index) { index_ = index; }
8994

@@ -100,9 +105,9 @@ class Port {
100105
void bind(T *v) {
101106
auto i = index_ == -1 ? 0 : index_;
102107
if (has_pred()) {
103-
impl_->params[i] = Halide::Internal::Parameter{Halide::type_of<T>(), false, 0, argument_name(pred_id(), pred_name(), succ_id(), succ_name(), i)};
108+
impl_->params[i] = Halide::Internal::Parameter{Halide::type_of<T>(), false, 0, argument_name(pred_id(), pred_name(), i)};
104109
} else {
105-
impl_->params[i] = Halide::Internal::Parameter{type(), false, dimensions(), argument_name(pred_id(), pred_name(), succ_id(), succ_name(), i)};
110+
impl_->params[i] = Halide::Internal::Parameter{type(), false, dimensions(), argument_name(pred_id(), pred_name(), i)};
106111
}
107112

108113
impl_->instances[i] = v;
@@ -113,9 +118,9 @@ class Port {
113118
void bind(const Halide::Buffer<T>& buf) {
114119
auto i = index_ == -1 ? 0 : index_;
115120
if (has_pred()) {
116-
impl_->params[i] = Halide::Internal::Parameter{buf.type(), true, buf.dimensions(), argument_name(pred_id(), pred_name(), succ_id(), succ_name(), i)};
121+
impl_->params[i] = Halide::Internal::Parameter{buf.type(), true, buf.dimensions(), argument_name(pred_id(), pred_name(), i)};
117122
} else {
118-
impl_->params[i] = Halide::Internal::Parameter{type(), true, dimensions(), argument_name(pred_id(), pred_name(), succ_id(), succ_name(), i)};
123+
impl_->params[i] = Halide::Internal::Parameter{type(), true, dimensions(), argument_name(pred_id(), pred_name(), i)};
119124
}
120125

121126
impl_->instances[i] = buf.raw_buffer();
@@ -125,9 +130,9 @@ class Port {
125130
void bind(const std::vector<Halide::Buffer<T>>& bufs) {
126131
for (size_t i=0; i<bufs.size(); ++i) {
127132
if (has_pred()) {
128-
impl_->params[i] = Halide::Internal::Parameter{bufs[i].type(), true, bufs[i].dimensions(), argument_name(pred_id(), pred_name(), succ_id(), succ_name(), i)};
133+
impl_->params[i] = Halide::Internal::Parameter{bufs[i].type(), true, bufs[i].dimensions(), argument_name(pred_id(), pred_name(), i)};
129134
} else {
130-
impl_->params[i] = Halide::Internal::Parameter{type(), true, dimensions(), argument_name(pred_id(), pred_name(), succ_id(), succ_name(), i)};
135+
impl_->params[i] = Halide::Internal::Parameter{type(), true, dimensions(), argument_name(pred_id(), pred_name(), i)};
131136
}
132137

133138
impl_->instances[i] = bufs[i].raw_buffer();
@@ -148,7 +153,7 @@ class Port {
148153
/**
149154
* This port is created from another node
150155
*/
151-
Port(const std::string& pid, const std::string& pn, const std::string& sid, const std::string& sn) : impl_(new Impl(pid, pn, sid, sn, Halide::Type(), 0)), index_(-1) {}
156+
Port(const std::string& pid, const std::string& pn) : impl_(new Impl(pid, pn, Halide::Type(), 0)), index_(-1) {}
152157

153158

154159
std::vector<Halide::Argument> as_argument() const {
@@ -158,7 +163,7 @@ class Port {
158163
args.resize(i+1, Halide::Argument());
159164
}
160165
auto kind = dimensions() == 0 ? Halide::Argument::InputScalar : Halide::Argument::InputBuffer;
161-
args[i] = Halide::Argument(argument_name(pred_id(), pred_name(), succ_id(), succ_name(), i), kind, type(), dimensions(), Halide::ArgumentEstimates());
166+
args[i] = Halide::Argument(argument_name(pred_id(), pred_name(), i), kind, type(), dimensions(), Halide::ArgumentEstimates());
162167
}
163168
return args;
164169
}
@@ -184,7 +189,7 @@ class Port {
184189
if (es.size() <= i) {
185190
es.resize(i+1, Halide::Expr());
186191
}
187-
es[i] = Halide::Internal::Variable::make(type(), argument_name(pred_id(), pred_name(), succ_id(), succ_name(), i), param);
192+
es[i] = Halide::Internal::Variable::make(type(), argument_name(pred_id(), pred_name(), i), param);
188193
}
189194
return es;
190195
}
@@ -205,7 +210,7 @@ class Port {
205210
args.push_back(Halide::Var::implicit(i));
206211
args_expr.push_back(Halide::Var::implicit(i));
207212
}
208-
Halide::Func f(param.type(), param.dimensions(), argument_name(pred_id(), pred_name(), succ_id(), succ_name(), i) + "_im");
213+
Halide::Func f(param.type(), param.dimensions(), argument_name(pred_id(), pred_name(), i) + "_im");
209214
f(args) = Halide::Internal::Call::make(param, args_expr);
210215
fs[i] = f;
211216
}

include/ion/port_map.h

Lines changed: 11 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -13,24 +13,20 @@ namespace std
1313
{
1414

1515
template<>
16-
struct hash<tuple<string, string, string, string, int>>
16+
struct hash<tuple<string, string, int>>
1717
{
18-
std::size_t operator()(const tuple<string, string, string, string, int>& k) const noexcept
18+
std::size_t operator()(const tuple<string, string, int>& k) const noexcept
1919
{
20-
return std::hash<std::string>{}(std::get<0>(k)) ^ std::hash<std::string>{}(std::get<1>(k)) ^
21-
std::hash<std::string>{}(std::get<2>(k)) ^ std::hash<std::string>{}(std::get<3>(k)) ^
22-
std::hash<int>{}(std::get<4>(k));
20+
return std::hash<std::string>{}(std::get<0>(k)) ^ std::hash<std::string>{}(std::get<1>(k)) ^ std::hash<int>{}(std::get<2>(k));
2321
}
2422
};
2523

2624
template<>
27-
struct equal_to<tuple<string, string, string, string, int>>
25+
struct equal_to<tuple<string, string, int>>
2826
{
29-
bool operator()(const tuple<string, string, string, string, int>& v0, const tuple<string, string, string, string, int>& v1) const
27+
bool operator()(const tuple<string, string, int>& v0, const tuple<string, string, int>& v1) const
3028
{
31-
return (std::get<0>(v0) == std::get<0>(v1) && std::get<1>(v0) == std::get<1>(v1) &&
32-
std::get<2>(v0) == std::get<2>(v1) && std::get<3>(v0) == std::get<3>(v1) &&
33-
std::get<4>(v0) == std::get<4>(v1));
29+
return (std::get<0>(v0) == std::get<0>(v1) && std::get<1>(v0) == std::get<1>(v1) && std::get<2>(v0) == std::get<2>(v1));
3430
}
3531
};
3632

@@ -69,7 +65,7 @@ class PortMap {
6965
*/
7066
template<typename T>
7167
void set(Port port, T v) {
72-
auto& buf(scalar_buffer_[argument_name(port.pred_id(), port.pred_name(), port.succ_id(), port.succ_name(), port.index())]);
68+
auto& buf(scalar_buffer_[argument_name(port.pred_id(), port.pred_name(), port.index())]);
7369
buf.resize(sizeof(v));
7470
std::memcpy(buf.data(), &v, sizeof(v));
7571
port.bind(reinterpret_cast<T*>(buf.data()));
@@ -98,7 +94,7 @@ class PortMap {
9894
void set(Port port, Halide::Buffer<T>& buf) {
9995
if (port.has_pred()) {
10096
// This is just an output.
101-
output_buffer_[std::make_tuple(port.pred_id(), port.pred_name(), port.succ_id(), port.succ_name(), port.index())] = { buf };
97+
output_buffer_[std::make_tuple(port.pred_id(), port.pred_name(), port.index())] = { buf };
10298
} else {
10399
port.bind(buf);
104100
}
@@ -129,7 +125,7 @@ class PortMap {
129125
if (port.has_pred()) {
130126
// This is just an output.
131127
for (auto buf : bufs) {
132-
output_buffer_[std::make_tuple(port.pred_id(), port.pred_name(), port.succ_id(), port.succ_name(), port.index())].push_back(buf);
128+
output_buffer_[std::make_tuple(port.pred_id(), port.pred_name(), port.index())].push_back(buf);
133129
}
134130
} else {
135131
port.bind(bufs);
@@ -138,7 +134,7 @@ class PortMap {
138134
dirty_ = true;
139135
}
140136

141-
std::unordered_map<std::tuple<std::string, std::string, std::string, std::string, int>, std::vector<Halide::Buffer<>>> get_output_buffer() const {
137+
std::unordered_map<std::tuple<std::string, std::string, int>, std::vector<Halide::Buffer<>>> get_output_buffer() const {
142138
return output_buffer_;
143139
}
144140

@@ -152,7 +148,7 @@ class PortMap {
152148

153149
private:
154150
bool dirty_;
155-
std::unordered_map<std::tuple<std::string, std::string, std::string, std::string, int>, std::vector<Halide::Buffer<>>> output_buffer_;
151+
std::unordered_map<std::tuple<std::string, std::string, int>, std::vector<Halide::Buffer<>>> output_buffer_;
156152

157153
std::unordered_map<std::string, std::vector<uint8_t>> scalar_buffer_;
158154
};

include/ion/util.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ namespace ion {
77

88
class Port;
99

10-
std::string argument_name(const std::string& pred_id, const std::string& pred_name, const std::string& sucd_id, const std::string& succ_name, int32_t index);
10+
std::string argument_name(const std::string& node_id, const std::string& name, int32_t index);
1111

1212
std::string array_name(const std::string& port_name, size_t i);
1313

src/builder.cc

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -351,9 +351,7 @@ Halide::Pipeline Builder::build(ion::PortMap& pm) {
351351
for (auto kv : output_buffers) {
352352
auto pred_id = std::get<0>(kv.first);
353353
auto pred_name = std::get<1>(kv.first);
354-
auto succ_id = std::get<2>(kv.first);
355-
auto succ_name = std::get<3>(kv.first);
356-
auto index = std::get<4>(kv.first);
354+
auto index = std::get<2>(kv.first);
357355

358356
if (index != -1) {
359357
auto fs = bbs[pred_id]->output_func(pred_name);

src/node.cc

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -31,14 +31,8 @@ void Node::set_iports(const std::vector<Port>& ports) {
3131

3232
auto& port(ports[i]);
3333

34-
if (port.succ_name().empty()) {
35-
port.impl_->succ_name = info.name;
36-
} else if (port.succ_name() != info.name) {
37-
log::error("Port {} does not match name", port.succ_name());
38-
throw std::runtime_error("Failed to validate input port");
39-
}
34+
port.impl_->succ_chans.insert({.node_id=id(), .name=info.name});
4035

41-
port.impl_->succ_id = impl_->id;
4236
impl_->ports.push_back(port);
4337

4438
i++;

src/serializer.h

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -43,10 +43,8 @@ template<>
4343
class adl_serializer<ion::Port> {
4444
public:
4545
static void to_json(json& j, const ion::Port& v) {
46-
j["pred_id"] = v.impl_->pred_id;
47-
j["pred_name"] = v.impl_->pred_name;
48-
j["succ_id"] = v.impl_->succ_id;
49-
j["succ_name"] = v.impl_->succ_name;
46+
j["pred_chan"] = v.impl_->pred_chan;
47+
j["succ_chans"] = v.impl_->succ_chans;
5048
j["type"] = static_cast<halide_type_t>(v.impl_->type);
5149
j["dimensions"] = v.impl_->dimensions;
5250
j["size"] = v.impl_->params.size();
@@ -56,14 +54,12 @@ class adl_serializer<ion::Port> {
5654

5755
static void from_json(const json& j, ion::Port& v) {
5856
v = ion::Port(ion::Port::find_impl(j["impl_ptr"].get<uintptr_t>()));
59-
v.impl_->pred_id = j["pred_id"].get<std::string>();
60-
v.impl_->pred_name = j["pred_name"].get<std::string>();
61-
v.impl_->succ_id = j["succ_id"].get<std::string>();
62-
v.impl_->succ_name = j["succ_name"].get<std::string>();
57+
v.impl_->pred_chan = j["pred_chan"].get<ion::Port::Channel>();
58+
v.impl_->succ_chans = j["succ_chans"].get<std::set<ion::Port::Channel>>();
6359
v.impl_->type = j["type"].get<halide_type_t>();
6460
v.impl_->dimensions = j["dimensions"];
6561
for (auto i=0; i<j["size"]; ++i) {
66-
v.impl_->params[i] = Halide::Internal::Parameter(v.impl_->type, v.impl_->dimensions != 0, v.impl_->dimensions, ion::argument_name(v.impl_->pred_id, v.impl_->pred_name, v.impl_->succ_id, v.impl_->succ_name, i));
62+
v.impl_->params[i] = Halide::Internal::Parameter(v.impl_->type, v.impl_->dimensions != 0, v.impl_->dimensions, ion::argument_name(v.pred_id(), v.pred_name(), i));
6763
}
6864
v.index_ = j["index"];
6965
}

src/util.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,12 @@
55

66
namespace ion {
77

8-
std::string argument_name(const std::string& pred_id, const std::string& pred_name, const std::string& succ_id, const std::string& succ_name, int32_t index) {
8+
std::string argument_name(const std::string& node_id, const std::string& name, int32_t index) {
99
if (index == -1) {
1010
index = 0;
1111
}
1212

13-
std::string s = "_" + pred_id + "_" + pred_name + "_" + succ_id + "_" + succ_name + "_" + std::to_string(index);;
13+
std::string s = "_" + node_id + "_" + name + std::to_string(index);;
1414
std::replace(s.begin(), s.end(), '-', '_');
1515

1616
return s;

0 commit comments

Comments
 (0)