Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update/ scalar input array #335

Draft
wants to merge 8 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 14 additions & 16 deletions include/ion/c_ion.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,20 @@ int ion_port_bind_u32(ion_port_t, uint32_t *);
int ion_port_bind_u64(ion_port_t, uint64_t *);
int ion_port_bind_f32(ion_port_t, float *);
int ion_port_bind_f64(ion_port_t, double *);

int ion_port_bind_i8_array(ion_port_t, int8_t *, int);
int ion_port_bind_i16_array(ion_port_t, int16_t *, int);
int ion_port_bind_i32_array(ion_port_t, int32_t *, int);
int ion_port_bind_i64_array(ion_port_t, int64_t *, int);
int ion_port_bind_u1_array(ion_port_t, bool *, int);
int ion_port_bind_u8_array(ion_port_t, uint8_t *, int);
int ion_port_bind_u16_array(ion_port_t, uint16_t *, int);
int ion_port_bind_u32_array(ion_port_t, uint32_t *, int);
int ion_port_bind_u64_array(ion_port_t, uint64_t *, int);
int ion_port_bind_f32_array(ion_port_t, float *, int);
int ion_port_bind_f64_array(ion_port_t, double *, int);


int ion_port_bind_buffer(ion_port_t, ion_buffer_t);
int ion_port_bind_buffer_array(ion_port_t, ion_buffer_t *, int);

Expand Down Expand Up @@ -84,22 +98,6 @@ int ion_graph_destroy(ion_graph_t);
int ion_graph_run(ion_graph_t);
int ion_graph_create_with_multiple(ion_graph_t *ptr, ion_graph_t *objs, int size);

[[deprecated("ion_port_bind* can be used instead of ion_port_map.")]] int ion_port_map_create(ion_port_map_t *);
[[deprecated("ion_port_bind* can be used instead of ion_port_map.")]] int ion_port_map_destroy(ion_port_map_t);
[[deprecated("ion_port_bind* can be used instead of ion_port_map.")]] int ion_port_map_set_i8(ion_port_map_t, ion_port_t, int8_t);
[[deprecated("ion_port_bind* can be used instead of ion_port_map.")]] int ion_port_map_set_i16(ion_port_map_t, ion_port_t, int16_t);
[[deprecated("ion_port_bind* can be used instead of ion_port_map.")]] int ion_port_map_set_i32(ion_port_map_t, ion_port_t, int32_t);
[[deprecated("ion_port_bind* can be used instead of ion_port_map.")]] int ion_port_map_set_i64(ion_port_map_t, ion_port_t, int64_t);
[[deprecated("ion_port_bind* can be used instead of ion_port_map.")]] int ion_port_map_set_u1(ion_port_map_t, ion_port_t, bool);
[[deprecated("ion_port_bind* can be used instead of ion_port_map.")]] int ion_port_map_set_u8(ion_port_map_t, ion_port_t, uint8_t);
[[deprecated("ion_port_bind* can be used instead of ion_port_map.")]] int ion_port_map_set_u16(ion_port_map_t, ion_port_t, uint16_t);
[[deprecated("ion_port_bind* can be used instead of ion_port_map.")]] int ion_port_map_set_u32(ion_port_map_t, ion_port_t, uint32_t);
[[deprecated("ion_port_bind* can be used instead of ion_port_map.")]] int ion_port_map_set_u64(ion_port_map_t, ion_port_t, uint64_t);
[[deprecated("ion_port_bind* can be used instead of ion_port_map.")]] int ion_port_map_set_f32(ion_port_map_t, ion_port_t, float);
[[deprecated("ion_port_bind* can be used instead of ion_port_map.")]] int ion_port_map_set_f64(ion_port_map_t, ion_port_t, double);
[[deprecated("ion_port_bind* can be used instead of ion_port_map.")]] int ion_port_map_set_buffer(ion_port_map_t, ion_port_t, ion_buffer_t);
[[deprecated("ion_port_bind* can be used instead of ion_port_map.")]] int ion_port_map_set_buffer_array(ion_port_map_t, ion_port_t, ion_buffer_t *, int);

#if defined __cplusplus
}
#endif
Expand Down
9 changes: 9 additions & 0 deletions include/ion/node.h
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,15 @@ class Node {
return Port(arg, impl_->graph_id);
}


