Skip to content

Commit 3483f82

Browse files
skeptrunedevcdxker
authored andcommitted
feature: add suggestion type and context plus bugfix randomness
1 parent 632f29d commit 3483f82

File tree

5 files changed

+86
-17
lines changed

5 files changed

+86
-17
lines changed

server/src/bin/clone-qdrant.rs

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,20 +3,19 @@ use qdrant_client::{
33
Qdrant,
44
};
55
use trieve_server::{
6-
errors::ServiceError, get_env,
7-
operators::qdrant_operator::scroll_qdrant_collection_ids_custom_url,
6+
errors::ServiceError, operators::qdrant_operator::scroll_qdrant_collection_ids_custom_url,
87
};
98
#[allow(clippy::print_stdout)]
109
#[tokio::main]
1110
async fn main() -> Result<(), ServiceError> {
1211
dotenvy::dotenv().ok();
1312

1413
let origin_qdrant_url =
15-
get_env!("ORIGIN_QDRANT_URL", "ORIGIN_QDRANT_URL is not set").to_string();
16-
let new_qdrant_url = get_env!("NEW_QDRANT_URL", "NEW_QDRANT_URL is not set").to_string();
17-
let qdrant_api_key = get_env!("QDRANT_API_KEY", "QDRANT_API_KEY is not set").to_string();
14+
std::env::var("ORIGIN_QDRANT_URL").expect("ORIGIN_QDRANT_URL is not set");
15+
let new_qdrant_url = std::env::var("NEW_QDRANT_URL").expect("NEW_QDRANT_URL is not set");
16+
let qdrant_api_key = std::env::var("QDRANT_API_KEY").expect("QDRANT_API_KEY is not set");
1817
let collection_to_clone =
19-
get_env!("COLLECTION_TO_CLONE", "COLLECTION_TO_CLONE is not set").to_string();
18+
std::env::var("COLLECTION_TO_CLONE").expect("COLLECTION_TO_CLONE is not set");
2019

2120
let mut offset = Some(uuid::Uuid::nil().to_string());
2221

server/src/data/models.rs

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ use itertools::Itertools;
4343
use openai_dive::v1::resources::chat::{ChatMessage, ChatMessageContent, Role};
4444
use qdrant_client::qdrant::{GeoBoundingBox, GeoLineString, GeoPoint, GeoPolygon, GeoRadius};
4545
use qdrant_client::{prelude::Payload, qdrant, qdrant::RetrievedPoint};
46+
use rand::Rng;
4647
use serde::{Deserialize, Serialize};
4748
use serde_json::{json, Value};
4849
use std::collections::HashMap;
@@ -54,6 +55,28 @@ use utoipa::ToSchema;
5455
pub type Pool = diesel_async::pooled_connection::deadpool::Pool<diesel_async::AsyncPgConnection>;
5556
pub type RedisPool = bb8_redis::bb8::Pool<bb8_redis::RedisConnectionManager>;
5657

58+
pub fn uuid_between(uuid1: uuid::Uuid, uuid2: uuid::Uuid) -> uuid::Uuid {
59+
let num1 = u128::from_be_bytes(*uuid1.as_bytes());
60+
let num2 = u128::from_be_bytes(*uuid2.as_bytes());
61+
62+
let (min_num, max_num) = if num1 < num2 {
63+
(num1, num2)
64+
} else {
65+
(num2, num1)
66+
};
67+
68+
let diff = max_num - min_num;
69+
let mut rng = rand::thread_rng();
70+
71+
let random_offset = rng.gen_range(0..=diff);
72+
73+
let result_num = min_num + random_offset;
74+
75+
let result_bytes = result_num.to_be_bytes();
76+
77+
uuid::Uuid::from_bytes(result_bytes)
78+
}
79+
5780
#[derive(Debug, Serialize, Deserialize, Queryable, Insertable, Selectable, Clone, ToSchema)]
5881
#[schema(example = json!({
5982
"id": "e3e3e3e3-e3e3-e3e3-e3e3-e3e3e3e3e3e3",
@@ -4139,6 +4162,17 @@ pub enum SearchMethod {
41394162
BM25,
41404163
}
41414164

4165+
#[derive(Debug, Serialize, Deserialize, ToSchema, Display, Clone, PartialEq)]
4166+
#[serde(rename_all = "lowercase")]
4167+
pub enum SuggestType {
4168+
#[display(fmt = "question")]
4169+
Question,
4170+
#[display(fmt = "keyword")]
4171+
Keyword,
4172+
#[display(fmt = "semantic")]
4173+
Semantic,
4174+
}
4175+
41424176
#[derive(Debug, Serialize, Deserialize, ToSchema, Display, Clone, PartialEq)]
41434177
#[serde(rename_all = "lowercase")]
41444178
pub enum CountSearchMethod {

server/src/handlers/message_handler.rs

Lines changed: 42 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ use super::{
55
use crate::{
66
data::models::{
77
self, ChunkMetadata, DatasetAndOrgWithSubAndPlan, DatasetConfiguration, HighlightOptions,
8-
LLMOptions, Pool, SearchMethod,
8+
LLMOptions, Pool, SearchMethod, SuggestType,
99
},
1010
errors::ServiceError,
1111
get_env,
@@ -604,6 +604,10 @@ pub struct SuggestedQueriesReqPayload {
604604
pub query: Option<String>,
605605
/// Can be either "semantic", "fulltext", "hybrid, or "bm25". If specified as "hybrid", it will pull in one page of both semantic and full-text results then re-rank them using scores from a cross encoder model. "semantic" will pull in one page of the nearest cosine distant vectors. "fulltext" will pull in one page of full-text results based on SPLADE. "bm25" will get one page of results scored using BM25 with the terms OR'd together.
606606
pub search_type: Option<SearchMethod>,
607+
/// Type of suggestions. Can be "question", "keyword", or "semantic". If not specified, this defaults to "keyword".
608+
pub suggestion_type: Option<SuggestType>,
609+
/// Context is the context of the query. This can be any string under 15 words and 200 characters. The context will be used to generate the suggested queries. Defaults to None.
610+
pub context: Option<String>,
607611
/// Filters is a JSON object which can be used to filter chunks. This is useful for when you want to filter chunks by arbitrary metadata. Unlike with tag filtering, there is a performance hit for filtering on metadata.
608612
pub filters: Option<ChunkFilter>,
609613
}
@@ -756,9 +760,36 @@ pub async fn get_suggested_queries(
756760
.collect::<Vec<String>>()
757761
.join("\n\n");
758762

763+
let query_style = match data.suggestion_type.clone().unwrap_or(SuggestType::Keyword) {
764+
SuggestType::Question => "question",
765+
SuggestType::Keyword => "keyword",
766+
SuggestType::Semantic => "semantic while not question",
767+
};
768+
let context_sentence = match data.context.clone() {
769+
Some(context) => {
770+
if context.split_whitespace().count() > 15 || context.len() > 200 {
771+
return Err(ServiceError::BadRequest(
772+
"Context must be under 15 words and 200 characters".to_string(),
773+
));
774+
}
775+
776+
format!(
777+
"\n\nSuggest things with the following context in mind: {}.\n\n",
778+
context
779+
)
780+
}
781+
None => "".to_string(),
782+
};
783+
759784
let content = ChatMessageContent::Text(format!(
760-
"Here is some context for the dataset for which the user is querying for {}. Generate 10 suggested followup keyword searches based off the domain of this dataset. Your only response should be the 10 followup keyword searches which are separated by new lines and are just text and you do not add any other context or information about the followup keyword searches. This should not be a list, so do not number each keyword search. These followup keyword searches should be related to the domain of the dataset.",
761-
rag_content
785+
"Here is some context for the dataset for which the user is querying for {}{}. Generate 10 suggested followup {} style queries based off the domain of this dataset. Your only response should be the 10 followup {} style queries which are separated by new lines and are just text and you do not add any other context or information about the followup {} style queries. This should not be a list, so do not number each {} style queries. These followup {} style queries should be related to the domain of the dataset.",
786+
rag_content,
787+
context_sentence,
788+
query_style,
789+
query_style,
790+
query_style,
791+
query_style,
792+
query_style
762793
));
763794

764795
let message = ChatMessage {
@@ -826,7 +857,14 @@ pub async fn get_suggested_queries(
826857
_ => "".to_string(),
827858
}
828859
.split('\n')
829-
.map(|query| query.to_string().trim().trim_matches('\n').to_string())
860+
.filter_map(|query| {
861+
let cleaned_query = query.to_string().trim().trim_matches('\n').to_string();
862+
if cleaned_query.is_empty() {
863+
None
864+
} else {
865+
Some(cleaned_query)
866+
}
867+
})
830868
.collect();
831869

832870
while queries.len() < 3 {

server/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -360,6 +360,7 @@ impl Modify for SecurityAddon {
360360
data::models::CountSearchMethod,
361361
data::models::SearchMethod,
362362
data::models::SearchType,
363+
data::models::SuggestType,
363364
data::models::ApiKeyRespBody,
364365
data::models::UsageGraphPoint,
365366
data::models::SearchResultType,

server/src/operators/chunk_operator.rs

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
use crate::data::models::{
2-
ChunkData, ChunkGroup, ChunkGroupBookmark, ChunkMetadataTable, ChunkMetadataTags,
2+
uuid_between, ChunkData, ChunkGroup, ChunkGroupBookmark, ChunkMetadataTable, ChunkMetadataTags,
33
ChunkMetadataTypes, ContentChunkMetadata, Dataset, DatasetConfiguration, DatasetTags,
44
IngestSpecificChunkMetadata, SlimChunkMetadata, SlimChunkMetadataTable, UnifiedId,
55
};
@@ -345,7 +345,6 @@ pub async fn get_random_chunk_metadatas_query(
345345
pool: web::Data<Pool>,
346346
) -> Result<Vec<ChunkMetadata>, ServiceError> {
347347
use crate::data::schema::chunk_metadata::dsl as chunk_metadata_columns;
348-
let mut random_uuid = uuid::Uuid::new_v4();
349348

350349
let mut conn = pool
351350
.get()
@@ -355,7 +354,7 @@ pub async fn get_random_chunk_metadatas_query(
355354
let get_lowest_id_future = chunk_metadata_columns::chunk_metadata
356355
.filter(chunk_metadata_columns::dataset_id.eq(dataset_id))
357356
.select(chunk_metadata_columns::id)
358-
.order_by(chunk_metadata_columns::id.desc())
357+
.order_by(chunk_metadata_columns::id.asc())
359358
.first::<uuid::Uuid>(&mut conn);
360359
let get_highest_ids_future = chunk_metadata_columns::chunk_metadata
361360
.filter(chunk_metadata_columns::dataset_id.eq(dataset_id))
@@ -384,14 +383,12 @@ pub async fn get_random_chunk_metadatas_query(
384383
))
385384
}
386385
};
387-
while (random_uuid < lowest_id) || (random_uuid > highest_id) {
388-
random_uuid = uuid::Uuid::new_v4();
389-
}
386+
let random_uuid = uuid_between(lowest_id, highest_id);
390387

391388
let chunk_metadatas: Vec<ChunkMetadataTable> = chunk_metadata_columns::chunk_metadata
392389
.filter(chunk_metadata_columns::dataset_id.eq(dataset_id))
393390
.filter(chunk_metadata_columns::id.gt(random_uuid))
394-
.order_by(chunk_metadata_columns::id.desc())
391+
.order_by(chunk_metadata_columns::id.asc())
395392
.limit(limit.try_into().unwrap_or(10))
396393
.select(ChunkMetadataTable::as_select())
397394
.load::<ChunkMetadataTable>(&mut conn)

0 commit comments

Comments
 (0)