diff --git a/treescope/external/torch_support.py b/treescope/external/torch_support.py index 0b00e70..0da96d2 100644 --- a/treescope/external/torch_support.py +++ b/treescope/external/torch_support.py @@ -189,7 +189,8 @@ def get_array_summary(self, array: torch.Tensor, fast: bool) -> str: output_parts.append("(" + ", ".join(name_parts) + ")") # Drop axis names. - array = array.rename(None) + if any(name is not None for name in array.names): + array = array.rename(None) size = np.prod(array.shape) if size > 0 and size < 100_000 and not fast: is_floating = array.dtype.is_floating_point @@ -230,11 +231,11 @@ def get_array_summary(self, array: torch.Tensor, fast: bool) -> str: if ct_nan: output_parts.append(f" nan:{ct_nan:_d}") - ct_inf = torch.count_nonzero(torch.isposinf(array)) + ct_inf = torch.count_nonzero(array == torch.inf) if ct_inf: output_parts.append(f" inf:{ct_inf:_d}") - ct_neginf = torch.count_nonzero(torch.isneginf(array)) + ct_neginf = torch.count_nonzero(array == -torch.inf) if ct_neginf: output_parts.append(f" -inf:{ct_neginf:_d}")