Skip to content

Commit f8eb45f

Browse files
committed
v0.5.1
1 parent 706b7da commit f8eb45f

File tree

9 files changed

+77
-46
lines changed

9 files changed

+77
-46
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ pip install git+https://github.com/Tongjilibo/bert4torch
9393
### 4.1 版本历史
9494
|更新日期| bert4torch | torch4keras | 版本说明 |
9595
|------| ---------------- | ----------------- |----------- |
96+
|20240619| 0.5.1 | 0.2.4 | 增加Qwen1.5, Qwen2, glm4; 增加SWA/convert_lm_logits_dtype;调整各个trainer(重点DPOTrainer), generation中segment_ids, repetition_penalty需带query, RMSNorm中转类型bug|
9697
|20240418| 0.5.0 | 0.2.2 | 修复chatglm3的bug, 修复save_pretrained时多文件的bug,增加CausalLMLoss, 修改deepspeed的传参逻辑,修改Text2Vec的bug, 完善openai client, 增加get_weight_decay_optim_groups|
9798
|20240317| 0.4.9.post2 | 0.2.1.post2 |增加get_weight_decay_optim_groups函数, attention中允许is_causal,修改repetition_penalty的bug,把baichuan从llama中剥离,修复config_path的bug,允许num_key_value_heads参数,[torch4keras-v0.2.1.post2](https://github.com/Tongjilibo/torch4keras/releases/tag/v0.2.1.post2)更新特性|
9899
|20240221| 0.4.8 | 0.2.0|fastapi发布服务允许闲时offload到cpu, `build_transformer_model`允许从hf下载, 添加`FillMask`的pipeline, 添加`SequenceClassificationTrainer`|

bert4torch/snippets/import_utils.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -10,18 +10,6 @@
1010
import importlib.metadata as importlib_metadata
1111

1212

13-
def is_accelerate_available(check_partial_state=False):
14-
'''是否可以使用accelerate'''
15-
accelerate_available = importlib.util.find_spec("accelerate") is not None
16-
if accelerate_available:
17-
if check_partial_state:
18-
return version.parse(importlib_metadata.version("accelerate")) >= version.parse("0.17.0")
19-
else:
20-
return True
21-
else:
22-
return False
23-
24-
2513
def is_flash_attn_available():
2614
'''是否可以使用包flash_attn'''
2715
_flash_attn_available = is_package_available("flash_attn") and \

bert4torch/trainer/__init__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,6 @@
44

55
from torch4keras.trainer import * # torch4keras>=0.1.2.post2
66
from .ppo_trainer import PPOTrainer
7-
from .dpo_trainer import DPOTrainer
8-
from .ptuningv2_trainer import PtuningV2Trainer
9-
from .sequence_classification_trainer import SequenceClassificationTrainer
7+
from .dpo_trainer import DPOTrainer, DPOModel
8+
from .ptuningv2_trainer import PtuningV2Trainer, PtuningV2Model
9+
from .sequence_classification_trainer import SequenceClassificationTrainer, SequenceClassificationModel

bert4torch/trainer/dpo_trainer.py

Lines changed: 36 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from contextlib import contextmanager, nullcontext
88
import warnings
99
import inspect
10+
from torch.nn.modules import Module
1011
from torch4keras.trainer import AutoTrainer, Trainer
1112
from bert4torch.models import BaseModel, build_transformer_model
1213
from bert4torch.snippets import is_peft_available, disable_dropout_in_model, peft_module_casting_to_bf16
@@ -20,6 +21,14 @@ class DPOModel(BaseModel):
2021
2122
:param model: 待训练模型
2223
:param ref_model: 参考模型
24+
:param args: dpo训练的部分参数
25+
:param model_init_kwargs: model的build_transformer_model参数
26+
:param ref_model_init_kwargs: ref_model的build_transformer_model参数
27+
:param model_adapter_name: model的adapter_name
28+
:param ref_adapter_name: ref_model的adapter_name
29+
:param peft_config: peft配置项
30+
:param disable_dropout: 是否不适用dropout
31+
:param force_use_ref_model: 强制使用ref_model
2332
'''
2433
def __init__(
2534
self,
@@ -163,6 +172,14 @@ class DPOTrainer(AutoTrainer):
163172
'''DPOTrainer
164173
:param model: 待训练模型
165174
:param ref_model: 参考模型
175+
:param args: dpo训练的部分参数
176+
:param model_init_kwargs: model的build_transformer_model参数
177+
:param ref_model_init_kwargs: ref_model的build_transformer_model参数
178+
:param model_adapter_name: model的adapter_name
179+
:param ref_adapter_name: ref_model的adapter_name
180+
:param peft_config: peft配置项
181+
:param disable_dropout: 是否不适用dropout
182+
:param force_use_ref_model: 强制使用ref_model
166183
167184
Examples
168185
```python
@@ -175,11 +192,26 @@ class DPOTrainer(AutoTrainer):
175192
>>> model.to('cuda')
176193
```
177194
'''
195+
def __init__(self,
196+
model: Optional[Union[BaseModel, str]],
197+
*trainer_args,
198+
ref_model:BaseModel=None,
199+
args: Optional[DottableDict] = DottableDict(),
200+
model_init_kwargs: Optional[Dict] = None,
201+
ref_model_init_kwargs: Optional[Dict] = None,
202+
model_adapter_name: Optional[str] = None,
203+
ref_adapter_name: Optional[str] = None,
204+
peft_config: Optional[Dict] = None,
205+
disable_dropout: bool = True,
206+
force_use_ref_model: bool = False,
207+
**kwargs):
208+
pass
209+
178210
def __new__(cls,
179211
model: Optional[Union[BaseModel, str]],
180-
*args,
212+
*trainer_args,
181213
ref_model:BaseModel=None,
182-
dpo_args: Optional[DottableDict] = DottableDict(),
214+
args: Optional[DottableDict] = DottableDict(),
183215
model_init_kwargs: Optional[Dict] = None,
184216
ref_model_init_kwargs: Optional[Dict] = None,
185217
model_adapter_name: Optional[str] = None,
@@ -189,7 +221,7 @@ def __new__(cls,
189221
force_use_ref_model: bool = False,
190222
**kwargs
191223
) -> Trainer:
192-
module = DPOModel(model, ref_model, dpo_args, model_init_kwargs, ref_model_init_kwargs,
224+
module = DPOModel(model, ref_model, args, model_init_kwargs, ref_model_init_kwargs,
193225
model_adapter_name, ref_adapter_name, peft_config, disable_dropout, force_use_ref_model)
194226
module.to(model.device)
195-
return super().__new__(cls, module, *args, **kwargs)
227+
return super().__new__(cls, module, *trainer_args, **kwargs)

bert4torch/trainer/ptuningv2_trainer.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,10 @@ class PtuningV2Trainer(AutoTrainer):
133133
>>> model = PtuningV2Trainer(encoder).to('cuda')
134134
```
135135
'''
136+
def __init__(self, encoder:nn.Module, *args, pre_seq_len:int=128, prefix_projection:bool=False, **kwargs):
137+
pass
138+
136139
def __new__(cls, encoder:nn.Module, *args, pre_seq_len:int=128, prefix_projection:bool=False, **kwargs) -> Trainer:
137140
module = PtuningV2Model(encoder, *args, pre_seq_len=pre_seq_len, prefix_projection=prefix_projection, **kwargs)
141+
module.to(encoder.device)
138142
return super().__new__(cls, module, *args, **kwargs)

bert4torch/trainer/sequence_classification_trainer.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -61,12 +61,15 @@ class SequenceClassificationTrainer(AutoTrainer):
6161
>>> config_path = '' # bert4torch_config.json路径
6262
>>> checkpoint_path = '' # 模型文件夹路径
6363
>>> bert = build_transformer_model(config_path=config_path, checkpoint_path=checkpoint_path, with_pool=True)
64-
>>> model = SequenceClassificationTrainer(bert)
65-
>>> model.to('cuda')
64+
>>> model = SequenceClassificationTrainer(bert).to('cuda')
6665
```
6766
'''
67+
def __init__(self, module:BaseModel, *args, num_labels:int=2, classifier_dropout:float=None,
68+
pool_strategy:Literal['pooler', 'cls', 'last-avg', 'mean', 'last-max', 'max', 'first-last-avg', 'custom']='cls', **kwargs):
69+
pass
70+
6871
def __new__(cls, module:BaseModel, *args, num_labels:int=2, classifier_dropout:float=None,
6972
pool_strategy:Literal['pooler', 'cls', 'last-avg', 'mean', 'last-max', 'max', 'first-last-avg', 'custom']='cls', **kwargs) -> Trainer:
70-
module = SequenceClassificationModel(module, num_labels, classifier_dropout, pool_strategy, **kwargs)
71-
module.to(model.device)
72-
return super().__new__(cls, module, *args, **kwargs)
73+
model = SequenceClassificationModel(module, num_labels, classifier_dropout, pool_strategy, **kwargs)
74+
model.to(module.device)
75+
return super().__new__(cls, model, *args, **kwargs)

docs/History.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
## 更新历史
22

3+
- **20240619**:增加Qwen1.5, Qwen2, glm4; 增加SWA/convert_lm_logits_dtype;调整各个trainer(重点DPOTrainer), generation中segment_ids, repetition_penalty需带query
34
- **20240426**:简化大模型调用demo, generation_config从config读取, 增加Qwen2和SWA, 修复RMSNorm中转类型bug
45
- **20240418**:修改Text2Vec的bug, 完善openai client, 增加get_weight_decay_optim_groups
56
- **20240331**: 修复chatglm3的bug, 修复save_pretrained时多文件的bug,增加CausalLMLoss, 修改deepspeed的传参逻辑

examples/sentence_classfication/task_sentiment_classification.py

Lines changed: 23 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -63,27 +63,29 @@ def collate_fn(batch):
6363
valid_dataloader = DataLoader(MyDataset([f'{data_dir}/sentiment.valid.data']), batch_size=batch_size, collate_fn=collate_fn)
6464
test_dataloader = DataLoader(MyDataset([f'{data_dir}/sentiment.test.data']), batch_size=batch_size, collate_fn=collate_fn)
6565

66-
# 方式1
67-
class Model(BaseModel):
68-
def __init__(self, pool_method='cls') -> None:
69-
super().__init__()
70-
self.pool_method = pool_method
71-
self.bert = build_transformer_model(config_path=config_path, checkpoint_path=checkpoint_path, with_pool=True, gradient_checkpoint=True)
72-
self.dropout = nn.Dropout(0.1)
73-
self.dense = nn.Linear(self.bert.configs['hidden_size'], 2)
74-
75-
def forward(self, token_ids, segment_ids):
76-
hidden_states, pooling = self.bert([token_ids, segment_ids])
77-
pooled_output = get_pool_emb(hidden_states, pooling, token_ids.gt(0).long(), self.pool_method)
78-
output = self.dropout(pooled_output)
79-
output = self.dense(output)
80-
return output
81-
model = Model().to(device)
82-
83-
# 方式2
84-
# from bert4torch.trainer import SequenceClassificationTrainer
85-
# bert = build_transformer_model(config_path=config_path, checkpoint_path=checkpoint_path, with_pool=True, gradient_checkpoint=True)
86-
# model = SequenceClassificationTrainer(bert).to(device)
66+
if False:
67+
# 方式1
68+
class Model(BaseModel):
69+
def __init__(self, pool_method='cls') -> None:
70+
super().__init__()
71+
self.pool_method = pool_method
72+
self.bert = build_transformer_model(config_path=config_path, checkpoint_path=checkpoint_path, with_pool=True, gradient_checkpoint=True)
73+
self.dropout = nn.Dropout(0.1)
74+
self.dense = nn.Linear(self.bert.configs['hidden_size'], 2)
75+
76+
def forward(self, token_ids, segment_ids):
77+
hidden_states, pooling = self.bert([token_ids, segment_ids])
78+
pooled_output = get_pool_emb(hidden_states, pooling, token_ids.gt(0).long(), self.pool_method)
79+
output = self.dropout(pooled_output)
80+
output = self.dense(output)
81+
return output
82+
model = Model().to(device)
83+
84+
else:
85+
# 方式2
86+
from bert4torch.trainer import SequenceClassificationTrainer
87+
bert = build_transformer_model(config_path=config_path, checkpoint_path=checkpoint_path, with_pool=True, gradient_checkpoint=True)
88+
model = SequenceClassificationTrainer(bert).to(device)
8789

8890
# 定义使用的loss和optimizer,这里支持自定义
8991
model.compile(

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,6 @@
1414
license='MIT Licence',
1515
url='https://github.com/Tongjilibo/bert4torch',
1616
author='Tongjilibo',
17-
install_requires=['numpy', 'tqdm', 'torch>1.6', 'torch4keras==0.2.3', 'six'],
17+
install_requires=['numpy', 'tqdm', 'torch>1.6', 'torch4keras==0.2.4', 'six'],
1818
packages=find_packages()
1919
)

0 commit comments

Comments
 (0)