Skip to content

Commit d27729c

Browse files
committed
bugfix: too many batched embeddings causes error
1 parent 89690b4 commit d27729c

File tree

2 files changed

+94
-70
lines changed

2 files changed

+94
-70
lines changed

server/src/bin/ingestion-microservice.rs

Lines changed: 28 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -147,13 +147,12 @@ fn main() {
147147
let web_redis_pool = web_redis_pool.clone();
148148
let should_terminate = Arc::clone(&should_terminate);
149149

150-
tokio::spawn(
151-
async move { ingestion_service(i, should_terminate, web_redis_pool, web_pool).await },
152-
)
150+
tokio::spawn(async move {
151+
ingestion_service(i, should_terminate, web_redis_pool, web_pool).await
152+
})
153153
})
154154
.collect();
155155

156-
157156
while !should_terminate.load(Ordering::Relaxed) {}
158157
log::info!("Shutdown signal received, killing all children...");
159158
futures::future::join_all(threads).await
@@ -171,18 +170,36 @@ async fn ingestion_service(
171170
) {
172171
log::info!("Starting ingestion service thread");
173172

174-
let mut redis_connection = match redis_pool.get().await {
175-
Ok(redis_connection) => redis_connection,
176-
Err(err) => {
177-
log::error!("Failed to get redis connection outside of loop: {:?}", err);
178-
return;
173+
let mut sleep_time = std::time::Duration::from_secs(1);
174+
175+
#[allow(unused_assignments)]
176+
let mut opt_redis_connection = None;
177+
178+
loop {
179+
let borrowed_redis_connection = match redis_pool.get().await {
180+
Ok(redis_connection) => Some(redis_connection),
181+
Err(err) => {
182+
log::error!("Failed to get redis connection outside of loop: {:?}", err);
183+
None
184+
}
185+
};
186+
187+
if borrowed_redis_connection.is_some() {
188+
opt_redis_connection = borrowed_redis_connection;
189+
break;
179190
}
180-
};
191+
192+
tokio::time::sleep(sleep_time).await;
193+
sleep_time = std::cmp::min(sleep_time * 2, std::time::Duration::from_secs(60));
194+
}
195+
196+
let mut redis_connection =
197+
opt_redis_connection.expect("Failed to get redis connection outside of loop");
181198

182199
loop {
183200
if should_terminate.load(Ordering::Relaxed) {
184201
log::info!("Shutting down");
185-
break
202+
break;
186203
}
187204

188205
let payload_result: Result<Vec<String>, redis::RedisError> = redis::cmd("brpoplpush")

server/src/operators/model_operator.rs

Lines changed: 66 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ pub struct EmbeddingParameters {
2121

2222
#[tracing::instrument]
2323
pub async fn create_embeddings(
24-
message: Vec<String>,
24+
messages: Vec<String>,
2525
embed_type: &str,
2626
dataset_config: ServerDatasetConfiguration,
2727
) -> Result<Vec<Vec<f32>>, ServiceError> {
@@ -68,63 +68,67 @@ pub async fn create_embeddings(
6868
organization: None,
6969
};
7070

71-
let clipped_messages = message
72-
.iter()
73-
.map(|msg| {
74-
if msg.len() > 7000 {
75-
msg.chars().take(20000).collect()
76-
} else {
77-
msg.clone()
78-
}
79-
})
80-
.collect::<Vec<String>>();
81-
82-
let input = match embed_type {
83-
"doc" => EmbeddingInput::StringArray(clipped_messages),
84-
"query" => EmbeddingInput::String(
85-
format!(
86-
"{}{}",
87-
dataset_config.EMBEDDING_QUERY_PREFIX,
88-
clipped_messages
89-
.first()
90-
.unwrap_or(&"Arbitrary because query is empty".to_string())
91-
)
92-
.to_string(),
93-
),
94-
_ => EmbeddingInput::StringArray(clipped_messages),
95-
};
96-
97-
// Vectorize
98-
let parameters = EmbeddingParameters {
99-
model: dataset_config.EMBEDDING_MODEL_NAME.to_string(),
100-
input,
101-
};
102-
103-
let embeddings_resp = ureq::post(&format!(
104-
"{}/embeddings?api-version=2023-05-15",
105-
client.base_url
106-
))
107-
.set("Authorization", &format!("Bearer {}", client.api_key))
108-
.set("api-key", &client.api_key)
109-
.set("Content-Type", "application/json")
110-
.send_json(serde_json::to_value(parameters).unwrap())
111-
.map_err(|e| {
112-
ServiceError::InternalServerError(format!(
113-
"Could not get embeddings from server: {:?}, {:?}",
114-
e,
115-
e.to_string()
71+
let mut all_vectors = vec![];
72+
let thirty_message_groups = messages.chunks(30).collect::<Vec<_>>();
73+
74+
for thirty_messages in thirty_message_groups {
75+
let clipped_messages = thirty_messages
76+
.iter()
77+
.map(|msg| {
78+
if msg.len() > 7000 {
79+
msg.chars().take(20000).collect()
80+
} else {
81+
msg.clone()
82+
}
83+
})
84+
.collect::<Vec<String>>();
85+
86+
let input = match embed_type {
87+
"doc" => EmbeddingInput::StringArray(clipped_messages),
88+
"query" => EmbeddingInput::String(
89+
format!(
90+
"{}{}",
91+
dataset_config.EMBEDDING_QUERY_PREFIX,
92+
clipped_messages
93+
.first()
94+
.unwrap_or(&"Arbitrary because query is empty".to_string())
95+
)
96+
.to_string(),
97+
),
98+
_ => EmbeddingInput::StringArray(clipped_messages),
99+
};
100+
101+
// Vectorize
102+
let parameters = EmbeddingParameters {
103+
model: dataset_config.EMBEDDING_MODEL_NAME.to_string(),
104+
input,
105+
};
106+
107+
let embeddings_resp = ureq::post(&format!(
108+
"{}/embeddings?api-version=2023-05-15",
109+
client.base_url
116110
))
117-
})?;
118-
119-
let embeddings: EmbeddingResponse = format_response(embeddings_resp.into_string().unwrap())
111+
.set("Authorization", &format!("Bearer {}", client.api_key))
112+
.set("api-key", &client.api_key)
113+
.set("Content-Type", "application/json")
114+
.send_json(serde_json::to_value(parameters).unwrap())
120115
.map_err(|e| {
121-
log::error!("Failed to format response from embeddings server {:?}", e);
122-
ServiceError::InternalServerError(
123-
"Failed to format response from embeddings server".to_owned(),
124-
)
116+
ServiceError::InternalServerError(format!(
117+
"Could not get embeddings from server: {:?}, {:?}",
118+
e,
119+
e.to_string()
120+
))
125121
})?;
126122

127-
let vectors: Vec<Vec<f32>> = embeddings
123+
let embeddings: EmbeddingResponse = format_response(embeddings_resp.into_string().unwrap())
124+
.map_err(|e| {
125+
log::error!("Failed to format response from embeddings server {:?}", e);
126+
ServiceError::InternalServerError(
127+
"Failed to format response from embeddings server".to_owned(),
128+
)
129+
})?;
130+
131+
let vectors: Vec<Vec<f32>> = embeddings
128132
.data
129133
.into_iter()
130134
.map(|x| match x.embedding {
@@ -136,14 +140,17 @@ pub async fn create_embeddings(
136140
})
137141
.collect();
138142

139-
if vectors.iter().any(|x| x.is_empty()) {
140-
return Err(ServiceError::InternalServerError(
141-
"Embedding server responded with Base64 and that is not currently supported for embeddings".to_owned(),
142-
));
143+
if vectors.iter().any(|x| x.is_empty()) {
144+
return Err(ServiceError::InternalServerError(
145+
"Embedding server responded with Base64 and that is not currently supported for embeddings".to_owned(),
146+
));
147+
}
148+
149+
all_vectors.extend(vectors);
143150
}
144151

145152
transaction.finish();
146-
Ok(vectors)
153+
Ok(all_vectors)
147154
}
148155

149156
#[derive(Debug, Serialize, Deserialize)]

0 commit comments

Comments
 (0)