Skip to content

Commit

Permalink
bugfix: too many batched embeddings causes error
Browse files Browse the repository at this point in the history
  • Loading branch information
skeptrunedev committed Apr 4, 2024
1 parent 2d592ff commit b10d7aa
Showing 1 changed file with 45 additions and 12 deletions.
57 changes: 45 additions & 12 deletions server/src/bin/ingestion-microservice.rs
Original file line number Diff line number Diff line change
Expand Up @@ -147,13 +147,12 @@ fn main() {
let web_redis_pool = web_redis_pool.clone();
let should_terminate = Arc::clone(&should_terminate);

tokio::spawn(
async move { ingestion_service(i, should_terminate, web_redis_pool, web_pool).await },
)
tokio::spawn(async move {
ingestion_service(i, should_terminate, web_redis_pool, web_pool).await
})
})
.collect();


while !should_terminate.load(Ordering::Relaxed) {}
log::info!("Shutdown signal received, killing all children...");
futures::future::join_all(threads).await
Expand All @@ -171,18 +170,36 @@ async fn ingestion_service(
) {
log::info!("Starting ingestion service thread");

let mut redis_connection = match redis_pool.get().await {
Ok(redis_connection) => redis_connection,
Err(err) => {
log::error!("Failed to get redis connection outside of loop: {:?}", err);
return;
let mut sleep_time = std::time::Duration::from_secs(1);

#[allow(unused_assignments)]
let mut opt_redis_connection = None;

loop {
let borrowed_redis_connection = match redis_pool.get().await {
Ok(redis_connection) => Some(redis_connection),
Err(err) => {
log::error!("Failed to get redis connection outside of loop: {:?}", err);
None
}
};

if borrowed_redis_connection.is_some() {
opt_redis_connection = borrowed_redis_connection;
break;
}
};

tokio::time::sleep(sleep_time).await;
sleep_time = std::cmp::min(sleep_time * 2, std::time::Duration::from_secs(60));
}

let mut redis_connection =
opt_redis_connection.expect("Failed to get redis connection outside of loop");

loop {
if should_terminate.load(Ordering::Relaxed) {
log::info!("Shutting down");
break
break;
}

let payload_result: Result<Vec<String>, redis::RedisError> = redis::cmd("brpoplpush")
Expand Down Expand Up @@ -373,7 +390,23 @@ async fn upload_chunk(
true => {
let chunks = coarse_doc_chunker(content.clone());

let embeddings = create_embeddings(chunks, "doc", dataset_config.clone()).await?;
let thirty_chunks_grouped = chunks.chunks(30).collect::<Vec<_>>();

let mut embeddings = vec![];

for thirty_chunks in thirty_chunks_grouped {
let cur_embeddings =
create_embeddings(thirty_chunks.to_vec(), "doc", dataset_config.clone())
.await
.map_err(|err| {
ServiceError::InternalServerError(format!(
"Failed to create embedding: {:?}",
err
))
})?;

embeddings.extend(cur_embeddings);
}

average_embeddings(embeddings)?
}
Expand Down

0 comments on commit b10d7aa

Please sign in to comment.