Skip to content

Commit 252827e

Browse files
authored
Merge pull request #2 from EarthyScience/fg/xarray
Ad an xarray wrapper
2 parents 527d6db + a17d019 commit 252827e

File tree

2 files changed

+21
-4
lines changed

2 files changed

+21
-4
lines changed

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@ readme = "README.md"
1515
requires-python = ">=3.8"
1616
dependencies = [
1717
"numpy",
18-
"dask"
18+
"dask",
19+
"xarray"
1920
]
2021
classifiers = [
2122
"Programming Language :: Python :: 3",

rqadeforestation/__init__.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import ctypes as ct
33
import numpy as np
44
import dask.array as da
5+
import xarray as xr
56

67
class MallocVector(ct.Structure):
78
_fields_ = [("pointer", ct.c_void_p),
@@ -38,7 +39,7 @@ def rqatrend(y: np.ndarray, threshold: float, border: int = 10, theiler: int = 1
3839
:param theiler: Theiler window size for the RQA calculation.
3940
:return: The RQA trend value.
4041
"""
41-
py = mvptr(y)
42+
py = mvptr(y.astype(np.float64))
4243
lib.rqatrend.argtypes = (ct.POINTER(MallocVector), ct.c_double, ct.c_int64, ct.c_int64)
4344
lib.rqatrend.restype = ct.c_double
4445
result_single = lib.rqatrend(py, threshold, border, theiler)
@@ -56,16 +57,31 @@ def rqatrend_matrix(matrix: np.ndarray, threshold: float, border: int = 10, thei
5657
:return: Numpy array of all RQA trend values of size n_timeseries.
5758
"""
5859

60+
if not len(matrix.shape) == 2:
61+
raise Exception("Input to rqatrend_matrix must be 2d")
62+
5963
n = matrix.shape[0]
6064
result_several = np.ones(n)
6165
p_result_several = mvptr(result_several)
62-
p_matrix = mmptr(matrix)
63-
66+
p_matrix = mmptr(matrix.astype(np.float64))
6467
# arguments: result_vector, data, threshhold, border, theiler
6568
lib.rqatrend_inplace.argtypes = (ct.POINTER(MallocVector), ct.POINTER(MallocMatrix), ct.c_double, ct.c_int64, ct.c_int64)
6669
return_value = lib.rqatrend_inplace(p_result_several, p_matrix, threshold, border, theiler)
6770
return result_several
6871

72+
def rqatrend_xarray(x, threshold:float, border: int = 10,
73+
theiler: int=1, out_dtype = np.float64,
74+
timeaxis_name = "time"):
75+
return xr.apply_ufunc(
76+
rqatrend,
77+
x.chunk({timeaxis_name: -1}),
78+
kwargs = {'threshold': threshold,'border':border,'theiler':theiler},
79+
input_core_dims = [[timeaxis_name]],
80+
output_core_dims = [[]],
81+
dask = "parallelized",
82+
vectorize=True,
83+
)
84+
6985

7086
def rqatrend_dask(x: da.Array, timeseries_axis: int, threshold: float, border: int = 10, theiler: int = 1, out_dtype: type = np.float64) -> da.Array:
7187
"""

0 commit comments

Comments
 (0)