Skip to content

Commit

Permalink
WIP: Port can bind buffer or scalar value
Browse files Browse the repository at this point in the history
  • Loading branch information
Fixstars-iizuka committed Dec 21, 2023
1 parent 95c13a6 commit 39ca5dc
Show file tree
Hide file tree
Showing 4 changed files with 74 additions and 32 deletions.
8 changes: 4 additions & 4 deletions include/ion/node.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,11 +84,11 @@ class Node {

/**
* Retrieve output port of the node.
* @arg key: The key of port name which is matched with first argument of Output declared in user-defined class deriving BuildingBlock.
* @return Port object which is specified by key.
* @arg name: The name of port name which is matched with first argument of Output declared in user-defined class deriving BuildingBlock.
* @return Port object which is specified by name.
*/
Port operator[](const std::string& key) {
return Port(key, impl_->id);
Port operator[](const std::string& name) {
return Port(name, impl_->id);
}

std::string id() const {
Expand Down
22 changes: 22 additions & 0 deletions include/ion/port.h
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,28 @@ class Port {
return port;
}

template<typename T>
void bind(T v) {
auto i = index_ == -1 ? 0 : index_;
impl_->params.resize(i+1, Halide::Internal::Parameter{type(), dimensions() != 0, dimensions(), argument_name(node_id(), name())});
impl_->params[i].set_scalar(v);
}

template<typename T>
void bind(const Halide::Buffer<T>& buf) {
auto i = index_ == -1 ? 0 : index_;
impl_->params.resize(i+1, Halide::Internal::Parameter{type(), dimensions() != 0, dimensions(), argument_name(node_id(), name())});
impl_->params[i].set_buffer(buf);
}

template<typename T>
void bind(const std::vector<Halide::Buffer<T>>& bufs) {
impl_->params.resize(bufs.size(), Halide::Internal::Parameter{type(), dimensions() != 0, dimensions(), argument_name(node_id(), name())});
for (size_t i=0; i<bufs.size(); ++i) {
impl_->params[i].set_buffer(bufs[i]);
}
}

static std::shared_ptr<Impl> find_impl(uintptr_t ptr) {
static std::unordered_map<uintptr_t, std::shared_ptr<Impl>> impls;
static std::mutex mutex;
Expand Down
57 changes: 29 additions & 28 deletions include/ion/port_map.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,17 +65,17 @@ class PortMap {
*/
template<typename T>
void set(Port port, T v) {
auto& params(port.params());
auto i = port.index();
if (i == -1) {
// TODO: It should be a number of array defined at BuildingBlock
i = 0;
}
params.resize(i+1, Halide::Internal::Parameter{port.type(), port.dimensions() != 0, port.dimensions(), argument_name(port.node_id(), port.name())});
params[i].set_scalar(v);
params_[argument_name(port.node_id(), port.name())].resize(i+1);
params_[argument_name(port.node_id(), port.name())][i] = params[i];

// auto& params(port.params());
// auto i = port.index();
// if (i == -1) {
// // TODO: It should be a number of array defined at BuildingBlock
// i = 0;
// }
// params.resize(i+1, Halide::Internal::Parameter{port.type(), port.dimensions() != 0, port.dimensions(), argument_name(port.node_id(), port.name())});
// params[i].set_scalar(v);
// params_[argument_name(port.node_id(), port.name())].resize(i+1);
// params_[argument_name(port.node_id(), port.name())][i] = params[i];
port.bind(v);
dirty_ = true;
}

Expand Down Expand Up @@ -103,16 +103,17 @@ class PortMap {
// This is just an output.
output_buffer_[std::make_tuple(port.node_id(), port.name(), port.index())] = { buf };
} else {
auto& params(port.params());
auto i = port.index();
if (i == -1) {
// TODO: It should be a number of array defined at BuildingBlock
i = 0;
}
params.resize(i+1, Halide::Internal::Parameter{port.type(), port.dimensions() != 0, port.dimensions(), argument_name(port.node_id(), port.name())});
params[i].set_buffer(buf);
params_[argument_name(port.node_id(), port.name())].resize(i+1);
params_[argument_name(port.node_id(), port.name())][i] = params[i];
// auto& params(port.params());
// auto i = port.index();
// if (i == -1) {
// // TODO: It should be a number of array defined at BuildingBlock
// i = 0;
// }
// params.resize(i+1, Halide::Internal::Parameter{port.type(), port.dimensions() != 0, port.dimensions(), argument_name(port.node_id(), port.name())});
// params[i].set_buffer(buf);
// params_[argument_name(port.node_id(), port.name())].resize(i+1);
// params_[argument_name(port.node_id(), port.name())][i] = params[i];
port.bind(buf);
}

dirty_ = true;
Expand Down Expand Up @@ -144,13 +145,13 @@ class PortMap {
output_buffer_[std::make_tuple(port.node_id(), port.name(), port.index())].push_back(buf);
}
} else {
auto& params(port.params());
params.resize(bufs.size(), Halide::Internal::Parameter{port.type(), port.dimensions() != 0, port.dimensions(), argument_name(port.node_id(), port.name())});
for (size_t i=0; i<bufs.size(); ++i) {
params[i].set_buffer(bufs[i]);
}
params_[argument_name(port.node_id(), port.name())] = params;

// auto& params(port.params());
// params.resize(bufs.size(), Halide::Internal::Parameter{port.type(), port.dimensions() != 0, port.dimensions(), argument_name(port.node_id(), port.name())});
// for (size_t i=0; i<bufs.size(); ++i) {
// params[i].set_buffer(bufs[i]);
// }
// params_[argument_name(port.node_id(), port.name())] = params;
port.bind(bufs);
}

dirty_ = true;
Expand Down
19 changes: 19 additions & 0 deletions src/builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,7 @@ Halide::Pipeline Builder::build(ion::PortMap& pm) {
}
} else {
if (arginfo.kind == Halide::Internal::ArgInfoKind::Scalar) {
#if 0
if (pm.is_mapped(argument_name(port.node_id(), port.name()))) {
// This block should be executed when g.run is called with appropriate PortMap.
const auto& params(pm.get_params(argument_name(port.node_id(), port.name())));
Expand Down Expand Up @@ -330,7 +331,15 @@ Halide::Pipeline Builder::build(ion::PortMap& pm) {
}
bb->bind_input(arginfo.name, es);
}
#else
std::vector<Halide::Expr> es;
for (const auto& p : port.params()) {
es.push_back(Halide::Internal::Variable::make(port.type(), argument_name(port.node_id(), port.name()), p));
}
bb->bind_input(arginfo.name, es);
#endif
} else if (arginfo.kind == Halide::Internal::ArgInfoKind::Function) {
#if 0
if (pm.is_mapped(argument_name(port.node_id(), port.name()))) {
// This block should be executed when g.run is called with appropriate PortMap.
const auto& params(pm.get_params(argument_name(port.node_id(), port.name())));
Expand All @@ -347,6 +356,16 @@ Halide::Pipeline Builder::build(ion::PortMap& pm) {
} else {
bb->bind_input(arginfo.name, { Halide::ImageParam(port.type(), port.dimensions(), argument_name(port.node_id(), port.name()))});
}
#else
std::vector<Halide::Func> fs;
for (const auto& p : port.params()) {
auto b(p.buffer());
Halide::Func f;
f(Halide::_) = b(Halide::_);
fs.push_back(f);
}
bb->bind_input(arginfo.name, fs);
#endif
} else {
throw std::runtime_error("fixme");
}
Expand Down

0 comments on commit 39ca5dc

Please sign in to comment.