-
Notifications
You must be signed in to change notification settings - Fork 290
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Loading status checks…
fix: Add EngineConfig & StrategyHandler (#211)
* fix: Add EngineConfig & StrategyHandler * fix: Configs settings
Showing
9 changed files
with
258 additions
and
153 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
from enum import Enum | ||
|
||
from pydantic import BaseModel | ||
from pydantic_settings import BaseSettings, SettingsConfigDict | ||
|
||
|
||
class TextDetConfig(BaseModel): | ||
det_arch: str = "fast_base" | ||
batch_size: int = 2 | ||
assume_straight_pages: bool = True | ||
preserve_aspect_ratio: bool = True | ||
symmetric_pad: bool = True | ||
load_in_8_bit: bool = False | ||
|
||
|
||
class AutoStrategyConfig(BaseModel): | ||
auto_page_threshold: float = 0.6 | ||
auto_document_threshold: float = 0.2 | ||
|
||
|
||
class TextRecoConfig(BaseModel): | ||
reco_arch: str = "crnn_vgg16_bn" | ||
batch_size: int = 512 | ||
|
||
|
||
class DeviceEnum(str, Enum): | ||
CPU = "cpu" | ||
CUDA = "cuda" | ||
COREML = "coreml" | ||
|
||
|
||
class MegaParseConfig(BaseSettings): | ||
""" | ||
Configuration for Megaparse. | ||
""" | ||
|
||
model_config = SettingsConfigDict( | ||
env_prefix="MEGAPARSE_", | ||
env_file=(".env.local", ".env"), | ||
env_nested_delimiter="__", | ||
extra="ignore", | ||
use_enum_values=True, | ||
) | ||
text_det_config: TextDetConfig = TextDetConfig() | ||
text_reco_config: TextRecoConfig = TextRecoConfig() | ||
auto_parse_config: AutoStrategyConfig = AutoStrategyConfig() | ||
device: DeviceEnum = DeviceEnum.CPU |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,108 +1,167 @@ | ||
import logging | ||
import random | ||
import warnings | ||
from pathlib import Path | ||
from typing import Any, List | ||
|
||
import numpy as np | ||
import onnxruntime as rt | ||
import pypdfium2 as pdfium | ||
from megaparse_sdk.schema.parser_config import StrategyEnum | ||
from onnxtr.io import DocumentFile | ||
from onnxtr.models import detection_predictor | ||
from onnxtr.models.engine import EngineConfig | ||
from pypdfium2._helpers.page import PdfPage | ||
|
||
from megaparse.configs.auto import AutoStrategyConfig, DeviceEnum, TextDetConfig | ||
from megaparse.predictor.doctr_layout_detector import LayoutPredictor | ||
from megaparse.predictor.models.base import PageLayout | ||
|
||
logger = logging.getLogger("megaparse") | ||
|
||
|
||
def get_strategy_page( | ||
pdfium_page: PdfPage, onnxtr_page: PageLayout, threshold: float | ||
) -> StrategyEnum: | ||
# assert ( | ||
# p_width == onnxtr_page.dimensions[1] | ||
# and p_height == onnxtr_page.dimensions[0] | ||
# ), "Page dimensions do not match" | ||
text_coords = [] | ||
# Get all the images in the page | ||
for obj in pdfium_page.get_objects(): | ||
if obj.type == 1: | ||
text_coords.append(obj.get_pos()) | ||
|
||
p_width, p_height = int(pdfium_page.get_width()), int(pdfium_page.get_height()) | ||
|
||
pdfium_canva = np.zeros((int(p_height), int(p_width))) | ||
|
||
for coords in text_coords: | ||
# (left,bottom,right, top) | ||
# 0---l--------------R-> y | ||
# | | ||
# B (x0,y0) | ||
# | | ||
# T (x1,y1) | ||
# ^ | ||
# x | ||
x0, y0, x1, y1 = ( | ||
p_height - coords[3], | ||
coords[0], | ||
p_height - coords[1], | ||
coords[2], | ||
class StrategyHandler: | ||
def __init__( | ||
self, | ||
auto_config: AutoStrategyConfig = AutoStrategyConfig(), | ||
text_det_config: TextDetConfig = TextDetConfig(), | ||
device: DeviceEnum = DeviceEnum.CPU, | ||
) -> None: | ||
self.config = auto_config | ||
self.device = device | ||
general_options = rt.SessionOptions() | ||
providers = self._get_providers() | ||
engine_config = EngineConfig( | ||
session_options=general_options, | ||
providers=providers, | ||
) | ||
x0 = max(0, min(p_height, int(x0))) | ||
y0 = max(0, min(p_width, int(y0))) | ||
x1 = max(0, min(p_height, int(x1))) | ||
y1 = max(0, min(p_width, int(y1))) | ||
pdfium_canva[x0:x1, y0:y1] = 1 | ||
|
||
onnxtr_canva = np.zeros((int(p_height), int(p_width))) | ||
for block in onnxtr_page.bboxes: | ||
x0, y0 = block.bbox[0] | ||
x1, y1 = block.bbox[1] | ||
x0 = max(0, min(int(x0 * p_width), int(p_width))) | ||
y0 = max(0, min(int(y0 * p_height), int(p_height))) | ||
x1 = max(0, min(int(x1 * p_width), int(p_width))) | ||
y1 = max(0, min(int(y1 * p_height), int(p_height))) | ||
onnxtr_canva[y0:y1, x0:x1] = 1 | ||
|
||
intersection = np.logical_and(pdfium_canva, onnxtr_canva) | ||
union = np.logical_or(pdfium_canva, onnxtr_canva) | ||
iou = np.sum(intersection) / np.sum(union) | ||
if iou < threshold: | ||
return StrategyEnum.HI_RES | ||
return StrategyEnum.FAST | ||
|
||
|
||
def determine_strategy( | ||
file: str | ||
| Path | ||
| bytes, # FIXME : Careful here on removing BinaryIO (not handled by onnxtr) | ||
threshold_pages_ocr: float, | ||
threshold_per_page: float, | ||
) -> StrategyEnum: | ||
logger.info("Determining strategy...") | ||
need_ocr = 0 | ||
|
||
onnxtr_document = DocumentFile.from_pdf(file) | ||
det_predictor = detection_predictor() | ||
layout_predictor = LayoutPredictor(det_predictor) | ||
|
||
pdfium_document = pdfium.PdfDocument(file) | ||
|
||
onnxtr_document_layout = layout_predictor(onnxtr_document) | ||
|
||
for pdfium_page, onnxtr_page in zip( | ||
pdfium_document, onnxtr_document_layout, strict=True | ||
): | ||
strategy = get_strategy_page( | ||
pdfium_page, onnxtr_page, threshold=threshold_per_page | ||
) | ||
need_ocr += strategy == StrategyEnum.HI_RES | ||
|
||
doc_need_ocr = (need_ocr / len(pdfium_document)) > threshold_pages_ocr | ||
if isinstance(pdfium_document, pdfium.PdfDocument): | ||
pdfium_document.close() | ||
self.det_predictor = detection_predictor( | ||
arch=text_det_config.det_arch, | ||
assume_straight_pages=text_det_config.assume_straight_pages, | ||
preserve_aspect_ratio=text_det_config.preserve_aspect_ratio, | ||
symmetric_pad=text_det_config.symmetric_pad, | ||
batch_size=text_det_config.batch_size, | ||
load_in_8_bit=text_det_config.load_in_8_bit, | ||
engine_cfg=engine_config, | ||
) | ||
|
||
if doc_need_ocr: | ||
logger.info("Using HI_RES strategy") | ||
return StrategyEnum.HI_RES | ||
logger.info("Using FAST strategy") | ||
return StrategyEnum.FAST | ||
def _get_providers(self) -> List[str]: | ||
prov = rt.get_available_providers() | ||
logger.info("Available providers:", prov) | ||
if self.device == DeviceEnum.CUDA: | ||
# TODO: support openvino, directml etc | ||
if "CUDAExecutionProvider" not in prov: | ||
raise ValueError( | ||
"onnxruntime can't find CUDAExecutionProvider in list of available providers" | ||
) | ||
return ["TensorrtExecutionProvider", "CUDAExecutionProvider"] | ||
elif self.device == DeviceEnum.COREML: | ||
if "CoreMLExecutionProvider" not in prov: | ||
raise ValueError( | ||
"onnxruntime can't find CoreMLExecutionProvider in list of available providers" | ||
) | ||
return ["CoreMLExecutionProvider"] | ||
elif self.device == DeviceEnum.CPU: | ||
return ["CPUExecutionProvider"] | ||
else: | ||
warnings.warn( | ||
"Device not supported, using CPU", | ||
UserWarning, | ||
stacklevel=2, | ||
) | ||
return ["CPUExecutionProvider"] | ||
|
||
def get_strategy_page( | ||
self, pdfium_page: PdfPage, onnxtr_page: PageLayout | ||
) -> StrategyEnum: | ||
# assert ( | ||
# p_width == onnxtr_page.dimensions[1] | ||
# and p_height == onnxtr_page.dimensions[0] | ||
# ), "Page dimensions do not match" | ||
text_coords = [] | ||
# Get all the images in the page | ||
for obj in pdfium_page.get_objects(): | ||
if obj.type == 1: | ||
text_coords.append(obj.get_pos()) | ||
|
||
p_width, p_height = int(pdfium_page.get_width()), int(pdfium_page.get_height()) | ||
|
||
pdfium_canva = np.zeros((int(p_height), int(p_width))) | ||
|
||
for coords in text_coords: | ||
# (left,bottom,right, top) | ||
# 0---l--------------R-> y | ||
# | | ||
# B (x0,y0) | ||
# | | ||
# T (x1,y1) | ||
# ^ | ||
# x | ||
x0, y0, x1, y1 = ( | ||
p_height - coords[3], | ||
coords[0], | ||
p_height - coords[1], | ||
coords[2], | ||
) | ||
x0 = max(0, min(p_height, int(x0))) | ||
y0 = max(0, min(p_width, int(y0))) | ||
x1 = max(0, min(p_height, int(x1))) | ||
y1 = max(0, min(p_width, int(y1))) | ||
pdfium_canva[x0:x1, y0:y1] = 1 | ||
|
||
onnxtr_canva = np.zeros((int(p_height), int(p_width))) | ||
for block in onnxtr_page.bboxes: | ||
x0, y0 = block.bbox[0] | ||
x1, y1 = block.bbox[1] | ||
x0 = max(0, min(int(x0 * p_width), int(p_width))) | ||
y0 = max(0, min(int(y0 * p_height), int(p_height))) | ||
x1 = max(0, min(int(x1 * p_width), int(p_width))) | ||
y1 = max(0, min(int(y1 * p_height), int(p_height))) | ||
onnxtr_canva[y0:y1, x0:x1] = 1 | ||
|
||
intersection = np.logical_and(pdfium_canva, onnxtr_canva) | ||
union = np.logical_or(pdfium_canva, onnxtr_canva) | ||
iou = np.sum(intersection) / np.sum(union) | ||
if iou < self.config.auto_page_threshold: | ||
return StrategyEnum.HI_RES | ||
return StrategyEnum.FAST | ||
|
||
def determine_strategy( | ||
self, | ||
file: str | ||
| Path | ||
| bytes, # FIXME : Careful here on removing BinaryIO (not handled by onnxtr) | ||
max_samples: int = 5, | ||
) -> StrategyEnum: | ||
logger.info("Determining strategy...") | ||
need_ocr = 0 | ||
|
||
onnxtr_document = DocumentFile.from_pdf(file) | ||
layout_predictor = LayoutPredictor(self.det_predictor) | ||
pdfium_document = pdfium.PdfDocument(file) | ||
|
||
if len(pdfium_document) > max_samples: | ||
sample_pages_index = random.sample(range(len(onnxtr_document)), max_samples) | ||
onnxtr_document = [onnxtr_document[i] for i in sample_pages_index] | ||
pdfium_document = [pdfium_document[i] for i in sample_pages_index] | ||
|
||
onnxtr_document_layout = layout_predictor(onnxtr_document) | ||
|
||
for pdfium_page, onnxtr_page in zip( | ||
pdfium_document, onnxtr_document_layout, strict=True | ||
): | ||
strategy = self.get_strategy_page(pdfium_page, onnxtr_page) | ||
need_ocr += strategy == StrategyEnum.HI_RES | ||
|
||
doc_need_ocr = ( | ||
need_ocr / len(pdfium_document) | ||
) > self.config.auto_document_threshold | ||
if isinstance(pdfium_document, pdfium.PdfDocument): | ||
pdfium_document.close() | ||
|
||
if doc_need_ocr: | ||
logger.info("Using HI_RES strategy") | ||
return StrategyEnum.HI_RES | ||
logger.info("Using FAST strategy") | ||
return StrategyEnum.FAST |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,30 +1,26 @@ | ||
import os | ||
|
||
import pytest | ||
from megaparse.parser.strategy import determine_strategy | ||
from megaparse.parser.strategy import StrategyHandler | ||
from megaparse_sdk.schema.parser_config import StrategyEnum | ||
from megaparse_sdk.config import MegaParseConfig | ||
|
||
ocr_pdfs = os.listdir("./tests/pdf/ocr") | ||
native_pdfs = os.listdir("./tests/pdf/native") | ||
config = MegaParseConfig() | ||
|
||
strategy_handler = StrategyHandler() | ||
|
||
|
||
@pytest.mark.parametrize("hi_res_pdf", ocr_pdfs) | ||
def test_hi_res_strategy(hi_res_pdf): | ||
strategy = determine_strategy( | ||
strategy = strategy_handler.determine_strategy( | ||
f"./tests/pdf/ocr/{hi_res_pdf}", | ||
threshold_per_page=config.auto_page_threshold, | ||
threshold_pages_ocr=config.auto_document_threshold, | ||
) | ||
assert strategy == StrategyEnum.HI_RES | ||
|
||
|
||
@pytest.mark.parametrize("native_pdf", native_pdfs) | ||
def test_fast_strategy(native_pdf): | ||
strategy = determine_strategy( | ||
strategy = strategy_handler.determine_strategy( | ||
f"./tests/pdf/native/{native_pdf}", | ||
threshold_per_page=config.auto_page_threshold, | ||
threshold_pages_ocr=config.auto_document_threshold, | ||
) | ||
assert strategy == StrategyEnum.FAST |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters