Skip to content

Commit 000d9b6

Browse files
Xingyi ZhouScenic Authors
authored andcommitted
Implement jit sharding.
PiperOrigin-RevId: 617009006
1 parent 592970e commit 000d9b6

File tree

2 files changed

+131
-11
lines changed

2 files changed

+131
-11
lines changed

scenic/dataset_lib/flexio/flexio.py

Lines changed: 62 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import jax
2525
import jax.numpy as jnp
2626
import ml_collections
27+
import numpy as np
2728
from scenic.dataset_lib import dataset_utils
2829
from scenic.dataset_lib import datasets
2930
import tensorflow as tf
@@ -37,6 +38,7 @@
3738
Features = preprocess_spec.Features
3839
TfFeature = Union[tf.io.FixedLenFeature, tf.io.VarLenFeature,
3940
tf.io.FixedLenSequenceFeature]
41+
PyTree = Any
4042

4143
# From grain/_src/core/constants.py
4244
GRAIN_META_DATA = [
@@ -77,6 +79,35 @@ def tf2jax_dtype(dtype: tf.dtypes.DType) -> Union[jnp.dtype, tf.dtypes.DType]:
7779

7880

7981

82+
def shard_jit(data: PyTree, global_devices: np.ndarray) -> PyTree:
83+
"""Shards data for use in jit-based pipelines.
84+
85+
Note that the order of global devices for sharding data is important and
86+
should be compatible with device order used in the rest of the trainer for
87+
models params, state, etc.
88+
89+
Args:
90+
data: PyTree of data. Assumed to already contain Numpy arrays.
91+
global_devices: List of global devices to shard over.
92+
93+
Returns:
94+
Sharded data.
95+
"""
96+
97+
def _shard_array(x):
98+
mesh = jax.sharding.Mesh(global_devices, ('devices',))
99+
sharding = jax.sharding.NamedSharding(
100+
mesh, jax.sharding.PartitionSpec('devices'))
101+
local_ds = mesh.local_devices
102+
103+
xs = jax.device_put(np.split(x, len(local_ds), axis=0), local_ds)
104+
105+
global_shape = (x.shape[0] * jax.process_count(), *x.shape[1:])
106+
return jax.make_array_from_single_device_arrays(global_shape, sharding, xs)
107+
108+
return jax.tree_util.tree_map(_shard_array, data)
109+
110+
80111
def apply_process_fn_with_populated_seed(ds: tf.data.Dataset,
81112
preprocess_fn: Callable[[Features],
82113
Features], *,
@@ -294,7 +325,7 @@ def _build_pipeline(
294325
num_local_shards: int,
295326
rng: Union[None, jnp.ndarray, tf.Tensor] = None,
296327
global_rng: Union[None, jnp.ndarray, tf.Tensor] = None,
297-
shuffle: bool = False
328+
shuffle: bool = False,
298329
) -> Optional[Union[tf.data.Dataset, Dict[str, tf.data.Dataset]]]:
299330
"""Build a tf.data.Dataset pipeline using clu.deterministic_data or DMVR.
300331
@@ -304,6 +335,8 @@ def _build_pipeline(
304335
dataset_configs: Dataset configurations.
305336
batch_size: Total batch size (sum for all devices).
306337
num_local_shards: Number of local shards (usually num local devices).
338+
<= 0 means we don't shard batches across devices, and use 1 batch dim
339+
instead of 2.
307340
rng: Per-host random seed (JAX format).
308341
global_rng: Global random seed (JAX format).
309342
shuffle: Whether to shuffle.
@@ -421,11 +454,14 @@ def _batch_and_prefetch(ds, batch_size):
421454
return ds
422455

423456
# Batch to the desired output batch size:
424-
if batch_size % num_local_shards != 0:
457+
if num_local_shards > 0 and batch_size % num_local_shards != 0:
425458
raise ValueError(
426459
f'Local (host) batch size of {batch_size} is not divisible'
427460
f'to num_local_shard={num_local_shards}.')
428-
batch_dims = [num_local_shards, batch_size // num_local_shards]
461+
if num_local_shards > 0:
462+
batch_dims = [num_local_shards, batch_size // num_local_shards]
463+
else:
464+
batch_dims = [batch_size]
429465
for batch_size in reversed(batch_dims):
430466
if dataset_configs.get('padded_batch'):
431467
ds = ds.padded_batch(batch_size, drop_remainder=True)
@@ -488,7 +524,8 @@ def get_iterator(
488524
ds: Union[tf.data.Dataset, Dict[str, tf.data.Dataset]],
489525
configs=ml_collections.ConfigDict,
490526
*,
491-
return_iterator: bool = False
527+
return_iterator: bool = False,
528+
devices_jit: Optional[np.ndarray] = None,
492529
) -> Tuple[Union[Iterable[Any] | None, Dict[str, Iterable[Any] | None]], Union[
493530
Tuple[Any, ...], Dict[str, Tuple[Any, ...]]], Union[int, Dict[str, int]]]:
494531
"""Given a (dict of) Dataset object(s), returns iterators and metadata.
@@ -498,6 +535,7 @@ def get_iterator(
498535
configs: A Config dict.
499536
return_iterator: If False, the function returns a None instead of an
500537
iterator.
538+
devices_jit: List of devices to shard the data over for jit-based pipelines.
501539
502540
Returns:
503541
Iterators, input specification and num_examples.
@@ -522,6 +560,10 @@ def _get_input_spec(ds):
522560
else:
523561
ds_it = iter(dataset)
524562
ds_iter[dataset_name] = map(dataset_utils.tf_to_numpy, ds_it)
563+
if devices_jit is not None:
564+
ds_iter[dataset_name] = map(
565+
functools.partial(shard_jit, global_devices=devices_jit),
566+
ds_iter[dataset_name])
525567
input_spec[dataset_name] = _get_input_spec(dataset)
526568
# TODO(dehghani): Add support for having different input specs.
527569
first_input_spec = list(input_spec.values())[0]
@@ -536,6 +578,10 @@ def _get_input_spec(ds):
536578
else:
537579
ds_it = iter(ds)
538580
ds_iter = map(dataset_utils.tf_to_numpy, ds_it)
581+
if devices_jit is not None:
582+
ds_iter = map(
583+
functools.partial(
584+
shard_jit, global_devices=devices_jit), ds_iter)
539585
total_examples = sum(list(total_examples.values()))
540586
input_spec = _get_input_spec(ds)
541587
else:
@@ -557,7 +603,8 @@ def get_dataset(
557603
start_step: Optional[int] = None,
558604
dtype_str: str = 'float32',
559605
shuffle_seed: int = 0,
560-
dataset_service_address: Optional[str] = None) -> dataset_utils.Dataset:
606+
dataset_service_address: Optional[str] = None,
607+
devices: Optional[np.ndarray] = None) -> dataset_utils.Dataset:
561608
"""Returns generators for video datasets.
562609
563610
Args:
@@ -571,6 +618,7 @@ def get_dataset(
571618
dtype_str: Data type of the image. Only 'float32' is currently supported.
572619
shuffle_seed: Unsupported; use rng instead.
573620
dataset_service_address: Unsupported; must be None.
621+
devices: List of devices to shard the data over for jit-based pipelines.
574622
575623
Returns:
576624
A dataset_utils.Dataset() which includes a train_iter, a valid_iter,
@@ -601,7 +649,7 @@ def get_dataset(
601649
start_step=start_step,
602650
dataset_configs=dataset_configs,
603651
batch_size=batch_size,
604-
num_local_shards=num_shards,
652+
num_local_shards=num_shards if devices is None else -1,
605653
rng=train_rng,
606654
global_rng=global_rng,
607655
shuffle=True)
@@ -613,19 +661,21 @@ def get_dataset(
613661
start_step=0,
614662
dataset_configs=dataset_configs,
615663
batch_size=eval_batch_size,
616-
num_local_shards=num_shards,
664+
num_local_shards=num_shards if devices is None else -1,
617665
global_rng=global_rng,
618666
rng=eval_rng)
619667

620668
return_iterators = dataset_configs.get('return_iterators', True)
621669
train_iter, train_input_spec, total_train_examples = get_iterator(
622670
train_ds,
623671
dataset_configs.get('train'),
624-
return_iterator=return_iterators)
672+
return_iterator=return_iterators,
673+
devices_jit=devices)
625674
eval_iter, eval_input_spec, total_eval_examples = get_iterator(
626675
eval_ds,
627676
dataset_configs.get('eval'),
628-
return_iterator=return_iterators)
677+
return_iterator=return_iterators,
678+
devices_jit=devices)
629679

630680
# Testing dataset:
631681
rng, test_rng = jax.random.split(rng)
@@ -634,14 +684,15 @@ def get_dataset(
634684
start_step=0,
635685
dataset_configs=dataset_configs,
636686
batch_size=eval_batch_size,
637-
num_local_shards=num_shards,
687+
num_local_shards=num_shards if devices is None else -1,
638688
global_rng=global_rng,
639689
rng=test_rng)
640690

641691
test_iter, test_input_spec, total_test_examples = get_iterator(
642692
test_ds,
643693
dataset_configs.get('test'),
644-
return_iterator=return_iterators)
694+
return_iterator=return_iterators,
695+
devices_jit=devices)
645696

646697
# Collect dataset metadata.
647698
meta_data = {

scenic/dataset_lib/flexio/tests/flexio_test.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
from grand_vision.preprocessing import image_ops
2020
from grand_vision.preprocessing import modalities
2121
import jax
22+
from jax._src import array as jax_array
23+
from jax.experimental import mesh_utils
2224
import ml_collections
2325
from scenic.dataset_lib.flexio import flexio
2426
import tensorflow as tf
@@ -92,6 +94,73 @@ def test_tfds_datasets(self, train_tfds_name, eval_tfds_name):
9294
self.assertDictEqual(
9395
jax.tree_util.tree_map(lambda x: x.shape, valid_data), expected_shapes)
9496

97+
@parameterized.named_parameters(
98+
('coco_coco', 'coco', 'coco'),
99+
)
100+
def test_sharded_tfds_datasets(self, train_tfds_name, eval_tfds_name):
101+
"""Test TFDS dataset loading."""
102+
dataset_configs = D({
103+
'train': {
104+
'sources': [D({
105+
'source': 'tfds',
106+
'tfds_name': train_tfds_name,
107+
'split': 'train',
108+
'shuffle_buffer_size': 2,
109+
'cache': False,
110+
'preproc_spec': 'decode_coco_example|crop_or_pad(64, 16)',
111+
})],
112+
'preproc_spec': 'crop_or_pad_meta_data(16, 16)',
113+
},
114+
'eval': {
115+
'sources': [D({
116+
'source': 'tfds',
117+
'tfds_name': eval_tfds_name,
118+
'split': 'validation',
119+
'shuffle_buffer_size': 1,
120+
'cache': False,
121+
'preproc_spec': 'decode_coco_example',
122+
})],
123+
'preproc_spec': ('central_crop(64)'
124+
'|crop_or_pad(64, 16)'
125+
'|crop_or_pad_meta_data(16, 16)'),
126+
},
127+
'pp_libs': [ # We override the default ops.
128+
'grand_vision.preprocessing.image_ops']
129+
})
130+
rng = jax.random.PRNGKey(0)
131+
devices = mesh_utils.create_device_mesh((jax.device_count(),))
132+
ds = flexio.get_dataset(
133+
batch_size=8,
134+
eval_batch_size=8,
135+
num_shards=jax.local_device_count(),
136+
rng=rng,
137+
dataset_configs=dataset_configs,
138+
devices=devices)
139+
prefix_shape = (8,)
140+
expected_shapes = {
141+
modalities.ANNOTATION_ID: prefix_shape + (16,),
142+
modalities.AREA: prefix_shape + (16,),
143+
modalities.BOXES: prefix_shape + (16, 4),
144+
modalities.CROWD: prefix_shape + (16,),
145+
modalities.IMAGE: prefix_shape + (64, 64, 3),
146+
modalities.IMAGE_ID: prefix_shape,
147+
modalities.IMAGE_PADDING_MASK: prefix_shape + (64, 64),
148+
modalities.INSTANCE_LABELS: prefix_shape + (16,),
149+
modalities.ORIGINAL_SIZE: prefix_shape + (2,),
150+
image_ops.SEED_KEY: prefix_shape + (2,)
151+
}
152+
train_data = next(ds.train_iter)
153+
valid_data = next(ds.valid_iter)
154+
self.assertDictEqual(
155+
jax.tree_util.tree_map(lambda x: x.shape, train_data), expected_shapes)
156+
self.assertDictEqual(
157+
jax.tree_util.tree_map(lambda x: x.shape, valid_data), expected_shapes)
158+
self.assertDictEqual(
159+
jax.tree_util.tree_map(type, train_data),
160+
jax.tree_util.tree_map(lambda x: jax_array.ArrayImpl, train_data))
161+
self.assertDictEqual(
162+
jax.tree_util.tree_map(type, valid_data),
163+
jax.tree_util.tree_map(lambda x: jax_array.ArrayImpl, valid_data))
95164

96165
if __name__ == '__main__':
97166
absltest.main()

0 commit comments

Comments
 (0)