Skip to content

Commit

Permalink
Changed imports for UDF in feature_extractor.py file
Browse files Browse the repository at this point in the history
  • Loading branch information
GriffinBabe committed Mar 11, 2024
1 parent dbe6116 commit e7c31be
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 80 deletions.
144 changes: 64 additions & 80 deletions src/openeo_gfmap/features/feature_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
implementation of feature extractors of a UDF.
"""

import inspect
import re
from abc import ABC, abstractmethod

import numpy as np
Expand All @@ -13,59 +15,10 @@
from pyproj import Transformer
from pyproj.crs import CRS

REQUIRED_IMPORTS = """
from abc import ABC, abstractmethod
import openeo
from openeo.udf import XarrayDataCube, inspect
from openeo.udf.udf_data import UdfData
import xarray as xr
import numpy as np
from pyproj import Transformer
from pyproj.crs import CRS
from typing import Union
"""


LAT_HARMONIZED_NAME = "GEO-LAT"
LON_HARMONIZED_NAME = "GEO-LON"
EPSG_HARMONIZED_NAME = "GEO-EPSG"

# To fill in: EPSG_HARMONIZED_NAME, Is it pixel based and Feature Extractor class
APPLY_DATACUBE_SOURCE_CODE = """
LAT_HARMONIZED_NAME = "{lat_harmonized_name}"
LON_HARMONIZED_NAME = "{lon_harmonized_name}"
EPSG_HARMONIZED_NAME = "{epsg_harmonized_name}"
from openeo.udf import XarrayDataCube
from openeo.udf.udf_data import UdfData
IS_PIXEL_BASED = {is_pixel_based}
def apply_udf_data(udf_data: UdfData) -> XarrayDataCube:
feature_extractor = {feature_extractor_class}() # User-defined, feature extractor class initialized here
if not IS_PIXEL_BASED:
assert len(udf_data.datacube_list) == 1, "OpenEO GFMAP Feature extractor pipeline only supports single input cubes for the tile."
cube = udf_data.datacube_list[0]
parameters = udf_data.user_context
proj = udf_data.proj
if proj is not None:
proj = proj["EPSG"]
parameters[EPSG_HARMONIZED_NAME] = proj
cube = feature_extractor._execute(cube, parameters=parameters)
udf_data.datacube_list = [cube]
return udf_data
"""


class FeatureExtractor(ABC):
"""Base class for all feature extractor UDFs. It provides some common
Expand Down Expand Up @@ -184,12 +137,65 @@ def execute(self, inarr: xr.DataArray) -> xr.DataArray:
pass


def apply_udf_data(udf_data: UdfData) -> XarrayDataCube:
feature_extractor_class = "<feature_extractor_class>"

# User-defined, feature extractor class initialized here
feature_extractor = feature_extractor_class()

is_pixel_based = issubclass(feature_extractor_class, PointFeatureExtractor)

if not is_pixel_based:
assert (
len(udf_data.datacube_list) == 1
), "OpenEO GFMAP Feature extractor pipeline only supports single input cubes for the tile."

cube = udf_data.datacube_list[0]
parameters = udf_data.user_context

proj = udf_data.proj
if proj is not None:
proj = proj["EPSG"]

parameters[EPSG_HARMONIZED_NAME] = proj

cube = feature_extractor._execute(cube, parameters=parameters)

udf_data.datacube_list = [cube]

return udf_data


def _get_imports() -> str:
with open(__file__, "r", encoding="UTF-8") as f:
script_source = f.read()

lines = script_source.split("\n")

imports = []
static_globals = []

for line in lines:
if line.strip().startswith(("import ", "from ")):
imports.append(line)
elif re.match("^[A-Z_0-9]+\s*=.*$", line):
static_globals.append(line)

return "\n".join(imports) + "\n\n" + "\n".join(static_globals)


def _get_apply_udf_data(feature_extractor: FeatureExtractor) -> str:
source_lines = inspect.getsource(apply_udf_data)
source = "".join(source_lines)
# replace in the source function the `feature_extractor_class`
return source.replace('"<feature_extractor_class>"', feature_extractor.__name__)


def generate_udf_code(feature_extractor_class: FeatureExtractor) -> openeo.UDF:
"""Generates the udf code by packing imports of this file, the necessary
superclass and subclasses as well as the user defined feature extractor
class and the apply_datacube function.
"""
import inspect

# UDF code that will be built here
udf_code = ""
Expand All @@ -198,36 +204,12 @@ class and the apply_datacube function.
feature_extractor_class, FeatureExtractor
), "The feature extractor class must be a subclass of FeatureExtractor."

if issubclass(feature_extractor_class, PatchFeatureExtractor):
udf_code += f"{REQUIRED_IMPORTS}\n\n"
udf_code += f"{inspect.getsource(FeatureExtractor)}\n\n"
udf_code += f"{inspect.getsource(PatchFeatureExtractor)}\n\n"
udf_code += f"{inspect.getsource(feature_extractor_class)}\n\n"
udf_code += APPLY_DATACUBE_SOURCE_CODE.format(
lat_harmonized_name=LAT_HARMONIZED_NAME,
lon_harmonized_name=LON_HARMONIZED_NAME,
epsg_harmonized_name=EPSG_HARMONIZED_NAME,
is_pixel_based=False,
feature_extractor_class=feature_extractor_class.__name__,
)
elif issubclass(feature_extractor_class, PointFeatureExtractor):
udf_code += f"{REQUIRED_IMPORTS}\n\n"
udf_code += f"{inspect.getsource(FeatureExtractor)}\n\n"
udf_code += f"{inspect.getsource(PointFeatureExtractor)}\n\n"
udf_code += f"{inspect.getsource(feature_extractor_class)}\n\n"
udf_code += APPLY_DATACUBE_SOURCE_CODE.format(
lat_harmonized_name=LAT_HARMONIZED_NAME,
lon_harmonized_name=LON_HARMONIZED_NAME,
epsg_harmonized_name=EPSG_HARMONIZED_NAME,
is_pixel_based=True,
feature_extractor_class=feature_extractor_class.__name__,
)
else:
raise NotImplementedError(
"The feature extractor must be a subclass of either "
"PatchFeatureExtractor or PointFeatureExtractor."
)

udf_code += _get_imports() + "\n\n"
udf_code += f"{inspect.getsource(FeatureExtractor)}\n\n"
udf_code += f"{inspect.getsource(PatchFeatureExtractor)}\n\n"
udf_code += f"{inspect.getsource(PointFeatureExtractor)}\n\n"
udf_code += f"{inspect.getsource(feature_extractor_class)}\n\n"
udf_code += _get_apply_udf_data(feature_extractor_class)
return udf_code


Expand Down Expand Up @@ -266,6 +248,8 @@ def apply_feature_extractor_local(
"""
udf_code = generate_udf_code(feature_extractor_class)

print(udf_code)

udf = openeo.UDF(code=udf_code, context=parameters)

cube = XarrayDataCube(cube)
Expand Down
Binary file modified tests/test_openeo_gfmap/resources/test_optical_cube.nc
Binary file not shown.

0 comments on commit e7c31be

Please sign in to comment.