@@ -83,7 +83,8 @@ def is_conversational(features, data_columns):
83
83
"""Check if data is in a conversational format.
84
84
Examples:
85
85
86
- features = {'prompt': [{'content': Value(dtype='string', id=None), 'role': Value(dtype='string', id=None)}], 'completion': [{'content': Value(dtype='string', id=None), 'role': Value(dtype='string', id=None)}]}
86
+ features = {'prompt': [{'content': Value(dtype='string', id=None), 'role': Value(dtype='string', id=None)}],
87
+ 'completion': [{'content': Value(dtype='string', id=None), 'role': Value(dtype='string', id=None)}]}
87
88
data_columns = ["prompt", "completion"]
88
89
is_conversational(features, data_columns) return True.
89
90
@@ -149,11 +150,11 @@ def __init__(
149
150
self .eos_id = eos_id
150
151
self .unk_id = unk_id
151
152
152
- def map (self , features ):
153
+ def map (self , element ):
153
154
inputs , targets = [], []
154
- for i , text in enumerate (features [self .text_column_name ]):
155
+ for i , text in enumerate (element [self .text_column_name ]):
155
156
inputs += text
156
- targets += [self .unk_id ] * len (text ) if self .completion_only and features ["is_prompt" ][i ] else text
157
+ targets += [self .unk_id ] * len (text ) if self .completion_only and element ["is_prompt" ][i ] else text
157
158
if self .add_bos :
158
159
inputs = [self .bos_id ] + inputs
159
160
targets = [self .bos_id ] + targets
@@ -173,10 +174,10 @@ class HFNormalizeFeatures(grain.MapTransform):
173
174
def __init__ (self , column_name ):
174
175
self .column_name = column_name
175
176
176
- def map (self , features ):
177
+ def map (self , element ):
177
178
return {
178
- "inputs" : np .asarray (features [self .column_name ], dtype = np .int32 ),
179
- "targets" : np .asarray (features [self .column_name ], dtype = np .int32 ),
179
+ "inputs" : np .asarray (element [self .column_name ], dtype = np .int32 ),
180
+ "targets" : np .asarray (element [self .column_name ], dtype = np .int32 ),
180
181
}
181
182
182
183
@@ -214,8 +215,8 @@ def _check_shard_count(self):
214
215
if self .n_shards < (self .dataloading_host_count * self .num_threads ):
215
216
warnings .warn (
216
217
f"WARNING: Inefficient dataloading. Your train or eval dataset contains { self .n_shards } shards, "
217
- "smaller than number of host loading data. This is known to lead to inefficient dataloading. "
218
- "see https:// github.com/google/maxtext/blob/main/getting_started/Data_Input_Pipeline.md#multihost-dataloading-best-practice"
218
+ "smaller than number of host loading data. This is known to lead to inefficient dataloading. See "
219
+ "github.com/google/maxtext/blob/main/getting_started/Data_Input_Pipeline.md#multihost-dataloading-best-practice"
219
220
)
220
221
self .n_shards = self .dataloading_host_count * self .num_threads
221
222
@@ -277,15 +278,15 @@ def __init__(self, data_columns, tokenize):
277
278
else :
278
279
self .dtype = tf .int64
279
280
280
- def map (self , features ):
281
+ def map (self , element ):
281
282
def _parse (example ):
282
283
parsed = tf .io .parse_example (
283
284
example ,
284
285
{col : tf .io .FixedLenSequenceFeature ([], dtype = self .dtype , allow_missing = True ) for col in self .data_columns },
285
286
)
286
287
return parsed
287
288
288
- return _parse (features )
289
+ return _parse (element )
289
290
290
291
291
292
@dataclasses .dataclass
@@ -296,11 +297,11 @@ def __init__(self, column_names, tokenize):
296
297
self .column_names = column_names
297
298
self .tokenize = tokenize
298
299
299
- def map (self , features ):
300
+ def map (self , element ):
300
301
if self .tokenize :
301
- return {col : features [col ].numpy ()[0 ].decode () for col in self .column_names }
302
+ return {col : element [col ].numpy ()[0 ].decode () for col in self .column_names }
302
303
else :
303
- return {col : features [col ].numpy () for col in self .column_names }
304
+ return {col : element [col ].numpy () for col in self .column_names }
304
305
305
306
306
307
@dataclasses .dataclass
@@ -311,15 +312,15 @@ def __init__(self, mapping_dict, keep_old_keys=False):
311
312
self .mapping_dict = mapping_dict
312
313
self .keep_old_keys = keep_old_keys
313
314
314
- def map (self , features ):
315
+ def map (self , element ):
315
316
old_keys = set ()
316
317
for new_key , old_key in self .mapping_dict .items ():
317
- features [new_key ] = features [old_key ]
318
+ element [new_key ] = element [old_key ]
318
319
old_keys .add (old_key )
319
320
if not self .keep_old_keys :
320
321
for key in old_keys :
321
- del features [key ]
322
- return features
322
+ del element [key ]
323
+ return element
323
324
324
325
325
326
@dataclasses .dataclass
@@ -329,12 +330,12 @@ class ReformatPacking(grain.MapTransform):
329
330
def __init__ (self , column_names ):
330
331
self .column_names = column_names
331
332
332
- def map (self , data ):
333
+ def map (self , element ):
333
334
ret = {}
334
335
for col in self .column_names :
335
- ret [f"{ col } " ] = data [0 ][col ]
336
- ret [f"{ col } _segmentation" ] = data [1 ][col ]
337
- ret [f"{ col } _position" ] = data [2 ][col ]
336
+ ret [f"{ col } " ] = element [0 ][col ]
337
+ ret [f"{ col } _segmentation" ] = element [1 ][col ]
338
+ ret [f"{ col } _position" ] = element [2 ][col ]
338
339
return ret
339
340
340
341
@@ -347,35 +348,25 @@ class PadOrTrimToMaxLength(grain.MapTransform):
347
348
def __init__ (self , max_length ):
348
349
self .max_length = max_length
349
350
350
- def map (self , data : dict [str , np .ndarray ]):
351
+ def map (self , element : dict [str , np .ndarray ]):
351
352
"""map to each element"""
352
353
353
- def _max_true_length (prompts , pad_token_id ):
354
- true_lengths = []
355
- for prompt in prompts :
356
- matches = np .where (prompt == pad_token_id )[0 ]
357
- if matches .size != 0 :
358
- true_lengths .append (matches [0 ])
359
- else :
360
- true_lengths .append (prompts .shape [0 ])
361
- return true_lengths
362
-
363
354
def _pad (x , max_length ):
364
355
pad_amount = max (max_length - x .shape [0 ], 0 )
365
356
pad_amount = [(0 , pad_amount )] + [(0 , 0 )] * (len (x .shape ) - 1 )
366
357
return np .pad (x , pad_amount )[:max_length ]
367
358
368
- data_columns = list (data .keys ())
359
+ data_columns = list (element .keys ())
369
360
for data_column in data_columns :
370
- data [f"{ data_column } _segmentation" ] = (data [data_column ] != 0 ).astype (np .int32 )
371
- data [f"{ data_column } _position" ] = np .arange (data [data_column ].shape [0 ], dtype = np .int32 )
372
- data [f"{ data_column } _true_length" ] = np .array (data [data_column ].shape [0 ], dtype = np .int32 )
373
- for key , _ in data .items ():
361
+ element [f"{ data_column } _segmentation" ] = (element [data_column ] != 0 ).astype (np .int32 )
362
+ element [f"{ data_column } _position" ] = np .arange (element [data_column ].shape [0 ], dtype = np .int32 )
363
+ element [f"{ data_column } _true_length" ] = np .array (element [data_column ].shape [0 ], dtype = np .int32 )
364
+ for key , _ in element .items ():
374
365
if "true_length" not in key :
375
- data [key ] = _pad (data [key ], self .max_length )
366
+ element [key ] = _pad (element [key ], self .max_length )
376
367
# for data_column in data_columns:
377
368
# data[f"{data_column}_true_length"] = _max_true_length(data[data_column], 0)
378
- return data
369
+ return element
379
370
380
371
381
372
@dataclasses .dataclass
@@ -386,21 +377,21 @@ def __init__(self, max_length, pad_id):
386
377
self .max_length = max_length
387
378
self .pad_id = pad_id
388
379
389
- def map (self , data : dict [str , np .ndarray ]):
380
+ def map (self , element : dict [str , np .ndarray ]):
390
381
"""map to each element"""
391
382
392
383
def _pad (x , max_length , pad_id ):
393
384
pad_amount = max (max_length - x .shape [0 ], 0 )
394
385
pad_amount = [(0 , pad_amount )] + [(0 , 0 )] * (len (x .shape ) - 1 )
395
386
return np .pad (x , pad_amount , constant_values = pad_id )
396
387
397
- data_columns = list (data .keys ())
388
+ data_columns = list (element .keys ())
398
389
for data_column in data_columns :
399
- data [f"{ data_column } _segmentation" ] = (data [data_column ] != self .pad_id ).astype (np .int32 )
400
- data [f"{ data_column } _position" ] = np .arange (data [data_column ].shape [0 ], dtype = np .int32 )
401
- for key , _ in data .items ():
402
- data [key ] = _pad (data [key ], self .max_length , self .pad_id )
403
- return data
390
+ element [f"{ data_column } _segmentation" ] = (element [data_column ] != self .pad_id ).astype (np .int32 )
391
+ element [f"{ data_column } _position" ] = np .arange (element [data_column ].shape [0 ], dtype = np .int32 )
392
+ for key , _ in element .items ():
393
+ element [key ] = _pad (element [key ], self .max_length , self .pad_id )
394
+ return element
404
395
405
396
406
397
def shift_right (x , axis = 1 ):
@@ -444,5 +435,5 @@ def __init__(self, ignored_ids, axis=1):
444
435
self .ignored_ids = ignored_ids
445
436
self .axis = axis
446
437
447
- def map (self , data ):
448
- return shift_and_refine (data , ignored_ids = self .ignored_ids , axis = self .axis )
438
+ def map (self , element ):
439
+ return shift_and_refine (element , ignored_ids = self .ignored_ids , axis = self .axis )
0 commit comments