template<class T, size_t N>
Port make_iport(std::array<T, N> * arg) const {
if (to_string(impl_->graph_id).empty())
return Port(arg);
else
return Port(arg, impl_->graph_id);
}

std::shared_ptr<Impl> impl_;
};

Expand Down
51 changes: 51 additions & 0 deletions include/ion/port.h
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,26 @@ class Port {
this->bind(vptr);
}


/**
* Construct new port from scalar array
*/
template<class T, size_t N>
Port(std::array<T, N> * arr)
: impl_(new Impl(NodeID(""), Halide::Internal::unique_name("_ion_port_"), Halide::type_of<T>(), 0, GraphID(""))), index_(-1) {
this->bind(arr);
}

/**
* Construct new port from scalar array and bind graph id to port
*/
template<class T, size_t N>
Port(std::array<T, N> * arr, const GraphID &gid)
: impl_(new Impl(NodeID(""), Halide::Internal::unique_name("_ion_port_"), Halide::type_of<T>(), 0, gid)), index_(-1) {
this->bind(arr);
}


/**
* Construct new port from buffer
*/
Expand Down Expand Up @@ -239,6 +259,37 @@ class Port {
impl_->bound_address[i] = std::make_tuple(v, false);
}


template<class T, size_t N>
void bind(std::array<T, N> * arr) {
for (int i = 0; i < N; i++) {
if (has_pred()) {
impl_->params[i] = Halide::Parameter{Halide::type_of<T>(), false, 0, argument_name(pred_id(), id(), pred_name(), i, graph_id())};
} else {
impl_->params[i] = Halide::Parameter{type(), false, dimensions(), argument_name(pred_id(), id(), pred_name(), i, graph_id())};
}
impl_->instances[i] = &(arr->at(i));
impl_->bound_address[i] = std::make_tuple(&(arr->at(i)), false);
}
}


// For C API and Python binding compatibility
template<class T>
void bind(T *v, int size) {
for (int i = 0; i < size; i++) {
if (has_pred()) {
impl_->params[i] = Halide::Parameter{Halide::type_of<T>(), false, 0, argument_name(pred_id(), id(), pred_name(), i, graph_id())};
} else {
impl_->params[i] = Halide::Parameter{type(), false, dimensions(), argument_name(pred_id(), id(), pred_name(), i, graph_id())};
}
impl_->instances[i] = v;
impl_->bound_address[i] = std::make_tuple(v, false);
v +=1;
}
}


template<typename T>
void bind(const Halide::Buffer<T> &buf) {
auto i = index_ == -1 ? 0 : index_;
Expand Down
109 changes: 78 additions & 31 deletions python/ionpy/Port.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,21 @@
ion_port_bind_f32,
ion_port_bind_f64,
ion_port_bind_buffer,
ion_port_bind_buffer_array
ion_port_bind_buffer_array,

ion_port_bind_i8_array,
ion_port_bind_i16_array,
ion_port_bind_i32_array,
ion_port_bind_i64_array,

ion_port_bind_u1_array,
ion_port_bind_u8_array,
ion_port_bind_u16_array,
ion_port_bind_u32_array,
ion_port_bind_u64_array,

ion_port_bind_f32_array,
ion_port_bind_f64_array,
)

from .Type import Type
Expand Down Expand Up @@ -62,38 +76,71 @@ def __del__(self):
if self.obj: # check not nullptr
ion_port_destroy(self.obj)

