Skip to content

Commit

Permalink
add more functions to test
Browse files Browse the repository at this point in the history
  • Loading branch information
Aarsh2001 committed Nov 5, 2023
1 parent 1d43f1e commit 9602c9f
Showing 1 changed file with 9 additions and 25 deletions.
34 changes: 9 additions & 25 deletions test_scripts/sample_torch_source_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -2987,33 +2987,17 @@ def fmax(
return torch.fmax(x1, x2, out=None)


# no docstring
def count_nonzero(
a: torch.Tensor,

@with_unsupported_dtypes({"2.0.1 and below": ("complex",)}, backend_version)
def nansum(
x: torch.Tensor,
/,
*,
axis: Optional[Union[int, Tuple[int, ...]]] = None,
keepdims: bool = False,
axis: Optional[Union[Tuple[int, ...], int]] = None,
dtype: Optional[torch.dtype] = None,
keepdims: bool = False,
out: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if isinstance(axis, list):
axis = tuple(axis)
if dtype is None:
x = torch.count_nonzero(a, dim=axis)
else:
x = torch.tensor(torch.count_nonzero(a, dim=axis), dtype=dtype)
if not keepdims:
return x
if isinstance(axis, int):
if axis == -1:
temp = x.dim() - 1
if temp < -1:
temp = 0
return x.unsqueeze(temp)
return x.unsqueeze(axis)
elif axis is not None:
for d in sorted(axis):
x = x.unsqueeze(d)
return x
return x
dtype = ivy.as_native_dtype(dtype)
return torch.nansum(x, dim=axis, keepdim=keepdims, dtype=dtype)

0 comments on commit 9602c9f

Please sign in to comment.