-
Notifications
You must be signed in to change notification settings - Fork 2
/
train_utils.py
101 lines (84 loc) · 3.51 KB
/
train_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import os
import random
import torch
import numpy as np
SUPPORTED_DATASETS = ['rudrec', 'nerel_bio', 'conll2003']
# code from https://github.com/IlyaGusev/rulm/blob/master/self_instruct/src/util/dl.py
def set_random_seed(seed):
random.seed(seed)
np.random.seed(seed)
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:2"
os.environ["PL_GLOBAL_SEED"] = str(seed)
os.environ["PYTHONHASHSEED"] = str(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
def fix_tokenizer(tokenizer):
# Fixing broken tokenizers
special_tokens = dict()
for token_id in range(1000):
token = tokenizer.convert_ids_to_tokens(token_id)
if tokenizer.pad_token_id in (None, tokenizer.vocab_size) and "pad" in token:
special_tokens["pad_token"] = token
if tokenizer.bos_token_id in (None, tokenizer.vocab_size) and "<s>" in token:
special_tokens["bos_token"] = token
if tokenizer.eos_token_id in (None, tokenizer.vocab_size) and "</s>" in token:
special_tokens["eos_token"] = token
if tokenizer.unk_token_id in (None, tokenizer.vocab_size) and "unk" in token:
special_tokens["unk_token"] = token
if tokenizer.sep_token_id in (None, tokenizer.vocab_size) and "sep" in token:
special_tokens["sep_token"] = token
if (
tokenizer.sep_token_id in (None, tokenizer.vocab_size)
and "bos_token" in special_tokens
):
special_tokens["sep_token"] = special_tokens["bos_token"]
if (
tokenizer.pad_token_id in (None, tokenizer.vocab_size)
and "pad_token" not in special_tokens
):
if tokenizer.unk_token_id is not None:
special_tokens["pad_token"] = tokenizer.unk_token
else:
special_tokens["pad_token"] = "<|pad|>"
if (
tokenizer.sep_token_id in (None, tokenizer.vocab_size)
and "sep_token" not in special_tokens
):
if tokenizer.bos_token_id is not None:
special_tokens["sep_token"] = tokenizer.bos_token
else:
special_tokens["sep_token"] = "<|sep|>"
tokenizer.add_special_tokens(special_tokens)
print("Vocab size: ", tokenizer.vocab_size)
print("PAD: ", tokenizer.pad_token_id, tokenizer.pad_token)
print("BOS: ", tokenizer.bos_token_id, tokenizer.bos_token)
print("EOS: ", tokenizer.eos_token_id, tokenizer.eos_token)
print("UNK: ", tokenizer.unk_token_id, tokenizer.unk_token)
print("SEP: ", tokenizer.sep_token_id, tokenizer.sep_token)
return tokenizer
def fix_model(model, tokenizer, use_resize=True):
model.config.pad_token_id = tokenizer.pad_token_id
assert model.config.pad_token_id is not None
bos_candidates = (
tokenizer.bos_token_id,
tokenizer.cls_token_id,
tokenizer.sep_token_id,
tokenizer.unk_token_id
)
for bos_candidate in bos_candidates:
model.config.bos_token_id = bos_candidate
if bos_candidate is not None:
break
assert model.config.bos_token_id is not None
model.config.decoder_start_token_id = model.config.bos_token_id
eos_candidates = (tokenizer.eos_token_id, tokenizer.sep_token_id)
for eos_candidate in eos_candidates:
model.config.eos_token_id = eos_candidate
if eos_candidate is not None:
break
assert model.config.eos_token_id is not None
if use_resize:
model.resize_token_embeddings(len(tokenizer))
return model