Skip to content

Commit

Permalink
support faiss ivflat index (TuGraph-family#695)
Browse files Browse the repository at this point in the history
* support vector index

* update camke

* fix cpplint

* fix cpplint

* fix cpplint

* fix cpplint

* fix cpplint && rename vector index counter

* support vector index blob to store in the index list table

* update dockerfile to support faiss dynamic library

* change the procedure function

* change vim name && ready for incremental update index algorithms

* HNSW interface

* HNSW interface ready

* update procedure

* update cmake file

* remove faiss && update cmake file

* update cmake file

* support vsag interface

* support procedure for hnsw

* update hnsw

* supports vsag Add's delete opertion

* update vsag add

* update vsag delete_ids

* update vsag serialize

* Fix bug and Submit test case

* fix some bug

* fix format

* fix format

* fix format

* Update index_manager.cpp

* Update field_extractor.cpp

* Update field_extractor.cpp

* Update field_extractor.cpp

* Update field_extractor.cpp

* support faiss hnsw

* merge

* support faiss ivf_flat

* delete file

* fix bug

* fix ci

* fix ci

* fix test_faiss_index

* fix some bug

* support ivf_flat range search

* fix ci

* fix procedure

* fix test

* fix ci

* fix test

* fix test

* fix test

* fix test

* fix test

* fix memory leak

* fix memory leak

* fix memory leak

* fix memory leak

* fix memory leak

* fix memory leak

* fix memory leak bug

* fix memory leak

* fix memory leak

* Update field_extractor.h

* fix conflict

* fix ci

* fix bug

* fix bug

* fix test

* fix test

* cmake fix

* fix makefile

* fix cmake

---------

Co-authored-by: PPPoint-t <[email protected]>
Co-authored-by: lipanpan03 <[email protected]>
Co-authored-by: PPPoint_ <[email protected]>
  • Loading branch information
4 people authored Dec 20, 2024
1 parent 5a22d93 commit cc68bcf
Show file tree
Hide file tree
Showing 17 changed files with 729 additions and 52 deletions.
1 change: 1 addition & 0 deletions include/lgraph/lgraph_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -1375,6 +1375,7 @@ struct VectorIndexSpec {
std::string distance_type;
int hnsw_m;
int hnsw_ef_construction;
int ivf_flat_nlist;
};

struct EdgeUid {
Expand Down
7 changes: 7 additions & 0 deletions src/BuildLGraphApi.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ set(LGRAPH_CORE_SRC
core/transaction.cpp
core/vertex_index.cpp
core/vector_index.cpp
core/faiss_ivf_flat.cpp
core/vsag_hnsw.cpp
core/wal.cpp
core/lmdb/mdb.c
Expand Down Expand Up @@ -122,11 +123,15 @@ if (CMAKE_CXX_COMPILER_ID STREQUAL "GNU")
pthread
rt
z
/opt/OpenBLAS/lib/libopenblas.a
faiss
)
elseif (CMAKE_CXX_COMPILER_ID MATCHES "Clang")
if (CMAKE_SYSTEM_NAME STREQUAL "Darwin")
target_link_libraries(${TARGET_LGRAPH} PUBLIC
vsag
/opt/OpenBLAS/lib/libopenblas.a
faiss
${Boost_LIBRARIES}
omp
pthread
Expand All @@ -136,6 +141,8 @@ elseif (CMAKE_CXX_COMPILER_ID MATCHES "Clang")
else ()
target_link_libraries(${TARGET_LGRAPH} PUBLIC
vsag
/opt/OpenBLAS/lib/libopenblas.a
faiss
rt
omp
pthread
Expand Down
1 change: 1 addition & 0 deletions src/BuildLGraphApiForJNI.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ set(LGRAPH_CORE_SRC
core/transaction.cpp
core/vertex_index.cpp
core/vector_index.cpp
core/faiss_ivf_flat.cpp
core/vsag_hnsw.cpp
core/wal.cpp
core/lmdb/mdb.c
Expand Down
4 changes: 4 additions & 0 deletions src/BuildLGraphServer.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,8 @@ if (NOT (CMAKE_SYSTEM_NAME STREQUAL "Darwin"))
geax_isogql
bolt
vsag
/opt/OpenBLAS/lib/libopenblas.a
faiss
# begin static linking
-Wl,-Bstatic
cpprest
Expand Down Expand Up @@ -140,4 +142,6 @@ target_link_libraries(${TARGET_SERVER}
${TARGET_SERVER_LIB}
librocksdb.a
vsag
/opt/OpenBLAS/lib/libopenblas.a
faiss
)
167 changes: 167 additions & 0 deletions src/core/faiss_ivf_flat.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
/**
* Copyright 2022 AntGroup CO., Ltd.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include <utility>
#include "core/faiss_ivf_flat.h"
#include "tools/lgraph_log.h"
#include "fma-common/string_formatter.h"
#include "lgraph/lgraph_exceptions.h"

namespace lgraph {
IVFFlat::IVFFlat(const std::string& label, const std::string& name,
const std::string& distance_type,
const std::string& index_type, int vec_dimension,
std::vector<int> index_spec)
: VectorIndex(label, name, distance_type, index_type,
vec_dimension, std::move(index_spec)) {
Build();
LOG_INFO() << FMA_FMT("Create IVF_Flat instance, {}:{}", GetLabel(), GetName());
}

IVFFlat::~IVFFlat() {
LOG_INFO() << FMA_FMT("Destroy IVF_Flat instance, {}:{}", GetLabel(), GetName());
IPquantizer_ = nullptr;
L2quantizer_ = nullptr;
index_ = nullptr;
}

// add vector to index
void IVFFlat::Add(const std::vector<std::vector<float>>& vectors,
const std::vector<int64_t>& vids) {
if (vectors.size() != vids.size()) {
THROW_CODE(VectorIndexException,
"size mismatch, vectors.size:{}, vids.size:{}", vectors.size(), vids.size());
}
if (vectors.empty()) {
return;
}
auto num_vectors = vectors.size();
// reduce dimension
std::vector<float> index_vectors;
index_vectors.reserve(num_vectors * vec_dimension_);
for (const auto& vec : vectors) {
index_vectors.insert(index_vectors.end(), vec.begin(), vec.end());
}
if (!index_->is_trained) {
// train after build quantizer
assert(!index_->is_trained);
index_->train(num_vectors, index_vectors.data());
assert(index_->is_trained);
index_->add_with_ids(num_vectors, index_vectors.data(), vids.data());
} else {
THROW_CODE(VectorIndexException, "failed to add vector to index");
}
}

void IVFFlat::Clear() {
IPquantizer_ = nullptr;
L2quantizer_ = nullptr;
index_ = nullptr;
Build();
}

void IVFFlat::Remove(const std::vector<int64_t>& vids) {
// not support now
}

// build index
void IVFFlat::Build() {
if (distance_type_ == "l2") {
L2quantizer_ = std::make_shared<faiss::IndexFlatL2>(vec_dimension_);
index_ = std::make_shared<faiss::IndexIVFFlat>
(L2quantizer_.get(), vec_dimension_, index_spec_[0]);
} else if (distance_type_ == "ip") {
IPquantizer_ = std::make_shared<faiss::IndexFlatIP>(vec_dimension_);
index_ = std::make_shared<faiss::IndexIVFFlat>
(IPquantizer_.get(), vec_dimension_, index_spec_[0]);
} else {
THROW_CODE(InputError, "failed to build vector index");
}
}

// serialize index
std::vector<uint8_t> IVFFlat::Save() {
faiss::VectorIOWriter writer;
faiss::write_index(index_.get(), &writer, 0);
return writer.data;
}

// load index form serialization
void IVFFlat::Load(std::vector<uint8_t>& idx_bytes) {
faiss::VectorIOReader reader;
reader.data = idx_bytes;
auto loadindex = faiss::read_index(&reader);
index_.reset(dynamic_cast<faiss::IndexIVFFlat*>(loadindex));
}

// search vector in index
std::vector<std::pair<int64_t, float>>
IVFFlat::KnnSearch(const std::vector<float>& query, int64_t top_k, int ef_search) {
if (query.empty() || top_k == 0) {
THROW_CODE(InputError, "please check the input");
}
std::vector<std::pair<int64_t, float>> ret;
std::vector<float> distances(top_k);
std::vector<int64_t> indices(top_k);
if (index_->ntotal == 0) {
THROW_CODE(InputError, "there is no indexed vector");
}
index_->nprobe = static_cast<size_t>(ef_search);
index_->search(1, query.data(), top_k, distances.data(), indices.data());
for (int64_t i = 0; i < top_k; ++i) {
ret.emplace_back(indices[i], distances[i]);
}
return ret;
}

std::vector<std::pair<int64_t, float>>
IVFFlat::RangeSearch(const std::vector<float>& query, float radius, int ef_search, int limit) {
if (query.empty()) {
THROW_CODE(InputError, "please check the input");
}
std::vector<std::pair<int64_t, float>> ret;
if (index_->ntotal == 0) {
THROW_CODE(InputError, "there is no indexed vector");
}
index_->nprobe = static_cast<size_t>(ef_search);
faiss::RangeSearchResult result(1);
index_->range_search(1, query.data(), radius, &result);
if (limit != -1) {
int64_t max = (static_cast<int64_t>(result.lims[1]) < limit) ?
static_cast<int64_t>(result.lims[1]) : limit;
for (int64_t i = 0; i < max; ++i) {
ret.emplace_back(result.labels[i], result.distances[i]);
}
} else {
for (int64_t i = 0; i < static_cast<int64_t>(result.lims[1]); ++i) {
ret.emplace_back(result.labels[i], result.distances[i]);
}
}
return ret;
}

int64_t IVFFlat::GetElementsNum() {
return index_->ntotal;
}

int64_t IVFFlat::GetMemoryUsage() {
// not support
return 0;
}

int64_t IVFFlat::GetDeletedIdsNum() {
// not support
return 0;
}

} // namespace lgraph
90 changes: 90 additions & 0 deletions src/core/faiss_ivf_flat.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
/**
* Copyright 2022 AntGroup CO., Ltd.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once

#include <vector>
#include <cstdint>
#include "core/vector_index.h"
#include "faiss/index_io.h"
#include "faiss/impl/io.h"
#include "faiss/IndexFlat.h"
#include "faiss/IndexIVFFlat.h"
#include "faiss/impl/AuxIndexStructures.h"

namespace lgraph {

class IVFFlat : public VectorIndex {
friend class Schema;
friend class LightningGraph;
friend class Transaction;
friend class IndexManager;

std::shared_ptr<faiss::IndexFlatL2> L2quantizer_;
std::shared_ptr<faiss::IndexFlatIP> IPquantizer_;
std::shared_ptr<faiss::IndexIVFFlat> index_;

// build index
void Build();

public:
IVFFlat(const std::string& label, const std::string& name,
const std::string& distance_type, const std::string& index_type,
int vec_dimension, std::vector<int> index_spec);

IVFFlat(const IVFFlat& rhs) = delete;

IVFFlat(IVFFlat&& rhs) = delete;

~IVFFlat() override;

IVFFlat& operator=(const IVFFlat& rhs) = delete;

IVFFlat& operator=(IVFFlat&& rhs) = delete;

// add vector to index and build index
void Add(const std::vector<std::vector<float>>& vectors,
const std::vector<int64_t>& vids) override;

void Remove(const std::vector<int64_t>& vids) override;

void Clear() override;

// serialize index
std::vector<uint8_t> Save() override;

// load index form serialization
void Load(std::vector<uint8_t>& idx_bytes) override;

// search vector in index
std::vector<std::pair<int64_t, float>> KnnSearch(
const std::vector<float>& query, int64_t top_k, int ef_search) override;

std::vector<std::pair<int64_t, float>> RangeSearch(
const std::vector<float>& query, float radius, int ef_search, int limit) override;

int64_t GetElementsNum() override;
int64_t GetMemoryUsage() override;
int64_t GetDeletedIdsNum() override;

template <typename T>
static void writeBinaryPOD(std::ostream& out, const T& podRef) {
out.write((char*)&podRef, sizeof(T));
}

template <typename T>
static void readBinaryPOD(std::istream& in, T& podRef) {
in.read((char*)&podRef, sizeof(T));
}
};
} // namespace lgraph
Loading

0 comments on commit cc68bcf

Please sign in to comment.