Skip to content

Commit a38c087

Browse files
zhangyubo0722zhangyubo0722
andauthored
add ppocr v5 (#15121)
Co-authored-by: zhangyubo0722 <[email protected]>
1 parent 0cc9870 commit a38c087

File tree

6 files changed

+18600
-13
lines changed

6 files changed

+18600
-13
lines changed

configs/rec/PP-FormuaNet/PP-FormulaNet-S.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ Architecture:
3939
in_channels: 3
4040
Transform:
4141
Backbone:
42-
name: PPHGNetV2_B4
42+
name: PPHGNetV2_B4_Formula
4343
class_num: 1024
4444

4545
Head:
Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
Global:
2+
debug: false
3+
use_gpu: true
4+
epoch_num: 75
5+
log_smooth_window: 20
6+
print_batch_step: 10
7+
save_model_dir: ./output/PP-OCRv5_server_rec
8+
save_epoch_step: 1
9+
eval_batch_step: [0, 2000]
10+
cal_metric_during_train: true
11+
calc_epoch_interval: 1
12+
pretrained_model:
13+
checkpoints:
14+
save_inference_dir:
15+
use_visualdl: false
16+
infer_img: doc/imgs_words/ch/word_1.jpg
17+
character_dict_path: ./ppocr/utils/dict/ppocrv5_dict.txt
18+
max_text_length: &max_text_length 25
19+
infer_mode: false
20+
use_space_char: true
21+
distributed: true
22+
save_res_path: ./output/rec/predicts_ppocrv5.txt
23+
d2s_train_image_shape: [3, 48, 320]
24+
25+
26+
Optimizer:
27+
name: Adam
28+
beta1: 0.9
29+
beta2: 0.999
30+
lr:
31+
name: Cosine
32+
learning_rate: 0.0005
33+
warmup_epoch: 1
34+
regularizer:
35+
name: L2
36+
factor: 3.0e-05
37+
38+
39+
Architecture:
40+
model_type: rec
41+
algorithm: SVTR_HGNet
42+
Transform:
43+
Backbone:
44+
name: PPHGNetV2_B4
45+
text_rec: True
46+
Head:
47+
name: MultiHead
48+
head_list:
49+
- CTCHead:
50+
Neck:
51+
name: svtr
52+
dims: 120
53+
depth: 2
54+
hidden_dims: 120
55+
kernel_size: [1, 3]
56+
use_guide: True
57+
Head:
58+
fc_decay: 0.00001
59+
- NRTRHead:
60+
nrtr_dim: 384
61+
max_text_length: *max_text_length
62+
63+
Loss:
64+
name: MultiLoss
65+
loss_config_list:
66+
- CTCLoss:
67+
- NRTRLoss:
68+
69+
PostProcess:
70+
name: CTCLabelDecode
71+
72+
Metric:
73+
name: RecMetric
74+
main_indicator: acc
75+
76+
Train:
77+
dataset:
78+
name: MultiScaleDataSet
79+
ds_width: false
80+
data_dir: ./train_data/
81+
ext_op_transform_idx: 1
82+
label_file_list:
83+
- ./train_data/train_list.txt
84+
transforms:
85+
- DecodeImage:
86+
img_mode: BGR
87+
channel_first: false
88+
- RecAug:
89+
- MultiLabelEncode:
90+
gtc_encode: NRTRLabelEncode
91+
- KeepKeys:
92+
keep_keys:
93+
- image
94+
- label_ctc
95+
- label_gtc
96+
- length
97+
- valid_ratio
98+
sampler:
99+
name: MultiScaleSampler
100+
scales: [[320, 32], [320, 48], [320, 64]]
101+
first_bs: &bs 64
102+
fix_bs: false
103+
divided_factor: [8, 16] # w, h
104+
is_training: True
105+
loader:
106+
shuffle: true
107+
batch_size_per_card: *bs
108+
drop_last: true
109+
num_workers: 16
110+
Eval:
111+
dataset:
112+
name: SimpleDataSet
113+
data_dir: ./train_data
114+
label_file_list:
115+
- ./train_data/val_list.txt
116+
transforms:
117+
- DecodeImage:
118+
img_mode: BGR
119+
channel_first: false
120+
- MultiLabelEncode:
121+
gtc_encode: NRTRLabelEncode
122+
- RecResizeImg:
123+
image_shape: [3, 48, 320]
124+
- KeepKeys:
125+
keep_keys:
126+
- image
127+
- label_ctc
128+
- label_gtc
129+
- length
130+
- valid_ratio
131+
loader:
132+
shuffle: false
133+
drop_last: false
134+
batch_size_per_card: 128
135+
num_workers: 4

ppocr/modeling/backbones/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def build_backbone(config, model_type):
7171
from .rec_repvit import RepSVTR
7272
from .rec_svtrv2 import SVTRv2
7373
from .rec_vary_vit import Vary_VIT_B, Vary_VIT_B_Formula
74-
from .rec_pphgnetv2 import PPHGNetV2_B4
74+
from .rec_pphgnetv2 import PPHGNetV2_B4, PPHGNetV2_B4_Formula
7575

7676
support_dict = [
7777
"MobileNetV1Enhance",
@@ -101,6 +101,7 @@ def build_backbone(config, model_type):
101101
"DonutSwinModel",
102102
"Vary_VIT_B",
103103
"PPHGNetV2_B4",
104+
"PPHGNetV2_B4_Formula",
104105
"Vary_VIT_B_Formula",
105106
]
106107
elif model_type == "e2e":

ppocr/modeling/backbones/rec_pphgnetv2.py

Lines changed: 77 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1061,7 +1061,13 @@ class StemBlock(TheseusLayer):
10611061
"""
10621062

10631063
def __init__(
1064-
self, in_channels, mid_channels, out_channels, use_lab=False, lr_mult=1.0
1064+
self,
1065+
in_channels,
1066+
mid_channels,
1067+
out_channels,
1068+
use_lab=False,
1069+
lr_mult=1.0,
1070+
text_rec=False,
10651071
):
10661072
super().__init__()
10671073
self.stem1 = ConvBNAct(
@@ -1094,7 +1100,7 @@ def __init__(
10941100
in_channels=mid_channels * 2,
10951101
out_channels=mid_channels,
10961102
kernel_size=3,
1097-
stride=2,
1103+
stride=1 if text_rec else 2,
10981104
use_lab=use_lab,
10991105
lr_mult=lr_mult,
11001106
)
@@ -1230,6 +1236,7 @@ def __init__(
12301236
light_block=True,
12311237
kernel_size=3,
12321238
use_lab=False,
1239+
stride=2,
12331240
lr_mult=1.0,
12341241
):
12351242

@@ -1240,7 +1247,7 @@ def __init__(
12401247
in_channels=in_channels,
12411248
out_channels=in_channels,
12421249
kernel_size=3,
1243-
stride=2,
1250+
stride=stride,
12441251
groups=in_channels,
12451252
use_act=False,
12461253
use_lab=use_lab,
@@ -1298,13 +1305,20 @@ def __init__(
12981305
dropout_prob=0.0,
12991306
class_num=1000,
13001307
lr_mult_list=[1.0, 1.0, 1.0, 1.0, 1.0],
1308+
det=False,
1309+
text_rec=False,
1310+
out_indices=None,
13011311
**kwargs,
13021312
):
13031313
super().__init__()
1314+
self.det = det
1315+
self.text_rec = text_rec
13041316
self.use_lab = use_lab
13051317
self.use_last_conv = use_last_conv
13061318
self.class_expand = class_expand
13071319
self.class_num = class_num
1320+
self.out_indices = out_indices if out_indices is not None else [0, 1, 2, 3]
1321+
self.out_channels = []
13081322

13091323
# stem
13101324
self.stem = StemBlock(
@@ -1313,6 +1327,7 @@ def __init__(
13131327
out_channels=stem_channels[2],
13141328
use_lab=use_lab,
13151329
lr_mult=lr_mult_list[0],
1330+
text_rec=text_rec,
13161331
)
13171332

13181333
# stages
@@ -1327,6 +1342,7 @@ def __init__(
13271342
light_block,
13281343
kernel_size,
13291344
layer_num,
1345+
stride,
13301346
) = stage_config[k]
13311347
self.stages.append(
13321348
HGV2_Stage(
@@ -1339,9 +1355,14 @@ def __init__(
13391355
light_block,
13401356
kernel_size,
13411357
use_lab,
1358+
stride,
13421359
lr_mult=lr_mult_list[i + 1],
13431360
)
13441361
)
1362+
if i in self.out_indices:
1363+
self.out_channels.append(out_channels)
1364+
if not self.det:
1365+
self.out_channels = stage_config["stage4"][2]
13451366

13461367
self.avg_pool = AdaptiveAvgPool2D(1)
13471368

@@ -1378,8 +1399,19 @@ def _init_weights(self):
13781399

13791400
def forward(self, x):
13801401
x = self.stem(x)
1381-
for stage in self.stages:
1402+
out = []
1403+
for i, stage in enumerate(self.stages):
13821404
x = stage(x)
1405+
if self.det and i in self.out_indices:
1406+
out.append(x)
1407+
if self.det:
1408+
return out
1409+
1410+
if self.text_rec:
1411+
if self.training:
1412+
x = F.adaptive_avg_pool2d(x, [1, 40])
1413+
else:
1414+
x = F.avg_pool2d(x, [3, 2])
13831415
return x
13841416

13851417

@@ -1479,6 +1511,42 @@ def PPHGNetV2_B3(pretrained=False, use_ssld=False, **kwargs):
14791511
return model
14801512

14811513

1514+
def PPHGNetV2_B4(pretrained=False, use_ssld=False, det=False, text_rec=False, **kwargs):
1515+
"""
1516+
PPHGNetV2_B4
1517+
Args:
1518+
pretrained (bool/str): If `True` load pretrained parameters, `False` otherwise.
1519+
If str, means the path of the pretrained model.
1520+
use_ssld (bool) Whether using ssld pretrained model when pretrained is True.
1521+
Returns:
1522+
model: nn.Layer. Specific `PPHGNetV2_B4` model depends on args.
1523+
"""
1524+
stage_config_rec = {
1525+
# in_channels, mid_channels, out_channels, num_blocks, is_downsample, light_block, kernel_size, layer_num, stride
1526+
"stage1": [48, 48, 128, 1, True, False, 3, 6, [2, 1]],
1527+
"stage2": [128, 96, 512, 1, True, False, 3, 6, [1, 2]],
1528+
"stage3": [512, 192, 1024, 3, True, True, 5, 6, [2, 1]],
1529+
"stage4": [1024, 384, 2048, 1, True, True, 5, 6, [2, 1]],
1530+
}
1531+
1532+
stage_config_det = {
1533+
# in_channels, mid_channels, out_channels, num_blocks, is_downsample, light_block, kernel_size, layer_num
1534+
"stage1": [48, 48, 128, 1, False, False, 3, 6, 2],
1535+
"stage2": [128, 96, 512, 1, True, False, 3, 6, 2],
1536+
"stage3": [512, 192, 1024, 3, True, True, 5, 6, 2],
1537+
"stage4": [1024, 384, 2048, 1, True, True, 5, 6, 2],
1538+
}
1539+
model = PPHGNetV2(
1540+
stem_channels=[3, 32, 48],
1541+
stage_config=stage_config_det if det else stage_config_rec,
1542+
use_lab=False,
1543+
det=det,
1544+
text_rec=text_rec,
1545+
**kwargs,
1546+
)
1547+
return model
1548+
1549+
14821550
def PPHGNetV2_B5(pretrained=False, use_ssld=False, **kwargs):
14831551
"""
14841552
PPHGNetV2_B5
@@ -1527,7 +1595,7 @@ def PPHGNetV2_B6(pretrained=False, use_ssld=False, **kwargs):
15271595
return model
15281596

15291597

1530-
class PPHGNetV2_B4(nn.Layer):
1598+
class PPHGNetV2_B4_Formula(nn.Layer):
15311599
"""
15321600
PPHGNetV2_B4
15331601
Args:
@@ -1543,10 +1611,10 @@ def __init__(self, in_channels=3, class_num=1000):
15431611
self.out_channels = 2048
15441612
stage_config = {
15451613
# in_channels, mid_channels, out_channels, num_blocks, is_downsample, light_block, kernel_size, layer_num
1546-
"stage1": [48, 48, 128, 1, False, False, 3, 6],
1547-
"stage2": [128, 96, 512, 1, True, False, 3, 6],
1548-
"stage3": [512, 192, 1024, 3, True, True, 5, 6],
1549-
"stage4": [1024, 384, 2048, 1, True, True, 5, 6],
1614+
"stage1": [48, 48, 128, 1, False, False, 3, 6, 2],
1615+
"stage2": [128, 96, 512, 1, True, False, 3, 6, 2],
1616+
"stage3": [512, 192, 1024, 3, True, True, 5, 6, 2],
1617+
"stage4": [1024, 384, 2048, 1, True, True, 5, 6, 2],
15501618
}
15511619

15521620
self.pphgnet_b4 = PPHGNetV2(

0 commit comments

Comments
 (0)