diff --git a/pyproject.toml b/pyproject.toml index 1013d480..179bb7e9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,3 +39,5 @@ homepage = "https://github.com/SciQLop/speasy" [project.optional-dependencies] zstd = ["pyzstd"] +[tool.ruff.lint] +select = ["NPY201"] diff --git a/speasy/products/variable.py b/speasy/products/variable.py index 2399de8b..f02961a8 100644 --- a/speasy/products/variable.py +++ b/speasy/products/variable.py @@ -247,6 +247,16 @@ def __truediv__(self, other): def __rtruediv__(self, other): return np.divide(other, self) + def __np_build_axes__(self, other, axis=None): + if axis is None or self.ndim == other.ndim: + return deepcopy(self.__axes) + else: + axes = [] + for i, ax in enumerate(self.__axes): + if i != axis: + axes.append(deepcopy(ax)) + return axes + def __array_function__(self, func, types, args, kwargs): if func.__name__ in SpeasyVariable.__LIKE_NP_FUNCTIONS__: return SpeasyVariable.__dict__[func.__name__].__func__(self) @@ -258,29 +268,34 @@ def __array_function__(self, func, types, args, kwargs): if np.isscalar(res): return res if isinstance(res, np.ndarray): - if len(res.shape) != self.shape and res.shape[0] != len(self.time): + if len(res.shape) != self.shape and (res.shape[0] != len(self.time) or kwargs.get('axis', None) == 0): return res n_cols = res.shape[1] if len(res.shape) > 1 else 1 return SpeasyVariable( - axes=deepcopy(self.__axes), + axes=self.__np_build_axes__(res, axis=kwargs.get('axis', None)), values=DataContainer(values=res, name=f"{func.__name__}_{self.__values_container.name}", meta=deepcopy(self.__values_container.meta)), columns=[f"column_{i}" for i in range(n_cols)], ) - def __array_ufunc__(self, ufunc, method, *inputs, out=None, **kwargs): + def __array_ufunc__(self, ufunc, method, *inputs, out: 'SpeasyVariable' or None = None, **kwargs): if out is not None: _out = _values(out[0]) else: _out = None inputs = list(map(_values, inputs)) values = ufunc(*inputs, **{name: _values(value) for name, value in kwargs}, out=_out) + + axes = self.__np_build_axes__(values, axis=kwargs.get('axis', None)) + if out is not None: + if isinstance(out, SpeasyVariable): + out.__axes = axes return out else: return SpeasyVariable( - axes=deepcopy(self.__axes), + axes=axes, values=DataContainer(values=values, name=f"{ufunc.__name__}_{self.__values_container.name}", meta=deepcopy(self.__values_container.meta)), columns=[f"column_{i}" for i in range(values.shape[1])], diff --git a/tests/test_speasy_variable.py b/tests/test_speasy_variable.py index 2d089f69..16534079 100644 --- a/tests/test_speasy_variable.py +++ b/tests/test_speasy_variable.py @@ -54,6 +54,20 @@ def make_2d_var(start: float = 0., stop: float = 0., step: float = 1., coef: flo values=DataContainer(values, is_time_dependent=True, meta={"DISPLAY_TYPE": "spectrogram"}), columns=["Values"]) +def make_3d_var(start: float = 0., stop: float = 0., step: float = 1., coef: float = 1., height: int = 32, + depth: int = 32): + time = np.arange(start, stop, step) + values = np.random.random((len(time), height, depth)) + y = np.repeat(np.arange(height), len(time), axis=0) + z = np.repeat(np.arange(depth), len(time), axis=0) + return SpeasyVariable( + axes=[VariableTimeAxis(values=epoch_to_datetime64(time)), + VariableAxis(name='y', values=y, is_time_dependent=True), + VariableAxis(name='z', values=z, is_time_dependent=True) + ], + values=DataContainer(values, is_time_dependent=True, meta={"DISPLAY_TYPE": "spectrogram"}), columns=["Values"]) + + def make_2d_var_1d_y(start: float = 0., stop: float = 0., step: float = 1., coef: float = 1., height: int = 32): time = np.arange(start, stop, step) values = (time * coef).reshape(-1, 1) * np.arange(height).reshape(1, -1) @@ -299,10 +313,13 @@ def test_time_shift(self): self.assertTrue(np.all(var.time == self.var.time + shift)) +@ddt class TestSpeasyVariableNumpyInterface(unittest.TestCase): def setUp(self): self.var = make_simple_var(1., 10., 1., 10.) self.vector = make_simple_var_3cols(1., 10., 1., 10.) + self.spectro = make_2d_var(1., 10., 10., 32) + self.var3d = make_3d_var(1., 10., 10., 32, 16) def tearDown(self): pass @@ -326,6 +343,29 @@ def test_ufunc_magnitude(self): self.assertTrue(np.allclose(var.values, np.linalg.norm(self.vector.values, axis=1).reshape(-1, 1))) self.assertTrue(np.allclose(var, np.linalg.norm(self.vector, axis=1))) + @data(np.sum, np.mean, np.std, np.var, np.max, np.min) + def test_functions_that_reduce_ndim_on_axis1(self, func): + for var in (self.spectro, self.var3d): + result = func(var, axis=1) + self.assertEqual(len(var.axes) - 1, len(result.axes)) + self.assertTrue(np.all(result.values == func(var.values, axis=1))) + + @data(np.sum, np.mean, np.std, np.var, np.max, np.min) + def test_functions_that_reduce_ndim_on_last_axis(self, func): + for var in (self.spectro, self.var3d): + axis = len(var.axes) - 1 + result = func(var, axis=axis) + self.assertEqual(len(var.axes) - 1, len(result.axes)) + self.assertTrue(np.all(result.values == func(var.values, axis=axis))) + + @data(np.sum, np.mean, np.std, np.var, np.max, np.min) + def test_functions_that_reduce_ndim_on_axis0(self, func): + for var in (self.spectro, self.var3d): + result = func(var, axis=0) + self.assertIsNot(type(result), SpeasyVariable) + self.assertIsInstance(result, np.ndarray) + self.assertTrue(np.all(result == func(var.values, axis=0))) + def test_zeros_like(self): var = np.zeros_like(self.var) self.assertEqual(self.var.shape, var.shape) @@ -349,14 +389,10 @@ def test_empty_like(self): self.assertListEqual(self.var.axes, var.axes) self.assertListEqual(self.var.columns, var.columns) - def test_scalar_result(self): - for v in (self.var, self.vector): - self.assertIsInstance(np.sum(v), float) - self.assertIsInstance(np.mean(v), float) - self.assertIsInstance(np.std(v), float) - self.assertIsInstance(np.var(v), float) - self.assertIsInstance(np.max(v), float) - self.assertIsInstance(np.min(v), float) + @data(np.sum, np.mean, np.std, np.var, np.max, np.min) + def test_scalar_result(self, func): + for v in (self.var, self.vector, self.spectro, self.var3d): + self.assertIsInstance(func(v), float) class SpeasyVariableCompare(unittest.TestCase):