Skip to content

Commit ca0fb39

Browse files
authored
update UniHDSA's configurations
1 parent 8b7afb6 commit ca0fb39

File tree

5 files changed

+514
-0
lines changed

5 files changed

+514
-0
lines changed

UniHDSA/README.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
# UniHDSA: A Unified Relation Prediction Approach for Hierarchical Document Structure Analysis
2+
3+
## Introduction
4+
5+
Document structure analysis is essential for understanding both the physical layout and logical structure of documents, aiding in tasks such as information retrieval, document summarization, and knowledge extraction. Hierarchical Document Structure Analysis (HDSA) aims to restore the hierarchical structure of documents created with hierarchical schemas. Traditional approaches either focus on specific subtasks in isolation or use multiple branches to address distinct tasks. In this work, we introduce UniHDSA, a unified relation prediction approach for HDSA that treats various subtasks as relation prediction problems within a consolidated label space. This allows a single module to handle multiple tasks simultaneously, improving efficiency, scalability, and adaptability. Our multimodal Transformer-based system demonstrates state-of-the-art performance on the Comp-HRDoc benchmark and competitive results on the DocLayNet dataset, showcasing the effectiveness of our method across all subtasks.
6+
7+
## Reproduction
8+
9+
This project is built on [detrex](https://github.com/IDEA-Research/detrex/tree/main), a library for computer vision. Due to company policy, we cannot release the code for the model. However, we provide the detailed configuration including the model architecture, training hyperparameters, and data processing methods. We also provide the code for the evaluation of the model.

UniHDSA/configs/data/hdsa_bert.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
from omegaconf import OmegaConf
2+
3+
import detectron2.data.transforms as T
4+
from detectron2.config import LazyCall as L
5+
from detectron2.data import (
6+
build_detection_test_loader,
7+
build_detection_train_loader,
8+
get_detection_dataset_dicts,
9+
)
10+
from projects.unified_layout_analysis_v2.evaluation.unified_layout_evaluation import UniLayoutEvaluator
11+
from projects.unified_layout_analysis_v2.modeling.backbone.bert import TextTokenizer
12+
13+
from detrex.data.dataset_mappers import PODDatasetMapper, pod_transform_gen
14+
from detrex.data.dataset_mappers import HRDocDatasetMapper
15+
16+
dataloader = OmegaConf.create()
17+
18+
dataloader.train = L(build_detection_train_loader)(
19+
dataset=L(get_detection_dataset_dicts)(names="COMP_HRDOC_HR_TRAIN"),
20+
mapper=L(HRDocDatasetMapper)(
21+
augmentation=L(pod_transform_gen)(
22+
min_size_train=(320, 416, 512, 608, 704, 800),
23+
max_size_train=1024,
24+
min_size_train_sampling="choice",
25+
min_size_test=512,
26+
max_size_test=1024,
27+
random_resize_type="ResizeShortestEdge",
28+
random_flip=False,
29+
is_train=True,
30+
),
31+
TextTokenizer=L(TextTokenizer)(
32+
model_type="bert-base-uncased",
33+
text_max_len=512,
34+
input_overlap_stride=0,
35+
),
36+
is_train=True,
37+
image_format="BGR",
38+
),
39+
total_batch_size=1,
40+
num_workers=4,
41+
)
42+
43+
dataloader.test = L(build_detection_test_loader)(
44+
dataset=L(get_detection_dataset_dicts)(names="COMP_HRDOC_HR_TEST"),
45+
mapper=L(HRDocDatasetMapper)(
46+
augmentation=L(pod_transform_gen)(
47+
min_size_train=(320, 416, 512, 608, 704, 800),
48+
max_size_train=1024,
49+
min_size_train_sampling="choice",
50+
min_size_test=512,
51+
max_size_test=1024,
52+
random_resize_type="ResizeShortestEdge",
53+
random_flip=False,
54+
is_train=False,
55+
),
56+
TextTokenizer=L(TextTokenizer)(
57+
model_type="bert-base-uncased",
58+
text_max_len=512,
59+
input_overlap_stride=0,
60+
),
61+
is_train=False,
62+
image_format="BGR",
63+
),
64+
num_workers=4,
65+
)
66+
67+
dataloader.evaluator = L(UniLayoutEvaluator)(
68+
dataset_name="${..test.dataset.names}",
69+
)
Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,186 @@
1+
import torch.nn as nn
2+
import copy
3+
from detrex.layers import PositionEmbeddingSine
4+
from detrex.modeling.backbone import ResNet, BasicStem
5+
from detrex.modeling.neck import ChannelMapper
6+
from detectron2.layers import ShapeSpec
7+
from detectron2.config import LazyCall as L
8+
from detrex.modeling.matcher import HungarianMatcher
9+
10+
from projects.unified_layout_analysis_v2.modeling import (
11+
UniDETRMultiScales,
12+
DabDeformableDetrTransformer,
13+
DabDeformableDetrTransformerEncoder,
14+
DabDeformableDetrTransformerDecoder,
15+
TwoStageCriterion,
16+
DeepStem,
17+
)
18+
19+
from projects.unified_layout_analysis_v2.modeling.uni_relation_prediction_head import (
20+
UniRelationPredictionHead,
21+
HRIPNHead
22+
)
23+
24+
from projects.unified_layout_analysis_v2.modeling.doc_transformer import (
25+
DocTransformerEncoder,
26+
DocTransformer
27+
)
28+
29+
from projects.unified_layout_analysis_v2.modeling.backbone.bert import (
30+
Bert,
31+
TextTokenizer
32+
)
33+
34+
# Define the main model
35+
model = L(UniDETRMultiScales)(
36+
backbone=L(ResNet)(
37+
stem=L(DeepStem)(in_channels=3, out_channels=64, norm="FrozenBN"),
38+
stages=L(ResNet.make_default_stages)(
39+
depth=18,
40+
norm="FrozenBN",
41+
),
42+
out_features=["res2", "res3", "res4", "res5"],
43+
freeze_at=1,
44+
),
45+
position_embedding=L(PositionEmbeddingSine)(
46+
num_pos_feats=128,
47+
temperature=10000,
48+
normalize=True,
49+
offset=-0.5,
50+
),
51+
neck=L(ChannelMapper)(
52+
input_shapes={
53+
"res3": ShapeSpec(channels=128),
54+
"res4": ShapeSpec(channels=256),
55+
"res5": ShapeSpec(channels=512),
56+
},
57+
in_features=["res3", "res4", "res5"],
58+
out_channels=256,
59+
num_outs=4,
60+
kernel_size=1,
61+
norm_layer=L(nn.GroupNorm)(num_groups=32, num_channels=256),
62+
),
63+
transformer=L(DabDeformableDetrTransformer)(
64+
encoder=L(DabDeformableDetrTransformerEncoder)(
65+
embed_dim=256,
66+
num_heads=8,
67+
feedforward_dim=2048,
68+
attn_dropout=0.0,
69+
ffn_dropout=0.0,
70+
num_layers=3,
71+
post_norm=False,
72+
num_feature_levels=4,
73+
),
74+
decoder=L(DabDeformableDetrTransformerDecoder)(
75+
embed_dim=256,
76+
num_heads=8,
77+
feedforward_dim=2048,
78+
attn_dropout=0.0,
79+
ffn_dropout=0.0,
80+
num_layers=3,
81+
return_intermediate=True,
82+
num_feature_levels=4,
83+
),
84+
as_two_stage=True,
85+
num_feature_levels=4,
86+
decoder_in_feature_level=[0, 1, 2, 3],
87+
),
88+
embed_dim=256,
89+
num_classes=14,
90+
num_graphical_classes=2,
91+
num_types=3,
92+
relation_prediction_head=L(UniRelationPredictionHead)(
93+
relation_num_classes=2,
94+
embed_dim=256,
95+
hidden_dim=1024,
96+
), # 0: a->a, 1: intra, 2: inter
97+
aux_loss=True,
98+
criterion=L(TwoStageCriterion)(
99+
num_classes=2,
100+
matcher=L(HungarianMatcher)(
101+
cost_class=2.0,
102+
cost_bbox=5.0,
103+
cost_giou=2.0,
104+
cost_class_type="focal_loss_cost",
105+
alpha=0.25,
106+
gamma=2.0,
107+
),
108+
weight_dict={
109+
"loss_class": 1,
110+
"loss_bbox": 5.0,
111+
"loss_giou": 2.0,
112+
},
113+
loss_class_type="focal_loss",
114+
alpha=0.25,
115+
gamma=2.0,
116+
two_stage_binary_cls=False,
117+
),
118+
as_two_stage=True,
119+
pixel_mean=[123.675, 116.280, 103.530],
120+
pixel_std=[58.395, 57.120, 57.375],
121+
device="cuda",
122+
windows_size=[6,8],
123+
freeze_language_model=False,
124+
)
125+
126+
model.logical_role_relation_prediction_head=L(UniRelationPredictionHead)(
127+
relation_num_classes=1,
128+
embed_dim=256,
129+
hidden_dim=1024,
130+
)
131+
132+
# Update auxiliary loss weight dictionary
133+
base_weight_dict = copy.deepcopy(model.criterion.weight_dict)
134+
if model.aux_loss:
135+
weight_dict = model.criterion.weight_dict
136+
aux_weight_dict = {f"{k}_{i}": v for i in range(model.transformer.decoder.num_layers - 1) for k, v in base_weight_dict.items()}
137+
weight_dict.update(aux_weight_dict)
138+
model.criterion.weight_dict = weight_dict
139+
140+
# Additional loss weight updates
141+
model.criterion.weight_dict.update({
142+
"loss_class_enc": 1.0,
143+
"loss_bbox_enc": 5.0,
144+
"loss_giou_enc": 2.0,
145+
})
146+
147+
# Add document transformer module
148+
model.doc_transformer = L(DocTransformer)(
149+
encoder=L(DocTransformerEncoder)(
150+
embed_dim=256,
151+
num_heads=8,
152+
feedforward_dim=2048,
153+
attn_dropout=0.0,
154+
ffn_dropout=0.0,
155+
num_layers=3,
156+
post_norm=False,
157+
batch_first=True,
158+
),
159+
decoder=None,
160+
)
161+
162+
# Add relation prediction head
163+
model.doc_relation_prediction_head = L(HRIPNHead)(
164+
relation_num_classes=2,
165+
embed_dim=256,
166+
hidden_dim=1024,
167+
)
168+
169+
# Add language model
170+
model.language_model = L(Bert)(
171+
bert_model_type="bert-base-uncased",
172+
text_max_len=512,
173+
input_overlap_stride=0,
174+
output_embedding_dim=1024,
175+
max_batch_size=1,
176+
used_layers=12,
177+
used_hidden_idxs=[12],
178+
hidden_embedding_dim=768,
179+
)
180+
181+
# Add tokenizer
182+
model.tokenizer = L(TextTokenizer)(
183+
model_type="bert-base-uncased",
184+
text_max_len=512,
185+
input_overlap_stride=0,
186+
)

0 commit comments

Comments
 (0)