Skip to content

Commit

Permalink
fix: Add EngineConfig & StrategyHandler (#211)
Browse files Browse the repository at this point in the history
* fix: Add EngineConfig & StrategyHandler

* fix: Configs settings
chloedia authored Jan 8, 2025
1 parent 03c7ada commit 2e1c6dd
Showing 9 changed files with 258 additions and 153 deletions.
47 changes: 47 additions & 0 deletions libs/megaparse/src/megaparse/configs/auto.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from enum import Enum

from pydantic import BaseModel
from pydantic_settings import BaseSettings, SettingsConfigDict


class TextDetConfig(BaseModel):
det_arch: str = "fast_base"
batch_size: int = 2
assume_straight_pages: bool = True
preserve_aspect_ratio: bool = True
symmetric_pad: bool = True
load_in_8_bit: bool = False


class AutoStrategyConfig(BaseModel):
auto_page_threshold: float = 0.6
auto_document_threshold: float = 0.2


class TextRecoConfig(BaseModel):
reco_arch: str = "crnn_vgg16_bn"
batch_size: int = 512


class DeviceEnum(str, Enum):
CPU = "cpu"
CUDA = "cuda"
COREML = "coreml"


class MegaParseConfig(BaseSettings):
"""
Configuration for Megaparse.
"""

model_config = SettingsConfigDict(
env_prefix="MEGAPARSE_",
env_file=(".env.local", ".env"),
env_nested_delimiter="__",
extra="ignore",
use_enum_values=True,
)
text_det_config: TextDetConfig = TextDetConfig()
text_reco_config: TextRecoConfig = TextRecoConfig()
auto_parse_config: AutoStrategyConfig = AutoStrategyConfig()
device: DeviceEnum = DeviceEnum.CPU
2 changes: 1 addition & 1 deletion libs/megaparse/src/megaparse/examples/parse_file.py
Original file line number Diff line number Diff line change
@@ -7,7 +7,7 @@ def main():
parser = UnstructuredParser()
megaparse = MegaParse(parser=parser)

file_path = "./tests/pdf/ocr/0168126.pdf"
file_path = "./tests/pdf/native/0168029.pdf"

parsed_file = megaparse.load(file_path)
print(f"\n----- File Response : {file_path} -----\n")
37 changes: 23 additions & 14 deletions libs/megaparse/src/megaparse/megaparse.py
Original file line number Diff line number Diff line change
@@ -4,36 +4,51 @@
from pathlib import Path
from typing import IO, BinaryIO

from megaparse_sdk.config import MegaParseConfig
from megaparse.configs.auto import DeviceEnum, MegaParseConfig
from megaparse_sdk.schema.extensions import FileExtension
from megaparse_sdk.schema.parser_config import StrategyEnum

from megaparse.checker.format_checker import FormatChecker
from megaparse.exceptions.base import ParsingException
from megaparse.parser.base import BaseParser
from megaparse.parser.doctr_parser import DoctrParser
from megaparse.parser.strategy import determine_strategy
from megaparse.parser.strategy import StrategyHandler
from megaparse.parser.unstructured_parser import UnstructuredParser

logger = logging.getLogger("megaparse")


class MegaParse:
config: MegaParseConfig = MegaParseConfig()
config = MegaParseConfig()

def __init__(
self,
parser: BaseParser = UnstructuredParser(strategy=StrategyEnum.FAST),
ocr_parser: BaseParser = DoctrParser(),
parser: BaseParser | None = None,
ocr_parser: BaseParser | None = None,
strategy: StrategyEnum = StrategyEnum.AUTO,
format_checker: FormatChecker | None = None,
) -> None:
if not parser:
parser = UnstructuredParser(strategy=StrategyEnum.FAST)
if not ocr_parser:
ocr_parser = DoctrParser(
text_det_config=self.config.text_det_config,
text_reco_config=self.config.text_reco_config,
device=self.config.device,
)

self.strategy = strategy
self.parser = parser
self.ocr_parser = ocr_parser
self.format_checker = format_checker
self.last_parsed_document: str = ""

self.strategy_handler = StrategyHandler(
text_det_config=self.config.text_det_config,
auto_config=self.config.auto_parse_config,
device=self.config.device,
)

def validate_input(
self,
file_path: Path | str | None = None,
@@ -132,17 +147,11 @@ def _select_parser(
if self.strategy != StrategyEnum.AUTO or file_extension != FileExtension.PDF:
return self.parser
if file:
local_strategy = determine_strategy(
file=file,
threshold_pages_ocr=self.config.auto_document_threshold,
threshold_per_page=self.config.auto_page_threshold,
local_strategy = self.strategy_handler.determine_strategy(
file=file, # type: ignore #FIXME: Careful here on removing BinaryIO (not handled by onnxtr)
)
if file_path:
local_strategy = determine_strategy(
file=file_path,
threshold_pages_ocr=self.config.auto_document_threshold,
threshold_per_page=self.config.auto_page_threshold,
)
local_strategy = self.strategy_handler.determine_strategy(file=file_path)

if local_strategy == StrategyEnum.HI_RES:
return self.ocr_parser
39 changes: 25 additions & 14 deletions libs/megaparse/src/megaparse/parser/doctr_parser.py
Original file line number Diff line number Diff line change
@@ -3,6 +3,7 @@
from pathlib import Path
from typing import IO, BinaryIO, List

from megaparse.configs.auto import DeviceEnum, TextRecoConfig, TextDetConfig
import onnxruntime as rt
from megaparse_sdk.schema.extensions import FileExtension
from onnxtr.io import DocumentFile
@@ -19,16 +20,13 @@ class DoctrParser(BaseParser):

def __init__(
self,
det_predictor_model: str = "db_resnet50",
reco_predictor_model: str = "crnn_vgg16_bn",
det_bs: int = 2,
reco_bs: int = 512,
assume_straight_pages: bool = True,
text_det_config: TextDetConfig = TextDetConfig(),
text_reco_config: TextRecoConfig = TextRecoConfig(),
device: DeviceEnum = DeviceEnum.CPU,
straighten_pages: bool = False,
use_gpu: bool = False,
**kwargs,
):
self.use_gpu = use_gpu
self.device = device
general_options = rt.SessionOptions()
providers = self._get_providers()
engine_config = EngineConfig(
@@ -37,11 +35,11 @@ def __init__(
)
# TODO: set in config or pass as kwargs
self.predictor = ocr_predictor(
det_arch=det_predictor_model,
reco_arch=reco_predictor_model,
det_bs=det_bs,
reco_bs=reco_bs,
assume_straight_pages=assume_straight_pages,
det_arch=text_det_config.det_arch,
reco_arch=text_reco_config.reco_arch,
det_bs=text_det_config.batch_size,
reco_bs=text_reco_config.batch_size,
assume_straight_pages=text_det_config.assume_straight_pages,
straighten_pages=straighten_pages,
# Preprocessing related parameters
det_engine_cfg=engine_config,
@@ -53,14 +51,27 @@ def __init__(
def _get_providers(self) -> List[str]:
prov = rt.get_available_providers()
logger.info("Available providers:", prov)
if self.use_gpu:
if self.device == DeviceEnum.CUDA:
# TODO: support openvino, directml etc
if "CUDAExecutionProvider" not in prov:
raise ValueError(
"onnxruntime can't find CUDAExecutionProvider in list of available providers"
)
return ["CUDAExecutionProvider"]
return ["TensorrtExecutionProvider", "CUDAExecutionProvider"]
elif self.device == DeviceEnum.COREML:
if "CoreMLExecutionProvider" not in prov:
raise ValueError(
"onnxruntime can't find CoreMLExecutionProvider in list of available providers"
)
return ["CoreMLExecutionProvider"]
elif self.device == DeviceEnum.CPU:
return ["CPUExecutionProvider"]
else:
warnings.warn(
"Device not supported, using CPU",
UserWarning,
stacklevel=2,
)
return ["CPUExecutionProvider"]

def convert(
235 changes: 147 additions & 88 deletions libs/megaparse/src/megaparse/parser/strategy.py
Original file line number Diff line number Diff line change
@@ -1,108 +1,167 @@
import logging
import random
import warnings
from pathlib import Path
from typing import Any, List

import numpy as np
import onnxruntime as rt
import pypdfium2 as pdfium
from megaparse_sdk.schema.parser_config import StrategyEnum
from onnxtr.io import DocumentFile
from onnxtr.models import detection_predictor
from onnxtr.models.engine import EngineConfig
from pypdfium2._helpers.page import PdfPage

from megaparse.configs.auto import AutoStrategyConfig, DeviceEnum, TextDetConfig
from megaparse.predictor.doctr_layout_detector import LayoutPredictor
from megaparse.predictor.models.base import PageLayout

logger = logging.getLogger("megaparse")


def get_strategy_page(
pdfium_page: PdfPage, onnxtr_page: PageLayout, threshold: float
) -> StrategyEnum:
# assert (
# p_width == onnxtr_page.dimensions[1]
# and p_height == onnxtr_page.dimensions[0]
# ), "Page dimensions do not match"
text_coords = []
# Get all the images in the page
for obj in pdfium_page.get_objects():
if obj.type == 1:
text_coords.append(obj.get_pos())

p_width, p_height = int(pdfium_page.get_width()), int(pdfium_page.get_height())

pdfium_canva = np.zeros((int(p_height), int(p_width)))

for coords in text_coords:
# (left,bottom,right, top)
# 0---l--------------R-> y
# |
# B (x0,y0)
# |
# T (x1,y1)
# ^
# x
x0, y0, x1, y1 = (
p_height - coords[3],
coords[0],
p_height - coords[1],
coords[2],
class StrategyHandler:
def __init__(
self,
auto_config: AutoStrategyConfig = AutoStrategyConfig(),
text_det_config: TextDetConfig = TextDetConfig(),
device: DeviceEnum = DeviceEnum.CPU,
) -> None:
self.config = auto_config
self.device = device
general_options = rt.SessionOptions()
providers = self._get_providers()
engine_config = EngineConfig(
session_options=general_options,
providers=providers,
)
x0 = max(0, min(p_height, int(x0)))
y0 = max(0, min(p_width, int(y0)))
x1 = max(0, min(p_height, int(x1)))
y1 = max(0, min(p_width, int(y1)))
pdfium_canva[x0:x1, y0:y1] = 1

onnxtr_canva = np.zeros((int(p_height), int(p_width)))
for block in onnxtr_page.bboxes:
x0, y0 = block.bbox[0]
x1, y1 = block.bbox[1]
x0 = max(0, min(int(x0 * p_width), int(p_width)))
y0 = max(0, min(int(y0 * p_height), int(p_height)))
x1 = max(0, min(int(x1 * p_width), int(p_width)))
y1 = max(0, min(int(y1 * p_height), int(p_height)))
onnxtr_canva[y0:y1, x0:x1] = 1

intersection = np.logical_and(pdfium_canva, onnxtr_canva)
union = np.logical_or(pdfium_canva, onnxtr_canva)
iou = np.sum(intersection) / np.sum(union)
if iou < threshold:
return StrategyEnum.HI_RES
return StrategyEnum.FAST


def determine_strategy(
file: str
| Path
| bytes, # FIXME : Careful here on removing BinaryIO (not handled by onnxtr)
threshold_pages_ocr: float,
threshold_per_page: float,
) -> StrategyEnum:
logger.info("Determining strategy...")
need_ocr = 0

onnxtr_document = DocumentFile.from_pdf(file)
det_predictor = detection_predictor()
layout_predictor = LayoutPredictor(det_predictor)

pdfium_document = pdfium.PdfDocument(file)

onnxtr_document_layout = layout_predictor(onnxtr_document)

for pdfium_page, onnxtr_page in zip(
pdfium_document, onnxtr_document_layout, strict=True
):
strategy = get_strategy_page(
pdfium_page, onnxtr_page, threshold=threshold_per_page
)
need_ocr += strategy == StrategyEnum.HI_RES

doc_need_ocr = (need_ocr / len(pdfium_document)) > threshold_pages_ocr
if isinstance(pdfium_document, pdfium.PdfDocument):
pdfium_document.close()
self.det_predictor = detection_predictor(
arch=text_det_config.det_arch,
assume_straight_pages=text_det_config.assume_straight_pages,
preserve_aspect_ratio=text_det_config.preserve_aspect_ratio,
symmetric_pad=text_det_config.symmetric_pad,
batch_size=text_det_config.batch_size,
load_in_8_bit=text_det_config.load_in_8_bit,
engine_cfg=engine_config,
)

if doc_need_ocr:
logger.info("Using HI_RES strategy")
return StrategyEnum.HI_RES
logger.info("Using FAST strategy")
return StrategyEnum.FAST
def _get_providers(self) -> List[str]:
prov = rt.get_available_providers()
logger.info("Available providers:", prov)
if self.device == DeviceEnum.CUDA:
# TODO: support openvino, directml etc
if "CUDAExecutionProvider" not in prov:
raise ValueError(
"onnxruntime can't find CUDAExecutionProvider in list of available providers"
)
return ["TensorrtExecutionProvider", "CUDAExecutionProvider"]
elif self.device == DeviceEnum.COREML:
if "CoreMLExecutionProvider" not in prov:
raise ValueError(
"onnxruntime can't find CoreMLExecutionProvider in list of available providers"
)
return ["CoreMLExecutionProvider"]
elif self.device == DeviceEnum.CPU:
return ["CPUExecutionProvider"]
else:
warnings.warn(
"Device not supported, using CPU",
UserWarning,
stacklevel=2,
)
return ["CPUExecutionProvider"]

def get_strategy_page(
self, pdfium_page: PdfPage, onnxtr_page: PageLayout
) -> StrategyEnum:
# assert (
# p_width == onnxtr_page.dimensions[1]
# and p_height == onnxtr_page.dimensions[0]
# ), "Page dimensions do not match"
text_coords = []
# Get all the images in the page
for obj in pdfium_page.get_objects():
if obj.type == 1:
text_coords.append(obj.get_pos())

p_width, p_height = int(pdfium_page.get_width()), int(pdfium_page.get_height())

pdfium_canva = np.zeros((int(p_height), int(p_width)))

for coords in text_coords:
# (left,bottom,right, top)
# 0---l--------------R-> y
# |
# B (x0,y0)
# |
# T (x1,y1)
# ^
# x
x0, y0, x1, y1 = (
p_height - coords[3],
coords[0],
p_height - coords[1],
coords[2],
)
x0 = max(0, min(p_height, int(x0)))
y0 = max(0, min(p_width, int(y0)))
x1 = max(0, min(p_height, int(x1)))
y1 = max(0, min(p_width, int(y1)))
pdfium_canva[x0:x1, y0:y1] = 1

onnxtr_canva = np.zeros((int(p_height), int(p_width)))
for block in onnxtr_page.bboxes:
x0, y0 = block.bbox[0]
x1, y1 = block.bbox[1]
x0 = max(0, min(int(x0 * p_width), int(p_width)))
y0 = max(0, min(int(y0 * p_height), int(p_height)))
x1 = max(0, min(int(x1 * p_width), int(p_width)))
y1 = max(0, min(int(y1 * p_height), int(p_height)))
onnxtr_canva[y0:y1, x0:x1] = 1

intersection = np.logical_and(pdfium_canva, onnxtr_canva)
union = np.logical_or(pdfium_canva, onnxtr_canva)
iou = np.sum(intersection) / np.sum(union)
if iou < self.config.auto_page_threshold:
return StrategyEnum.HI_RES
return StrategyEnum.FAST

def determine_strategy(
self,
file: str
| Path
| bytes, # FIXME : Careful here on removing BinaryIO (not handled by onnxtr)
max_samples: int = 5,
) -> StrategyEnum:
logger.info("Determining strategy...")
need_ocr = 0

onnxtr_document = DocumentFile.from_pdf(file)
layout_predictor = LayoutPredictor(self.det_predictor)
pdfium_document = pdfium.PdfDocument(file)

if len(pdfium_document) > max_samples:
sample_pages_index = random.sample(range(len(onnxtr_document)), max_samples)
onnxtr_document = [onnxtr_document[i] for i in sample_pages_index]
pdfium_document = [pdfium_document[i] for i in sample_pages_index]

onnxtr_document_layout = layout_predictor(onnxtr_document)

for pdfium_page, onnxtr_page in zip(
pdfium_document, onnxtr_document_layout, strict=True
):
strategy = self.get_strategy_page(pdfium_page, onnxtr_page)
need_ocr += strategy == StrategyEnum.HI_RES

doc_need_ocr = (
need_ocr / len(pdfium_document)
) > self.config.auto_document_threshold
if isinstance(pdfium_document, pdfium.PdfDocument):
pdfium_document.close()

if doc_need_ocr:
logger.info("Using HI_RES strategy")
return StrategyEnum.HI_RES
logger.info("Using FAST strategy")
return StrategyEnum.FAST
Original file line number Diff line number Diff line change
@@ -1,19 +1,21 @@
import logging
from typing import Any, List

import numpy as np
from megaparse.predictor.models.base import (
BlockLayout,
PageLayout,
BBOX,
Point2D,
BlockLayout,
BlockType,
PageLayout,
)
from onnxtr.models.detection.predictor import DetectionPredictor
from onnxtr.models.engine import EngineConfig
from onnxtr.models.predictor.base import _OCRPredictor
from onnxtr.utils.geometry import detach_scores
from onnxtr.utils.repr import NestedObject

logger = logging.getLogger("megaparse")


class LayoutPredictor(NestedObject, _OCRPredictor):
"""Implements an object able to localize and identify text elements in a set of documents
@@ -42,6 +44,7 @@ def __init__(
preserve_aspect_ratio: bool = True,
symmetric_pad: bool = True,
detect_orientation: bool = False,
use_gpu: bool = False,
clf_engine_cfg: EngineConfig | None = None,
**kwargs: Any,
):
14 changes: 5 additions & 9 deletions libs/megaparse/tests/pdf/test_detect_ocr.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,26 @@
import os

import pytest
from megaparse.parser.strategy import determine_strategy
from megaparse.parser.strategy import StrategyHandler
from megaparse_sdk.schema.parser_config import StrategyEnum
from megaparse_sdk.config import MegaParseConfig

ocr_pdfs = os.listdir("./tests/pdf/ocr")
native_pdfs = os.listdir("./tests/pdf/native")
config = MegaParseConfig()

strategy_handler = StrategyHandler()


@pytest.mark.parametrize("hi_res_pdf", ocr_pdfs)
def test_hi_res_strategy(hi_res_pdf):
strategy = determine_strategy(
strategy = strategy_handler.determine_strategy(
f"./tests/pdf/ocr/{hi_res_pdf}",
threshold_per_page=config.auto_page_threshold,
threshold_pages_ocr=config.auto_document_threshold,
)
assert strategy == StrategyEnum.HI_RES


@pytest.mark.parametrize("native_pdf", native_pdfs)
def test_fast_strategy(native_pdf):
strategy = determine_strategy(
strategy = strategy_handler.determine_strategy(
f"./tests/pdf/native/{native_pdf}",
threshold_per_page=config.auto_page_threshold,
threshold_pages_ocr=config.auto_document_threshold,
)
assert strategy == StrategyEnum.FAST
13 changes: 4 additions & 9 deletions libs/megaparse/tests/pdf/test_pdf_processing.py
Original file line number Diff line number Diff line change
@@ -2,13 +2,12 @@

import pytest
from megaparse.megaparse import MegaParse
from megaparse.parser.strategy import determine_strategy
from megaparse.parser.strategy import StrategyHandler
from megaparse.parser.unstructured_parser import UnstructuredParser
from megaparse_sdk.config import MegaParseConfig
from megaparse_sdk.schema.extensions import FileExtension
from megaparse_sdk.schema.parser_config import StrategyEnum

config = MegaParseConfig()
strategy_handler = StrategyHandler()


@pytest.fixture
@@ -56,16 +55,12 @@ async def test_megaparse_pdf_processor_file(pdf_name, request):


def test_strategy(scanned_pdf, native_pdf):
strategy = determine_strategy(
strategy = strategy_handler.determine_strategy(
scanned_pdf,
threshold_per_page=config.auto_page_threshold,
threshold_pages_ocr=config.auto_document_threshold,
)
assert strategy == StrategyEnum.HI_RES

strategy = determine_strategy(
strategy = strategy_handler.determine_strategy(
native_pdf,
threshold_per_page=config.auto_page_threshold,
threshold_pages_ocr=config.auto_document_threshold,
)
assert strategy == StrategyEnum.FAST
15 changes: 0 additions & 15 deletions libs/megaparse_sdk/megaparse_sdk/config.py
Original file line number Diff line number Diff line change
@@ -14,21 +14,6 @@ class MegaParseSDKConfig(BaseSettings):
max_retries: int = 3


class MegaParseConfig(BaseSettings):
"""
Configuration for Megaparse.
"""

model_config = SettingsConfigDict(
env_prefix="MEGAPARSE_",
env_file=(".env.local", ".env"),
env_nested_delimiter="__",
extra="ignore",
)
auto_page_threshold: float = 0.6
auto_document_threshold: float = 0.2


class SSLConfig(BaseModel):
ssl_key_file: FilePath
ssl_cert_file: FilePath

0 comments on commit 2e1c6dd

Please sign in to comment.