-
Notifications
You must be signed in to change notification settings - Fork 4
Assorted fixes and simplifications #121
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Changes from all commits
Commits
Show all changes
4 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,4 +2,5 @@ | |
__pycache__/ | ||
*.py[cod] | ||
.coverage | ||
.hypothesis | ||
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,7 +14,7 @@ | |
|
||
import torch | ||
|
||
from . import _dtypes_impl, _helpers | ||
from . import _dtypes_impl | ||
from . import _reductions as _impl | ||
from . import _util | ||
from ._normalizations import ( | ||
|
@@ -27,7 +27,7 @@ | |
normalize_array_like, | ||
) | ||
|
||
###### array creation routines | ||
# ###### array creation routines | ||
|
||
|
||
def copy( | ||
|
@@ -71,18 +71,16 @@ def atleast_3d(*arys: ArrayLike): | |
|
||
|
||
def _concat_check(tup, dtype, out): | ||
"""Check inputs in concatenate et al.""" | ||
if tup == (): | ||
# XXX:RuntimeError in torch, ValueError in numpy | ||
raise ValueError("need at least one array to concatenate") | ||
|
||
if out is not None: | ||
if dtype is not None: | ||
# mimic numpy | ||
raise TypeError( | ||
"concatenate() only takes `out` or `dtype` as an " | ||
"argument, but both were provided." | ||
) | ||
"""Check inputs in concatenate et al.""" | ||
if out is not None and dtype is not None: | ||
# mimic numpy | ||
raise TypeError( | ||
"concatenate() only takes `out` or `dtype` as an " | ||
"argument, but both were provided." | ||
) | ||
|
||
|
||
def _concat_cast_helper(tensors, out=None, dtype=None, casting="same_kind"): | ||
|
@@ -104,12 +102,7 @@ def _concatenate(tensors, axis=0, out=None, dtype=None, casting="same_kind"): | |
# pure torch implementation, used below and in cov/corrcoef below | ||
tensors, axis = _util.axis_none_ravel(*tensors, axis=axis) | ||
tensors = _concat_cast_helper(tensors, out, dtype, casting) | ||
|
||
try: | ||
result = torch.cat(tensors, axis) | ||
except (IndexError, RuntimeError) as e: | ||
raise _util.AxisError(*e.args) | ||
return result | ||
return torch.cat(tensors, axis) | ||
|
||
|
||
def concatenate( | ||
|
@@ -177,11 +170,7 @@ def stack( | |
tensors = _concat_cast_helper(arrays, dtype=dtype, casting=casting) | ||
result_ndim = tensors[0].ndim + 1 | ||
axis = _util.normalize_axis_index(axis, result_ndim) | ||
try: | ||
result = torch.stack(tensors, axis=axis) | ||
except RuntimeError as e: | ||
raise ValueError(*e.args) | ||
return result | ||
return torch.stack(tensors, axis=axis) | ||
|
||
|
||
# ### split ### | ||
|
@@ -352,24 +341,17 @@ def arange( | |
dtype = _dtypes_impl.default_dtypes.int_dtype | ||
dt_list = [_util._coerce_to_tensor(x).dtype for x in (start, stop, step)] | ||
dt_list.append(dtype) | ||
dtype = _dtypes_impl.result_type_impl(dt_list) | ||
target_dtype = _dtypes_impl.result_type_impl(dt_list) | ||
|
||
# work around RuntimeError: "arange_cpu" not implemented for 'ComplexFloat' | ||
if dtype.is_complex: | ||
work_dtype, target_dtype = torch.float64, dtype | ||
else: | ||
work_dtype, target_dtype = dtype, dtype | ||
work_dtype = torch.float64 if target_dtype.is_complex else target_dtype | ||
|
||
if (step > 0 and start > stop) or (step < 0 and start < stop): | ||
# empty range | ||
return torch.empty(0, dtype=target_dtype) | ||
|
||
try: | ||
result = torch.arange(start, stop, step, dtype=work_dtype) | ||
result = _util.cast_if_needed(result, target_dtype) | ||
except RuntimeError: | ||
raise ValueError("Maximum allowed size exceeded") | ||
|
||
result = torch.arange(start, stop, step, dtype=work_dtype) | ||
result = _util.cast_if_needed(result, target_dtype) | ||
return result | ||
|
||
|
||
|
@@ -593,8 +575,7 @@ def where( | |
y: Optional[ArrayLike] = None, | ||
/, | ||
): | ||
selector = (x is None) == (y is None) | ||
if not selector: | ||
if (x is None) != (y is None): | ||
raise ValueError("either both or neither of x and y should be given") | ||
|
||
if condition.dtype != torch.bool: | ||
|
@@ -603,14 +584,11 @@ def where( | |
if x is None and y is None: | ||
result = torch.where(condition) | ||
else: | ||
try: | ||
result = torch.where(condition, x, y) | ||
except RuntimeError as e: | ||
raise ValueError(*e.args) | ||
result = torch.where(condition, x, y) | ||
return result | ||
|
||
|
||
###### module-level queries of object properties | ||
# ###### module-level queries of object properties | ||
|
||
|
||
def ndim(a: ArrayLike): | ||
|
@@ -628,7 +606,7 @@ def size(a: ArrayLike, axis=None): | |
return a.shape[axis] | ||
|
||
|
||
###### shape manipulations and indexing | ||
# ###### shape manipulations and indexing | ||
|
||
|
||
def expand_dims(a: ArrayLike, axis): | ||
|
@@ -665,6 +643,7 @@ def broadcast_to(array: ArrayLike, shape, subok: NotImplementedType = False): | |
return torch.broadcast_to(array, size=shape) | ||
|
||
|
||
# This is a function from tuples to tuples, so we just reuse it | ||
from torch import broadcast_shapes | ||
|
||
|
||
|
@@ -742,16 +721,15 @@ def triu_indices(n, k=0, m=None): | |
def tril_indices_from(arr: ArrayLike, k=0): | ||
if arr.ndim != 2: | ||
raise ValueError("input array must be 2-d") | ||
result = torch.tril_indices(arr.shape[0], arr.shape[1], offset=k) | ||
return tuple(result) | ||
# Return a tensor rather than a tuple to avoid a graphbreak | ||
return torch.tril_indices(arr.shape[0], arr.shape[1], offset=k) | ||
|
||
|
||
def triu_indices_from(arr: ArrayLike, k=0): | ||
if arr.ndim != 2: | ||
raise ValueError("input array must be 2-d") | ||
result = torch.triu_indices(arr.shape[0], arr.shape[1], offset=k) | ||
# unpack: numpy returns a 2-tuple of index arrays; torch returns a 2-row tensor | ||
return tuple(result) | ||
# Return a tensor rather than a tuple to avoid a graphbreak | ||
return torch.triu_indices(arr.shape[0], arr.shape[1], offset=k) | ||
|
||
|
||
def tri( | ||
|
@@ -765,34 +743,14 @@ def tri( | |
if M is None: | ||
M = N | ||
tensor = torch.ones((N, M), dtype=dtype) | ||
tensor = torch.tril(tensor, diagonal=k) | ||
return tensor | ||
return torch.tril(tensor, diagonal=k) | ||
|
||
|
||
# ### nanfunctions ### # FIXME: this is a stub | ||
# ### nanfunctions ### | ||
|
||
|
||
def nanmean( | ||
a: ArrayLike, | ||
axis=None, | ||
dtype: Optional[DTypeLike] = None, | ||
out: Optional[OutArray] = None, | ||
keepdims=None, | ||
*, | ||
where: NotImplementedType = None, | ||
): | ||
# XXX: this needs to be rewritten | ||
if dtype is None: | ||
dtype = a.dtype | ||
if axis is None: | ||
result = a.nanmean(dtype=dtype) | ||
if keepdims: | ||
result = torch.full(a.shape, result, dtype=result.dtype) | ||
else: | ||
result = a.nanmean(dtype=dtype, dim=axis, keepdim=bool(keepdims)) | ||
if out is not None: | ||
out.copy_(result) | ||
return result | ||
def nanmean(): | ||
raise NotImplementedError | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, these all should go. This weirdness was to be able to port nanXXX tests, so should probbaly be redone in one go. This PR or a separate one, whichever works. |
||
|
||
|
||
def nanmin(): | ||
|
@@ -999,12 +957,7 @@ def clip( | |
max: Optional[ArrayLike] = None, | ||
out: Optional[OutArray] = None, | ||
): | ||
# np.clip requires both a_min and a_max not None, while ndarray.clip allows | ||
# one of them to be None. Follow the more lax version. | ||
if min is None and max is None: | ||
raise ValueError("One of max or min must be given") | ||
result = torch.clamp(a, min, max) | ||
return result | ||
return torch.clamp(a, min, max) | ||
|
||
|
||
def repeat(a: ArrayLike, repeats: ArrayLike, axis=None): | ||
|
@@ -1368,15 +1321,10 @@ def transpose(a: ArrayLike, axes=None): | |
# numpy allows both .tranpose(sh) and .transpose(*sh) | ||
# also older code uses axes being a list | ||
if axes in [(), None, (None,)]: | ||
axes = tuple(range(a.ndim))[::-1] | ||
axes = tuple(reversed(range(a.ndim))) | ||
elif len(axes) == 1: | ||
axes = axes[0] | ||
|
||
try: | ||
result = a.permute(axes) | ||
except RuntimeError: | ||
raise ValueError("axes don't match array") | ||
return result | ||
return a.permute(axes) | ||
|
||
|
||
def ravel(a: ArrayLike, order: NotImplementedType = "C"): | ||
|
@@ -1391,41 +1339,6 @@ def _flatten(a: ArrayLike, order: NotImplementedType = "C"): | |
return torch.flatten(a) | ||
|
||
|
||
# ### Type/shape etc queries ### | ||
|
||
|
||
def real(a: ArrayLike): | ||
result = torch.real(a) | ||
return result | ||
|
||
|
||
def imag(a: ArrayLike): | ||
if a.is_complex(): | ||
result = a.imag | ||
else: | ||
result = torch.zeros_like(a) | ||
return result | ||
|
||
|
||
def round_(a: ArrayLike, decimals=0, out: Optional[OutArray] = None): | ||
if a.is_floating_point(): | ||
result = torch.round(a, decimals=decimals) | ||
elif a.is_complex(): | ||
# RuntimeError: "round_cpu" not implemented for 'ComplexFloat' | ||
result = ( | ||
torch.round(a.real, decimals=decimals) | ||
+ torch.round(a.imag, decimals=decimals) * 1j | ||
) | ||
else: | ||
# RuntimeError: "round_cpu" not implemented for 'int' | ||
result = a | ||
return result | ||
|
||
|
||
around = round_ | ||
round = round_ | ||
|
||
|
||
# ### reductions ### | ||
|
||
|
||
|
@@ -1742,6 +1655,9 @@ def sinc(x: ArrayLike): | |
return torch.sinc(x) | ||
|
||
|
||
# ### Type/shape etc queries ### | ||
|
||
|
||
def real(a: ArrayLike): | ||
return torch.real(a) | ||
|
||
|
@@ -1754,7 +1670,7 @@ def imag(a: ArrayLike): | |
return result | ||
|
||
|
||
def round_(a: ArrayLike, decimals=0, out: Optional[OutArray] = None): | ||
def round(a: ArrayLike, decimals=0, out: Optional[OutArray] = None): | ||
if a.is_floating_point(): | ||
result = torch.round(a, decimals=decimals) | ||
elif a.is_complex(): | ||
|
@@ -1769,8 +1685,8 @@ def round_(a: ArrayLike, decimals=0, out: Optional[OutArray] = None): | |
return result | ||
|
||
|
||
around = round_ | ||
round = round_ | ||
around = round | ||
lezcano marked this conversation as resolved.
Show resolved
Hide resolved
|
||
round_ = round | ||
|
||
|
||
def real_if_close(a: ArrayLike, tol=100): | ||
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a break w.r.t. numpy, so should be added to the list in #73 I guess.
Also this should fail on CI, so those tests need to be adjusted.
Semantics wise, should we not be worried about
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added to that issue.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As discussed in this example, we should be fine, as it's possible to index also with a tensor while preserving the semantics.