Skip to content

Commit 1110e4a

Browse files
author
Darius Couchard
authored
Merge pull request #153 from Open-EO/job-splitter-centroids
Job splitters should retain original geometries
2 parents 27a239e + ff556cd commit 1110e4a

File tree

2 files changed

+117
-33
lines changed

2 files changed

+117
-33
lines changed

src/openeo_gfmap/manager/job_splitters.py

Lines changed: 34 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -12,21 +12,27 @@
1212
from openeo_gfmap.manager import _log
1313

1414

15-
def load_s2_grid() -> gpd.GeoDataFrame:
15+
def load_s2_grid(web_mercator: bool = False) -> gpd.GeoDataFrame:
1616
"""Returns a geo data frame from the S2 grid."""
1717
# Builds the path where the geodataframe should be
18-
gdf_path = Path.home() / ".openeo-gfmap" / "s2grid_bounds.geojson"
18+
if not web_mercator:
19+
gdf_path = Path.home() / ".openeo-gfmap" / "s2grid_bounds_4326.geoparquet"
20+
url = "https://artifactory.vgt.vito.be/artifactory/auxdata-public/gfmap/s2grid_bounds_4326.geoparquet"
21+
else:
22+
gdf_path = Path.home() / ".openeo-gfmap" / "s2grid_bounds_3857.geoparquet"
23+
url = "https://artifactory.vgt.vito.be/artifactory/auxdata-public/gfmap/s2grid_bounds_3857.geoparquet"
24+
1925
if not gdf_path.exists():
2026
_log.info("S2 grid not found, downloading it from artifactory.")
2127
# Downloads the file from the artifactory URL
2228
gdf_path.parent.mkdir(exist_ok=True)
2329
response = requests.get(
24-
"https://artifactory.vgt.vito.be/artifactory/auxdata-public/gfmap/s2grid_bounds.geojson",
30+
url,
2531
timeout=180, # 3mins
2632
)
2733
with open(gdf_path, "wb") as f:
2834
f.write(response.content)
29-
return gpd.read_file(gdf_path)
35+
return gpd.read_parquet(gdf_path)
3036

3137

