Skip to content

Commit

Permalink
feature: readd recency bias for groups
Browse files Browse the repository at this point in the history
  • Loading branch information
densumesh authored and skeptrunedev committed Nov 30, 2024
1 parent c1bd516 commit 962ff9c
Showing 1 changed file with 62 additions and 40 deletions.
102 changes: 62 additions & 40 deletions server/src/operators/search_operator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@ use super::typo_operator::correct_query;
use crate::data::models::{
convert_to_date_time, ChunkGroup, ChunkGroupAndFileId, ChunkMetadata,
ChunkMetadataStringTagSet, ChunkMetadataTypes, ConditionType, ContentChunkMetadata, Dataset,
DatasetConfiguration, GeoInfoWithBias, HasIDCondition, QdrantChunkMetadata, QdrantSortBy,
QueryTypes, ReRankOptions, RedisPool, ScoreChunk, ScoreChunkDTO, SearchMethod,
SlimChunkMetadata, SortByField, SortBySearchType, SortOptions, UnifiedId,
DatasetConfiguration, HasIDCondition, QdrantChunkMetadata, QdrantSortBy, QueryTypes,
ReRankOptions, RedisPool, ScoreChunk, ScoreChunkDTO, SearchMethod, SlimChunkMetadata,
SortByField, SortBySearchType, SortOptions, UnifiedId,
};
use crate::handlers::chunk_handler::{
AutocompleteReqPayload, ChunkFilter, CountChunkQueryResponseBody, CountChunksReqPayload,
Expand Down Expand Up @@ -1630,12 +1630,15 @@ pub fn rerank_chunks(

pub fn rerank_groups(
groups: Vec<GroupScoreChunk>,
tag_weights: Option<HashMap<String, f32>>,
use_weights: Option<bool>,
query_location: Option<GeoInfoWithBias>,
sort_options: Option<SortOptions>,
) -> Vec<GroupScoreChunk> {
let mut reranked_groups = Vec::new();
if use_weights.unwrap_or(true) {
if sort_options.is_none() {
return groups;
}

let sort_options = sort_options.unwrap();
if sort_options.use_weights.unwrap_or(true) {
groups.into_iter().for_each(|mut group| {
let first_chunk = group.metadata.get_mut(0).unwrap();
if first_chunk.metadata[0].metadata().weight == 0.0 {
Expand All @@ -1649,8 +1652,55 @@ pub fn rerank_groups(
reranked_groups = groups;
}

if query_location.is_some() && query_location.unwrap().bias > 0.0 {
let info_with_bias = query_location.unwrap();
if sort_options.recency_bias.is_some() && sort_options.recency_bias.unwrap() > 0.0 {
let recency_weight = sort_options.recency_bias.unwrap();
let min_timestamp = reranked_groups
.iter()
.filter_map(|group| group.metadata[0].metadata[0].metadata().time_stamp)
.min();
let max_timestamp = reranked_groups
.iter()
.filter_map(|group| group.metadata[0].metadata[0].metadata().time_stamp)
.max();
let max_score = reranked_groups
.iter()
.map(|group| group.metadata[0].score)
.max_by(|a, b| a.partial_cmp(b).unwrap());
let min_score = reranked_groups
.iter()
.map(|group| group.metadata[0].score)
.min_by(|a, b| a.partial_cmp(b).unwrap());

if let (Some(min), Some(max)) = (min_timestamp, max_timestamp) {
let min_duration = chrono::Utc::now().signed_duration_since(min.and_utc());
let max_duration = chrono::Utc::now().signed_duration_since(max.and_utc());

reranked_groups = reranked_groups
.iter_mut()
.map(|group| {
let first_chunk = group.metadata.get_mut(0).unwrap();
if let Some(time_stamp) = first_chunk.metadata[0].metadata().time_stamp {
let duration =
chrono::Utc::now().signed_duration_since(time_stamp.and_utc());
let normalized_recency_score = (duration.num_seconds() as f32
- min_duration.num_seconds() as f32)
/ (max_duration.num_seconds() as f32
- min_duration.num_seconds() as f32);

let normalized_chunk_score = (first_chunk.score - min_score.unwrap_or(0.0))
/ (max_score.unwrap_or(1.0) - min_score.unwrap_or(0.0));

first_chunk.score = (normalized_chunk_score * (1.0 / recency_weight) as f64)
+ (recency_weight * normalized_recency_score) as f64
}
group.clone()
})
.collect::<Vec<GroupScoreChunk>>();
}
}

if sort_options.location_bias.is_some() && sort_options.location_bias.unwrap().bias > 0.0 {
let info_with_bias = sort_options.location_bias.unwrap();
let query_location = info_with_bias.location;
let location_bias = info_with_bias.bias;
let distances = reranked_groups
Expand Down Expand Up @@ -1688,7 +1738,7 @@ pub fn rerank_groups(
.collect::<Vec<GroupScoreChunk>>();
}

if let Some(tag_weights) = tag_weights {
if let Some(tag_weights) = sort_options.tag_weights {
reranked_groups = reranked_groups
.iter_mut()
.map(|group| {
Expand Down Expand Up @@ -2545,21 +2595,7 @@ pub async fn search_over_groups_query(

timer.add("fetched from postgres");

result_chunks.group_chunks = rerank_groups(
result_chunks.group_chunks,
data.sort_options
.as_ref()
.map(|d| d.tag_weights.clone())
.unwrap_or_default(),
data.sort_options
.as_ref()
.map(|d| d.use_weights)
.unwrap_or_default(),
data.sort_options
.as_ref()
.map(|d| d.location_bias)
.unwrap_or_default(),
);
result_chunks.group_chunks = rerank_groups(result_chunks.group_chunks, data.sort_options);

result_chunks.corrected_query = corrected_query.map(|c| c.query);

Expand Down Expand Up @@ -2771,21 +2807,7 @@ pub async fn hybrid_search_over_groups(
});
}

reranked_chunks = rerank_groups(
reranked_chunks,
data.sort_options
.as_ref()
.map(|d| d.tag_weights.clone())
.unwrap_or_default(),
data.sort_options
.as_ref()
.map(|d| d.use_weights)
.unwrap_or_default(),
data.sort_options
.as_ref()
.map(|d| d.location_bias)
.unwrap_or_default(),
);
reranked_chunks = rerank_groups(reranked_chunks, data.sort_options);

let result_chunks = DeprecatedSearchOverGroupsResponseBody {
group_chunks: reranked_chunks,
Expand Down

0 comments on commit 962ff9c

Please sign in to comment.