Skip to content

Commit c800d89

Browse files
committed
[*.py] Linting low-hanging fruit across whole codebase
1 parent b5d1a1b commit c800d89

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

59 files changed

+533
-439
lines changed

MaxText/convert_deepseek_ckpt.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,10 @@
3434
import psutil
3535
from tqdm import tqdm
3636

37+
from safetensors import safe_open
38+
3739
from MaxText import max_logging
3840
from MaxText.inference_utils import str2bool
39-
from safetensors import safe_open
4041
from MaxText import llama_or_mistral_ckpt
4142

4243

MaxText/convert_gemma2_chkpt.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,11 @@
2727

2828
from typing import Any
2929
import sys
30-
from MaxText import max_logging
3130

3231

3332
import orbax
3433

34+
from MaxText import max_logging
3535
from MaxText import checkpointing
3636
from MaxText.train import save_checkpoint
3737

MaxText/convert_gemma3_chkpt.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,12 @@
2222

2323
from typing import Any
2424
import sys
25-
from MaxText import max_logging
2625

2726

2827
import orbax
2928

3029
from MaxText import checkpointing
30+
from MaxText import max_logging
3131
from MaxText.train import save_checkpoint
3232

3333
Params = dict[str, Any]

MaxText/convert_gemma_chkpt.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,12 @@
2727

2828
from typing import Any
2929
import sys
30-
from MaxText import max_logging
3130

3231

3332
import orbax
3433

3534
from MaxText import checkpointing
35+
from MaxText import max_logging
3636
from MaxText.train import save_checkpoint
3737

3838
Params = dict[str, Any]

MaxText/inference_microbenchmark.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,7 @@ def write_results(results, filename, flatten_microbenchmark_results):
220220
if flatten_microbenchmark_results:
221221
results["flattened_results"] = flatten_dict(results)
222222
if filename:
223-
with open(filename, "w", encoding="utf-8") as f:
223+
with open(filename, "wt", encoding="utf-8") as f:
224224
json.dump(results, f, indent=2)
225225
return results
226226