def bind(self, v: Union[int, float, Buffer, List[Buffer]]):
def bind(self, v: Union[int, float, Buffer, List[Union[Buffer, int, float]]]):
if self.dim == 0:
if self.bind_value is None:
self.bind_value = np.ctypeslib.as_ctypes_type(self.type.to_dtype())(v)
else:
self.bind_value.value = v
# scalar
if self.type.code_ == TypeCode.Int:
if self.type.bits_ == 8 and ion_port_bind_i8(self.obj, ctypes.byref(self.bind_value)) != 0:
raise Exception('Invalid operation')
elif self.type.bits_ == 16 and ion_port_bind_i16(self.obj, ctypes.byref(self.bind_value)) != 0:
raise Exception('Invalid operation')
elif self.type.bits_ == 32 and ion_port_bind_i32(self.obj, ctypes.byref(self.bind_value)) != 0:
raise Exception('Invalid operation')
elif self.type.bits_ == 64 and ion_port_bind_i64(self.obj, ctypes.byref(self.bind_value)) != 0:
raise Exception('Invalid operation')
elif self.type.code_ == TypeCode.Uint:
if self.type.bits_ == 1 and ion_port_bind_u1(self.obj, ctypes.byref(self.bind_value)) != 0:
raise Exception('Invalid operation')
if self.type.bits_ == 8 and ion_port_bind_u8(self.obj, ctypes.byref(self.bind_value)) != 0:
raise Exception('Invalid operation')
if self.type.bits_ == 16 and ion_port_bind_u16(self.obj, ctypes.byref(self.bind_value)) != 0:
raise Exception('Invalid operation')
if self.type.bits_ == 32 and ion_port_bind_u32(self.obj, ctypes.byref(self.bind_value)) != 0:
raise Exception('Invalid operation')
if self.type.bits_ == 64 and ion_port_bind_u64(self.obj, ctypes.byref(self.bind_value)) != 0:
raise Exception('Invalid operation')
elif self.type.code_ == TypeCode.Float:
if self.type.bits_ == 32 and ion_port_bind_f32(self.obj, ctypes.byref(self.bind_value)) != 0:
raise Exception('Invalid operation')
if self.type.bits_ == 64 and ion_port_bind_f64(self.obj, ctypes.byref(self.bind_value)) != 0:
raise Exception('Invalid operation')
if type(v) is not list:
if self.bind_value is None:
self.bind_value = np.ctypeslib.as_ctypes_type(self.type.to_dtype())(v)
else:
self.bind_value.value = v

if self.type.code_ == TypeCode.Int:
if self.type.bits_ == 8 and ion_port_bind_i8(self.obj, ctypes.byref(self.bind_value)) != 0:
raise Exception('Invalid operation')
elif self.type.bits_ == 16 and ion_port_bind_i16(self.obj, ctypes.byref(self.bind_value)) != 0:
raise Exception('Invalid operation')
elif self.type.bits_ == 32 and ion_port_bind_i32(self.obj, ctypes.byref(self.bind_value)) != 0:
raise Exception('Invalid operation')
elif self.type.bits_ == 64 and ion_port_bind_i64(self.obj, ctypes.byref(self.bind_value)) != 0:
raise Exception('Invalid operation')
elif self.type.code_ == TypeCode.Uint:
if self.type.bits_ == 1 and ion_port_bind_u1(self.obj, ctypes.byref(self.bind_value)) != 0:
raise Exception('Invalid operation')
if self.type.bits_ == 8 and ion_port_bind_u8(self.obj, ctypes.byref(self.bind_value)) != 0:
raise Exception('Invalid operation')
if self.type.bits_ == 16 and ion_port_bind_u16(self.obj, ctypes.byref(self.bind_value)) != 0:
raise Exception('Invalid operation')
if self.type.bits_ == 32 and ion_port_bind_u32(self.obj, ctypes.byref(self.bind_value)) != 0:
raise Exception('Invalid operation')
if self.type.bits_ == 64 and ion_port_bind_u64(self.obj, ctypes.byref(self.bind_value)) != 0:
raise Exception('Invalid operation')
elif self.type.code_ == TypeCode.Float:
if self.type.bits_ == 32 and ion_port_bind_f32(self.obj, ctypes.byref(self.bind_value)) != 0:
raise Exception('Invalid operation')
if self.type.bits_ == 64 and ion_port_bind_f64(self.obj, ctypes.byref(self.bind_value)) != 0:
raise Exception('Invalid operation')
else:
# scalar array
ctype = np.ctypeslib.as_ctypes_type(self.type.to_dtype())
c_arr = (ctype* len(v))(*v)
self.bind_value = c_arr
if self.type.code_ == TypeCode.Int:
if self.type.bits_ == 8 and ion_port_bind_i8_array(self.obj, self.bind_value, len(v)) != 0:
raise Exception('Invalid operation')
elif self.type.bits_ == 16 and ion_port_bind_i16_array(self.obj, self.bind_value, len(v)) != 0:
raise Exception('Invalid operation')
elif self.type.bits_ == 32 and ion_port_bind_i32_array(self.obj, self.bind_value, len(v)) != 0:
raise Exception('Invalid operation')
elif self.type.bits_ == 64 and ion_port_bind_i64_array(self.obj, self.bind_value, len(v)) != 0:
raise Exception('Invalid operation')
elif self.type.code_ == TypeCode.Uint:
if self.type.bits_ == 1 and ion_port_bind_u1_array(self.obj, self.bind_value, len(v)) != 0:
raise Exception('Invalid operation')
if self.type.bits_ == 8 and ion_port_bind_u8_array(self.obj, self.bind_value, len(v)) != 0:
raise Exception('Invalid operation')
if self.type.bits_ == 16 and ion_port_bind_u16_array(self.obj,self.bind_value, len(v)) != 0:
raise Exception('Invalid operation')
if self.type.bits_ == 32 and ion_port_bind_u32_array(self.obj, self.bind_value, len(v)) != 0:
raise Exception('Invalid operation')
if self.type.bits_ == 64 and ion_port_bind_u64_array(self.obj, self.bind_value, len(v)) != 0:
raise Exception('Invalid operation')
elif self.type.code_ == TypeCode.Float:
if self.type.bits_ == 32 and ion_port_bind_f32_array(self.obj,self.bind_value, len(v)) != 0:
raise Exception('Invalid operation')
if self.type.bits_ == 64 and ion_port_bind_f64_array(self.obj,self.bind_value, len(v)) != 0:
raise Exception('Invalid operation')

