Skip to content

Commit

Permalink
feature: make pdf2md ocr optional
Browse files Browse the repository at this point in the history
  • Loading branch information
cdxker committed Nov 19, 2024
1 parent 5735ef8 commit 1f655cf
Show file tree
Hide file tree
Showing 4 changed files with 104 additions and 100 deletions.
8 changes: 1 addition & 7 deletions pdf2md/server/src/workers/chunk-worker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -113,13 +113,7 @@ pub async fn chunk_sub_pdf(
.as_slice()
.to_vec();

let result = chunk_sub_pages(
file_data,
task.clone(),
&clickhouse_client,
&redis_pool,
)
.await?;
let result = chunk_sub_pages(file_data, task.clone(), &clickhouse_client, &redis_pool).await?;

log::info!("Got {} pages for {:?}", result.len(), task.task_id);

Expand Down
191 changes: 99 additions & 92 deletions server/src/bin/file-worker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,15 @@ use std::sync::{
};
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt, EnvFilter, Layer};
use trieve_server::{
data::models::{self, ChunkGroup, FileWorkerMessage},
data::models::{self, FileWorkerMessage},
errors::ServiceError,
establish_connection, get_env,
handlers::chunk_handler::ChunkReqPayload,
operators::{
clickhouse_operator::{ClickHouseEvent, EventQueue},
dataset_operator::get_dataset_and_organization_from_dataset_id_query,
file_operator::{
create_file_chunks, create_file_query, get_aws_bucket, preprocess_file_to_chunks,
},
group_operator::{create_group_from_file_query, create_groups_query},
},
};

