Skip to content

Commit

Permalink
feature: add scores to chunk recommendation responses
Browse files Browse the repository at this point in the history
  • Loading branch information
skeptrunedev committed Apr 9, 2024
1 parent 1f6c09b commit c340e6b
Show file tree
Hide file tree
Showing 6 changed files with 196 additions and 50 deletions.
106 changes: 104 additions & 2 deletions server/src/data/models.rs
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,61 @@ impl ChunkCollision {
}
}

#[derive(Debug, Serialize, Deserialize, Clone, ToSchema)]
#[schema(example = json!({
"id": "e3e3e3e3-e3e3-e3e3-e3e3-e3e3e3e3e3e3",
"content": "Hello, world!",
"link": "https://trieve.ai",
"qdrant_point_id": "e3e3e3e3-e3e3-e3e3-e3e3-e3e3e3e3e3e3",
"created_at": "2021-01-01T00:00:00",
"updated_at": "2021-01-01T00:00:00",
"tag_set": "tag1,tag2",
"chunk_html": "<p>Hello, world!</p>",
"metadata": {"key": "value"},
"tracking_id": "e3e3e3e3-e3e3-e3e3-e3e3-e3e3e3e3e3e3",
"time_stamp": "2021-01-01T00:00:00",
"dataset_id": "e3e3e3e3-e3e3-e3e3-e3e3-e3e3e3e3e3e3",
"weight": 0.5,
"score": 0.9,
}))]
pub struct ChunkMetadataWithScore {
pub id: uuid::Uuid,
pub content: String,
pub link: Option<String>,
pub qdrant_point_id: Option<uuid::Uuid>,
pub created_at: chrono::NaiveDateTime,
pub updated_at: chrono::NaiveDateTime,
pub tag_set: Option<String>,
pub chunk_html: Option<String>,
pub metadata: Option<serde_json::Value>,
pub tracking_id: Option<String>,
pub time_stamp: Option<NaiveDateTime>,
pub dataset_id: uuid::Uuid,
pub weight: f64,
pub score: f32,
}

impl From<(ChunkMetadata, f32)> for ChunkMetadataWithScore {
fn from((chunk, score): (ChunkMetadata, f32)) -> Self {
ChunkMetadataWithScore {
id: chunk.id,
content: chunk.content,
link: chunk.link,
qdrant_point_id: chunk.qdrant_point_id,
created_at: chunk.created_at,
updated_at: chunk.updated_at,
tag_set: chunk.tag_set,
chunk_html: chunk.chunk_html,
metadata: chunk.metadata,
tracking_id: chunk.tracking_id,
time_stamp: chunk.time_stamp,
dataset_id: chunk.dataset_id,
weight: chunk.weight,
score,
}
}
}

#[derive(Debug, Serialize, Deserialize, Clone, Queryable, ToSchema)]
#[schema(example = json!({
"id": "e3e3e3e3-e3e3-e3e3-e3e3-e3e3e3e3e3e3",
Expand Down Expand Up @@ -392,6 +447,53 @@ impl From<ChunkMetadata> for SlimChunkMetadata {
}
}

#[derive(Debug, Serialize, Deserialize, Clone, ToSchema)]
#[schema(example = json!({
"id": "e3e3e3e3-e3e3-e3e3-e3e3-e3e3e3e3e3e3",
"link": "https://trieve.ai",
"qdrant_point_id": "e3e3e3e3-e3e3-e3e3-e3e3-e3e3e3e3e3e3",
"created_at": "2021-01-01T00:00:00",
"updated_at": "2021-01-01T00:00:00",
"tag_set": "tag1,tag2",
"metadata": {"key": "value"},
"tracking_id": "e3e3e3e3-e3e3-e3e3-e3e3-e3e3e3e3e3e3",
"time_stamp": "2021-01-01T00:00:00",
"dataset_id": "e3e3e3e3-e3e3-e3e3-e3e3-e3e3e3e3e3e3",
"weight": 0.5,
"score": 0.9,
}))]
pub struct SlimChunkMetadataWithScore {
pub id: uuid::Uuid,
pub link: Option<String>,
pub qdrant_point_id: Option<uuid::Uuid>,
pub created_at: chrono::NaiveDateTime,
pub updated_at: chrono::NaiveDateTime,
pub tag_set: Option<String>,
pub metadata: Option<serde_json::Value>,
pub tracking_id: Option<String>,
pub time_stamp: Option<NaiveDateTime>,
pub weight: f64,
pub score: f32,
}

