Skip to content

Commit 6a46303

Browse files
authored
feat: use thread lock to prevent racing condition (#430)
This PR puts table model initialization behind a threadlock so it is thread safe. ## testing Use the following script to test (after changing extension) loading model in multithreading env. [test_threading.txt](https://github.com/user-attachments/files/20560179/test_threading.txt)
1 parent 54ed8ae commit 6a46303

File tree

5 files changed

+55
-23
lines changed

5 files changed

+55
-23
lines changed

CHANGELOG.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1-
## 1.0.5-dev0
1+
## 1.0.5
22

3+
* feat: add thread lock to prevent racing condition when instantiating singletons
34
* feat: parametrize edge config for `DetrImageProcessor` with env variables
45

56
## 1.0.4

test_unstructured_inference/models/test_tables.py

Lines changed: 31 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
import os
2+
import threading
3+
from copy import deepcopy
24

35
import numpy as np
46
import pytest
@@ -7,7 +9,6 @@
79
from transformers.models.table_transformer.modeling_table_transformer import (
810
TableTransformerDecoder,
911
)
10-
from copy import deepcopy
1112

1213
import unstructured_inference.models.table_postprocess as postprocess
1314
from unstructured_inference.models import tables
@@ -572,7 +573,7 @@ def test_load_table_model_raises_when_not_available(model_path):
572573

573574

574575
@pytest.mark.parametrize(
575-
"bbox1, bbox2, expected_result",
576+
("bbox1", "bbox2", "expected_result"),
576577
[
577578
((0, 0, 5, 5), (2, 2, 7, 7), 0.36),
578579
((0, 0, 0, 0), (6, 6, 10, 10), 0),
@@ -921,7 +922,9 @@ def test_table_prediction_output_format(
921922
)
922923
if output_format:
923924
result = table_transformer.run_prediction(
924-
example_image, result_format=output_format, ocr_tokens=mocked_ocr_tokens
925+
example_image,
926+
result_format=output_format,
927+
ocr_tokens=mocked_ocr_tokens,
925928
)
926929
else:
927930
result = table_transformer.run_prediction(example_image, ocr_tokens=mocked_ocr_tokens)
@@ -952,7 +955,9 @@ def test_table_prediction_output_format_when_wrong_type_then_value_error(
952955
)
953956
with pytest.raises(ValueError):
954957
table_transformer.run_prediction(
955-
example_image, result_format="Wrong format", ocr_tokens=mocked_ocr_tokens
958+
example_image,
959+
result_format="Wrong format",
960+
ocr_tokens=mocked_ocr_tokens,
956961
)
957962

958963

@@ -991,7 +996,8 @@ def test_table_prediction_with_no_ocr_tokens(table_transformer, example_image):
991996
],
992997
)
993998
def test_objects_are_filtered_based_on_class_thresholds_when_correct_prediction_and_threshold(
994-
thresholds, expected_object_number
999+
thresholds,
1000+
expected_object_number,
9951001
):
9961002
objects = [
9971003
{"label": "0", "score": 0.2},
@@ -1010,7 +1016,8 @@ def test_objects_are_filtered_based_on_class_thresholds_when_correct_prediction_
10101016
],
10111017
)
10121018
def test_objects_are_filtered_based_on_class_thresholds_when_two_classes(
1013-
thresholds, expected_object_number
1019+
thresholds,
1020+
expected_object_number,
10141021
):
10151022
objects = [
10161023
{"label": "0", "score": 0.2},
@@ -1800,7 +1807,7 @@ def test_compute_confidence_score_zero_division_error_handling():
18001807

18011808

18021809
@pytest.mark.parametrize(
1803-
"column_span_score, row_span_score, expected_text_to_indexes",
1810+
("column_span_score", "row_span_score", "expected_text_to_indexes"),
18041811
[
18051812
(
18061813
0.9,
@@ -1827,7 +1834,9 @@ def test_compute_confidence_score_zero_division_error_handling():
18271834
],
18281835
)
18291836
def test_subcells_filtering_when_overlapping_spanning_cells(
1830-
column_span_score, row_span_score, expected_text_to_indexes
1837+
column_span_score,
1838+
row_span_score,
1839+
expected_text_to_indexes,
18311840
):
18321841
"""
18331842
# table
@@ -1894,3 +1903,17 @@ def test_subcells_filtering_when_overlapping_spanning_cells(
18941903

18951904
predicted_cells_after_reorder, _ = structure_to_cells(saved_table_structure, tokens=tokens)
18961905
assert predicted_cells_after_reorder == predicted_cells
1906+
1907+
1908+
def test_model_init_is_thread_safe():
1909+
threads = []
1910+
tables.tables_agent.model = None
1911+
for i in range(5):
1912+
thread = threading.Thread(target=tables.load_agent)
1913+
threads.append(thread)
1914+
thread.start()
1915+
1916+
for thread in threads:
1917+
thread.join()
1918+
1919+
assert tables.tables_agent.model is not None

unstructured_inference/__version__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "1.0.5-dev0" # pragma: no cover
1+
__version__ = "1.0.5" # pragma: no cover

unstructured_inference/models/base.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import json
44
import os
5+
import threading
56
from typing import Dict, Optional, Tuple, Type
67

78
from unstructured_inference.models.detectron2onnx import (
@@ -18,12 +19,15 @@
1819

1920
class Models(object):
2021
_instance = None
22+
_lock = threading.Lock()
2123

2224
def __new__(cls):
2325
"""return an instance if one already exists otherwise create an instance"""
2426
if cls._instance is None:
25-
cls._instance = super(Models, cls).__new__(cls)
26-
cls.models: Dict[str, UnstructuredModel] = {}
27+
with cls._lock:
28+
if cls._instance is None:
29+
cls._instance = super(Models, cls).__new__(cls)
30+
cls.models: Dict[str, UnstructuredModel] = {}
2731
return cls._instance
2832

2933
def __contains__(self, key):

unstructured_inference/models/tables.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# https://github.com/microsoft/table-transformer/blob/main/src/inference.py
22
# https://github.com/NielsRogge/Transformers-Tutorials/blob/master/Table%20Transformer/Using_Table_Transformer_for_table_detection_and_table_structure_recognition.ipynb
3+
import threading
34
import xml.etree.ElementTree as ET
45
from collections import defaultdict
56
from pathlib import Path
@@ -23,20 +24,21 @@
2324

2425
from . import table_postprocess as postprocess
2526

27+
DEFAULT_MODEL = "microsoft/table-transformer-structure-recognition"
28+
2629

2730
class UnstructuredTableTransformerModel(UnstructuredModel):
2831
"""Unstructured model wrapper for table-transformer."""
2932

3033
_instance = None
34+
_lock = threading.Lock()
3135

32-
def __init__(self):
33-
pass
34-
35-
@classmethod
36-
def instance(cls):
36+
def __new__(cls):
3737
"""return an instance if one already exists otherwise create an instance"""
3838
if cls._instance is None:
39-
cls._instance = cls.__new__(cls)
39+
with cls._lock:
40+
if cls._instance is None:
41+
cls._instance = super(UnstructuredTableTransformerModel, cls).__new__(cls)
4042
return cls._instance
4143

4244
def predict(
@@ -70,7 +72,7 @@ def initialize(
7072
):
7173
"""Loads the donut model using the specified parameters"""
7274
self.device = device
73-
self.feature_extractor = DetrImageProcessor.from_pretrained(model)
75+
self.feature_extractor = DetrImageProcessor.from_pretrained(model, device_map=self.device)
7476
# value not set in the configuration and needed for newer models
7577
# https://huggingface.co/microsoft/table-transformer-structure-recognition-v1.1-all/discussions/1
7678
self.feature_extractor.size["shortest_edge"] = inference_config.IMG_PROCESSOR_SHORTEST_EDGE
@@ -145,15 +147,17 @@ def run_prediction(
145147
return prediction
146148

147149

148-
tables_agent: UnstructuredTableTransformerModel = UnstructuredTableTransformerModel.instance()
150+
tables_agent: UnstructuredTableTransformerModel = UnstructuredTableTransformerModel()
149151

150152

151153
def load_agent():
152154
"""Loads the Table agent."""
153155

154-
if not hasattr(tables_agent, "model"):
155-
logger.info("Loading the Table agent ...")
156-
tables_agent.initialize("microsoft/table-transformer-structure-recognition")
156+
if getattr(tables_agent, "model", None) is None:
157+
with tables_agent._lock:
158+
if getattr(tables_agent, "model", None) is None:
159+
logger.info("Loading the Table agent ...")
160+
tables_agent.initialize(DEFAULT_MODEL)
157161

158162
return
159163

0 commit comments

Comments
 (0)