diff --git a/.gitignore b/.gitignore index c0e944b3..e33f7077 100755 --- a/.gitignore +++ b/.gitignore @@ -168,3 +168,18 @@ notebooks/S1A_IW_GRDH_1SDV_20191026T153410_20191026T153444_029631_035FDA_2640.SA scripts/classification/tenpercent_sparse/.nfs00000000c35c9cfd00000035 download.zip catboost_info/catboost_training.json + +*.cbm +*.pt +*.onnx +*.nc +*.7z +*.dmg +*.gz +*.iso +*.jar +*.rar +*.tar +*.zip + +.notebook-tests/ \ No newline at end of file diff --git a/scripts/inference/cropland_mapping.py b/scripts/inference/cropland_mapping.py index 6066b092..03eb7d6c 100644 --- a/scripts/inference/cropland_mapping.py +++ b/scripts/inference/cropland_mapping.py @@ -1,21 +1,19 @@ """Cropland mapping inference script, demonstrating the use of the GFMAP, Presto and WorldCereal classifiers in a first inference pipeline.""" import argparse +from pathlib import Path +import openeo from openeo_gfmap import BoundingBoxExtent, TemporalContext -from openeo_gfmap.backend import Backend, BackendContext, cdse_connection -from openeo_gfmap.features.feature_extractor import PatchFeatureExtractor +from openeo_gfmap.backend import Backend, BackendContext +from openeo_gfmap.features.feature_extractor import apply_feature_extractor +from openeo_gfmap.inference.model_inference import apply_model_inference +from worldcereal.openeo.feature_extractor import PrestoFeatureExtractor +from worldcereal.openeo.inference import CroplandClassifier from worldcereal.openeo.preprocessing import worldcereal_preprocessed_inputs_gfmap - -class PrestoFeatureExtractor(PatchFeatureExtractor): - def __init__(self): - pass - - def extract(self, image): - pass - +ONNX_DEPS_URL = "https://artifactory.vgt.vito.be/artifactory/auxdata-public/openeo/onnx_dependencies_1.16.3.zip" if __name__ == "__main__": parser = argparse.ArgumentParser( @@ -27,12 +25,20 @@ def extract(self, image): parser.add_argument("miny", type=float, help="Minimum Y coordinate (south)") parser.add_argument("maxx", type=float, help="Maximum X coordinate (east)") parser.add_argument("maxy", type=float, help="Maximum Y coordinate (north)") + parser.add_argument( + "--epsg", + type=int, + default=4326, + help="EPSG code of the input `minx`, `miny`, `maxx`, `maxy` parameters.", + ) parser.add_argument( "start_date", type=str, help="Starting date for data extraction." ) parser.add_argument("end_date", type=str, help="Ending date for data extraction.") parser.add_argument( - "output_folder", type=str, help="Path to folder where to save results." + "output_path", + type=Path, + help="Path to folder where to save the resulting NetCDF.", ) args = parser.parse_args() @@ -41,29 +47,70 @@ def extract(self, image): miny = args.miny maxx = args.maxx maxy = args.maxy + epsg = args.epsg start_date = args.start_date end_date = args.end_date - spatial_extent = BoundingBoxExtent(minx, miny, maxx, maxy) + spatial_extent = BoundingBoxExtent(minx, miny, maxx, maxy, epsg) temporal_extent = TemporalContext(start_date, end_date) - backend = BackendContext(Backend.CDSE) + backend_context = BackendContext(Backend.FED) + + connection = openeo.connect( + "https://openeo.creo.vito.be/openeo/" + ).authenticate_oidc() # Preparing the input cube for the inference - input_cube = worldcereal_preprocessed_inputs_gfmap( - connection=cdse_connection(), - backend_context=backend, + inputs = worldcereal_preprocessed_inputs_gfmap( + connection=connection, + backend_context=backend_context, spatial_extent=spatial_extent, temporal_extent=temporal_extent, ) - # Start the job and download - job = input_cube.create_job( - title=f"Cropland inference BBOX: {minx} {miny} {maxx} {maxy}", - description="Cropland inference using WorldCereal, Presto and GFMAP classifiers", - out_format="NetCDF", + # Test feature computer + presto_parameters = { + "rescale_s1": False, # Will be done in the Presto UDF itself! + } + + features = apply_feature_extractor( + feature_extractor_class=PrestoFeatureExtractor, + cube=inputs, + parameters=presto_parameters, + size=[ + {"dimension": "x", "unit": "px", "value": 100}, + {"dimension": "y", "unit": "px", "value": 100}, + ], + overlap=[ + {"dimension": "x", "unit": "px", "value": 0}, + {"dimension": "y", "unit": "px", "value": 0}, + ], + ) + + catboost_parameters = {} + + classes = apply_model_inference( + model_inference_class=CroplandClassifier, + cube=features, + parameters=catboost_parameters, + size=[ + {"dimension": "x", "unit": "px", "value": 100}, + {"dimension": "y", "unit": "px", "value": 100}, + {"dimension": "t", "value": "P1D"}, + ], + overlap=[ + {"dimension": "x", "unit": "px", "value": 0}, + {"dimension": "y", "unit": "px", "value": 0}, + ], ) - job.start_and_wait() - job.get_results().download_files(args.output_folder) + classes.execute_batch( + outputfile=args.output_path, + out_format="NetCDF", + job_options={ + "driver-memory": "4g", + "executor-memoryOverhead": "12g", + "udf-dependency-archives": [f"{ONNX_DEPS_URL}#onnx_deps"], + }, + ) diff --git a/src/worldcereal/openeo/feature_extractor.py b/src/worldcereal/openeo/feature_extractor.py new file mode 100644 index 00000000..19037055 --- /dev/null +++ b/src/worldcereal/openeo/feature_extractor.py @@ -0,0 +1,113 @@ +"""Feature computer GFMAP compatible to compute Presto embeddings.""" + +import xarray as xr +from openeo.udf import XarrayDataCube +from openeo_gfmap.features.feature_extractor import PatchFeatureExtractor + + +class PrestoFeatureExtractor(PatchFeatureExtractor): + """Feature extractor to use Presto model to compute per-pixel embeddings. + This will generate a datacube with 128 bands, each band representing a + feature from the Presto model. + + Interesting UDF parameters: + - presto_url: A public URL to the Presto model file. A default Presto + version is provided if the parameter is left undefined. + - rescale_s1: Is specifically disabled by default, as the presto + dependencies already take care of the backscatter decompression. If + specified, should be set as `False`. + """ + + PRESTO_MODEL_URL = "https://artifactory.vgt.vito.be/artifactory/auxdata-public/worldcereal-minimal-inference/presto.pt" # NOQA + PRESO_WHL_URL = "https://artifactory.vgt.vito.be/artifactory/auxdata-public/worldcereal/dependencies/presto_worldcereal-0.1.0-py3-none-any.whl" + BASE_URL = "https://s3.waw3-1.cloudferro.com/swift/v1/project_dependencies" # NOQA + DEPENDENCY_NAME = "worldcereal_deps.zip" + + GFMAP_BAND_MAPPING = { + "S2-L2A-B02": "B02", + "S2-L2A-B03": "B03", + "S2-L2A-B04": "B04", + "S2-L2A-B05": "B05", + "S2-L2A-B06": "B06", + "S2-L2A-B07": "B07", + "S2-L2A-B08": "B08", + "S2-L2A-B8A": "B8A", + "S2-L2A-B11": "B11", + "S2-L2A-B12": "B12", + "S1-SIGMA0-VH": "VH", + "S1-SIGMA0-VV": "VV", + "COP-DEM": "DEM", + "AGERA5-TMEAN": "temperature-mean", + "AGERA5-PRECIP": "precipitation-flux", + } + + def unpack_presto_wheel(self, wheel_url: str, destination_dir: str) -> list: + import urllib.request + import zipfile + from pathlib import Path + + # Downloads the wheel file + modelfile, _ = urllib.request.urlretrieve( + wheel_url, filename=Path.cwd() / Path(wheel_url).name + ) + with zipfile.ZipFile(modelfile, "r") as zip_ref: + zip_ref.extractall(destination_dir) + return destination_dir + + def output_labels(self) -> list: + """Returns the output labels from this UDF, which is the output labels + of the presto embeddings""" + return [f"presto_ft_{i}" for i in range(128)] + + def execute(self, inarr: xr.DataArray) -> xr.DataArray: + import sys + from pathlib import Path + + if self.epsg is None: + raise ValueError( + "EPSG code is required for Presto feature extraction, but was " + "not correctly initialized." + ) + presto_model_url = self._parameters.get( + "presto_model_url", self.PRESTO_MODEL_URL + ) + presto_wheel_url = self._parameters.get("presot_wheel_url", self.PRESO_WHL_URL) + + # The below is required to avoid flipping of the result + # when running on OpenEO backend! + inarr = inarr.transpose("bands", "t", "x", "y") + + # Change the band names + new_band_names = [ + self.GFMAP_BAND_MAPPING.get(b.item(), b.item()) for b in inarr.bands + ] + inarr = inarr.assign_coords(bands=new_band_names) + + # Handle NaN values in Presto compatible way + inarr = inarr.fillna(65535) + + # Unzip de dependencies on the backend + self.logger.info("Unzipping dependencies") + deps_dir = self.extract_dependencies(self.BASE_URL, self.DEPENDENCY_NAME) + self.logger.info("Unpacking presto wheel") + deps_dir = self.unpack_presto_wheel(presto_wheel_url, deps_dir) + + self.logger.info("Appending dependencies") + sys.path.append(str(deps_dir)) + + # Debug, print the dependency directory + self.logger.info("Dependency directory: %s", list(Path(deps_dir).iterdir())) + + from presto.inference import ( # pylint: disable=import-outside-toplevel + get_presto_features, + ) + + self.logger.info("Extracting presto features") + features = get_presto_features(inarr, presto_model_url, self.epsg) + return features + + def _execute(self, cube: XarrayDataCube, parameters: dict) -> XarrayDataCube: + # Disable S1 rescaling (decompression) by default + if parameters.get("rescale_s1", None) is None: + parameters.update({"rescale_s1": False}) + return super()._execute(cube, parameters) diff --git a/src/worldcereal/openeo/feature_udf.py b/src/worldcereal/openeo/feature_udf.py deleted file mode 100644 index 4ec815d9..00000000 --- a/src/worldcereal/openeo/feature_udf.py +++ /dev/null @@ -1,184 +0,0 @@ -# -*- coding: utf-8 -*- -import sys -from typing import Dict - -import numpy as np -import pandas as pd -import xarray as xr -from openeo.udf import XarrayDataCube -from satio.collections import XArrayTrainingCollection - -from worldcereal.features.settings import ( - get_cropland_features_meta, - get_default_rsi_meta, -) -from worldcereal.fp import L2AFeaturesProcessor - -sys.path.append("/data/users/Public/driesj/openeo/deps/satio") -sys.path.append("/data/users/Public/driesj/openeo/deps/wc-classification/src") -# sys.path.insert(0,'/data/users/Public/driesj/openeo/deps/tf230') - -wheels = [ - "loguru-0.5.3-py3-none-any.whl", - "aiocontextvars-0.2.2-py2.py3-none-any.whl", - "contextvars-2.4", - "immutables-0.14-cp36-cp36m-manylinux1_x86_64.whl", - "importlib_resources-3.3.0-py2.py3-none-any.whl", -] -for wheel in wheels: - sys.path.append("/data/users/Public/driesj/openeo/deps/" + wheel) - - -classifier_file = "/tmp/worldcereal_croplandextent_lpis_unet.h5" - - -features_meta = get_cropland_features_meta() - - -class L2AFeaturesProcessor10m(L2AFeaturesProcessor): - L2A_BANDS_10M = [ - "B02", - "B03", - "B04", - "B08", - "B05", - "B06", - "B07", - "B8A", - "B11", - "B12", - "SCL", - "sunAzimuthAngles", - "sunZenithAngles", - "viewAzimuthMean", - "viewZenithMean", - ] - L2A_BANDS_DICT_ALL_10M = {10: L2A_BANDS_10M, 20: {"DUMMY"}} - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - @property - def supported_bands(self): - return L2AFeaturesProcessor10m.L2A_BANDS_DICT_ALL_10M - - -def apply_datacube(cube: XarrayDataCube, context: Dict) -> XarrayDataCube: - """ - This UDF computes WorldCereal features using SatIO. - It works on a spatiotemporal stack for one specific sensor, - currently Sentinel-2 - - @param cube: - @param context: A context dictionary, has to contain 'satio_settings' - @return: - """ - # access the underlying xarray - inarr = cube.get_array() - - # translate openEO dim name into satio convention - inarr = inarr.rename({"t": "timestamp"}) - # satio expects uint16! - inarr = inarr.astype(np.uint16) - - settings = context["satio_settings"] - settings["OPTICAL"]["composite"]["start"] = np.datetime_as_string( - inarr.coords["timestamp"].values.min(), unit="D" - ) - settings["OPTICAL"]["composite"]["end"] = np.datetime_as_string( - inarr.coords["timestamp"].values.max(), unit="D" - ) - - classify = context["classify"] - - collection = XArrayTrainingCollection( - sensor="S2", processing_level="L2A", df=pd.DataFrame(), array=inarr - ) - - from satio.rsindices import RSI_META_S2 - - default_rsi_meta = RSI_META_S2.copy() - rsi_meta = get_default_rsi_meta()["OPTICAL"] - - # in openEO, all bands are provided in 10m for now - # so we need to modify satio defaults - rsi_meta["brightness"] = default_rsi_meta["brightness"] - rsi_meta["brightness"]["native_res"] = 10 - - if "sen2agri_temp_feat" in features_meta.get("OPTICAL", {}): - features_meta["OPTICAL"]["sen2agri_temp_feat"]["parameters"][ - "time_start" - ] = settings["OPTICAL"]["composite"]["start"] - - processor = L2AFeaturesProcessor10m( - collection, - settings["OPTICAL"], - rsi_meta=rsi_meta, - features_meta=features_meta["OPTICAL"], - ) - features = processor.compute_features() - - # Extracted core from worldcereal ClassificationProcessor, - # to be seen what we need to keep - - if classify: - windowsize = 64 - import tensorflow as tf - - # from worldcereal.classification.models import WorldCerealUNET - # unetmodel = WorldCerealUNET(windowsize=64, features= 60) - # unetmodel.model.load_weights(classifier_file) - # classifier = unetmodel.model - classifier = tf.keras.models.load_model(classifier_file) - - xdim = features.data.shape[1] - ydim = features.data.shape[2] - - prediction = np.empty((xdim, ydim)) - - # can be avoided by using openEO apply_neighbourhood - for xStart in range(0, xdim, windowsize): - for yStart in range(0, ydim, windowsize): - # We need to check if we're at the end of the master image - # We have to make sure we have a full subtile - # so we need to expand such tile and the resulting overlap - # with previous subtile is not an issue - if xStart + windowsize > xdim: - xStart = xdim - windowsize - xEnd = xdim - else: - xEnd = xStart + windowsize - if yStart + windowsize > ydim: - yStart = ydim - windowsize - yEnd = ydim - else: - yEnd = yStart + windowsize - - features_patch = features.data[:, xStart:xEnd, yStart:yEnd] - patchprediction = ( - classifier.predict( - features_patch.transpose((1, 2, 0)).reshape( - (1, windowsize * windowsize, -1) - ) - ) - .squeeze() - .reshape((windowsize, windowsize)) - ) - - prediction[xStart:xEnd, yStart:yEnd] = patchprediction - - prediction_xarray = xr.DataArray(prediction.astype(np.float32), dims=["x", "y"]) - - # wrap back to datacube and return - return XarrayDataCube(prediction_xarray) - - else: - features_xarray = xr.DataArray( - features.data.astype(np.float32), - dims=["bands", "x", "y"], - coords={"bands": features.names}, - ) - - # wrap back to datacube and return - return XarrayDataCube(features_xarray) - return XarrayDataCube(features_xarray) diff --git a/src/worldcereal/openeo/inference.py b/src/worldcereal/openeo/inference.py new file mode 100644 index 00000000..e1adf289 --- /dev/null +++ b/src/worldcereal/openeo/inference.py @@ -0,0 +1,75 @@ +"""Model inference on Presto feature for binary classication""" + +import xarray as xr +from openeo_gfmap.inference.model_inference import ModelInference + + +class CroplandClassifier(ModelInference): + """Binary crop-land classifier using ONNX to load a catboost model. + + The classifier use the embeddings computed from the Presto Feature + Extractor. + + Interesting UDF parameters: + - classifier_url: A public URL to the ONNX classification model. Default is + the public Presto model. + """ + + import numpy as np + + CATBOOST_PATH = "https://artifactory.vgt.vito.be/artifactory/auxdata-public/worldcereal-minimal-inference/wc_catboost.onnx" # NOQA + + def __init__(self): + super().__init__() + + self.onnx_session = None + + def dependencies(self) -> list: + return [] # Disable the dependencies from PIP install + + def output_labels(self) -> list: + return ["classification"] + + def predict(self, features: np.ndarray) -> np.ndarray: + """ + Predicts labels using the provided features array. + """ + import numpy as np + + if self.onnx_session is None: + raise ValueError("Model has not been loaded. Please load a model first.") + + # Prepare input data for ONNX model + outputs = self.onnx_session.run(None, {"features": features}) + + # Threshold for binary conversion + threshold = 0.5 + + # Extract all prediction values and convert them to binary labels + prediction_values = [sublist["True"] for sublist in outputs[1]] + binary_labels = np.array(prediction_values) >= threshold + binary_labels = binary_labels.astype(int) + + return binary_labels + + def execute(self, inarr: xr.DataArray) -> xr.DataArray: + classifier_url = self._parameters.get("classifier_url", self.CATBOOST_PATH) + + # shape and indices for output ("xy", "bands") + x_coords, y_coords = inarr.x.values, inarr.y.values + inarr = inarr.transpose("bands", "x", "y").stack(xy=["x", "y"]).transpose() + + self.onnx_session = self.load_ort_session(classifier_url) + + # Run catboost classification + self.logger.info(f"Catboost classification with input shape: {inarr.shape}") + classification = self.predict(inarr.values) + self.logger.info(f"Classification done with shape: {classification.shape}") + + classification = xr.DataArray( + classification.reshape((1, len(x_coords), len(y_coords))), + dims=["bands", "x", "y"], + coords={"bands": ["classification"], "x": x_coords, "y": y_coords}, + ) + + return classification diff --git a/src/worldcereal/openeo/preprocessing.py b/src/worldcereal/openeo/preprocessing.py index 6991ea41..b5ca0987 100644 --- a/src/worldcereal/openeo/preprocessing.py +++ b/src/worldcereal/openeo/preprocessing.py @@ -3,6 +3,7 @@ from openeo import UDF, Connection, DataCube from openeo_gfmap import ( + Backend, BackendContext, BoundingBoxExtent, FetchType, @@ -12,12 +13,9 @@ from openeo_gfmap.fetching.generic import build_generic_extractor from openeo_gfmap.fetching.s1 import build_sentinel1_grd_extractor from openeo_gfmap.fetching.s2 import build_sentinel2_l2a_extractor -from openeo_gfmap.preprocessing.compositing import ( - max_ndvi_compositing, - mean_compositing, -) -from openeo_gfmap.preprocessing.interpolation import linear_interpolation +from openeo_gfmap.preprocessing.compositing import mean_compositing, median_compositing from openeo_gfmap.preprocessing.sar import compress_backscatter_uint16 +from openeo_gfmap.utils.catalogue import UncoveredS1Exception, select_S1_orbitstate COMPOSITE_WINDOW = "month" @@ -66,6 +64,14 @@ def raw_datacube_S2( if filter_tile: scl_cube_properties["tileId"] = lambda val: val == filter_tile + # Create the job to extract S2 + extraction_parameters = { + "target_resolution": None, # Disable target resolution + "load_collection": { + "eo:cloud_cover": lambda val: val <= 95.0, + }, + } + scl_cube = connection.load_collection( collection_id="SENTINEL2_L2A", bands=["SCL"], @@ -89,35 +95,31 @@ def raw_datacube_S2( erosion_kernel_size=3, ).rename_labels("bands", ["S2-L2A-SCL_DILATED_MASK"]) - # Compute the distance to cloud and add it to the cube - distance_to_cloud = scl_cube.apply_neighborhood( - process=UDF.from_file(Path(__file__).parent / "udf_distance_to_cloud.py"), - size=[ - {"dimension": "x", "unit": "px", "value": 256}, - {"dimension": "y", "unit": "px", "value": 256}, - {"dimension": "t", "unit": "null", "value": "P1D"}, - ], - overlap=[ - {"dimension": "x", "unit": "px", "value": 16}, - {"dimension": "y", "unit": "px", "value": 16}, - ], - ).rename_labels("bands", ["S2-L2A-DISTANCE-TO-CLOUD"]) + if additional_masks: + # Compute the distance to cloud and add it to the cube + distance_to_cloud = scl_cube.apply_neighborhood( + process=UDF.from_file(Path(__file__).parent / "udf_distance_to_cloud.py"), + size=[ + {"dimension": "x", "unit": "px", "value": 256}, + {"dimension": "y", "unit": "px", "value": 256}, + {"dimension": "t", "unit": "null", "value": "P1D"}, + ], + overlap=[ + {"dimension": "x", "unit": "px", "value": 16}, + {"dimension": "y", "unit": "px", "value": 16}, + ], + ).rename_labels("bands", ["S2-L2A-DISTANCE-TO-CLOUD"]) - additional_masks = scl_dilated_mask.merge_cubes(distance_to_cloud) + additional_masks = scl_dilated_mask.merge_cubes(distance_to_cloud) - # Try filtering using the geometry - if fetch_type == FetchType.TILE: - additional_masks = additional_masks.filter_spatial(spatial_extent.to_geojson()) + # Try filtering using the geometry + if fetch_type == FetchType.TILE: + additional_masks = additional_masks.filter_spatial( + spatial_extent.to_geojson() + ) - # Create the job to extract S2 - extraction_parameters = { - "target_resolution": None, # Disable target resolution - "load_collection": { - "eo:cloud_cover": lambda val: val <= 95.0, - }, - } - if additional_masks: extraction_parameters["pre_merge"] = additional_masks + if filter_tile: extraction_parameters["load_collection"]["tileId"] = ( lambda val: val == filter_tile @@ -162,10 +164,37 @@ def raw_datacube_S1( List of Sentinel-1 bands to extract. fetch_type : FetchType GFMAP Fetch type to use for extraction. + target_resolution : float, optional + Target resolution to resample the data to, by default 20.0. + orbit_direction : Optional[str], optional + Orbit direction to filter the data, by default None. If None and the + backend is in CDSE, then querries the catalogue for the best orbit + direction to use. In the case querrying is unavailable or fails, then + uses "ASCENDING" as a last resort. """ extractor_parameters = { "target_resolution": target_resolution, } + if orbit_direction is None and backend_context.backend in [ + Backend.CDSE, + Backend.CDSE_STAGING, + Backend.FED, + ]: + try: + orbit_direction = select_S1_orbitstate( + backend_context, spatial_extent, temporal_extent + ) + print( + f"Selected orbit direction: {orbit_direction} from max " + "accumulated area overlap between bounds and products." + ) + except UncoveredS1Exception as exc: + orbit_direction = "ASCENDING" + print( + f"Could not find any Sentinel-1 data for the given spatio-temporal context. " + f"Using ASCENDING orbit direction as a last resort. Error: {exc}" + ) + if orbit_direction is not None: extractor_parameters["load_collection"] = { "sat:orbit_state": lambda orbit: orbit == orbit_direction @@ -193,6 +222,22 @@ def raw_datacube_DEM( return extractor.get_cube(connection, spatial_extent, None) +def raw_datacube_METEO( + connection: Connection, + backend_context: BackendContext, + spatial_extent: SpatialContext, + temporal_extent: TemporalContext, + fetch_type: FetchType, +) -> DataCube: + extractor = build_generic_extractor( + backend_context=backend_context, + bands=["AGERA5-TMEAN", "AGERA5-PRECIP"], + fetch_type=fetch_type, + collection_name="AGERA5", + ) + return extractor.get_cube(connection, spatial_extent, temporal_extent) + + def worldcereal_preprocessed_inputs_gfmap( connection: Connection, backend_context: BackendContext, @@ -213,7 +258,6 @@ def worldcereal_preprocessed_inputs_gfmap( "S2-L2A-B06", "S2-L2A-B07", "S2-L2A-B08", - "S2-L2A-B8A", "S2-L2A-B11", "S2-L2A-B12", ], @@ -223,12 +267,14 @@ def worldcereal_preprocessed_inputs_gfmap( apply_mask=True, ) - s2_data = max_ndvi_compositing(s2_data, period="month") - s2_data = linear_interpolation(s2_data) + s2_data = median_compositing(s2_data, period="month") # Cast to uint16 s2_data = s2_data.linear_scale_range(0, 65534, 0, 65534) + # Decide on the orbit direction from the maximum overlapping area of + # available products. + # Extraction of the S1 data s1_data = raw_datacube_S1( connection=connection, @@ -236,16 +282,15 @@ def worldcereal_preprocessed_inputs_gfmap( spatial_extent=spatial_extent, temporal_extent=temporal_extent, bands=[ - "S1-SIGMA0-VV", "S1-SIGMA0-VH", + "S1-SIGMA0-VV", ], fetch_type=FetchType.TILE, target_resolution=10.0, # Compute the backscatter at 20m resolution, then upsample nearest neighbor when merging cubes - orbit_direction="ASCENDING", + orbit_direction=None, # Make the querry on the catalogue for the best orbit ) s1_data = mean_compositing(s1_data, period="month") - s1_data = linear_interpolation(s1_data) s1_data = compress_backscatter_uint16(backend_context, s1_data) dem_data = raw_datacube_DEM( @@ -255,10 +300,26 @@ def worldcereal_preprocessed_inputs_gfmap( fetch_type=FetchType.TILE, ) - dem_data = dem_data.resample_cube_spatial(s2_data, method="cubic") dem_data = dem_data.linear_scale_range(0, 65534, 0, 65534) + # meteo_data = raw_datacube_METEO( + # connection=connection, + # backend_context=backend_context, + # spatial_extent=spatial_extent, + # temporal_extent=temporal_extent, + # fetch_type=FetchType.TILE, + # ) + + # # Perform compositing differently depending on the bands + # mean_temperature = meteo_data.band("AGERA5-TMEAN") + # mean_temperature = mean_compositing(mean_temperature, period="month") + + # total_precipitation = meteo_data.band("AGERA5-PRECIP") + # total_precipitation = sum_compositing(total_precipitation, period="month") + data = s2_data.merge_cubes(s1_data) data = data.merge_cubes(dem_data) + # data = data.merge_cubes(mean_temperature) + # data = data.merge_cubes(total_precipitation) return data