Skip to content

Commit e3afd70

Browse files
committed
Simplify tokenization.
1 parent 81039e3 commit e3afd70

File tree

8 files changed

+27
-258
lines changed

8 files changed

+27
-258
lines changed

arguments.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,7 @@ def add_text_generate_args(parser):
215215
group.add_argument("--top_p", type=float, default=0.0)
216216
group.add_argument("--top_k", type=int, default=0)
217217
group.add_argument("--out-seq-length", type=int, default=256)
218+
group.add_argument("--input-text", type=str, default=None)
218219
return parser
219220

220221

bpe_3w_new/chinese_vocab.model

-2 Bytes
Binary file not shown.

bpe_3w_new/chinese_vocab.vocab

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
<unk> 0
22
<s> 0
33
</s> 0
4-
<cls> 0
4+
0
55
<sep> 0
66
<pad> 0
77
<mask> 0

bpe_3w_new/merges.txt

Whitespace-only changes.

data_utils/tokenization_gpt2.py

Lines changed: 5 additions & 251 deletions
Original file line numberDiff line numberDiff line change
@@ -35,140 +35,18 @@ def lru_cache():
3535

3636
from .file_utils import cached_path
3737

38-
logger = logging.getLogger(__name__)
39-
40-
PRETRAINED_VOCAB_ARCHIVE_MAP = {
41-
'gpt2': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-vocab.json",
42-
}
43-
PRETRAINED_MERGES_ARCHIVE_MAP = {
44-
'gpt2': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-merges.txt",
45-
}
46-
PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = {
47-
'gpt2': 1024,
48-
}
49-
VOCAB_NAME = 'vocab.json'
50-
MERGES_NAME = 'merges.txt'
51-
SPECIAL_TOKENS_NAME = 'special_tokens.txt'
52-
53-
@lru_cache()
54-
def bytes_to_unicode():
55-
"""
56-
Returns list of utf-8 byte and a corresponding list of unicode strings.
57-
The reversible bpe codes work on unicode strings.
58-
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
59-
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
60-
This is a signficant percentage of your normal, say, 32K bpe vocab.
61-
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
62-
And avoids mapping to whitespace/control characters the bpe code barfs on.
63-
"""
64-
_chr = unichr if sys.version_info[0] == 2 else chr
65-
bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
66-
cs = bs[:]
67-
n = 0
68-
for b in range(2**8):
69-
if b not in bs:
70-
bs.append(b)
71-
cs.append(2**8+n)
72-
n += 1
73-
cs = [_chr(n) for n in cs]
74-
return dict(zip(bs, cs))
75-
76-
def get_pairs(word):
77-
"""Return set of symbol pairs in a word.
78-
79-
Word is represented as tuple of symbols (symbols being variable-length strings).
80-
"""
81-
pairs = set()
82-
prev_char = word[0]
83-
for char in word[1:]:
84-
pairs.add((prev_char, char))
85-
prev_char = char
86-
return pairs
87-
8838
class GPT2Tokenizer(object):
89-
"""
90-
GPT-2 BPE tokenizer. Peculiarities:
91-
- Byte-level BPE
92-
"""
93-
@classmethod
94-
def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs):
95-
"""
96-
Instantiate a PreTrainedBertModel from a pre-trained model file.
97-
Download and cache the pre-trained model file if needed.
98-
"""
99-
if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP:
100-
vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name_or_path]
101-
merges_file = PRETRAINED_MERGES_ARCHIVE_MAP[pretrained_model_name_or_path]
102-
special_tokens_file = None
103-
else:
104-
vocab_file = os.path.join(pretrained_model_name_or_path, VOCAB_NAME)
105-
merges_file = os.path.join(pretrained_model_name_or_path, MERGES_NAME)
106-
special_tokens_file = os.path.join(pretrained_model_name_or_path, SPECIAL_TOKENS_NAME)
107-
if not os.path.exists(special_tokens_file):
108-
special_tokens_file = None
109-
else:
110-
logger.info("loading special tokens file {}".format(special_tokens_file))
111-
# redirect to the cache, if necessary
112-
try:
113-
resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir)
114-
resolved_merges_file = cached_path(merges_file, cache_dir=cache_dir)
115-
except EnvironmentError:
116-
logger.error(
117-
"Model name '{}' was not found in model name list ({}). "
118-
"We assumed '{}' was a path or url but couldn't find files {} and {} "
119-
"at this path or url.".format(
120-
pretrained_model_name_or_path,
121-
', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()),
122-
pretrained_model_name_or_path,
123-
vocab_file, merges_file))
124-
return None
125-
if resolved_vocab_file == vocab_file and resolved_merges_file == merges_file:
126-
logger.info("loading vocabulary file {}".format(vocab_file))
127-
logger.info("loading merges file {}".format(merges_file))
128-
else:
129-
logger.info("loading vocabulary file {} from cache at {}".format(
130-
vocab_file, resolved_vocab_file))
131-
logger.info("loading merges file {} from cache at {}".format(
132-
merges_file, resolved_merges_file))
133-
if pretrained_model_name_or_path in PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP:
134-
# if we're using a pretrained model, ensure the tokenizer wont index sequences longer
135-
# than the number of positional embeddings
136-
max_len = PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP[pretrained_model_name_or_path]
137-
kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len)
138-
# Instantiate tokenizer.
139-
if special_tokens_file and 'special_tokens' not in kwargs:
140-
special_tokens = open(special_tokens_file, encoding='utf-8').read().split('\n')[:-1]
141-
else:
142-
special_tokens = kwargs.pop('special_tokens', [])
143-
tokenizer = cls(resolved_vocab_file, resolved_merges_file, special_tokens=special_tokens, *inputs, **kwargs)
144-
return tokenizer
14539

