Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions olmocr/bench/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,18 @@ to run it against your own OCR tools. Your tool just needs to support Markdown o
<td align="center">99.3</td>
<td align="center">64.5 ± 1.1</td>
</tr>
<tr>
<td align="left">Dots OCR</td>
<td align="center">65.2</td>
<td align="center">69.7</td>
<td align="center"><strong>84.8</strong></td>
<td align="center">38.6</td>
<td align="center">79.5</td>
<td align="center">72.9</td>
<td align="center">46.2</td>
<td align="center">97.8</td>
<td align="center">69.3 ± 1.1</td>
</tr>
<tr>
<td align="left">GPT-4o (No Anchor)</td>
<td align="center">51.5</td>
Expand Down
1 change: 1 addition & 0 deletions olmocr/bench/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,7 @@ async def process_with_semaphore(task):
"docling": ("olmocr.bench.runners.run_docling", "run_docling"),
"rolmocr": ("olmocr.bench.runners.run_rolmocr", "run_rolmocr"),
"paddlepaddle": ("olmocr.bench.runners.run_paddlepaddle", "run_paddlepaddle"),
"dotsocr": ("olmocr.bench.runners.run_dotsocr", "run_dotsocr"),
"transformers": ("olmocr.bench.runners.run_transformers", "run_transformers"),
"server": ("olmocr.bench.runners.run_server", "run_server"),
}
Expand Down
91 changes: 91 additions & 0 deletions olmocr/bench/runners/run_dotsocr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
import base64
from io import BytesIO

import torch
from PIL import Image
from qwen_vl_utils import process_vision_info
from transformers import AutoModelForCausalLM, AutoProcessor

from olmocr.data.renderpdf import render_pdf_to_base64png

_model = None
_processor = None


def load_model(model_name: str = "./weights/DotsOCR"):
"""
Load the DotsOCR model and processor if they haven't been loaded already.

Args:
model_name: Hugging Face model name for DotsOCR

Returns:
model: The DotsOCR model loaded on the appropriate device.
processor: The corresponding processor.
"""
global _model, _processor
if _model is None or _processor is None:
_model = AutoModelForCausalLM.from_pretrained(
model_name, torch_dtype=torch.bfloat16, device_map="auto", attn_implementation="flash_attention_2", low_cpu_mem_usage=True, trust_remote_code=True
)
_processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)
return _model, _processor


def run_dotsocr(pdf_path: str, page_num: int = 1, model_name: str = "./weights/DotsOCR", target_longest_image_dim: int = 1024) -> str:
"""
Convert page of a PDF file to structured layout information using DotsOCR.

This function renders the specified page of the PDF to an image, runs DotsOCR on that image,
and returns the structured layout information as JSON.

Args:
pdf_path (str): The local path to the PDF file.
page_num (int): The page number to process (default: 1).
model_name (str): Hugging Face model name (default: "./weights/DotsOCR").
target_longest_image_dim (int): Target dimension for the longest side of the image (default: 1024).

Returns:
str: The structured layout information in JSON format.
"""
# Ensure the model is loaded (cached across calls)
model, processor = load_model(model_name)

# Convert the specified page of the PDF to a base64-encoded PNG image.
image_base64 = render_pdf_to_base64png(pdf_path, page_num=page_num, target_longest_image_dim=target_longest_image_dim)

# Create PIL Image from base64
image = Image.open(BytesIO(base64.b64decode(image_base64)))

# Define the prompt for layout extraction
prompt = """Extract the text content from this image."""

messages = [{"role": "user", "content": [{"type": "image", "image": image}, {"type": "text", "text": prompt}]}]

# Preparation for inference
text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)

image_inputs, video_inputs = process_vision_info(messages)
inputs = processor(
text=[text],
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt",
)

inputs = inputs.to("cuda")

with torch.no_grad():
generated_ids = model.generate(**inputs, max_new_tokens=4096)

generated_ids_trimmed = [out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)]

output_text = processor.batch_decode(generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False)

del inputs
del generated_ids
del generated_ids_trimmed
torch.cuda.empty_cache()

return output_text[0] if output_text else ""