Skip to content

Commit

Permalink
feature: allow to ignore qdrant count call
Browse files Browse the repository at this point in the history
  • Loading branch information
skeptrunedev committed May 7, 2024
1 parent 7e45570 commit f8297ab
Show file tree
Hide file tree
Showing 5 changed files with 72 additions and 4 deletions.
26 changes: 26 additions & 0 deletions server/src/handlers/chunk_handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1133,6 +1133,8 @@ pub struct SearchChunkData {
pub page: Option<u64>,
/// Page size is the number of chunks to fetch. This can be used to fetch more than 10 chunks at a time.
pub page_size: Option<u64>,
/// Get total page count for the query accounting for the applied filters. Defaults to true, but can be set to false to reduce latency in edge cases performance.
pub get_total_pages: Option<bool>,
/// Filters is a JSON object which can be used to filter chunks. This is useful for when you want to filter chunks by arbitrary metadata. Unlike with tag filtering, there is a performance hit for filtering on metadata.
pub filters: Option<ChunkFilter>,
/// Set date_bias to true to bias search results towards more recent chunks. This will work best in hybrid search mode.
Expand All @@ -1153,6 +1155,26 @@ pub struct SearchChunkData {
pub slim_chunks: Option<bool>,
}

impl Default for SearchChunkData {
fn default() -> Self {
SearchChunkData {
search_type: "hybrid".to_string(),
query: "".to_string(),
page: Some(1),
get_total_pages: None,
page_size: Some(10),
filters: None,
date_bias: None,
use_weights: None,
get_collisions: None,
highlight_results: None,
highlight_delimiters: None,
score_threshold: None,
slim_chunks: None,
}
}
}

#[derive(Serialize, Deserialize, Debug, ToSchema, Clone)]
#[schema(example = json!({
"metadata": [
Expand Down Expand Up @@ -1264,6 +1286,7 @@ pub async fn search_chunk(
);

let page = data.page.unwrap_or(1);
let get_total_pages = data.get_total_pages.unwrap_or(true);

let mut parsed_query = ParsedQuery {
query: data.query.clone(),
Expand Down Expand Up @@ -1292,6 +1315,7 @@ pub async fn search_chunk(
data.clone(),
parsed_query,
page,
get_total_pages,
pool,
dataset_org_plan_sub.dataset,
server_dataset_config,
Expand All @@ -1303,6 +1327,7 @@ pub async fn search_chunk(
data.clone(),
parsed_query,
page,
get_total_pages,
pool,
dataset_org_plan_sub.dataset,
server_dataset_config,
Expand All @@ -1314,6 +1339,7 @@ pub async fn search_chunk(
data.clone(),
parsed_query,
page,
get_total_pages,
pool,
dataset_org_plan_sub.dataset,
&mut timer,
Expand Down
14 changes: 13 additions & 1 deletion server/src/handlers/group_handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1085,6 +1085,8 @@ pub struct SearchWithinGroupData {
pub page: Option<u64>,
/// The page size is the number of chunks to fetch. This can be used to fetch more than 10 chunks at a time.
pub page_size: Option<u64>,
/// Get total page count for the query accounting for the applied filters. Defaults to true, but can be set to false to reduce latency in edge cases performance.
pub get_total_pages: Option<bool>,
/// Filters is a JSON object which can be used to filter chunks. The values on each key in the object will be used to check for an exact substring match on the metadata values for each existing chunk. This is useful for when you want to filter chunks by arbitrary metadata. Unlike with tag filtering, there is a performance hit for filtering on metadata.
pub filters: Option<ChunkFilter>,
/// Group specifies the group to search within. Results will only consist of chunks which are bookmarks within the specified group.
Expand Down Expand Up @@ -1115,6 +1117,7 @@ impl From<SearchWithinGroupData> for SearchChunkData {
query: data.query,
page: data.page,
page_size: data.page_size,
get_total_pages: data.get_total_pages,
filters: data.filters,
search_type: data.search_type,
date_bias: data.date_bias,
Expand Down Expand Up @@ -1177,6 +1180,7 @@ pub async fn search_within_group(

//search over the links as well
let page = data.page.unwrap_or(1);
let get_total_pages = data.get_total_pages.unwrap_or(true);
let group_id = data.group_id;
let dataset_id = dataset_org_plan_sub.dataset.id;
let search_pool = pool.clone();
Expand Down Expand Up @@ -1217,6 +1221,7 @@ pub async fn search_within_group(
parsed_query,
group,
page,
get_total_pages,
search_pool,
dataset_org_plan_sub.dataset,
server_dataset_config,
Expand All @@ -1229,6 +1234,7 @@ pub async fn search_within_group(
parsed_query,
group,
page,
get_total_pages,
search_pool,
dataset_org_plan_sub.dataset,
server_dataset_config,
Expand All @@ -1241,6 +1247,7 @@ pub async fn search_within_group(
parsed_query,
group,
page,
get_total_pages,
search_pool,
dataset_org_plan_sub.dataset,
server_dataset_config,
Expand Down Expand Up @@ -1278,6 +1285,8 @@ pub struct SearchOverGroupsData {
pub page: Option<u64>,
/// Page size is the number of chunks to fetch. This can be used to fetch more than 10 chunks at a time.
pub page_size: Option<u32>,
/// Get total page count for the query accounting for the applied filters. Defaults to true, but can be set to false to reduce latency in edge cases performance.
pub get_total_pages: Option<bool>,
/// Filters is a JSON object which can be used to filter chunks. The values on each key in the object will be used to check for an exact substring match on the metadata values for each existing chunk. This is useful for when you want to filter chunks by arbitrary metadata. Unlike with tag filtering, there is a performance hit for filtering on metadata.
pub filters: Option<ChunkFilter>,
/// Set get_collisions to true to get the collisions for each chunk. This will only apply if environment variable COLLISIONS_ENABLED is set to true.
Expand Down Expand Up @@ -1334,8 +1343,8 @@ pub async fn search_over_groups(
dataset_org_plan_sub.dataset.server_configuration.clone(),
);

//search over the links as well
let page = data.page.unwrap_or(1);
let get_total_pages = data.get_total_pages.unwrap_or(true);

let mut parsed_query = ParsedQuery {
query: data.query.clone(),
Expand All @@ -1359,6 +1368,7 @@ pub async fn search_over_groups(
data.clone(),
parsed_query,
page,
get_total_pages,
pool,
dataset_org_plan_sub.dataset,
server_dataset_config,
Expand All @@ -1370,6 +1380,7 @@ pub async fn search_over_groups(
data.clone(),
parsed_query,
page,
get_total_pages,
pool,
dataset_org_plan_sub.dataset,
server_dataset_config,
Expand All @@ -1381,6 +1392,7 @@ pub async fn search_over_groups(
data.clone(),
parsed_query,
page,
get_total_pages,
pool,
dataset_org_plan_sub.dataset,
server_dataset_config,
Expand Down
1 change: 1 addition & 0 deletions server/src/handlers/message_handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -675,6 +675,7 @@ pub async fn stream_response(
negated_words: None,
},
dataset.id,
false,
pool.clone(),
config,
)
Expand Down
5 changes: 5 additions & 0 deletions server/src/operators/qdrant_operator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -996,7 +996,12 @@ pub async fn recommend_qdrant_groups_query(
pub async fn get_point_count_qdrant_query(
filters: Filter,
config: ServerDatasetConfiguration,
get_total_pages: bool,
) -> Result<u64, ServiceError> {
if !get_total_pages {
return Ok(0);
};

let qdrant_collection = config.QDRANT_COLLECTION_NAME;

let qdrant =
Expand Down
30 changes: 27 additions & 3 deletions server/src/operators/search_operator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,7 @@ pub async fn assemble_qdrant_filter(
pub async fn retrieve_qdrant_points_query(
vector: VectorType,
page: u64,
get_total_pages: bool,
limit: u64,
score_threshold: Option<f32>,
filters: Option<ChunkFilter>,
Expand Down Expand Up @@ -234,7 +235,7 @@ pub async fn retrieve_qdrant_points_query(
config.clone(),
);

let count_future = get_point_count_qdrant_query(filter, config);
let count_future = get_point_count_qdrant_query(filter, config, get_total_pages);

let (point_ids, count) = futures::join!(point_ids_future, count_future);

Expand Down Expand Up @@ -375,6 +376,7 @@ pub struct SearchOverGroupsQueryResult {
pub async fn retrieve_group_qdrant_points_query(
vector: VectorType,
page: u64,
get_total_pages: bool,
filters: Option<ChunkFilter>,
limit: u32,
score_threshold: Option<f32>,
Expand Down Expand Up @@ -405,7 +407,7 @@ pub async fn retrieve_group_qdrant_points_query(
config.clone(),
);

let count_future = get_point_count_qdrant_query(filter, config);
let count_future = get_point_count_qdrant_query(filter, config, get_total_pages);

let (point_ids, count) = futures::join!(point_id_future, count_future);

Expand Down Expand Up @@ -501,6 +503,7 @@ pub async fn global_unfiltered_top_match_query(
pub async fn search_within_chunk_group_query(
embedding_vector: VectorType,
page: u64,
get_total_pages: bool,
pool: web::Data<Pool>,
filters: Option<ChunkFilter>,
limit: u64,
Expand Down Expand Up @@ -533,7 +536,7 @@ pub async fn search_within_chunk_group_query(
config.clone(),
);

let count_future = get_point_count_qdrant_query(filter, config);
let count_future = get_point_count_qdrant_query(filter, config, get_total_pages);

let (point_ids, count) = futures::join!(point_ids_future, count_future);

Expand Down Expand Up @@ -1154,6 +1157,7 @@ pub async fn search_semantic_chunks(
data: SearchChunkData,
parsed_query: ParsedQuery,
page: u64,
get_total_pages: bool,
pool: web::Data<Pool>,
dataset: Dataset,
timer: &mut Timer,
Expand Down Expand Up @@ -1190,6 +1194,7 @@ pub async fn search_semantic_chunks(
let search_chunk_query_results = retrieve_qdrant_points_query(
VectorType::Dense(embedding_vector),
page,
get_total_pages,
data.page_size.unwrap_or(10),
data.score_threshold,
data.filters.clone(),
Expand Down Expand Up @@ -1219,6 +1224,7 @@ pub async fn search_full_text_chunks(
data: SearchChunkData,
parsed_query: ParsedQuery,
page: u64,
get_total_pages: bool,
pool: web::Data<Pool>,
dataset: Dataset,
config: ServerDatasetConfiguration,
Expand All @@ -1243,6 +1249,7 @@ pub async fn search_full_text_chunks(
let search_chunk_query_results = retrieve_qdrant_points_query(
VectorType::Sparse(embedding_vector),
page,
get_total_pages,
data.page_size.unwrap_or(10),
data.score_threshold,
data.filters.clone(),
Expand Down Expand Up @@ -1280,6 +1287,7 @@ pub async fn search_hybrid_chunks(
data: SearchChunkData,
parsed_query: ParsedQuery,
page: u64,
get_total_pages: bool,
pool: web::Data<Pool>,
dataset: Dataset,
config: ServerDatasetConfiguration,
Expand Down Expand Up @@ -1317,6 +1325,7 @@ pub async fn search_hybrid_chunks(
data.filters.clone(),
parsed_query.clone(),
dataset.id,
get_total_pages,
pool.clone(),
config.clone(),
);
Expand All @@ -1325,6 +1334,7 @@ pub async fn search_hybrid_chunks(
data.clone(),
parsed_query,
page,
get_total_pages,
pool.clone(),
dataset,
config,
Expand Down Expand Up @@ -1485,6 +1495,7 @@ pub async fn search_semantic_groups(
parsed_query: ParsedQuery,
group: ChunkGroup,
page: u64,
get_total_pages: bool,
pool: web::Data<Pool>,
dataset: Dataset,
config: ServerDatasetConfiguration,
Expand All @@ -1505,6 +1516,7 @@ pub async fn search_semantic_groups(
let search_semantic_chunk_query_results = search_within_chunk_group_query(
VectorType::Dense(embedding_vector),
page,
get_total_pages,
pool.clone(),
data.filters.clone(),
data.page_size.unwrap_or(10),
Expand Down Expand Up @@ -1540,6 +1552,7 @@ pub async fn search_full_text_groups(
parsed_query: ParsedQuery,
group: ChunkGroup,
page: u64,
get_total_pages: bool,
pool: web::Data<Pool>,
dataset: Dataset,
config: ServerDatasetConfiguration,
Expand All @@ -1550,6 +1563,7 @@ pub async fn search_full_text_groups(
let search_chunk_query_results = search_within_chunk_group_query(
VectorType::Sparse(embedding_vector),
page,
get_total_pages,
pool.clone(),
data_inner.filters.clone(),
data.page_size.unwrap_or(10),
Expand Down Expand Up @@ -1585,6 +1599,7 @@ pub async fn search_hybrid_groups(
parsed_query: ParsedQuery,
group: ChunkGroup,
page: u64,
get_total_pages: bool,
pool: web::Data<Pool>,
dataset: Dataset,
config: ServerDatasetConfiguration,
Expand All @@ -1608,6 +1623,7 @@ pub async fn search_hybrid_groups(
let semantic_future = search_within_chunk_group_query(
VectorType::Dense(dense_embedding_vector),
page,
get_total_pages,
pool.clone(),
data.filters.clone(),
data.page_size.unwrap_or(10),
Expand All @@ -1621,6 +1637,7 @@ pub async fn search_hybrid_groups(
let full_text_future = search_within_chunk_group_query(
VectorType::Sparse(sparse_embedding_vector),
page,
get_total_pages,
pool.clone(),
data_inner.filters.clone(),
data.page_size.unwrap_or(10),
Expand Down Expand Up @@ -1712,6 +1729,7 @@ pub async fn semantic_search_over_groups(
data: SearchOverGroupsData,
parsed_query: ParsedQuery,
page: u64,
get_total_pages: bool,
pool: web::Data<Pool>,
dataset: Dataset,
config: ServerDatasetConfiguration,
Expand All @@ -1730,6 +1748,7 @@ pub async fn semantic_search_over_groups(
let search_chunk_query_results = retrieve_group_qdrant_points_query(
VectorType::Dense(embedding_vector),
page,
get_total_pages,
data.filters.clone(),
data.page_size.unwrap_or(10),
data.score_threshold,
Expand All @@ -1754,6 +1773,7 @@ pub async fn full_text_search_over_groups(
data: SearchOverGroupsData,
parsed_query: ParsedQuery,
page: u64,
get_total_pages: bool,
pool: web::Data<Pool>,
dataset: Dataset,
config: ServerDatasetConfiguration,
Expand All @@ -1765,6 +1785,7 @@ pub async fn full_text_search_over_groups(
let search_chunk_query_results = retrieve_group_qdrant_points_query(
VectorType::Sparse(embedding_vector),
page,
get_total_pages,
data.filters.clone(),
data.page_size.unwrap_or(10),
data.score_threshold,
Expand Down Expand Up @@ -1836,6 +1857,7 @@ pub async fn hybrid_search_over_groups(
data: SearchOverGroupsData,
parsed_query: ParsedQuery,
page: u64,
get_total_pages: bool,
pool: web::Data<Pool>,
dataset: Dataset,
config: ServerDatasetConfiguration,
Expand All @@ -1859,6 +1881,7 @@ pub async fn hybrid_search_over_groups(
let semantic_future = retrieve_group_qdrant_points_query(
VectorType::Dense(dense_embedding_vector),
page,
get_total_pages,
data.filters.clone(),
data.page_size.unwrap_or(10),
data.score_threshold,
Expand All @@ -1872,6 +1895,7 @@ pub async fn hybrid_search_over_groups(
let full_text_future = retrieve_group_qdrant_points_query(
VectorType::Sparse(sparse_embedding_vector),
page,
get_total_pages,
data.filters.clone(),
data.page_size.unwrap_or(10),
data.score_threshold,
Expand Down

0 comments on commit f8297ab

Please sign in to comment.