@@ -35,140 +35,18 @@ def lru_cache():
35
35
36
36
from .file_utils import cached_path
37
37
38
- logger = logging .getLogger (__name__ )
39
-
40
- PRETRAINED_VOCAB_ARCHIVE_MAP = {
41
- 'gpt2' : "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-vocab.json" ,
42
- }
43
- PRETRAINED_MERGES_ARCHIVE_MAP = {
44
- 'gpt2' : "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-merges.txt" ,
45
- }
46
- PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = {
47
- 'gpt2' : 1024 ,
48
- }
49
- VOCAB_NAME = 'vocab.json'
50
- MERGES_NAME = 'merges.txt'
51
- SPECIAL_TOKENS_NAME = 'special_tokens.txt'
52
-
53
- @lru_cache ()
54
- def bytes_to_unicode ():
55
- """
56
- Returns list of utf-8 byte and a corresponding list of unicode strings.
57
- The reversible bpe codes work on unicode strings.
58
- This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
59
- When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
60
- This is a signficant percentage of your normal, say, 32K bpe vocab.
61
- To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
62
- And avoids mapping to whitespace/control characters the bpe code barfs on.
63
- """
64
- _chr = unichr if sys .version_info [0 ] == 2 else chr
65
- bs = list (range (ord ("!" ), ord ("~" )+ 1 ))+ list (range (ord ("¡" ), ord ("¬" )+ 1 ))+ list (range (ord ("®" ), ord ("ÿ" )+ 1 ))
66
- cs = bs [:]
67
- n = 0
68
- for b in range (2 ** 8 ):
69
- if b not in bs :
70
- bs .append (b )
71
- cs .append (2 ** 8 + n )
72
- n += 1
73
- cs = [_chr (n ) for n in cs ]
74
- return dict (zip (bs , cs ))
75
-
76
- def get_pairs (word ):
77
- """Return set of symbol pairs in a word.
78
-
79
- Word is represented as tuple of symbols (symbols being variable-length strings).
80
- """
81
- pairs = set ()
82
- prev_char = word [0 ]
83
- for char in word [1 :]:
84
- pairs .add ((prev_char , char ))
85
- prev_char = char
86
- return pairs
87
-
88
38
class GPT2Tokenizer (object ):
89
- """
90
- GPT-2 BPE tokenizer. Peculiarities:
91
- - Byte-level BPE
92
- """
93
- @classmethod
94
- def from_pretrained (cls , pretrained_model_name_or_path , cache_dir = None , * inputs , ** kwargs ):
95
- """
96
- Instantiate a PreTrainedBertModel from a pre-trained model file.
97
- Download and cache the pre-trained model file if needed.
98
- """
99
- if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP :
100
- vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP [pretrained_model_name_or_path ]
101
- merges_file = PRETRAINED_MERGES_ARCHIVE_MAP [pretrained_model_name_or_path ]
102
- special_tokens_file = None
103
- else :
104
- vocab_file = os .path .join (pretrained_model_name_or_path , VOCAB_NAME )
105
- merges_file = os .path .join (pretrained_model_name_or_path , MERGES_NAME )
106
- special_tokens_file = os .path .join (pretrained_model_name_or_path , SPECIAL_TOKENS_NAME )
107
- if not os .path .exists (special_tokens_file ):
108
- special_tokens_file = None
109
- else :
110
- logger .info ("loading special tokens file {}" .format (special_tokens_file ))
111
- # redirect to the cache, if necessary
112
- try :
113
- resolved_vocab_file = cached_path (vocab_file , cache_dir = cache_dir )
114
- resolved_merges_file = cached_path (merges_file , cache_dir = cache_dir )
115
- except EnvironmentError :
116
- logger .error (
117
- "Model name '{}' was not found in model name list ({}). "
118
- "We assumed '{}' was a path or url but couldn't find files {} and {} "
119
- "at this path or url." .format (
120
- pretrained_model_name_or_path ,
121
- ', ' .join (PRETRAINED_VOCAB_ARCHIVE_MAP .keys ()),
122
- pretrained_model_name_or_path ,
123
- vocab_file , merges_file ))
124
- return None
125
- if resolved_vocab_file == vocab_file and resolved_merges_file == merges_file :
126
- logger .info ("loading vocabulary file {}" .format (vocab_file ))
127
- logger .info ("loading merges file {}" .format (merges_file ))
128
- else :
129
- logger .info ("loading vocabulary file {} from cache at {}" .format (
130
- vocab_file , resolved_vocab_file ))
131
- logger .info ("loading merges file {} from cache at {}" .format (
132
- merges_file , resolved_merges_file ))
133
- if pretrained_model_name_or_path in PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP :
134
- # if we're using a pretrained model, ensure the tokenizer wont index sequences longer
135
- # than the number of positional embeddings
136
- max_len = PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP [pretrained_model_name_or_path ]
137
- kwargs ['max_len' ] = min (kwargs .get ('max_len' , int (1e12 )), max_len )
138
- # Instantiate tokenizer.
139
- if special_tokens_file and 'special_tokens' not in kwargs :
140
- special_tokens = open (special_tokens_file , encoding = 'utf-8' ).read ().split ('\n ' )[:- 1 ]
141
- else :
142
- special_tokens = kwargs .pop ('special_tokens' , [])
143
- tokenizer = cls (resolved_vocab_file , resolved_merges_file , special_tokens = special_tokens , * inputs , ** kwargs )
144
- return tokenizer
145
39
146
- def __init__ (self , vocab_file , merges_file , model_file , errors = 'replace' , special_tokens = None , max_len = None ):
40
+ def __init__ (self , vocab_file , model_file , max_len = None ):
147
41
self .max_len = max_len if max_len is not None else int (1e12 )
148
42
self .encoder = json .load (open (vocab_file ))
149
43
self .decoder = {v :k for k ,v in self .encoder .items ()}
150
- self .errors = errors # how to handle errors in decoding
151
- self .byte_encoder = bytes_to_unicode ()
152
- self .byte_decoder = {v :k for k , v in self .byte_encoder .items ()}
153
- bpe_data = open (merges_file , encoding = 'utf-8' ).read ().split ('\n ' )[1 :- 1 ]
154
- bpe_merges = [tuple (merge .split ()) for merge in bpe_data ]
155
- self .bpe_ranks = dict (zip (bpe_merges , range (len (bpe_merges ))))
156
- self .cache = {}
157
-
158
- # Should haved added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions
159
- self .pat = re .compile (r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""" )
160
-
161
- self .special_tokens = {}
162
- self .special_tokens_decoder = {}
163
- self .set_special_tokens (special_tokens )
164
44
165
45
self .sp = spm .SentencePieceProcessor (model_file = model_file )
166
46
self .translator = str .maketrans (" \n " , "\u2582 \u2583 " )
167
47
168
48
self .eod_id = self .encoder ['<eod>' ]
169
49
170
- self .decoder [7 ] = '\u2584 '
171
-
172
50
@property
173
51
def vocab_size (self ):
174
52
return len (self .encoder )
@@ -180,142 +58,18 @@ def __len__(self):
180
58
def eod (self ):
181
59
return self .eod_id
182
60
183
- def set_special_tokens (self , special_tokens ):
184
- """ Add a list of additional tokens to the encoder.
185
- The additional tokens are indexed starting from the last index of the
186
- current vocabulary in the order of the `special_tokens` list.
187
- """
188
- if not special_tokens :
189
- self .special_tokens = {}
190
- self .special_tokens_decoder = {}
191
- return
192
- self .special_tokens = dict ((tok , len (self .encoder ) + i ) for i , tok in enumerate (special_tokens ))
193
- self .special_tokens_decoder = {v :k for k , v in self .special_tokens .items ()}
194
- logger .info ("Special tokens {}" .format (self .special_tokens ))
195
-
196
- def bpe (self , token ):
197
- if token in self .cache :
198
- return self .cache [token ]
199
- word = tuple (token )
200
- pairs = get_pairs (word )
201
-
202
- if not pairs :
203
- return token
204
-
205
- while True :
206
- bigram = min (pairs , key = lambda pair : self .bpe_ranks .get (pair , float ('inf' )))
207
- if bigram not in self .bpe_ranks :
208
- break
209
- first , second = bigram
210
- new_word = []
211
- i = 0
212
- while i < len (word ):
213
- try :
214
- j = word .index (first , i )
215
- new_word .extend (word [i :j ])
216
- i = j
217
- except :
218
- new_word .extend (word [i :])
219
- break
220
-
221
- if word [i ] == first and i < len (word )- 1 and word [i + 1 ] == second :
222
- new_word .append (first + second )
223
- i += 2
224
- else :
225
- new_word .append (word [i ])
226
- i += 1
227
- new_word = tuple (new_word )
228
- word = new_word
229
- if len (word ) == 1 :
230
- break
231
- else :
232
- pairs = get_pairs (word )
233
- word = ' ' .join (word )
234
- self .cache [token ] = word
235
- return word
236
-
237
61
def tokenize (self , text ):
238
62
""" Tokenize a string. """
239
- bpe_tokens = []
240
63
seg_list = [x .translate (self .translator ) for x in jieba .cut (text , cut_all = False )]
241
64
new_seg = " " .join (seg_list )
242
- tmp_bpe_tokens = self .sp .encode (new_seg , out_type = str )
243
- bpe_tokens .extend (tmp_bpe_tokens )
244
- return bpe_tokens
245
-
246
- def convert_tokens_to_ids (self , tokens ):
247
- """ Converts a sequence of tokens into ids using the vocab. """
248
- ids = []
249
- if isinstance (tokens , str ) or (sys .version_info [0 ] == 2 and isinstance (tokens , unicode )):
250
- if tokens in self .special_tokens :
251
- return self .special_tokens [tokens ]
252
- else :
253
- return self .encoder .get (tokens , 0 )
254
- for token in tokens :
255
- if token in self .special_tokens :
256
- ids .append (self .special_tokens [token ])
257
- else :
258
- ids .append (self .encoder .get (token , 0 ))
259
- if len (ids ) > self .max_len :
260
- logger .warning (
261
- "Token indices sequence length is longer than the specified maximum "
262
- " sequence length for this OpenAI GPT model ({} > {}). Running this"
263
- " sequence through the model will result in indexing errors" .format (len (ids ), self .max_len )
264
- )
265
- return ids
266
-
267
- def convert_ids_to_tokens (self , ids , skip_special_tokens = False ):
268
- """Converts a sequence of ids in BPE tokens using the vocab."""
269
- tokens = []
270
- for i in ids :
271
- if i in self .special_tokens_decoder :
272
- if not skip_special_tokens :
273
- tokens .append (self .special_tokens_decoder [i ])
274
- else :
275
- tokens .append (self .decoder [i ])
276
- return tokens
65
+ return self .sp .encode (new_seg )
277
66
278
67
def encode (self , text ):
279
- res = self .convert_tokens_to_ids ( self . tokenize (text ) )
68
+ res = self .tokenize (text )
280
69
return res
281
70
282
71
def decode (self , tokens ):
283
- text = self .sp .decode ([self .decoder [x ] for x in tokens ])
284
- text = text .replace (' ' , '' ).replace ('\u2582 ' , ' ' ).replace ('\u2583 ' , '\n ' ).replace ('\u2584 ' , '<eod>' )
285
- #text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors=self.errors)
72
+ text = self .sp .decode (tokens )
73
+ text = text .replace (' ' , '' ).replace ('\u2582 ' , ' ' ).replace ('\u2583 ' , '\n ' )
286
74
return text
287
75
288
- def save_vocabulary (self , vocab_path ):
289
- """Save the tokenizer vocabulary and merge files to a directory."""
290
- if not os .path .isdir (vocab_path ):
291
- logger .error ("Vocabulary path ({}) should be a directory" .format (vocab_path ))
292
- return
293
- vocab_file = os .path .join (vocab_path , VOCAB_NAME )
294
- merge_file = os .path .join (vocab_path , MERGES_NAME )
295
- special_tokens_file = os .path .join (vocab_path , SPECIAL_TOKENS_NAME )
296
-
297
- with open (vocab_file , 'w' , encoding = 'utf-8' ) as f :
298
- f .write (json .dumps (self .encoder , ensure_ascii = False ))
299
-
300
- index = 0
301
- with open (merge_file , "w" , encoding = "utf-8" ) as writer :
302
- writer .write (u'#version: 0.2\n ' )
303
- for bpe_tokens , token_index in sorted (self .bpe_ranks .items (), key = lambda kv : kv [1 ]):
304
- if index != token_index :
305
- logger .warning ("Saving vocabulary to {}: BPE merge indices are not consecutive."
306
- " Please check that the tokenizer is not corrupted!" .format (merge_file ))
307
- index = token_index
308
- writer .write (' ' .join (bpe_tokens ) + u'\n ' )
309
- index += 1
310
-
311
- index = len (self .encoder )
312
- with open (special_tokens_file , 'w' , encoding = 'utf-8' ) as writer :
313
- for token , token_index in sorted (self .special_tokens .items (), key = lambda kv : kv [1 ]):
314
- if index != token_index :
315
- logger .warning ("Saving special tokens vocabulary to {}: BPE indices are not consecutive."
316
- " Please check that the tokenizer is not corrupted!" .format (special_tokens_file ))
317
- index = token_index
318
- writer .write (token + u'\n ' )
319
- index += 1
320
-
321
- return vocab_file , merge_file , special_tokens_file
0 commit comments