Skip to content

Commit

Permalink
feature: heading based chunking
Browse files Browse the repository at this point in the history
  • Loading branch information
densumesh committed Dec 14, 2024
1 parent a1f0d24 commit 3b169c9
Show file tree
Hide file tree
Showing 5 changed files with 240 additions and 44 deletions.
39 changes: 38 additions & 1 deletion frontends/search/src/components/UploadFile.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,10 @@ export const UploadFile = () => {
const [targetSplitsPerChunk, setTargetSplitsPerChunk] = createSignal(20);
const [rebalanceChunks, setRebalanceChunks] = createSignal(false);
const [useGptChunking, setUseGptChunking] = createSignal(false);
const [useHeadingBasedChunking, setUseHeadingBasedChunking] =
createSignal(false);
const [groupTrackingId, setGroupTrackingId] = createSignal("");
const [systemPrompt, setSystemPrompt] = createSignal("");

const [showFileInput, setShowFileInput] = createSignal(true);
const [showFolderInput, setShowFolderInput] = createSignal(false);
Expand Down Expand Up @@ -149,7 +152,11 @@ export const UploadFile = () => {
split_delimiters: splitDelimiters(),
target_splits_per_chunk: targetSplitsPerChunk(),
rebalance_chunks: rebalanceChunks(),
pdf2md_options: { use_pdf2md_ocr: useGptChunking() },
pdf2md_options: {
use_pdf2md_ocr: useGptChunking(),
split_headings: useHeadingBasedChunking(),
system_prompt: systemPrompt(),
},
group_tracking_id:
groupTrackingId() === "" ? undefined : groupTrackingId(),
// eslint-disable-next-line @typescript-eslint/no-unsafe-assignment
Expand Down Expand Up @@ -343,6 +350,36 @@ export const UploadFile = () => {
onInput={(e) => setUseGptChunking(e.currentTarget.checked)}
class="h-4 w-4 rounded-md border border-gray-300 bg-neutral-100 px-4 py-1 dark:bg-neutral-700"
/>
<div class="flex flex-row items-center space-x-2">
<div>Heading Based Chunking</div>
<Tooltip
body={<BsInfoCircle />}
tooltipText="If set to true, Trieve will use the headings in the document to chunk the text."
/>
</div>
<input
type="checkbox"
checked={useHeadingBasedChunking()}
onInput={(e) =>
setUseHeadingBasedChunking(e.currentTarget.checked)
}
class="h-4 w-4 rounded-md border border-gray-300 bg-neutral-100 px-4 py-1 dark:bg-neutral-700"
/>
<div class="flex flex-col space-y-2">
<div class="flex flex-row items-center space-x-2">
<div>System Prompt</div>
<Tooltip
body={<BsInfoCircle />}
tooltipText="System prompt to use when chunking. This is an optional field which allows you to specify the system prompt to use when chunking the text. If not specified, the default system prompt is used. However, you may want to use a different system prompt."
/>
</div>
<textarea
placeholder="optional system prompt to use when chunking"
value={systemPrompt()}
onInput={(e) => setSystemPrompt(e.target.value)}
class="w-full rounded-md border border-gray-300 bg-neutral-100 px-4 py-1 dark:bg-neutral-700"
/>
</div>
</div>
</Show>
<div class="m-1 mb-1 flex flex-row gap-2">
Expand Down
16 changes: 12 additions & 4 deletions pdf2md/server/src/operators/pdf_chunk.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,17 @@ use regex::Regex;
use s3::creds::time::OffsetDateTime;

const CHUNK_SYSTEM_PROMPT: &str = "
Convert the following PDF page to markdown.
Return only the markdown with no explanation text.
Do not exclude any content from the page.";
Convert this PDF page to markdown formatting, following these requirements:
1. Break the content into logical sections with clear markdown headings (# for main sections, ## for subsections, etc.)
2. Create section headers that accurately reflect the content and hierarchy of each part
3. Include all body content from the page
4. Exclude any PDF headers and footers
5. Return only the formatted markdown without any explanatory text
6. Match the original document's content organization but with explicit markdown structure
Please provide the markdown version using this structured approach.
";

fn get_data_url_from_image(img: DynamicImage) -> Result<String, ServiceError> {
let mut encoded = Vec::new();
Expand Down Expand Up @@ -108,7 +116,7 @@ async fn get_markdown_from_image(
if let Some(prev_md_doc) = prev_md_doc {
let prev_md_doc_message = ChatMessage::System {
content: ChatMessageContent::Text(format!(
"Markdown must maintain consistent formatting with the following page: \n\n {}",
"Markdown must maintain consistent formatting with the following page, DO NOT INCLUDE CONTENT FROM THIS PAGE IN YOUR RESPONSE: \n\n {}",
prev_md_doc
)),
name: None,
Expand Down
26 changes: 26 additions & 0 deletions server/src/bin/file-worker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,19 @@ use trieve_server::{
},
};

const HEADING_CHUNKING_SYSTEM_PROMPT: &str = "
Analyze this PDF page and restructure it into clear markdown sections based on the content topics. For each distinct topic or theme discussed:
1. Create a meaningful section heading using markdown (# for main topics, ## for subtopics)
2. Group related content under each heading
3. Break up dense paragraphs into more readable chunks where appropriate
4. Maintain the key information but organize it by subject matter
5. Skip headers, footers, and page numbers
6. Focus on semantic organization rather than matching the original layout
Please provide just the reorganized markdown without any explanatory text
";

fn main() {
dotenvy::dotenv().ok();
env_logger::builder()
Expand Down Expand Up @@ -330,6 +343,19 @@ async fn upload_file(
json_value["system_prompt"] = serde_json::json!(system_prompt);
}

if file_worker_message
.upload_file_data
.pdf2md_options
.as_ref()
.is_some_and(|options| options.split_headings.unwrap_or(false))
{
json_value["system_prompt"] = serde_json::json!(format!(
"{}\n\n{}",
json_value["system_prompt"].as_str().unwrap_or(""),
HEADING_CHUNKING_SYSTEM_PROMPT
));
}

log::info!("Sending file to pdf2md");
let pdf2md_response = pdf2md_client
.post(format!("{}/api/task", pdf2md_url))
Expand Down
1 change: 1 addition & 0 deletions server/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,7 @@ impl Modify for SecurityAddon {
handlers::file_handler::CreatePresignedUrlForCsvJsonlReqPayload,
handlers::file_handler::CreatePresignedUrlForCsvJsonResponseBody,
handlers::file_handler::UploadHtmlPageReqPayload,
handlers::file_handler::Pdf2MdOptions,
handlers::invitation_handler::InvitationData,
handlers::event_handler::GetEventsData,
handlers::organization_handler::CreateOrganizationReqPayload,
Expand Down
202 changes: 163 additions & 39 deletions server/src/operators/file_operator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,61 @@ pub fn preprocess_file_to_chunks(
Ok(chunk_htmls)
}

pub fn split_markdown_by_headings(markdown_text: &str) -> Vec<String> {
let lines: Vec<&str> = markdown_text
.trim()
.lines()
.filter(|x| !x.trim().is_empty())
.collect();
let mut chunks = Vec::new();
let mut current_content = Vec::new();
let mut pending_heading: Option<String> = None;

fn is_heading(line: &str) -> bool {
line.trim().starts_with('#')
}

fn save_chunk(chunks: &mut Vec<String>, content: &[String]) {
if !content.is_empty() {
chunks.push(content.join("\n").trim().to_string());
}
}

for (i, line) in lines.iter().enumerate() {
if is_heading(line) {
if !current_content.is_empty() {
save_chunk(&mut chunks, &current_content);
current_content.clear();
}

if i + 1 < lines.len() && !is_heading(lines[i + 1]) {
if let Some(heading) = pending_heading.take() {
current_content.push(heading);
}
current_content.push(line.to_string());
} else {
pending_heading = Some(line.to_string());
}
} else if !line.trim().is_empty() || !current_content.is_empty() {
current_content.push(line.to_string());
}
}

if !current_content.is_empty() {
save_chunk(&mut chunks, &current_content);
}

if let Some(heading) = pending_heading {
chunks.push(heading);
}

if chunks.is_empty() && !lines.is_empty() {
chunks.push(lines.join("\n").trim().to_string());
}

chunks
}

#[allow(clippy::too_many_arguments)]
pub async fn create_file_chunks(
created_file_id: uuid::Uuid,
Expand All @@ -134,47 +189,116 @@ pub async fn create_file_chunks(
) -> Result<(), ServiceError> {
let name = upload_file_data.file_name.clone();

let chunk_group = ChunkGroup::from_details(
Some(name.clone()),
upload_file_data.description.clone(),
dataset_org_plan_sub.dataset.id,
upload_file_data.group_tracking_id.clone(),
None,
upload_file_data
.tag_set
.clone()
.map(|tag_set| tag_set.into_iter().map(Some).collect()),
);

let chunk_group_option = create_groups_query(vec![chunk_group], true, pool.clone())
.await
.map_err(|e| {
log::error!("Could not create group {:?}", e);
ServiceError::BadRequest("Could not create group".to_string())
})?
.pop();

let chunk_group = match chunk_group_option {
Some(group) => group,
None => {
return Err(ServiceError::BadRequest(
"Could not create group from file".to_string(),
));
if upload_file_data
.pdf2md_options
.is_some_and(|x| x.split_headings.unwrap_or(false))
{
let mut new_chunks = Vec::new();

for chunk in chunks {
let chunk_group = ChunkGroup::from_details(
Some(format!(
"{}-page-{}",
name,
chunk.metadata.as_ref().unwrap_or(&serde_json::json!({
"page_num": 0
}))["page_num"]
.as_i64()
.unwrap_or(0)
)),
upload_file_data.description.clone(),
dataset_org_plan_sub.dataset.id,
upload_file_data.group_tracking_id.clone(),
None,
upload_file_data
.tag_set
.clone()
.map(|tag_set| tag_set.into_iter().map(Some).collect()),
);

let chunk_group_option = create_groups_query(vec![chunk_group], true, pool.clone())
.await
.map_err(|e| {
log::error!("Could not create group {:?}", e);
ServiceError::BadRequest("Could not create group".to_string())
})?
.pop();

let chunk_group = match chunk_group_option {
Some(group) => group,
None => {
return Err(ServiceError::BadRequest(
"Could not create group from file".to_string(),
));
}
};

let group_id = chunk_group.id;

create_group_from_file_query(group_id, created_file_id, pool.clone())
.await
.map_err(|e| {
log::error!("Could not create group from file {:?}", e);
e
})?;

let split_chunks =
split_markdown_by_headings(chunk.chunk_html.as_ref().unwrap_or(&String::new()));

for (i, split_chunk) in split_chunks.into_iter().enumerate() {
new_chunks.push(ChunkReqPayload {
chunk_html: Some(split_chunk),
tracking_id: chunk.tracking_id.clone().map(|x| format!("{}-{}", x, i)),
group_ids: Some(vec![group_id]),
..chunk.clone()
});
}
}
};

let group_id = chunk_group.id;

chunks.iter_mut().for_each(|chunk| {
chunk.group_ids = Some(vec![group_id]);
});

create_group_from_file_query(group_id, created_file_id, pool.clone())
.await
.map_err(|e| {
log::error!("Could not create group from file {:?}", e);
e
})?;
chunks = new_chunks;
} else {
let chunk_group = ChunkGroup::from_details(
Some(name.clone()),
upload_file_data.description.clone(),
dataset_org_plan_sub.dataset.id,
upload_file_data.group_tracking_id.clone(),
None,
upload_file_data
.tag_set
.clone()
.map(|tag_set| tag_set.into_iter().map(Some).collect()),
);

let chunk_group_option = create_groups_query(vec![chunk_group], true, pool.clone())
.await
.map_err(|e| {
log::error!("Could not create group {:?}", e);
ServiceError::BadRequest("Could not create group".to_string())
})?
.pop();

let chunk_group = match chunk_group_option {
Some(group) => group,
None => {
return Err(ServiceError::BadRequest(
"Could not create group from file".to_string(),
));
}
};

let group_id = chunk_group.id;

chunks.iter_mut().for_each(|chunk| {
chunk.group_ids = Some(vec![group_id]);
});

create_group_from_file_query(group_id, created_file_id, pool.clone())
.await
.map_err(|e| {
log::error!("Could not create group from file {:?}", e);
e
})?;
}

let chunk_count = get_row_count_for_organization_id_query(
dataset_org_plan_sub.organization.organization.id,
Expand Down

0 comments on commit 3b169c9

Please sign in to comment.