Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve validation #209

Merged
merged 8 commits into from
Jan 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,9 @@ endif()
file(GLOB ION_CORE_SRC LIST_DIRECTORIES false ${PROJECT_SOURCE_DIR}/src/*)
add_library(ion-core SHARED ${ION_CORE_SRC})
target_include_directories(ion-core PUBLIC ${PROJECT_SOURCE_DIR}/include ${PROJECT_SOURCE_DIR}/src)
if (UNIX)
if (APPLE)
target_link_libraries(ion-core PUBLIC Halide::Halide Halide::Runtime)
elseif (UNIX)
target_link_libraries(ion-core PUBLIC Halide::Halide Halide::Runtime dl pthread z m stdc++)
else()
target_link_libraries(ion-core PUBLIC Halide::Halide Halide::Runtime)
Expand Down
18 changes: 5 additions & 13 deletions cmake/IonUtil.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,11 @@ function(ion_aot_executable NAME_PREFIX)
# Build compile
set(COMPILE_NAME ${NAME_PREFIX}_compile)
add_executable(${COMPILE_NAME} ${IAE_SRCS_COMPILE})
if(UNIX AND NOT APPLE)
if(UNIX)
target_compile_options(${COMPILE_NAME} PUBLIC -fno-rtti) # For Halide::Generator
target_link_options(${COMPILE_NAME} PUBLIC -Wl,--export-dynamic) # For JIT compiling
endif()
IF (APPLE)
target_compile_options(${COMPILE_NAME}
PUBLIC -fno-rtti # For Halide::Generator
PUBLIC -rdynamic) # For JIT compiling
if(NOT APPLE)
target_link_options(${COMPILE_NAME} PUBLIC -Wl,--export-dynamic) # For JIT compiling
endif()
endif()
target_include_directories(${COMPILE_NAME} PUBLIC "${PROJECT_SOURCE_DIR}/include")
target_link_libraries(${COMPILE_NAME} PRIVATE ion-core ${PLATFORM_LIBRARIES})
Expand Down Expand Up @@ -105,12 +102,7 @@ function(ion_jit_executable NAME_PREFIX)
add_executable(${NAME} ${IJE_SRCS})
if (UNIX AND NOT APPLE)
target_link_options(${NAME} PUBLIC -Wl,--export-dynamic) # For JIT compiling
endif()
if (APPLE)
target_compile_options(${NAME}
PUBLIC -fno-rtti # For Halide::Generator
PUBLIC -rdynamic) # For JIT compiling
endif()
endif()
add_dependencies(${NAME} ion-bb ion-bb-test)
target_include_directories(${NAME} PUBLIC ${PROJECT_SOURCE_DIR}/include ${ION_BB_INCLUDE_DIRS} ${IJE_INCS})
target_link_libraries(${NAME} PRIVATE ion-core ${ION_BB_LIBRARIES} ${PLATFORM_LIBRARIES} ${IJE_LIBS})
Expand Down
4 changes: 2 additions & 2 deletions example/sgm_compile.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,14 @@ int main() {
Node ln = b.add("image_io_color_data_loader").set_param(Param{"url", "http://ion-kit.s3.us-west-2.amazonaws.com/images/aloe_left.jpg"}, Param{"width", std::to_string(input_width)}, Param{"height", std::to_string(input_height)});
ln = b.add("base_normalize_3d_uint8")(ln["output"]);
ln = b.add("image_processing_resize_nearest_3d")(ln["output"]).set_param(Param{"width", std::to_string(input_width)}, Param{"height", std::to_string(input_height)}, Param{"scale", std::to_string(scale)});
ln = b.add("base_schedule")(ln["output"]).set_param(Param{"output_name", "scaled_left"}, Param{"compute_level", "compute_root"});
ln = b.add("base_schedule")(ln["output"]).set_param(Param{"output_name", "scaled_left"}, Param{"compute_level", "compute_root"}, Param("input.type", "float32"), Param("input.dim", 3));
ln = b.add("image_processing_calc_luminance")(ln["output"]).set_param(Param{"luminance_method", "Average"});
ln = b.add("base_denormalize_2d_uint8")(ln["output"]);

Node rn = b.add("image_io_color_data_loader").set_param(Param{"url", "http://ion-kit.s3.us-west-2.amazonaws.com/images/aloe_right.jpg"}, Param{"width", std::to_string(input_width)}, Param{"height", std::to_string(input_height)});
rn = b.add("base_normalize_3d_uint8")(rn["output"]);
rn = b.add("image_processing_resize_nearest_3d")(rn["output"]).set_param(Param{"width", std::to_string(input_width)}, Param{"height", std::to_string(input_height)}, Param{"scale", std::to_string(scale)});
rn = b.add("base_schedule")(rn["output"]).set_param(Param{"output_name", "scaled_right"}, Param{"compute_level", "compute_root"});
rn = b.add("base_schedule")(rn["output"]).set_param(Param{"output_name", "scaled_right"}, Param{"compute_level", "compute_root"}, Param("input.type", "float32"), Param("input.dim", 3));
rn = b.add("image_processing_calc_luminance")(rn["output"]).set_param(Param{"luminance_method", "Average"});
rn = b.add("base_denormalize_2d_uint8")(rn["output"]);

Expand Down
2 changes: 2 additions & 0 deletions include/ion/builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,8 @@ class Builder {

Halide::Pipeline build(bool implicit_output = false);

void determine_and_validate();

std::vector<Halide::Argument> get_arguments_stub() const;
std::vector<const void*> get_arguments_instance() const;

Expand Down
5 changes: 4 additions & 1 deletion include/ion/building_block.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,10 @@ class BuildingBlock : public Halide::Generator<T> {

template<typename... Ts>
void register_disposer(const std::string& n) {
reinterpret_cast<Builder*>(static_cast<uint64_t>(builder_ptr))->register_disposer(bb_id, n);
auto bb(reinterpret_cast<Builder*>(static_cast<uint64_t>(builder_ptr)));
if (bb) {
bb->register_disposer(bb_id, n);
}
}

ion::Buffer<uint8_t> get_id() {
Expand Down
23 changes: 4 additions & 19 deletions include/ion/node.h
Original file line number Diff line number Diff line change
Expand Up @@ -110,26 +110,11 @@ class Node {
return impl_->ports;
}

std::vector<Port> iports() const {
std::vector<Port> iports;
for (const auto& p: impl_->ports) {
if (std::count_if(p.impl_->succ_chans.begin(), p.impl_->succ_chans.end(),
[&](const Port::Channel& c) { return std::get<0>(c) == impl_->id; })) {
iports.push_back(p);
}
}
return iports;
}
Port iport(const std::string& pn);
std::vector<std::tuple<std::string, Port>> iports() const;

std::vector<Port> oports() const {
std::vector<Port> oports;
for (const auto& p: impl_->ports) {
if (id() == p.pred_id()) {
oports.push_back(p);
}
}
return oports;
}
Port oport(const std::string& pn);
std::vector<std::tuple<std::string, Port>> oports() const;

private:
Node(const std::string& id, const std::string& name, const Halide::Target& target)
Expand Down
46 changes: 18 additions & 28 deletions include/ion/port.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,12 +53,7 @@ class Port {

private:
struct Impl {
// std::string pred_id;
// std::string pred_name;

// std::string succ_id;
// std::string succ_name;

std::string id;
Channel pred_chan;
std::set<Channel> succ_chans;

Expand All @@ -68,13 +63,8 @@ class Port {
std::unordered_map<int32_t, Halide::Internal::Parameter> params;
std::unordered_map<int32_t, const void *> instances;

Impl() {}

Impl(const std::string& pid, const std::string& pn, const Halide::Type& t, int32_t d)
: pred_chan{pid, pn}, succ_chans{}, type(t), dimensions(d)
{
params[0] = Halide::Internal::Parameter(type, dimensions != 0, dimensions, argument_name(pid, pn, 0));
}
Impl();
Impl(const std::string& pid, const std::string& pn, const Halide::Type& t, int32_t d);
};

public:
Expand Down Expand Up @@ -103,7 +93,7 @@ class Port {
*/
template<typename T,
typename std::enable_if<std::is_arithmetic<T>::value>::type* = nullptr>
Port(T *vptr) : impl_(new Impl("", Halide::Internal::unique_name("ion_port"), Halide::type_of<T>(), 0)), index_(-1) {
Port(T *vptr) : impl_(new Impl("", Halide::Internal::unique_name("_ion_port_"), Halide::type_of<T>(), 0)), index_(-1) {
this->bind(vptr);
}

Expand All @@ -124,6 +114,7 @@ class Port {
}

// Getter
const std::string& id() const { return impl_->id; }
const Channel& pred_chan() const { return impl_->pred_chan; }
const std::string& pred_id() const { return std::get<0>(impl_->pred_chan); }
const std::string& pred_name() const { return std::get<1>(impl_->pred_chan); }
Expand All @@ -132,15 +123,22 @@ class Port {
int32_t dimensions() const { return impl_->dimensions; }
int32_t size() const { return static_cast<int32_t>(impl_->params.size()); }
int32_t index() const { return index_; }
uintptr_t impl_ptr() const { return reinterpret_cast<uintptr_t>(impl_.get()); }

// Setter
void set_index(int index) { index_ = index; }

// Util
bool has_pred() const { return !std::get<0>(impl_->pred_chan).empty(); }
bool has_pred_by_nid(const std::string& nid) const { return !std::get<0>(impl_->pred_chan).empty(); }
bool has_succ() const { return !impl_->succ_chans.empty(); }
bool has_succ(const Channel& c) const { return impl_->succ_chans.count(c); }
bool has_succ_by_nid(const std::string& nid) const {
return std::count_if(impl_->succ_chans.begin(),
impl_->succ_chans.end(),
[&](const Port::Channel& c) { return std::get<0>(c) == nid; });
}

void determine_succ(const std::string& nid, const std::string& old_pn, const std::string& new_pn);

/**
* Overloaded operator to set the port index and return a reference to the current port. eg. port[0]
Expand Down Expand Up @@ -189,25 +187,17 @@ class Port {
}
}

static std::tuple<std::shared_ptr<Impl>, bool> find_impl(uintptr_t ptr) {
static std::unordered_map<uintptr_t, std::shared_ptr<Impl>> impls;
static std::mutex mutex;
std::scoped_lock lock(mutex);
bool found = true;
if (!impls.count(ptr)) {
impls[ptr] = std::make_shared<Impl>();
found = false;
}
return std::make_tuple(impls[ptr], found);
}
static std::tuple<std::shared_ptr<Impl>, bool> find_impl(const std::string& id);

private:
/**
* This port is created from another node
* This port is created from another node.
* In this case, it is not sure what this port is input or output.
* pid and pn is stored in both pred and succ,
* then it will determined through pipeline build process.
*/
Port(const std::string& pid, const std::string& pn) : impl_(new Impl(pid, pn, Halide::Type(), 0)), index_(-1) {}


std::vector<Halide::Argument> as_argument() const {
std::vector<Halide::Argument> args;
for (const auto& [i, param] : impl_->params) {
Expand Down
9 changes: 2 additions & 7 deletions src/bb/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -32,16 +32,11 @@ target_include_directories(ion-bb PUBLIC ${PROJECT_SOURCE_DIR}/include ${PROJECT
target_link_libraries(ion-bb PUBLIC ion-core ${ION_BB_LIBRARIES})
if(UNIX)
target_compile_options(ion-bb PUBLIC -fno-rtti) # For Halide::Generator
if(APPLE)
target_compile_options(ion-bb
PUBLIC -fno-rtti # For Halide::Generator
PUBLIC -rdynamic) # For JIT compiling
else()
if(NOT APPLE)
target_link_options(ion-bb PUBLIC -Wl,--export-dynamic) # For JIT compiling
endif()
elseif(MSVC)
target_compile_options(ion-bb
PUBLIC /bigobj)
target_compile_options(ion-bb PUBLIC /bigobj)
endif()

#
Expand Down
Loading