Skip to content

Commit

Permalink
feature: add strategies to recommendations
Browse files Browse the repository at this point in the history
  • Loading branch information
skeptrunedev committed Apr 9, 2024
1 parent be5f173 commit c79de28
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 18 deletions.
3 changes: 3 additions & 0 deletions server/src/handlers/chunk_handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1438,6 +1438,8 @@ pub struct RecommendChunksRequest {
pub positive_tracking_ids: Option<Vec<String>>,
/// The tracking_ids of the chunks to be used as negative examples for the recommendation. The chunks in this array will be used to filter out similar chunks.
pub negative_tracking_ids: Option<Vec<String>>,
/// Strategy to use for recommendations, either "average_vector" or "best_score". The default is "average_vector". The "average_vector" strategy will construct a single average vector from the positive and negative samples then use it to perform a pseudo-search. The "best_score" strategy is more advanced and navigates the HNSW with a heuristic of picking edges where the point is closer to the positive samples than it is the negatives.
pub strategy: Option<String>,
/// Filters to apply to the chunks to be recommended. This is a JSON object which contains the filters to apply to the chunks to be recommended. The default is None.
pub filters: Option<ChunkFilter>,
/// The number of chunks to return. This is the number of chunks which will be returned in the response. The default is 10.
Expand Down Expand Up @@ -1618,6 +1620,7 @@ pub async fn get_recommended_chunks(
let recommended_qdrant_point_ids = recommend_qdrant_query(
positive_qdrant_ids,
negative_qdrant_ids,
data.strategy.clone(),
data.filters.clone(),
limit,
dataset_org_plan_sub.dataset.id,
Expand Down
13 changes: 7 additions & 6 deletions server/src/handlers/group_handler.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
use super::{
auth_handler::{AdminOnly, LoggedUser},
chunk_handler::{
parse_query, ChunkFilter, ScoreChunkDTO, SearchChunkData,
},
chunk_handler::{parse_query, ChunkFilter, ScoreChunkDTO, SearchChunkData},
};
use crate::{
data::models::{
Expand Down Expand Up @@ -876,7 +874,7 @@ pub struct GenerateOffGroupData {
}

#[derive(Debug, Serialize, Deserialize, Clone, ToSchema)]
pub struct ReccomendGroupChunksRequest {
pub struct RecommendGroupChunksRequest {
/// The ids of the groups to be used as positive examples for the recommendation. The groups in this array will be used to find similar groups.
pub positive_group_ids: Option<Vec<uuid::Uuid>>,
/// The ids of the groups to be used as negative examples for the recommendation. The groups in this array will be used to filter out similar groups.
Expand All @@ -885,6 +883,8 @@ pub struct ReccomendGroupChunksRequest {
pub positive_group_tracking_ids: Option<Vec<String>>,
/// The ids of the groups to be used as negative examples for the recommendation. The groups in this array will be used to filter out similar groups.
pub negative_group_tracking_ids: Option<Vec<String>>,
/// Strategy to use for recommendations, either "average_vector" or "best_score". The default is "average_vector". The "average_vector" strategy will construct a single average vector from the positive and negative samples then use it to perform a pseudo-search. The "best_score" strategy is more advanced and navigates the HNSW with a heuristic of picking edges where the point is closer to the positive samples than it is the negatives.
pub strategy: Option<String>,
/// Filters to apply to the chunks to be recommended. This is a JSON object which contains the filters to apply to the chunks to be recommended. The default is None.
pub filters: Option<ChunkFilter>,
/// The number of groups to return. This is the number of groups which will be returned in the response. The default is 10.
Expand Down Expand Up @@ -958,7 +958,7 @@ pub enum RecommendGroupChunkResponseTypes {
path = "/chunk_group/recommend",
context_path = "/api",
tag = "chunk_group",
request_body(content = ReccomendGroupChunksRequest, description = "JSON request payload to get recommendations of chunks similar to the chunks in the request", content_type = "application/json"),
request_body(content = RecommendGroupChunksRequest, description = "JSON request payload to get recommendations of chunks similar to the chunks in the request", content_type = "application/json"),
responses(
(status = 200, description = "JSON body representing the groups which are similar to the groups in the request", body = RecommendGroupChunkResponseTypes),
(status = 400, description = "Service error relating to to getting similar chunks", body = ErrorResponseBody),
Expand All @@ -972,7 +972,7 @@ pub enum RecommendGroupChunkResponseTypes {
)]
#[tracing::instrument(skip(pool))]
pub async fn get_recommended_groups(
data: web::Json<ReccomendGroupChunksRequest>,
data: web::Json<RecommendGroupChunksRequest>,
pool: web::Data<Pool>,
_user: LoggedUser,
dataset_org_plan_sub: DatasetAndOrgWithSubAndPlan,
Expand Down Expand Up @@ -1080,6 +1080,7 @@ pub async fn get_recommended_groups(
let recommended_qdrant_point_ids = recommend_qdrant_groups_query(
positive_qdrant_ids,
negative_qdrant_ids,
data.strategy.clone(),
data.filters.clone(),
limit,
data.group_size.unwrap_or(10),
Expand Down
4 changes: 2 additions & 2 deletions server/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ impl Modify for SecurityAddon {
name = "BSL",
url = "https://github.com/devflowinc/trieve/blob/main/LICENSE.txt",
),
version = "0.5.9",
version = "0.6.0",
),
servers(
(url = "https://api.trieve.ai",
Expand Down Expand Up @@ -227,7 +227,7 @@ impl Modify for SecurityAddon {
handlers::chunk_handler::ReturnQueuedChunk,
handlers::chunk_handler::UpdateChunkData,
handlers::chunk_handler::RecommendChunksRequest,
handlers::group_handler::ReccomendGroupChunksRequest,
handlers::group_handler::RecommendGroupChunksRequest,
handlers::chunk_handler::UpdateChunkByTrackingIdData,
handlers::chunk_handler::SearchChunkQueryResponseBody,
handlers::chunk_handler::GenerateChunksRequest,
Expand Down
39 changes: 29 additions & 10 deletions server/src/operators/qdrant_operator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,10 @@ use qdrant_client::{
group_id::Kind, payload_index_params::IndexParams, point_id::PointIdOptions,
quantization_config::Quantization, BinaryQuantization, CountPoints, CreateCollection,
Distance, FieldType, Filter, HnswConfigDiff, PayloadIndexParams, PointId, PointStruct,
QuantizationConfig, RecommendPointGroups, RecommendPoints, SearchPointGroups, SearchPoints,
SparseIndexConfig, SparseVectorConfig, SparseVectorParams, TextIndexParams, TokenizerType,
Value, Vector, VectorParams, VectorParamsMap, VectorsConfig,
QuantizationConfig, RecommendPointGroups, RecommendPoints, RecommendStrategy,
SearchPointGroups, SearchPoints, SparseIndexConfig, SparseVectorConfig, SparseVectorParams,
TextIndexParams, TokenizerType, Value, Vector, VectorParams, VectorParamsMap,
VectorsConfig,
},
};
use serde::{Deserialize, Serialize};
Expand Down Expand Up @@ -799,6 +800,7 @@ pub async fn search_qdrant_query(
pub async fn recommend_qdrant_query(
positive_ids: Vec<uuid::Uuid>,
negative_ids: Vec<uuid::Uuid>,
strategy: Option<String>,
filters: Option<ChunkFilter>,
limit: u64,
dataset_id: uuid::Uuid,
Expand All @@ -807,10 +809,15 @@ pub async fn recommend_qdrant_query(
) -> Result<Vec<uuid::Uuid>, ServiceError> {
let qdrant_collection = config.QDRANT_COLLECTION_NAME;

let filter = assemble_qdrant_filter(filters, None, None, dataset_id, pool).await?;
let recommend_strategy = match strategy {
Some(strategy) => match strategy.as_str() {
"best_score" => RecommendStrategy::BestScore,
_ => RecommendStrategy::AverageVector,
},
None => RecommendStrategy::AverageVector,
};

let qdrant =
get_qdrant_connection(Some(&config.QDRANT_URL), Some(&config.QDRANT_API_KEY)).await?;
let filter = assemble_qdrant_filter(filters, None, None, dataset_id, pool).await?;

let positive_point_ids: Vec<PointId> = positive_ids
.iter()
Expand Down Expand Up @@ -851,11 +858,14 @@ pub async fn recommend_qdrant_query(
read_consistency: None,
positive_vectors: vec![],
negative_vectors: vec![],
strategy: None,
strategy: Some(recommend_strategy.into()),
timeout: None,
shard_key_selector: None,
};

let qdrant =
get_qdrant_connection(Some(&config.QDRANT_URL), Some(&config.QDRANT_API_KEY)).await?;

let recommended_point_ids = qdrant
.recommend(&recommend_points)
.await
Expand All @@ -880,6 +890,7 @@ pub async fn recommend_qdrant_query(
pub async fn recommend_qdrant_groups_query(
positive_ids: Vec<uuid::Uuid>,
negative_ids: Vec<uuid::Uuid>,
strategy: Option<String>,
filter: Option<ChunkFilter>,
limit: u64,
group_size: u32,
Expand All @@ -889,8 +900,13 @@ pub async fn recommend_qdrant_groups_query(
) -> Result<Vec<GroupSearchResults>, ServiceError> {
let qdrant_collection = config.QDRANT_COLLECTION_NAME;

let qdrant =
get_qdrant_connection(Some(&config.QDRANT_URL), Some(&config.QDRANT_API_KEY)).await?;
let recommend_strategy = match strategy {
Some(strategy) => match strategy.as_str() {
"best_score" => RecommendStrategy::BestScore,
_ => RecommendStrategy::AverageVector,
},
None => RecommendStrategy::AverageVector,
};

let filters = assemble_qdrant_filter(filter, None, None, dataset_id, pool).await?;

Expand Down Expand Up @@ -932,14 +948,17 @@ pub async fn recommend_qdrant_groups_query(
read_consistency: None,
positive_vectors: vec![],
negative_vectors: vec![],
strategy: None,
strategy: Some(recommend_strategy.into()),
timeout: None,
shard_key_selector: None,
group_by: "group_ids".to_string(),
group_size: if group_size == 0 { 1 } else { group_size },
with_lookup: None,
};

let qdrant =
get_qdrant_connection(Some(&config.QDRANT_URL), Some(&config.QDRANT_API_KEY)).await?;

let data = qdrant
.recommend_groups(&recommend_points)
.await
Expand Down

0 comments on commit c79de28

Please sign in to comment.