diff --git a/CMakeLists.txt b/CMakeLists.txt index aceebd98..23ef1f85 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -61,7 +61,9 @@ endif() file(GLOB ION_CORE_SRC LIST_DIRECTORIES false ${PROJECT_SOURCE_DIR}/src/*) add_library(ion-core SHARED ${ION_CORE_SRC}) target_include_directories(ion-core PUBLIC ${PROJECT_SOURCE_DIR}/include ${PROJECT_SOURCE_DIR}/src) -if (UNIX) +if (APPLE) + target_link_libraries(ion-core PUBLIC Halide::Halide Halide::Runtime) +elseif (UNIX) target_link_libraries(ion-core PUBLIC Halide::Halide Halide::Runtime dl pthread z m stdc++) else() target_link_libraries(ion-core PUBLIC Halide::Halide Halide::Runtime) diff --git a/cmake/IonUtil.cmake b/cmake/IonUtil.cmake index dff11c83..b3620924 100644 --- a/cmake/IonUtil.cmake +++ b/cmake/IonUtil.cmake @@ -18,14 +18,11 @@ function(ion_aot_executable NAME_PREFIX) # Build compile set(COMPILE_NAME ${NAME_PREFIX}_compile) add_executable(${COMPILE_NAME} ${IAE_SRCS_COMPILE}) - if(UNIX AND NOT APPLE) + if(UNIX) target_compile_options(${COMPILE_NAME} PUBLIC -fno-rtti) # For Halide::Generator - target_link_options(${COMPILE_NAME} PUBLIC -Wl,--export-dynamic) # For JIT compiling - endif() - IF (APPLE) - target_compile_options(${COMPILE_NAME} - PUBLIC -fno-rtti # For Halide::Generator - PUBLIC -rdynamic) # For JIT compiling + if(NOT APPLE) + target_link_options(${COMPILE_NAME} PUBLIC -Wl,--export-dynamic) # For JIT compiling + endif() endif() target_include_directories(${COMPILE_NAME} PUBLIC "${PROJECT_SOURCE_DIR}/include") target_link_libraries(${COMPILE_NAME} PRIVATE ion-core ${PLATFORM_LIBRARIES}) @@ -105,12 +102,7 @@ function(ion_jit_executable NAME_PREFIX) add_executable(${NAME} ${IJE_SRCS}) if (UNIX AND NOT APPLE) target_link_options(${NAME} PUBLIC -Wl,--export-dynamic) # For JIT compiling - endif() - if (APPLE) - target_compile_options(${NAME} - PUBLIC -fno-rtti # For Halide::Generator - PUBLIC -rdynamic) # For JIT compiling - endif() + endif() add_dependencies(${NAME} ion-bb ion-bb-test) target_include_directories(${NAME} PUBLIC ${PROJECT_SOURCE_DIR}/include ${ION_BB_INCLUDE_DIRS} ${IJE_INCS}) target_link_libraries(${NAME} PRIVATE ion-core ${ION_BB_LIBRARIES} ${PLATFORM_LIBRARIES} ${IJE_LIBS}) diff --git a/example/sgm_compile.cc b/example/sgm_compile.cc index c0d6b4d6..743e52a8 100644 --- a/example/sgm_compile.cc +++ b/example/sgm_compile.cc @@ -18,14 +18,14 @@ int main() { Node ln = b.add("image_io_color_data_loader").set_param(Param{"url", "http://ion-kit.s3.us-west-2.amazonaws.com/images/aloe_left.jpg"}, Param{"width", std::to_string(input_width)}, Param{"height", std::to_string(input_height)}); ln = b.add("base_normalize_3d_uint8")(ln["output"]); ln = b.add("image_processing_resize_nearest_3d")(ln["output"]).set_param(Param{"width", std::to_string(input_width)}, Param{"height", std::to_string(input_height)}, Param{"scale", std::to_string(scale)}); - ln = b.add("base_schedule")(ln["output"]).set_param(Param{"output_name", "scaled_left"}, Param{"compute_level", "compute_root"}); + ln = b.add("base_schedule")(ln["output"]).set_param(Param{"output_name", "scaled_left"}, Param{"compute_level", "compute_root"}, Param("input.type", "float32"), Param("input.dim", 3)); ln = b.add("image_processing_calc_luminance")(ln["output"]).set_param(Param{"luminance_method", "Average"}); ln = b.add("base_denormalize_2d_uint8")(ln["output"]); Node rn = b.add("image_io_color_data_loader").set_param(Param{"url", "http://ion-kit.s3.us-west-2.amazonaws.com/images/aloe_right.jpg"}, Param{"width", std::to_string(input_width)}, Param{"height", std::to_string(input_height)}); rn = b.add("base_normalize_3d_uint8")(rn["output"]); rn = b.add("image_processing_resize_nearest_3d")(rn["output"]).set_param(Param{"width", std::to_string(input_width)}, Param{"height", std::to_string(input_height)}, Param{"scale", std::to_string(scale)}); - rn = b.add("base_schedule")(rn["output"]).set_param(Param{"output_name", "scaled_right"}, Param{"compute_level", "compute_root"}); + rn = b.add("base_schedule")(rn["output"]).set_param(Param{"output_name", "scaled_right"}, Param{"compute_level", "compute_root"}, Param("input.type", "float32"), Param("input.dim", 3)); rn = b.add("image_processing_calc_luminance")(rn["output"]).set_param(Param{"luminance_method", "Average"}); rn = b.add("base_denormalize_2d_uint8")(rn["output"]); diff --git a/include/ion/builder.h b/include/ion/builder.h index 527de95e..4aeabcea 100644 --- a/include/ion/builder.h +++ b/include/ion/builder.h @@ -100,6 +100,8 @@ class Builder { Halide::Pipeline build(bool implicit_output = false); + void determine_and_validate(); + std::vector get_arguments_stub() const; std::vector get_arguments_instance() const; diff --git a/include/ion/building_block.h b/include/ion/building_block.h index fd989a6f..80ed01cb 100644 --- a/include/ion/building_block.h +++ b/include/ion/building_block.h @@ -34,7 +34,10 @@ class BuildingBlock : public Halide::Generator { template void register_disposer(const std::string& n) { - reinterpret_cast(static_cast(builder_ptr))->register_disposer(bb_id, n); + auto bb(reinterpret_cast(static_cast(builder_ptr))); + if (bb) { + bb->register_disposer(bb_id, n); + } } ion::Buffer get_id() { diff --git a/include/ion/node.h b/include/ion/node.h index 78eff26f..5ff765eb 100644 --- a/include/ion/node.h +++ b/include/ion/node.h @@ -110,26 +110,11 @@ class Node { return impl_->ports; } - std::vector iports() const { - std::vector iports; - for (const auto& p: impl_->ports) { - if (std::count_if(p.impl_->succ_chans.begin(), p.impl_->succ_chans.end(), - [&](const Port::Channel& c) { return std::get<0>(c) == impl_->id; })) { - iports.push_back(p); - } - } - return iports; - } + Port iport(const std::string& pn); + std::vector> iports() const; - std::vector oports() const { - std::vector oports; - for (const auto& p: impl_->ports) { - if (id() == p.pred_id()) { - oports.push_back(p); - } - } - return oports; - } + Port oport(const std::string& pn); + std::vector> oports() const; private: Node(const std::string& id, const std::string& name, const Halide::Target& target) diff --git a/include/ion/port.h b/include/ion/port.h index 849d4e56..66a4c0f5 100644 --- a/include/ion/port.h +++ b/include/ion/port.h @@ -53,12 +53,7 @@ class Port { private: struct Impl { - // std::string pred_id; - // std::string pred_name; - - // std::string succ_id; - // std::string succ_name; - + std::string id; Channel pred_chan; std::set succ_chans; @@ -68,13 +63,8 @@ class Port { std::unordered_map params; std::unordered_map instances; - Impl() {} - - Impl(const std::string& pid, const std::string& pn, const Halide::Type& t, int32_t d) - : pred_chan{pid, pn}, succ_chans{}, type(t), dimensions(d) - { - params[0] = Halide::Internal::Parameter(type, dimensions != 0, dimensions, argument_name(pid, pn, 0)); - } + Impl(); + Impl(const std::string& pid, const std::string& pn, const Halide::Type& t, int32_t d); }; public: @@ -103,7 +93,7 @@ class Port { */ template::value>::type* = nullptr> - Port(T *vptr) : impl_(new Impl("", Halide::Internal::unique_name("ion_port"), Halide::type_of(), 0)), index_(-1) { + Port(T *vptr) : impl_(new Impl("", Halide::Internal::unique_name("_ion_port_"), Halide::type_of(), 0)), index_(-1) { this->bind(vptr); } @@ -124,6 +114,7 @@ class Port { } // Getter + const std::string& id() const { return impl_->id; } const Channel& pred_chan() const { return impl_->pred_chan; } const std::string& pred_id() const { return std::get<0>(impl_->pred_chan); } const std::string& pred_name() const { return std::get<1>(impl_->pred_chan); } @@ -132,15 +123,22 @@ class Port { int32_t dimensions() const { return impl_->dimensions; } int32_t size() const { return static_cast(impl_->params.size()); } int32_t index() const { return index_; } - uintptr_t impl_ptr() const { return reinterpret_cast(impl_.get()); } // Setter void set_index(int index) { index_ = index; } // Util bool has_pred() const { return !std::get<0>(impl_->pred_chan).empty(); } + bool has_pred_by_nid(const std::string& nid) const { return !std::get<0>(impl_->pred_chan).empty(); } bool has_succ() const { return !impl_->succ_chans.empty(); } bool has_succ(const Channel& c) const { return impl_->succ_chans.count(c); } + bool has_succ_by_nid(const std::string& nid) const { + return std::count_if(impl_->succ_chans.begin(), + impl_->succ_chans.end(), + [&](const Port::Channel& c) { return std::get<0>(c) == nid; }); + } + + void determine_succ(const std::string& nid, const std::string& old_pn, const std::string& new_pn); /** * Overloaded operator to set the port index and return a reference to the current port. eg. port[0] @@ -189,25 +187,17 @@ class Port { } } - static std::tuple, bool> find_impl(uintptr_t ptr) { - static std::unordered_map> impls; - static std::mutex mutex; - std::scoped_lock lock(mutex); - bool found = true; - if (!impls.count(ptr)) { - impls[ptr] = std::make_shared(); - found = false; - } - return std::make_tuple(impls[ptr], found); - } + static std::tuple, bool> find_impl(const std::string& id); private: /** - * This port is created from another node + * This port is created from another node. + * In this case, it is not sure what this port is input or output. + * pid and pn is stored in both pred and succ, + * then it will determined through pipeline build process. */ Port(const std::string& pid, const std::string& pn) : impl_(new Impl(pid, pn, Halide::Type(), 0)), index_(-1) {} - std::vector as_argument() const { std::vector args; for (const auto& [i, param] : impl_->params) { diff --git a/src/bb/CMakeLists.txt b/src/bb/CMakeLists.txt index 099f60a5..e927a1e1 100644 --- a/src/bb/CMakeLists.txt +++ b/src/bb/CMakeLists.txt @@ -32,16 +32,11 @@ target_include_directories(ion-bb PUBLIC ${PROJECT_SOURCE_DIR}/include ${PROJECT target_link_libraries(ion-bb PUBLIC ion-core ${ION_BB_LIBRARIES}) if(UNIX) target_compile_options(ion-bb PUBLIC -fno-rtti) # For Halide::Generator - if(APPLE) - target_compile_options(ion-bb - PUBLIC -fno-rtti # For Halide::Generator - PUBLIC -rdynamic) # For JIT compiling - else() + if(NOT APPLE) target_link_options(ion-bb PUBLIC -Wl,--export-dynamic) # For JIT compiling endif() elseif(MSVC) - target_compile_options(ion-bb - PUBLIC /bigobj) + target_compile_options(ion-bb PUBLIC /bigobj) endif() # diff --git a/src/builder.cc b/src/builder.cc index 33fd17e8..f5eb3dbc 100644 --- a/src/builder.cc +++ b/src/builder.cc @@ -18,8 +18,6 @@ #include "metadata.h" #include "serializer.h" -#define SW 1 - namespace ion { namespace { @@ -38,16 +36,18 @@ std::map compute_output_files(const Halide: bool is_ready(const std::vector& sorted, const Node& n) { bool ready = true; - for (auto port : n.iports()) { + for (const auto& [pn, port] : n.iports()) { // This port has predecessor dependency. Always ready to add. if (!port.has_pred()) { continue; } + const auto& port_(port); // This is workaround for Clang-14 (MacOS) + // Check port dependent node is already added ready &= std::find_if(sorted.begin(), sorted.end(), - [&port](const Node& n) { - return n.id() == port.pred_id(); + [&](const Node& n) { + return n.id() == port_.pred_id(); }) != sorted.end(); } return ready; @@ -76,6 +76,37 @@ std::vector topological_sort(std::vector nodes) { return sorted; } +Halide::Internal::AbstractGenerator::ArgInfo make_arginfo(const std::string& name, + Halide::Internal::ArgInfoDirection dir, + Halide::Internal::ArgInfoKind kind, + const std::vector& types, + int dimensions) { + return Halide::Internal::AbstractGenerator::ArgInfo { + name, dir, kind, types, dimensions + }; +} + +bool is_free(const std::string& pn) { + return pn.find("_ion_iport_") != std::string::npos; +} + +std::tuple find_ith_input(const std::vector& arginfos, int i) { + int j = 0; + for (const auto& arginfo : arginfos) { + if (arginfo.dir != Halide::Internal::ArgInfoDirection::Input) { + continue; + } + + if (i == j) { + return std::make_tuple(arginfo, true); + } + + j++; + } + + return std::make_tuple(Halide::Internal::AbstractGenerator::ArgInfo(), false); +} + } // anonymous using json = nlohmann::json; @@ -112,6 +143,7 @@ Builder& Builder::with_bb_module(const std::string& module_path) { void Builder::save(const std::string& file_name) { + determine_and_validate(); std::ofstream ofs(file_name); json j; j["target"] = target_.to_string(); @@ -224,47 +256,48 @@ Halide::Pipeline Builder::build(bool implicit_output) { log::info("Start building pipeline"); + determine_and_validate(); + // Sort nodes prior to build. // This operation is required especially for the graph which is loaded from JSON definition. nodes_ = topological_sort(nodes_); - auto generator_names = Halide::Internal::GeneratorRegistry::enumerate(); - // Constructing Generator object and setting static parameters std::unordered_map bbs; for (auto n : nodes_) { - - if (std::find(generator_names.begin(), generator_names.end(), n.name()) == generator_names.end()) { - throw std::runtime_error("Cannot find generator : " + n.name()); - } - auto bb(Halide::Internal::GeneratorRegistry::create(n.name(), Halide::GeneratorContext(n.target()))); + + // Default parameter Halide::GeneratorParamsMap params; params["builder_ptr"] = std::to_string(reinterpret_cast(this)); params["bb_id"] = n.id(); + + // User defined parameter for (const auto& p : n.params()) { - params[p.key()] = p.val(); + params[p.key()] = p.val(); } bb->set_generatorparam_values(params); bbs[n.id()] = std::move(bb); } - // Assigning ports + // Assigning ports and build pipeline for (size_t i=0; iarginfos(); - for (size_t j=0; joutput_func(port.pred_name()); + + const auto& pred_bb(bbs[port.pred_id()]); + + auto fs = pred_bb->output_func(port.pred_name()); if (arginfo.kind == Halide::Internal::ArgInfoKind::Scalar) { bb->bind_input(arginfo.name, fs); } else if (arginfo.kind == Halide::Internal::ArgInfoKind::Function) { - auto fs = bbs[port.pred_id()]->output_func(port.pred_name()); // no specific index provided, direct output Port if (index == -1) { bb->bind_input(arginfo.name, fs); @@ -298,7 +331,7 @@ Halide::Pipeline Builder::build(bool implicit_output) { // This mode is used for AOT compilation std::unordered_map> referenced; for (const auto& n : nodes_) { - for (const auto& port : n.iports()) { + for (const auto& [pn, port] : n.iports()) { if (port.has_pred()) { for (const auto &f : bbs[port.pred_id()]->output_func(port.pred_name())) { referenced[port.pred_id()].emplace_back(f.name()); @@ -331,12 +364,25 @@ Halide::Pipeline Builder::build(bool implicit_output) { // Collects all output which is bound with buffer. // This mode is used for JIT for (const auto& node : nodes_) { - for (const auto& port : node.oports()) { + for (const auto& [pn, port] : node.oports()) { const auto& port_instances(port.as_instance()); if (port_instances.empty()) { continue; } + const auto& pred_bb(bbs[port.pred_id()]); + + // Validate port exists + const auto& port_(port); // This is workaround for Clang-14 (MacOS) + const auto& pred_arginfos(pred_bb->arginfos()); + if (!std::count_if(pred_arginfos.begin(), pred_arginfos.end(), + [&](Halide::Internal::AbstractGenerator::ArgInfo arginfo){ return port_.pred_name() == arginfo.name && Halide::Internal::ArgInfoDirection::Output == arginfo.dir; })) { + auto msg = fmt::format("BuildingBlock \"{}\" has no output \"{}\"", pred_bb->name(), port.pred_name()); + log::error(msg); + throw std::runtime_error(msg); + } + + auto fs(bbs[port.pred_id()]->output_func(port.pred_name())); output_funcs.insert(output_funcs.end(), fs.begin(), fs.end()); } @@ -350,6 +396,76 @@ Halide::Pipeline Builder::build(bool implicit_output) { return Halide::Pipeline(output_funcs); } +void Builder::determine_and_validate() { + + auto generator_names = Halide::Internal::GeneratorRegistry::enumerate(); + + for (auto n : nodes_) { + if (std::find(generator_names.begin(), generator_names.end(), n.name()) == generator_names.end()) { + throw std::runtime_error("Cannot find generator : " + n.name()); + } + + auto bb(Halide::Internal::GeneratorRegistry::create(n.name(), Halide::GeneratorContext(n.target()))); + + // Validate and set parameters + for (const auto& p : n.params()) { + try { + bb->set_generatorparam_value(p.key(), p.val()); + } catch (const Halide::CompileError& e) { + auto msg = fmt::format("BuildingBlock \"{}\" has no parameter \"{}\"", n.name(), p.key()); + log::error(msg); + throw std::runtime_error(msg); + } + } + + try { + bb->build_pipeline(); + } catch (const Halide::CompileError& e) { + log::error(e.what()); + throw std::runtime_error(e.what()); + } + + const auto& arginfos(bb->arginfos()); + + // validate input port + auto i = 0; + for (auto& [pn, port] : n.iports()) { + if (is_free(pn)) { + const auto& [arginfo, found] = find_ith_input(arginfos, i); + if (!found) { + auto msg = fmt::format("BuildingBlock \"{}\" has no input #{}", n.name(), i); + log::error(msg); + throw std::runtime_error(msg); + } + + port.determine_succ(n.id(), pn, arginfo.name); + pn = arginfo.name; + } + + const auto& pn_(pn); // This is workaround for Clang-14 (MacOS) + if (!std::count_if(arginfos.begin(), arginfos.end(), + [&](Halide::Internal::AbstractGenerator::ArgInfo arginfo){ return pn_ == arginfo.name && Halide::Internal::ArgInfoDirection::Input == arginfo.dir; })) { + auto msg = fmt::format("BuildingBlock \"{}\" has no input \"{}\"", n.name(), pn); + log::error(msg); + throw std::runtime_error(msg); + } + + i++; + } + + // validate output + for (const auto& [pn, port] : n.oports()) { + const auto& pn_(pn); // This is workaround for Clang-14 (MacOS) + if (!std::count_if(arginfos.begin(), arginfos.end(), + [&](Halide::Internal::AbstractGenerator::ArgInfo arginfo){ return pn_ == arginfo.name && Halide::Internal::ArgInfoDirection::Output == arginfo.dir; })) { + auto msg = fmt::format("BuildingBlock \"{}\" has no output \"{}\"", n.name(), pn); + log::error(msg); + throw std::runtime_error(msg); + } + } + } +} + std::string Builder::bb_metadata(void) { std::vector md; @@ -377,7 +493,7 @@ std::vector Builder::get_arguments_stub() const { std::set added_ports; std::vector args; for (const auto& node : nodes_) { - for (const auto& port : node.iports()) { + for (const auto& [pn, port] : node.iports()) { if (port.has_pred()) { continue; } @@ -400,7 +516,7 @@ std::vector Builder::get_arguments_instance() const { // Input for (const auto& node : nodes_) { - for (const auto& port : node.iports()) { + for (const auto& [pn, port] : node.iports()) { if (port.has_pred()) { continue; } @@ -417,7 +533,7 @@ std::vector Builder::get_arguments_instance() const { // Output for (const auto& node : nodes_) { - for (const auto& port : node.oports()) { + for (const auto& [pn, port] : node.oports()) { const auto& port_instances(port.as_instance()); instances.insert(instances.end(), port_instances.begin(), port_instances.end()); } diff --git a/src/node.cc b/src/node.cc index 7b5eba84..3e95f5b7 100644 --- a/src/node.cc +++ b/src/node.cc @@ -18,6 +18,11 @@ Node::Impl::Impl(const std::string& id_, const std::string& name_, const Halide: void Node::set_iport(const std::vector& ports) { + std::remove_if(impl_->ports.begin(), impl_->ports.end(), + [&](const Port& p) { + return p.has_succ_by_nid(this->id()); + }); + size_t i = 0; for (auto& port : ports) { // TODO: Validation is better to be done lazily after BuildingBlock::configure @@ -32,7 +37,7 @@ void Node::set_iport(const std::vector& ports) { // } // NOTE: Is succ_chans name OK to be just leave as it is? - port.impl_->succ_chans.insert({id(), "_ion_iport_" + i}); + port.impl_->succ_chans.insert({id(), "_ion_iport_" + std::to_string(i)}); impl_->ports.push_back(port); @@ -41,28 +46,68 @@ void Node::set_iport(const std::vector& ports) { } Port Node::operator[](const std::string& name) { - // TODO: Validation is better to be done lazily after BuildingBlock::configure - // - // if (std::find_if(impl_->arginfos.begin(), impl_->arginfos.end(), - // [&](const Halide::Internal::AbstractGenerator::ArgInfo& info) { return info.name == name; }) == impl_->arginfos.end()) { - // log::error("Port {} is not found", name); - // throw std::runtime_error("Failed to find port"); - // } + auto it = std::find_if(impl_->ports.begin(), impl_->ports.end(), + [&](const Port& p){ return p.pred_id() == impl_->id && p.pred_name() == name; }); + if (it == impl_->ports.end()) { + // This is output port which is never referenced. + // Bind myself as a predecessor and register + Port port(impl_->id, name); + impl_->ports.push_back(port); + return port; + } else { + // Port is already registered + return *it; + } +} - auto it = std::find_if(impl_->ports.begin(), impl_->ports.end(), - [&](const Port& p){ return (p.pred_name() == name && p.pred_id() == impl_->id) || p.has_succ({impl_->id, name}); }); - if (it == impl_->ports.end()) { - // This is output port which is never referenced. - // Bind myself as a predecessor and register - - // TODO: Validate with arginfo - Port port(impl_->id, name); - impl_->ports.push_back(port); - return port; - } else { - // Port is already registered - return *it; +Port Node::iport(const std::string& pn) { + for (const auto& p: impl_->ports) { + auto it = std::find_if(p.impl_->succ_chans.begin(), p.impl_->succ_chans.end(), + [&](const Port::Channel& c) { return std::get<0>(c) == impl_->id && std::get<1>(c) == pn; }); + if (it != p.impl_->succ_chans.end()) { + return p; } } + auto msg = fmt::format("BuildingBlock \"{}\" has no input \"{}\"", name(), pn); + log::error(msg); + throw std::runtime_error(msg); +} + +std::vector> Node::iports() const { + std::vector> iports; + for (const auto& p: impl_->ports) { + auto it = std::find_if(p.impl_->succ_chans.begin(), p.impl_->succ_chans.end(), + [&](const Port::Channel& c) { return std::get<0>(c) == impl_->id; }); + if (it != p.impl_->succ_chans.end()) { + iports.push_back(std::make_tuple(std::get<1>(*it), p)); + } + } + return iports; +} + +Port Node::oport(const std::string& pn) { + return this->operator[](pn); + // auto it = std::find_if(impl_->ports.begin(), impl_->ports.end(), + // [&](const Port& p) { return p.pred_id() == id() && p.pred_name() == pn; }); + + // if (it != impl_->ports.end()) { + // return *it; + // } + + // auto msg = fmt::format("BuildingBlock \"{}\" has no output \"{}\"", name(), pn); + // log::error(msg); + // throw std::runtime_error(msg); +} + +std::vector> Node::oports() const { + std::vector> oports; + for (const auto& p: impl_->ports) { + if (id() == p.pred_id()) { + oports.push_back(std::make_tuple(p.pred_name(), p)); + } + } + return oports; +} + } // namespace ion diff --git a/src/port.cc b/src/port.cc index fb08495d..a90a1506 100644 --- a/src/port.cc +++ b/src/port.cc @@ -1,6 +1,44 @@ #include "ion/port.h" +#include "uuid/sole.hpp" +#include "log.h" + namespace ion { +Port::Impl::Impl() + : id(sole::uuid4().str()), pred_chan{"", ""}, succ_chans{}, type(), dimensions(-1) +{ +} + +Port::Impl::Impl(const std::string& pid, const std::string& pn, const Halide::Type& t, int32_t d) + : id(sole::uuid4().str()), pred_chan{pid, pn}, succ_chans{}, type(t), dimensions(d) +{ + params[0] = Halide::Internal::Parameter(type, dimensions != 0, dimensions, argument_name(pid, pn, 0)); +} + +void Port::determine_succ(const std::string& nid, const std::string& old_pn, const std::string& new_pn) { + auto it = std::find(impl_->succ_chans.begin(), impl_->succ_chans.end(), Channel{nid, old_pn}); + if (it == impl_->succ_chans.end()) { + log::error("fixme"); + throw std::runtime_error("fixme"); + } + + log::debug("Determine free port {} as {} on Node {}", old_pn, new_pn, nid); + impl_->succ_chans.erase(it); + impl_->succ_chans.insert(Channel{nid, new_pn}); +} + +std::tuple, bool> Port::find_impl(const std::string& id) { + static std::unordered_map> impls; + static std::mutex mutex; + std::scoped_lock lock(mutex); + bool found = true; + if (!impls.count(id)) { + impls[id] = std::make_shared(); + found = false; + } + log::debug("Port {} is {}found", id, found ? "" : "not "); + return std::make_tuple(impls[id], found); +} } // namespace ion diff --git a/src/serializer.h b/src/serializer.h index 8e0c8921..c1f42967 100644 --- a/src/serializer.h +++ b/src/serializer.h @@ -42,17 +42,17 @@ static void from_json(const json& j, ion::Param& v) { template<> struct adl_serializer { static void to_json(json& j, const ion::Port& v) { + j["id"] = v.id(); j["pred_chan"] = v.pred_chan(); j["succ_chans"] = v.succ_chans(); j["type"] = static_cast(v.type()); j["dimensions"] = v.dimensions(); j["size"] = v.size(); - j["impl_ptr"] = v.impl_ptr(); j["index"] = v.index(); } static void from_json(const json& j, ion::Port& v) { - auto [impl, found] = ion::Port::find_impl(j["impl_ptr"].get()); + auto [impl, found] = ion::Port::find_impl(j["id"].get()); if (!found) { impl->pred_chan = j["pred_chan"].get(); impl->succ_chans = j["succ_chans"].get>(); diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 750790ee..54ff24d4 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -66,7 +66,10 @@ ion_jit_executable(configure SRCS configure.cc) # Export test # TODO: Resolve defects in feature/win-debug branch on Windows environment -# ion_jit_executable(export SRCS export.cc) +ion_jit_executable(export SRCS export.cc) + +# Validation test +ion_jit_executable(validation SRCS validation.cc) ion_aot_executable(simple_graph SRCS_COMPILE simple_graph_compile.cc SRCS_RUN simple_graph_run.cc LIBS ion-bb-test) ion_jit_executable(simple_graph SRCS simple_graph_jit.cc) diff --git a/test/array_dup_names.cc b/test/array_dup_names.cc index 400ea3c6..bd30912c 100644 --- a/test/array_dup_names.cc +++ b/test/array_dup_names.cc @@ -22,9 +22,9 @@ int main() { Builder b; b.set_target(Halide::get_host_target()); auto n = b.add("test_array_output")(in).set_param(Param("len", len)); - n = b.add("test_array_input")(n["array_output"]); + n = b.add("test_array_input")(n["array_output"]).set_param(Param("array_input.size", len)); n = b.add("test_array_output")(n["output"]).set_param(Param("len", len)); - n = b.add("test_array_input")(n["array_output"]); + n = b.add("test_array_input")(n["array_output"]).set_param(Param("array_input.size", len)); Halide::Buffer out(w, h); out.fill(0); diff --git a/test/array_inout.cc b/test/array_inout.cc index a70a4c2f..8d104820 100644 --- a/test/array_inout.cc +++ b/test/array_inout.cc @@ -25,7 +25,7 @@ int main() { Builder b; b.set_target(Halide::get_host_target()); auto n = b.add("test_array_output")(in).set_param(Param("len", len)); - n = b.add("test_array_input")(n["array_output"]); + n = b.add("test_array_input")(n["array_output"]).set_param(Param("array_input.size", len)); n["output"].bind(out); b.run(); diff --git a/test/array_input.cc b/test/array_input.cc index 963ffa20..1dba284d 100644 --- a/test/array_input.cc +++ b/test/array_input.cc @@ -18,8 +18,8 @@ int main() { Builder b; b.set_target(Halide::get_host_target()); Node n; - n = b.add("test_array_copy")(input).set_param(Param("len", len)); - n = b.add("test_array_input")(n["array_output"]).set_param(Param("len", len)); + n = b.add("test_array_copy")(input).set_param(Param("array_input.size", len)); + n = b.add("test_array_input")(n["array_output"]).set_param(Param("array_input.size", len)); std::vector> ins{ Halide::Buffer{w, h}, @@ -70,8 +70,8 @@ int main() { Builder b; b.set_target(Halide::get_host_target()); Node n; - n = b.add("test_array_copy")(input).set_param(Param("len", len)); - n = b.add("test_array_input")(n["array_output"]).set_param(Param("len", len)); + n = b.add("test_array_copy")(input).set_param(Param("array_input.size", len)); + n = b.add("test_array_input")(n["array_output"]).set_param(Param("array_input.size", len)); Halide::Buffer in0(w, h), in1(w, h), in2(w, h), in3(w, h), in4(w, h); @@ -130,8 +130,8 @@ int main() { Builder b; b.set_target(Halide::get_host_target()); Node n; - n = b.add("test_array_copy")(ins).set_param(Param("len", len)); - n = b.add("test_array_input")(n["array_output"]).set_param(Param("len", len)); + n = b.add("test_array_copy")(ins).set_param(Param("array_input.size", len)); + n = b.add("test_array_input")(n["array_output"]).set_param(Param("array_input.size", len)); for (int y = 0; y < h; ++y) { for (int x = 0; x < w; ++x) { diff --git a/test/array_output.cc b/test/array_output.cc index 959328fb..98b100b3 100644 --- a/test/array_output.cc +++ b/test/array_output.cc @@ -36,7 +36,7 @@ int main() { Node n; n = b.add("test_array_output")(in).set_param(Param("len", len)); - n = b.add("test_array_copy")(n["array_output"]).set_param(Param("len", len)); + n = b.add("test_array_copy")(n["array_output"]).set_param(Param("array_input.size", len)); for (int i=0; i::make_scalar(); auto nodes = b.nodes(); - nodes[1]["min0"].bind(&min0); - nodes[1]["extent0"].bind(&extent0); - nodes[1]["min1"].bind(&min1); - nodes[1]["extent1"].bind(&extent1); - nodes[1]["v"].bind(&v); - nodes[1]["output"].bind(out); + nodes[1].iport("min0").bind(&min0); + nodes[1].iport("extent0").bind(&extent0); + nodes[1].iport("min1").bind(&min1); + nodes[1].iport("extent1").bind(&extent1); + nodes[1].iport("v").bind(&v); + nodes[1].oport("output").bind(out); b.run(); } @@ -79,11 +79,11 @@ int main() auto nodes = b.nodes(); - nodes[0]["input"].bind(in); - nodes[1]["input_width"].bind(&size); - nodes[1]["input_height"].bind(&size); - nodes[4]["output_height"].bind(&size); - nodes[4]["output"].bind(out); + nodes[0].iport("input").bind(in); + nodes[1].iport("input_width").bind(&size); + nodes[1].iport("input_height").bind(&size); + nodes[4].iport("output_height").bind(&size); + nodes[4].oport("output").bind(out); b.compile("ex"); b.run(); @@ -120,7 +120,7 @@ int main() b.with_bb_module("ion-bb-test"); b.set_target(ion::get_host_target()); auto n = b.add("test_array_output")(input).set_param(Param("len", len)); - n = b.add("test_array_input")(n["array_output"]); + n = b.add("test_array_input")(n["array_output"]).set_param(Param("array_input.size", len)); b.save("array_inout.json"); } { @@ -140,9 +140,9 @@ int main() for (auto& n : b.nodes()) { if (n.name() == "test_array_output") { - n["input"].bind(in); + n.iport("input").bind(in); } else if (n.name() == "test_array_input") { - n["output"].bind(out); + n.oport("output").bind(out); } } diff --git a/test/inverted_dep.cc b/test/inverted_dep.cc index b06745b1..fe76a7e2 100644 --- a/test/inverted_dep.cc +++ b/test/inverted_dep.cc @@ -16,24 +16,23 @@ int main() std::string graph = R"( { "nodes": [ - { - "id": "9ebf9c1e-25bf-451d-b92e-54322c72476f", + "id": "3de72ac3-d7e4-4de1-b73e-49856f8b5fc7", "name": "test_consumer", "params": [], "ports": [ { "dimensions": 0, - "impl_ptr": 94067312766832, + "id": "2792b187-a42f-4c02-9399-25fc3acddd8e", "index": -1, "pred_chan": [ - "2c706f47-6f51-4f1e-82de-f87f2dd0e9ab", + "c4fcbdba-7da4-4149-80ab-4ad5da37b435", "output" ], "size": 1, "succ_chans": [ [ - "9ebf9c1e-25bf-451d-b92e-54322c72476f", + "3de72ac3-d7e4-4de1-b73e-49856f8b5fc7", "input" ] ], @@ -45,7 +44,7 @@ int main() }, { "dimensions": 0, - "impl_ptr": 94067313268224, + "id": "b44a2f84-b7a2-40a4-9fbf-ed80078b6123", "index": -1, "pred_chan": [ "", @@ -54,7 +53,7 @@ int main() "size": 1, "succ_chans": [ [ - "9ebf9c1e-25bf-451d-b92e-54322c72476f", + "3de72ac3-d7e4-4de1-b73e-49856f8b5fc7", "min0" ] ], @@ -66,7 +65,7 @@ int main() }, { "dimensions": 0, - "impl_ptr": 94067313265008, + "id": "2f9ab162-f72a-42c8-8b92-2cbcf5ce71f7", "index": -1, "pred_chan": [ "", @@ -75,7 +74,7 @@ int main() "size": 1, "succ_chans": [ [ - "9ebf9c1e-25bf-451d-b92e-54322c72476f", + "3de72ac3-d7e4-4de1-b73e-49856f8b5fc7", "extent0" ] ], @@ -87,7 +86,7 @@ int main() }, { "dimensions": 0, - "impl_ptr": 94067313264752, + "id": "ba2f373c-2dd7-436f-b816-0ca59ca83037", "index": -1, "pred_chan": [ "", @@ -96,7 +95,7 @@ int main() "size": 1, "succ_chans": [ [ - "9ebf9c1e-25bf-451d-b92e-54322c72476f", + "3de72ac3-d7e4-4de1-b73e-49856f8b5fc7", "min1" ] ], @@ -108,7 +107,7 @@ int main() }, { "dimensions": 0, - "impl_ptr": 94067312830512, + "id": "537fd4b2-eef1-4c69-a04f-bd09adf3c93f", "index": -1, "pred_chan": [ "", @@ -117,7 +116,7 @@ int main() "size": 1, "succ_chans": [ [ - "9ebf9c1e-25bf-451d-b92e-54322c72476f", + "3de72ac3-d7e4-4de1-b73e-49856f8b5fc7", "extent1" ] ], @@ -129,7 +128,7 @@ int main() }, { "dimensions": 0, - "impl_ptr": 94067312767360, + "id": "80f24262-a521-43b7-8063-3b410fb5c509", "index": -1, "pred_chan": [ "", @@ -138,7 +137,7 @@ int main() "size": 1, "succ_chans": [ [ - "9ebf9c1e-25bf-451d-b92e-54322c72476f", + "3de72ac3-d7e4-4de1-b73e-49856f8b5fc7", "v" ] ], @@ -147,28 +146,12 @@ int main() "code": 0, "lanes": 1 } - }, - { - "dimensions": 0, - "impl_ptr": 94067312830256, - "index": -1, - "pred_chan": [ - "9ebf9c1e-25bf-451d-b92e-54322c72476f", - "output" - ], - "size": 1, - "succ_chans": [], - "type": { - "bits": 0, - "code": 3, - "lanes": 0 - } } ], - "target": "host-trace_pipeline" + "target": "host-profile" }, { - "id": "2c706f47-6f51-4f1e-82de-f87f2dd0e9ab", + "id": "c4fcbdba-7da4-4149-80ab-4ad5da37b435", "name": "test_producer", "params": [ { @@ -179,16 +162,16 @@ int main() "ports": [ { "dimensions": 0, - "impl_ptr": 94067312766832, + "id": "2792b187-a42f-4c02-9399-25fc3acddd8e", "index": -1, "pred_chan": [ - "2c706f47-6f51-4f1e-82de-f87f2dd0e9ab", + "c4fcbdba-7da4-4149-80ab-4ad5da37b435", "output" ], "size": 1, "succ_chans": [ [ - "9ebf9c1e-25bf-451d-b92e-54322c72476f", + "3de72ac3-d7e4-4de1-b73e-49856f8b5fc7", "input" ] ], @@ -199,10 +182,10 @@ int main() } } ], - "target": "host-trace_pipeline" + "target": "host-profile" } ], - "target": "host-trace_pipeline" + "target": "host-profile" } )"; std::ofstream ofs(file_name); @@ -220,17 +203,17 @@ int main() for (auto& n : b.nodes()) { std::cout << n.name() << std::endl; if (n.name() == "test_consumer") { - n["min0"].bind(&min0); - n["extent0"].bind(&extent0); - n["min1"].bind(&min1); - n["extent1"].bind(&extent1); - n["v"].bind(&v); - n["output"].bind(r); + n.iport("min0").bind(&min0); + n.iport("extent0").bind(&extent0); + n.iport("min1").bind(&min1); + n.iport("extent1").bind(&extent1); + n.iport("v").bind(&v); + n.oport("output").bind(r); } } b.run(); - + } catch (const Halide::Error &e) { std::cerr << e.what() << std::endl; return 1; diff --git a/test/test-bb.h b/test/test-bb.h index 65be4e0b..d79d6d8f 100644 --- a/test/test-bb.h +++ b/test/test-bb.h @@ -104,12 +104,24 @@ class Merge : public BuildingBlock { template class Inc : public BuildingBlock> { public: - BuildingBlockParam v{"v", 0}; Input input{"input", Halide::type_of(), D}; Output output{"output", Halide::type_of(), D}; + BuildingBlockParam v{"v", 0}; + BuildingBlockParam enable_extra_input{"enable_extra_input", false}; + Input *extra_input; + + void configure() { + if (enable_extra_input) { + extra_input = Halide::Internal::GeneratorBase::add_input("extra_input"); + } + } void generate() { - output(Halide::_) = input(Halide::_) + v; + Halide::Expr rv = input(Halide::_) + v; + if (enable_extra_input) { + rv += static_cast(enable_extra_input) ? *extra_input : Halide::Internal::make_const(Halide::type_of(), 0); + } + output(Halide::_) = rv; } void schedule() { @@ -186,14 +198,12 @@ class MultiOut : public BuildingBlock { class ArrayInput : public BuildingBlock { public: - BuildingBlockParam len{"len", 5}; - Input array_input{"array_input", Int(32), 2}; Output output{"output", Int(32), 2}; void generate() { Halide::Expr v = 0; - for (int i = 0; i < len; ++i) { + for (int i = 0; i < array_input.size(); ++i) { v += array_input[i](x, y); } output(x, y) = v; @@ -222,14 +232,12 @@ class ArrayOutput : public BuildingBlock { class ArrayCopy : public BuildingBlock { public: - BuildingBlockParam len{"len", 5}; - Input array_input{"array_input", Int(32), 2}; Output array_output{"array_output", Int(32), 2}; void generate() { - array_output.resize(len); - for (int i = 0; i < len; ++i) { + array_output.resize(array_input.size()); + for (int i = 0; i < array_input.size(); ++i) { array_output[i](x, y) = array_input[i](x, y); } } diff --git a/test/validation.cc b/test/validation.cc new file mode 100644 index 00000000..6eab72c4 --- /dev/null +++ b/test/validation.cc @@ -0,0 +1,102 @@ +#include "ion/ion.h" + +#include "spdlog/cfg/helpers.h" +#include "spdlog/details/os.h" +#include "spdlog/sinks/stdout_color_sinks.h" +#include "spdlog/sinks/basic_file_sink.h" + +using namespace ion; + +int main() +{ + try { + Buffer input(2, 2); + Buffer output(2, 2); + + // Unknown parameter + { + Builder b; + b.with_bb_module("ion-bb-test"); + b.set_target(Halide::get_host_target()); + Node n; + n = b.add("test_inc_i32x2")(input).set_param(Param("unknown-parameter", 1)); + n = b.add("test_inc_i32x2")(n["output"]); + n["output"].bind(output); + + try { + b.run(); + } catch (const std::exception& e) { + // The error should thrown as runtime_error, not Halide::Error + std::cerr << e.what() << std::endl; + } + } + + // Unknown output port 1 + { + Builder b; + b.with_bb_module("ion-bb-test"); + b.set_target(Halide::get_host_target()); + Node n; + n = b.add("test_inc_i32x2")(input).set_param(Param("v", 41)); + n = b.add("test_inc_i32x2")(n["unknown-port"]); + n["output"].bind(output); + + try { + b.run(); + } catch (const std::exception& e) { + // The error should thrown as runtime_error, not Halide::Error + std::cerr << e.what() << std::endl; + } + } + + // Unknown output port 2 + { + Builder b; + b.with_bb_module("ion-bb-test"); + b.set_target(Halide::get_host_target()); + Node n; + n = b.add("test_inc_i32x2")(input).set_param(Param("v", 41)); + n = b.add("test_inc_i32x2")(n["output"]); + n["unknown-port"].bind(output); + + try { + b.run(); + } catch (const std::exception& e) { + // The error should thrown as runtime_error, not Halide::Error + std::cerr << e.what() << std::endl; + } + } + + // Unknown input port 1 + { + Builder b; + b.with_bb_module("ion-bb-test"); + b.set_target(Halide::get_host_target()); + + Buffer unknown(2, 2); + + Node n; + n = b.add("test_inc_i32x2")(input, unknown).set_param(Param("v", 41)); + n = b.add("test_inc_i32x2")(n["output"]); + n["output"].bind(output); + + try { + b.run(); + } catch (const std::exception& e) { + // The error should thrown as runtime_error, not Halide::Error + std::cerr << e.what() << std::endl; + } + } + + } catch (Halide::Error& e) { + std::cerr << e.what() << std::endl; + return 1; + } catch (const std::exception& e) { + std::cerr << e.what() << std::endl; + return 0; + } + + std::cout << "Passed" << std::endl; + + return 0; +}