From 77b10d7971f180dd98cd07ef7129053401f85c45 Mon Sep 17 00:00:00 2001 From: Daniel Johnson Date: Thu, 5 Sep 2024 17:51:54 -0700 Subject: [PATCH] Add support for PmapSharding and other integer device index maps. Allows sharding device index maps to contain integer values, corresponding to axes along which the arrays should be stacked instead of concatenated. This is required to support showing arrays with PmapSharding, including those built using jax.device_put_sharded and jax.device_put_replicated. PiperOrigin-RevId: 671571309 --- treescope/_internal/api/arrayviz.py | 37 ++++++++++++++++++++--- treescope/ndarray_adapters.py | 47 +++++++---------------------- 2 files changed, 44 insertions(+), 40 deletions(-) diff --git a/treescope/_internal/api/arrayviz.py b/treescope/_internal/api/arrayviz.py index 2a1ad8f..2487fed 100644 --- a/treescope/_internal/api/arrayviz.py +++ b/treescope/_internal/api/arrayviz.py @@ -271,7 +271,7 @@ def render_array( ) if not 1 <= pixels_per_cell <= 21: raise ValueError( - f"pixels_per_cell must be between 1 and 21 inclusive, got" + "pixels_per_cell must be between 1 and 21 inclusive, got" f" {pixels_per_cell}" ) @@ -741,8 +741,37 @@ def render_sharding_info( raise ValueError(f"Unrecognized axis info {type(info)}") array_shape = [info.size for info in array_axis_info] - shard_shape = sharding_info.shard_shape - num_shards = np.prod(array_shape) // np.prod(shard_shape) + orig_shard_shape = sharding_info.shard_shape + num_shards = np.prod(array_shape) // np.prod(orig_shard_shape) + orig_device_indices_map = sharding_info.device_index_to_shard_slices + # Possibly adjust the shard shape so that its length is the same as the array + # shape, and so that all items in device_indices_map are slices. + device_indices_map = {} + shard_shape = [] + orig_shard_shape_index = 0 + first = True + for key, ints_or_slices in orig_device_indices_map.items(): + new_slices = [] + for i, int_or_slc in enumerate(ints_or_slices): + if isinstance(int_or_slc, int): + new_slices.append(slice(int_or_slc, int_or_slc + 1)) + if first: + shard_shape.append(1) + elif isinstance(int_or_slc, slice): + new_slices.append(int_or_slc) + if first: + shard_shape.append(orig_shard_shape[orig_shard_shape_index]) + orig_shard_shape_index += 1 + else: + raise ValueError( + f"Unrecognized axis slice in sharding info: {int_or_slc} at index" + f" {i} for device {key}" + ) + device_indices_map[key] = tuple(new_slices) + first = False + + assert len(shard_shape) == len(array_shape) + assert orig_shard_shape_index == len(orig_shard_shape) # Compute a truncation for visualizing a single shard. Each shard will be # shown as a shrunken version of the actual shard dimensions, roughly # proportional to the shard sizes. @@ -776,7 +805,6 @@ def render_sharding_info( vec = np.array([True] * candidate + [False] + [True] * candidate) shard_mask = shard_mask[..., None] * vec # Figure out which device is responsible for each shard. - device_indices_map = sharding_info.device_index_to_shard_slices device_to_shard_offsets = {} shard_offsets_to_devices = collections.defaultdict(list) for device_index, slices in device_indices_map.items(): @@ -789,6 +817,7 @@ def render_sharding_info( else: assert slc.stop == slc.start + shard_shape[i] shard_offsets.append(slc.start // shard_shape[i]) + shard_offsets = tuple(shard_offsets) device_to_shard_offsets[device_index] = shard_offsets shard_offsets_to_devices[shard_offsets].append(device_index) diff --git a/treescope/ndarray_adapters.py b/treescope/ndarray_adapters.py index 5cfee05..7eb2299 100644 --- a/treescope/ndarray_adapters.py +++ b/treescope/ndarray_adapters.py @@ -82,54 +82,29 @@ def logical_key(self) -> int: return self.axis_logical_index -@dataclasses.dataclass(frozen=True) -class ArraySummary: - """Summary of the contents of an array. - - Any of the attributes of this summary can be None to indicate that the - corresponding statistic is not applicable to the array (either because of - dtype or because there are no finite values). - - Attributes: - finite_mean: The mean of the finite values in the array. - finite_stddev: The standard deviation of the finite values in the array. - finite_min: The minimum of the finite values in the array. - finite_max: The maximum of the finite values in the array. - count_zero: The number of zero values in the array. - count_nonzero: The number of nonzero values in the array. - count_nan: The number of NaN values in the array. - count_posinf: The number of positive infinity values in the array. - count_neginf: The number of negative infinity values in the array. - """ - - finite_mean: float | None - finite_stddev: float | None - finite_min: float | None - finite_max: float | None - count_zero: int | None - count_nonzero: float | None - count_nan: float | None - count_posinf: float | None - count_neginf: float | None - - @dataclasses.dataclass(frozen=True) class ShardingInfo: """Summary of the sharding of an array. Attributes: - shard_shape: Shape of a single shard. + shard_shape: Shape of a single shard. Should be the same length as the + number of slices in each value of `device_index_to_shard_slices`. device_index_to_shard_slices: A mapping from device index to the tuple of - per-axis slices of the original array that is assigned to that device. The - length of each axis slice must match the `shard_shape` along that axis (or - be the full slice ``slice(None)``). + per-axis indices or slices of the original array that is assigned to that + device. Each entry of this tuple should either be an int or a slice + object. If an int, that axis should not appear in shard_shape (e.g. the + full array is formed by stacking the shards along a new axis). If a slice, + the corresponding axis should appear in shard_shape, and the slice should + be the full slice ``slice(None)`` (if the array is not sharded over this + axis) or a slice that matches the corresponding entry in `shard_shape` (if + the full array is formed by concatenating the shards along this axis). device_type: The type of device that the array is sharded across, as a string (e.g. "CPU", "TPU", "GPU"). fully_replicated: Whether the array is fully replicated across all devices. """ shard_shape: tuple[int, ...] - device_index_to_shard_slices: dict[int, tuple[slice, ...]] + device_index_to_shard_slices: dict[int, tuple[slice | int, ...]] device_type: str fully_replicated: bool = False