Skip to content

Commit

Permalink
add_input is supported
Browse files Browse the repository at this point in the history
  • Loading branch information
Fixstars-iizuka committed Jan 11, 2024
1 parent 7fc8e27 commit 4ad44b3
Show file tree
Hide file tree
Showing 3 changed files with 100 additions and 16 deletions.
35 changes: 19 additions & 16 deletions src/node.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,19 +19,20 @@ Node::Impl::Impl(const std::string& id_, const std::string& name_, const Halide:
void Node::set_iport(const std::vector<Port>& ports) {

size_t i = 0;
for (const auto& info : impl_->arginfos) {
if (info.dir == Halide::Internal::ArgInfoDirection::Output) {
continue;
}

if (i >= ports.size()) {
log::error("Port {} is out of range", i);
throw std::runtime_error("Failed to validate input port");
}
for (auto& port : ports) {
// TODO: Validation is better to be done lazily after BuildingBlock::configure
//
// if (info.dir == Halide::Internal::ArgInfoDirection::Output) {
// continue;
// }

auto& port(ports[i]);
// if (i >= ports.size()) {
// log::error("Port {} is out of range", i);
// throw std::runtime_error("Failed to validate input port");
// }

port.impl_->succ_chans.insert({id(), info.name});
// NOTE: Is succ_chans name OK to be just leave as it is?
port.impl_->succ_chans.insert({id(), "_ion_iport_" + i});

impl_->ports.push_back(port);

Expand All @@ -40,11 +41,13 @@ void Node::set_iport(const std::vector<Port>& ports) {
}

Port Node::operator[](const std::string& name) {
if (std::find_if(impl_->arginfos.begin(), impl_->arginfos.end(),
[&](const Halide::Internal::AbstractGenerator::ArgInfo& info) { return info.name == name; }) == impl_->arginfos.end()) {
log::error("Port {} is not found", name);
throw std::runtime_error("Failed to find port");
}
// TODO: Validation is better to be done lazily after BuildingBlock::configure
//
// if (std::find_if(impl_->arginfos.begin(), impl_->arginfos.end(),
// [&](const Halide::Internal::AbstractGenerator::ArgInfo& info) { return info.name == name; }) == impl_->arginfos.end()) {
// log::error("Port {} is not found", name);
// throw std::runtime_error("Failed to find port");
// }

auto it = std::find_if(impl_->ports.begin(), impl_->ports.end(),
[&](const Port& p){ return (p.pred_name() == name && p.pred_id() == impl_->id) || p.has_succ({impl_->id, name}); });
Expand Down
3 changes: 3 additions & 0 deletions test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,9 @@ endif()
# Duplicate name test
ion_jit_executable(dup-port-name SRCS dup-port-name.cc)

# BuildingBlock::configure
ion_jit_executable(configure SRCS configure.cc)

# Export test
# TODO: Resolve defects in feature/win-debug branch on Windows environment
# ion_jit_executable(export SRCS export.cc)
Expand Down
78 changes: 78 additions & 0 deletions test/configure.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
#include <ion/ion.h>

using namespace ion;

struct Test : BuildingBlock<Test> {
// This Building Block takes 1 input, 1 output and 1 parameter.
Input<Halide::Func> input{"input", Int(32), 1};
Output<Halide::Func> output{"output", Int(32), 1};
std::vector<Input<int32_t> *> extra_scalar_inputs;
GeneratorParam<int32_t> num{"num", 0};

void configure() {
for (int32_t i=0; i<num; ++i) {
extra_scalar_inputs.push_back(add_input<int32_t>("extra_scalar_input_" + std::to_string(i)));
}
}

void generate() {
Halide::Var i;
Halide::Expr v = input(i);
for (int i=0; i<num; ++i) {
v += *extra_scalar_inputs[i];
}
output(i) = v;
}
};
ION_REGISTER_BUILDING_BLOCK(Test, test);

int main() {
try {
int32_t v = 1;
auto size = 4;

Buffer<int32_t> input{size};
input.fill(40);

// No extra
{
Builder b;
b.set_target(get_host_target());
Buffer<int32_t> output{size};
b.add("test")(input)["output"].bind(output);
b.run();
for (int i=0; i<size; ++i) {
if (output(i) != 40) {
return 1;
}
}
}

// Added Extra
{
Builder b;
b.set_target(get_host_target());
Buffer<int32_t> output{size};
b.add("test")(input, &v, &v).set_param(Param("num", 2))["output"].bind(output);
b.compile("x");
b.run();
for (int i=0; i<size; ++i) {
if (output(i) != 42) {
return 1;
}
}
}

} catch (Halide::Error& e) {
std::cerr << e.what() << std::endl;
return 1;
} catch (const std::exception& e) {
std::cerr << e.what() << std::endl;
return 1;
}

std::cout << "Passed" << std::endl;

return 0;

}

0 comments on commit 4ad44b3

Please sign in to comment.