Skip to content

Commit 67a045c

Browse files
linjiecccsijunhe
andauthored
[ModelZoo] Refactor ERNIE-Layout Usage and add unittest (PaddlePaddle#4170)
* Add unittest * Apply suggestions from code review * Delete test_tokenizer.py * Apply suggestions from code review Co-authored-by: Sijun He <[email protected]>
1 parent dabf033 commit 67a045c

File tree

9 files changed

+771
-198
lines changed

9 files changed

+771
-198
lines changed

model_zoo/ernie-layout/run_cls.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,23 +13,22 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16-
import os
1716
import collections
17+
import os
1818
from functools import partial
1919

20-
import paddle
21-
from paddlenlp.trainer import PdArgumentParser, TrainingArguments
22-
from paddlenlp.trainer import get_last_checkpoint
23-
from paddlenlp.transformers import AutoTokenizer, AutoModelForSequenceClassification
24-
from paddlenlp.utils.log import logger
25-
from paddle.metric import Accuracy
26-
from datasets import load_dataset, load_metric
2720
import datasets
21+
import paddle
2822
from data_collator import DataCollator
29-
23+
from datasets import load_dataset
3024
from finetune_args import DataArguments, ModelArguments
31-
from utils import PreProcessor, PostProcessor, get_label_ld
3225
from layout_trainer import LayoutTrainer
26+
from paddle.metric import Accuracy
27+
from utils import PostProcessor, PreProcessor, get_label_ld
28+
29+
from paddlenlp.trainer import PdArgumentParser, TrainingArguments, get_last_checkpoint
30+
from paddlenlp.transformers import AutoModelForSequenceClassification, AutoTokenizer
31+
from paddlenlp.utils.log import logger
3332

3433

3534
def main():

model_zoo/ernie-layout/run_mrc.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,22 +13,21 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16-
import os
1716
import collections
17+
import os
1818
from functools import partial
1919

20-
import paddle
21-
from paddlenlp.trainer import PdArgumentParser, TrainingArguments
22-
from paddlenlp.trainer import get_last_checkpoint
23-
from paddlenlp.transformers import AutoTokenizer, AutoModelForQuestionAnswering
24-
from paddlenlp.utils.log import logger
2520
import datasets
26-
from datasets import load_dataset, load_metric
21+
import paddle
2722
from data_collator import DataCollator
28-
23+
from datasets import load_dataset
2924
from finetune_args import DataArguments, ModelArguments
30-
from utils import PreProcessor, PostProcessor, anls_score
3125
from layout_trainer import LayoutTrainer
26+
from utils import PostProcessor, PreProcessor, anls_score
27+
28+
from paddlenlp.trainer import PdArgumentParser, TrainingArguments, get_last_checkpoint
29+
from paddlenlp.transformers import AutoModelForQuestionAnswering, AutoTokenizer
30+
from paddlenlp.utils.log import logger
3231

3332

3433
def main():

model_zoo/ernie-layout/run_ner.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,22 +13,21 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16-
import os
1716
import collections
17+
import os
1818
from functools import partial
1919

2020
import paddle
21-
from paddlenlp.trainer import PdArgumentParser, TrainingArguments
22-
from paddlenlp.trainer import get_last_checkpoint
23-
from paddlenlp.transformers import AutoTokenizer, AutoModelForTokenClassification
24-
from seqeval.metrics import classification_report
25-
26-
from datasets import load_dataset
2721
from data_collator import DataCollator
28-
22+
from datasets import load_dataset
2923
from finetune_args import DataArguments, ModelArguments
30-
from utils import PreProcessor, PostProcessor, get_label_ld
3124
from layout_trainer import LayoutTrainer
25+
from seqeval.metrics import classification_report
26+
from utils import PostProcessor, PreProcessor, get_label_ld
27+
28+
from paddlenlp.trainer import PdArgumentParser, TrainingArguments, get_last_checkpoint
29+
from paddlenlp.transformers import AutoModelForTokenClassification, AutoTokenizer
30+
from paddlenlp.utils.log import logger
3231

3332

3433
def main():

paddlenlp/transformers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@
7777
from .ernie_gram.tokenizer import *
7878
from .ernie_layout.modeling import *
7979
from .ernie_layout.tokenizer import *
80+
from .ernie_layout.configuration import *
8081
from .ernie_m.configuration import *
8182
from .ernie_m.modeling import *
8283
from .ernie_m.tokenizer import *
Lines changed: 207 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,207 @@
1+
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
""" ERNIE-Layout model configuration"""
15+
16+
from typing import Dict
17+
18+
from ..configuration_utils import PretrainedConfig
19+
20+
__all__ = [
21+
"ERNIE_LAYOUT_PRETRAINED_INIT_CONFIGURATION",
22+
"ErnieLayoutConfig",
23+
"ERNIE_LAYOUT_PRETRAINED_RESOURCE_FILES_MAP",
24+
]
25+
26+
ERNIE_LAYOUT_PRETRAINED_INIT_CONFIGURATION = {
27+
"ernie-layoutx-base-uncased": {
28+
"attention_probs_dropout_prob": 0.1,
29+
"bos_token_id": 0,
30+
"coordinate_size": 128,
31+
"eos_token_id": 2,
32+
"gradient_checkpointing": False,
33+
"has_relative_attention_bias": True,
34+
"has_spatial_attention_bias": True,
35+
"has_visual_segment_embedding": False,
36+
"hidden_act": "gelu",
37+
"hidden_dropout_prob": 0.1,
38+
"hidden_size": 768,
39+
"image_feature_pool_shape": [7, 7, 256],
40+
"initializer_range": 0.02,
41+
"intermediate_size": 3072,
42+
"layer_norm_eps": 1e-12,
43+
"max_2d_position_embeddings": 1024,
44+
"max_position_embeddings": 514,
45+
"max_rel_2d_pos": 256,
46+
"max_rel_pos": 128,
47+
"model_type": "ernie_layout",
48+
"num_attention_heads": 12,
49+
"num_hidden_layers": 12,
50+
"output_past": True,
51+
"pad_token_id": 1,
52+
"shape_size": 128,
53+
"rel_2d_pos_bins": 64,
54+
"rel_pos_bins": 32,
55+
"type_vocab_size": 100,
56+
"vocab_size": 250002,
57+
},
58+
"uie-x-base": {
59+
"attention_probs_dropout_prob": 0.1,
60+
"bos_token_id": 0,
61+
"coordinate_size": 128,
62+
"eos_token_id": 2,
63+
"gradient_checkpointing": False,
64+
"has_relative_attention_bias": True,
65+
"has_spatial_attention_bias": True,
66+
"has_visual_segment_embedding": False,
67+
"hidden_act": "gelu",
68+
"hidden_dropout_prob": 0.1,
69+
"hidden_size": 768,
70+
"image_feature_pool_shape": [7, 7, 256],
71+
"initializer_range": 0.02,
72+
"intermediate_size": 3072,
73+
"layer_norm_eps": 1e-12,
74+
"max_2d_position_embeddings": 1024,
75+
"max_position_embeddings": 514,
76+
"max_rel_2d_pos": 256,
77+
"max_rel_pos": 128,
78+
"model_type": "ernie_layout",
79+
"num_attention_heads": 12,
80+
"num_hidden_layers": 12,
81+
"output_past": True,
82+
"pad_token_id": 1,
83+
"shape_size": 128,
84+
"rel_2d_pos_bins": 64,
85+
"rel_pos_bins": 32,
86+
"type_vocab_size": 100,
87+
"vocab_size": 250002,
88+
},
89+
}
90+
91+
ERNIE_LAYOUT_PRETRAINED_RESOURCE_FILES_MAP = {
92+
"model_state": {
93+
"ernie-layoutx-base-uncased": "https://bj.bcebos.com/paddlenlp/models/transformers/ernie_layout/ernie_layoutx_base_uncased.pdparams",
94+
"uie-x-base": "https://bj.bcebos.com/paddlenlp/models/transformers/uie_x/uie_x_base.pdparams",
95+
},
96+
}
97+
98+
99+
class ErnieLayoutConfig(PretrainedConfig):
100+
r"""
101+
This is the configuration class to store the configuration of a [`ErnieLayoutModel`]. It is used to
102+
instantiate a ErnieLayout model according to the specified arguments, defining the model architecture. Instantiating a
103+
configuration with the defaults will yield a similar configuration to that of the ErnieLayout
104+
ernie-layoutx-base-uncased architecture.
105+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
106+
documentation from [`PretrainedConfig`] for more information.
107+
Args:
108+
vocab_size (`int`, *optional*, defaults to 250002):
109+
Vocabulary size of the ErnieLayout model. Defines the number of different tokens that can be represented by the
110+
`inputs_ids` passed when calling [`ErnieLayoutModel`].
111+
hidden_size (`int`, *optional*, defaults to 768):
112+
Dimensionality of the encoder layers and the pooler layer.
113+
num_hidden_layers (`int`, *optional*, defaults to 12):
114+
Number of hidden layers in the Transformer encoder.
115+
num_attention_heads (`int`, *optional*, defaults to 12):
116+
Number of attention heads for each attention layer in the Transformer encoder.
117+
intermediate_size (`int`, *optional*, defaults to 3072):
118+
Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder.
119+
hidden_act (`str` or `Callable`, *optional*, defaults to `"gelu"`):
120+
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
121+
`"relu"`, `"silu"` and `"gelu_new"` are supported.
122+
hidden_dropout_prob (`float`, *optional*, defaults to 0.1):
123+
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
124+
attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1):
125+
The dropout ratio for the attention probabilities.
126+
max_position_embeddings (`int`, *optional*, defaults to 514):
127+
The maximum sequence length that this model might ever be used with. Typically set this to something large
128+
just in case (e.g., 514 or 1028 or 2056).
129+
type_vocab_size (`int`, *optional*, defaults to 100):
130+
The vocabulary size of the `token_type_ids` passed when calling [`ErnieModel`].
131+
initializer_range (`float`, *optional*, defaults to 0.02):
132+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
133+
layer_norm_eps (`float`, *optional*, defaults to 1e-12):
134+
The epsilon used by the layer normalization layers.
135+
use_cache (`bool`, *optional*, defaults to `True`):
136+
Whether or not the model should return the last key/values attentions (not used by all models). Only
137+
relevant if `config.is_decoder=True`.
138+
classifier_dropout (`float`, *optional*):
139+
The dropout ratio for classifier.
140+
has_visual_segment_embedding (`bool`, *optional*, defaults to `False`):
141+
Whether or not the model has visual segment embedding.
142+
Examples:
143+
```python
144+
>>> from paddlenlp.transformers import ErnieLayoutModel, ErnieLayoutConfig
145+
>>> # Initializing a ErnieLayout ernie-layoutx-base-uncased configuration
146+
>>> configuration = ErnieLayoutConfig()
147+
>>> # Initializing a model from the style configuration
148+
>>> model = ErnieLayoutModel(configuration)
149+
>>> # Accessing the model configuration
150+
>>> configuration = model.config
151+
```"""
152+
model_type = "ernie_layout"
153+
attribute_map: Dict[str, str] = {"num_classes": "num_labels", "dropout": "classifier_dropout"}
154+
pretrained_init_configuration = ERNIE_LAYOUT_PRETRAINED_INIT_CONFIGURATION
155+
156+
def __init__(
157+
self,
158+
vocab_size: int = 30522,
159+
hidden_size: int = 768,
160+
num_hidden_layers: int = 12,
161+
num_attention_heads: int = 12,
162+
task_id=0,
163+
intermediate_size: int = 3072,
164+
hidden_act: str = "gelu",
165+
hidden_dropout_prob: float = 0.1,
166+
attention_probs_dropout_prob: float = 0.1,
167+
max_position_embeddings: int = 512,
168+
max_2d_position_embeddings: int = 1024,
169+
task_type_vocab_size: int = 3,
170+
type_vocab_size: int = 16,
171+
initializer_range: float = 0.02,
172+
pad_token_id: int = 0,
173+
pool_act: str = "tanh",
174+
fuse: bool = False,
175+
image_feature_pool_shape=[7, 7, 256],
176+
layer_norm_eps=1e-12,
177+
use_cache=False,
178+
use_task_id=True,
179+
enable_recompute=False,
180+
classifier_dropout=None,
181+
has_visual_segment_embedding=False,
182+
**kwargs
183+
):
184+
super().__init__(pad_token_id=pad_token_id, **kwargs)
185+
self.vocab_size = vocab_size
186+
self.hidden_size = hidden_size
187+
self.num_hidden_layers = num_hidden_layers
188+
self.num_attention_heads = num_attention_heads
189+
self.task_id = task_id
190+
self.intermediate_size = intermediate_size
191+
self.hidden_act = hidden_act
192+
self.hidden_dropout_prob = hidden_dropout_prob
193+
self.attention_probs_dropout_prob = attention_probs_dropout_prob
194+
self.max_position_embeddings = max_position_embeddings
195+
self.max_2d_position_embeddings = max_2d_position_embeddings
196+
self.task_type_vocab_size = task_type_vocab_size
197+
self.type_vocab_size = type_vocab_size
198+
self.initializer_range = initializer_range
199+
self.pool_act = pool_act
200+
self.fuse = fuse
201+
self.image_feature_pool_shape = image_feature_pool_shape
202+
self.layer_norm_eps = layer_norm_eps
203+
self.use_cache = use_cache
204+
self.use_task_id = use_task_id
205+
self.classifier_dropout = classifier_dropout
206+
self.has_visual_segment_embedding = has_visual_segment_embedding
207+
self.enable_recompute = enable_recompute

0 commit comments

Comments
 (0)