-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
132 lines (97 loc) · 3.65 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
from dataclasses import dataclass, field
from typing import Optional
from itertools import chain
from transformers import LlamaForCausalLM, LlamaTokenizer, Trainer, TrainingArguments, HfArgumentParser, DataCollatorForSeq2Seq
from datasets import load_dataset
IGNORE_INDEX = -100
@dataclass
class ModelArguments:
base_model: Optional[str] = field(default="base-model")
@dataclass
class DataArguments:
data_path: str = field(
default=None, metadata={"help": "Path to the training data."}
)
max_seq_length = 4096
@dataclass
class TrainArguments(TrainingArguments):
per_device_train_batch_size = 2
gradient_accumulation_steps = 1
num_train_epochs = 3
learning_rate = 2e-5
fp16 = True
logging_steps = 10
optim = "adamw_torch"
save_strategy = "epoch"
output_dir = 'output'
save_total_limit = 5
report_to = 'wandb'
adam_beta1 = 0.9
adam_beta2 = 0.95
def load_data(tokenizer, dataset, max_length):
def preprocess_pretrain_dataset(examples):
text_ids = tokenizer(
examples["text"],
add_special_tokens=False)["input_ids"]
concatenated_ids = list(chain(*text_ids))
total_length = len(concatenated_ids)
block_size = max_length
# we drop the small remainder, and if the total_length < block_size, we exclude this batch
total_length = (total_length // block_size) * block_size
# split by chunks of max_source_length
result = [concatenated_ids[i: i + block_size]
for i in range(0, total_length, block_size)]
return {
"input_ids": result,
"labels": result.copy()
}
def print_supervised_dataset_example(example):
print("input_ids:\n{}".format(example["input_ids"]))
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
print("label_ids:\n{}".format(example["labels"]))
print("labels:\n{}".format(tokenizer.decode([
token_id if token_id != -100 else tokenizer.pad_token_id for token_id in example["labels"]
], skip_special_tokens=False)))
column_names = dataset.column_names
dataset = dataset.map(
preprocess_pretrain_dataset,
batched=True,
remove_columns=column_names,
num_proc=64
)
print_supervised_dataset_example(dataset[0])
print(len(dataset))
return {
"train_dataset": dataset
}
def train():
parser = HfArgumentParser(
(ModelArguments, DataArguments, TrainingArguments)
)
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
tokenizer = LlamaTokenizer.from_pretrained(model_args.base_model)
tokenizer.pad_token_id = 0
model = LlamaForCausalLM.from_pretrained(
model_args.base_model,
use_flash_attention_2=True
)
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print("total params: ", total_params)
if data_args.data_path.endswith(".json") or data_args.data_path.endswith(".jsonl"):
data = load_dataset("json", data_files=data_args.data_path)
else:
data = load_dataset(data_args.data_path)
dataset = load_data(tokenizer, data['train'], data_args.max_seq_length)
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer)
trainer = Trainer(
model=model,
tokenizer=tokenizer,
args=training_args,
data_collator=data_collator,
**dataset
)
model.config.use_cache = False
trainer.train(resume_from_checkpoint=False)
trainer.save_model()
if __name__ == "__main__":
train()