Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Making KFAC use pjit instead of pmap #121

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 12 additions & 11 deletions examples/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,14 @@
_IMAGENET_STDDEV_RGB = (0.229, 0.224, 0.225)


def sharded_iterator(
dataset: tf.data.Dataset,
sharding: jax.sharding.NamedSharding,
) -> Iterator[Batch]:
for batch in iter(tensorflow_datasets.as_numpy(dataset)):
yield jax.device_put(batch, sharding)


def mnist_dataset(
split: str,
has_labels: bool,
Expand All @@ -43,6 +51,7 @@ def mnist_dataset(
repeat: bool,
shuffle: bool,
drop_remainder: bool,
sharding: jax.sharding.NamedSharding,
seed: Optional[int] = None,
multi_device: bool = True,
reshuffle_each_iteration: bool = True,
Expand All @@ -59,13 +68,13 @@ def mnist_dataset(
shuffle: Whether to shuffle the dataset.
drop_remainder: Whether to drop the remainder of the dataset if the number
of data points is not divisible by the total batch size.
sharding: Sharding spec for each batch.
seed: Any seed to use for random pre-processing.
multi_device: If the returned batch should take into account the number of
devices present, in which case it will return an array with shape
`(num_device, device_batch_size, ...)`.
reshuffle_each_iteration: Whether to reshuffle the dataset in a new order
after each iteration.
dtype: The returned data type of the images.

Returns:
The MNIST dataset as a tensorflow dataset.
Expand All @@ -74,14 +83,7 @@ def mnist_dataset(
# Set for multi devices vs single device
num_devices = jax.device_count() if multi_device else 1
num_local_devices = jax.local_device_count() if multi_device else 1

if multi_device:
host_batch_shape = [num_local_devices, device_batch_size]
else:
host_batch_shape = [device_batch_size]

host_batch_size = num_local_devices * device_batch_size

num_examples = tfds.builder("mnist").info.splits[split].num_examples

if num_examples % num_devices != 0:
Expand All @@ -95,8 +97,7 @@ def preprocess_batch(
"""Standard reshaping of the images to (28, 28)."""
images = tf.image.convert_image_dtype(images, dtype)
single_example_shape = [784] if flatten_images else [28, 28]
images = tf.reshape(images, host_batch_shape + single_example_shape)
labels = tf.reshape(labels, host_batch_shape)
images = tf.reshape(images, [host_batch_size] + single_example_shape)
if has_labels:
return dict(images=images, labels=labels)
else:
Expand All @@ -123,7 +124,7 @@ def preprocess_batch(

ds = ds.prefetch(tf.data.experimental.AUTOTUNE)

return iter(tensorflow_datasets.as_numpy(ds))
return sharded_iterator(ds, sharding)


def imagenet_num_examples_and_split(
Expand Down
5 changes: 4 additions & 1 deletion examples/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,10 @@ def __init__(
axis_name=self.pmap_axis_name,
)

@property
def state_sharding(self) -> jax.sharding.NamedSharding:
raise NotImplementedError()

def init(
self,
params: Params,
Expand Down Expand Up @@ -438,7 +442,6 @@ def create_optimizer(
value_func_has_aux=has_aux,
value_func_has_state=has_func_state,
value_func_has_rng=has_rng,
multi_device=True,
**kwargs,
)
elif name == "sgd":
Expand Down
Loading