Skip to content

Commit

Permalink
Merge pull request #189 from WorldCereal/custom-training-update
Browse files Browse the repository at this point in the history
Make the use of class weights optional (default=False)
  • Loading branch information
kvantricht authored Oct 15, 2024
2 parents 9c60085 + e8e9e68 commit a805195
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 48 deletions.
17 changes: 0 additions & 17 deletions notebooks/README.md

This file was deleted.

88 changes: 57 additions & 31 deletions notebooks/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,24 @@
import random
from calendar import monthrange
from datetime import datetime, timedelta
from typing import List, Optional
from typing import List, Optional, Tuple, Union

import ipywidgets as widgets
import leafmap
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import rasterio
from catboost import CatBoostClassifier, Pool
from IPython.display import display
from loguru import logger
from matplotlib.patches import Rectangle
from openeo_gfmap import BoundingBoxExtent, TemporalContext
from presto.utils import DEFAULT_SEED
from pyproj import Transformer
from sklearn.metrics import classification_report, confusion_matrix
from sklearn.model_selection import train_test_split
from sklearn.utils.class_weight import compute_class_weight

from worldcereal.parameters import CropLandParameters, CropTypeParameters
from worldcereal.seasons import get_season_dates_for_extent
Expand Down Expand Up @@ -372,14 +377,31 @@ def get_custom_cropland_labels(df, checkbox_widgets, new_label="cropland"):


def train_classifier(
training_dataframe: pd.DataFrame, class_names: Optional[List[str]] = None
):
import numpy as np
from catboost import CatBoostClassifier, Pool
from presto.utils import DEFAULT_SEED
from sklearn.metrics import classification_report, confusion_matrix
from sklearn.model_selection import train_test_split
from sklearn.utils.class_weight import compute_class_weight
training_dataframe: pd.DataFrame,
class_names: Optional[List[str]] = None,
balance_classes: bool = False,
) -> Tuple[CatBoostClassifier, Union[str | dict], np.ndarray]:
"""Method to train a custom CatBoostClassifier on a training dataframe.
Parameters
----------
training_dataframe : pd.DataFrame
training dataframe containing inputs and targets
class_names : Optional[List[str]], optional
class names to use, by default None
balance_classes : bool, optional
if True, class weights are used during training to balance the classes, by default False
Returns
-------
Tuple[CatBoostClassifier, Union[str | dict], np.ndarray]
The trained CatBoost model, the classification report, and the confusion matrix
Raises
------
ValueError
When not enough classes are present in the training dataframe to train a model
"""

logger.info("Split train/test ...")
samples_train, samples_test = train_test_split(
Expand All @@ -400,28 +422,32 @@ def train_classifier(
loss_function = "Logloss"

# Compute sample weights
logger.info("Computing class weights ...")
class_weights = np.round(
compute_class_weight(
class_weight="balanced",
classes=np.unique(samples_train["downstream_class"]),
y=samples_train["downstream_class"],
),
3,
)
class_weights = {
k: v
for k, v in zip(np.unique(samples_train["downstream_class"]), class_weights)
}
logger.info(f"Class weights: {class_weights}")

sample_weights = np.ones((len(samples_train["downstream_class"]),))
sample_weights_val = np.ones((len(samples_test["downstream_class"]),))
for k, v in class_weights.items():
sample_weights[samples_train["downstream_class"] == k] = v
sample_weights_val[samples_test["downstream_class"] == k] = v
samples_train["weight"] = sample_weights
samples_test["weight"] = sample_weights_val
if balance_classes:
logger.info("Computing class weights ...")
class_weights = np.round(
compute_class_weight(
class_weight="balanced",
classes=np.unique(samples_train["downstream_class"]),
y=samples_train["downstream_class"],
),
3,
)
class_weights = {
k: v
for k, v in zip(np.unique(samples_train["downstream_class"]), class_weights)
}
logger.info(f"Class weights: {class_weights}")

sample_weights = np.ones((len(samples_train["downstream_class"]),))
sample_weights_val = np.ones((len(samples_test["downstream_class"]),))
for k, v in class_weights.items():
sample_weights[samples_train["downstream_class"] == k] = v
sample_weights_val[samples_test["downstream_class"] == k] = v
samples_train["weight"] = sample_weights
samples_test["weight"] = sample_weights_val
else:
samples_train["weight"] = 1
samples_test["weight"] = 1

# Define classifier
custom_downstream_model = CatBoostClassifier(
Expand Down

0 comments on commit a805195

Please sign in to comment.