impl From<ChunkMetadataWithScore> for SlimChunkMetadataWithScore {
fn from(chunk: ChunkMetadataWithScore) -> Self {
SlimChunkMetadataWithScore {
id: chunk.id,
link: chunk.link,
qdrant_point_id: chunk.qdrant_point_id,
created_at: chunk.created_at,
updated_at: chunk.updated_at,
tag_set: chunk.tag_set,
metadata: chunk.metadata,
tracking_id: chunk.tracking_id,
time_stamp: chunk.time_stamp,
weight: chunk.weight,
score: chunk.score,
}
}
}

#[derive(Debug, Serialize, Deserialize, Clone, Queryable, ToSchema)]
#[schema(
example = json!({
Expand Down Expand Up @@ -456,7 +558,7 @@ pub struct SearchSlimChunkQueryResponseBody {
}
],
}))]
pub struct GroupSlimChunksDTO {
pub struct GroupScoreSlimChunks {
pub group_id: uuid::Uuid,
pub metadata: Vec<ScoreSlimChunks>,
}
Expand All @@ -470,7 +572,7 @@ pub struct SearchGroupSlimChunksResult {

#[derive(Serialize, Deserialize, ToSchema)]
pub struct SearchOverGroupsSlimChunksResponseBody {
pub group_chunks: Vec<GroupSlimChunksDTO>,
pub group_chunks: Vec<GroupScoreSlimChunks>,
pub total_chunk_pages: i64,
}

Expand Down
61 changes: 43 additions & 18 deletions server/src/handlers/chunk_handler.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
use super::auth_handler::{AdminOnly, LoggedUser};
use crate::data::models::{
ChatMessageProxy, ChunkMetadata, DatasetAndOrgWithSubAndPlan, IngestSpecificChunkMetadata,
Pool, RedisPool, ScoreSlimChunks, SearchSlimChunkQueryResponseBody, ServerDatasetConfiguration,
SlimChunkMetadata, UnifiedId,
ChatMessageProxy, ChunkMetadata, ChunkMetadataWithScore, DatasetAndOrgWithSubAndPlan,
IngestSpecificChunkMetadata, Pool, RedisPool, ScoreSlimChunks,
SearchSlimChunkQueryResponseBody, ServerDatasetConfiguration, SlimChunkMetadata,
SlimChunkMetadataWithScore, UnifiedId,
};
use crate::errors::ServiceError;
use crate::get_env;
Expand Down Expand Up @@ -1449,9 +1450,9 @@ pub struct RecommendChunksRequest {
}

#[derive(Serialize, Deserialize, Debug, ToSchema)]
pub struct RecommendChunkMetadata(Vec<ChunkMetadata>);
pub struct RecommendChunkMetadata(Vec<ChunkMetadataWithScore>);
#[derive(Serialize, Deserialize, Debug, ToSchema)]
pub struct RecommendSlimChunkMetadata(Vec<SlimChunkMetadata>);
pub struct RecommendSlimChunkMetadata(Vec<SlimChunkMetadataWithScore>);

#[derive(Serialize, Deserialize, Debug, ToSchema)]
#[serde(untagged)]
Expand All @@ -1470,6 +1471,7 @@ pub enum RecommendChunksResponseTypes {
"time_stamp": "2021-01-01T00:00:00",
"dataset_id": "e3e3e3e3-e3e3-e3e3-e3e3-e3e3e3e3e3e3",
"weight": 0.5,
"score": 0.9,
}]))]
Chunks(RecommendChunkMetadata),
#[schema(example = json!([{
Expand All @@ -1484,6 +1486,7 @@ pub enum RecommendChunksResponseTypes {
"time_stamp": "2021-01-01T00:00:00",
"dataset_id": "e3e3e3e3-e3e3-e3e3-e3e3-e3e3e3e3e3e3",
"weight": 0.5,
"score": 0.9,
}]))]
#[schema(title = "SlimChunkMetadata")]
SlimChunks(RecommendSlimChunkMetadata),
Expand Down Expand Up @@ -1617,7 +1620,7 @@ pub async fn get_recommended_chunks(

timer.add("finish extending tracking_ids and chunk_ids to qdrant_point_ids; start recommend_qdrant_query");

let recommended_qdrant_point_ids = recommend_qdrant_query(
let recommended_qdrant_results = recommend_qdrant_query(
positive_qdrant_ids,
negative_qdrant_ids,
data.strategy.clone(),
Expand All @@ -1634,30 +1637,52 @@ pub async fn get_recommended_chunks(

timer.add("finish recommend_qdrant_query; start get_metadata_from_point_ids");

let recommended_chunk_metadatas =
get_metadata_from_point_ids(recommended_qdrant_point_ids, pool)
.await
.map_err(|err| {
ServiceError::BadRequest(format!(
"Could not get recommended chunk_metadas from qdrant_point_ids: {}",
err
))
})?;
let recommended_chunk_metadatas = get_metadata_from_point_ids(
recommended_qdrant_results
.clone()
.into_iter()
.map(|recommend_qdrant_result| recommend_qdrant_result.point_id)
.collect(),
pool,
)
.await
.map_err(|err| {
ServiceError::BadRequest(format!(
"Could not get recommended chunk_metadas from qdrant_point_ids: {}",
err
))
})?;

let recommended_chunk_metadatas_with_score = recommended_chunk_metadatas
.into_iter()
.map(|chunk_metadata| {
let score = recommended_qdrant_results
.iter()
.find(|recommend_qdrant_result| {
recommend_qdrant_result.point_id
== chunk_metadata.qdrant_point_id.unwrap_or_default()
})
.map(|recommend_qdrant_result| recommend_qdrant_result.score)
.unwrap_or(0.0);

ChunkMetadataWithScore::from((chunk_metadata, score))
})
.collect::<Vec<ChunkMetadataWithScore>>();

timer.add("finish get_metadata_from_point_ids and return results");

if data.slim_chunks.unwrap_or(false) {
let res = recommended_chunk_metadatas
let res = recommended_chunk_metadatas_with_score
.into_iter()
.map(|chunk| chunk.into())
.collect::<Vec<SlimChunkMetadata>>();
.collect::<Vec<SlimChunkMetadataWithScore>>();

return Ok(HttpResponse::Ok().json(res));
}

Ok(HttpResponse::Ok()
.insert_header((Timer::header_key(), timer.header_value()))
.json(recommended_chunk_metadatas))
.json(recommended_chunk_metadatas_with_score))
}

#[derive(Debug, Serialize, Deserialize, ToSchema)]
Expand Down
20 changes: 10 additions & 10 deletions server/src/handlers/group_handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use super::{
use crate::{
data::models::{
ChunkGroup, ChunkGroupAndFile, ChunkGroupBookmark, ChunkMetadata,
DatasetAndOrgWithSubAndPlan, GroupSlimChunksDTO, Pool, ScoreSlimChunks,
DatasetAndOrgWithSubAndPlan, GroupScoreSlimChunks, Pool, ScoreSlimChunks,
SearchGroupSlimChunksResult, SearchOverGroupsSlimChunksResponseBody,
ServerDatasetConfiguration, UnifiedId,
},
Expand All @@ -20,7 +20,7 @@ use crate::{
search_operator::{
full_text_search_over_groups, get_metadata_from_groups, hybrid_search_over_groups,
search_full_text_groups, search_hybrid_groups, search_semantic_groups,
semantic_search_over_groups, GroupScoreChunkDTO, SearchOverGroupsQueryResult,
semantic_search_over_groups, GroupScoreChunk, SearchOverGroupsQueryResult,
SearchOverGroupsResponseBody,
},
},
Expand Down Expand Up @@ -896,9 +896,9 @@ pub struct RecommendGroupChunksRequest {
}

#[derive(Debug, Serialize, Deserialize, Clone, ToSchema)]
pub struct RecommendGroupChunksDTO(pub Vec<GroupScoreChunkDTO>);
pub struct RecommendGroupChunks(pub Vec<GroupScoreChunk>);
#[derive(Debug, Serialize, Deserialize, Clone, ToSchema)]
pub struct RecommendGroupSlimChunksDTO(pub Vec<GroupSlimChunksDTO>);
pub struct RecommendGroupSlimChunks(pub Vec<GroupScoreSlimChunks>);

#[derive(Serialize, Deserialize, Debug, ToSchema)]
#[serde(untagged)]
Expand All @@ -925,7 +925,7 @@ pub enum RecommendGroupChunkResponseTypes {
}
]
}]))]
GroupSlimChunksDTO(RecommendGroupSlimChunksDTO),
GroupSlimChunksDTO(RecommendGroupSlimChunks),
#[schema(example = json!({
"group_id": "e3e3e3-e3e3-e3e3-e3e3-e3e3e3e3e3e3",
"metadata": [
Expand All @@ -947,7 +947,7 @@ pub enum RecommendGroupChunkResponseTypes {
}
]
}))]
GroupScoreChunkDTO(RecommendGroupChunksDTO),
GroupScoreChunkDTO(RecommendGroupChunks),
}

