Skip to content

Commit

Permalink
Making KFAC use pjit instead of pmap
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 526595657
  • Loading branch information
botev authored and KfacJaxDev committed May 5, 2023
1 parent 3be9b1a commit 1c5ed39
Show file tree
Hide file tree
Showing 13 changed files with 626 additions and 534 deletions.
23 changes: 12 additions & 11 deletions examples/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,14 @@
_IMAGENET_STDDEV_RGB = (0.229, 0.224, 0.225)


def sharded_iterator(
dataset: tf.data.Dataset,
sharding: jax.sharding.NamedSharding,
) -> Iterator[Batch]:
for batch in iter(tensorflow_datasets.as_numpy(dataset)):
yield jax.device_put(batch, sharding)


def mnist_dataset(
split: str,
has_labels: bool,
Expand All @@ -43,6 +51,7 @@ def mnist_dataset(
repeat: bool,
shuffle: bool,
drop_remainder: bool,
sharding: jax.sharding.NamedSharding,
seed: Optional[int] = None,
multi_device: bool = True,
reshuffle_each_iteration: bool = True,
Expand All @@ -59,13 +68,13 @@ def mnist_dataset(
shuffle: Whether to shuffle the dataset.
drop_remainder: Whether to drop the remainder of the dataset if the number
of data points is not divisible by the total batch size.
sharding: Sharding spec for each batch.
seed: Any seed to use for random pre-processing.
multi_device: If the returned batch should take into account the number of
devices present, in which case it will return an array with shape
`(num_device, device_batch_size, ...)`.
reshuffle_each_iteration: Whether to reshuffle the dataset in a new order
after each iteration.
dtype: The returned data type of the images.
Returns:
The MNIST dataset as a tensorflow dataset.
Expand All @@ -74,14 +83,7 @@ def mnist_dataset(
# Set for multi devices vs single device
num_devices = jax.device_count() if multi_device else 1
num_local_devices = jax.local_device_count() if multi_device else 1

if multi_device:
host_batch_shape = [num_local_devices, device_batch_size]
else:
host_batch_shape = [device_batch_size]

host_batch_size = num_local_devices * device_batch_size

num_examples = tfds.builder("mnist").info.splits[split].num_examples

if num_examples % num_devices != 0:
Expand All @@ -95,8 +97,7 @@ def preprocess_batch(
"""Standard reshaping of the images to (28, 28)."""
images = tf.image.convert_image_dtype(images, dtype)
single_example_shape = [784] if flatten_images else [28, 28]
images = tf.reshape(images, host_batch_shape + single_example_shape)
labels = tf.reshape(labels, host_batch_shape)
images = tf.reshape(images, [host_batch_size] + single_example_shape)
if has_labels:
return dict(images=images, labels=labels)
else:
Expand All @@ -123,7 +124,7 @@ def preprocess_batch(

ds = ds.prefetch(tf.data.experimental.AUTOTUNE)

return iter(tensorflow_datasets.as_numpy(ds))
return sharded_iterator(ds, sharding)


def imagenet_num_examples_and_split(
Expand Down
5 changes: 4 additions & 1 deletion examples/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,10 @@ def __init__(
axis_name=self.pmap_axis_name,
)

@property
def state_sharding(self) -> jax.sharding.NamedSharding:
raise NotImplementedError()

def init(
self,
params: Params,
Expand Down Expand Up @@ -438,7 +442,6 @@ def create_optimizer(
value_func_has_aux=has_aux,
value_func_has_state=has_func_state,
value_func_has_rng=has_rng,
multi_device=True,
**kwargs,
)
elif name == "sgd":
Expand Down
Loading

0 comments on commit 1c5ed39

Please sign in to comment.