From 7e7922ae27b9dc139c7b9ff973385f8b599e11b7 Mon Sep 17 00:00:00 2001 From: Jiacai Liu Date: Tue, 5 Nov 2024 23:24:06 +0800 Subject: [PATCH] refactor: make array to store error message (#1) * clang format * refactor: make array to store error message * run tests * fix ci * timeout --- .github/workflows/CI.yml | 33 +++ Makefile | 8 + include/wrapper.h | 89 +++----- src/ffi.rs | 12 +- src/wrapper.cpp | 476 ++++++++++++++++++++------------------- 5 files changed, 334 insertions(+), 284 deletions(-) create mode 100644 .github/workflows/CI.yml create mode 100644 Makefile diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml new file mode 100644 index 0000000..24bdac8 --- /dev/null +++ b/.github/workflows/CI.yml @@ -0,0 +1,33 @@ +name: CI + +on: + workflow_dispatch: + pull_request: + paths-ignore: + - '**.md' + push: + branches: + - main + - master + paths-ignore: + - '**.md' + +env: + CC: gcc + +jobs: + test: + timeout-minutes: 60 + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - name: Install deps + run: | + sudo apt install -y gfortran python3-dev libomp-15-dev lcov intel-mkl + - uses: Swatinem/rust-cache@v2 + - name: fmt + run: | + make fmt + - name: test + run: | + make test diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..1ddefd6 --- /dev/null +++ b/Makefile @@ -0,0 +1,8 @@ + +.PHONY: fmt +fmt: + find src include -iname "*.h" -o -iname "*.cpp" | xargs clang-format -i + +.PHONY: test +test: + cargo test diff --git a/include/wrapper.h b/include/wrapper.h index 81d26bb..2d3e092 100644 --- a/include/wrapper.h +++ b/include/wrapper.h @@ -20,60 +20,45 @@ extern "C" { -struct CError { - int type_; - const char* message; +#define VSAG_WRAPPER_MAX_ERROR_MESSAGE_LENGTH 256 - CError(int type, const char* message) : type_(type), message(message) {} +struct CError { + int type_; + char message[VSAG_WRAPPER_MAX_ERROR_MESSAGE_LENGTH]; }; -const CError* create_index( - const char* in_index_type, - const char* in_parameters, - - void** out_index_ptr -); - -const CError* build_index( - void* in_index_ptr, - size_t in_num_vectors, - size_t in_dim, - const int64_t* in_ids, - const float* in_vectors, - - const int64_t** out_failed_ids, - size_t* out_num_failed -); - -const CError* knn_search_index( - void* in_index_ptr, - size_t in_dim, - const float* in_query_vector, - size_t in_k, - const char* in_search_parameters, - - const int64_t** out_ids, - const float** out_distances, - size_t* out_num_results -); - -const CError* dump_index( - void* in_index_ptr, - const char* in_file_path -); - -const CError* load_index( - const char* in_file_path, - const char* in_index_type, - const char* in_parameters, - - void** out_index_ptr -); - -void free_error(const CError*); -void free_index(void* index_ptr); -void free_i64_vector(int64_t* vector); -void free_f32_vector(float* vector); +CError *new_error(int type_, const char *msg); +void free_error(const CError *); + +const CError *create_index(const char *in_index_type, const char *in_parameters, + + void **out_index_ptr); + +const CError *build_index(void *in_index_ptr, size_t in_num_vectors, + size_t in_dim, const int64_t *in_ids, + const float *in_vectors, + + const int64_t **out_failed_ids, + size_t *out_num_failed); + +const CError *knn_search_index(void *in_index_ptr, size_t in_dim, + const float *in_query_vector, size_t in_k, + const char *in_search_parameters, + + const int64_t **out_ids, + const float **out_distances, + size_t *out_num_results); + +const CError *dump_index(void *in_index_ptr, const char *in_file_path); + +const CError *load_index(const char *in_file_path, const char *in_index_type, + const char *in_parameters, + + void **out_index_ptr); + +void free_index(void *index_ptr); +void free_i64_vector(int64_t *vector); +void free_f32_vector(float *vector); } // extern "C" -#endif // WRAPPER_H \ No newline at end of file +#endif // WRAPPER_H diff --git a/src/ffi.rs b/src/ffi.rs index 6fa1377..637eccb 100644 --- a/src/ffi.rs +++ b/src/ffi.rs @@ -64,16 +64,20 @@ extern "C" { #[repr(C)] pub struct CError { pub type_: c_int, - pub message: *const c_char, + // should be same length with CError defined in wrapper.h + pub message: [u8; 256], } pub fn from_c_error(err: *const CError) -> crate::error::Error { let error = crate::error::Error { error_type: unsafe { std::mem::transmute::((*err).type_) }, message: unsafe { - std::ffi::CStr::from_ptr((*err).message) - .to_string_lossy() - .into_owned() + let null_pos = (*err) + .message + .iter() + .position(|&x| x == 0) + .unwrap_or((*err).message.len()); + String::from_utf8_lossy(&(*err).message[..null_pos]).into_owned() }, }; unsafe { diff --git a/src/wrapper.cpp b/src/wrapper.cpp index 58409cd..706ae99 100644 --- a/src/wrapper.cpp +++ b/src/wrapper.cpp @@ -13,275 +13,295 @@ // limitations under the License. #include "wrapper.h" +#include "vsag/factory.h" +#include "vsag/index.h" +#include #include #include +#include #include #include #include -#include "vsag/index.h" -#include "vsag/factory.h" -#include -#include template -static void -writeBinaryPOD(std::ostream& out, const T& podRef) { - out.write((char*)&podRef, sizeof(T)); +static void writeBinaryPOD(std::ostream &out, const T &podRef) { + out.write((char *)&podRef, sizeof(T)); } -template -static void -readBinaryPOD(std::istream& in, T& podRef) { - in.read((char*)&podRef, sizeof(T)); +template static void readBinaryPOD(std::istream &in, T &podRef) { + in.read((char *)&podRef, sizeof(T)); } extern "C" { - -const CError* create_index(const char* in_index_type, const char* in_parameters, void** out_index_ptr) { - if (!in_index_type || !in_parameters || !out_index_ptr) { - return new CError{static_cast(vsag::ErrorType::INVALID_ARGUMENT), "Invalid null argument."}; - } - - auto result = vsag::Factory::CreateIndex(in_index_type, in_parameters); - - if (!result.has_value()) { - // Convert C++ error to dynamically allocated CError - return new CError{static_cast(result.error().type), strdup(result.error().message.c_str())}; - } - - auto pIndex = new std::shared_ptr(result.value()); - *out_index_ptr = static_cast(pIndex); - - return nullptr; // Success: Return NULL +CError *new_error(int type_, const char *msg) { + CError *err = (CError *)malloc(sizeof(CError)); + if (err == NULL) { + return NULL; + } + + size_t msg_size = strlen(msg); + memcpy(err->message, msg, + msg_size > VSAG_WRAPPER_MAX_ERROR_MESSAGE_LENGTH + ? VSAG_WRAPPER_MAX_ERROR_MESSAGE_LENGTH + : msg_size); + + return err; } -const CError* build_index( - void* in_index_ptr, - size_t in_num_vectors, - size_t in_dim, - const int64_t* in_ids, - const float* in_vectors, - - const int64_t** out_failed_ids, - size_t* out_num_failed -) { - if (!in_index_ptr || !in_ids || !in_vectors || !out_failed_ids || !out_num_failed) { - return new CError{static_cast(vsag::ErrorType::INVALID_ARGUMENT), "Invalid null argument."}; - } - - // Cast the void pointer back to the original pointer type, std::shared_ptr* - auto pIndex = static_cast*>(in_index_ptr); - - auto base = vsag::Dataset::Make(); - base->NumElements(in_num_vectors)->Dim(in_dim)->Ids(in_ids)->Float32Vectors(in_vectors)->Owner(false); - auto result = (*pIndex)->Build(base); - - if (!result.has_value()) { - // Convert C++ error to dynamically allocated CError - return new CError{static_cast(result.error().type), strdup(result.error().message.c_str())}; - } - - // Copy the failed IDs to the output array - auto failed_ids = result.value(); - auto failed_ids_array = new int64_t[failed_ids.size()]; - std::copy(failed_ids.begin(), failed_ids.end(), failed_ids_array); - *out_failed_ids = failed_ids_array; - *out_num_failed = static_cast(failed_ids.size()); - - return nullptr; // Success: Return NULL +void free_error(const CError *error) { + if (error) { + free(const_cast(error)); // Deallocate the error struct + } } -const CError* knn_search_index( - void* in_index_ptr, - size_t in_dim, - const float* in_query_vector, - size_t in_k, - const char* in_search_parameters, - - const int64_t** out_ids, - const float** out_distances, - size_t* out_num_results -) { - if (!in_index_ptr || !in_query_vector || !in_search_parameters || !out_ids || !out_distances || !out_num_results) { - return new CError{static_cast(vsag::ErrorType::INVALID_ARGUMENT), "Invalid null argument."}; - } +const CError *create_index(const char *in_index_type, const char *in_parameters, + void **out_index_ptr) { + if (!in_index_type || !in_parameters || !out_index_ptr) { + return new_error(static_cast(vsag::ErrorType::INVALID_ARGUMENT), + "Invalid null argument."); + } - // Cast the void pointer back to the original pointer type, std::shared_ptr* - auto pIndex = static_cast*>(in_index_ptr); + auto result = vsag::Factory::CreateIndex(in_index_type, in_parameters); - auto query = vsag::Dataset::Make(); - query->NumElements(1)->Dim(in_dim)->Float32Vectors(in_query_vector)->Owner(false); - auto result = (*pIndex)->KnnSearch(query, in_k, in_search_parameters); + if (!result.has_value()) { + // Convert C++ error to dynamically allocated CError + return new_error(static_cast(result.error().type), + result.error().message.c_str()); + } - if (!result.has_value()) { - // Convert C++ error to dynamically allocated CError - return new CError{static_cast(result.error().type), strdup(result.error().message.c_str())}; - } - - auto dataset = result.value(); - auto num = dataset->GetDim(); - *out_num_results = num; - - auto ids_array = new int64_t[num]; - auto ids = dataset->GetIds(); - std::copy(ids, ids + num, ids_array); - auto distances_array = new float[num]; - auto distances = dataset->GetDistances(); - std::copy(distances, distances + num, distances_array); + auto pIndex = new std::shared_ptr(result.value()); + *out_index_ptr = static_cast(pIndex); - *out_ids = ids_array; - *out_distances = distances_array; - - return nullptr; // Success: Return NULL + return nullptr; // Success: Return NULL } -const CError* dump_index(void* in_index_ptr, const char* in_file_path) { - if (!in_index_ptr || !in_file_path) { - return new CError{static_cast(vsag::ErrorType::INVALID_ARGUMENT), "Invalid null argument."}; - } - - // Cast the void pointer back to the original pointer type, std::shared_ptr* - auto pIndex = static_cast*>(in_index_ptr); - - if (auto bs = (*pIndex)->Serialize(); bs.has_value()) { - auto keys = bs->GetKeys(); - std::vector offsets; - - std::ofstream file(in_file_path, std::ios::binary); - uint64_t offset = 0; - for (auto key : keys) { - // [len][data...][len][data...]... - vsag::Binary b = bs->Get(key); - writeBinaryPOD(file, b.size); - file.write((const char*)b.data.get(), b.size); - offsets.push_back(offset); - offset += sizeof(b.size) + b.size; - } - // footer - for (uint64_t i = 0; i < keys.size(); ++i) { - // [len][key...][offset][len][key...][offset]... - const auto& key = keys[i]; - int64_t len = key.length(); - writeBinaryPOD(file, len); - file.write(key.c_str(), key.length()); - writeBinaryPOD(file, offsets[i]); - } - // [num_keys][footer_offset]$ - writeBinaryPOD(file, keys.size()); - writeBinaryPOD(file, offset); - file.close(); - } else { - auto err = bs.error(); - return new CError{static_cast(err.type), strdup(err.message.c_str())}; - } - - return nullptr; // Success: Return NULL +const CError *build_index(void *in_index_ptr, size_t in_num_vectors, + size_t in_dim, const int64_t *in_ids, + const float *in_vectors, + + const int64_t **out_failed_ids, + size_t *out_num_failed) { + if (!in_index_ptr || !in_ids || !in_vectors || !out_failed_ids || + !out_num_failed) { + return new_error(static_cast(vsag::ErrorType::INVALID_ARGUMENT), + "Invalid null argument."); + } + + // Cast the void pointer back to the original pointer type, + // std::shared_ptr* + auto pIndex = static_cast *>(in_index_ptr); + + auto base = vsag::Dataset::Make(); + base->NumElements(in_num_vectors) + ->Dim(in_dim) + ->Ids(in_ids) + ->Float32Vectors(in_vectors) + ->Owner(false); + auto result = (*pIndex)->Build(base); + + if (!result.has_value()) { + // Convert C++ error to dynamically allocated CError + return new_error(static_cast(result.error().type), + result.error().message.c_str()); + } + + // Copy the failed IDs to the output array + auto failed_ids = result.value(); + auto failed_ids_array = new int64_t[failed_ids.size()]; + std::copy(failed_ids.begin(), failed_ids.end(), failed_ids_array); + *out_failed_ids = failed_ids_array; + *out_num_failed = static_cast(failed_ids.size()); + + return nullptr; // Success: Return NULL } -const CError* load_index( - const char* in_file_path, - const char* in_index_type, - const char* in_parameters, +const CError *knn_search_index(void *in_index_ptr, size_t in_dim, + const float *in_query_vector, size_t in_k, + const char *in_search_parameters, + + const int64_t **out_ids, + const float **out_distances, + size_t *out_num_results) { + if (!in_index_ptr || !in_query_vector || !in_search_parameters || !out_ids || + !out_distances || !out_num_results) { + return new_error(static_cast(vsag::ErrorType::INVALID_ARGUMENT), + "Invalid null argument."); + } + + // Cast the void pointer back to the original pointer type, + // std::shared_ptr* + auto pIndex = static_cast *>(in_index_ptr); + + auto query = vsag::Dataset::Make(); + query->NumElements(1) + ->Dim(in_dim) + ->Float32Vectors(in_query_vector) + ->Owner(false); + auto result = (*pIndex)->KnnSearch(query, in_k, in_search_parameters); + + if (!result.has_value()) { + // Convert C++ error to dynamically allocated CError + return new_error(static_cast(result.error().type), + result.error().message.c_str()); + } + + auto dataset = result.value(); + auto num = dataset->GetDim(); + *out_num_results = num; + + auto ids_array = new int64_t[num]; + auto ids = dataset->GetIds(); + std::copy(ids, ids + num, ids_array); + auto distances_array = new float[num]; + auto distances = dataset->GetDistances(); + std::copy(distances, distances + num, distances_array); + + *out_ids = ids_array; + *out_distances = distances_array; + + return nullptr; // Success: Return NULL +} - void** out_index_ptr -) { - if (!in_file_path || !in_index_type || !in_parameters || !out_index_ptr) { - return new CError{static_cast(vsag::ErrorType::INVALID_ARGUMENT), "Invalid null argument."}; - } +const CError *dump_index(void *in_index_ptr, const char *in_file_path) { + if (!in_index_ptr || !in_file_path) { + return new_error(static_cast(vsag::ErrorType::INVALID_ARGUMENT), + "Invalid null argument."); + } - std::ifstream file(in_file_path, std::ios::in); - file.seekg(-sizeof(uint64_t) * 2, std::ios::end); - uint64_t num_keys, footer_offset; - readBinaryPOD(file, num_keys); - readBinaryPOD(file, footer_offset); - // std::cout << "num_keys: " << num_keys << std::endl; - // std::cout << "footer_offset: " << footer_offset << std::endl; - file.seekg(footer_offset, std::ios::beg); + // Cast the void pointer back to the original pointer type, + // std::shared_ptr* + auto pIndex = static_cast *>(in_index_ptr); - std::vector keys; + if (auto bs = (*pIndex)->Serialize(); bs.has_value()) { + auto keys = bs->GetKeys(); std::vector offsets; - for (uint64_t i = 0; i < num_keys; ++i) { - int64_t key_len; - readBinaryPOD(file, key_len); - // std::cout << "key_len: " << key_len << std::endl; - char key_buf[key_len + 1]; - memset(key_buf, 0, key_len + 1); - file.read(key_buf, key_len); - // std::cout << "key: " << key_buf << std::endl; - keys.push_back(key_buf); - - uint64_t offset; - readBinaryPOD(file, offset); - // std::cout << "offset: " << offset << std::endl; - offsets.push_back(offset); - } - vsag::ReaderSet rs; - for (uint64_t i = 0; i < num_keys; ++i) { - int64_t size = 0; - if (i + 1 == num_keys) { - size = footer_offset; - } else { - size = offsets[i + 1]; - } - size -= (offsets[i] + sizeof(uint64_t)); - auto file_reader = vsag::Factory::CreateLocalFileReader( - in_file_path, offsets[i] + sizeof(uint64_t), size); - rs.Set(keys[i], file_reader); + std::ofstream file(in_file_path, std::ios::binary); + uint64_t offset = 0; + for (auto key : keys) { + // [len][data...][len][data...]... + vsag::Binary b = bs->Get(key); + writeBinaryPOD(file, b.size); + file.write((const char *)b.data.get(), b.size); + offsets.push_back(offset); + offset += sizeof(b.size) + b.size; } - - std::shared_ptr hnsw; - if (auto index = vsag::Factory::CreateIndex(in_index_type, in_parameters); - index.has_value()) { - hnsw = index.value(); - } else { - auto err = index.error(); - return new CError{static_cast(err.type), strdup(err.message.c_str())}; - } - auto res = hnsw->Deserialize(rs); - if (!res.has_value()) { - auto err = res.error(); - return new CError{static_cast(err.type), strdup(err.message.c_str())}; + // footer + for (uint64_t i = 0; i < keys.size(); ++i) { + // [len][key...][offset][len][key...][offset]... + const auto &key = keys[i]; + int64_t len = key.length(); + writeBinaryPOD(file, len); + file.write(key.c_str(), key.length()); + writeBinaryPOD(file, offsets[i]); } - - auto pIndex = new std::shared_ptr(hnsw); - *out_index_ptr = static_cast(pIndex); - - return nullptr; // Success: Return NULL + // [num_keys][footer_offset]$ + writeBinaryPOD(file, keys.size()); + writeBinaryPOD(file, offset); + file.close(); + } else { + auto err = bs.error(); + return new_error(static_cast(err.type), err.message.c_str()); + } + + return nullptr; // Success: Return NULL } -void free_error(const CError* error) { - if (error) { - free(const_cast(error->message)); // Properly deallocate the dynamically allocated message - delete error; // Deallocate the error struct +const CError *load_index(const char *in_file_path, const char *in_index_type, + const char *in_parameters, + + void **out_index_ptr) { + if (!in_file_path || !in_index_type || !in_parameters || !out_index_ptr) { + return new_error(static_cast(vsag::ErrorType::INVALID_ARGUMENT), + "Invalid null argument."); + } + + std::ifstream file(in_file_path, std::ios::in); + file.seekg(-sizeof(uint64_t) * 2, std::ios::end); + uint64_t num_keys, footer_offset; + readBinaryPOD(file, num_keys); + readBinaryPOD(file, footer_offset); + // std::cout << "num_keys: " << num_keys << std::endl; + // std::cout << "footer_offset: " << footer_offset << std::endl; + file.seekg(footer_offset, std::ios::beg); + + std::vector keys; + std::vector offsets; + for (uint64_t i = 0; i < num_keys; ++i) { + int64_t key_len; + readBinaryPOD(file, key_len); + // std::cout << "key_len: " << key_len << std::endl; + char key_buf[key_len + 1]; + memset(key_buf, 0, key_len + 1); + file.read(key_buf, key_len); + // std::cout << "key: " << key_buf << std::endl; + keys.push_back(key_buf); + + uint64_t offset; + readBinaryPOD(file, offset); + // std::cout << "offset: " << offset << std::endl; + offsets.push_back(offset); + } + + vsag::ReaderSet rs; + for (uint64_t i = 0; i < num_keys; ++i) { + int64_t size = 0; + if (i + 1 == num_keys) { + size = footer_offset; + } else { + size = offsets[i + 1]; } + size -= (offsets[i] + sizeof(uint64_t)); + auto file_reader = vsag::Factory::CreateLocalFileReader( + in_file_path, offsets[i] + sizeof(uint64_t), size); + rs.Set(keys[i], file_reader); + } + + std::shared_ptr hnsw; + if (auto index = vsag::Factory::CreateIndex(in_index_type, in_parameters); + index.has_value()) { + hnsw = index.value(); + } else { + auto err = index.error(); + return new_error(static_cast(err.type), err.message.c_str()); + } + auto res = hnsw->Deserialize(rs); + if (!res.has_value()) { + auto err = res.error(); + return new_error(static_cast(err.type), err.message.c_str()); + } + + auto pIndex = new std::shared_ptr(hnsw); + *out_index_ptr = static_cast(pIndex); + + return nullptr; // Success: Return NULL } -void free_index(void* index_ptr) { - if (index_ptr) { - // Cast the void pointer back to the original pointer type, std::shared_ptr* - std::shared_ptr* pIndex = static_cast*>(index_ptr); +void free_index(void *index_ptr) { + if (index_ptr) { + // Cast the void pointer back to the original pointer type, + // std::shared_ptr* + std::shared_ptr *pIndex = + static_cast *>(index_ptr); - // Delete the std::shared_ptr which was dynamically allocated - delete pIndex; + // Delete the std::shared_ptr which was dynamically allocated + delete pIndex; - // Note: Deleting the std::shared_ptr will automatically handle - // the decrement of the reference count and will delete the managed Index object - // if the reference count goes to zero. - } + // Note: Deleting the std::shared_ptr will automatically handle + // the decrement of the reference count and will delete the managed Index + // object if the reference count goes to zero. + } } -void free_i64_vector(int64_t* vector) { - if (vector) { - delete[] vector; - } +void free_i64_vector(int64_t *vector) { + if (vector) { + delete[] vector; + } } -void free_f32_vector(float* vector) { - if (vector) { - delete[] vector; - } +void free_f32_vector(float *vector) { + if (vector) { + delete[] vector; + } } } // extern "C" - -