Skip to content

Commit

Permalink
feature: add suggestion type and context plus bugfix randomness
Browse files Browse the repository at this point in the history
  • Loading branch information
skeptrunedev authored and cdxker committed Aug 27, 2024
1 parent 632f29d commit 3483f82
Show file tree
Hide file tree
Showing 5 changed files with 86 additions and 17 deletions.
11 changes: 5 additions & 6 deletions server/src/bin/clone-qdrant.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,19 @@ use qdrant_client::{
Qdrant,
};
use trieve_server::{
errors::ServiceError, get_env,
operators::qdrant_operator::scroll_qdrant_collection_ids_custom_url,
errors::ServiceError, operators::qdrant_operator::scroll_qdrant_collection_ids_custom_url,
};
#[allow(clippy::print_stdout)]
#[tokio::main]
async fn main() -> Result<(), ServiceError> {
dotenvy::dotenv().ok();

let origin_qdrant_url =
get_env!("ORIGIN_QDRANT_URL", "ORIGIN_QDRANT_URL is not set").to_string();
let new_qdrant_url = get_env!("NEW_QDRANT_URL", "NEW_QDRANT_URL is not set").to_string();
let qdrant_api_key = get_env!("QDRANT_API_KEY", "QDRANT_API_KEY is not set").to_string();
std::env::var("ORIGIN_QDRANT_URL").expect("ORIGIN_QDRANT_URL is not set");
let new_qdrant_url = std::env::var("NEW_QDRANT_URL").expect("NEW_QDRANT_URL is not set");
let qdrant_api_key = std::env::var("QDRANT_API_KEY").expect("QDRANT_API_KEY is not set");
let collection_to_clone =
get_env!("COLLECTION_TO_CLONE", "COLLECTION_TO_CLONE is not set").to_string();
std::env::var("COLLECTION_TO_CLONE").expect("COLLECTION_TO_CLONE is not set");

let mut offset = Some(uuid::Uuid::nil().to_string());

Expand Down
34 changes: 34 additions & 0 deletions server/src/data/models.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ use itertools::Itertools;
use openai_dive::v1::resources::chat::{ChatMessage, ChatMessageContent, Role};
use qdrant_client::qdrant::{GeoBoundingBox, GeoLineString, GeoPoint, GeoPolygon, GeoRadius};
use qdrant_client::{prelude::Payload, qdrant, qdrant::RetrievedPoint};
use rand::Rng;
use serde::{Deserialize, Serialize};
use serde_json::{json, Value};
use std::collections::HashMap;
Expand All @@ -54,6 +55,28 @@ use utoipa::ToSchema;
pub type Pool = diesel_async::pooled_connection::deadpool::Pool<diesel_async::AsyncPgConnection>;
pub type RedisPool = bb8_redis::bb8::Pool<bb8_redis::RedisConnectionManager>;

pub fn uuid_between(uuid1: uuid::Uuid, uuid2: uuid::Uuid) -> uuid::Uuid {
let num1 = u128::from_be_bytes(*uuid1.as_bytes());
let num2 = u128::from_be_bytes(*uuid2.as_bytes());

let (min_num, max_num) = if num1 < num2 {
(num1, num2)
} else {
(num2, num1)
};

let diff = max_num - min_num;
let mut rng = rand::thread_rng();

let random_offset = rng.gen_range(0..=diff);

let result_num = min_num + random_offset;

let result_bytes = result_num.to_be_bytes();

uuid::Uuid::from_bytes(result_bytes)
}

#[derive(Debug, Serialize, Deserialize, Queryable, Insertable, Selectable, Clone, ToSchema)]
#[schema(example = json!({
"id": "e3e3e3e3-e3e3-e3e3-e3e3-e3e3e3e3e3e3",
Expand Down Expand Up @@ -4139,6 +4162,17 @@ pub enum SearchMethod {
BM25,
}

#[derive(Debug, Serialize, Deserialize, ToSchema, Display, Clone, PartialEq)]
#[serde(rename_all = "lowercase")]
pub enum SuggestType {
#[display(fmt = "question")]
Question,
#[display(fmt = "keyword")]
Keyword,
#[display(fmt = "semantic")]
Semantic,
}

#[derive(Debug, Serialize, Deserialize, ToSchema, Display, Clone, PartialEq)]
#[serde(rename_all = "lowercase")]
pub enum CountSearchMethod {
Expand Down
46 changes: 42 additions & 4 deletions server/src/handlers/message_handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use super::{
use crate::{
data::models::{
self, ChunkMetadata, DatasetAndOrgWithSubAndPlan, DatasetConfiguration, HighlightOptions,
LLMOptions, Pool, SearchMethod,
LLMOptions, Pool, SearchMethod, SuggestType,
},
errors::ServiceError,
get_env,
Expand Down Expand Up @@ -604,6 +604,10 @@ pub struct SuggestedQueriesReqPayload {
pub query: Option<String>,
/// Can be either "semantic", "fulltext", "hybrid, or "bm25". If specified as "hybrid", it will pull in one page of both semantic and full-text results then re-rank them using scores from a cross encoder model. "semantic" will pull in one page of the nearest cosine distant vectors. "fulltext" will pull in one page of full-text results based on SPLADE. "bm25" will get one page of results scored using BM25 with the terms OR'd together.
pub search_type: Option<SearchMethod>,
/// Type of suggestions. Can be "question", "keyword", or "semantic". If not specified, this defaults to "keyword".
pub suggestion_type: Option<SuggestType>,
/// Context is the context of the query. This can be any string under 15 words and 200 characters. The context will be used to generate the suggested queries. Defaults to None.
pub context: Option<String>,
/// Filters is a JSON object which can be used to filter chunks. This is useful for when you want to filter chunks by arbitrary metadata. Unlike with tag filtering, there is a performance hit for filtering on metadata.
pub filters: Option<ChunkFilter>,
}
Expand Down Expand Up @@ -756,9 +760,36 @@ pub async fn get_suggested_queries(
.collect::<Vec<String>>()
.join("\n\n");

let query_style = match data.suggestion_type.clone().unwrap_or(SuggestType::Keyword) {
SuggestType::Question => "question",
SuggestType::Keyword => "keyword",
SuggestType::Semantic => "semantic while not question",
};
let context_sentence = match data.context.clone() {
Some(context) => {
if context.split_whitespace().count() > 15 || context.len() > 200 {
return Err(ServiceError::BadRequest(
"Context must be under 15 words and 200 characters".to_string(),
));
}

format!(
"\n\nSuggest things with the following context in mind: {}.\n\n",
context
)
}
None => "".to_string(),
};

let content = ChatMessageContent::Text(format!(
"Here is some context for the dataset for which the user is querying for {}. Generate 10 suggested followup keyword searches based off the domain of this dataset. Your only response should be the 10 followup keyword searches which are separated by new lines and are just text and you do not add any other context or information about the followup keyword searches. This should not be a list, so do not number each keyword search. These followup keyword searches should be related to the domain of the dataset.",
rag_content
"Here is some context for the dataset for which the user is querying for {}{}. Generate 10 suggested followup {} style queries based off the domain of this dataset. Your only response should be the 10 followup {} style queries which are separated by new lines and are just text and you do not add any other context or information about the followup {} style queries. This should not be a list, so do not number each {} style queries. These followup {} style queries should be related to the domain of the dataset.",
rag_content,
context_sentence,
query_style,
query_style,
query_style,
query_style,
query_style
));

let message = ChatMessage {
Expand Down Expand Up @@ -826,7 +857,14 @@ pub async fn get_suggested_queries(
_ => "".to_string(),
}
.split('\n')
.map(|query| query.to_string().trim().trim_matches('\n').to_string())
.filter_map(|query| {
let cleaned_query = query.to_string().trim().trim_matches('\n').to_string();
if cleaned_query.is_empty() {
None
} else {
Some(cleaned_query)
}
})
.collect();

while queries.len() < 3 {
Expand Down
1 change: 1 addition & 0 deletions server/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,7 @@ impl Modify for SecurityAddon {
data::models::CountSearchMethod,
data::models::SearchMethod,
data::models::SearchType,
data::models::SuggestType,
data::models::ApiKeyRespBody,
data::models::UsageGraphPoint,
data::models::SearchResultType,
Expand Down
11 changes: 4 additions & 7 deletions server/src/operators/chunk_operator.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crate::data::models::{
ChunkData, ChunkGroup, ChunkGroupBookmark, ChunkMetadataTable, ChunkMetadataTags,
uuid_between, ChunkData, ChunkGroup, ChunkGroupBookmark, ChunkMetadataTable, ChunkMetadataTags,
ChunkMetadataTypes, ContentChunkMetadata, Dataset, DatasetConfiguration, DatasetTags,
IngestSpecificChunkMetadata, SlimChunkMetadata, SlimChunkMetadataTable, UnifiedId,
};
Expand Down Expand Up @@ -345,7 +345,6 @@ pub async fn get_random_chunk_metadatas_query(
pool: web::Data<Pool>,
) -> Result<Vec<ChunkMetadata>, ServiceError> {
use crate::data::schema::chunk_metadata::dsl as chunk_metadata_columns;
let mut random_uuid = uuid::Uuid::new_v4();

let mut conn = pool
.get()
Expand All @@ -355,7 +354,7 @@ pub async fn get_random_chunk_metadatas_query(
let get_lowest_id_future = chunk_metadata_columns::chunk_metadata
.filter(chunk_metadata_columns::dataset_id.eq(dataset_id))
.select(chunk_metadata_columns::id)
.order_by(chunk_metadata_columns::id.desc())
.order_by(chunk_metadata_columns::id.asc())
.first::<uuid::Uuid>(&mut conn);
let get_highest_ids_future = chunk_metadata_columns::chunk_metadata
.filter(chunk_metadata_columns::dataset_id.eq(dataset_id))
Expand Down Expand Up @@ -384,14 +383,12 @@ pub async fn get_random_chunk_metadatas_query(
))
}
};
while (random_uuid < lowest_id) || (random_uuid > highest_id) {
random_uuid = uuid::Uuid::new_v4();
}
let random_uuid = uuid_between(lowest_id, highest_id);

let chunk_metadatas: Vec<ChunkMetadataTable> = chunk_metadata_columns::chunk_metadata
.filter(chunk_metadata_columns::dataset_id.eq(dataset_id))
.filter(chunk_metadata_columns::id.gt(random_uuid))
.order_by(chunk_metadata_columns::id.desc())
.order_by(chunk_metadata_columns::id.asc())
.limit(limit.try_into().unwrap_or(10))
.select(ChunkMetadataTable::as_select())
.load::<ChunkMetadataTable>(&mut conn)
Expand Down

0 comments on commit 3483f82

Please sign in to comment.