Skip to content

Easy to use class balanced cross entropy and focal loss implementation for Pytorch

License

Notifications You must be signed in to change notification settings

fcakyon/balanced-loss

Repository files navigation

pypi version total downloads fcakyon twitter

Easy-to-use, class-balanced, cross-entropy and focal loss implementation for Pytorch.

Theory

When training dataset labels are imbalanced, one thing to do is to balance the loss across sample classes.

  • First, the effective number of samples are calculated for all classes as:

alt-text

  • Then the class balanced loss function is defined as:

alt-text

Installation

pip install balanced-loss

Usage

  • Standard losses:
import torch
from balanced_loss import Loss

# outputs and labels
logits = torch.tensor([[0.78, 0.1, 0.05]]) # 1 batch, 3 class
labels = torch.tensor([0]) # 1 batch

# focal loss
focal_loss = Loss(loss_type="focal_loss")
loss = focal_loss(logits, labels)
# cross-entropy loss
ce_loss = Loss(loss_type="cross_entropy")
loss = ce_loss(logits, labels)
# binary cross-entropy loss
bce_loss = Loss(loss_type="binary_cross_entropy")
loss = bce_loss(logits, labels)
  • Class-balanced losses:
import torch
from balanced_loss import Loss

# outputs and labels
logits = torch.tensor([[0.78, 0.1, 0.05]]) # 1 batch, 3 class
labels = torch.tensor([0]) # 1 batch

# number of samples per class in the training dataset
samples_per_class = [30, 100, 25] # 30, 100, 25 samples for labels 0, 1 and 2, respectively

# class-balanced focal loss
focal_loss = Loss(
    loss_type="focal_loss",
    samples_per_class=samples_per_class,
    class_balanced=True
)
loss = focal_loss(logits, labels)
# class-balanced cross-entropy loss
ce_loss = Loss(
    loss_type="cross_entropy",
    samples_per_class=samples_per_class,
    class_balanced=True
)
loss = ce_loss(logits, labels)
# class-balanced binary cross-entropy loss
bce_loss = Loss(
    loss_type="binary_cross_entropy",
    samples_per_class=samples_per_class,
    class_balanced=True
)
loss = bce_loss(logits, labels)
  • Customize parameters:
import torch
from balanced_loss import Loss

# outputs and labels
logits = torch.tensor([[0.78, 0.1, 0.05]]) # 1 batch, 3 class
labels = torch.tensor([0])

# number of samples per class in the training dataset
samples_per_class = [30, 100, 25] # 30, 100, 25 samples for labels 0, 1 and 2, respectively

# class-balanced focal loss
focal_loss = Loss(
    loss_type="focal_loss",
    beta=0.999, # class-balanced loss beta
    fl_gamma=2, # focal loss gamma
    samples_per_class=samples_per_class,
    class_balanced=True
)
loss = focal_loss(logits, labels)

Improvements

What is the difference between this repo and vandit15's?

  • This repo is a pypi installable package
  • This repo implements loss functions as torch.nn.Module
  • In addition to class balanced losses, this repo also supports the standard versions of the cross entropy/focal loss etc. over the same API
  • All typos and errors in vandit15's source are fixed
  • Continuously tested on PyTorch 1.13.1 and 2.5.1
  • Automatically selects loss module device based on logits

References

https://arxiv.org/abs/1901.05555

https://github.com/richardaecn/class-balanced-loss

https://github.com/vandit15/Class-balanced-loss-pytorch