From 0891698e9cb171ed6c5359e2f18b635a5fc239f5 Mon Sep 17 00:00:00 2001 From: Daniel Johnson Date: Mon, 26 Aug 2024 12:36:44 -0700 Subject: [PATCH] Simplify torch handling to support more backends. Disables named-axis renaming if the tensor is already nameless, and replaces calls to the `isposinf`/`isneginf` primitives with simple equality checks. This reduces the set of operations that tensors must support, allowing Treescope to support more backends. With this change it should be possible to use Treescope with torch tensors that use the "mps" (Metal Performance Shaders / Apple Silicon) device backend. PiperOrigin-RevId: 667676888 --- treescope/external/torch_support.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/treescope/external/torch_support.py b/treescope/external/torch_support.py index 0b00e70..0da96d2 100644 --- a/treescope/external/torch_support.py +++ b/treescope/external/torch_support.py @@ -189,7 +189,8 @@ def get_array_summary(self, array: torch.Tensor, fast: bool) -> str: output_parts.append("(" + ", ".join(name_parts) + ")") # Drop axis names. - array = array.rename(None) + if any(name is not None for name in array.names): + array = array.rename(None) size = np.prod(array.shape) if size > 0 and size < 100_000 and not fast: is_floating = array.dtype.is_floating_point @@ -230,11 +231,11 @@ def get_array_summary(self, array: torch.Tensor, fast: bool) -> str: if ct_nan: output_parts.append(f" nan:{ct_nan:_d}") - ct_inf = torch.count_nonzero(torch.isposinf(array)) + ct_inf = torch.count_nonzero(array == torch.inf) if ct_inf: output_parts.append(f" inf:{ct_inf:_d}") - ct_neginf = torch.count_nonzero(torch.isneginf(array)) + ct_neginf = torch.count_nonzero(array == -torch.inf) if ct_neginf: output_parts.append(f" -inf:{ct_neginf:_d}")