Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Write patch extraction metadata directly to STAC API #228

Merged
merged 7 commits into from
Nov 29, 2024
4 changes: 3 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,9 @@ train = [
"scikit-learn==1.5.0",
"torch==2.3.1",
"ipywidgets==8.1.3",
"duckdb==1.1.0"
"duckdb==1.1.0",
"pystac==1.10.1",
"pystac-client==0.8.3"
]

[tool.pytest.ini_options]
Expand Down
4 changes: 4 additions & 0 deletions scripts/extractions/extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,13 +228,17 @@ def setup_extraction_functions(
title="Sentinel-1 GRD",
spatial_resolution="20m",
s1_orbit_fix=True,
sensor="Sentinel1",
write_stac_api=True,
),
ExtractionCollection.PATCH_SENTINEL2: partial(
post_job_action_patch,
extract_value=extract_value,
description="Sentinel2 L2A observations, processed.",
title="Sentinel-2 L2A",
spatial_resolution="10m",
sensor="Sentinel2",
write_stac_api=True,
),
ExtractionCollection.PATCH_METEO: partial(
post_job_action_patch,
Expand Down
21 changes: 21 additions & 0 deletions src/worldcereal/openeo/extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,11 @@
import xarray as xr
from shapely import Point

from worldcereal.stac.stac_api_interaction import (
StacApiInteraction,
VitoStacApiAuthentication,
)

# Logger used for the pipeline
pipeline_log = logging.getLogger("extraction_pipeline")

Expand Down Expand Up @@ -49,6 +54,8 @@ def post_job_action_patch(
title: str,
spatial_resolution: str,
s1_orbit_fix: bool = False, # To rename the samples from the S1 orbit
write_stac_api: bool = False,
sensor: str = "Sentinel1",
) -> list:
"""From the job items, extract the metadata and save it in a netcdf file."""
base_gpd = gpd.GeoDataFrame.from_features(json.loads(row.geometry)).set_crs(
Expand Down Expand Up @@ -128,6 +135,20 @@ def post_job_action_patch(
ds.to_netcdf(temp_file.name)
shutil.move(temp_file.name, item_asset_path)

if write_stac_api:
username = os.getenv("STAC_API_USERNAME")
password = os.getenv("STAC_API_PASSWORD")

stac_api_interaction = StacApiInteraction(
sensor=sensor,
base_url="https://stac.openeo.vito.be",
auth=VitoStacApiAuthentication(username=username, password=password),
)

pipeline_log.info("Writing the STAC API metadata")
stac_api_interaction.upload_items_bulk(job_items)
pipeline_log.info("STAC API metadata written")

return job_items


Expand Down
208 changes: 208 additions & 0 deletions src/worldcereal/stac/stac_api_interaction.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,208 @@
import concurrent
from concurrent.futures import ThreadPoolExecutor
from typing import Iterable

import pystac
import pystac_client
import requests
from openeo.rest.auth.oidc import (
OidcClientInfo,
OidcProviderInfo,
OidcResourceOwnerPasswordAuthenticator,
)
from requests.auth import AuthBase


class VitoStacApiAuthentication(AuthBase):
"""Class that handles authentication for the VITO STAC API. https://stac.openeo.vito.be/"""

def __init__(self, **kwargs):
self.username = kwargs.get("username")
self.password = kwargs.get("password")

def __call__(self, request):
request.headers["Authorization"] = self.get_access_token()
return request

def get_access_token(self) -> str:
"""Get API bearer access token via password flow.

Returns
-------
str
A string containing the bearer access token.
"""
provider_info = OidcProviderInfo(
issuer="https://sso.terrascope.be/auth/realms/terrascope"
)

client_info = OidcClientInfo(
client_id="terracatalogueclient",
provider=provider_info,
)

if self.username and self.password:
authenticator = OidcResourceOwnerPasswordAuthenticator(
client_info=client_info, username=self.username, password=self.password
)
else:
raise ValueError(
"Credentials are required to obtain an access token. Please set STAC_API_USERNAME and STAC_API_PASSWORD environment variables."
)

tokens = authenticator.get_tokens()

return f"Bearer {tokens.access_token}"


class StacApiInteraction:
"""Class that handles the interaction with a STAC API."""

def __init__(
self, sensor: str, base_url: str, auth: AuthBase, bulk_size: int = 500
):
if sensor not in ["Sentinel1", "Sentinel2"]:
raise ValueError(
f"Invalid sensor '{sensor}'. Allowed values are 'Sentinel1' and 'Sentinel2'."
)
self.sensor = sensor
self.base_url = base_url
self.collection_id = f"worldcereal_{sensor.lower()}_patch_extractions"

self.auth = auth

self.bulk_size = bulk_size

def exists(self) -> bool:
client = pystac_client.Client.open(self.base_url)
return (
len([c.id for c in client.get_collections() if c.id == self.collection_id])
> 0
)

def _join_url(self, url_path: str) -> str:
return str(self.base_url + "/" + url_path)

def create_collection(self):
spatial_extent = pystac.SpatialExtent([[-180, -90, 180, 90]])
temporal_extent = pystac.TemporalExtent([[None, None]])
extent = pystac.Extent(spatial=spatial_extent, temporal=temporal_extent)

collection = pystac.Collection(
id=self.collection_id,
description=f"WorldCereal Patch Extractions for {self.sensor}",
extent=extent,
)

collection.validate()
coll_dict = collection.to_dict()

default_auth = {
"_auth": {
"read": ["anonymous"],
"write": ["stac-openeo-admin", "stac-openeo-editor"],
}
}

coll_dict.update(default_auth)

response = requests.post(
self._join_url("collections"), auth=self.auth, json=coll_dict
)

expected_status = [
requests.status_codes.codes.ok,
requests.status_codes.codes.created,
requests.status_codes.codes.accepted,
]

self._check_response_status(response, expected_status)

return response

def add_item(self, item: pystac.Item):
if not self.exists():
self.create_collection()

self._prepare_item(item)

url_path = f"collections/{self.collection_id}/items"
response = requests.post(
self._join_url(url_path), auth=self.auth, json=item.to_dict()
)

expected_status = [
requests.status_codes.codes.ok,
requests.status_codes.codes.created,
requests.status_codes.codes.accepted,
]

self._check_response_status(response, expected_status)

return response

def _prepare_item(self, item: pystac.Item):
item.collection_id = self.collection_id
if not item.get_links(pystac.RelType.COLLECTION):
item.add_link(
pystac.Link(rel=pystac.RelType.COLLECTION, target=item.collection_id)
)

def _ingest_bulk(self, items: Iterable[pystac.Item]) -> dict:
if not all(i.collection_id == self.collection_id for i in items):
raise Exception("All collection IDs should be identical for bulk ingests")

url_path = f"collections/{self.collection_id}/bulk_items"
data = {
"method": "upsert",
"items": {item.id: item.to_dict() for item in items},
}
response = requests.post(
url=self._join_url(url_path), auth=self.auth, json=data
)

expected_status = [
requests.status_codes.codes.ok,
requests.status_codes.codes.created,
requests.status_codes.codes.accepted,
]

self._check_response_status(response, expected_status)
return response.json()

def upload_items_bulk(self, items: Iterable[pystac.Item]) -> None:
if not self.exists():
self.create_collection()

chunk = []
futures = []

with ThreadPoolExecutor(max_workers=4) as executor:
for item in items:
self._prepare_item(item)
chunk.append(item)

if len(chunk) == self.bulk_size:
futures.append(executor.submit(self._ingest_bulk, chunk.copy()))
chunk = []

if chunk:
self._ingest_bulk(chunk)

for _ in concurrent.futures.as_completed(futures):
continue

def _check_response_status(
self, response: requests.Response, expected_status_codes: list[int]
):
if response.status_code not in expected_status_codes:
message = (
f"Expecting HTTP status to be any of {expected_status_codes} "
+ f"but received {response.status_code} - {response.reason}, request method={response.request.method}\n"
+ f"response body:\n{response.text}"
)

raise Exception(message)

def get_collection_id(self) -> str:
return self.collection_id
Loading