Skip to content

Commit

Permalink
Merge pull request #209 from fixstars/feature/validation
Browse files Browse the repository at this point in the history
Improve validation
  • Loading branch information
iitaku authored Jan 18, 2024
2 parents 87c5a67 + b91091b commit 564bd08
Show file tree
Hide file tree
Showing 22 changed files with 463 additions and 199 deletions.
4 changes: 3 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,9 @@ endif()
file(GLOB ION_CORE_SRC LIST_DIRECTORIES false ${PROJECT_SOURCE_DIR}/src/*)
add_library(ion-core SHARED ${ION_CORE_SRC})
target_include_directories(ion-core PUBLIC ${PROJECT_SOURCE_DIR}/include ${PROJECT_SOURCE_DIR}/src)
if (UNIX)
if (APPLE)
target_link_libraries(ion-core PUBLIC Halide::Halide Halide::Runtime)
elseif (UNIX)
target_link_libraries(ion-core PUBLIC Halide::Halide Halide::Runtime dl pthread z m stdc++)
else()
target_link_libraries(ion-core PUBLIC Halide::Halide Halide::Runtime)
Expand Down
18 changes: 5 additions & 13 deletions cmake/IonUtil.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,11 @@ function(ion_aot_executable NAME_PREFIX)
# Build compile
set(COMPILE_NAME ${NAME_PREFIX}_compile)
add_executable(${COMPILE_NAME} ${IAE_SRCS_COMPILE})
if(UNIX AND NOT APPLE)
if(UNIX)
target_compile_options(${COMPILE_NAME} PUBLIC -fno-rtti) # For Halide::Generator
target_link_options(${COMPILE_NAME} PUBLIC -Wl,--export-dynamic) # For JIT compiling
endif()
IF (APPLE)
target_compile_options(${COMPILE_NAME}
PUBLIC -fno-rtti # For Halide::Generator
PUBLIC -rdynamic) # For JIT compiling
if(NOT APPLE)
target_link_options(${COMPILE_NAME} PUBLIC -Wl,--export-dynamic) # For JIT compiling
endif()
endif()
target_include_directories(${COMPILE_NAME} PUBLIC "${PROJECT_SOURCE_DIR}/include")
target_link_libraries(${COMPILE_NAME} PRIVATE ion-core ${PLATFORM_LIBRARIES})
Expand Down Expand Up @@ -105,12 +102,7 @@ function(ion_jit_executable NAME_PREFIX)
add_executable(${NAME} ${IJE_SRCS})
if (UNIX AND NOT APPLE)
target_link_options(${NAME} PUBLIC -Wl,--export-dynamic) # For JIT compiling
endif()
if (APPLE)
target_compile_options(${NAME}
PUBLIC -fno-rtti # For Halide::Generator
PUBLIC -rdynamic) # For JIT compiling
endif()
endif()
add_dependencies(${NAME} ion-bb ion-bb-test)
target_include_directories(${NAME} PUBLIC ${PROJECT_SOURCE_DIR}/include ${ION_BB_INCLUDE_DIRS} ${IJE_INCS})
target_link_libraries(${NAME} PRIVATE ion-core ${ION_BB_LIBRARIES} ${PLATFORM_LIBRARIES} ${IJE_LIBS})
Expand Down
4 changes: 2 additions & 2 deletions example/sgm_compile.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,14 @@ int main() {
Node ln = b.add("image_io_color_data_loader").set_param(Param{"url", "http://ion-kit.s3.us-west-2.amazonaws.com/images/aloe_left.jpg"}, Param{"width", std::to_string(input_width)}, Param{"height", std::to_string(input_height)});
ln = b.add("base_normalize_3d_uint8")(ln["output"]);
ln = b.add("image_processing_resize_nearest_3d")(ln["output"]).set_param(Param{"width", std::to_string(input_width)}, Param{"height", std::to_string(input_height)}, Param{"scale", std::to_string(scale)});
ln = b.add("base_schedule")(ln["output"]).set_param(Param{"output_name", "scaled_left"}, Param{"compute_level", "compute_root"});
ln = b.add("base_schedule")(ln["output"]).set_param(Param{"output_name", "scaled_left"}, Param{"compute_level", "compute_root"}, Param("input.type", "float32"), Param("input.dim", 3));
ln = b.add("image_processing_calc_luminance")(ln["output"]).set_param(Param{"luminance_method", "Average"});
ln = b.add("base_denormalize_2d_uint8")(ln["output"]);

Node rn = b.add("image_io_color_data_loader").set_param(Param{"url", "http://ion-kit.s3.us-west-2.amazonaws.com/images/aloe_right.jpg"}, Param{"width", std::to_string(input_width)}, Param{"height", std::to_string(input_height)});
rn = b.add("base_normalize_3d_uint8")(rn["output"]);
rn = b.add("image_processing_resize_nearest_3d")(rn["output"]).set_param(Param{"width", std::to_string(input_width)}, Param{"height", std::to_string(input_height)}, Param{"scale", std::to_string(scale)});
rn = b.add("base_schedule")(rn["output"]).set_param(Param{"output_name", "scaled_right"}, Param{"compute_level", "compute_root"});
rn = b.add("base_schedule")(rn["output"]).set_param(Param{"output_name", "scaled_right"}, Param{"compute_level", "compute_root"}, Param("input.type", "float32"), Param("input.dim", 3));
rn = b.add("image_processing_calc_luminance")(rn["output"]).set_param(Param{"luminance_method", "Average"});
rn = b.add("base_denormalize_2d_uint8")(rn["output"]);

Expand Down
2 changes: 2 additions & 0 deletions include/ion/builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,8 @@ class Builder {

Halide::Pipeline build(bool implicit_output = false);

void determine_and_validate();

std::vector<Halide::Argument> get_arguments_stub() const;
std::vector<const void*> get_arguments_instance() const;

Expand Down
5 changes: 4 additions & 1 deletion include/ion/building_block.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,10 @@ class BuildingBlock : public Halide::Generator<T> {

template<typename... Ts>
void register_disposer(const std::string& n) {
reinterpret_cast<Builder*>(static_cast<uint64_t>(builder_ptr))->register_disposer(bb_id, n);
auto bb(reinterpret_cast<Builder*>(static_cast<uint64_t>(builder_ptr)));
if (bb) {
bb->register_disposer(bb_id, n);
}
}

ion::Buffer<uint8_t> get_id() {
Expand Down
23 changes: 4 additions & 19 deletions include/ion/node.h
Original file line number Diff line number Diff line change
Expand Up @@ -110,26 +110,11 @@ class Node {
return impl_->ports;
}

std::vector<Port> iports() const {
std::vector<Port> iports;
for (const auto& p: impl_->ports) {
if (std::count_if(p.impl_->succ_chans.begin(), p.impl_->succ_chans.end(),
[&](const Port::Channel& c) { return std::get<0>(c) == impl_->id; })) {
iports.push_back(p);
}
}
return iports;
}
Port iport(const std::string& pn);
std::vector<std::tuple<std::string, Port>> iports() const;

std::vector<Port> oports() const {
std::vector<Port> oports;
for (const auto& p: impl_->ports) {
if (id() == p.pred_id()) {
oports.push_back(p);
}
}
return oports;
}
Port oport(const std::string& pn);
std::vector<std::tuple<std::string, Port>> oports() const;

private:
Node(const std::string& id, const std::string& name, const Halide::Target& target)
Expand Down
46 changes: 18 additions & 28 deletions include/ion/port.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,12 +53,7 @@ class Port {

private:
struct Impl {
// std::string pred_id;
// std::string pred_name;

// std::string succ_id;
// std::string succ_name;

std::string id;
Channel pred_chan;
std::set<Channel> succ_chans;

Expand All @@ -68,13 +63,8 @@ class Port {
std::unordered_map<int32_t, Halide::Internal::Parameter> params;
std::unordered_map<int32_t, const void *> instances;

Impl() {}

Impl(const std::string& pid, const std::string& pn, const Halide::Type& t, int32_t d)
: pred_chan{pid, pn}, succ_chans{}, type(t), dimensions(d)
{
params[0] = Halide::Internal::Parameter(type, dimensions != 0, dimensions, argument_name(pid, pn, 0));
}
Impl();
Impl(const std::string& pid, const std::string& pn, const Halide::Type& t, int32_t d);
};

public:
Expand Down Expand Up @@ -103,7 +93,7 @@ class Port {
*/
template<typename T,
typename std::enable_if<std::is_arithmetic<T>::value>::type* = nullptr>
Port(T *vptr) : impl_(new Impl("", Halide::Internal::unique_name("ion_port"), Halide::type_of<T>(), 0)), index_(-1) {
Port(T *vptr) : impl_(new Impl("", Halide::Internal::unique_name("_ion_port_"), Halide::type_of<T>(), 0)), index_(-1) {
this->bind(vptr);
}

Expand All @@ -124,6 +114,7 @@ class Port {
}

// Getter
const std::string& id() const { return impl_->id; }
const Channel& pred_chan() const { return impl_->pred_chan; }
const std::string& pred_id() const { return std::get<0>(impl_->pred_chan); }
const std::string& pred_name() const { return std::get<1>(impl_->pred_chan); }
Expand All @@ -132,15 +123,22 @@ class Port {
int32_t dimensions() const { return impl_->dimensions; }
int32_t size() const { return static_cast<int32_t>(impl_->params.size()); }
int32_t index() const { return index_; }
uintptr_t impl_ptr() const { return reinterpret_cast<uintptr_t>(impl_.get()); }

// Setter
void set_index(int index) { index_ = index; }

// Util
bool has_pred() const { return !std::get<0>(impl_->pred_chan).empty(); }
bool has_pred_by_nid(const std::string& nid) const { return !std::get<0>(impl_->pred_chan).empty(); }
bool has_succ() const { return !impl_->succ_chans.empty(); }
bool has_succ(const Channel& c) const { return impl_->succ_chans.count(c); }
bool has_succ_by_nid(const std::string& nid) const {
return std::count_if(impl_->succ_chans.begin(),
impl_->succ_chans.end(),
[&](const Port::Channel& c) { return std::get<0>(c) == nid; });
}

void determine_succ(const std::string& nid, const std::string& old_pn, const std::string& new_pn);

/**
* Overloaded operator to set the port index and return a reference to the current port. eg. port[0]
Expand Down Expand Up @@ -189,25 +187,17 @@ class Port {
}
}

static std::tuple<std::shared_ptr<Impl>, bool> find_impl(uintptr_t ptr) {
static std::unordered_map<uintptr_t, std::shared_ptr<Impl>> impls;
static std::mutex mutex;
std::scoped_lock lock(mutex);
bool found = true;
if (!impls.count(ptr)) {
impls[ptr] = std::make_shared<Impl>();
found = false;
}
return std::make_tuple(impls[ptr], found);
}
static std::tuple<std::shared_ptr<Impl>, bool> find_impl(const std::string& id);

private:
/**
* This port is created from another node
* This port is created from another node.
* In this case, it is not sure what this port is input or output.
* pid and pn is stored in both pred and succ,
* then it will determined through pipeline build process.
*/
Port(const std::string& pid, const std::string& pn) : impl_(new Impl(pid, pn, Halide::Type(), 0)), index_(-1) {}


std::vector<Halide::Argument> as_argument() const {
std::vector<Halide::Argument> args;
for (const auto& [i, param] : impl_->params) {
Expand Down
9 changes: 2 additions & 7 deletions src/bb/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -32,16 +32,11 @@ target_include_directories(ion-bb PUBLIC ${PROJECT_SOURCE_DIR}/include ${PROJECT
target_link_libraries(ion-bb PUBLIC ion-core ${ION_BB_LIBRARIES})
if(UNIX)
target_compile_options(ion-bb PUBLIC -fno-rtti) # For Halide::Generator
if(APPLE)
target_compile_options(ion-bb
PUBLIC -fno-rtti # For Halide::Generator
PUBLIC -rdynamic) # For JIT compiling
else()
if(NOT APPLE)
target_link_options(ion-bb PUBLIC -Wl,--export-dynamic) # For JIT compiling
endif()
elseif(MSVC)
target_compile_options(ion-bb
PUBLIC /bigobj)
target_compile_options(ion-bb PUBLIC /bigobj)
endif()

#
Expand Down
Loading

0 comments on commit 564bd08

Please sign in to comment.