/// Get Recommended Groups
Expand Down Expand Up @@ -1108,15 +1108,15 @@ pub async fn get_recommended_groups(
if data.slim_chunks.unwrap_or(false) {
let res = recommended_chunk_metadatas
.into_iter()
.map(|metadata| GroupSlimChunksDTO {
.map(|metadata| GroupScoreSlimChunks {
group_id: metadata.group_id,
metadata: metadata
.metadata
.into_iter()
.map(|chunk| chunk.into())
.collect::<Vec<ScoreSlimChunks>>(),
})
.collect::<Vec<GroupSlimChunksDTO>>();
.collect::<Vec<GroupScoreSlimChunks>>();

return Ok(HttpResponse::Ok()
.insert_header((Timer::header_key(), timer.header_value()))
Expand Down Expand Up @@ -1431,15 +1431,15 @@ pub async fn search_over_groups(
let ids = result_chunks
.group_chunks
.into_iter()
.map(|metadata| GroupSlimChunksDTO {
.map(|metadata| GroupScoreSlimChunks {
group_id: metadata.group_id,
metadata: metadata
.metadata
.into_iter()
.map(|chunk| chunk.into())
.collect::<Vec<ScoreSlimChunks>>(),
})
.collect::<Vec<GroupSlimChunksDTO>>();
.collect::<Vec<GroupScoreSlimChunks>>();

let res = SearchOverGroupsSlimChunksResponseBody {
group_chunks: ids,
Expand Down
10 changes: 6 additions & 4 deletions server/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -237,8 +237,8 @@ impl Modify for SecurityAddon {
handlers::chunk_handler::RecommendChunksResponseTypes,
handlers::chunk_handler::RecommendChunkMetadata,
handlers::chunk_handler::RecommendSlimChunkMetadata,
handlers::group_handler::RecommendGroupChunksDTO,
handlers::group_handler::RecommendGroupSlimChunksDTO,
handlers::group_handler::RecommendGroupChunks,
handlers::group_handler::RecommendGroupSlimChunks,
handlers::group_handler::SearchWithinGroupData,
handlers::group_handler::SearchOverGroupsData,
handlers::group_handler::SearchGroupsResult,
Expand Down Expand Up @@ -271,7 +271,7 @@ impl Modify for SecurityAddon {
handlers::organization_handler::UpdateOrganizationData,
operators::event_operator::EventReturn,
operators::search_operator::SearchOverGroupsResponseBody,
operators::search_operator::GroupScoreChunkDTO,
operators::search_operator::GroupScoreChunk,
handlers::dataset_handler::CreateDatasetRequest,
handlers::dataset_handler::UpdateDatasetRequest,
data::models::ApiKeyDTO,
Expand All @@ -280,6 +280,7 @@ impl Modify for SecurityAddon {
data::models::Topic,
data::models::Message,
data::models::ChunkMetadata,
data::models::ChunkMetadataWithScore,
data::models::ChatMessageProxy,
data::models::Event,
data::models::SlimGroup,
Expand All @@ -298,9 +299,10 @@ impl Modify for SecurityAddon {
data::models::SearchSlimChunkQueryResponseBody,
data::models::ScoreSlimChunks,
data::models::SlimChunkMetadata,
data::models::SlimChunkMetadataWithScore,
data::models::SearchGroupSlimChunksResult,
data::models::SearchOverGroupsSlimChunksResponseBody,
data::models::GroupSlimChunksDTO,
data::models::GroupScoreSlimChunks,
handlers::chunk_handler::RangeCondition,
errors::ErrorResponseBody,
)
Expand Down
Loading

0 comments on commit c340e6b

Please sign in to comment.