Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion c/lib.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -413,7 +413,7 @@ USEARCH_EXPORT size_t usearch_search(
USEARCH_EXPORT size_t usearch_filtered_search( //
usearch_index_t index, //
void const* query, usearch_scalar_kind_t query_kind, size_t results_limit, //
int (*filter)(usearch_key_t key, void* filter_state), void* filter_state, //
usearch_filtered_search_callback_t filter, void* filter_state, //
usearch_key_t* found_keys, usearch_distance_t* found_distances, usearch_error_t* error) {

USEARCH_ASSERT(index && query && filter && error && "Missing arguments");
Expand Down
6 changes: 5 additions & 1 deletion c/usearch.h
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,10 @@ USEARCH_EXPORT typedef struct usearch_init_options_t {
bool multi;
} usearch_init_options_t;

extern int goFilteredSearchCallback(usearch_key_t, void*);

USEARCH_EXPORT typedef int (*usearch_filtered_search_callback_t)(usearch_key_t, void*);

/**
* @brief Retrieves the version of the library.
* @return The version of the library.
Expand Down Expand Up @@ -391,7 +395,7 @@ USEARCH_EXPORT size_t usearch_search( //
USEARCH_EXPORT size_t usearch_filtered_search( //
usearch_index_t index, //
void const* query_vector, usearch_scalar_kind_t query_kind, size_t count, //
int (*filter)(usearch_key_t key, void* filter_state), void* filter_state, //
usearch_filtered_search_callback_t filter, void* filter_state, //
usearch_key_t* keys, usearch_distance_t* distances, usearch_error_t* error);

/**
Expand Down
135 changes: 135 additions & 0 deletions golang/lib.go
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,12 @@ func DefaultConfig(dimensions uint) IndexConfig {
return c
}

// FilteredSearchHandler include the callback functiona and user data
type FilteredSearchHandler struct {
Callback func(key Key, handler *FilteredSearchHandler) int
Data any
}

// Index represents a USearch approximate nearest neighbor index.
// It implements io.Closer for idiomatic resource cleanup.
//
Expand Down Expand Up @@ -638,6 +644,56 @@ func (index *Index) Search(query []float32, limit uint) (keys []Key, distances [
return keys, distances, nil
}

// Search finds the k nearest neighbors to the query vector.
//
// Parameters:
// - query: Must have exactly Dimensions() elements
// - limit: Maximum number of results to return
//
// Returns:
// - keys: IDs of the nearest vectors (up to limit)
// - distances: Distance to each result (same length as keys)
// - err: Error if query is invalid or search fails
//
// The actual number of results may be less than limit if the index
// contains fewer vectors.
func (index *Index) FilteredSearch(query []float32, limit uint, handler *FilteredSearchHandler) (keys []Key, distances []float32, err error) {
if index.handle == nil {
panic("index is uninitialized")
}

if len(query) == 0 {
return nil, nil, errors.New("query vector cannot be empty")
}
if uint(len(query)) != index.config.Dimensions {
return nil, nil, fmt.Errorf("query dimension mismatch: got %d, expected %d", len(query), index.config.Dimensions)
}
if handler == nil {
return nil, nil, errors.New("filtered search handler cannot be nil")
}
if limit == 0 {
return []Key{}, []float32{}, nil
}

keys = make([]Key, limit)
distances = make([]float32, limit)
var errorMessage *C.char
resultCount := uint(C.usearch_filtered_search(index.handle, unsafe.Pointer(&query[0]), C.usearch_scalar_f32_k, (C.size_t)(limit),
(C.usearch_filtered_search_callback_t)(C.goFilteredSearchCallback), unsafe.Pointer(handler),
(*C.usearch_key_t)(&keys[0]), (*C.usearch_distance_t)(&distances[0]), (*C.usearch_error_t)(&errorMessage)))
runtime.KeepAlive(query)
runtime.KeepAlive(keys)
runtime.KeepAlive(distances)
runtime.KeepAlive(handler)
if errorMessage != nil {
return nil, nil, errors.New(C.GoString(errorMessage))
}

keys = keys[:resultCount]
distances = distances[:resultCount]
return keys, distances, nil
}

// SearchUnsafe performs k-Approximate Nearest Neighbors Search using an unsafe pointer.
//
// SAFETY REQUIREMENTS:
Expand Down Expand Up @@ -675,6 +731,48 @@ func (index *Index) SearchUnsafe(query unsafe.Pointer, limit uint) (keys []Key,
return keys, distances, nil
}

//export goFilteredSearchCallback
func goFilteredSearchCallback(key C.usearch_key_t, ptr unsafe.Pointer) C.int {
handler := (*FilteredSearchHandler)(ptr)
return C.int(handler.Callback(Key(key), handler))
}

// Filtred Search performs k-Approximate Nearest Neighbors Search for the closest vectors to the query vector with filtering.
func (index *Index) FilteredSearchUnsafe(query unsafe.Pointer, limit uint, handler *FilteredSearchHandler) (keys []Key, distances []float32, err error) {
if index.handle == nil {
panic("index is uninitialized")
}

if query == nil {
return nil, nil, errors.New("query pointer cannot be nil")
}

if handler == nil {
return nil, nil, errors.New("filtered search handler cannot be nil")
}

if limit == 0 {
return []Key{}, []float32{}, nil
}

keys = make([]Key, limit)
distances = make([]float32, limit)
var errorMessage *C.char
resultCount := uint(C.usearch_filtered_search(index.handle, query, index.config.Quantization.CValue(), (C.size_t)(limit),
(C.usearch_filtered_search_callback_t)(C.goFilteredSearchCallback), unsafe.Pointer(handler),
(*C.usearch_key_t)(&keys[0]), (*C.usearch_distance_t)(&distances[0]), (*C.usearch_error_t)(&errorMessage)))
runtime.KeepAlive(keys)
runtime.KeepAlive(distances)
runtime.KeepAlive(handler)
if errorMessage != nil {
return nil, nil, errors.New(C.GoString(errorMessage))
}

keys = keys[:resultCount]
distances = distances[:resultCount]
return keys, distances, nil
}

// ExactSearch performs multithreaded exact nearest neighbors search.
// Unlike the index-based search, this computes distances to all vectors in the dataset.
//
Expand Down Expand Up @@ -819,6 +917,43 @@ func (index *Index) SearchI8(query []int8, limit uint) (keys []Key, distances []
return keys, distances, nil
}

func (index *Index) FilteredSearchI8(query []int8, limit uint, handler *FilteredSearchHandler) (keys []Key, distances []float32, err error) {
if index.handle == nil {
panic("index is uninitialized")
}

if len(query) == 0 {
return nil, nil, errors.New("query vector cannot be empty")
}
if uint(len(query)) != index.config.Dimensions {
return nil, nil, fmt.Errorf("query dimension mismatch: got %d, expected %d", len(query), index.config.Dimensions)
}
if handler == nil {
return nil, nil, errors.New("filtered search handler cannot be nil")
}
if limit == 0 {
return []Key{}, []float32{}, nil
}

keys = make([]Key, limit)
distances = make([]float32, limit)
var errorMessage *C.char
resultCount := uint(C.usearch_filtered_search(index.handle, unsafe.Pointer(&query[0]), C.usearch_scalar_i8_k, (C.size_t)(limit),
(C.usearch_filtered_search_callback_t)(C.goFilteredSearchCallback), unsafe.Pointer(handler),
(*C.usearch_key_t)(&keys[0]), (*C.usearch_distance_t)(&distances[0]), (*C.usearch_error_t)(&errorMessage)))
runtime.KeepAlive(query)
runtime.KeepAlive(keys)
runtime.KeepAlive(distances)
runtime.KeepAlive(handler)
if errorMessage != nil {
return nil, nil, errors.New(C.GoString(errorMessage))
}

keys = keys[:resultCount]
distances = distances[:resultCount]
return keys, distances, nil
}

// DistanceI8 computes the distance between two int8 vectors.
//
// Example:
Expand Down
Loading