diff --git a/pyrfu/pyrf/cart2sph_ts.py b/pyrfu/pyrf/cart2sph_ts.py index e429816..d0cdcc6 100644 --- a/pyrfu/pyrf/cart2sph_ts.py +++ b/pyrfu/pyrf/cart2sph_ts.py @@ -3,6 +3,7 @@ # 3rd party imports import numpy as np +import xarray as xr # Local imports from .ts_vec_xyz import ts_vec_xyz @@ -37,8 +38,14 @@ def cart2sph_ts(inp, direction_flag: int = 1): """ - if inp.attrs["TENSOR_ORDER"] != 1 or inp.data.ndim != 2: - raise TypeError("Input must be vector field") + # Check input type + assert isinstance(inp, xr.DataArray), "inp must be a xarray.DataArray" + + # Check that inp is a vector time series + assert inp.data.ndim == 2 and inp.shape[1] == 3, "inp must be a vector time series" + + # Check direction +/-1 + assert direction_flag in [-1, 1], "direction_flag must be +/-1" if direction_flag == -1: r_data = inp.data[:, 0] @@ -52,7 +59,7 @@ def cart2sph_ts(inp, direction_flag: int = 1): x_data = r_data * cos_the * cos_phi y_data = r_data * cos_the * sin_phi - out_data = np.hstack([x_data, y_data, z_data]) + out_data = np.transpose(np.vstack([x_data, y_data, z_data])) else: xy2 = inp.data[:, 0] ** 2 + inp.data[:, 1] ** 2 diff --git a/pyrfu/tests/test_pyrf.py b/pyrfu/tests/test_pyrf.py index 2d35fd3..f5f85b3 100644 --- a/pyrfu/tests/test_pyrf.py +++ b/pyrfu/tests/test_pyrf.py @@ -465,6 +465,24 @@ def test_cart2sph_output(self): ) +class Cart2SphTsTestCase(unittest.TestCase): + def test_cart2sph_ts_input(self): + with self.assertRaises(AssertionError): + pyrf.cart2sph_ts(0.0) + pyrf.cart2sph_ts(generate_data(100, "vector")) + pyrf.cart2sph_ts(generate_ts(64.0, 100, "scalar")) + pyrf.cart2sph_ts(generate_ts(64.0, 100, "vector"), 2) + + def test_cart2sph_ts_output(self): + result = pyrf.cart2sph_ts(generate_ts(64.0, 100, "vector"), 1) + self.assertIsInstance(result, xr.DataArray) + self.assertListEqual(list(result.shape), [100, 3]) + + result = pyrf.cart2sph_ts(generate_ts(64.0, 100, "vector"), -1) + self.assertIsInstance(result, xr.DataArray) + self.assertListEqual(list(result.shape), [100, 3]) + + class CdfEpoch2Datetime64TestCase(unittest.TestCase): def test_cdfepoch2datetime64_input_type(self): ref_time = 599572869184000000