146-
def __init__(self, vocab_file, merges_file, model_file, errors='replace', special_tokens=None, max_len=None):
40+
def __init__(self, vocab_file, model_file, max_len=None):
14741
self.max_len = max_len if max_len is not None else int(1e12)
14842
self.encoder = json.load(open(vocab_file))
14943
self.decoder = {v:k for k,v in self.encoder.items()}
150-
self.errors = errors # how to handle errors in decoding
151-
self.byte_encoder = bytes_to_unicode()
152-
self.byte_decoder = {v:k for k, v in self.byte_encoder.items()}
153-
bpe_data = open(merges_file, encoding='utf-8').read().split('\n')[1:-1]
154-
bpe_merges = [tuple(merge.split()) for merge in bpe_data]
155-
self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
156-
self.cache = {}
157-
158-
# Should haved added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions
159-
self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")
160-
161-
self.special_tokens = {}
162-
self.special_tokens_decoder = {}
163-
self.set_special_tokens(special_tokens)
16444

16545
self.sp = spm.SentencePieceProcessor(model_file=model_file)
16646
self.translator = str.maketrans(" \n", "\u2582\u2583")
16747

16848
self.eod_id = self.encoder['<eod>']
16949

170-
self.decoder[7] = '\u2584'
171-
17250
@property
17351
def vocab_size(self):
17452
return len(self.encoder)
@@ -180,142 +58,18 @@ def __len__(self):
18058
def eod(self):
18159
return self.eod_id
18260

