Skip to content

Commit

Permalink
Merge pull request #709 from Open-EO/execute_local_udf_context
Browse files Browse the repository at this point in the history
Execute local udf context
  • Loading branch information
EmileSonneveld authored Jan 17, 2025
2 parents ee7941a + ab9c070 commit 5ffc554
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 5 deletions.
12 changes: 7 additions & 5 deletions openeo/udf/run_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,9 @@ def run_udf_code(code: str, data: UdfData) -> UdfData:
raise OpenEoUdfException("No UDF found.")


def execute_local_udf(udf: Union[str, openeo.UDF], datacube: Union[str, xarray.DataArray, XarrayDataCube], fmt='netcdf'):
def execute_local_udf(
udf: Union[str, openeo.UDF], datacube: Union[str, pathlib.Path, xarray.DataArray, XarrayDataCube], fmt="netcdf"
):
"""
Locally executes an user defined function on a previously downloaded datacube.
Expand All @@ -244,8 +246,8 @@ def execute_local_udf(udf: Union[str, openeo.UDF], datacube: Union[str, xarray.D
:param fmt: format of the file if datacube is string
:return: the resulting DataCube
"""
if isinstance(udf, openeo.UDF):
udf = udf.code
if isinstance(udf, str):
udf = openeo.UDF(code=udf)

if isinstance(datacube, (str, pathlib.Path)):
d = XarrayDataCube.from_file(path=datacube, fmt=fmt)
Expand All @@ -266,13 +268,13 @@ def execute_local_udf(udf: Union[str, openeo.UDF], datacube: Union[str, xarray.D
.astype(numpy.float64)
)
# wrap to udf_data
udf_data = UdfData(datacube_list=[d])
udf_data = UdfData(datacube_list=[d], user_context=udf.context)

# TODO: enrich to other types like time series, vector data,... probalby by adding named arguments
# signature: UdfData(proj, datacube_list, feature_collection_list, structured_data_list, ml_model_list, metadata)

# run the udf through the same routine as it would have been parsed in the backend
result = run_udf_code(udf, udf_data)
result = run_udf_code(udf.code, udf_data)
return result


Expand Down
25 changes: 25 additions & 0 deletions tests/udf/test_run_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,31 @@ def test_run_local_udf_from_file_netcdf(tmp_path):
assert result[2, 0, 4, 3] == _ndvi(2034, 2134)


def test_run_local_udf_from_file_netcdf_with_context(tmp_path):
udf_code = _get_udf_code("multiply_factor.py")
xdc = _build_xdc(
ts=[numpy.datetime64("2020-08-01"), numpy.datetime64("2020-08-11"), numpy.datetime64("2020-08-21")],
bands=["bandzero", "bandone"],
xs=[10.0, 11.0, 12.0, 13.0, 14.0],
ys=[20.0, 21.0, 22.0, 23.0, 24.0, 25.0],
)
assert xdc.array.shape == (3, 2, 5, 6)
data_path = tmp_path / "data.nc"
xdc.save_to_file(path=data_path, fmt="netcdf")

factor = 100
udf = UDF(udf_code, runtime="Python", context={"factor": factor})
res = execute_local_udf(udf, data_path, fmt="netcdf")

assert isinstance(res, UdfData)
result = res.get_datacube_list()[0].get_array()

assert result.shape == (3, 2, 6, 5)
swapped_result = result.transpose("t", "bands", "x", "y")
expected = xdc.array * factor
xarray.testing.assert_equal(swapped_result, expected)


def _is_package_available(name: str) -> bool:
# TODO: move this to a more general test utility module.
return importlib.util.find_spec(name) is not None
Expand Down
7 changes: 7 additions & 0 deletions tests/udf/udf_code/multiply_factor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from openeo.udf import XarrayDataCube


def apply_datacube(cube: XarrayDataCube, context: dict) -> XarrayDataCube:
factor = context["factor"]
array = cube.get_array() * factor
return XarrayDataCube(array)

0 comments on commit 5ffc554

Please sign in to comment.