Skip to content

Commit c79de28

Browse files
committed
feature: add strategies to recommendations
1 parent be5f173 commit c79de28

File tree

4 files changed

+41
-18
lines changed

4 files changed

+41
-18
lines changed

server/src/handlers/chunk_handler.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1438,6 +1438,8 @@ pub struct RecommendChunksRequest {
14381438
pub positive_tracking_ids: Option<Vec<String>>,
14391439
/// The tracking_ids of the chunks to be used as negative examples for the recommendation. The chunks in this array will be used to filter out similar chunks.
14401440
pub negative_tracking_ids: Option<Vec<String>>,
1441+
/// Strategy to use for recommendations, either "average_vector" or "best_score". The default is "average_vector". The "average_vector" strategy will construct a single average vector from the positive and negative samples then use it to perform a pseudo-search. The "best_score" strategy is more advanced and navigates the HNSW with a heuristic of picking edges where the point is closer to the positive samples than it is the negatives.
1442+
pub strategy: Option<String>,
14411443
/// Filters to apply to the chunks to be recommended. This is a JSON object which contains the filters to apply to the chunks to be recommended. The default is None.
14421444
pub filters: Option<ChunkFilter>,
14431445
/// The number of chunks to return. This is the number of chunks which will be returned in the response. The default is 10.
@@ -1618,6 +1620,7 @@ pub async fn get_recommended_chunks(
16181620
let recommended_qdrant_point_ids = recommend_qdrant_query(
16191621
positive_qdrant_ids,
16201622
negative_qdrant_ids,
1623+
data.strategy.clone(),
16211624
data.filters.clone(),
16221625
limit,
16231626
dataset_org_plan_sub.dataset.id,

server/src/handlers/group_handler.rs

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
use super::{
22
auth_handler::{AdminOnly, LoggedUser},
3-
chunk_handler::{
4-
parse_query, ChunkFilter, ScoreChunkDTO, SearchChunkData,
5-
},
3+
chunk_handler::{parse_query, ChunkFilter, ScoreChunkDTO, SearchChunkData},
64
};
75
use crate::{
86
data::models::{
@@ -876,7 +874,7 @@ pub struct GenerateOffGroupData {
876874
}
877875

878876
#[derive(Debug, Serialize, Deserialize, Clone, ToSchema)]
879-
pub struct ReccomendGroupChunksRequest {
877+
pub struct RecommendGroupChunksRequest {
880878
/// The ids of the groups to be used as positive examples for the recommendation. The groups in this array will be used to find similar groups.
881879
pub positive_group_ids: Option<Vec<uuid::Uuid>>,
882880
/// The ids of the groups to be used as negative examples for the recommendation. The groups in this array will be used to filter out similar groups.
@@ -885,6 +883,8 @@ pub struct ReccomendGroupChunksRequest {
885883
pub positive_group_tracking_ids: Option<Vec<String>>,
886884
/// The ids of the groups to be used as negative examples for the recommendation. The groups in this array will be used to filter out similar groups.
887885
pub negative_group_tracking_ids: Option<Vec<String>>,
886+
/// Strategy to use for recommendations, either "average_vector" or "best_score". The default is "average_vector". The "average_vector" strategy will construct a single average vector from the positive and negative samples then use it to perform a pseudo-search. The "best_score" strategy is more advanced and navigates the HNSW with a heuristic of picking edges where the point is closer to the positive samples than it is the negatives.
887+
pub strategy: Option<String>,
888888
/// Filters to apply to the chunks to be recommended. This is a JSON object which contains the filters to apply to the chunks to be recommended. The default is None.
889889
pub filters: Option<ChunkFilter>,
890890
/// The number of groups to return. This is the number of groups which will be returned in the response. The default is 10.
@@ -958,7 +958,7 @@ pub enum RecommendGroupChunkResponseTypes {
958958
path = "/chunk_group/recommend",
959959
context_path = "/api",
960960
tag = "chunk_group",
961-
request_body(content = ReccomendGroupChunksRequest, description = "JSON request payload to get recommendations of chunks similar to the chunks in the request", content_type = "application/json"),
961+
request_body(content = RecommendGroupChunksRequest, description = "JSON request payload to get recommendations of chunks similar to the chunks in the request", content_type = "application/json"),
962962
responses(
963963
(status = 200, description = "JSON body representing the groups which are similar to the groups in the request", body = RecommendGroupChunkResponseTypes),
964964
(status = 400, description = "Service error relating to to getting similar chunks", body = ErrorResponseBody),
@@ -972,7 +972,7 @@ pub enum RecommendGroupChunkResponseTypes {
972972
)]
973973
#[tracing::instrument(skip(pool))]
974974
pub async fn get_recommended_groups(
975-
data: web::Json<ReccomendGroupChunksRequest>,
975+
data: web::Json<RecommendGroupChunksRequest>,
976976
pool: web::Data<Pool>,
977977
_user: LoggedUser,
978978
dataset_org_plan_sub: DatasetAndOrgWithSubAndPlan,
@@ -1080,6 +1080,7 @@ pub async fn get_recommended_groups(
10801080
let recommended_qdrant_point_ids = recommend_qdrant_groups_query(
10811081
positive_qdrant_ids,
10821082
negative_qdrant_ids,
1083+
data.strategy.clone(),
10831084
data.filters.clone(),
10841085
limit,
10851086
data.group_size.unwrap_or(10),

server/src/lib.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ impl Modify for SecurityAddon {
130130
name = "BSL",
131131
url = "https://github.com/devflowinc/trieve/blob/main/LICENSE.txt",
132132
),
133-
version = "0.5.9",
133+
version = "0.6.0",
134134
),
135135
servers(
136136
(url = "https://api.trieve.ai",
@@ -227,7 +227,7 @@ impl Modify for SecurityAddon {
227227
handlers::chunk_handler::ReturnQueuedChunk,
228228
handlers::chunk_handler::UpdateChunkData,
229229
handlers::chunk_handler::RecommendChunksRequest,
230-
handlers::group_handler::ReccomendGroupChunksRequest,
230+
handlers::group_handler::RecommendGroupChunksRequest,
231231
handlers::chunk_handler::UpdateChunkByTrackingIdData,
232232
handlers::chunk_handler::SearchChunkQueryResponseBody,
233233
handlers::chunk_handler::GenerateChunksRequest,

server/src/operators/qdrant_operator.rs

Lines changed: 29 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,10 @@ use qdrant_client::{
1212
group_id::Kind, payload_index_params::IndexParams, point_id::PointIdOptions,
1313
quantization_config::Quantization, BinaryQuantization, CountPoints, CreateCollection,
1414
Distance, FieldType, Filter, HnswConfigDiff, PayloadIndexParams, PointId, PointStruct,
15-
QuantizationConfig, RecommendPointGroups, RecommendPoints, SearchPointGroups, SearchPoints,
16-
SparseIndexConfig, SparseVectorConfig, SparseVectorParams, TextIndexParams, TokenizerType,
17-
Value, Vector, VectorParams, VectorParamsMap, VectorsConfig,
15+
QuantizationConfig, RecommendPointGroups, RecommendPoints, RecommendStrategy,
16+
SearchPointGroups, SearchPoints, SparseIndexConfig, SparseVectorConfig, SparseVectorParams,
17+
TextIndexParams, TokenizerType, Value, Vector, VectorParams, VectorParamsMap,
18+
VectorsConfig,
1819
},
1920
};
2021
use serde::{Deserialize, Serialize};
@@ -799,6 +800,7 @@ pub async fn search_qdrant_query(
799800
pub async fn recommend_qdrant_query(
800801
positive_ids: Vec<uuid::Uuid>,
801802
negative_ids: Vec<uuid::Uuid>,
803+
strategy: Option<String>,
802804
filters: Option<ChunkFilter>,
803805
limit: u64,
804806
dataset_id: uuid::Uuid,
@@ -807,10 +809,15 @@ pub async fn recommend_qdrant_query(
807809
) -> Result<Vec<uuid::Uuid>, ServiceError> {
808810
let qdrant_collection = config.QDRANT_COLLECTION_NAME;
809811

810-
let filter = assemble_qdrant_filter(filters, None, None, dataset_id, pool).await?;
812+
let recommend_strategy = match strategy {
813+
Some(strategy) => match strategy.as_str() {
814+
"best_score" => RecommendStrategy::BestScore,
815+
_ => RecommendStrategy::AverageVector,
816+
},
817+
None => RecommendStrategy::AverageVector,
818+
};
811819

812-
let qdrant =
813-
get_qdrant_connection(Some(&config.QDRANT_URL), Some(&config.QDRANT_API_KEY)).await?;
820+
let filter = assemble_qdrant_filter(filters, None, None, dataset_id, pool).await?;
814821

815822
let positive_point_ids: Vec<PointId> = positive_ids
816823
.iter()
@@ -851,11 +858,14 @@ pub async fn recommend_qdrant_query(
851858
read_consistency: None,
852859
positive_vectors: vec![],
853860
negative_vectors: vec![],
854-
strategy: None,
861+
strategy: Some(recommend_strategy.into()),
855862
timeout: None,
856863
shard_key_selector: None,
857864
};
858865

866+
let qdrant =
867+
get_qdrant_connection(Some(&config.QDRANT_URL), Some(&config.QDRANT_API_KEY)).await?;
868+
859869
let recommended_point_ids = qdrant
860870
.recommend(&recommend_points)
861871
.await
@@ -880,6 +890,7 @@ pub async fn recommend_qdrant_query(
880890
pub async fn recommend_qdrant_groups_query(
881891
positive_ids: Vec<uuid::Uuid>,
882892
negative_ids: Vec<uuid::Uuid>,
893+
strategy: Option<String>,
883894
filter: Option<ChunkFilter>,
884895
limit: u64,
885896
group_size: u32,
@@ -889,8 +900,13 @@ pub async fn recommend_qdrant_groups_query(
889900
) -> Result<Vec<GroupSearchResults>, ServiceError> {
890901
let qdrant_collection = config.QDRANT_COLLECTION_NAME;
891902

892-
let qdrant =
893-
get_qdrant_connection(Some(&config.QDRANT_URL), Some(&config.QDRANT_API_KEY)).await?;
903+
let recommend_strategy = match strategy {
904+
Some(strategy) => match strategy.as_str() {
905+
"best_score" => RecommendStrategy::BestScore,
906+
_ => RecommendStrategy::AverageVector,
907+
},
908+
None => RecommendStrategy::AverageVector,
909+
};
894910

895911
let filters = assemble_qdrant_filter(filter, None, None, dataset_id, pool).await?;
896912

@@ -932,14 +948,17 @@ pub async fn recommend_qdrant_groups_query(
932948
read_consistency: None,
933949
positive_vectors: vec![],
934950
negative_vectors: vec![],
935-
strategy: None,
951+
strategy: Some(recommend_strategy.into()),
936952
timeout: None,
937953
shard_key_selector: None,
938954
group_by: "group_ids".to_string(),
939955
group_size: if group_size == 0 { 1 } else { group_size },
940956
with_lookup: None,
941957
};
942958

959+
let qdrant =
960+
get_qdrant_connection(Some(&config.QDRANT_URL), Some(&config.QDRANT_API_KEY)).await?;
961+
943962
let data = qdrant
944963
.recommend_groups(&recommend_points)
945964
.await

0 commit comments

Comments
 (0)