183-
def set_special_tokens(self, special_tokens):
184-
""" Add a list of additional tokens to the encoder.
185-
The additional tokens are indexed starting from the last index of the
186-
current vocabulary in the order of the `special_tokens` list.
187-
"""
188-
if not special_tokens:
189-
self.special_tokens = {}
190-
self.special_tokens_decoder = {}
191-
return
192-
self.special_tokens = dict((tok, len(self.encoder) + i) for i, tok in enumerate(special_tokens))
193-
self.special_tokens_decoder = {v:k for k, v in self.special_tokens.items()}
194-
logger.info("Special tokens {}".format(self.special_tokens))
195-
196-
def bpe(self, token):
197-
if token in self.cache:
198-
return self.cache[token]
199-
word = tuple(token)
200-
pairs = get_pairs(word)
201-
202-
if not pairs:
203-
return token
204-
205-
while True:
206-
bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
207-
if bigram not in self.bpe_ranks:
208-
break
209-
first, second = bigram
210-
new_word = []
211-
i = 0
212-
while i < len(word):
213-
try:
214-
j = word.index(first, i)
215-
new_word.extend(word[i:j])
216-
i = j
217-
except:
218-
new_word.extend(word[i:])
219-
break
220-
221-
if word[i] == first and i < len(word)-1 and word[i+1] == second:
222-
new_word.append(first+second)
223-
i += 2
224-
else:
225-
new_word.append(word[i])
226-
i += 1
227-
new_word = tuple(new_word)
228-
word = new_word
229-
if len(word) == 1:
230-
break
231-
else:
232-
pairs = get_pairs(word)
233-
word = ' '.join(word)
234-
self.cache[token] = word
235-
return word
236-
23761
def tokenize(self, text):
23862
""" Tokenize a string. """
239-
bpe_tokens = []
24063
seg_list = [x.translate(self.translator) for x in jieba.cut(text, cut_all=False)]
24164
new_seg = " ".join(seg_list)
242-
tmp_bpe_tokens = self.sp.encode(new_seg, out_type=str)
243-
bpe_tokens.extend(tmp_bpe_tokens)
244-
return bpe_tokens
245-
246-
def convert_tokens_to_ids(self, tokens):
247-
""" Converts a sequence of tokens into ids using the vocab. """
248-
ids = []
249-
if isinstance(tokens, str) or (sys.version_info[0] == 2 and isinstance(tokens, unicode)):
250-
if tokens in self.special_tokens:
251-
return self.special_tokens[tokens]
252-
else:
253-
return self.encoder.get(tokens, 0)
254-
for token in tokens:
255-
if token in self.special_tokens:
256-
ids.append(self.special_tokens[token])
257-
else:
258-
ids.append(self.encoder.get(token, 0))
259-
if len(ids) > self.max_len:
260-
logger.warning(
261-
"Token indices sequence length is longer than the specified maximum "
262-
" sequence length for this OpenAI GPT model ({} > {}). Running this"
263-
" sequence through the model will result in indexing errors".format(len(ids), self.max_len)
264-
)
265-
return ids
266-
267-
def convert_ids_to_tokens(self, ids, skip_special_tokens=False):
268-
"""Converts a sequence of ids in BPE tokens using the vocab."""
269-
tokens = []
270-
for i in ids:
271-
if i in self.special_tokens_decoder:
272-
if not skip_special_tokens:
273-
tokens.append(self.special_tokens_decoder[i])
274-
else:
275-
tokens.append(self.decoder[i])
276-
return tokens
65+
return self.sp.encode(new_seg)
27766

27867
def encode(self, text):
279-
res = self.convert_tokens_to_ids(self.tokenize(text))
68+
res = self.tokenize(text)
28069
return res
28170

28271
def decode(self, tokens):
283-
text = self.sp.decode([self.decoder[x] for x in tokens])
284-
text = text.replace(' ', '').replace('\u2582', ' ').replace('\u2583', '\n').replace('\u2584', '<eod>')
285-
#text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors=self.errors)
72+
text = self.sp.decode(tokens)
73+
text = text.replace(' ', '').replace('\u2582', ' ').replace('\u2583', '\n')
28674
return text
28775

