Skip to content

Commit

Permalink
Add support for PmapSharding and other integer device index maps.
Browse files Browse the repository at this point in the history
Allows sharding device index maps to contain integer values,
corresponding to axes along which the arrays should be stacked
instead of concatenated. This is required to support showing
arrays with PmapSharding, including those built using
jax.device_put_sharded and jax.device_put_replicated.

PiperOrigin-RevId: 671571309
  • Loading branch information
danieldjohnson authored and Treescope Developers committed Sep 6, 2024
1 parent fd287df commit 77b10d7
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 40 deletions.
37 changes: 33 additions & 4 deletions treescope/_internal/api/arrayviz.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ def render_array(
)
if not 1 <= pixels_per_cell <= 21:
raise ValueError(
f"pixels_per_cell must be between 1 and 21 inclusive, got"
"pixels_per_cell must be between 1 and 21 inclusive, got"
f" {pixels_per_cell}"
)

Expand Down Expand Up @@ -741,8 +741,37 @@ def render_sharding_info(
raise ValueError(f"Unrecognized axis info {type(info)}")

array_shape = [info.size for info in array_axis_info]
shard_shape = sharding_info.shard_shape
num_shards = np.prod(array_shape) // np.prod(shard_shape)
orig_shard_shape = sharding_info.shard_shape
num_shards = np.prod(array_shape) // np.prod(orig_shard_shape)
orig_device_indices_map = sharding_info.device_index_to_shard_slices
# Possibly adjust the shard shape so that its length is the same as the array
# shape, and so that all items in device_indices_map are slices.
device_indices_map = {}
shard_shape = []
orig_shard_shape_index = 0
first = True
for key, ints_or_slices in orig_device_indices_map.items():
new_slices = []
for i, int_or_slc in enumerate(ints_or_slices):
if isinstance(int_or_slc, int):
new_slices.append(slice(int_or_slc, int_or_slc + 1))
if first:
shard_shape.append(1)
elif isinstance(int_or_slc, slice):
new_slices.append(int_or_slc)
if first:
shard_shape.append(orig_shard_shape[orig_shard_shape_index])
orig_shard_shape_index += 1
else:
raise ValueError(
f"Unrecognized axis slice in sharding info: {int_or_slc} at index"
f" {i} for device {key}"
)
device_indices_map[key] = tuple(new_slices)
first = False

assert len(shard_shape) == len(array_shape)
assert orig_shard_shape_index == len(orig_shard_shape)
# Compute a truncation for visualizing a single shard. Each shard will be
# shown as a shrunken version of the actual shard dimensions, roughly
# proportional to the shard sizes.
Expand Down Expand Up @@ -776,7 +805,6 @@ def render_sharding_info(
vec = np.array([True] * candidate + [False] + [True] * candidate)
shard_mask = shard_mask[..., None] * vec
# Figure out which device is responsible for each shard.
device_indices_map = sharding_info.device_index_to_shard_slices
device_to_shard_offsets = {}
shard_offsets_to_devices = collections.defaultdict(list)
for device_index, slices in device_indices_map.items():
Expand All @@ -789,6 +817,7 @@ def render_sharding_info(
else:
assert slc.stop == slc.start + shard_shape[i]
shard_offsets.append(slc.start // shard_shape[i])

shard_offsets = tuple(shard_offsets)
device_to_shard_offsets[device_index] = shard_offsets
shard_offsets_to_devices[shard_offsets].append(device_index)
Expand Down
47 changes: 11 additions & 36 deletions treescope/ndarray_adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,54 +82,29 @@ def logical_key(self) -> int:
return self.axis_logical_index


@dataclasses.dataclass(frozen=True)
class ArraySummary:
"""Summary of the contents of an array.
Any of the attributes of this summary can be None to indicate that the
corresponding statistic is not applicable to the array (either because of
dtype or because there are no finite values).
Attributes:
finite_mean: The mean of the finite values in the array.
finite_stddev: The standard deviation of the finite values in the array.
finite_min: The minimum of the finite values in the array.
finite_max: The maximum of the finite values in the array.
count_zero: The number of zero values in the array.
count_nonzero: The number of nonzero values in the array.
count_nan: The number of NaN values in the array.
count_posinf: The number of positive infinity values in the array.
count_neginf: The number of negative infinity values in the array.
"""

finite_mean: float | None
finite_stddev: float | None
finite_min: float | None
finite_max: float | None
count_zero: int | None
count_nonzero: float | None
count_nan: float | None
count_posinf: float | None
count_neginf: float | None


@dataclasses.dataclass(frozen=True)
class ShardingInfo:
"""Summary of the sharding of an array.
Attributes:
shard_shape: Shape of a single shard.
shard_shape: Shape of a single shard. Should be the same length as the
number of slices in each value of `device_index_to_shard_slices`.
device_index_to_shard_slices: A mapping from device index to the tuple of
per-axis slices of the original array that is assigned to that device. The
length of each axis slice must match the `shard_shape` along that axis (or
be the full slice ``slice(None)``).
per-axis indices or slices of the original array that is assigned to that
device. Each entry of this tuple should either be an int or a slice
object. If an int, that axis should not appear in shard_shape (e.g. the
full array is formed by stacking the shards along a new axis). If a slice,
the corresponding axis should appear in shard_shape, and the slice should
be the full slice ``slice(None)`` (if the array is not sharded over this
axis) or a slice that matches the corresponding entry in `shard_shape` (if
the full array is formed by concatenating the shards along this axis).
device_type: The type of device that the array is sharded across, as a
string (e.g. "CPU", "TPU", "GPU").
fully_replicated: Whether the array is fully replicated across all devices.
"""

shard_shape: tuple[int, ...]
device_index_to_shard_slices: dict[int, tuple[slice, ...]]
device_index_to_shard_slices: dict[int, tuple[slice | int, ...]]
device_type: str
fully_replicated: bool = False

Expand Down

0 comments on commit 77b10d7

Please sign in to comment.