Expand Down Expand Up @@ -354,113 +352,122 @@ async fn upload_file(
.await?;

if file_name.ends_with(".pdf") {
// Send file to router PDF2MD
let pdf2md_url = std::env::var("PDF2MD_URL")
.expect("PDF2MD_URL must be set")
.to_string();

let pdf2md_auth = std::env::var("PDF2MD_AUTH").unwrap_or("".to_string());

let pdf2md_client = reqwest::Client::new();
let encoded_file = base64::prelude::BASE64_STANDARD.encode(file_data.clone());

let json_value = serde_json::json!({
"base64_file": encoded_file.clone()
});

let pdf2md_response = pdf2md_client
.post(format!("{}/api/task", pdf2md_url))
.header("Content-Type", "application/json")
.header("Authorization", &pdf2md_auth)
.json(&json_value)
.send()
.await
.map_err(|err| {
log::error!("Could not send file to pdf2md {:?}", err);
ServiceError::BadRequest("Could not send file to pdf2md".to_string())
})?;
if let Some(true) = file_worker_message.upload_file_data.use_pdf2md_ocr {
// Send file to router PDF2MD
let pdf2md_url = std::env::var("PDF2MD_URL")
.expect("PDF2MD_URL must be set")
.to_string();

let response = pdf2md_response.json::<CreateFileTaskResponse>().await;
let pdf2md_auth = std::env::var("PDF2MD_AUTH").unwrap_or("".to_string());

let task_id = match response {
Ok(response) => response.task_id,
Err(err) => {
log::error!("Could not parse task_id from pdf2md {:?}", err);
return Err(ServiceError::BadRequest(format!(
"Could not parse task_id from pdf2md {:?}",
err
)));
}
};
let pdf2md_client = reqwest::Client::new();
let encoded_file = base64::prelude::BASE64_STANDARD.encode(file_data.clone());

log::info!("Waiting on Task {}", task_id);
let mut completed_task: Option<PollTaskResponse> = None;
let json_value = serde_json::json!({
"base64_file": encoded_file.clone()
});

loop {
let request = pdf2md_client
.get(format!("{}/api/task/{}", pdf2md_url, task_id).as_str())
log::info!("Sending file to pdf2md");
let pdf2md_response = pdf2md_client
.post(format!("{}/api/task", pdf2md_url))
.header("Content-Type", "application/json")
.header("Authorization", &pdf2md_auth)
.json(&json_value)
.send()
.await
.map_err(|err| {
log::error!("Could not send poll request to pdf2md {:?}", err);
ServiceError::BadRequest(format!("Could not send request to pdf2md {:?}", err))
log::error!("Could not send file to pdf2md {:?}", err);
ServiceError::BadRequest("Could not send file to pdf2md".to_string())
})?;

let response = request.json::<PollTaskResponse>().await.map_err(|err| {
log::error!("Could not parse response from pdf2md {:?}", err);
ServiceError::BadRequest(format!("Could not parse response from pdf2md {:?}", err))
})?;

if (response.status == "Completed" && response.total_document_pages != 0)
&& response.pages.is_some()
{
log::info!("Got job back from task {}", task_id);
completed_task = Some(response);
break;
} else {
log::info!("Polling on task {}... {:?}", task_id, response);
tokio::time::sleep(std::time::Duration::from_secs(5)).await;
continue;
}
}

if let Some(task) = completed_task {
// Poll Chunks from pdf chunks from service
let file_size_mb = (file_data.len() as f64 / 1024.0 / 1024.0).round() as i64;
let created_file = create_file_query(
file_id,
file_size_mb,
file_worker_message.upload_file_data.clone(),
file_worker_message.dataset_id,
web_pool.clone(),
)
.await?;
let response = pdf2md_response.json::<CreateFileTaskResponse>().await;

let mut chunk_htmls: Vec<String> = vec![];
let task_id = match response {
Ok(response) => response.task_id,
Err(err) => {
log::error!("Could not parse task_id from pdf2md {:?}", err);
return Err(ServiceError::BadRequest(format!(
"Could not parse task_id from pdf2md {:?}",
err
)));
}
};

log::info!("Waiting on Task {}", task_id);
#[allow(unused_assignments)]
let mut completed_task: Option<PollTaskResponse> = None;

loop {
let request = pdf2md_client
.get(format!("{}/api/task/{}", pdf2md_url, task_id).as_str())
.header("Content-Type", "application/json")
.header("Authorization", &pdf2md_auth)
.send()
.await
.map_err(|err| {
log::error!("Could not send poll request to pdf2md {:?}", err);
ServiceError::BadRequest(format!(
"Could not send request to pdf2md {:?}",
err
))
})?;

let response = request.json::<PollTaskResponse>().await.map_err(|err| {
log::error!("Could not parse response from pdf2md {:?}", err);
ServiceError::BadRequest(format!(
"Could not parse response from pdf2md {:?}",
err
))
})?;

log::info!("Chunks got {:?}", task);
if let Some(pages) = task.pages {
for page in pages {
chunk_htmls.push(page.content.clone());
if (response.status == "Completed" && response.total_document_pages != 0)
&& response.pages.is_some()
{
log::info!("Got job back from task {}", task_id);
completed_task = Some(response);
break;
} else {
log::info!("Polling on task {}... {:?}", task_id, response);
tokio::time::sleep(std::time::Duration::from_secs(5)).await;
continue;
}
}

log::info!("Chunks got {}", chunk_htmls.len());
if let Some(task) = completed_task {
// Poll Chunks from pdf chunks from service
let file_size_mb = (file_data.len() as f64 / 1024.0 / 1024.0).round() as i64;
let created_file = create_file_query(
file_id,
file_size_mb,
file_worker_message.upload_file_data.clone(),
file_worker_message.dataset_id,
web_pool.clone(),
)
.await?;

create_file_chunks(
created_file.id,
file_worker_message.upload_file_data,
chunk_htmls,
dataset_org_plan_sub,
web_pool.clone(),
event_queue.clone(),
redis_conn,
)
.await?;
let mut chunk_htmls: Vec<String> = vec![];

log::info!("Got {} pages from pdf2ocr", chunk_htmls.len());

return Ok(Some(file_id));
if let Some(pages) = task.pages {
for page in pages {
chunk_htmls.push(page.content.clone());
}
}

create_file_chunks(
created_file.id,
file_worker_message.upload_file_data,
chunk_htmls,
dataset_org_plan_sub,
web_pool.clone(),
event_queue.clone(),
redis_conn,
)
.await?;

return Ok(Some(file_id));
}
}
}

Expand Down
3 changes: 3 additions & 0 deletions server/src/handlers/file_handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,9 @@ pub struct UploadFileReqPayload {
pub target_splits_per_chunk: Option<usize>,
/// Group tracking id is an optional field which allows you to specify the tracking id of the group that is created from the file. Chunks created will be created with the tracking id of `group_tracking_id|<index of chunk>`
pub group_tracking_id: Option<String>,
/// Parameter to use pdf2md_ocr. If true, the file will be converted to markdown using gpt-4o.
/// Default is false.
pub use_pdf2md_ocr: Option<bool>,
}

#[derive(Debug, Serialize, Deserialize, Clone, ToSchema)]
Expand Down
2 changes: 1 addition & 1 deletion server/src/operators/file_operator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ pub async fn create_file_chunks(
) -> Result<(), ServiceError> {
let mut chunks: Vec<ChunkReqPayload> = [].to_vec();

let name = format!("{}", upload_file_data.file_name);
let name = upload_file_data.file_name.clone();

let chunk_group = ChunkGroup::from_details(
Some(name.clone()),
Expand Down

0 comments on commit 1f655cf

Please sign in to comment.