288-
def save_vocabulary(self, vocab_path):
289-
"""Save the tokenizer vocabulary and merge files to a directory."""
290-
if not os.path.isdir(vocab_path):
291-
logger.error("Vocabulary path ({}) should be a directory".format(vocab_path))
292-
return
293-
vocab_file = os.path.join(vocab_path, VOCAB_NAME)
294-
merge_file = os.path.join(vocab_path, MERGES_NAME)
295-
special_tokens_file = os.path.join(vocab_path, SPECIAL_TOKENS_NAME)
296-
297-
with open(vocab_file, 'w', encoding='utf-8') as f:
298-
f.write(json.dumps(self.encoder, ensure_ascii=False))
299-
300-
index = 0
301-
with open(merge_file, "w", encoding="utf-8") as writer:
302-
writer.write(u'#version: 0.2\n')
303-
for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]):
304-
if index != token_index:
305-
logger.warning("Saving vocabulary to {}: BPE merge indices are not consecutive."
306-
" Please check that the tokenizer is not corrupted!".format(merge_file))
307-
index = token_index
308-
writer.write(' '.join(bpe_tokens) + u'\n')
309-
index += 1
310-
311-
index = len(self.encoder)
312-
with open(special_tokens_file, 'w', encoding='utf-8') as writer:
313-
for token, token_index in sorted(self.special_tokens.items(), key=lambda kv: kv[1]):
314-
if index != token_index:
315-
logger.warning("Saving special tokens vocabulary to {}: BPE indices are not consecutive."
316-
" Please check that the tokenizer is not corrupted!".format(special_tokens_file))
317-
index = token_index
318-
writer.write(token + u'\n')
319-
index += 1
320-
321-
return vocab_file, merge_file, special_tokens_file

example.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
中国的首都是北京
2+
日本的首都是东京
3+
美国的首都是

generate_samples.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -172,10 +172,13 @@ def generate_samples(model, tokenizer, args, device):
172172
terminate_runs=0
173173

174174
if mpu.get_model_parallel_rank() == 0:
175-
raw_text = input("\nContext prompt (stop to exit) >>> ")
176-
while not raw_text:
177-
print('Prompt should not be empty!')
175+
if args.input_text:
176+
raw_text = open(args.input_text).read().strip()
177+
else:
178178
raw_text = input("\nContext prompt (stop to exit) >>> ")
179+
while not raw_text:
180+
print('Prompt should not be empty!')
181+
raw_text = input("\nContext prompt (stop to exit) >>> ")
179182

180183
if "stop" in raw_text:
181184
terminate_runs = 1
@@ -264,6 +267,9 @@ def generate_samples(model, tokenizer, args, device):
264267
torch.distributed.barrier(group=mpu.get_model_parallel_group())
265268
context_count += 1
266269

270+
if args.input_text:
271+
break
272+
267273
def prepare_tokenizer(args):
268274

269275
tokenizer_args = {
@@ -357,7 +363,7 @@ def main():
357363
set_random_seed(args.seed)
358364

359365
#get the tokenizer
360-
tokenizer = GPT2Tokenizer(os.path.join(args.tokenizer_path, 'vocab.json'), os.path.join(args.tokenizer_path, 'merges.txt'), os.path.join(args.tokenizer_path, 'chinese_vocab.model'))
366+
tokenizer = GPT2Tokenizer(os.path.join(args.tokenizer_path, 'vocab.json'), os.path.join(args.tokenizer_path, 'chinese_vocab.model'))
361367

362368
# Model
363369
model = setup_model(args)

scripts/generate_text.sh

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ TEMP=0.9
1313
TOPK=0
1414
TOPP=0
1515

16-
python -m torch.distributed.launch --nproc_per_node 2 generate_samples.py \
16+
CMD="python -m torch.distributed.launch --nproc_per_node 2 generate_samples.py \
1717
--model-parallel-size $MPSIZE \
1818
--num-layers $NLAYERS \
1919
--hidden-size $NHIDDEN \
@@ -28,5 +28,10 @@ python -m torch.distributed.launch --nproc_per_node 2 generate_samples.py \
2828
--top_k $TOPK \
2929
--top_p $TOPP \
3030
--tokenizer-path bpe_3w_new/ \
31-
--vocab-size 30000
31+
--vocab-size 30000 "
32+
33+
if [ ! -z $2 ]; then
34+
CMD+="--input-text $2"
35+
fi
3236

37+
$CMD

0 commit comments

Comments
 (0)