Skip to content

Commit

Permalink
BENCH: fix issues with operator benchmarks in bench_ufunc.py
Browse files Browse the repository at this point in the history
  • Loading branch information
rgommers committed Aug 27, 2023
1 parent 2cdbc42 commit 8a730be
Showing 1 changed file with 38 additions and 12 deletions.
50 changes: 38 additions & 12 deletions benchmarks/benchmarks/bench_ufunc.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,23 +159,49 @@ def time_ndarray__0d__(self, methname, npdtypes):
class MethodsV1(Benchmark):
""" Benchmark for the methods which take an argument
"""
params = [['__and__', '__add__', '__eq__', '__floordiv__', '__ge__',
'__gt__', '__le__', '__lt__', '__matmul__',
'__mod__', '__mul__', '__ne__', '__or__',
'__pow__', '__sub__', '__truediv__', '__xor__'],
params = [['__add__', '__eq__', '__ge__', '__gt__', '__le__',
'__lt__', '__matmul__', '__mul__', '__ne__',
'__pow__', '__sub__', '__truediv__'],
TYPES1]
param_names = ['methods', 'npdtypes']
timeout = 10

def setup(self, methname, npdtypes):
if (
npdtypes.startswith("complex")
and methname in ["__floordiv__", "__mod__"]
) or (
not npdtypes.startswith("int")
and methname in ["__and__", "__or__", "__xor__"]
):
raise NotImplementedError # skip
values = get_squares_().get(npdtypes)
self.xargs = [values[0], values[1]]
if np.issubdtype(npdtypes, np.inexact):
# avoid overflow in __pow__/__matmul__ for low-precision dtypes
self.xargs[1] *= 0.01

def time_ndarray_meth(self, methname, npdtypes):
getattr(operator, methname)(*self.xargs)


class MethodsV1IntOnly(Benchmark):
""" Benchmark for the methods which take an argument
"""
params = [['__and__', '__or__', '__xor__'],
['int16', 'int32', 'int64']]
param_names = ['methods', 'npdtypes']
timeout = 10

def setup(self, methname, npdtypes):
values = get_squares_().get(npdtypes)
self.xargs = [values[0], values[1]]

def time_ndarray_meth(self, methname, npdtypes):
getattr(operator, methname)(*self.xargs)


class MethodsV1NoComplex(Benchmark):
""" Benchmark for the methods which take an argument
"""
params = [['__floordiv__', '__mod__'],
[dt for dt in TYPES1 if not dt.startswith('complex')]]
param_names = ['methods', 'npdtypes']
timeout = 10

def setup(self, methname, npdtypes):
values = get_squares_().get(npdtypes)
self.xargs = [values[0], values[1]]

Expand Down

0 comments on commit 8a730be

Please sign in to comment.