MaxText/input_pipeline/_distillation_data_processing.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,11 @@
2121
query [{'role': 'user', 'content': '...'}] from the target output [{'role': 'assistant', 'content': '...'}].
2222
"""
2323

24-
import datasets
25-
import transformers
26-
2724
from dataclasses import dataclass, field
2825
from typing import List
26+
27+
import datasets
28+
2929
from MaxText import max_logging
3030
from MaxText.input_pipeline import _input_pipeline_utils
3131

MaxText/input_pipeline/_grain_data_processing.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,9 +58,11 @@ def get_datasets(
5858
dataset = dataset.shuffle(seed=shuffle_seed)
5959
dataset = dataset.repeat(num_epoch)
6060
dataset = dataset[dataloading_host_index::dataloading_host_count] # sharding
61-
assert grain_worker_count <= len(
62-
dataset
63-
), f"grain worker count is currently {grain_worker_count}, exceeding the max allowable value {len(dataset)} (file shard count of a data loading host) for your dataset. Please lower grain_worker_count or increase file shard count."
61+
assert grain_worker_count <= len(dataset), (
62+
f"grain worker count is currently {grain_worker_count}, exceeding the max allowable value {len(dataset)} "
63+
f"(file shard count of a data loading host) for your dataset. "
64+
f"Please lower grain_worker_count or increase file shard count."
65+
)
6466
dataset = dataset.map(grain.experimental.ParquetIterDataset)
6567
dataset = grain.experimental.InterleaveIterDataset(dataset, cycle_length=len(dataset))
6668
dataset = grain.experimental.WindowShuffleIterDataset(dataset, window_size=100, seed=shuffle_seed)
@@ -232,6 +234,7 @@ def make_grain_eval_iterator(
232234
global_mesh,
233235
process_indices,
234236
):
237+
"""Load, preprocess dataset and return iterators"""
235238
assert (
236239
config.global_batch_size_to_load_eval % global_mesh.size == 0
237240
), "Batch size should be divisible number of global devices."

MaxText/input_pipeline/_grain_tokenizer.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,17 +44,17 @@ def __post_init__(self):
4444
if isinstance(self.sequence_length, int):
4545
self.sequence_length = [self.sequence_length] * len(self.feature_names)
4646

47-
def map(self, features: dict[str, Any]) -> dict[str, Any]:
47+
def map(self, element: dict[str, Any]) -> dict[str, Any]:
4848
"""Maps to each element."""
4949
if self._processor is None:
5050
with self._initialize_processor_lock:
5151
if self._processor is None: # Ensures only one thread initializes SPP.
5252
self._processor = self.tokenizer
5353
for feature_name, sequence_length in zip(self.feature_names, self.sequence_length, strict=True):
54-
text = features[feature_name]
54+
text = element[feature_name]
5555
token_ids = self._processor.encode(text)[:sequence_length]
56-
features[feature_name] = np.asarray(token_ids, dtype=np.int32)
57-
return features
56+
element[feature_name] = np.asarray(token_ids, dtype=np.int32)
57+
return element
5858

5959
def __getstate__(self):
6060
state = self.__dict__.copy()

MaxText/input_pipeline/_hf_data_processing.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,11 @@ def preprocessing_pipeline(
140140
)
141141
data_column_names = ("inputs", "targets")
142142
elif use_dpo:
143-
lists2array = lambda x: jax.tree.map(np.asarray, x, is_leaf=lambda x: isinstance(x, (list, tuple)))
143+
144+
def lists2array(x):
145+
"""Convert lists/tuples to array"""
146+
return jax.tree.map(np.asarray, x, is_leaf=lambda y: isinstance(y, (list, tuple)))
147+
144148
operations.append(grain.MapOperation(lists2array))
145149
else:
146150
assert len(data_column_names) == 1

MaxText/input_pipeline/_input_pipeline_utils.py

Lines changed: 40 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,8 @@ def is_conversational(features, data_columns):
8383
"""Check if data is in a conversational format.
8484
Examples:
8585
86-
features = {'prompt': [{'content': Value(dtype='string', id=None), 'role': Value(dtype='string', id=None)}], 'completion': [{'content': Value(dtype='string', id=None), 'role': Value(dtype='string', id=None)}]}
86+
features = {'prompt': [{'content': Value(dtype='string', id=None), 'role': Value(dtype='string', id=None)}],
87+
'completion': [{'content': Value(dtype='string', id=None), 'role': Value(dtype='string', id=None)}]}
8788
data_columns = ["prompt", "completion"]
8889
is_conversational(features, data_columns) return True.
8990
@@ -149,11 +150,11 @@ def __init__(
149150
self.eos_id = eos_id
150151
self.unk_id = unk_id
151152

152-
def map(self, features):
153+
def map(self, element):
153154
inputs, targets = [], []
154-
for i, text in enumerate(features[self.text_column_name]):
155+
for i, text in enumerate(element[self.text_column_name]):
155156
inputs += text
156-
targets += [self.unk_id] * len(text) if self.completion_only and features["is_prompt"][i] else text
157+
targets += [self.unk_id] * len(text) if self.completion_only and element["is_prompt"][i] else text
157158
if self.add_bos:
158159
inputs = [self.bos_id] + inputs
159160
targets = [self.bos_id] + targets
@@ -173,10 +174,10 @@ class HFNormalizeFeatures(grain.MapTransform):
173174
def __init__(self, column_name):
174175
self.column_name = column_name
175176

176-
def map(self, features):
177+
def map(self, element):
177178
return {
178-
"inputs": np.asarray(features[self.column_name], dtype=np.int32),
179-
"targets": np.asarray(features[self.column_name], dtype=np.int32),
179+
"inputs": np.asarray(element[self.column_name], dtype=np.int32),
180+
"targets": np.asarray(element[self.column_name], dtype=np.int32),
180181
}
181182

182183

@@ -214,8 +215,8 @@ def _check_shard_count(self):
214215
if self.n_shards < (self.dataloading_host_count * self.num_threads):
215216
warnings.warn(
216217
f"WARNING: Inefficient dataloading. Your train or eval dataset contains {self.n_shards} shards, "
217-
"smaller than number of host loading data. This is known to lead to inefficient dataloading. "
218-
"see https://github.com/google/maxtext/blob/main/getting_started/Data_Input_Pipeline.md#multihost-dataloading-best-practice"
218+
"smaller than number of host loading data. This is known to lead to inefficient dataloading. See"
219+
"github.com/google/maxtext/blob/main/getting_started/Data_Input_Pipeline.md#multihost-dataloading-best-practice"
219220
)
220221
self.n_shards = self.dataloading_host_count * self.num_threads
221222

@@ -277,15 +278,15 @@ def __init__(self, data_columns, tokenize):
277278
else:
278279
self.dtype = tf.int64
279280

280-
def map(self, features):
281+
def map(self, element):
281282
def _parse(example):
282283
parsed = tf.io.parse_example(
283284
example,
284285
{col: tf.io.FixedLenSequenceFeature([], dtype=self.dtype, allow_missing=True) for col in self.data_columns},
285286
)
286287
return parsed
287288

288-
return _parse(features)
289+
return _parse(element)
289290

290291

291292
@dataclasses.dataclass
@@ -296,11 +297,11 @@ def __init__(self, column_names, tokenize):
296297
self.column_names = column_names
297298
self.tokenize = tokenize
298299

299-
def map(self, features):
300+
def map(self, element):
300301
if self.tokenize:
301-
return {col: features[col].numpy()[0].decode() for col in self.column_names}
302+
return {col: element[col].numpy()[0].decode() for col in self.column_names}
302303
else:
303-
return {col: features[col].numpy() for col in self.column_names}
304+
return {col: element[col].numpy() for col in self.column_names}
304305

305306

306307
@dataclasses.dataclass
@@ -311,15 +312,15 @@ def __init__(self, mapping_dict, keep_old_keys=False):
311312
self.mapping_dict = mapping_dict
312313
self.keep_old_keys = keep_old_keys
313314

314-
def map(self, features):
315+
def map(self, element):
315316
old_keys = set()
316317
for new_key, old_key in self.mapping_dict.items():
317-
features[new_key] = features[old_key]
318+
element[new_key] = element[old_key]
318319
old_keys.add(old_key)
319320
if not self.keep_old_keys:
320321
for key in old_keys:
321-
del features[key]
322-
return features
322+
del element[key]
323+
return element
323324

324325

325326
@dataclasses.dataclass
@@ -329,12 +330,12 @@ class ReformatPacking(grain.MapTransform):
329330
def __init__(self, column_names):
330331
self.column_names = column_names
331332

332-
def map(self, data):
333+
def map(self, element):
333334
ret = {}
334335
for col in self.column_names:
335-
ret[f"{col}"] = data[0][col]
336-
ret[f"{col}_segmentation"] = data[1][col]
337-
ret[f"{col}_position"] = data[2][col]
336+
ret[f"{col}"] = element[0][col]
337+
ret[f"{col}_segmentation"] = element[1][col]
338+
ret[f"{col}_position"] = element[2][col]
338339
return ret
339340

340341

@@ -347,35 +348,25 @@ class PadOrTrimToMaxLength(grain.MapTransform):
347348
def __init__(self, max_length):
348349
self.max_length = max_length
349350

350-
def map(self, data: dict[str, np.ndarray]):
351+
def map(self, element: dict[str, np.ndarray]):
351352
"""map to each element"""
352353

353-
def _max_true_length(prompts, pad_token_id):
354-
true_lengths = []
355-
for prompt in prompts:
356-
matches = np.where(prompt == pad_token_id)[0]
357-
if matches.size != 0:
358-
true_lengths.append(matches[0])
359-
else:
360-
true_lengths.append(prompts.shape[0])
361-
return true_lengths
362-
363354
def _pad(x, max_length):
364355
pad_amount = max(max_length - x.shape[0], 0)
365356
pad_amount = [(0, pad_amount)] + [(0, 0)] * (len(x.shape) - 1)
366357
return np.pad(x, pad_amount)[:max_length]
367358

368-
data_columns = list(data.keys())
359+
data_columns = list(element.keys())
369360
for data_column in data_columns:
370-
data[f"{data_column}_segmentation"] = (data[data_column] != 0).astype(np.int32)
371-
data[f"{data_column}_position"] = np.arange(data[data_column].shape[0], dtype=np.int32)
372-
data[f"{data_column}_true_length"] = np.array(data[data_column].shape[0], dtype=np.int32)
373-
for key, _ in data.items():
361+
element[f"{data_column}_segmentation"] = (element[data_column] != 0).astype(np.int32)
362+
element[f"{data_column}_position"] = np.arange(element[data_column].shape[0], dtype=np.int32)
363+
element[f"{data_column}_true_length"] = np.array(element[data_column].shape[0], dtype=np.int32)
364+
for key, _ in element.items():
374365
if "true_length" not in key:
375-
data[key] = _pad(data[key], self.max_length)
366+
element[key] = _pad(element[key], self.max_length)
376367
# for data_column in data_columns:
377368
# data[f"{data_column}_true_length"] = _max_true_length(data[data_column], 0)
378-
return data
369+
return element
379370

380371

381372
@dataclasses.dataclass
@@ -386,21 +377,21 @@ def __init__(self, max_length, pad_id):
386377
self.max_length = max_length
387378
self.pad_id = pad_id
388379

389-
def map(self, data: dict[str, np.ndarray]):
380+
def map(self, element: dict[str, np.ndarray]):
390381
"""map to each element"""
391382

392383
def _pad(x, max_length, pad_id):
393384
pad_amount = max(max_length - x.shape[0], 0)
394385
pad_amount = [(0, pad_amount)] + [(0, 0)] * (len(x.shape) - 1)
395386
return np.pad(x, pad_amount, constant_values=pad_id)
396387

397-
data_columns = list(data.keys())
388+
data_columns = list(element.keys())
398389
for data_column in data_columns:
399-
data[f"{data_column}_segmentation"] = (data[data_column] != self.pad_id).astype(np.int32)
400-
data[f"{data_column}_position"] = np.arange(data[data_column].shape[0], dtype=np.int32)
401-
for key, _ in data.items():
402-
data[key] = _pad(data[key], self.max_length, self.pad_id)
403-
return data
390+
element[f"{data_column}_segmentation"] = (element[data_column] != self.pad_id).astype(np.int32)
391+
element[f"{data_column}_position"] = np.arange(element[data_column].shape[0], dtype=np.int32)
392+
for key, _ in element.items():
393+
element[key] = _pad(element[key], self.max_length, self.pad_id)
394+
return element
404395

405396

406397
def shift_right(x, axis=1):
@@ -444,5 +435,5 @@ def __init__(self, ignored_ids, axis=1):
444435
self.ignored_ids = ignored_ids
445436
self.axis = axis
446437

447-
def map(self, data):
448-
return shift_and_refine(data, ignored_ids=self.ignored_ids, axis=self.axis)
438+
def map(self, element):
439+
return shift_and_refine(element, ignored_ids=self.ignored_ids, axis=self.axis)

0 commit comments

Comments
 (0)