diff --git a/scripts/extractions/extract_meteo.py b/scripts/extractions/patch_extractions/extract_meteo.py similarity index 100% rename from scripts/extractions/extract_meteo.py rename to scripts/extractions/patch_extractions/extract_meteo.py diff --git a/scripts/extractions/extract_optical.py b/scripts/extractions/patch_extractions/extract_optical.py similarity index 100% rename from scripts/extractions/extract_optical.py rename to scripts/extractions/patch_extractions/extract_optical.py diff --git a/scripts/extractions/extract_sar.py b/scripts/extractions/patch_extractions/extract_sar.py similarity index 100% rename from scripts/extractions/extract_sar.py rename to scripts/extractions/patch_extractions/extract_sar.py diff --git a/scripts/extractions/point_extractions/point_extractions.py b/scripts/extractions/point_extractions/point_extractions.py new file mode 100644 index 00000000..8f3efc07 --- /dev/null +++ b/scripts/extractions/point_extractions/point_extractions.py @@ -0,0 +1,280 @@ +"""Extract point data using OpenEO-GFMAP package.""" +import argparse +import logging +from functools import partial +from pathlib import Path +from typing import List, Optional + +import geojson +import geopandas as gpd +import openeo +import pandas as pd +import pystac +from openeo_gfmap import Backend, BackendContext, FetchType, TemporalContext +from openeo_gfmap.backend import cdse_connection +from openeo_gfmap.manager.job_manager import GFMAPJobManager +from openeo_gfmap.manager.job_splitters import split_job_s2grid + +from worldcereal.openeo.preprocessing import worldcereal_preprocessed_inputs_gfmap + +# Logger for this current pipeline +pipeline_log: Optional[logging.Logger] = None + + +def setup_logger(level=logging.INFO) -> None: + """Setup the logger from the openeo_gfmap package to the assigned level.""" + global pipeline_log + pipeline_log = logging.getLogger("pipeline_sar") + + pipeline_log.setLevel(level) + + stream_handler = logging.StreamHandler() + pipeline_log.addHandler(stream_handler) + + formatter = logging.Formatter("%(asctime)s|%(name)s|%(levelname)s: %(message)s") + stream_handler.setFormatter(formatter) + + # Exclude the other loggers from other libraries + class ManagerLoggerFilter(logging.Filter): + """Filter to only accept the OpenEO-GFMAP manager logs.""" + + def filter(self, record): + return record.name in [pipeline_log.name] + + stream_handler.addFilter(ManagerLoggerFilter()) + + +def filter_extract_true( + geometries: geojson.FeatureCollection, +) -> geojson.FeatureCollection: + """Remove all the geometries from the Feature Collection that have the property field `extract` set to `False`""" + return geojson.FeatureCollection( + [f for f in geometries.features if f.properties.get("extract", 0) != 0] + ) + + +def get_job_nb_points(row: pd.Series) -> int: + """Get the number of points in the geometry.""" + return len( + list( + filter( + lambda feat: feat.properties.get("extract"), + geojson.loads(row.geometry)["features"], + ) + ) + ) + + +# TODO: this is an example output_path. Adjust this function to your needs for production. +def generate_output_path(root_folder: Path, geometry_index: int, row: pd.Series): + features = geojson.loads(row.geometry) + sample_id = features[geometry_index].properties.get("sample_id", None) + if sample_id is None: + sample_id = features[geometry_index].properties["sampleID"] + + s2_tile_id = row.s2_tile + + subfolder = root_folder / s2_tile_id + return subfolder / f"{sample_id}{row.out_extension}" + + +def create_job_dataframe( + backend: Backend, split_jobs: List[gpd.GeoDataFrame] +) -> pd.DataFrame: + """Create a dataframe from the split jobs, containg all the necessary information to run the job.""" + columns = [ + "backend_name", + "out_extension", + "start_date", + "end_date", + "s2_tile", + "geometry", + ] + rows = [] + for job in split_jobs: + # Compute the average in the valid date and make a buffer of 1.5 year around + median_time = pd.to_datetime(job.valid_time).mean() + # A bit more than 9 months + start_date = median_time - pd.Timedelta(days=275) + # A bit more than 9 months + end_date = median_time + pd.Timedelta(days=275) + s2_tile = job.tile.iloc[0] + rows.append( + pd.Series( + dict( + zip( + columns, + [ + backend.value, + ".parquet", + start_date.strftime("%Y-%m-%d"), + end_date.strftime("%Y-%m-%d"), + s2_tile, + job.to_json(), + ], + ) + ) + ) + ) + + return pd.DataFrame(rows) + + +def create_datacube( + row: pd.Series, + connection: openeo.DataCube, + provider, + connection_provider, + executor_memory: str = "3G", + executor_memory_overhead: str = "5G", +): + """Creates an OpenEO BatchJob from the given row information.""" + + # Load the temporal and spatial extent + temporal_extent = TemporalContext(row.start_date, row.end_date) + + # Get the feature collection containing the geometry to the job + geometry = geojson.loads(row.geometry) + assert isinstance(geometry, geojson.FeatureCollection) + + # Filter the geometry to the rows with the extract only flag + geometry = filter_extract_true(geometry) + assert len(geometry.features) > 0, "No geometries with the extract flag found" + + # Backend name and fetching type + backend = Backend(row.backend_name) + backend_context = BackendContext(backend) + + inputs = worldcereal_preprocessed_inputs_gfmap( + connection=connection, + backend_context=backend_context, + spatial_extent=geometry, + temporal_extent=temporal_extent, + fetch_type=FetchType.POINT, + ) + + # Finally, create a vector cube based on the Point geometries + cube = inputs.aggregate_spatial(geometries=geometry, reducer="mean") + + # Increase the memory of the jobs depending on the number of polygons to extract + number_points = get_job_nb_points(row) + pipeline_log.debug("Number of polygons to extract %s", number_points) + + job_options = { + "executor-memory": executor_memory, + "executor-memoryOverhead": executor_memory_overhead, + "soft-error": True, + } + return cube.create_job( + out_format="Parquet", + title=f"GFMAP_Feature_Extraction_{row.s2_tile}", + job_options=job_options, + ) + + +def post_job_action( + job_items: List[pystac.Item], row: pd.Series, parameters: dict = None +) -> list: + for idx, item in enumerate(job_items): + item_asset_path = Path(list(item.assets.values())[0].href) + + gdf = gpd.read_parquet(item_asset_path) + + # Convert the dates to datetime format + gdf["date"] = pd.to_datetime(gdf["date"]) + + # Convert band dtype to uint16 (temporary fix) + # TODO: remove this step when the issue is fixed on the OpenEO backend + bands = [ + "S2-L2A-B02", + "S2-L2A-B03", + "S2-L2A-B04", + "S2-L2A-B05", + "S2-L2A-B06", + "S2-L2A-B07", + "S2-L2A-B08", + "S2-L2A-B11", + "S2-L2A-B12", + "S1-SIGMA0-VH", + "S1-SIGMA0-VV", + "COP-DEM", + "AGERA5-PRECIP", + "AGERA5-TMEAN", + ] + gdf[bands] = gdf[bands].fillna(65535).astype("uint16") + + gdf.to_parquet(item_asset_path, index=False) + + return job_items + + +if __name__ == "__main__": + setup_logger() + + parser = argparse.ArgumentParser( + description="S2 point extractions with OpenEO-GFMAP package." + ) + parser.add_argument( + "output_path", type=Path, help="Path where to save the extraction results." + ) + + # TODO: get the reference data from the RDM API. + parser.add_argument( + "input_df", type=str, help="Path to the input dataframe for the training data." + ) + parser.add_argument( + "--max_locations", + type=int, + default=500, + help="Maximum number of locations to extract per job.", + ) + parser.add_argument( + "--memory", type=str, default="3G", help="Memory to allocate for the executor." + ) + parser.add_argument( + "--memory-overhead", + type=str, + default="5G", + help="Memory overhead to allocate for the executor.", + ) + + args = parser.parse_args() + + tracking_df_path = Path(args.output_path) / "job_tracking.csv" + + # Load the input dataframe, and perform dataset splitting using the h3 tile + # to respect the area of interest. Also filters out the jobs that have + # no location with the extract=True flag. + pipeline_log.info("Loading input dataframe from %s.", args.input_df) + + input_df = gpd.read_parquet(args.input_df) + + split_dfs = split_job_s2grid(input_df, max_points=args.max_locations) + split_dfs = [df for df in split_dfs if df.extract.any()] + + job_df = create_job_dataframe(Backend.CDSE, split_dfs).iloc[ + [2] + ] # TODO: remove iloc + + # Setup the memory parameters for the job creator. + create_datacube = partial( + create_datacube, + executor_memory=args.memory, + executor_memory_overhead=args.memory_overhead, + ) + + manager = GFMAPJobManager( + output_dir=args.output_path, + output_path_generator=generate_output_path, + post_job_action=post_job_action, + collection_id="POINT-FEATURE-EXTRACTION", + collection_description="Worldcereal point feature extraction.", + poll_sleep=60, + n_threads=2, + post_job_params={}, + ) + + manager.add_backend(Backend.CDSE.value, cdse_connection, parallel_jobs=2) + + pipeline_log.info("Launching the jobs from the manager.") + manager.run_jobs(job_df, create_datacube, tracking_df_path) diff --git a/src/worldcereal/openeo/preprocessing.py b/src/worldcereal/openeo/preprocessing.py index fad5102a..1ea50714 100644 --- a/src/worldcereal/openeo/preprocessing.py +++ b/src/worldcereal/openeo/preprocessing.py @@ -29,8 +29,9 @@ def raw_datacube_S2( bands: List[str], fetch_type: FetchType, filter_tile: Optional[str] = None, - additional_masks: bool = True, - apply_mask: bool = False, + distance_to_cloud_flag: Optional[bool] = True, + additional_masks_flag: Optional[bool] = True, + apply_mask_flag: Optional[bool] = False, ) -> DataCube: """Extract Sentinel-2 datacube from OpenEO using GFMAP routines. Raw data is extracted with no cloud masking applied by default (can be @@ -96,7 +97,9 @@ def raw_datacube_S2( erosion_kernel_size=3, ).rename_labels("bands", ["S2-L2A-SCL_DILATED_MASK"]) - if additional_masks: + additional_masks = scl_dilated_mask + + if distance_to_cloud_flag: # Compute the distance to cloud and add it to the cube distance_to_cloud = scl_cube.apply_neighborhood( process=UDF.from_file(Path(__file__).parent / "udf_distance_to_cloud.py"), @@ -119,13 +122,14 @@ def raw_datacube_S2( spatial_extent.to_geojson() ) + if additional_masks_flag: extraction_parameters["pre_merge"] = additional_masks if filter_tile: extraction_parameters["load_collection"]["tileId"] = ( lambda val: val == filter_tile ) - if apply_mask: + if apply_mask_flag: extraction_parameters["pre_mask"] = scl_dilated_mask extractor = build_sentinel2_l2a_extractor( @@ -277,7 +281,16 @@ def worldcereal_preprocessed_inputs_gfmap( backend_context: BackendContext, spatial_extent: BoundingBoxExtent, temporal_extent: TemporalContext, + fetch_type: Optional[FetchType] = FetchType.TILE, + collections: Optional[List[str]] = None, ) -> DataCube: + if not all( + coll in {"sentinel1", "sentinel2", "cop_dem", "meteo"} for coll in collections + ): + raise ValueError( + f"Invalid collection name. Choose from {['sentinel1', 'sentinel2', 'cop_dem', 'meteo']}" + ) + cube_list = [] # Extraction of S2 from GFMAP s2_data = raw_datacube_S2( connection=connection, @@ -295,10 +308,11 @@ def worldcereal_preprocessed_inputs_gfmap( "S2-L2A-B11", "S2-L2A-B12", ], - fetch_type=FetchType.TILE, + fetch_type=fetch_type, filter_tile=False, - additional_masks=False, - apply_mask=True, + distance_to_cloud_flag=False if fetch_type == FetchType.POINT else True, + additional_masks_flag=False, + apply_mask_flag=True, ) s2_data = median_compositing(s2_data, period="month") @@ -306,6 +320,10 @@ def worldcereal_preprocessed_inputs_gfmap( # Cast to uint16 s2_data = s2_data.linear_scale_range(0, 65534, 0, 65534) + # Append the S2 data to the list + if collections is None or "sentinel2" in collections: + cube_list.append(s2_data) + # Decide on the orbit direction from the maximum overlapping area of # available products. @@ -319,7 +337,7 @@ def worldcereal_preprocessed_inputs_gfmap( "S1-SIGMA0-VH", "S1-SIGMA0-VV", ], - fetch_type=FetchType.TILE, + fetch_type=fetch_type, target_resolution=10.0, # Compute the backscatter at 20m resolution, then upsample nearest neighbor when merging cubes orbit_direction=None, # Make the querry on the catalogue for the best orbit ) @@ -327,23 +345,36 @@ def worldcereal_preprocessed_inputs_gfmap( s1_data = mean_compositing(s1_data, period="month") s1_data = compress_backscatter_uint16(backend_context, s1_data) + # Append the S1 data to the list + if collections is None or "sentinel1" in collections: + cube_list.append(s1_data) + dem_data = raw_datacube_DEM( connection=connection, backend_context=backend_context, spatial_extent=spatial_extent, - fetch_type=FetchType.TILE, + fetch_type=fetch_type, ) dem_data = dem_data.linear_scale_range(0, 65534, 0, 65534) + # Append the copernicus data to the list + if collections is None or "cop_dem" in collections: + cube_list.append(dem_data) + meteo_data = precomposited_datacube_METEO( connection=connection, spatial_extent=spatial_extent, temporal_extent=temporal_extent, ) - data = s2_data.merge_cubes(s1_data) - data = data.merge_cubes(dem_data) - data = data.merge_cubes(meteo_data) + # Append the copernicus data to the list + if collections is None or "meteo" in collections: + cube_list.append(meteo_data) + + # Merge the cubes + data = cube_list[0] + for cube in cube_list[1:]: + data = data.merge_cubes(cube) return data