diff --git a/treescope/_internal/api/arrayviz.py b/treescope/_internal/api/arrayviz.py index 2a1ad8f..2487fed 100644 --- a/treescope/_internal/api/arrayviz.py +++ b/treescope/_internal/api/arrayviz.py @@ -271,7 +271,7 @@ def render_array( ) if not 1 <= pixels_per_cell <= 21: raise ValueError( - f"pixels_per_cell must be between 1 and 21 inclusive, got" + "pixels_per_cell must be between 1 and 21 inclusive, got" f" {pixels_per_cell}" ) @@ -741,8 +741,37 @@ def render_sharding_info( raise ValueError(f"Unrecognized axis info {type(info)}") array_shape = [info.size for info in array_axis_info] - shard_shape = sharding_info.shard_shape - num_shards = np.prod(array_shape) // np.prod(shard_shape) + orig_shard_shape = sharding_info.shard_shape + num_shards = np.prod(array_shape) // np.prod(orig_shard_shape) + orig_device_indices_map = sharding_info.device_index_to_shard_slices + # Possibly adjust the shard shape so that its length is the same as the array + # shape, and so that all items in device_indices_map are slices. + device_indices_map = {} + shard_shape = [] + orig_shard_shape_index = 0 + first = True + for key, ints_or_slices in orig_device_indices_map.items(): + new_slices = [] + for i, int_or_slc in enumerate(ints_or_slices): + if isinstance(int_or_slc, int): + new_slices.append(slice(int_or_slc, int_or_slc + 1)) + if first: + shard_shape.append(1) + elif isinstance(int_or_slc, slice): + new_slices.append(int_or_slc) + if first: + shard_shape.append(orig_shard_shape[orig_shard_shape_index]) + orig_shard_shape_index += 1 + else: + raise ValueError( + f"Unrecognized axis slice in sharding info: {int_or_slc} at index" + f" {i} for device {key}" + ) + device_indices_map[key] = tuple(new_slices) + first = False + + assert len(shard_shape) == len(array_shape) + assert orig_shard_shape_index == len(orig_shard_shape) # Compute a truncation for visualizing a single shard. Each shard will be # shown as a shrunken version of the actual shard dimensions, roughly # proportional to the shard sizes. @@ -776,7 +805,6 @@ def render_sharding_info( vec = np.array([True] * candidate + [False] + [True] * candidate) shard_mask = shard_mask[..., None] * vec # Figure out which device is responsible for each shard. - device_indices_map = sharding_info.device_index_to_shard_slices device_to_shard_offsets = {} shard_offsets_to_devices = collections.defaultdict(list) for device_index, slices in device_indices_map.items(): @@ -789,6 +817,7 @@ def render_sharding_info( else: assert slc.stop == slc.start + shard_shape[i] shard_offsets.append(slc.start // shard_shape[i]) + shard_offsets = tuple(shard_offsets) device_to_shard_offsets[device_index] = shard_offsets shard_offsets_to_devices[shard_offsets].append(device_index) diff --git a/treescope/ndarray_adapters.py b/treescope/ndarray_adapters.py index 5cfee05..7eb2299 100644 --- a/treescope/ndarray_adapters.py +++ b/treescope/ndarray_adapters.py @@ -82,54 +82,29 @@ def logical_key(self) -> int: return self.axis_logical_index -@dataclasses.dataclass(frozen=True) -class ArraySummary: - """Summary of the contents of an array. - - Any of the attributes of this summary can be None to indicate that the - corresponding statistic is not applicable to the array (either because of - dtype or because there are no finite values). - - Attributes: - finite_mean: The mean of the finite values in the array. - finite_stddev: The standard deviation of the finite values in the array. - finite_min: The minimum of the finite values in the array. - finite_max: The maximum of the finite values in the array. - count_zero: The number of zero values in the array. - count_nonzero: The number of nonzero values in the array. - count_nan: The number of NaN values in the array. - count_posinf: The number of positive infinity values in the array. - count_neginf: The number of negative infinity values in the array. - """ - - finite_mean: float | None - finite_stddev: float | None - finite_min: float | None - finite_max: float | None - count_zero: int | None - count_nonzero: float | None - count_nan: float | None - count_posinf: float | None - count_neginf: float | None - - @dataclasses.dataclass(frozen=True) class ShardingInfo: """Summary of the sharding of an array. Attributes: - shard_shape: Shape of a single shard. + shard_shape: Shape of a single shard. Should be the same length as the + number of slices in each value of `device_index_to_shard_slices`. device_index_to_shard_slices: A mapping from device index to the tuple of - per-axis slices of the original array that is assigned to that device. The - length of each axis slice must match the `shard_shape` along that axis (or - be the full slice ``slice(None)``). + per-axis indices or slices of the original array that is assigned to that + device. Each entry of this tuple should either be an int or a slice + object. If an int, that axis should not appear in shard_shape (e.g. the + full array is formed by stacking the shards along a new axis). If a slice, + the corresponding axis should appear in shard_shape, and the slice should + be the full slice ``slice(None)`` (if the array is not sharded over this + axis) or a slice that matches the corresponding entry in `shard_shape` (if + the full array is formed by concatenating the shards along this axis). device_type: The type of device that the array is sharded across, as a string (e.g. "CPU", "TPU", "GPU"). fully_replicated: Whether the array is fully replicated across all devices. """ shard_shape: tuple[int, ...] - device_index_to_shard_slices: dict[int, tuple[slice, ...]] + device_index_to_shard_slices: dict[int, tuple[slice | int, ...]] device_type: str fully_replicated: bool = False