Skip to content

Commit

Permalink
capi
Browse files Browse the repository at this point in the history
  • Loading branch information
Xinyu Li authored and Xinyu Li committed Oct 24, 2024
1 parent 7a1ddfb commit 31f566c
Show file tree
Hide file tree
Showing 4 changed files with 191 additions and 16 deletions.
30 changes: 14 additions & 16 deletions include/ion/c_ion.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,20 @@ int ion_port_bind_u32(ion_port_t, uint32_t *);
int ion_port_bind_u64(ion_port_t, uint64_t *);
int ion_port_bind_f32(ion_port_t, float *);
int ion_port_bind_f64(ion_port_t, double *);

int ion_port_bind_i8_array(ion_port_t, int8_t *, int);
int ion_port_bind_i16_array(ion_port_t, int16_t *, int);
int ion_port_bind_i32_array(ion_port_t, int32_t *, int);
int ion_port_bind_i64_array(ion_port_t, int64_t *, int);
int ion_port_bind_u1_array(ion_port_t, bool *, int);
int ion_port_bind_u8_array(ion_port_t, uint8_t *, int);
int ion_port_bind_u16_array(ion_port_t, uint16_t *, int);
int ion_port_bind_u32_array(ion_port_t, uint32_t *, int);
int ion_port_bind_u64_array(ion_port_t, uint64_t *, int);
int ion_port_bind_f32_array(ion_port_t, float *, int);
int ion_port_bind_f64_array(ion_port_t, double *, int);


int ion_port_bind_buffer(ion_port_t, ion_buffer_t);
int ion_port_bind_buffer_array(ion_port_t, ion_buffer_t *, int);

Expand Down Expand Up @@ -84,22 +98,6 @@ int ion_graph_destroy(ion_graph_t);
int ion_graph_run(ion_graph_t);
int ion_graph_create_with_multiple(ion_graph_t *ptr, ion_graph_t *objs, int size);

[[deprecated("ion_port_bind* can be used instead of ion_port_map.")]] int ion_port_map_create(ion_port_map_t *);
[[deprecated("ion_port_bind* can be used instead of ion_port_map.")]] int ion_port_map_destroy(ion_port_map_t);
[[deprecated("ion_port_bind* can be used instead of ion_port_map.")]] int ion_port_map_set_i8(ion_port_map_t, ion_port_t, int8_t);
[[deprecated("ion_port_bind* can be used instead of ion_port_map.")]] int ion_port_map_set_i16(ion_port_map_t, ion_port_t, int16_t);
[[deprecated("ion_port_bind* can be used instead of ion_port_map.")]] int ion_port_map_set_i32(ion_port_map_t, ion_port_t, int32_t);
[[deprecated("ion_port_bind* can be used instead of ion_port_map.")]] int ion_port_map_set_i64(ion_port_map_t, ion_port_t, int64_t);
[[deprecated("ion_port_bind* can be used instead of ion_port_map.")]] int ion_port_map_set_u1(ion_port_map_t, ion_port_t, bool);
[[deprecated("ion_port_bind* can be used instead of ion_port_map.")]] int ion_port_map_set_u8(ion_port_map_t, ion_port_t, uint8_t);
[[deprecated("ion_port_bind* can be used instead of ion_port_map.")]] int ion_port_map_set_u16(ion_port_map_t, ion_port_t, uint16_t);
[[deprecated("ion_port_bind* can be used instead of ion_port_map.")]] int ion_port_map_set_u32(ion_port_map_t, ion_port_t, uint32_t);
[[deprecated("ion_port_bind* can be used instead of ion_port_map.")]] int ion_port_map_set_u64(ion_port_map_t, ion_port_t, uint64_t);
[[deprecated("ion_port_bind* can be used instead of ion_port_map.")]] int ion_port_map_set_f32(ion_port_map_t, ion_port_t, float);
[[deprecated("ion_port_bind* can be used instead of ion_port_map.")]] int ion_port_map_set_f64(ion_port_map_t, ion_port_t, double);
[[deprecated("ion_port_bind* can be used instead of ion_port_map.")]] int ion_port_map_set_buffer(ion_port_map_t, ion_port_t, ion_buffer_t);
[[deprecated("ion_port_bind* can be used instead of ion_port_map.")]] int ion_port_map_set_buffer_array(ion_port_map_t, ion_port_t, ion_buffer_t *, int);

#if defined __cplusplus
}
#endif
Expand Down
17 changes: 17 additions & 0 deletions include/ion/port.h
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,23 @@ class Port {
}
}


// For C API and Python binding compatibility
template<class T>
void bind(T *v, int size) {
for (int i = 0; i < size; i++) {
if (has_pred()) {
impl_->params[i] = Halide::Parameter{Halide::type_of<T>(), false, 0, argument_name(pred_id(), id(), pred_name(), i, graph_id())};
} else {
impl_->params[i] = Halide::Parameter{type(), false, dimensions(), argument_name(pred_id(), id(), pred_name(), i, graph_id())};
}
impl_->instances[i] = v;
impl_->bound_address[i] = std::make_tuple(v, false);
v +=1;
}
}


