diff --git a/openeo/udf/run_code.py b/openeo/udf/run_code.py index 6c0657dd1..76e159cd6 100644 --- a/openeo/udf/run_code.py +++ b/openeo/udf/run_code.py @@ -235,7 +235,9 @@ def run_udf_code(code: str, data: UdfData) -> UdfData: raise OpenEoUdfException("No UDF found.") -def execute_local_udf(udf: Union[str, openeo.UDF], datacube: Union[str, xarray.DataArray, XarrayDataCube], fmt='netcdf'): +def execute_local_udf( + udf: Union[str, openeo.UDF], datacube: Union[str, pathlib.Path, xarray.DataArray, XarrayDataCube], fmt="netcdf" +): """ Locally executes an user defined function on a previously downloaded datacube. @@ -244,8 +246,8 @@ def execute_local_udf(udf: Union[str, openeo.UDF], datacube: Union[str, xarray.D :param fmt: format of the file if datacube is string :return: the resulting DataCube """ - if isinstance(udf, openeo.UDF): - udf = udf.code + if isinstance(udf, str): + udf = openeo.UDF(code=udf) if isinstance(datacube, (str, pathlib.Path)): d = XarrayDataCube.from_file(path=datacube, fmt=fmt) @@ -266,13 +268,13 @@ def execute_local_udf(udf: Union[str, openeo.UDF], datacube: Union[str, xarray.D .astype(numpy.float64) ) # wrap to udf_data - udf_data = UdfData(datacube_list=[d]) + udf_data = UdfData(datacube_list=[d], user_context=udf.context) # TODO: enrich to other types like time series, vector data,... probalby by adding named arguments # signature: UdfData(proj, datacube_list, feature_collection_list, structured_data_list, ml_model_list, metadata) # run the udf through the same routine as it would have been parsed in the backend - result = run_udf_code(udf, udf_data) + result = run_udf_code(udf.code, udf_data) return result diff --git a/tests/udf/test_run_code.py b/tests/udf/test_run_code.py index ea9c6d24b..2579b574e 100644 --- a/tests/udf/test_run_code.py +++ b/tests/udf/test_run_code.py @@ -307,6 +307,31 @@ def test_run_local_udf_from_file_netcdf(tmp_path): assert result[2, 0, 4, 3] == _ndvi(2034, 2134) +def test_run_local_udf_from_file_netcdf_with_context(tmp_path): + udf_code = _get_udf_code("multiply_factor.py") + xdc = _build_xdc( + ts=[numpy.datetime64("2020-08-01"), numpy.datetime64("2020-08-11"), numpy.datetime64("2020-08-21")], + bands=["bandzero", "bandone"], + xs=[10.0, 11.0, 12.0, 13.0, 14.0], + ys=[20.0, 21.0, 22.0, 23.0, 24.0, 25.0], + ) + assert xdc.array.shape == (3, 2, 5, 6) + data_path = tmp_path / "data.nc" + xdc.save_to_file(path=data_path, fmt="netcdf") + + factor = 100 + udf = UDF(udf_code, runtime="Python", context={"factor": factor}) + res = execute_local_udf(udf, data_path, fmt="netcdf") + + assert isinstance(res, UdfData) + result = res.get_datacube_list()[0].get_array() + + assert result.shape == (3, 2, 6, 5) + swapped_result = result.transpose("t", "bands", "x", "y") + expected = xdc.array * factor + xarray.testing.assert_equal(swapped_result, expected) + + def _is_package_available(name: str) -> bool: # TODO: move this to a more general test utility module. return importlib.util.find_spec(name) is not None diff --git a/tests/udf/udf_code/multiply_factor.py b/tests/udf/udf_code/multiply_factor.py new file mode 100644 index 000000000..e00b333fa --- /dev/null +++ b/tests/udf/udf_code/multiply_factor.py @@ -0,0 +1,7 @@ +from openeo.udf import XarrayDataCube + + +def apply_datacube(cube: XarrayDataCube, context: dict) -> XarrayDataCube: + factor = context["factor"] + array = cube.get_array() * factor + return XarrayDataCube(array)