Skip to content

Assorted fixes and simplifications #121

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Apr 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,5 @@
__pycache__/
*.py[cod]
.coverage
.hypothesis

158 changes: 37 additions & 121 deletions torch_np/_funcs_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import torch

from . import _dtypes_impl, _helpers
from . import _dtypes_impl
from . import _reductions as _impl
from . import _util
from ._normalizations import (
Expand All @@ -27,7 +27,7 @@
normalize_array_like,
)

###### array creation routines
# ###### array creation routines


def copy(
Expand Down Expand Up @@ -71,18 +71,16 @@ def atleast_3d(*arys: ArrayLike):


def _concat_check(tup, dtype, out):
"""Check inputs in concatenate et al."""
if tup == ():
# XXX:RuntimeError in torch, ValueError in numpy
raise ValueError("need at least one array to concatenate")

if out is not None:
if dtype is not None:
# mimic numpy
raise TypeError(
"concatenate() only takes `out` or `dtype` as an "
"argument, but both were provided."
)
"""Check inputs in concatenate et al."""
if out is not None and dtype is not None:
# mimic numpy
raise TypeError(
"concatenate() only takes `out` or `dtype` as an "
"argument, but both were provided."
)


def _concat_cast_helper(tensors, out=None, dtype=None, casting="same_kind"):
Expand All @@ -104,12 +102,7 @@ def _concatenate(tensors, axis=0, out=None, dtype=None, casting="same_kind"):
# pure torch implementation, used below and in cov/corrcoef below
tensors, axis = _util.axis_none_ravel(*tensors, axis=axis)
tensors = _concat_cast_helper(tensors, out, dtype, casting)

try:
result = torch.cat(tensors, axis)
except (IndexError, RuntimeError) as e:
raise _util.AxisError(*e.args)
return result
return torch.cat(tensors, axis)


def concatenate(
Expand Down Expand Up @@ -177,11 +170,7 @@ def stack(
tensors = _concat_cast_helper(arrays, dtype=dtype, casting=casting)
result_ndim = tensors[0].ndim + 1
axis = _util.normalize_axis_index(axis, result_ndim)
try:
result = torch.stack(tensors, axis=axis)
except RuntimeError as e:
raise ValueError(*e.args)
return result
return torch.stack(tensors, axis=axis)


# ### split ###
Expand Down Expand Up @@ -352,24 +341,17 @@ def arange(
dtype = _dtypes_impl.default_dtypes.int_dtype
dt_list = [_util._coerce_to_tensor(x).dtype for x in (start, stop, step)]
dt_list.append(dtype)
dtype = _dtypes_impl.result_type_impl(dt_list)
target_dtype = _dtypes_impl.result_type_impl(dt_list)

# work around RuntimeError: "arange_cpu" not implemented for 'ComplexFloat'
if dtype.is_complex:
work_dtype, target_dtype = torch.float64, dtype
else:
work_dtype, target_dtype = dtype, dtype
work_dtype = torch.float64 if target_dtype.is_complex else target_dtype

if (step > 0 and start > stop) or (step < 0 and start < stop):
# empty range
return torch.empty(0, dtype=target_dtype)

try:
result = torch.arange(start, stop, step, dtype=work_dtype)
result = _util.cast_if_needed(result, target_dtype)
except RuntimeError:
raise ValueError("Maximum allowed size exceeded")

result = torch.arange(start, stop, step, dtype=work_dtype)
result = _util.cast_if_needed(result, target_dtype)
return result


Expand Down Expand Up @@ -593,8 +575,7 @@ def where(
y: Optional[ArrayLike] = None,
/,
):
selector = (x is None) == (y is None)
if not selector:
if (x is None) != (y is None):
raise ValueError("either both or neither of x and y should be given")

if condition.dtype != torch.bool:
Expand All @@ -603,14 +584,11 @@ def where(
if x is None and y is None:
result = torch.where(condition)
else:
try:
result = torch.where(condition, x, y)
except RuntimeError as e:
raise ValueError(*e.args)
result = torch.where(condition, x, y)
return result


###### module-level queries of object properties
# ###### module-level queries of object properties


def ndim(a: ArrayLike):
Expand All @@ -628,7 +606,7 @@ def size(a: ArrayLike, axis=None):
return a.shape[axis]


###### shape manipulations and indexing
# ###### shape manipulations and indexing


def expand_dims(a: ArrayLike, axis):
Expand Down Expand Up @@ -665,6 +643,7 @@ def broadcast_to(array: ArrayLike, shape, subok: NotImplementedType = False):
return torch.broadcast_to(array, size=shape)


# This is a function from tuples to tuples, so we just reuse it
from torch import broadcast_shapes


Expand Down Expand Up @@ -742,16 +721,15 @@ def triu_indices(n, k=0, m=None):
def tril_indices_from(arr: ArrayLike, k=0):
if arr.ndim != 2:
raise ValueError("input array must be 2-d")
result = torch.tril_indices(arr.shape[0], arr.shape[1], offset=k)
return tuple(result)
# Return a tensor rather than a tuple to avoid a graphbreak
return torch.tril_indices(arr.shape[0], arr.shape[1], offset=k)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a break w.r.t. numpy, so should be added to the list in #73 I guess.
Also this should fail on CI, so those tests need to be adjusted.

Semantics wise, should we not be worried about

In [40]: a = np.arange(6).reshape(3, 2)

In [41]: idx = np.tril_indices_from(a)

In [42]: a[idx]
Out[42]: array([0, 2, 3, 4, 5])

In [43]: a[np.asarray(idx)]
Out[43]: 
array([[[0, 1],
        [2, 3],
        [2, 3],
        [4, 5],
        [4, 5]],

       [[0, 1],
        [0, 1],
        [2, 3],
        [0, 1],
        [2, 3]]])

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added to that issue.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ISTM the main usage of these is to actually index something with the result, so if we return a tensor, the next stage for a user is to swear and throw an explicit tuple() call. So we're back to the same graph break and only earn a swear from a user

As discussed in this example, we should be fine, as it's possible to index also with a tensor while preserving the semantics.



def triu_indices_from(arr: ArrayLike, k=0):
if arr.ndim != 2:
raise ValueError("input array must be 2-d")
result = torch.triu_indices(arr.shape[0], arr.shape[1], offset=k)
# unpack: numpy returns a 2-tuple of index arrays; torch returns a 2-row tensor
return tuple(result)
# Return a tensor rather than a tuple to avoid a graphbreak
return torch.triu_indices(arr.shape[0], arr.shape[1], offset=k)


def tri(
Expand All @@ -765,34 +743,14 @@ def tri(
if M is None:
M = N
tensor = torch.ones((N, M), dtype=dtype)
tensor = torch.tril(tensor, diagonal=k)
return tensor
return torch.tril(tensor, diagonal=k)


# ### nanfunctions ### # FIXME: this is a stub
# ### nanfunctions ###


def nanmean(
a: ArrayLike,
axis=None,
dtype: Optional[DTypeLike] = None,
out: Optional[OutArray] = None,
keepdims=None,
*,
where: NotImplementedType = None,
):
# XXX: this needs to be rewritten
if dtype is None:
dtype = a.dtype
if axis is None:
result = a.nanmean(dtype=dtype)
if keepdims:
result = torch.full(a.shape, result, dtype=result.dtype)
else:
result = a.nanmean(dtype=dtype, dim=axis, keepdim=bool(keepdims))
if out is not None:
out.copy_(result)
return result
def nanmean():
raise NotImplementedError
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, these all should go. This weirdness was to be able to port nanXXX tests, so should probbaly be redone in one go. This PR or a separate one, whichever works.



def nanmin():
Expand Down Expand Up @@ -999,12 +957,7 @@ def clip(
max: Optional[ArrayLike] = None,
out: Optional[OutArray] = None,
):
# np.clip requires both a_min and a_max not None, while ndarray.clip allows
# one of them to be None. Follow the more lax version.
if min is None and max is None:
raise ValueError("One of max or min must be given")
result = torch.clamp(a, min, max)
return result
return torch.clamp(a, min, max)


def repeat(a: ArrayLike, repeats: ArrayLike, axis=None):
Expand Down Expand Up @@ -1368,15 +1321,10 @@ def transpose(a: ArrayLike, axes=None):
# numpy allows both .tranpose(sh) and .transpose(*sh)
# also older code uses axes being a list
if axes in [(), None, (None,)]:
axes = tuple(range(a.ndim))[::-1]
axes = tuple(reversed(range(a.ndim)))
elif len(axes) == 1:
axes = axes[0]

try:
result = a.permute(axes)
except RuntimeError:
raise ValueError("axes don't match array")
return result
return a.permute(axes)


def ravel(a: ArrayLike, order: NotImplementedType = "C"):
Expand All @@ -1391,41 +1339,6 @@ def _flatten(a: ArrayLike, order: NotImplementedType = "C"):
return torch.flatten(a)


# ### Type/shape etc queries ###


def real(a: ArrayLike):
result = torch.real(a)
return result


def imag(a: ArrayLike):
if a.is_complex():
result = a.imag
else:
result = torch.zeros_like(a)
return result


def round_(a: ArrayLike, decimals=0, out: Optional[OutArray] = None):
if a.is_floating_point():
result = torch.round(a, decimals=decimals)
elif a.is_complex():
# RuntimeError: "round_cpu" not implemented for 'ComplexFloat'
result = (
torch.round(a.real, decimals=decimals)
+ torch.round(a.imag, decimals=decimals) * 1j
)
else:
# RuntimeError: "round_cpu" not implemented for 'int'
result = a
return result


around = round_
round = round_


# ### reductions ###


Expand Down Expand Up @@ -1742,6 +1655,9 @@ def sinc(x: ArrayLike):
return torch.sinc(x)


# ### Type/shape etc queries ###


def real(a: ArrayLike):
return torch.real(a)

Expand All @@ -1754,7 +1670,7 @@ def imag(a: ArrayLike):
return result


def round_(a: ArrayLike, decimals=0, out: Optional[OutArray] = None):
def round(a: ArrayLike, decimals=0, out: Optional[OutArray] = None):
if a.is_floating_point():
result = torch.round(a, decimals=decimals)
elif a.is_complex():
Expand All @@ -1769,8 +1685,8 @@ def round_(a: ArrayLike, decimals=0, out: Optional[OutArray] = None):
return result


around = round_
round = round_
around = round
round_ = round


def real_if_close(a: ArrayLike, tol=100):
Expand Down
16 changes: 8 additions & 8 deletions torch_np/tests/numpy_tests/core/test_multiarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1532,9 +1532,9 @@ def test_squeeze(self):
def test_transpose(self):
a = np.array([[1, 2], [3, 4]])
assert_equal(a.transpose(), [[1, 3], [2, 4]])
assert_raises(ValueError, lambda: a.transpose(0))
assert_raises(ValueError, lambda: a.transpose(0, 0))
assert_raises(ValueError, lambda: a.transpose(0, 1, 2))
assert_raises((RuntimeError, ValueError), lambda: a.transpose(0))
assert_raises((RuntimeError, ValueError), lambda: a.transpose(0, 0))
assert_raises((RuntimeError, ValueError), lambda: a.transpose(0, 1, 2))

def test_sort(self):
# test ordering for floats and complex containing nans. It is only
Expand Down Expand Up @@ -7270,8 +7270,8 @@ def test_error(self):
c = [True, True]
a = np.ones((4, 5))
b = np.ones((5, 5))
assert_raises(ValueError, np.where, c, a, a)
assert_raises(ValueError, np.where, c[0], a, b)
assert_raises((RuntimeError, ValueError), np.where, c, a, a)
assert_raises((RuntimeError, ValueError), np.where, c[0], a, b)

def test_empty_result(self):
# pass empty where result through an assignment which reads the data of
Expand Down Expand Up @@ -7497,14 +7497,14 @@ def test_view_discard_refcount(self):

class TestArange:
def test_infinite(self):
assert_raises_regex(
ValueError, "size exceeded",
assert_raises(
(RuntimeError, ValueError), # "unsupported range",
np.arange, 0, np.inf
)

def test_nan_step(self):
assert_raises(
ValueError, # "cannot compute length",
(RuntimeError, ValueError), # "cannot compute length",
np.arange, 0, 1, np.nan
)

Expand Down
Loading