Skip to content

Commit

Permalink
add more functions to test
Browse files Browse the repository at this point in the history
  • Loading branch information
Aarsh2001 committed Nov 5, 2023
1 parent 92289f7 commit 46e70da
Showing 1 changed file with 32 additions and 0 deletions.
32 changes: 32 additions & 0 deletions test_scripts/sample_torch_source_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -2985,3 +2985,35 @@ def fmax(
"""
x1, x2 = promote_types_of_inputs(x1, x2)
return torch.fmax(x1, x2, out=None)


# no docstring
def count_nonzero(
a: torch.Tensor,
/,
*,
axis: Optional[Union[int, Tuple[int, ...]]] = None,
keepdims: bool = False,
dtype: Optional[torch.dtype] = None,
out: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if isinstance(axis, list):
axis = tuple(axis)
if dtype is None:
x = torch.count_nonzero(a, dim=axis)
else:
x = torch.tensor(torch.count_nonzero(a, dim=axis), dtype=dtype)
if not keepdims:
return x
if isinstance(axis, int):
if axis == -1:
temp = x.dim() - 1
if temp < -1:
temp = 0
return x.unsqueeze(temp)
return x.unsqueeze(axis)
elif axis is not None:
for d in sorted(axis):
x = x.unsqueeze(d)
return x
return x

0 comments on commit 46e70da

Please sign in to comment.