diff --git a/include/ion/port.h b/include/ion/port.h index 8800e39f..422baa4f 100644 --- a/include/ion/port.h +++ b/include/ion/port.h @@ -26,9 +26,7 @@ class Port { * @arg t: The type of the value. */ Port(const std::string& k, Halide::Type t) - : key_(k), type_(t), dimensions_(0), index_(-1), node_id_(), - param_(t, false, 0, k) - {} + : key_(k), type_(t), dimensions_(0), index_(-1), node_id_() {} /** * Construct new port for vector value. @@ -37,9 +35,7 @@ class Port { * @arg d: The dimension of the port. The range is 1 to 4. */ Port(const std::string& k, Halide::Type t, int32_t d) - : key_(k), type_(t), dimensions_(d), index_(-1), node_id_(), - param_(t, true, d, k) - {} + : key_(k), type_(t), dimensions_(d), index_(-1), node_id_() {} std::string key() const { return key_; } std::string& key() { return key_; } @@ -56,8 +52,13 @@ class Port { std::string node_id() const { return node_id_; } std::string& node_id() { return node_id_; } - Halide::Internal::Parameter& param() { - return param_; + std::vector& params() { + if (index_ == -1) { + params_.resize(1, Halide::Internal::Parameter{type_, dimensions_ != 0, dimensions_, key_}); + } else { + params_.resize(index_+1, Halide::Internal::Parameter{type_, dimensions_ != 0, dimensions_, key_}); + } + return params_; } bool is_bound() const { @@ -89,7 +90,7 @@ class Port { int32_t index_; std::string node_id_; - Halide::Internal::Parameter param_; + std::vector params_; }; } // namespace ion diff --git a/include/ion/port_map.h b/include/ion/port_map.h index 72ecd5b2..9f08e159 100644 --- a/include/ion/port_map.h +++ b/include/ion/port_map.h @@ -65,9 +65,16 @@ class PortMap { */ template void set(Port p, T v) { - auto param = p.param(); - param.set_scalar(v); - param_[(p.key())] = param; + auto params(p.params()); + auto i = p.index(); + if (i == -1) { + params[0].set_scalar(v); + params_[(p.key())] = params; + } else { + params[i].set_scalar(v); + params_[(p.key())].resize(i+1); + params_[(p.key())][i] = params[i]; + } dirty_ = true; } @@ -95,9 +102,16 @@ class PortMap { // This is just an output. output_buffer_[std::make_tuple(p.node_id(), p.key(), p.index())] = { buf }; } else { - auto param = p.param(); - param.set_buffer(buf); - param_[p.key()] = param; + auto params(p.params()); + auto i = p.index(); + if (i == -1) { + params[0].set_buffer(buf); + params_[p.key()] = params; + } else { + params[i].set_buffer(buf); + params_[p.key()].resize(i+1); + params_[p.key()][i] = params[i]; + } } dirty_ = true; @@ -138,11 +152,11 @@ class PortMap { // } bool is_mapped(const std::string& k) const { - return param_.count(k); + return params_.count(k); } - Halide::Internal::Parameter get_param(const std::string& k) const { - return param_.at(k); + std::vector get_params(const std::string& k) const { + return params_.at(k); } std::unordered_map, std::vector>> get_output_buffer() const { @@ -151,20 +165,24 @@ class PortMap { std::vector get_arguments_stub() const { std::vector args; - for (auto kv : param_) { - auto kind = kv.second.is_buffer() ? Halide::Argument::InputBuffer : Halide::Argument::InputScalar; - args.push_back(Halide::Argument(kv.first, kind, kv.second.type(), kv.second.dimensions(), Halide::ArgumentEstimates())); + for (const auto& kv : params_) { + for (const auto& p : kv.second) { + auto kind = p.is_buffer() ? Halide::Argument::InputBuffer : Halide::Argument::InputScalar; + args.push_back(Halide::Argument(kv.first, kind, p.type(), p.dimensions(), Halide::ArgumentEstimates())); + } } return args; } std::vector get_arguments_instance() const { std::vector args; - for (auto kv : param_) { - if (kv.second.is_buffer()) { - args.push_back(kv.second.raw_buffer()); - } else { - args.push_back(kv.second.scalar_address()); + for (const auto& kv : params_) { + for (const auto& p : kv.second) { + if (p.is_buffer()) { + args.push_back(p.raw_buffer()); + } else { + args.push_back(p.scalar_address()); + } } } return args; @@ -180,7 +198,7 @@ class PortMap { private: bool dirty_; - std::unordered_map param_; + std::unordered_map> params_; std::unordered_map, std::vector>> output_buffer_; }; diff --git a/src/builder.cc b/src/builder.cc index e75681cc..936bd54a 100644 --- a/src/builder.cc +++ b/src/builder.cc @@ -36,16 +36,16 @@ std::map compute_output_files(const Halide: bool is_ready(const std::vector& sorted, const Node& n) { bool ready = true; - for (auto p : n.ports()) { + for (auto port : n.ports()) { // This port has external dependency. Always ready to add. - if (p.node_id().empty()) { + if (port.node_id().empty()) { continue; } // Check port dependent node is already added ready &= std::find_if(sorted.begin(), sorted.end(), - [&p](const Node& n) { - return n.id() == p.node_id(); + [&port](const Node& n) { + return n.id() == port.node_id(); }) != sorted.end(); } return ready; @@ -260,23 +260,23 @@ Halide::Pipeline Builder::build(ion::PortMap& pm) { const auto& bb = bbs[n.id()]; auto arginfos = bb->arginfos(); for (size_t j=0; joutput_func(p.key()); + if (port.is_bound()) { + auto fs = bbs[port.node_id()]->output_func(port.key()); if (arginfo.kind == Halide::Internal::ArgInfoKind::Scalar) { bb->bind_input(arginfo.name, fs); } else if (arginfo.kind == Halide::Internal::ArgInfoKind::Function) { - auto fs = bbs[p.node_id()]->output_func(p.key()); + auto fs = bbs[port.node_id()]->output_func(port.key()); // no specific index provided, direct output Port if (index == -1) { bb->bind_input(arginfo.name, fs); } else { // access to Port[index] if (index>=fs.size()){ - throw std::runtime_error("Port index out of range: " + p.key()); + throw std::runtime_error("Port index out of range: " + port.key()); } bb->bind_input(arginfo.name, {fs[index]}); } @@ -285,10 +285,9 @@ Halide::Pipeline Builder::build(ion::PortMap& pm) { } } else { if (arginfo.kind == Halide::Internal::ArgInfoKind::Scalar) { - if (pm.is_mapped(p.key())) { + if (pm.is_mapped(port.key())) { // This block should be executed when g.run is called with appropriate PortMap. - // const std::vector& vs = { pm.get_param_expr(p.key()) }; - auto param = pm.get_param(p.key()); + const auto& params(pm.get_params(port.key())); // validation // if (arginfo.types.size() != vs.size()) { @@ -296,24 +295,45 @@ Halide::Pipeline Builder::build(ion::PortMap& pm) { // } // for (auto i=0; iname(), Halide::type_to_c_type(arginfo.types[i], false), p.key(), Halide::type_to_c_type(vs[i].type(), false)); + // log::error("Type mismatch: BB {} expects {}, but port {} has {}", bb->name(), Halide::type_to_c_type(arginfo.types[i], false), port.key(), Halide::type_to_c_type(vs[i].type(), false)); // } // } // validation - bb->bind_input(arginfo.name, { Halide::Internal::Variable::make(p.type(), p.key(), param) }); + + std::vector es; + for (const auto& p : params) { + es.push_back(Halide::Internal::Variable::make(port.type(), port.key(), p)); + } + bb->bind_input(arginfo.name, es); } else { - bb->bind_input(arginfo.name, { Halide::Internal::Variable::make(p.type(), p.key(), p.param()) }); + std::vector es; + for (const auto& p : port.params()) { + es.push_back(Halide::Internal::Variable::make(port.type(), port.key(), p)); + } + bb->bind_input(arginfo.name, es); } } else if (arginfo.kind == Halide::Internal::ArgInfoKind::Function) { - if (pm.is_mapped(p.key())) { + if (pm.is_mapped(port.key())) { // This block should be executed when g.run is called with appropriate PortMap. - auto b = pm.get_param(p.key()).buffer(); - Halide::Func f; - f(Halide::_) = b(Halide::_); - bb->bind_input(arginfo.name, { f }); + const auto& params(pm.get_params(port.key())); + + std::vector fs; + for (const auto& p : params) { + auto b(p.buffer()); + Halide::Func f; + f(Halide::_) = b(Halide::_); + fs.push_back(f); + } + + bb->bind_input(arginfo.name, fs); } else { - Halide::ImageParam param(p.type(), p.dimensions(), p.key()); - bb->bind_input(arginfo.name, { param }); + std::vector fs; + if (port.index() == -1) { + fs.resize(1, Halide::ImageParam(port.type(), port.dimensions(), port.key())); + } else { + fs.resize(port.index()+1, Halide::ImageParam(port.type(), port.dimensions(), port.key())); + } + bb->bind_input(arginfo.name, fs); } } else { throw std::runtime_error("fixme"); @@ -331,10 +351,10 @@ Halide::Pipeline Builder::build(ion::PortMap& pm) { // Traverse bbs and bundling all outputs std::unordered_map> dereferenced; for (const auto& n : nodes_) { - for (const auto& p : n.ports()) { - auto node_id = p.node_id(); + for (const auto& port : n.ports()) { + auto node_id = port.node_id(); if (!node_id.empty()) { - for (const auto &f : bbs[node_id]->output_func(p.key())) { + for (const auto &f : bbs[node_id]->output_func(port.key())) { dereferenced[node_id].emplace_back(f.name()); } } diff --git a/src/serializer.h b/src/serializer.h index f7fb39f2..e90eb20f 100644 --- a/src/serializer.h +++ b/src/serializer.h @@ -46,6 +46,7 @@ class adl_serializer { j["key_"] = v.key(); j["type_"] = static_cast(v.type()); j["dimensions_"] = v.dimensions(); + j["index_"] = v.index(); j["node_id_"] = v.node_id(); } @@ -53,14 +54,14 @@ class adl_serializer { v.key() = j["key_"].get(); v.type() = j["type_"].get(); v.dimensions() = j["dimensions_"]; + v.index() = j["index_"]; v.node_id() = j["node_id_"].get(); if (v.node_id().empty()) { - // if (v.dimensions() == 0) { - // v.expr() = Halide::Internal::Variable::make(v.type(), v.key(), Halide::Internal::Parameter(v.type(), false, 0, v.key())); - // } else { - // v.func() = Halide::ImageParam(v.type(), v.dimensions(), v.key()); - // } - v.param() = Halide::Internal::Parameter(v.type(), v.dimensions() != 0, v.dimensions(), v.key()); + if (v.index() == -1) { + v.params() = { Halide::Internal::Parameter(v.type(), v.dimensions() != 0, v.dimensions(), v.key()) }; + } else { + v.params() = std::vector(v.index()+1, Halide::Internal::Parameter{v.type(), v.dimensions() != 0, v.dimensions(), v.key()}); + } } } }; diff --git a/test/array_input.cc b/test/array_input.cc index 35654a24..d2c7fe96 100644 --- a/test/array_input.cc +++ b/test/array_input.cc @@ -15,7 +15,9 @@ int main() { Port input{"input", Halide::type_of(), 2}; Builder b; b.set_target(Halide::get_host_target()); - auto n = b.add("test_array_input")(input); + Node n; + n = b.add("test_array_copy")(input).set_param(Param{"len", std::to_string(len)}); + n = b.add("test_array_input")(n["array_output"]).set_param(Param{"len", std::to_string(len)}); Halide::Buffer in0(w, h), in1(w, h), in2(w, h), in3(w, h), in4(w, h); diff --git a/test/array_output.cc b/test/array_output.cc index a7e06ebe..010a1aef 100644 --- a/test/array_output.cc +++ b/test/array_output.cc @@ -15,7 +15,9 @@ int main() { Port input{"input", Halide::type_of(), 2}; Builder b; b.set_target(Halide::get_host_target()); - auto n = b.add("test_array_output")(input).set_param(Param{"len", std::to_string(len)}); + Node n; + n = b.add("test_array_output")(input).set_param(Param{"len", std::to_string(len)}); + n = b.add("test_array_copy")(n["array_output"]).set_param(Param{"len", std::to_string(len)}); Halide::Buffer in(w, h); for (int y = 0; y < h; ++y) { diff --git a/test/test-bb.h b/test/test-bb.h index dea4c31b..e97affd4 100644 --- a/test/test-bb.h +++ b/test/test-bb.h @@ -163,6 +163,25 @@ class MultiOut : public BuildingBlock { Halide::Var x, y, c; }; +class ArrayInput : public BuildingBlock { +public: + GeneratorParam len{"len", 5}; + + Input array_input{"array_input", Int(32), 2}; + Output output{"output", Int(32), 2}; + + void generate() { + Halide::Expr v = 0; + for (int i = 0; i < len; ++i) { + v += array_input[i](x, y); + } + output(x, y) = v; + } + +private: + Halide::Var x, y; +}; + class ArrayOutput : public BuildingBlock { public: GeneratorParam len{"len", 5}; @@ -181,25 +200,25 @@ class ArrayOutput : public BuildingBlock { Halide::Var x, y; }; -class ArrayInput : public BuildingBlock { +class ArrayCopy : public BuildingBlock { public: GeneratorParam len{"len", 5}; Input array_input{"array_input", Int(32), 2}; - Output output{"output", Int(32), 2}; + Output array_output{"array_output", Int(32), 2}; void generate() { - Halide::Expr v = 0; + array_output.resize(len); for (int i = 0; i < len; ++i) { - v += array_input[i](x, y); + array_output[i](x, y) = array_input[i](x, y); } - output(x, y) = v; } private: Halide::Var x, y; }; + class ExternIncI32x2 : public BuildingBlock { public: GeneratorParam v{"v", 0}; @@ -235,8 +254,9 @@ ION_REGISTER_BUILDING_BLOCK(ion::bb::test::IncI32x2, test_inc_i32x2); ION_REGISTER_BUILDING_BLOCK(ion::bb::test::Dup, test_dup); ION_REGISTER_BUILDING_BLOCK(ion::bb::test::Scale2x, test_scale2x); ION_REGISTER_BUILDING_BLOCK(ion::bb::test::MultiOut, test_multi_out); -ION_REGISTER_BUILDING_BLOCK(ion::bb::test::ArrayOutput, test_array_output); ION_REGISTER_BUILDING_BLOCK(ion::bb::test::ArrayInput, test_array_input); +ION_REGISTER_BUILDING_BLOCK(ion::bb::test::ArrayOutput, test_array_output); +ION_REGISTER_BUILDING_BLOCK(ion::bb::test::ArrayCopy, test_array_copy); ION_REGISTER_BUILDING_BLOCK(ion::bb::test::ExternIncI32x2, test_extern_inc_i32x2); #endif