# vector
else:

Expand Down
58 changes: 57 additions & 1 deletion python/ionpy/native.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,11 +105,67 @@ class c_builder_compile_option_t(ctypes.Structure):
ion_port_bind_f32.restype = ctypes.c_int
ion_port_bind_f32.argtypes = [ c_ion_port_t, ctypes.POINTER(ctypes.c_float) ]

# int ion_port_bind_f64(ion_port_t, double*;
# int ion_port_bind_f64(ion_port_t, double*;);
ion_port_bind_f64 = ion_core.ion_port_bind_f64
ion_port_bind_f64.restype = ctypes.c_int
ion_port_bind_f64.argtypes = [ c_ion_port_t, ctypes.POINTER(ctypes.c_double) ]

# int ion_port_bind_i8_array(ion_port_t, int8_t*, int);
ion_port_bind_i8_array = ion_core.ion_port_bind_i8_array
ion_port_bind_i8_array.restype = ctypes.c_int
ion_port_bind_i8_array.argtypes = [c_ion_port_t, ctypes.POINTER(ctypes.c_int8), ctypes.c_int ]

# int ion_port_bind_i16_array(ion_port_t, int16_t*, int);
ion_port_bind_i16_array = ion_core.ion_port_bind_i16_array
ion_port_bind_i16_array.restype = ctypes.c_int
ion_port_bind_i16_array.argtypes = [c_ion_port_t, ctypes.POINTER(ctypes.c_int16), ctypes.c_int ]

# int ion_port_bind_i32_array(ion_port_t, int32_t*, int);
ion_port_bind_i32_array = ion_core.ion_port_bind_i32_array
ion_port_bind_i32_array.restype = ctypes.c_int
ion_port_bind_i32_array.argtypes = [c_ion_port_t, ctypes.POINTER(ctypes.c_int32), ctypes.c_int ]

# int ion_port_bind_i64_array(ion_port_t, int64_t*, int);
ion_port_bind_i64_array = ion_core.ion_port_bind_i64_array
ion_port_bind_i64_array.restype = ctypes.c_int
ion_port_bind_i64_array.argtypes = [ c_ion_port_t, ctypes.POINTER(ctypes.c_int64), ctypes.c_int ]

# int ion_port_map_set_u1_array(ion_port_t, bool*, int);
ion_port_bind_u1_array = ion_core.ion_port_bind_u1_array
ion_port_bind_u1_array.restype = ctypes.c_int
ion_port_bind_u1_array.argtypes = [ c_ion_port_t, ctypes.POINTER(ctypes.c_bool), ctypes.c_int ]

# int ion_port_bind_u8_array(ion_port_t, uint8_t*, int);
ion_port_bind_u8_array = ion_core.ion_port_bind_u8_array
ion_port_bind_u8_array.restype = ctypes.c_int
ion_port_bind_u8_array.argtypes = [ c_ion_port_t, ctypes.POINTER(ctypes.c_uint8), ctypes.c_int ]

