Skip to content

Commit

Permalink
Simplify torch handling to support more backends.
Browse files Browse the repository at this point in the history
Disables named-axis renaming if the tensor is already nameless, and
replaces calls to the `isposinf`/`isneginf` primitives with simple
equality checks. This reduces the set of operations that tensors
must support, allowing Treescope to support more backends.

With this change it should be possible to use Treescope with torch
tensors that use the "mps" (Metal Performance Shaders / Apple Silicon)
device backend.

PiperOrigin-RevId: 667676888
  • Loading branch information
danieldjohnson authored and Treescope Developers committed Aug 27, 2024
1 parent 1b8eca6 commit 0891698
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions treescope/external/torch_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,8 @@ def get_array_summary(self, array: torch.Tensor, fast: bool) -> str:
output_parts.append("(" + ", ".join(name_parts) + ")")

# Drop axis names.
array = array.rename(None)
if any(name is not None for name in array.names):
array = array.rename(None)
size = np.prod(array.shape)
if size > 0 and size < 100_000 and not fast:
is_floating = array.dtype.is_floating_point
Expand Down Expand Up @@ -230,11 +231,11 @@ def get_array_summary(self, array: torch.Tensor, fast: bool) -> str:
if ct_nan:
output_parts.append(f" nan:{ct_nan:_d}")

ct_inf = torch.count_nonzero(torch.isposinf(array))
ct_inf = torch.count_nonzero(array == torch.inf)
if ct_inf:
output_parts.append(f" inf:{ct_inf:_d}")

ct_neginf = torch.count_nonzero(torch.isneginf(array))
ct_neginf = torch.count_nonzero(array == -torch.inf)
if ct_neginf:
output_parts.append(f" -inf:{ct_neginf:_d}")

Expand Down

0 comments on commit 0891698

Please sign in to comment.