Skip to content

Commit

Permalink
bugfix: too many batched embeddings causes error
Browse files Browse the repository at this point in the history
  • Loading branch information
skeptrunedev authored and cdxker committed Apr 4, 2024
1 parent 89690b4 commit 3ea123d
Show file tree
Hide file tree
Showing 2 changed files with 94 additions and 70 deletions.
39 changes: 28 additions & 11 deletions server/src/bin/ingestion-microservice.rs
Original file line number Diff line number Diff line change
Expand Up @@ -147,13 +147,12 @@ fn main() {
let web_redis_pool = web_redis_pool.clone();
let should_terminate = Arc::clone(&should_terminate);

tokio::spawn(
async move { ingestion_service(i, should_terminate, web_redis_pool, web_pool).await },
)
tokio::spawn(async move {
ingestion_service(i, should_terminate, web_redis_pool, web_pool).await
})
})
.collect();


while !should_terminate.load(Ordering::Relaxed) {}
log::info!("Shutdown signal received, killing all children...");
futures::future::join_all(threads).await
Expand All @@ -171,18 +170,36 @@ async fn ingestion_service(
) {
log::info!("Starting ingestion service thread");

let mut redis_connection = match redis_pool.get().await {
Ok(redis_connection) => redis_connection,
Err(err) => {
log::error!("Failed to get redis connection outside of loop: {:?}", err);
return;
let mut sleep_time = std::time::Duration::from_secs(1);

#[allow(unused_assignments)]
let mut opt_redis_connection = None;

loop {
let borrowed_redis_connection = match redis_pool.get().await {
Ok(redis_connection) => Some(redis_connection),
Err(err) => {
log::error!("Failed to get redis connection outside of loop: {:?}", err);
None
}
};

if borrowed_redis_connection.is_some() {
opt_redis_connection = borrowed_redis_connection;
break;
}
};

tokio::time::sleep(sleep_time).await;
sleep_time = std::cmp::min(sleep_time * 2, std::time::Duration::from_secs(60));
}

let mut redis_connection =
opt_redis_connection.expect("Failed to get redis connection outside of loop");

loop {
if should_terminate.load(Ordering::Relaxed) {
log::info!("Shutting down");
break
break;
}

let payload_result: Result<Vec<String>, redis::RedisError> = redis::cmd("brpoplpush")
Expand Down
125 changes: 66 additions & 59 deletions server/src/operators/model_operator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ pub struct EmbeddingParameters {

#[tracing::instrument]
pub async fn create_embeddings(
message: Vec<String>,
messages: Vec<String>,
embed_type: &str,
dataset_config: ServerDatasetConfiguration,
) -> Result<Vec<Vec<f32>>, ServiceError> {
Expand Down Expand Up @@ -68,63 +68,67 @@ pub async fn create_embeddings(
organization: None,
};

let clipped_messages = message
.iter()
.map(|msg| {
if msg.len() > 7000 {
msg.chars().take(20000).collect()
} else {
msg.clone()
}
})
.collect::<Vec<String>>();

let input = match embed_type {
"doc" => EmbeddingInput::StringArray(clipped_messages),
"query" => EmbeddingInput::String(
format!(
"{}{}",
dataset_config.EMBEDDING_QUERY_PREFIX,
clipped_messages
.first()
.unwrap_or(&"Arbitrary because query is empty".to_string())
)
.to_string(),
),
_ => EmbeddingInput::StringArray(clipped_messages),
};

// Vectorize
let parameters = EmbeddingParameters {
model: dataset_config.EMBEDDING_MODEL_NAME.to_string(),
input,
};

let embeddings_resp = ureq::post(&format!(
"{}/embeddings?api-version=2023-05-15",
client.base_url
))
.set("Authorization", &format!("Bearer {}", client.api_key))
.set("api-key", &client.api_key)
.set("Content-Type", "application/json")
.send_json(serde_json::to_value(parameters).unwrap())
.map_err(|e| {
ServiceError::InternalServerError(format!(
"Could not get embeddings from server: {:?}, {:?}",
e,
e.to_string()
let mut all_vectors = vec![];
let thirty_message_groups = messages.chunks(30).collect::<Vec<_>>();

for thirty_messages in thirty_message_groups {
let clipped_messages = thirty_messages
.iter()
.map(|msg| {
if msg.len() > 7000 {
msg.chars().take(20000).collect()
} else {
msg.clone()
}
})
.collect::<Vec<String>>();

let input = match embed_type {
"doc" => EmbeddingInput::StringArray(clipped_messages),
"query" => EmbeddingInput::String(
format!(
"{}{}",
dataset_config.EMBEDDING_QUERY_PREFIX,
clipped_messages
.first()
.unwrap_or(&"Arbitrary because query is empty".to_string())
)
.to_string(),
),
_ => EmbeddingInput::StringArray(clipped_messages),
};

// Vectorize
let parameters = EmbeddingParameters {
model: dataset_config.EMBEDDING_MODEL_NAME.to_string(),
input,
};

let embeddings_resp = ureq::post(&format!(
"{}/embeddings?api-version=2023-05-15",
client.base_url
))
})?;

let embeddings: EmbeddingResponse = format_response(embeddings_resp.into_string().unwrap())
.set("Authorization", &format!("Bearer {}", client.api_key))
.set("api-key", &client.api_key)
.set("Content-Type", "application/json")
.send_json(serde_json::to_value(parameters).unwrap())
.map_err(|e| {
log::error!("Failed to format response from embeddings server {:?}", e);
ServiceError::InternalServerError(
"Failed to format response from embeddings server".to_owned(),
)
ServiceError::InternalServerError(format!(
"Could not get embeddings from server: {:?}, {:?}",
e,
e.to_string()
))
})?;

let vectors: Vec<Vec<f32>> = embeddings
let embeddings: EmbeddingResponse = format_response(embeddings_resp.into_string().unwrap())
.map_err(|e| {
log::error!("Failed to format response from embeddings server {:?}", e);
ServiceError::InternalServerError(
"Failed to format response from embeddings server".to_owned(),
)
})?;

let vectors: Vec<Vec<f32>> = embeddings
.data
.into_iter()
.map(|x| match x.embedding {
Expand All @@ -136,14 +140,17 @@ pub async fn create_embeddings(
})
.collect();

if vectors.iter().any(|x| x.is_empty()) {
return Err(ServiceError::InternalServerError(
"Embedding server responded with Base64 and that is not currently supported for embeddings".to_owned(),
));
if vectors.iter().any(|x| x.is_empty()) {
return Err(ServiceError::InternalServerError(
"Embedding server responded with Base64 and that is not currently supported for embeddings".to_owned(),
));
}

all_vectors.extend(vectors);
}

transaction.finish();
Ok(vectors)
Ok(all_vectors)
}

#[derive(Debug, Serialize, Deserialize)]
Expand Down

0 comments on commit 3ea123d

Please sign in to comment.