Skip to content

Commit

Permalink
WIP: input array
Browse files Browse the repository at this point in the history
  • Loading branch information
Fixstars-iizuka committed Dec 19, 2023
1 parent c84e16d commit d4afb8a
Show file tree
Hide file tree
Showing 7 changed files with 131 additions and 67 deletions.
19 changes: 10 additions & 9 deletions include/ion/port.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,7 @@ class Port {
* @arg t: The type of the value.
*/
Port(const std::string& k, Halide::Type t)
: key_(k), type_(t), dimensions_(0), index_(-1), node_id_(),
param_(t, false, 0, k)
{}
: key_(k), type_(t), dimensions_(0), index_(-1), node_id_() {}

/**
* Construct new port for vector value.
Expand All @@ -37,9 +35,7 @@ class Port {
* @arg d: The dimension of the port. The range is 1 to 4.
*/
Port(const std::string& k, Halide::Type t, int32_t d)
: key_(k), type_(t), dimensions_(d), index_(-1), node_id_(),
param_(t, true, d, k)
{}
: key_(k), type_(t), dimensions_(d), index_(-1), node_id_() {}

std::string key() const { return key_; }
std::string& key() { return key_; }
Expand All @@ -56,8 +52,13 @@ class Port {
std::string node_id() const { return node_id_; }
std::string& node_id() { return node_id_; }

Halide::Internal::Parameter& param() {
return param_;
std::vector<Halide::Internal::Parameter>& params() {
if (index_ == -1) {
params_.resize(1, Halide::Internal::Parameter{type_, dimensions_ != 0, dimensions_, key_});
} else {
params_.resize(index_+1, Halide::Internal::Parameter{type_, dimensions_ != 0, dimensions_, key_});
}
return params_;
}

bool is_bound() const {
Expand Down Expand Up @@ -89,7 +90,7 @@ class Port {
int32_t index_;
std::string node_id_;

Halide::Internal::Parameter param_;
std::vector<Halide::Internal::Parameter> params_;
};

} // namespace ion
Expand Down
54 changes: 36 additions & 18 deletions include/ion/port_map.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,16 @@ class PortMap {
*/
template<typename T>
void set(Port p, T v) {
auto param = p.param();
param.set_scalar(v);
param_[(p.key())] = param;
auto params(p.params());
auto i = p.index();
if (i == -1) {
params[0].set_scalar(v);
params_[(p.key())] = params;
} else {
params[i].set_scalar(v);
params_[(p.key())].resize(i+1);
params_[(p.key())][i] = params[i];
}
dirty_ = true;
}

Expand Down Expand Up @@ -95,9 +102,16 @@ class PortMap {
// This is just an output.
output_buffer_[std::make_tuple(p.node_id(), p.key(), p.index())] = { buf };
} else {
auto param = p.param();
param.set_buffer(buf);
param_[p.key()] = param;
auto params(p.params());
auto i = p.index();
if (i == -1) {
params[0].set_buffer(buf);
params_[p.key()] = params;
} else {
params[i].set_buffer(buf);
params_[p.key()].resize(i+1);
params_[p.key()][i] = params[i];
}
}

dirty_ = true;
Expand Down Expand Up @@ -138,11 +152,11 @@ class PortMap {
// }

bool is_mapped(const std::string& k) const {
return param_.count(k);
return params_.count(k);
}

Halide::Internal::Parameter get_param(const std::string& k) const {
return param_.at(k);
std::vector<Halide::Internal::Parameter> get_params(const std::string& k) const {
return params_.at(k);
}

std::unordered_map<std::tuple<std::string, std::string, int>, std::vector<Halide::Buffer<>>> get_output_buffer() const {
Expand All @@ -151,20 +165,24 @@ class PortMap {

std::vector<Halide::Argument> get_arguments_stub() const {
std::vector<Halide::Argument> args;
for (auto kv : param_) {
auto kind = kv.second.is_buffer() ? Halide::Argument::InputBuffer : Halide::Argument::InputScalar;
args.push_back(Halide::Argument(kv.first, kind, kv.second.type(), kv.second.dimensions(), Halide::ArgumentEstimates()));
for (const auto& kv : params_) {
for (const auto& p : kv.second) {
auto kind = p.is_buffer() ? Halide::Argument::InputBuffer : Halide::Argument::InputScalar;
args.push_back(Halide::Argument(kv.first, kind, p.type(), p.dimensions(), Halide::ArgumentEstimates()));
}
}
return args;
}

std::vector<const void*> get_arguments_instance() const {
std::vector<const void*> args;
for (auto kv : param_) {
if (kv.second.is_buffer()) {
args.push_back(kv.second.raw_buffer());
} else {
args.push_back(kv.second.scalar_address());
for (const auto& kv : params_) {
for (const auto& p : kv.second) {
if (p.is_buffer()) {
args.push_back(p.raw_buffer());
} else {
args.push_back(p.scalar_address());
}
}
}
return args;
Expand All @@ -180,7 +198,7 @@ class PortMap {

private:
bool dirty_;
std::unordered_map<std::string, Halide::Internal::Parameter> param_;
std::unordered_map<std::string, std::vector<Halide::Internal::Parameter>> params_;
std::unordered_map<std::tuple<std::string, std::string, int>, std::vector<Halide::Buffer<>>> output_buffer_;
};

Expand Down
72 changes: 46 additions & 26 deletions src/builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,16 +36,16 @@ std::map<Halide::OutputFileType, std::string> compute_output_files(const Halide:

bool is_ready(const std::vector<Node>& sorted, const Node& n) {
bool ready = true;
for (auto p : n.ports()) {
for (auto port : n.ports()) {
// This port has external dependency. Always ready to add.
if (p.node_id().empty()) {
if (port.node_id().empty()) {
continue;
}

// Check port dependent node is already added
ready &= std::find_if(sorted.begin(), sorted.end(),
[&p](const Node& n) {
return n.id() == p.node_id();
[&port](const Node& n) {
return n.id() == port.node_id();
}) != sorted.end();
}
return ready;
Expand Down Expand Up @@ -260,23 +260,23 @@ Halide::Pipeline Builder::build(ion::PortMap& pm) {
const auto& bb = bbs[n.id()];
auto arginfos = bb->arginfos();
for (size_t j=0; j<n.ports().size(); ++j) {
auto p = n.ports()[j];
auto index = p.index();
auto port = n.ports()[j];
auto index = port.index();
// Unbounded parameter
const auto& arginfo = arginfos[j];
if (p.is_bound()) {
auto fs = bbs[p.node_id()]->output_func(p.key());
if (port.is_bound()) {
auto fs = bbs[port.node_id()]->output_func(port.key());
if (arginfo.kind == Halide::Internal::ArgInfoKind::Scalar) {
bb->bind_input(arginfo.name, fs);
} else if (arginfo.kind == Halide::Internal::ArgInfoKind::Function) {
auto fs = bbs[p.node_id()]->output_func(p.key());
auto fs = bbs[port.node_id()]->output_func(port.key());
// no specific index provided, direct output Port
if (index == -1) {
bb->bind_input(arginfo.name, fs);
} else {
// access to Port[index]
if (index>=fs.size()){
throw std::runtime_error("Port index out of range: " + p.key());
throw std::runtime_error("Port index out of range: " + port.key());
}
bb->bind_input(arginfo.name, {fs[index]});
}
Expand All @@ -285,35 +285,55 @@ Halide::Pipeline Builder::build(ion::PortMap& pm) {
}
} else {
if (arginfo.kind == Halide::Internal::ArgInfoKind::Scalar) {
if (pm.is_mapped(p.key())) {
if (pm.is_mapped(port.key())) {
// This block should be executed when g.run is called with appropriate PortMap.
// const std::vector<Halide::Expr>& vs = { pm.get_param_expr(p.key()) };
auto param = pm.get_param(p.key());
const auto& params(pm.get_params(port.key()));

// validation
// if (arginfo.types.size() != vs.size()) {
// log::error("E");
// }
// for (auto i=0; i<vs.size(); ++i) {
// if (arginfo.types[i] != vs[i].type()) {
// log::error("Type mismatch: BB {} expects {}, but port {} has {}", bb->name(), Halide::type_to_c_type(arginfo.types[i], false), p.key(), Halide::type_to_c_type(vs[i].type(), false));
// log::error("Type mismatch: BB {} expects {}, but port {} has {}", bb->name(), Halide::type_to_c_type(arginfo.types[i], false), port.key(), Halide::type_to_c_type(vs[i].type(), false));
// }
// }
// validation
bb->bind_input(arginfo.name, { Halide::Internal::Variable::make(p.type(), p.key(), param) });

std::vector<Halide::Expr> es;
for (const auto& p : params) {
es.push_back(Halide::Internal::Variable::make(port.type(), port.key(), p));
}
bb->bind_input(arginfo.name, es);
} else {
bb->bind_input(arginfo.name, { Halide::Internal::Variable::make(p.type(), p.key(), p.param()) });
std::vector<Halide::Expr> es;
for (const auto& p : port.params()) {
es.push_back(Halide::Internal::Variable::make(port.type(), port.key(), p));
}
bb->bind_input(arginfo.name, es);
}
} else if (arginfo.kind == Halide::Internal::ArgInfoKind::Function) {
if (pm.is_mapped(p.key())) {
if (pm.is_mapped(port.key())) {
// This block should be executed when g.run is called with appropriate PortMap.
auto b = pm.get_param(p.key()).buffer();
Halide::Func f;
f(Halide::_) = b(Halide::_);
bb->bind_input(arginfo.name, { f });
const auto& params(pm.get_params(port.key()));

std::vector<Halide::Func> fs;
for (const auto& p : params) {
auto b(p.buffer());
Halide::Func f;
f(Halide::_) = b(Halide::_);
fs.push_back(f);
}

bb->bind_input(arginfo.name, fs);
} else {
Halide::ImageParam param(p.type(), p.dimensions(), p.key());
bb->bind_input(arginfo.name, { param });
std::vector<Halide::Func> fs;
if (port.index() == -1) {
fs.resize(1, Halide::ImageParam(port.type(), port.dimensions(), port.key()));
} else {
fs.resize(port.index()+1, Halide::ImageParam(port.type(), port.dimensions(), port.key()));
}
bb->bind_input(arginfo.name, fs);
}
} else {
throw std::runtime_error("fixme");
Expand All @@ -331,10 +351,10 @@ Halide::Pipeline Builder::build(ion::PortMap& pm) {
// Traverse bbs and bundling all outputs
std::unordered_map<std::string, std::vector<std::string>> dereferenced;
for (const auto& n : nodes_) {
for (const auto& p : n.ports()) {
auto node_id = p.node_id();
for (const auto& port : n.ports()) {
auto node_id = port.node_id();
if (!node_id.empty()) {
for (const auto &f : bbs[node_id]->output_func(p.key())) {
for (const auto &f : bbs[node_id]->output_func(port.key())) {
dereferenced[node_id].emplace_back(f.name());
}
}
Expand Down
13 changes: 7 additions & 6 deletions src/serializer.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,21 +46,22 @@ class adl_serializer<ion::Port> {
j["key_"] = v.key();
j["type_"] = static_cast<halide_type_t>(v.type());
j["dimensions_"] = v.dimensions();
j["index_"] = v.index();
j["node_id_"] = v.node_id();
}

static void from_json(const json& j, ion::Port& v) {
v.key() = j["key_"].get<std::string>();
v.type() = j["type_"].get<halide_type_t>();
v.dimensions() = j["dimensions_"];
v.index() = j["index_"];
v.node_id() = j["node_id_"].get<std::string>();
if (v.node_id().empty()) {
// if (v.dimensions() == 0) {
// v.expr() = Halide::Internal::Variable::make(v.type(), v.key(), Halide::Internal::Parameter(v.type(), false, 0, v.key()));
// } else {
// v.func() = Halide::ImageParam(v.type(), v.dimensions(), v.key());
// }
v.param() = Halide::Internal::Parameter(v.type(), v.dimensions() != 0, v.dimensions(), v.key());
if (v.index() == -1) {
v.params() = { Halide::Internal::Parameter(v.type(), v.dimensions() != 0, v.dimensions(), v.key()) };
} else {
v.params() = std::vector<Halide::Internal::Parameter>(v.index()+1, Halide::Internal::Parameter{v.type(), v.dimensions() != 0, v.dimensions(), v.key()});
}
}
}
};
Expand Down
4 changes: 3 additions & 1 deletion test/array_input.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@ int main() {
Port input{"input", Halide::type_of<int32_t>(), 2};
Builder b;
b.set_target(Halide::get_host_target());
auto n = b.add("test_array_input")(input);
Node n;
n = b.add("test_array_copy")(input).set_param(Param{"len", std::to_string(len)});
n = b.add("test_array_input")(n["array_output"]).set_param(Param{"len", std::to_string(len)});

Halide::Buffer<int32_t> in0(w, h), in1(w, h), in2(w, h), in3(w, h), in4(w, h);

Expand Down
4 changes: 3 additions & 1 deletion test/array_output.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@ int main() {
Port input{"input", Halide::type_of<int32_t>(), 2};
Builder b;
b.set_target(Halide::get_host_target());
auto n = b.add("test_array_output")(input).set_param(Param{"len", std::to_string(len)});
Node n;
n = b.add("test_array_output")(input).set_param(Param{"len", std::to_string(len)});
n = b.add("test_array_copy")(n["array_output"]).set_param(Param{"len", std::to_string(len)});

Halide::Buffer<int32_t> in(w, h);
for (int y = 0; y < h; ++y) {
Expand Down
32 changes: 26 additions & 6 deletions test/test-bb.h
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,25 @@ class MultiOut : public BuildingBlock<MultiOut> {
Halide::Var x, y, c;
};

class ArrayInput : public BuildingBlock<ArrayInput> {
public:
GeneratorParam<int> len{"len", 5};

Input<Halide::Func[]> array_input{"array_input", Int(32), 2};
Output<Halide::Func> output{"output", Int(32), 2};

void generate() {
Halide::Expr v = 0;
for (int i = 0; i < len; ++i) {
v += array_input[i](x, y);
}
output(x, y) = v;
}

private:
Halide::Var x, y;
};

class ArrayOutput : public BuildingBlock<ArrayOutput> {
public:
GeneratorParam<int> len{"len", 5};
Expand All @@ -181,25 +200,25 @@ class ArrayOutput : public BuildingBlock<ArrayOutput> {
Halide::Var x, y;
};

class ArrayInput : public BuildingBlock<ArrayInput> {
class ArrayCopy : public BuildingBlock<ArrayCopy> {
public:
GeneratorParam<int> len{"len", 5};

Input<Halide::Func[]> array_input{"array_input", Int(32), 2};
Output<Halide::Func> output{"output", Int(32), 2};
Output<Halide::Func[]> array_output{"array_output", Int(32), 2};

void generate() {
Halide::Expr v = 0;
array_output.resize(len);
for (int i = 0; i < len; ++i) {
v += array_input[i](x, y);
array_output[i](x, y) = array_input[i](x, y);
}
output(x, y) = v;
}

private:
Halide::Var x, y;
};


class ExternIncI32x2 : public BuildingBlock<ExternIncI32x2> {
public:
GeneratorParam<int32_t> v{"v", 0};
Expand Down Expand Up @@ -235,8 +254,9 @@ ION_REGISTER_BUILDING_BLOCK(ion::bb::test::IncI32x2, test_inc_i32x2);
ION_REGISTER_BUILDING_BLOCK(ion::bb::test::Dup, test_dup);
ION_REGISTER_BUILDING_BLOCK(ion::bb::test::Scale2x, test_scale2x);
ION_REGISTER_BUILDING_BLOCK(ion::bb::test::MultiOut, test_multi_out);
ION_REGISTER_BUILDING_BLOCK(ion::bb::test::ArrayOutput, test_array_output);
ION_REGISTER_BUILDING_BLOCK(ion::bb::test::ArrayInput, test_array_input);
ION_REGISTER_BUILDING_BLOCK(ion::bb::test::ArrayOutput, test_array_output);
ION_REGISTER_BUILDING_BLOCK(ion::bb::test::ArrayCopy, test_array_copy);
ION_REGISTER_BUILDING_BLOCK(ion::bb::test::ExternIncI32x2, test_extern_inc_i32x2);

#endif

0 comments on commit d4afb8a

Please sign in to comment.