Skip to content

Commit

Permalink
Fixed output binding
Browse files Browse the repository at this point in the history
  • Loading branch information
Fixstars-iizuka committed Dec 30, 2023
1 parent 224c803 commit 5d5603c
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 77 deletions.
2 changes: 1 addition & 1 deletion include/ion/builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ class Builder {

private:

Halide::Pipeline build(ion::PortMap& ports);
Halide::Pipeline build(bool implicit_output = false);

std::vector<Halide::Argument> get_arguments_stub() const {
std::set<Port::Channel> added_ports;
Expand Down
100 changes: 27 additions & 73 deletions src/builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,7 @@ using json = nlohmann::json;
Builder::Builder()
: jit_ctx_(new Halide::JITUserContext), jit_ctx_ptr_(jit_ctx_.get())
{
#if SW
args_.push_back(&jit_ctx_ptr_);
#endif
}

Builder::~Builder()
Expand Down Expand Up @@ -138,8 +136,7 @@ void Builder::compile(const std::string& function_name, const CompileOption& opt
using namespace Halide;

// Build pipeline and module first
PortMap pm;
Pipeline p = build(pm);
Pipeline p = build(true);
if (!p.defined()) {
log::warn("This pipeline doesn't produce any outputs. Please bind a buffer with output port.");
return;
Expand Down Expand Up @@ -193,7 +190,7 @@ void Builder::run(void) {

void Builder::run(ion::PortMap& pm) {
if (!pipeline_.defined()) {
pipeline_ = build(pm);
pipeline_ = build();
if (!pipeline_.defined()) {
log::warn("This pipeline doesn't produce any outputs. Please bind a buffer with output port.");
return;
Expand All @@ -217,27 +214,10 @@ void Builder::run(ion::PortMap& pm) {
callable_ = pipeline_.compile_to_callable(get_arguments_stub(), target_);
}

#if !SW
if (pm.dirty()) {
args_.clear();
args_.push_back(&jit_ctx_ptr_);

auto args = get_arguments_instance();
args_.insert(args_.end(), args.begin(), args.end());

for (auto kv : pm.get_output_buffer()) {
for (auto b : kv.second) {
args_.push_back(b.raw_buffer());
}
}

pm.updated();
}
#endif
callable_.call_argv_fast(args_.size(), args_.data());
}

Halide::Pipeline Builder::build(ion::PortMap& pm) {
Halide::Pipeline Builder::build(bool implicit_output) {

log::info("Start building pipeline");

Expand Down Expand Up @@ -304,7 +284,7 @@ Halide::Pipeline Builder::build(ion::PortMap& pm) {
} else {
throw std::runtime_error("fixme");
}
#if SW

// Adding input args
if (added_args.count(port.impl_->pred_chan)) {
continue;
Expand All @@ -313,57 +293,37 @@ Halide::Pipeline Builder::build(ion::PortMap& pm) {

const auto& port_instances(port.as_instance());
args_.insert(args_.end(), port_instances.begin(), port_instances.end());
#endif
}
}
bb->build_pipeline();
}

std::vector<Halide::Func> output_funcs;
#if SW
for (const auto& node : nodes_) {
for (const auto& port : node.oports()) {
// if (port.has_succ()) {
// continue;
// }

const auto& port_instances(port.as_instance());
if (port_instances.empty()) {
continue;
}

auto fs(bbs[port.pred_id()]->output_func(port.pred_name()));
output_funcs.insert(output_funcs.end(), fs.begin(), fs.end());
args_.insert(args_.end(), port_instances.begin(), port_instances.end());
}
}
#else
const auto& output_buffers(pm.get_output_buffer());
if (output_buffers.empty()) {
// This is implicit mode. Make output list based on unbound output in the graph.
// Traverses bbs and bundles all outputs
// TODO: Now this can be more simplified by finding ports which doesn't have succ_id
std::unordered_map<std::string, std::vector<std::string>> dereferenced;
if (implicit_output) {
// Collects all output which is never referenced.
// This mode is used for AOT compilation
std::unordered_map<std::string, std::vector<std::string>> referenced;
for (const auto& n : nodes_) {
for (const auto& port : n.iports()) {
auto pred_id = port.pred_id();
if (!pred_id.empty()) {
for (const auto &f : bbs[pred_id]->output_func(port.pred_name())) {
dereferenced[pred_id].emplace_back(f.name());
if (port.has_pred()) {
for (const auto &f : bbs[port.pred_id()]->output_func(port.pred_name())) {
referenced[port.pred_id()].emplace_back(f.name());
}
}
}
}
for (int i=0; i<nodes_.size(); ++i) {
auto node_id = nodes_[i].id();

for (const auto& node : nodes_) {
auto node_id = node.id();
for (auto arginfo : bbs[node_id]->arginfos()) {
if (arginfo.dir != Halide::Internal::ArgInfoDirection::Output) {
// This is not output
continue;
}

// This is not output
// It is not dereferenced, then treat as outputs
const auto& dv = dereferenced[node_id];
const auto& dv = referenced[node_id];

for (auto f : bbs[node_id]->output_func(arginfo.name)) {
auto it = std::find(dv.begin(), dv.end(), f.name());
Expand All @@ -374,28 +334,22 @@ Halide::Pipeline Builder::build(ion::PortMap& pm) {
}
}
}

} else {
// This is expliti mode, mainly used in JIT compilation
for (auto kv : output_buffers) {
auto pred_id = std::get<0>(kv.first);
auto pred_name = std::get<1>(kv.first);
auto index = std::get<2>(kv.first);

if (index != -1) {
auto fs = bbs[pred_id]->output_func(pred_name);
if (index >= fs.size()) {
throw std::runtime_error("Port index out of range: " + pred_id + ", " + pred_name);
}
output_funcs.push_back(fs[index]);
} else {
for (auto f : bbs[pred_id]->output_func(pred_name)) {
output_funcs.push_back(f);
// Collects all output which is bound with buffer.
// This mode is used for JIT
for (const auto& node : nodes_) {
for (const auto& port : node.oports()) {
const auto& port_instances(port.as_instance());
if (port_instances.empty()) {
continue;
}

auto fs(bbs[port.pred_id()]->output_func(port.pred_name()));
output_funcs.insert(output_funcs.end(), fs.begin(), fs.end());
args_.insert(args_.end(), port_instances.begin(), port_instances.end());
}
}
}
#endif

if (output_funcs.empty()) {
return Halide::Pipeline();
Expand Down
3 changes: 0 additions & 3 deletions test/dup-port-name.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,6 @@ int main()
n = b.add("test_branch")(input, width, height);
n = b.add("test_merge")(n["output0"], n["output1"], height);

ion::Buffer<int32_t> obuf(16, 16);
n["output"].bind(obuf);

b.compile("complex_graph");
} catch (const Halide::Error& e) {
std::cerr << e.what() << std::endl;
Expand Down

0 comments on commit 5d5603c

Please sign in to comment.