# int ion_port_bind_u16_array(ion_port_t, uint16_t*, int);
ion_port_bind_u16_array = ion_core.ion_port_bind_u16_array
ion_port_bind_u16_array.restype = ctypes.c_int
ion_port_bind_u16_array.argtypes = [ c_ion_port_t, ctypes.POINTER(ctypes.c_uint16), ctypes.c_int ]

# int ion_port_bind_u32_array(ion_port_t, uint32_t*, int);
ion_port_bind_u32_array = ion_core.ion_port_bind_u32_array
ion_port_bind_u32_array.restype = ctypes.c_int
ion_port_bind_u32_array.argtypes = [ c_ion_port_t, ctypes.POINTER(ctypes.c_uint32), ctypes.c_int ]

# int ion_port_bind_u64_array(ion_port_t, uint64_t*, int);
ion_port_bind_u64_array = ion_core.ion_port_bind_u64_array
ion_port_bind_u64_array.restype = ctypes.c_int
ion_port_bind_u64_array.argtypes = [ c_ion_port_t, ctypes.POINTER(ctypes.c_uint64), ctypes.c_int ]

# int ion_port_bind_f32_array(ion_port_t, float*, int);
ion_port_bind_f32_array = ion_core.ion_port_bind_f32_array
ion_port_bind_f32_array.restype = ctypes.c_int
ion_port_bind_f32_array.argtypes = [ c_ion_port_t, ctypes.POINTER(ctypes.c_float), ctypes.c_int ]

# int ion_port_bind_f64_array(ion_port_t, double*, int);
ion_port_bind_f64_array = ion_core.ion_port_bind_f64_array
ion_port_bind_f64_array.restype = ctypes.c_int
ion_port_bind_f64_array.argtypes = [ c_ion_port_t, ctypes.POINTER(ctypes.c_double), ctypes.c_int ]


# int ion_port_bind_buffer(ion_port_t, ion_buffer_t);
ion_port_bind_buffer = ion_core.ion_port_bind_buffer
ion_port_bind_buffer.restype = ctypes.c_int
Expand Down
33 changes: 33 additions & 0 deletions src/c_ion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,39 @@ ION_PORT_BIND_IMPL(double *, f64)

#undef ION_PORT_BIND_IMPL

#define ION_PORT_BIND_ARRAY_IMPL(T, POSTFIX) \
int ion_port_bind_##POSTFIX##_array(ion_port_t obj, T v, int size) { \
try { \
reinterpret_cast<Port *>(obj)->bind(v, size); \
} catch (const Halide::Error &e) { \
log::error(e.what()); \
return 1; \
} catch (const std::exception &e) { \
log::error(e.what()); \
return 1; \
} catch (...) { \
log::error("Unknown exception was happened"); \
return 1; \
} \
\
return 0; \
}

ION_PORT_BIND_ARRAY_IMPL(int8_t *, i8)
ION_PORT_BIND_ARRAY_IMPL(int16_t *, i16)
ION_PORT_BIND_ARRAY_IMPL(int32_t *, i32)
ION_PORT_BIND_ARRAY_IMPL(int64_t *, i64)
ION_PORT_BIND_ARRAY_IMPL(bool *, u1)
ION_PORT_BIND_ARRAY_IMPL(uint8_t *, u8)
ION_PORT_BIND_ARRAY_IMPL(uint16_t *, u16)
ION_PORT_BIND_ARRAY_IMPL(uint32_t *, u32)
ION_PORT_BIND_ARRAY_IMPL(uint64_t *, u64)
ION_PORT_BIND_ARRAY_IMPL(float *, f32)
ION_PORT_BIND_ARRAY_IMPL(double *, f64)

#undef ION_PORT_BIND_ARRAY_IMPL


int ion_port_bind_buffer(ion_port_t obj, ion_buffer_t b) {
try {
// NOTE: Halide::Buffer class layout is safe to call Halide::Buffer<void>::type()
Expand Down
3 changes: 3 additions & 0 deletions test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ ion_jit_executable(array_input SRCS array_input.cc)
# Array Output test
ion_jit_executable(array_output SRCS array_output.cc)

# Scalar array Input test
ion_jit_executable(scalar_array_input SRCS scalar_array_input.cc)

# Duplicate array names test
ion_jit_executable(array_dup_names SRCS array_dup_names.cc)

Expand Down
Loading
Loading