template<typename T>
void bind(const Halide::Buffer<T> &buf) {
auto i = index_ == -1 ? 0 : index_;
Expand Down
33 changes: 33 additions & 0 deletions src/c_ion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,39 @@ ION_PORT_BIND_IMPL(double *, f64)

#undef ION_PORT_BIND_IMPL

#define ION_PORT_BIND_ARRAY_IMPL(T, POSTFIX) \
int ion_port_bind_##POSTFIX##_array(ion_port_t obj, T v, int size) { \
try { \
reinterpret_cast<Port *>(obj)->bind(v, size); \
} catch (const Halide::Error &e) { \
log::error(e.what()); \
return 1; \
} catch (const std::exception &e) { \
log::error(e.what()); \
return 1; \
} catch (...) { \
log::error("Unknown exception was happened"); \
return 1; \
} \
\
return 0; \
}

ION_PORT_BIND_ARRAY_IMPL(int8_t *, i8)
ION_PORT_BIND_ARRAY_IMPL(int16_t *, i16)
ION_PORT_BIND_ARRAY_IMPL(int32_t *, i32)
ION_PORT_BIND_ARRAY_IMPL(int64_t *, i64)
ION_PORT_BIND_ARRAY_IMPL(bool *, u1)
ION_PORT_BIND_ARRAY_IMPL(uint8_t *, u8)
ION_PORT_BIND_ARRAY_IMPL(uint16_t *, u16)
ION_PORT_BIND_ARRAY_IMPL(uint32_t *, u32)
ION_PORT_BIND_ARRAY_IMPL(uint64_t *, u64)
ION_PORT_BIND_ARRAY_IMPL(float *, f32)
ION_PORT_BIND_ARRAY_IMPL(double *, f64)

#undef ION_PORT_BIND_ARRAY_IMPL


int ion_port_bind_buffer(ion_port_t obj, ion_buffer_t b) {
try {
// NOTE: Halide::Buffer class layout is safe to call Halide::Buffer<void>::type()
Expand Down
127 changes: 127 additions & 0 deletions test/c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,133 @@ int main() {
if (ret != 0)
return ret;
}

{
ion_type_t t = {.code = ion_type_int, .bits = 32, .lanes = 1};

ion_port_t ip;
ret = ion_port_create(&ip, "input", t, 2);
if (ret != 0)
return ret;


ion_port_t offsets_p;
ret = ion_port_create(&offsets_p, "input_offsets", t, 0);
if (ret != 0)
return ret;

ion_param_t len;
ret = ion_param_create(&len, "input_offsets.size", "4");
if (ret != 0)
return ret;

ion_builder_t b;
ret = ion_builder_create(&b);
if (ret != 0)
return ret;

ret = ion_builder_set_target(b, "host");
if (ret != 0)
return ret;

ret = ion_builder_with_bb_module(b, "ion-bb-test");
if (ret != 0)
return ret;


ion_node_t n;
ret = ion_builder_add_node(b, "test_scalar_array", &n);
if (ret != 0)
return ret;

ret = ion_node_set_params(n, &len, 1);
if (ret != 0)
return ret;

ion_port_t *ports = (ion_port_t *) malloc(2 * sizeof(ion_port_t));
ports[0] = ip;
ports[1] = offsets_p;
ret = ion_node_set_iports(n, ports, 2);
if (ret != 0)
return ret;

int sizes[] = {4, 4};
ion_buffer_t ibuf;
ret = ion_buffer_create(&ibuf, t, sizes, 2);
if (ret != 0)
return ret;

int in[4 * 4];
for (int i = 0; i < 4 * 4; ++i) {
in[i] = 42;
}
ret = ion_buffer_write(ibuf, in, 4 * 4 * sizeof(int));
if (ret != 0)
return ret;

ion_port_t op;
ret = ion_node_get_port(n, "output", &op);
if (ret != 0)
return ret;

ion_buffer_t *obufs = (ion_buffer_t *) malloc(4 * sizeof(ion_buffer_t));
for (int i = 0; i < 4; ++i) {
ret = ion_buffer_create(obufs + i, t, sizes, 2);
if (ret != 0)
return ret;
}

int in_offsets[4];
for (int i = 0; i < 4; ++i) {
in_offsets[i] = i;
}

ret = ion_port_bind_i32_array(offsets_p, (int *) (&in_offsets), 4);
if (ret != 0)
return ret;

ret = ion_port_bind_buffer(ip, ibuf);
if (ret != 0)
return ret;

ret = ion_port_bind_buffer_array(op, obufs, 4);
if (ret != 0)
return ret;

ret = ion_builder_run(b);
if (ret != 0)
return ret;

for (int i = 0;i < 4 ;i++){
int out[4 * 4] = {0};
ret = ion_buffer_read(*(obufs + i), out, 4 * 4 * sizeof(int));
if (ret != 0)
return ret;
if (out[0] != 42 + i) {
printf("%d\n", out[0]);
return -1;
}
}

ret = ion_port_destroy(ip);
if (ret != 0)
return ret;

ret = ion_port_destroy(offsets_p);
if (ret != 0)
return ret;

ret = ion_port_destroy(op);
if (ret != 0)
return ret;

ret = ion_builder_destroy(b);
if (ret != 0)
return ret;

free(ports);
}

{

ion_type_t t = {.code = ion_type_int, .bits = 32, .lanes = 1};
Expand Down

0 comments on commit 31f566c

Please sign in to comment.