Skip to content

Commit

Permalink
bugfix: order on recommendations and group search
Browse files Browse the repository at this point in the history
  • Loading branch information
skeptrunedev committed Apr 10, 2024
1 parent 8e49f6e commit b1e29be
Show file tree
Hide file tree
Showing 8 changed files with 111 additions and 24 deletions.
7 changes: 7 additions & 0 deletions search/src/components/ChunkMetadataDisplay.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ export interface ChunkMetadataDisplayProps {
signedInUserId?: string;
viewingUserId?: string;
chunk: ChunkMetadata;
score?: number;
chunkGroups: ChunkGroupDTO[];
bookmarks: ChunkBookmarksDTO[];
setShowConfirmModal: Setter<boolean>;
Expand Down Expand Up @@ -282,6 +283,12 @@ const ChunkMetadataDisplay = (props: ChunkMetadataDisplayProps) => {
{props.chunk.link}
</a>
</Show>
<div class="grid w-fit auto-cols-min grid-cols-[1fr,3fr] gap-x-2 text-magenta-500 dark:text-magenta-400">
<Show when={props.score}>
<span class="font-semibold">Similarity: </span>
<span>{props.score?.toPrecision(3)}</span>
</Show>
</div>
<Show
when={
props.chunk.tag_set &&
Expand Down
4 changes: 1 addition & 3 deletions search/src/components/GroupPage.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,7 @@ export const GroupPage = (props: GroupPageProps) => {
const [searchType, setSearchType] = createSignal<string>("hybrid");
const [filters, setFilters] = createSignal<Filters | undefined>(undefined);
const [searchLoading, setSearchLoading] = createSignal(false);
const [chunkMetadatas, setChunkMetadatas] = createSignal<
ChunkMetadata[]
>([]);
const [chunkMetadatas, setChunkMetadatas] = createSignal<ChunkMetadata[]>([]);
const [searchMetadatasWithVotes, setSearchMetadatasWithVotes] = createSignal<
ScoreChunkDTO[]
>(searchChunkMetadatasWithVotes);
Expand Down
13 changes: 8 additions & 5 deletions search/src/components/SingleChunkPage.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import {
isChunkGroupPageDTO,
ChunkMetadata,
ScoreChunkDTO,
ChunkMetadataWithScore,
} from "../../utils/apiTypes";
import ScoreChunk from "./ScoreChunk";
import { FullScreenModal } from "./Atoms/FullScreenModal";
Expand All @@ -35,8 +36,9 @@ export const SingleChunkPage = (props: SingleChunkPageProps) => {
const $dataset = datasetAndUserContext.currentDataset;
const initialChunkMetadata = props.defaultResultChunk.metadata;

const [chunkMetadata, setChunkMetadata] =
createSignal<ChunkMetadata | null>(initialChunkMetadata);
const [chunkMetadata, setChunkMetadata] = createSignal<ChunkMetadata | null>(
initialChunkMetadata,
);
const [error, setError] = createSignal("");
const [fetching, setFetching] = createSignal(true);
const [chunkGroups, setChunkGroups] = createSignal<ChunkGroupDTO[]>([]);
Expand All @@ -52,7 +54,7 @@ export const SingleChunkPage = (props: SingleChunkPageProps) => {
const [loadingRecommendations, setLoadingRecommendations] =
createSignal(false);
const [recommendedChunks, setRecommendedChunks] = createSignal<
ChunkMetadata[]
ChunkMetadataWithScore[]
>([]);
const [openChat, setOpenChat] = createSignal(false);
const [selectedIds, setSelectedIds] = createSignal<string[]>([]);
Expand Down Expand Up @@ -113,7 +115,7 @@ export const SingleChunkPage = (props: SingleChunkPageProps) => {

const fetchRecommendations = (
ids: string[],
prev_recommendations: ChunkMetadata[],
prev_recommendations: ChunkMetadataWithScore[],
) => {
setLoadingRecommendations(true);
const currentDataset = $dataset?.();
Expand All @@ -133,7 +135,7 @@ export const SingleChunkPage = (props: SingleChunkPageProps) => {
}).then((response) => {
if (response.ok) {
void response.json().then((data) => {
const typed_data = data as ChunkMetadata[];
const typed_data = data as ChunkMetadataWithScore[];
const deduped_data = typed_data.filter((d) => {
return !prev_recommendations.some((c) => c.id == d.id);
});
Expand Down Expand Up @@ -263,6 +265,7 @@ export const SingleChunkPage = (props: SingleChunkPageProps) => {
<ChunkMetadataDisplay
totalGroupPages={totalGroupPages()}
chunk={chunk}
score={chunk.score}
chunkGroups={chunkGroups()}
bookmarks={bookmarks()}
setShowConfirmModal={setShowConfirmDeleteModal}
Expand Down
17 changes: 17 additions & 0 deletions search/utils/apiTypes.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,23 @@ export interface ChunkMetadata {
weight: number;
}

export interface ChunkMetadataWithScore {
id: string;
content: string;
chunk_html?: string;
link: string | null;
qdrant_point_id: string;
created_at: string;
updated_at: string;
tag_set: string | null;
tracking_id: string | null;
time_stamp: string | null;
metadata: Record<string, never> | null;
dataset_id: string;
weight: number;
score: number;
}

export const indirectHasOwnProperty = (obj: unknown, prop: string): boolean => {
return Object.prototype.hasOwnProperty.call(obj, prop);
};
Expand Down
9 changes: 9 additions & 0 deletions server/src/handlers/chunk_handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1628,6 +1628,15 @@ pub async fn get_recommended_chunks(
})
.collect::<Vec<ChunkMetadataWithScore>>();

let recommended_chunk_metadatas_with_score = recommended_chunk_metadatas_with_score
.into_iter()
.sorted_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
})
.collect::<Vec<ChunkMetadataWithScore>>();

timer.add("finish get_metadata_from_point_ids and return results");

if data.slim_chunks.unwrap_or(false) {
Expand Down
33 changes: 27 additions & 6 deletions server/src/handlers/group_handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ use crate::{
search_operator::{
full_text_search_over_groups, get_metadata_from_groups, hybrid_search_over_groups,
search_full_text_groups, search_hybrid_groups, search_semantic_groups,
semantic_search_over_groups, SearchOverGroupsQueryResult,
semantic_search_over_groups, GroupScoreChunk, SearchOverGroupsQueryResult,
},
},
};
Expand Down Expand Up @@ -1022,7 +1022,7 @@ pub async fn get_recommended_groups(

timer.add("finish to extend qdrant_point_ids for group_tracking_ids and group_ids; start to recommend_qdrant_groups_query from qdrant");

let recommended_qdrant_point_ids = recommend_qdrant_groups_query(
let recommended_groups_from_qdrant = recommend_qdrant_groups_query(
positive_qdrant_ids,
negative_qdrant_ids,
data.strategy.clone(),
Expand All @@ -1038,15 +1038,36 @@ pub async fn get_recommended_groups(
ServiceError::BadRequest(format!("Could not get recommended groups: {}", err))
})?;

let group_query_result = SearchOverGroupsQueryResult {
search_results: recommended_qdrant_point_ids.clone(),
total_chunk_pages: (recommended_qdrant_point_ids.len() as f64 / 10.0).ceil() as i64,
let group_qdrant_query_result = SearchOverGroupsQueryResult {
search_results: recommended_groups_from_qdrant.clone(),
total_chunk_pages: (recommended_groups_from_qdrant.len() as f64 / 10.0).ceil() as i64,
};

timer.add("finish to recommend_qdrant_groups_query from qdrant; start to get_metadata_from_groups from postgres");

let recommended_chunk_metadatas =
get_metadata_from_groups(group_query_result, Some(false), pool).await?;
get_metadata_from_groups(group_qdrant_query_result.clone(), Some(false), pool).await?;

let recommended_chunk_metadatas = recommended_groups_from_qdrant
.into_iter()
.filter_map(|group| {
recommended_chunk_metadatas
.iter()
.find(|metadata| metadata.group_id == group.group_id)
.cloned()
})
.collect::<Vec<GroupScoreChunk>>();

let recommended_chunk_metadatas = group_qdrant_query_result
.search_results
.into_iter()
.filter_map(|group| {
recommended_chunk_metadatas
.iter()
.find(|metadata| metadata.group_id == group.group_id)
.cloned()
})
.collect::<Vec<GroupScoreChunk>>();

timer.add("finish to get_metadata_from_groups from postgres and return results");

Expand Down
4 changes: 2 additions & 2 deletions server/src/operators/qdrant_operator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -984,7 +984,7 @@ pub async fn recommend_qdrant_groups_query(
)
})?;

let recommended_point_ids = data
let group_recommendation_results = data
.result
.unwrap()
.groups
Expand Down Expand Up @@ -1020,7 +1020,7 @@ pub async fn recommend_qdrant_groups_query(
})
.collect();

Ok(recommended_point_ids)
Ok(group_recommendation_results)
}

#[tracing::instrument]
Expand Down
48 changes: 40 additions & 8 deletions server/src/operators/search_operator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ use crate::handlers::chunk_handler::{
SearchChunkData, SearchChunkQueryResponseBody,
};
use crate::handlers::group_handler::{
SearchWithinGroupResults, SearchOverGroupsData, SearchWithinGroupData,
SearchOverGroupsData, SearchWithinGroupData, SearchWithinGroupResults,
};
use crate::operators::model_operator::get_sparse_vector;
use crate::operators::qdrant_operator::{get_qdrant_connection, search_qdrant_query};
Expand Down Expand Up @@ -1825,7 +1825,7 @@ pub async fn semantic_search_over_groups(

timer.add("finish creating dense embedding vector; start to fetch from qdrant");

let search_chunk_query_results = retrieve_group_qdrant_points_query(
let search_over_groups_qdrant_result = retrieve_group_qdrant_points_query(
VectorType::Dense(embedding_vector),
page,
data.filters.clone(),
Expand All @@ -1841,8 +1841,24 @@ pub async fn semantic_search_over_groups(

timer.add("finish fetching from qdrant; start to fetch from postgres");

let result_chunks =
retrieve_chunks_for_groups(search_chunk_query_results, &data, pool.clone()).await?;
let mut result_chunks = retrieve_chunks_for_groups(
search_over_groups_qdrant_result.clone(),
&data,
pool.clone(),
)
.await?;

result_chunks.group_chunks = search_over_groups_qdrant_result
.search_results
.iter()
.filter_map(|search_result| {
result_chunks
.group_chunks
.iter()
.find(|group| group.group_id == search_result.group_id)
.cloned()
})
.collect();

timer.add("finish fetching from postgres; return results");

Expand All @@ -1869,7 +1885,7 @@ pub async fn full_text_search_over_groups(

timer.add("finish getting sparse vector; start to fetch from qdrant");

let search_chunk_query_results = retrieve_group_qdrant_points_query(
let search_over_groups_qdrant_result = retrieve_group_qdrant_points_query(
VectorType::Sparse(sparse_vector),
page,
data.filters.clone(),
Expand All @@ -1885,14 +1901,30 @@ pub async fn full_text_search_over_groups(

timer.add("finish fetching from qdrant; start to fetch from postgres");

let result_chunks =
retrieve_chunks_for_groups(search_chunk_query_results, &data, pool.clone()).await?;
let mut result_groups_with_chunk_hits = retrieve_chunks_for_groups(
search_over_groups_qdrant_result.clone(),
&data,
pool.clone(),
)
.await?;

result_groups_with_chunk_hits.group_chunks = search_over_groups_qdrant_result
.search_results
.iter()
.filter_map(|search_result| {
result_groups_with_chunk_hits
.group_chunks
.iter()
.find(|group| group.group_id == search_result.group_id)
.cloned()
})
.collect();

timer.add("finish fetching from postgres; return results");

//TODO: rerank for groups

Ok(result_chunks)
Ok(result_groups_with_chunk_hits)
}

async fn cross_encoder_for_groups(
Expand Down

0 comments on commit b1e29be

Please sign in to comment.