Skip to content

Commit

Permalink
WF Array: use ufuncs in Array operators [EC-1072] (#6517)
Browse files Browse the repository at this point in the history
GitOrigin-RevId: 70d7a3354d0b6f707e624dfdbf557b825707eab6
  • Loading branch information
gjoseph92 authored and Descartes Labs Build committed Mar 11, 2020
1 parent 6c9dcb7 commit cb9384e
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 75 deletions.
145 changes: 71 additions & 74 deletions descarteslabs/workflows/types/array/array_.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from ... import env
from descarteslabs.common.graft import client
from ...cereal import serializable
from ..core import GenericProxytype, typecheck_promote, ProxyTypeError
from ..core import GenericProxytype, ProxyTypeError
from ..containers import Slice, Tuple, List, Dict
from ..primitives import Int, Float, Bool, NoneType

Expand All @@ -13,6 +13,24 @@
WF_TO_DTYPE_KIND = dict(zip(DTYPE_KIND_TO_WF.values(), DTYPE_KIND_TO_WF.keys()))


def _delayed_numpy_overrides():
# avoid circular imports
from descarteslabs.workflows.types.numpy import numpy_overrides

return numpy_overrides


def allow_reflect(func):
@functools.wraps(func)
def wrapped(*args):
try:
return func(*args)
except ProxyTypeError:
return NotImplemented

return wrapped


@serializable()
class Array(GenericProxytype):
"""
Expand Down Expand Up @@ -286,7 +304,7 @@ def __array_function__(self, func, types, args, kwargs):
args: arguments directly passed from the original call
kwargs: kwargs directly passed from the original call
"""
from descarteslabs.workflows.types.numpy import numpy_overrides
numpy_overrides = _delayed_numpy_overrides()

if func not in numpy_overrides.HANDLED_FUNCTIONS:
raise NotImplementedError(
Expand Down Expand Up @@ -320,7 +338,7 @@ def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
inputs: Tuple of the input arguments to ufunc
kwargs: Dict of optional input arguments to ufunc
"""
from descarteslabs.workflows.types.numpy import numpy_overrides
numpy_overrides = _delayed_numpy_overrides()

if method == "__call__":
if ufunc.__name__ not in numpy_overrides.HANDLED_UFUNCS:
Expand All @@ -332,93 +350,101 @@ def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
return NotImplemented

def __neg__(self):
return self._from_apply("neg", self)
return _delayed_numpy_overrides().negative(self)

def __pos__(self):
return self._from_apply("pos", self)

def __abs__(self):
return self._from_apply("abs", self)
return _delayed_numpy_overrides().absolute(self)

@typecheck_promote((lambda: Array, Int, Float))
@allow_reflect
def __lt__(self, other):
return self._result_type(other, is_bool=True)._from_apply("lt", self, other)
return _delayed_numpy_overrides().less(self, other)

@typecheck_promote((lambda: Array, Int, Float))
@allow_reflect
def __le__(self, other):
return self._result_type(other, is_bool=True)._from_apply("le", self, other)
return _delayed_numpy_overrides().less_equal(self, other)

@typecheck_promote((lambda: Array, Int, Float))
@allow_reflect
def __gt__(self, other):
return self._result_type(other, is_bool=True)._from_apply("gt", self, other)
return _delayed_numpy_overrides().greater(self, other)

@typecheck_promote((lambda: Array, Int, Float))
@allow_reflect
def __ge__(self, other):
return self._result_type(other, is_bool=True)._from_apply("ge", self, other)
return _delayed_numpy_overrides().greater_equal(self, other)

@allow_reflect
def __eq__(self, other):
return _delayed_numpy_overrides().equal(self, other)

@allow_reflect
def __ne__(self, other):
return _delayed_numpy_overrides().not_equal(self, other)

@typecheck_promote((lambda: Array, Int, Float))
@allow_reflect
def __add__(self, other):
return self._result_type(other)._from_apply("add", self, other)
return _delayed_numpy_overrides().add(self, other)

@typecheck_promote((lambda: Array, Int, Float))
@allow_reflect
def __sub__(self, other):
return self._result_type(other)._from_apply("sub", self, other)
return _delayed_numpy_overrides().subtract(self, other)

@typecheck_promote((lambda: Array, Int, Float))
@allow_reflect
def __mul__(self, other):
return self._result_type(other)._from_apply("mul", self, other)
return _delayed_numpy_overrides().multiply(self, other)

@typecheck_promote((lambda: Array, Int, Float))
@allow_reflect
def __div__(self, other):
return self._result_type(other)._from_apply("div", self, other)
return _delayed_numpy_overrides().divide(self, other)

@typecheck_promote((lambda: Array, Int, Float))
@allow_reflect
def __floordiv__(self, other):
return self._result_type(other)._from_apply("floordiv", self, other)
return _delayed_numpy_overrides().floor_divide(self, other)

@typecheck_promote((lambda: Array, Int, Float))
@allow_reflect
def __truediv__(self, other):
return self._result_type(other)._from_apply("truediv", self, other)
return _delayed_numpy_overrides().true_divide(self, other)

@typecheck_promote((lambda: Array, Int, Float))
@allow_reflect
def __mod__(self, other):
return self._result_type(other)._from_apply("mod", self, other)
return _delayed_numpy_overrides().mod(self, other)

@typecheck_promote((lambda: Array, Int, Float))
@allow_reflect
def __pow__(self, other):
return self._result_type(other)._from_apply("pow", self, other)
return _delayed_numpy_overrides().power(self, other)

@typecheck_promote((lambda: Array, Int, Float))
@allow_reflect
def __radd__(self, other):
return self._result_type(other)._from_apply("add", other, self)
return _delayed_numpy_overrides().add(other, self)

@typecheck_promote((lambda: Array, Int, Float))
@allow_reflect
def __rsub__(self, other):
return self._result_type(other)._from_apply("sub", other, self)
return _delayed_numpy_overrides().subtract(other, self)

@typecheck_promote((lambda: Array, Int, Float))
@allow_reflect
def __rmul__(self, other):
return self._result_type(other)._from_apply("mul", other, self)
return _delayed_numpy_overrides().multiply(other, self)

@typecheck_promote((lambda: Array, Int, Float))
@allow_reflect
def __rdiv__(self, other):
return self._result_type(other)._from_apply("div", other, self)
return _delayed_numpy_overrides().divide(other, self)

@typecheck_promote((lambda: Array, Int, Float))
@allow_reflect
def __rfloordiv__(self, other):
return self._result_type(other)._from_apply("floordiv", other, self)
return _delayed_numpy_overrides().floor_divide(other, self)

@typecheck_promote((lambda: Array, Int, Float))
@allow_reflect
def __rtruediv__(self, other):
return self._result_type(other)._from_apply("truediv", other, self)
return _delayed_numpy_overrides().true_divide(other, self)

@typecheck_promote((lambda: Array, Int, Float))
@allow_reflect
def __rmod__(self, other):
return self._result_type(other)._from_apply("mod", other, self)
return _delayed_numpy_overrides().mod(other, self)

@typecheck_promote((lambda: Array, Int, Float))
@allow_reflect
def __rpow__(self, other):
return self._result_type(other)._from_apply("pow", other, self)
return _delayed_numpy_overrides().power(other, self)

def min(self, axis=None):
""" Minimum along a given axis.
Expand Down Expand Up @@ -573,35 +599,6 @@ def _stats_return_type(self, axis):
return_type = type(self)._generictype[self.dtype, self.ndim - 1]
return return_type

def _result_type(self, other, is_bool=False):
result_generictype = type(self)._generictype
try:
other_generictype = type(other)._generictype
except AttributeError:
pass
else:
if issubclass(other_generictype, result_generictype):
result_generictype = other_generictype
dtype = self._result_dtype(other, is_bool)
ndim = self.ndim
other_ndim = getattr(other, "ndim", -1)
if ndim < other_ndim:
ndim = other_ndim
return result_generictype[dtype, ndim]

def _result_dtype(self, other, is_bool=False):
if is_bool:
return Bool
other_dtype = getattr(other, "dtype", None)
# If either are Float, the result is a Float
if self.dtype is Float or other_dtype is Float:
return Float
# Neither are Float, so if either are Int, the result is an Int
if self.dtype is Int or other.dtype is Int:
return Int
# Neither are Float, neither are Int, they must be Bool, so the result is Bool
return Bool


def typecheck_getitem(idx, ndim):
return_ndim = ndim
Expand Down
7 changes: 6 additions & 1 deletion descarteslabs/workflows/types/array/tests/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,12 +137,17 @@ def test_to_imagery():
assert isinstance(arr.to_imagery({}, {}), ImageCollection)


@pytest.mark.parametrize("method", [operator.lt, operator.le, operator.gt, operator.ge])
@pytest.mark.parametrize(
"method",
[operator.lt, operator.le, operator.gt, operator.ge, operator.eq, operator.ne],
)
@pytest.mark.parametrize("other", [Array[Int, 2]([[1, 2, 3], [4, 5, 6]]), 1, 0.5])
def test_container_bool_methods(method, other):
arr = Array[Int, 2]([[10, 11, 12], [13, 14, 15]])
result = method(arr, other)
r_result = method(other, arr)
assert isinstance(result, Array[Bool, 2])
assert isinstance(r_result, Array[Bool, 2])


@pytest.mark.parametrize(
Expand Down

0 comments on commit cb9384e

Please sign in to comment.