Skip to content

Commit

Permalink
Merge pull request #284 from fixstars/update/add-dynamic-port-binding
Browse files Browse the repository at this point in the history
Update/add dynamic port binding
  • Loading branch information
iitaku authored Jun 20, 2024
2 parents 7f9be03 + e19e755 commit ff09aa0
Show file tree
Hide file tree
Showing 11 changed files with 216 additions and 20 deletions.
7 changes: 7 additions & 0 deletions include/ion/node.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,8 @@ class Node {

void set_iport(const std::string& name, Port port);

void set_oport(Port port);

/**
* Retrieve relevant port of the node.
* @arg name: The name of port name which is matched with first argument of Input/Output declared in user-defined class deriving BuildingBlock.
Expand Down Expand Up @@ -121,6 +123,11 @@ class Node {
Port oport(const std::string& pn);
std::vector<std::tuple<std::string, Port>> oports() const;

std::vector<std::tuple<std::string, Port>> unbound_iports() const;
std::vector<std::tuple<std::string, Port>> unbound_oports() const;

void detect_data_hazard ()const ;

private:
Node(const NodeID& id, const std::string& name, const Halide::Target& target)
: impl_(new Impl{id, name, target})
Expand Down
34 changes: 22 additions & 12 deletions include/ion/port.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ class Port {
std::unordered_map<uint32_t, Halide::Parameter> params;
std::unordered_map<uint32_t, const void *> instances;

std::unordered_map<uint32_t, std::tuple<const void *, bool> > bound_address;

Impl();
Impl(const NodeID& nid, const std::string& pn, const Halide::Type& t, int32_t d, const GraphID &gid );
};
Expand Down Expand Up @@ -180,38 +182,43 @@ class Port {
template<typename T>
void bind(T *v) {
auto i = index_ == -1 ? 0 : index_;

if (has_pred()) {
impl_->params[i] = Halide::Parameter{Halide::type_of<T>(), false, 0, argument_name(pred_id(), pred_name(), i, graph_id())};
impl_->params[i] = Halide::Parameter{Halide::type_of<T>(), false, 0, argument_name(pred_id(), id(), pred_name(), i, graph_id())};
} else {
impl_->params[i] = Halide::Parameter{type(), false, dimensions(), argument_name(pred_id(), pred_name(), i, graph_id())};
impl_->params[i] = Halide::Parameter{type(), false, dimensions(), argument_name(pred_id(), id(), pred_name(), i, graph_id())};
}

impl_->instances[i] = v;
impl_->bound_address[i] = std::make_tuple(v,false);
}

template<typename T>
void bind(const Halide::Buffer<T>& buf) {
auto i = index_ == -1 ? 0 : index_;
if (has_pred()) {
impl_->params[i] = Halide::Parameter{buf.type(), true, buf.dimensions(), argument_name(pred_id(), pred_name(), i,graph_id())};
impl_->params[i] = Halide::Parameter{buf.type(), true, buf.dimensions(), argument_name(pred_id(), id(), pred_name(), i,graph_id())};
} else {
impl_->params[i] = Halide::Parameter{type(), true, dimensions(), argument_name(pred_id(), pred_name(), i,graph_id())};
impl_->params[i] = Halide::Parameter{type(), true, dimensions(), argument_name(pred_id(), id(), pred_name(), i,graph_id())};
}

impl_->instances[i] = buf.raw_buffer();
impl_->bound_address[i] = std::make_tuple(buf.data(),false);
}

template<typename T>
void bind(const std::vector<Halide::Buffer<T>>& bufs) {
for (int i=0; i<static_cast<int>(bufs.size()); ++i) {
if (has_pred()) {
impl_->params[i] = Halide::Parameter{bufs[i].type(), true, bufs[i].dimensions(), argument_name(pred_id(), pred_name(), i, graph_id())};
impl_->params[i] = Halide::Parameter{bufs[i].type(), true, bufs[i].dimensions(), argument_name(pred_id(), id(),pred_name(), i, graph_id())};
} else {
impl_->params[i] = Halide::Parameter{type(), true, dimensions(), argument_name(pred_id(), pred_name(), i, graph_id())};
impl_->params[i] = Halide::Parameter{type(), true, dimensions(), argument_name(pred_id(), id(), pred_name(), i, graph_id())};
}

impl_->instances[i] = bufs[i].raw_buffer();
impl_->bound_address[i] = std::make_tuple(bufs[i].data(),false);
}

}

static std::tuple<std::shared_ptr<Impl>, bool> find_impl(const std::string& id);
Expand All @@ -226,15 +233,15 @@ class Port {
if (es.size() <= i) {
es.resize(i+1, Halide::Expr());
}
es[i] = Halide::Internal::Variable::make(type(), argument_name(pred_id(), pred_name(), i, graph_id()), param);
es[i] = Halide::Internal::Variable::make(type(), argument_name(pred_id(), id(), pred_name(), i, graph_id()), param);
}
return es;
}

std::vector<Halide::Func> as_func() const {
if (dimensions() == 0) {
throw std::runtime_error("Unreachable");
}
// if (dimensions() == 0) {
// throw std::runtime_error("Unreachable");
// }

std::vector<Halide::Func> fs;
for (const auto& [i, param] : impl_->params ) {
Expand All @@ -247,9 +254,12 @@ class Port {
args.push_back(Halide::Var::implicit(i));
args_expr.push_back(Halide::Var::implicit(i));
}
Halide::Func f(param.type(), param.dimensions(), argument_name(pred_id(), pred_name(), i, graph_id()) + "_im");
Halide::Func f(param.type(), param.dimensions(), argument_name(pred_id(), id(), pred_name(), i, graph_id()) + "_im");
f(args) = Halide::Internal::Call::make(param, args_expr);
fs[i] = f;
if(std::get<1>(impl_->bound_address[i])){
f.compute_root();
}
}
return fs;
}
Expand All @@ -261,7 +271,7 @@ class Port {
args.resize(i+1, Halide::Argument());
}
auto kind = dimensions() == 0 ? Halide::Argument::InputScalar : Halide::Argument::InputBuffer;
args[i] = Halide::Argument(argument_name(pred_id(), pred_name(), i, graph_id()), kind, type(), dimensions(), Halide::ArgumentEstimates());
args[i] = Halide::Argument(argument_name(pred_id(), id(), pred_name(), i, graph_id()), kind, type(), dimensions(), Halide::ArgumentEstimates());
}
return args;
}
Expand Down
2 changes: 1 addition & 1 deletion include/ion/util.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ using NodeID = StringID<node_tag>;
using GraphID = StringID<graph_tag>;
using PortID = StringID<port_tag>;

std::string argument_name(const NodeID& node_id, const std::string& name, int32_t index, const GraphID& graph_id);
std::string argument_name(const NodeID& node_id, const PortID & portId, const std::string& name, int32_t index, const GraphID& graph_id);

} // namespace ion

Expand Down
5 changes: 5 additions & 0 deletions src/lower.cc
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,11 @@ Halide::Pipeline lower(Builder builder, std::vector<Node>& nodes, bool implicit_
// This operation is required especially for the graph which is loaded from JSON definition.
topological_sort(nodes);

// detect data hazard, If the input port is bound to the same address as the output port, call compute_root first
for (auto n : nodes) {
n.detect_data_hazard();
}

// Constructing Generator object and setting static parameters
std::unordered_map<NodeID, Halide::Internal::AbstractGeneratorPtr, NodeID::StringIDHash> bbs;
for (auto n : nodes) {
Expand Down
79 changes: 77 additions & 2 deletions src/node.cc
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,7 @@ Port Node::operator[](const std::string& name) {
// This is output port which is never referenced.
// Bind myself as a predecessor and register
Port port(id(), name);
port.impl_ ->graph_id = impl_->graph_id;
impl_->ports.push_back(port);
set_oport(port);
return port;
} else {
// Port is already registered
Expand Down Expand Up @@ -111,6 +110,38 @@ std::vector<std::tuple<std::string, Port>> Node::iports() const {
return iports;
}


std::vector<std::tuple<std::string, Port>> Node::unbound_iports() const {
std::vector<std::tuple<std::string, Port>> unbound_iports;
int iports_size = 0;

for (const auto& p: impl_->ports) {
auto it = std::find_if(p.impl_->succ_chans.begin(), p.impl_->succ_chans.end(),
[&](const Port::Channel& c) { return std::get<0>(c) == impl_->id; });
if (it != p.impl_->succ_chans.end()) {
iports_size+=1;
}
}

int iports_idx = 0;
for (auto & arginfo: impl_->arginfos){
if (arginfo.dir == Halide::Internal::ArgInfoDirection::Input) {
if(iports_idx>=iports_size){
Port port("_ion_iport_" + std::to_string(iports_idx), arginfo.types.front());
port.impl_->dimensions = arginfo.dimensions;
unbound_iports.push_back(std::make_tuple(arginfo.name, port));
}
iports_idx ++;
}
}
return unbound_iports;
}

void Node::set_oport(Port port) {
port.impl_ ->graph_id = impl_->graph_id;
impl_->ports.push_back(port);
}

Port Node::oport(const std::string& pn) {
return this->operator[](pn);

Expand Down Expand Up @@ -138,4 +169,48 @@ std::vector<std::tuple<std::string, Port>> Node::oports() const {
return oports;
}

std::vector<std::tuple<std::string, Port>> Node::unbound_oports() const {
std::vector<std::tuple<std::string, Port>> unbound_oports;
int oports_size = 0;

for (const auto& p: impl_->ports) {
if (id() == p.pred_id()) {
oports_size +=1;
}
}
int oports_idx = 0;
for (auto & arginfo: impl_->arginfos){
if (arginfo.dir == Halide::Internal::ArgInfoDirection::Output) {
if(oports_idx>=oports_size){
Port port(id(), arginfo.name);
port.impl_ ->type = arginfo.types.front();
port.impl_->dimensions = arginfo.dimensions;
unbound_oports.push_back(std::make_tuple(arginfo.name, port));
}
oports_idx ++;
}
}
return unbound_oports;
}

void Node::detect_data_hazard ()const {
std::vector<std::tuple<std::string, Port>> oports = Node::oports() ;
std::vector<std::tuple<std::string, Port>> iports = Node::iports() ;
std::set<std::tuple<const void *, bool>> address_set;

for (auto& [pn, port] :oports) {
for(auto& [i, t] : port.impl_->bound_address){
address_set.insert(t);
}
}

for (auto& [pn, port] :iports) {
for(auto& [i, t] : port.impl_->bound_address){
if (address_set.find(t) != address_set.end()) {
std::get<1>(t) = true;
}
}
}
};

} // namespace ion
2 changes: 1 addition & 1 deletion src/port.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ Port::Impl::Impl()
Port::Impl::Impl(const NodeID & nid, const std::string& pn, const Halide::Type& t, int32_t d, const GraphID & gid)
: id(PortID(sole::uuid4().str())), pred_chan{nid, pn}, succ_chans{}, type(t), dimensions(d), graph_id(gid)
{
params[0] = Halide::Parameter(type, dimensions != 0, dimensions, argument_name(nid, pn, 0, gid));
params[0] = Halide::Parameter(type, dimensions != 0, dimensions, argument_name(nid, id, pn, 0, gid));
}

void Port::determine_succ(const NodeID& nid, const std::string& old_pn, const std::string& new_pn) {
Expand Down
2 changes: 1 addition & 1 deletion src/serializer.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ struct adl_serializer<ion::Port> {
impl->dimensions = j["dimensions"];
for (auto i=0; i<j["size"]; ++i) {
impl->params[i] = Halide::Parameter(impl->type, impl->dimensions != 0, impl->dimensions,
ion::argument_name(std::get<0>(impl->pred_chan), std::get<1>(impl->pred_chan), i, impl->graph_id.value()));
ion::argument_name(std::get<0>(impl->pred_chan), impl->id, std::get<1>(impl->pred_chan), i, impl->graph_id.value()));
}
}
v = ion::Port(impl, j["index"]);
Expand Down
5 changes: 2 additions & 3 deletions src/util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,11 @@

namespace ion {

std::string argument_name(const NodeID & node_id, const std::string& name, int32_t index, const GraphID & graph_id) {
std::string argument_name(const NodeID & node_id, const PortID & portId, const std::string& name, int32_t index, const GraphID & graph_id) {
if (index == -1) {
index = 0;
}

std::string s = "_" + node_id.value() + "_" + name + std::to_string(index) + "_" + graph_id.value();
std::string s = "_" + node_id.value() + "_" + portId.value() + "_" + name + std::to_string(index) + "_" + graph_id.value();
std::replace(s.begin(), s.end(), '-', '_');

return s;
Expand Down
3 changes: 3 additions & 0 deletions test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,9 @@ ion_jit_executable(validation SRCS validation.cc)
# Multi
ion_jit_executable(multi_pipe SRCS multi_pipe.cc)

# Unbound Binding
ion_jit_executable(unbound_binding SRCS unbound_binding.cc)

# Graph
ion_jit_executable(graph SRCS graph.cc)

Expand Down
19 changes: 19 additions & 0 deletions test/test-bb.h
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,24 @@ class SubI32x2 : public BuildingBlock<SubI32x2> {
}
};

class IncByOffset : public BuildingBlock<IncByOffset> {
public:
Input<Halide::Func> input{"input", Int(32), 2};
Input<Halide::Func> input_offset{"input_offset", Int(32), 0}; // to imitate scalar input
BuildingBlockParam<int32_t> v{"v", 1};
Output<Halide::Func> output{"output", Int(32), 2};
Output<int32_t> output_offset{"output_offset"};

void generate() {
output(x, y) = input(x, y) + input_offset();
output_offset() = input_offset() + v;
}

private:
Halide::Var x, y;
};


} // test
} // bb
} // ion
Expand All @@ -313,5 +331,6 @@ ION_REGISTER_BUILDING_BLOCK(ion::bb::test::ArrayCopy, test_array_copy);
ION_REGISTER_BUILDING_BLOCK(ion::bb::test::ExternIncI32x2, test_extern_inc_i32x2);
ION_REGISTER_BUILDING_BLOCK(ion::bb::test::AddI32x2, test_add_i32x2);
ION_REGISTER_BUILDING_BLOCK(ion::bb::test::SubI32x2, test_sub_i32x2);
ION_REGISTER_BUILDING_BLOCK(ion::bb::test::IncByOffset, test_inc_by_offset);

#endif
78 changes: 78 additions & 0 deletions test/unbound_binding.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
#include <exception>

#include "ion/ion.h"

#include "test-bb.h"
#include "test-rt.h"

using namespace std;
using namespace ion;

int main() {
try {

{
constexpr size_t h = 4, w = 4;

Halide::Buffer<int32_t> in(w, h);
in.fill(42);

Halide::Buffer<int32_t> out(w, h);
out.fill(0);

Builder b;
Target target = Halide::get_host_target();
target.set_feature(Target::Debug);
b.set_target(target);

auto n = b.add("test_inc_by_offset")(in);

n = b.add("test_inc_by_offset")(n["output"]);

n["output"].bind(out);

Halide::Buffer<int32_t> param_buf = Halide::Buffer<int32_t>::make_scalar();
param_buf.fill(1);

// std::vector< int > sizes;
// Halide::Buffer<int32_t> param_buf1(param_buf.data(),sizes);

for(auto &n:b.nodes()){
for (auto& [pn, port] : n.unbound_iports()) {
port.bind(param_buf);
n.set_iport(port);
}

for (auto& [pn, port] : n.unbound_oports()) {
port.bind(param_buf);
n.set_oport(port);
}

}

b.run();
for (int y = 0; y < h; ++y) {
for (int x = 0; x < w; ++x) {
if (out(0,0) != 45 ) {
throw runtime_error("Unexpected out value");
}
}
}

if (param_buf(0) != 3 ) {
throw runtime_error("Unexpected value");
}
}

} catch (const Halide::Error& e) {
std::cerr << e.what() << std::endl;
return 1;
} catch (const std::exception& e) {
std::cerr << e.what() << std::endl;
return 1;
}

std::cout << "Passed" << std::endl;

return 0;
}

0 comments on commit ff09aa0

Please sign in to comment.