3238
def _resplit_group(
@@ -38,7 +44,7 @@ def _resplit_group(
3844

3945

4046
def split_job_s2grid(
41-
polygons: gpd.GeoDataFrame, max_points: int = 500
47+
polygons: gpd.GeoDataFrame, max_points: int = 500, web_mercator: bool = False
4248
) -> List[gpd.GeoDataFrame]:
4349
"""Split a job into multiple jobs from the position of the polygons/points. The centroid of
4450
the geometries to extract are used to select tile in the Sentinel-2 tile grid.
@@ -60,17 +66,23 @@ def split_job_s2grid(
6066
if polygons.crs is None:
6167
raise ValueError("The GeoDataFrame must contain a CRS")
6268

63-
polygons = polygons.to_crs(epsg=4326)
64-
if polygons.geometry.geom_type[0] != "Point":
65-
polygons["geometry"] = polygons.geometry.centroid
69+
epsg = 3857 if web_mercator else 4326
70+
71+
original_crs = polygons.crs
72+
73+
polygons = polygons.to_crs(epsg=epsg)
74+
75+
polygons["centroid"] = polygons.geometry.centroid
6676

6777
# Dataset containing all the S2 tiles, find the nearest S2 tile for each point
68-
s2_grid = load_s2_grid()
78+
s2_grid = load_s2_grid(web_mercator)
6979
s2_grid["geometry"] = s2_grid.geometry.centroid
7080

71-
polygons = gpd.sjoin_nearest(polygons, s2_grid[["tile", "geometry"]]).drop(
72-
columns=["index_right"]
73-
)
81+
polygons = gpd.sjoin_nearest(
82+
polygons.set_geometry("centroid"), s2_grid[["tile", "geometry"]]
83+
).drop(columns=["index_right", "centroid"])
84+
85+
polygons = polygons.set_geometry("geometry").to_crs(original_crs)
7486

7587
split_datasets = []
7688
for _, sub_gdf in polygons.groupby("tile"):
@@ -86,10 +98,13 @@ def append_h3_index(
8698
polygons: gpd.GeoDataFrame, grid_resolution: int = 3
8799
) -> gpd.GeoDataFrame:
88100
"""Append the H3 index to the polygons."""
89-
if polygons.geometry.geom_type[0] != "Point":
90-
geom_col = polygons.geometry.centroid
91-
else:
92-
geom_col = polygons.geometry
101+
102+
# Project to Web mercator to calculate centroids
103+
polygons = polygons.to_crs(epsg=3857)
104+
geom_col = polygons.geometry.centroid
105+
# Project to lat lon to calculate the h3 index
106+
geom_col = geom_col.to_crs(epsg=4326)
107+
93108
polygons["h3index"] = geom_col.apply(
94109
lambda pt: h3.geo_to_h3(pt.y, pt.x, grid_resolution)
95110
)
@@ -127,12 +142,13 @@ def split_job_hex(
127142
if polygons.crs is None:
128143
raise ValueError("The GeoDataFrame must contain a CRS")
129144

130-
# Project to lat/lon positions
131-
polygons = polygons.to_crs(epsg=4326)
145+
original_crs = polygons.crs
132146

133147
# Split the polygons into multiple jobs
134148
polygons = append_h3_index(polygons, grid_resolution)
135149

150+
polygons = polygons.to_crs(original_crs)
151+
136152
split_datasets = []
137153
for _, sub_gdf in polygons.groupby("h3index"):
138154
if len(sub_gdf) > max_points:
Lines changed: 83 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,94 @@
11
"""Test the job splitters and managers of OpenEO GFMAP."""
22

3-
from pathlib import Path
43

54
import geopandas as gpd
5+
from shapely.geometry import Point, Polygon
66

7-
from openeo_gfmap.manager.job_splitters import split_job_hex
7+
from openeo_gfmap.manager.job_splitters import split_job_hex, split_job_s2grid
88

99

10-
# TODO can we instead assert on exact numbers ?
11-
# would remove the print statement
12-
def test_split_jobs():
13-
dataset_path = Path(__file__).parent / "resources/wc_extraction_dataset.gpkg"
10+
def test_split_job_s2grid():
11+
# Create a mock GeoDataFrame with points
12+
# The points are located in two different S2 tiles
13+
data = {
14+
"id": [1, 2, 3, 4, 5],
15+
"geometry": [
16+
Point(60.02, 4.57),
17+
Point(59.6, 5.04),
18+
Point(59.92, 3.37),
19+
Point(59.07, 4.11),
20+
Point(58.77, 4.87),
21+
],
22+
}
23+
polygons = gpd.GeoDataFrame(data, crs="EPSG:4326")
1424

15-
# Load the dataset
16-
dataset = gpd.read_file(dataset_path)
25+
# Define expected number of split groups
26+
max_points = 2
1727

18-
# Split the dataset
19-
split_dataset = split_job_hex(dataset, max_points=500)
28+
# Call the function
29+
result = split_job_s2grid(polygons, max_points)
2030

21-
# Check the number of splits
22-
assert len(split_dataset) > 1
31+
assert (
32+
len(result) == 3
33+
), "The number of GeoDataFrames returned should match the number of splits needed."
2334

24-
for ds in split_dataset:
25-
print(len(ds))
26-
assert len(ds) <= 500
35+
# Check if the geometries are preserved
36+
for gdf in result:
37+
assert (
38+
"geometry" in gdf.columns
39+
), "Each GeoDataFrame should have a geometry column."
40+
assert gdf.crs == 4326, "The original CRS should be preserved."
41+
assert all(
42+
gdf.geometry.geom_type == "Point"
43+
), "Original geometries should be preserved."
44+
45+
46+
def test_split_job_hex():
47+
# Create a mock GeoDataFrame with points
48+
# The points/polygons are located in three different h3 hexes of size 3
49+
data = {
50+
"id": [1, 2, 3, 4, 5, 6],
51+
"geometry": [
52+
Point(60.02, 4.57),
53+
Point(58.34, 5.06),
54+
Point(59.92, 3.37),
55+
Point(58.85, 4.90),
56+
Point(58.77, 4.87),
57+
Polygon(
58+
[
59+
(58.78, 4.88),
60+
(58.78, 4.86),
61+
(58.76, 4.86),
62+
(58.76, 4.88),
63+
(58.78, 4.88),
64+
]
65+
),
66+
],
67+
}
68+
polygons = gpd.GeoDataFrame(data, crs="EPSG:4326")
69+
70+
max_points = 3
71+
72+
result = split_job_hex(polygons, max_points)
73+
74+
assert (
75+
len(result) == 4
76+
), "The number of GeoDataFrames returned should match the number of splits needed."
77+
78+
for idx, gdf in enumerate(result):
79+
assert (
80+
"geometry" in gdf.columns
81+
), "Each GeoDataFrame should have a geometry column."
82+
assert gdf.crs == 4326, "The CRS should be preserved."
83+
if idx == 1:
84+
assert all(
85+
gdf.geometry.geom_type == "Polygon"
86+
), "Original geometries should be preserved."
87+
else:
88+
assert all(
89+
gdf.geometry.geom_type == "Point"
90+
), "Original geometries should be preserved."
91+
92+
assert (
93+
len(result[0]) == 3
94+
), "The number of geometries in the first split should be 3."

0 commit comments

Comments
 (0)