From 3450c0092a644da25c42a06a56b2834484f16d47 Mon Sep 17 00:00:00 2001 From: Daniel Johnson Date: Tue, 27 Aug 2024 14:33:01 -0700 Subject: [PATCH] Simplify torch handling to support more backends. Disables named-axis renaming if the tensor is already nameless, and replaces calls to the `isposinf`/`isneginf` primitives with simple equality checks. This reduces the set of operations that tensors must support, allowing Treescope to support more backends. With this change it should be possible to use Treescope with torch tensors that use the "mps" (Metal Performance Shaders / Apple Silicon) device backend. PiperOrigin-RevId: 668141265 --- treescope/external/torch_support.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/treescope/external/torch_support.py b/treescope/external/torch_support.py index 0b00e70..0da96d2 100644 --- a/treescope/external/torch_support.py +++ b/treescope/external/torch_support.py @@ -189,7 +189,8 @@ def get_array_summary(self, array: torch.Tensor, fast: bool) -> str: output_parts.append("(" + ", ".join(name_parts) + ")") # Drop axis names. - array = array.rename(None) + if any(name is not None for name in array.names): + array = array.rename(None) size = np.prod(array.shape) if size > 0 and size < 100_000 and not fast: is_floating = array.dtype.is_floating_point @@ -230,11 +231,11 @@ def get_array_summary(self, array: torch.Tensor, fast: bool) -> str: if ct_nan: output_parts.append(f" nan:{ct_nan:_d}") - ct_inf = torch.count_nonzero(torch.isposinf(array)) + ct_inf = torch.count_nonzero(array == torch.inf) if ct_inf: output_parts.append(f" inf:{ct_inf:_d}") - ct_neginf = torch.count_nonzero(torch.isneginf(array)) + ct_neginf = torch.count_nonzero(array == -torch.inf) if ct_neginf: output_parts.append(f" -inf:{ct_neginf:_d}")