24
24
import jax
25
25
import jax .numpy as jnp
26
26
import ml_collections
27
+ import numpy as np
27
28
from scenic .dataset_lib import dataset_utils
28
29
from scenic .dataset_lib import datasets
29
30
import tensorflow as tf
37
38
Features = preprocess_spec .Features
38
39
TfFeature = Union [tf .io .FixedLenFeature , tf .io .VarLenFeature ,
39
40
tf .io .FixedLenSequenceFeature ]
41
+ PyTree = Any
40
42
41
43
# From grain/_src/core/constants.py
42
44
GRAIN_META_DATA = [
@@ -77,6 +79,35 @@ def tf2jax_dtype(dtype: tf.dtypes.DType) -> Union[jnp.dtype, tf.dtypes.DType]:
77
79
78
80
79
81
82
+ def shard_jit (data : PyTree , global_devices : np .ndarray ) -> PyTree :
83
+ """Shards data for use in jit-based pipelines.
84
+
85
+ Note that the order of global devices for sharding data is important and
86
+ should be compatible with device order used in the rest of the trainer for
87
+ models params, state, etc.
88
+
89
+ Args:
90
+ data: PyTree of data. Assumed to already contain Numpy arrays.
91
+ global_devices: List of global devices to shard over.
92
+
93
+ Returns:
94
+ Sharded data.
95
+ """
96
+
97
+ def _shard_array (x ):
98
+ mesh = jax .sharding .Mesh (global_devices , ('devices' ,))
99
+ sharding = jax .sharding .NamedSharding (
100
+ mesh , jax .sharding .PartitionSpec ('devices' ))
101
+ local_ds = mesh .local_devices
102
+
103
+ xs = jax .device_put (np .split (x , len (local_ds ), axis = 0 ), local_ds )
104
+
105
+ global_shape = (x .shape [0 ] * jax .process_count (), * x .shape [1 :])
106
+ return jax .make_array_from_single_device_arrays (global_shape , sharding , xs )
107
+
108
+ return jax .tree_util .tree_map (_shard_array , data )
109
+
110
+
80
111
def apply_process_fn_with_populated_seed (ds : tf .data .Dataset ,
81
112
preprocess_fn : Callable [[Features ],
82
113
Features ], * ,
@@ -294,7 +325,7 @@ def _build_pipeline(
294
325
num_local_shards : int ,
295
326
rng : Union [None , jnp .ndarray , tf .Tensor ] = None ,
296
327
global_rng : Union [None , jnp .ndarray , tf .Tensor ] = None ,
297
- shuffle : bool = False
328
+ shuffle : bool = False ,
298
329
) -> Optional [Union [tf .data .Dataset , Dict [str , tf .data .Dataset ]]]:
299
330
"""Build a tf.data.Dataset pipeline using clu.deterministic_data or DMVR.
300
331
@@ -304,6 +335,8 @@ def _build_pipeline(
304
335
dataset_configs: Dataset configurations.
305
336
batch_size: Total batch size (sum for all devices).
306
337
num_local_shards: Number of local shards (usually num local devices).
338
+ <= 0 means we don't shard batches across devices, and use 1 batch dim
339
+ instead of 2.
307
340
rng: Per-host random seed (JAX format).
308
341
global_rng: Global random seed (JAX format).
309
342
shuffle: Whether to shuffle.
@@ -421,11 +454,14 @@ def _batch_and_prefetch(ds, batch_size):
421
454
return ds
422
455
423
456
# Batch to the desired output batch size:
424
- if batch_size % num_local_shards != 0 :
457
+ if num_local_shards > 0 and batch_size % num_local_shards != 0 :
425
458
raise ValueError (
426
459
f'Local (host) batch size of { batch_size } is not divisible'
427
460
f'to num_local_shard={ num_local_shards } .' )
428
- batch_dims = [num_local_shards , batch_size // num_local_shards ]
461
+ if num_local_shards > 0 :
462
+ batch_dims = [num_local_shards , batch_size // num_local_shards ]
463
+ else :
464
+ batch_dims = [batch_size ]
429
465
for batch_size in reversed (batch_dims ):
430
466
if dataset_configs .get ('padded_batch' ):
431
467
ds = ds .padded_batch (batch_size , drop_remainder = True )
@@ -488,7 +524,8 @@ def get_iterator(
488
524
ds : Union [tf .data .Dataset , Dict [str , tf .data .Dataset ]],
489
525
configs = ml_collections .ConfigDict ,
490
526
* ,
491
- return_iterator : bool = False
527
+ return_iterator : bool = False ,
528
+ devices_jit : Optional [np .ndarray ] = None ,
492
529
) -> Tuple [Union [Iterable [Any ] | None , Dict [str , Iterable [Any ] | None ]], Union [
493
530
Tuple [Any , ...], Dict [str , Tuple [Any , ...]]], Union [int , Dict [str , int ]]]:
494
531
"""Given a (dict of) Dataset object(s), returns iterators and metadata.
@@ -498,6 +535,7 @@ def get_iterator(
498
535
configs: A Config dict.
499
536
return_iterator: If False, the function returns a None instead of an
500
537
iterator.
538
+ devices_jit: List of devices to shard the data over for jit-based pipelines.
501
539
502
540
Returns:
503
541
Iterators, input specification and num_examples.
@@ -522,6 +560,10 @@ def _get_input_spec(ds):
522
560
else :
523
561
ds_it = iter (dataset )
524
562
ds_iter [dataset_name ] = map (dataset_utils .tf_to_numpy , ds_it )
563
+ if devices_jit is not None :
564
+ ds_iter [dataset_name ] = map (
565
+ functools .partial (shard_jit , global_devices = devices_jit ),
566
+ ds_iter [dataset_name ])
525
567
input_spec [dataset_name ] = _get_input_spec (dataset )
526
568
# TODO(dehghani): Add support for having different input specs.
527
569
first_input_spec = list (input_spec .values ())[0 ]
@@ -536,6 +578,10 @@ def _get_input_spec(ds):
536
578
else :
537
579
ds_it = iter (ds )
538
580
ds_iter = map (dataset_utils .tf_to_numpy , ds_it )
581
+ if devices_jit is not None :
582
+ ds_iter = map (
583
+ functools .partial (
584
+ shard_jit , global_devices = devices_jit ), ds_iter )
539
585
total_examples = sum (list (total_examples .values ()))
540
586
input_spec = _get_input_spec (ds )
541
587
else :
@@ -557,7 +603,8 @@ def get_dataset(
557
603
start_step : Optional [int ] = None ,
558
604
dtype_str : str = 'float32' ,
559
605
shuffle_seed : int = 0 ,
560
- dataset_service_address : Optional [str ] = None ) -> dataset_utils .Dataset :
606
+ dataset_service_address : Optional [str ] = None ,
607
+ devices : Optional [np .ndarray ] = None ) -> dataset_utils .Dataset :
561
608
"""Returns generators for video datasets.
562
609
563
610
Args:
@@ -571,6 +618,7 @@ def get_dataset(
571
618
dtype_str: Data type of the image. Only 'float32' is currently supported.
572
619
shuffle_seed: Unsupported; use rng instead.
573
620
dataset_service_address: Unsupported; must be None.
621
+ devices: List of devices to shard the data over for jit-based pipelines.
574
622
575
623
Returns:
576
624
A dataset_utils.Dataset() which includes a train_iter, a valid_iter,
@@ -601,7 +649,7 @@ def get_dataset(
601
649
start_step = start_step ,
602
650
dataset_configs = dataset_configs ,
603
651
batch_size = batch_size ,
604
- num_local_shards = num_shards ,
652
+ num_local_shards = num_shards if devices is None else - 1 ,
605
653
rng = train_rng ,
606
654
global_rng = global_rng ,
607
655
shuffle = True )
@@ -613,19 +661,21 @@ def get_dataset(
613
661
start_step = 0 ,
614
662
dataset_configs = dataset_configs ,
615
663
batch_size = eval_batch_size ,
616
- num_local_shards = num_shards ,
664
+ num_local_shards = num_shards if devices is None else - 1 ,
617
665
global_rng = global_rng ,
618
666
rng = eval_rng )
619
667
620
668
return_iterators = dataset_configs .get ('return_iterators' , True )
621
669
train_iter , train_input_spec , total_train_examples = get_iterator (
622
670
train_ds ,
623
671
dataset_configs .get ('train' ),
624
- return_iterator = return_iterators )
672
+ return_iterator = return_iterators ,
673
+ devices_jit = devices )
625
674
eval_iter , eval_input_spec , total_eval_examples = get_iterator (
626
675
eval_ds ,
627
676
dataset_configs .get ('eval' ),
628
- return_iterator = return_iterators )
677
+ return_iterator = return_iterators ,
678
+ devices_jit = devices )
629
679
630
680
# Testing dataset:
631
681
rng , test_rng = jax .random .split (rng )
@@ -634,14 +684,15 @@ def get_dataset(
634
684
start_step = 0 ,
635
685
dataset_configs = dataset_configs ,
636
686
batch_size = eval_batch_size ,
637
- num_local_shards = num_shards ,
687
+ num_local_shards = num_shards if devices is None else - 1 ,
638
688
global_rng = global_rng ,
639
689
rng = test_rng )
640
690
641
691
test_iter , test_input_spec , total_test_examples = get_iterator (
642
692
test_ds ,
643
693
dataset_configs .get ('test' ),
644
- return_iterator = return_iterators )
694
+ return_iterator = return_iterators ,
695
+ devices_jit = devices )
645
696
646
697
# Collect dataset metadata.
647
698
meta_data = {
0 commit comments