From 9f58692a42286af2b4057608efabe0a354f3eab6 Mon Sep 17 00:00:00 2001 From: Christina Butsko Date: Tue, 2 Jul 2024 17:29:26 +0200 Subject: [PATCH 1/2] added notebook for custom crop-nocrop model training + respective functions updated in utils --- .../system_v1_custom_model_cropland.ipynb | 316 ++++++++++++++++++ notebooks/utils.py | 51 ++- 2 files changed, 353 insertions(+), 14 deletions(-) create mode 100644 notebooks/system_v1_custom_model_cropland.ipynb diff --git a/notebooks/system_v1_custom_model_cropland.ipynb b/notebooks/system_v1_custom_model_cropland.ipynb new file mode 100644 index 00000000..beb2527a --- /dev/null +++ b/notebooks/system_v1_custom_model_cropland.ipynb @@ -0,0 +1,316 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "![](./resources/System_v1_training_header.png)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This notebook contains a demonstration on how to train custom crop type models based on your own reference data and how to apply the resulting model to generate a custom crop type map.\n", + "\n", + "# Content\n", + "\n", + "- [Before you start](#before-you-start)\n", + "- [1. Define region of interest](#1.-Define-a-region-of-interest)\n", + "- [2. Check public in-situ reference data](#2.-Check-public-in-situ-reference-data)\n", + "- [3. Prepare own reference data](#3.-Prepare-own-reference-data)\n", + "- [4. Extract required model inputs](#4.-Extract-required-model-inputs)\n", + "- [5. Train custom classification model](#5.-Train-custom-classification-model)\n", + "- [6. Generate a map](#6.-Generate-a-map)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Before you start\n", + "\n", + "In order to run this notebook, you need to create an account on:\n", + "\n", + "- The Copernicus Data Space Ecosystem (CDSE)\n", + "--> by completing the form [HERE](https://identity.dataspace.copernicus.eu/auth/realms/CDSE/login-actions/registration?client_id=cdse-public&tab_id=eRKGqDvoYI0)\n", + "\n", + "- VITO's Terrascope platform\n", + "--> by completing the form [HERE](https://sso.terrascope.be/auth/realms/terrascope/login-actions/registration?client_id=drupal-terrascope&tab_id=irBzckp2aDo)" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "%load_ext autoreload\n", + "%autoreload 2" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "from worldcereal.utils.map import get_ui_map\n", + "RDM_API = \"https://ewoc-rdm-api.iiasa.ac.at\"" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# 1. Define a region of interest\n", + "\n", + "When running the code snippet below, an interactive map will be visualized.\n", + "Click the Rectangle button on the left hand side of the map to start drawing your region of interest.\n", + "When finished, execute the second cell to store the coordinates of your region of interest. " + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "6bd4e17b496f4a468030bdd2f334c622", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Map(center=[51.1872, 5.1154], controls=(ZoomControl(options=['position', 'zoom_in_text', 'zoom_in_title', 'zoo…" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "m, dc = get_ui_map()\n", + "m" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Your area of interest: (4.535852, 51.173725, 4.561776, 51.186639) (7 km2)\n" + ] + } + ], + "source": [ + "# retrieve bounding box from drawn rectangle\n", + "from utils import get_bbox_from_draw\n", + "\n", + "bbox, poly = get_bbox_from_draw(dc)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# 2. Check public in situ reference data\n", + "\n", + "Here we do a series of requests to the RDM API to retrieve the collections and samples overlapping our bbox..." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "‼ The following snippet does not query the RDM API, but parquet file on Cloudferro bucket with Phase I extractions" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Applying a buffer of 50.0 km to the selected area ...\n", + "Querying WorldCereal global database ...\n", + "Processing selected samples ...\n", + "Extracted and processed 38542 samples from global database.\n" + ] + } + ], + "source": [ + "from utils import query_worldcereal_samples\n", + "\n", + "public_df = query_worldcereal_samples(poly, buffer=50000, filter_cropland=False)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# 3.Prepare own reference data\n", + "\n", + "Include some guidelines on how to upload user dataset to RDM (using the UI) and requesting those user samples through the API." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "merged_df = public_df.copy()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# 4. Extract required model inputs\n", + "\n", + "Here we launch point extractions for all samples intersecting our bbox resulting in a set of parquet files.\n", + "\n", + "We collect all these inputs and prepare presto features for each sample." + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Loading Presto model ...\n", + "Computing Presto embeddings ...\n", + "Done.\n" + ] + } + ], + "source": [ + "from utils import get_inputs_outputs\n", + "\n", + "encodings, targets = get_inputs_outputs(merged_df, task_type=\"cropland\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# 5. Train custom classification model\n", + "We train a catboost model and upload this model to artifactory." + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Split train/test ...\n", + "Computing class weights ...\n", + "Class weights: {0: 0.6461708017223703, 1: 2.2103278975977556}\n", + "Training CatBoost classifier ...\n", + "0:\tlearn: 0.8181683\ttest: 0.8073083\tbest: 0.8073083 (0)\ttotal: 169ms\tremaining: 22m 32s\n", + "25:\tlearn: 0.8607315\ttest: 0.8497237\tbest: 0.8497237 (25)\ttotal: 2.01s\tremaining: 10m 15s\n", + "50:\tlearn: 0.8709879\ttest: 0.8579291\tbest: 0.8581425 (48)\ttotal: 3.88s\tremaining: 10m 4s\n", + "75:\tlearn: 0.8792820\ttest: 0.8605337\tbest: 0.8609643 (73)\ttotal: 5.76s\tremaining: 10m\n", + "100:\tlearn: 0.8863502\ttest: 0.8644228\tbest: 0.8645838 (99)\ttotal: 7.59s\tremaining: 9m 53s\n", + "125:\tlearn: 0.8938005\ttest: 0.8671287\tbest: 0.8679282 (118)\ttotal: 9.4s\tremaining: 9m 47s\n", + "150:\tlearn: 0.9001281\ttest: 0.8672146\tbest: 0.8684957 (145)\ttotal: 11.3s\tremaining: 9m 47s\n", + "175:\tlearn: 0.9073189\ttest: 0.8663867\tbest: 0.8684957 (145)\ttotal: 13.1s\tremaining: 9m 42s\n", + "Stopped by overfitting detector (50 iterations wait)\n", + "\n", + "bestTest = 0.8684957122\n", + "bestIteration = 145\n", + "\n", + "Shrink model to first 146 iterations.\n" + ] + } + ], + "source": [ + "from utils import train_classifier\n", + "\n", + "custom_model, report = train_classifier(encodings, targets)" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " precision recall f1-score support\n", + "\n", + " 0 0.95 0.91 0.93 8362\n", + " 1 0.73 0.84 0.78 2444\n", + "\n", + " accuracy 0.89 10806\n", + " macro avg 0.84 0.87 0.85 10806\n", + "weighted avg 0.90 0.89 0.90 10806\n", + "\n" + ] + } + ], + "source": [ + "print(report)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# 6. Deploy custom model\n", + "\n", + "Once trained, we have to upload our model to the cloud so it can be used for inference.\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# 7. Generate a map\n", + "\n", + "Using our custom model, we generate a map for our region of interest..." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "exp", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.8" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/notebooks/utils.py b/notebooks/utils.py index cfe3d2ba..0f42f658 100644 --- a/notebooks/utils.py +++ b/notebooks/utils.py @@ -59,7 +59,7 @@ def pick_croptypes(df: pd.DataFrame, samples_threshold: int = 100): return vbox, checkbox_widgets -def query_worldcereal_samples(bbox_poly, buffer=250000): +def query_worldcereal_samples(bbox_poly, buffer=250000, filter_cropland=True): import duckdb import geopandas as gpd @@ -72,6 +72,16 @@ def query_worldcereal_samples(bbox_poly, buffer=250000): .to_crs(epsg=4326)[0] ) + xmin, ymin, xmax, ymax = bbox_poly.bounds + twisted_bbox_poly = Polygon([(ymin, xmin), (ymin, xmax), (ymax, xmax), (ymax, xmin)]) + h3_cells_lst = [] + res = 5 + while len(h3_cells_lst)==0: + h3_cells_lst = list(h3.polyfill(twisted_bbox_poly.__geo_interface__, res)) + res += 1 + if res>5: + h3_cells_lst = tuple(np.unique([h3.h3_to_parent(xx, 5) for xx in h3_cells_lst])) + db = duckdb.connect() db.sql("INSTALL spatial") db.load_extension("spatial") @@ -80,17 +90,26 @@ def query_worldcereal_samples(bbox_poly, buffer=250000): # only querying the croptype data here print("Querying WorldCereal global database ...") - public_df_raw = db.sql( - f""" - set s3_endpoint='s3.waw3-1.cloudferro.com'; - set enable_progress_bar=false; - select * - from read_parquet('{parquet_path}', hive_partitioning = 1) original_data - where st_within(ST_Point(original_data.lon, original_data.lat), ST_GeomFromText('{bbox_poly.wkt}')) - and original_data.LANDCOVER_LABEL = 11 - and original_data.CROPTYPE_LABEL not in (0, 991, 7900, 9900, 9998, 1910, 1900, 1920, 1000, 11, 9910, 6212, 7920, 9520, 3400, 3900, 4390, 4000, 4300) - """ - ).df() + if filter_cropland: + query = f""" + set s3_endpoint='s3.waw3-1.cloudferro.com'; + set enable_progress_bar=false; + select * + from read_parquet('{parquet_path}', hive_partitioning = 1) original_data + where original_data.h3_l5_cell in {h3_cells_lst} + and original_data.LANDCOVER_LABEL = 11 + and original_data.CROPTYPE_LABEL not in (0, 991, 7900, 9900, 9998, 1910, 1900, 1920, 1000, 11, 9910, 6212, 7920, 9520, 3400, 3900, 4390, 4000, 4300) + """ + else: + query = f""" + set s3_endpoint='s3.waw3-1.cloudferro.com'; + set enable_progress_bar=false; + select * + from read_parquet('{parquet_path}', hive_partitioning = 1) original_data + where original_data.h3_l5_cell in {h3_cells_lst} + """ + + public_df_raw = db.sql(query).df() print("Processing selected samples ...") public_df = process_parquet(public_df_raw) public_df = map_croptypes(public_df) @@ -100,12 +119,16 @@ def query_worldcereal_samples(bbox_poly, buffer=250000): def get_inputs_outputs( - df: pd.DataFrame, batch_size: int = 256 + df: pd.DataFrame, batch_size: int = 256, task_type: str = "croptype" ) -> Tuple[np.ndarray, np.ndarray]: from presto.dataset import WorldCerealLabelledDataset from presto.presto import Presto - presto_model_url = "https://artifactory.vgt.vito.be/artifactory/auxdata-public/worldcereal/models/PhaseII/presto-ss-wc-ft-ct-30D_test.pt" + if task_type == "croptype": + presto_model_url = "https://artifactory.vgt.vito.be/artifactory/auxdata-public/worldcereal/models/PhaseII/presto-ss-wc-ft-ct-30D_test.pt" + if task_type == "cropland": + presto_model_url = "https://artifactory.vgt.vito.be/artifactory/auxdata-public/worldcereal/models/PhaseII/presto-ft-cl_30D_cropland_random.pt" + df["custom_class"] = (df["LANDCOVER_LABEL"]==11).astype(int) print("Loading Presto model ...") presto_model = Presto.load_pretrained_url(presto_url=presto_model_url, strict=False) From f39877ab648ce4d29ac4504e21281f611f52611e Mon Sep 17 00:00:00 2001 From: Christina Butsko Date: Tue, 2 Jul 2024 17:41:36 +0200 Subject: [PATCH 2/2] Revert "added notebook for custom crop-nocrop model training + respective functions updated in utils" This reverts commit 9f58692a42286af2b4057608efabe0a354f3eab6. --- .../system_v1_custom_model_cropland.ipynb | 316 ------------------ notebooks/utils.py | 51 +-- 2 files changed, 14 insertions(+), 353 deletions(-) delete mode 100644 notebooks/system_v1_custom_model_cropland.ipynb diff --git a/notebooks/system_v1_custom_model_cropland.ipynb b/notebooks/system_v1_custom_model_cropland.ipynb deleted file mode 100644 index beb2527a..00000000 --- a/notebooks/system_v1_custom_model_cropland.ipynb +++ /dev/null @@ -1,316 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "![](./resources/System_v1_training_header.png)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "This notebook contains a demonstration on how to train custom crop type models based on your own reference data and how to apply the resulting model to generate a custom crop type map.\n", - "\n", - "# Content\n", - "\n", - "- [Before you start](#before-you-start)\n", - "- [1. Define region of interest](#1.-Define-a-region-of-interest)\n", - "- [2. Check public in-situ reference data](#2.-Check-public-in-situ-reference-data)\n", - "- [3. Prepare own reference data](#3.-Prepare-own-reference-data)\n", - "- [4. Extract required model inputs](#4.-Extract-required-model-inputs)\n", - "- [5. Train custom classification model](#5.-Train-custom-classification-model)\n", - "- [6. Generate a map](#6.-Generate-a-map)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Before you start\n", - "\n", - "In order to run this notebook, you need to create an account on:\n", - "\n", - "- The Copernicus Data Space Ecosystem (CDSE)\n", - "--> by completing the form [HERE](https://identity.dataspace.copernicus.eu/auth/realms/CDSE/login-actions/registration?client_id=cdse-public&tab_id=eRKGqDvoYI0)\n", - "\n", - "- VITO's Terrascope platform\n", - "--> by completing the form [HERE](https://sso.terrascope.be/auth/realms/terrascope/login-actions/registration?client_id=drupal-terrascope&tab_id=irBzckp2aDo)" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [], - "source": [ - "%load_ext autoreload\n", - "%autoreload 2" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [], - "source": [ - "from worldcereal.utils.map import get_ui_map\n", - "RDM_API = \"https://ewoc-rdm-api.iiasa.ac.at\"" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# 1. Define a region of interest\n", - "\n", - "When running the code snippet below, an interactive map will be visualized.\n", - "Click the Rectangle button on the left hand side of the map to start drawing your region of interest.\n", - "When finished, execute the second cell to store the coordinates of your region of interest. " - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [ - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "6bd4e17b496f4a468030bdd2f334c622", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Map(center=[51.1872, 5.1154], controls=(ZoomControl(options=['position', 'zoom_in_text', 'zoom_in_title', 'zoo…" - ] - }, - "execution_count": 4, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "m, dc = get_ui_map()\n", - "m" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Your area of interest: (4.535852, 51.173725, 4.561776, 51.186639) (7 km2)\n" - ] - } - ], - "source": [ - "# retrieve bounding box from drawn rectangle\n", - "from utils import get_bbox_from_draw\n", - "\n", - "bbox, poly = get_bbox_from_draw(dc)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# 2. Check public in situ reference data\n", - "\n", - "Here we do a series of requests to the RDM API to retrieve the collections and samples overlapping our bbox..." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "‼ The following snippet does not query the RDM API, but parquet file on Cloudferro bucket with Phase I extractions" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Applying a buffer of 50.0 km to the selected area ...\n", - "Querying WorldCereal global database ...\n", - "Processing selected samples ...\n", - "Extracted and processed 38542 samples from global database.\n" - ] - } - ], - "source": [ - "from utils import query_worldcereal_samples\n", - "\n", - "public_df = query_worldcereal_samples(poly, buffer=50000, filter_cropland=False)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# 3.Prepare own reference data\n", - "\n", - "Include some guidelines on how to upload user dataset to RDM (using the UI) and requesting those user samples through the API." - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": {}, - "outputs": [], - "source": [ - "merged_df = public_df.copy()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# 4. Extract required model inputs\n", - "\n", - "Here we launch point extractions for all samples intersecting our bbox resulting in a set of parquet files.\n", - "\n", - "We collect all these inputs and prepare presto features for each sample." - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Loading Presto model ...\n", - "Computing Presto embeddings ...\n", - "Done.\n" - ] - } - ], - "source": [ - "from utils import get_inputs_outputs\n", - "\n", - "encodings, targets = get_inputs_outputs(merged_df, task_type=\"cropland\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# 5. Train custom classification model\n", - "We train a catboost model and upload this model to artifactory." - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Split train/test ...\n", - "Computing class weights ...\n", - "Class weights: {0: 0.6461708017223703, 1: 2.2103278975977556}\n", - "Training CatBoost classifier ...\n", - "0:\tlearn: 0.8181683\ttest: 0.8073083\tbest: 0.8073083 (0)\ttotal: 169ms\tremaining: 22m 32s\n", - "25:\tlearn: 0.8607315\ttest: 0.8497237\tbest: 0.8497237 (25)\ttotal: 2.01s\tremaining: 10m 15s\n", - "50:\tlearn: 0.8709879\ttest: 0.8579291\tbest: 0.8581425 (48)\ttotal: 3.88s\tremaining: 10m 4s\n", - "75:\tlearn: 0.8792820\ttest: 0.8605337\tbest: 0.8609643 (73)\ttotal: 5.76s\tremaining: 10m\n", - "100:\tlearn: 0.8863502\ttest: 0.8644228\tbest: 0.8645838 (99)\ttotal: 7.59s\tremaining: 9m 53s\n", - "125:\tlearn: 0.8938005\ttest: 0.8671287\tbest: 0.8679282 (118)\ttotal: 9.4s\tremaining: 9m 47s\n", - "150:\tlearn: 0.9001281\ttest: 0.8672146\tbest: 0.8684957 (145)\ttotal: 11.3s\tremaining: 9m 47s\n", - "175:\tlearn: 0.9073189\ttest: 0.8663867\tbest: 0.8684957 (145)\ttotal: 13.1s\tremaining: 9m 42s\n", - "Stopped by overfitting detector (50 iterations wait)\n", - "\n", - "bestTest = 0.8684957122\n", - "bestIteration = 145\n", - "\n", - "Shrink model to first 146 iterations.\n" - ] - } - ], - "source": [ - "from utils import train_classifier\n", - "\n", - "custom_model, report = train_classifier(encodings, targets)" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - " precision recall f1-score support\n", - "\n", - " 0 0.95 0.91 0.93 8362\n", - " 1 0.73 0.84 0.78 2444\n", - "\n", - " accuracy 0.89 10806\n", - " macro avg 0.84 0.87 0.85 10806\n", - "weighted avg 0.90 0.89 0.90 10806\n", - "\n" - ] - } - ], - "source": [ - "print(report)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# 6. Deploy custom model\n", - "\n", - "Once trained, we have to upload our model to the cloud so it can be used for inference.\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# 7. Generate a map\n", - "\n", - "Using our custom model, we generate a map for our region of interest..." - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "exp", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.11.8" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/notebooks/utils.py b/notebooks/utils.py index 0f42f658..cfe3d2ba 100644 --- a/notebooks/utils.py +++ b/notebooks/utils.py @@ -59,7 +59,7 @@ def pick_croptypes(df: pd.DataFrame, samples_threshold: int = 100): return vbox, checkbox_widgets -def query_worldcereal_samples(bbox_poly, buffer=250000, filter_cropland=True): +def query_worldcereal_samples(bbox_poly, buffer=250000): import duckdb import geopandas as gpd @@ -72,16 +72,6 @@ def query_worldcereal_samples(bbox_poly, buffer=250000, filter_cropland=True): .to_crs(epsg=4326)[0] ) - xmin, ymin, xmax, ymax = bbox_poly.bounds - twisted_bbox_poly = Polygon([(ymin, xmin), (ymin, xmax), (ymax, xmax), (ymax, xmin)]) - h3_cells_lst = [] - res = 5 - while len(h3_cells_lst)==0: - h3_cells_lst = list(h3.polyfill(twisted_bbox_poly.__geo_interface__, res)) - res += 1 - if res>5: - h3_cells_lst = tuple(np.unique([h3.h3_to_parent(xx, 5) for xx in h3_cells_lst])) - db = duckdb.connect() db.sql("INSTALL spatial") db.load_extension("spatial") @@ -90,26 +80,17 @@ def query_worldcereal_samples(bbox_poly, buffer=250000, filter_cropland=True): # only querying the croptype data here print("Querying WorldCereal global database ...") - if filter_cropland: - query = f""" - set s3_endpoint='s3.waw3-1.cloudferro.com'; - set enable_progress_bar=false; - select * - from read_parquet('{parquet_path}', hive_partitioning = 1) original_data - where original_data.h3_l5_cell in {h3_cells_lst} - and original_data.LANDCOVER_LABEL = 11 - and original_data.CROPTYPE_LABEL not in (0, 991, 7900, 9900, 9998, 1910, 1900, 1920, 1000, 11, 9910, 6212, 7920, 9520, 3400, 3900, 4390, 4000, 4300) - """ - else: - query = f""" - set s3_endpoint='s3.waw3-1.cloudferro.com'; - set enable_progress_bar=false; - select * - from read_parquet('{parquet_path}', hive_partitioning = 1) original_data - where original_data.h3_l5_cell in {h3_cells_lst} - """ - - public_df_raw = db.sql(query).df() + public_df_raw = db.sql( + f""" + set s3_endpoint='s3.waw3-1.cloudferro.com'; + set enable_progress_bar=false; + select * + from read_parquet('{parquet_path}', hive_partitioning = 1) original_data + where st_within(ST_Point(original_data.lon, original_data.lat), ST_GeomFromText('{bbox_poly.wkt}')) + and original_data.LANDCOVER_LABEL = 11 + and original_data.CROPTYPE_LABEL not in (0, 991, 7900, 9900, 9998, 1910, 1900, 1920, 1000, 11, 9910, 6212, 7920, 9520, 3400, 3900, 4390, 4000, 4300) + """ + ).df() print("Processing selected samples ...") public_df = process_parquet(public_df_raw) public_df = map_croptypes(public_df) @@ -119,16 +100,12 @@ def query_worldcereal_samples(bbox_poly, buffer=250000, filter_cropland=True): def get_inputs_outputs( - df: pd.DataFrame, batch_size: int = 256, task_type: str = "croptype" + df: pd.DataFrame, batch_size: int = 256 ) -> Tuple[np.ndarray, np.ndarray]: from presto.dataset import WorldCerealLabelledDataset from presto.presto import Presto - if task_type == "croptype": - presto_model_url = "https://artifactory.vgt.vito.be/artifactory/auxdata-public/worldcereal/models/PhaseII/presto-ss-wc-ft-ct-30D_test.pt" - if task_type == "cropland": - presto_model_url = "https://artifactory.vgt.vito.be/artifactory/auxdata-public/worldcereal/models/PhaseII/presto-ft-cl_30D_cropland_random.pt" - df["custom_class"] = (df["LANDCOVER_LABEL"]==11).astype(int) + presto_model_url = "https://artifactory.vgt.vito.be/artifactory/auxdata-public/worldcereal/models/PhaseII/presto-ss-wc-ft-ct-30D_test.pt" print("Loading Presto model ...") presto_model = Presto.load_pretrained_url(presto_url=presto_model_url, strict=False)