From 5531e2cb11f53941cf83d0c0414fb0ae2be7c6fe Mon Sep 17 00:00:00 2001 From: Dens Sumesh Date: Wed, 8 Jan 2025 10:06:38 -0800 Subject: [PATCH] feature: add prompt to image search --- server/src/data/models.rs | 5 ++++- server/src/operators/message_operator.rs | 8 ++++---- server/src/operators/search_operator.rs | 5 ++++- 3 files changed, 12 insertions(+), 6 deletions(-) diff --git a/server/src/data/models.rs b/server/src/data/models.rs index e4d53248e..4db72c7de 100644 --- a/server/src/data/models.rs +++ b/server/src/data/models.rs @@ -7656,7 +7656,10 @@ impl From<(ParsedQuery, f32)> for MultiQuery { #[derive(Debug, Serialize, Deserialize, ToSchema, Clone, PartialEq)] #[serde(untagged)] pub enum SearchModalities { - Image { image_url: String }, + Image { + image_url: String, + llm_prompt: Option, + }, Text(String), } diff --git a/server/src/operators/message_operator.rs b/server/src/operators/message_operator.rs index bbf397ca2..4be19d564 100644 --- a/server/src/operators/message_operator.rs +++ b/server/src/operators/message_operator.rs @@ -1245,6 +1245,7 @@ pub async fn get_topic_string( pub async fn get_text_from_image( image_url: String, + prompt: Option, dataset: &Dataset, ) -> Result { let dataset_config = DatasetConfiguration::from_json(dataset.server_configuration.clone()); @@ -1286,12 +1287,11 @@ pub async fn get_text_from_image( }, }; + let default_system_prompt = "Please describe the image and turn the description into a search query. DO NOT INCLUDE ANY OTHER CONTEXT OR INFORMATION. JUST OUTPUT THE SEARCH QUERY AND NOTHING ELSE".to_string(); + let messages = vec![ ChatMessage::System { - content: ChatMessageContent::Text( - "Please describe the image and turn the description into a search query. DO NOT INCLUDE ANY OTHER CONTEXT OR INFORMATION. JUST OUTPUT THE SEARCH QUERY AND NOTHING ELSE" - .to_string(), - ), + content: ChatMessageContent::Text(prompt.unwrap_or(default_system_prompt)), name: None, }, ChatMessage::User { diff --git a/server/src/operators/search_operator.rs b/server/src/operators/search_operator.rs index 120ff89e1..0bdd0f429 100644 --- a/server/src/operators/search_operator.rs +++ b/server/src/operators/search_operator.rs @@ -1618,7 +1618,10 @@ pub async fn parse_query( let stop_words = get_stop_words(); let query = match query { SearchModalities::Text(query) => query, - SearchModalities::Image { image_url } => get_text_from_image(image_url, dataset).await?, + SearchModalities::Image { + image_url, + llm_prompt, + } => get_text_from_image(image_url, llm_prompt, dataset).await?, }; let query = match remove_stop_words {