Skip to content

Commit

Permalink
Merge pull request #62 from jrzaurin/pmulinka/losses
Browse files Browse the repository at this point in the history
Pmulinka/losses - all additional changes by @jrzaurin are ok
  • Loading branch information
5uperpalo committed Dec 6, 2021
2 parents e1e6a2a + 7f323f1 commit 88790c3
Show file tree
Hide file tree
Showing 11 changed files with 3,377 additions and 44 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ dist/
__pycache__*
*.pyc
*.pyo
*.key
*.egg-info

# jupyter / ipython
Expand Down
2 changes: 1 addition & 1 deletion VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
1.0.12
1.0.13
2,978 changes: 2,978 additions & 0 deletions examples/14_ZILNLoss_OrigKeras_vs_pytorch_multimodal.ipynb

Large diffs are not rendered by default.

137 changes: 137 additions & 0 deletions pytorch_widedeep/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,131 @@
use_cuda = torch.cuda.is_available()


class TweedieLoss(nn.Module):
"""
Tweedie loss for extremely unbalanced zero-inflated data``
All credits go to `Wenbo Shi
<https://towardsdatascience.com/tweedie-loss-function-for-right-skewed-data-2c5ca470678f> and
<https://arxiv.org/abs/1811.10192>`
"""

def __init__(self):
super().__init__()

def forward(self, input: Tensor, target: Tensor, p=1.5) -> Tensor:
assert (
input.min() > 0
), """All input values must be >=0, if your model is predicting
values <0 try to enforce positive values by activation function
on last layer with `trainer.enforce_positive_output=True`"""
assert target.min() >= 0, "All target values must be >=0"
loss = -target * torch.pow(input, 1 - p) / (1 - p) + torch.pow(input, 2 - p) / (
2 - p
)
return torch.mean(loss)


class QuantileLoss(nn.Module):
"""
Quantile loss, i.e. a quantile of ``q=0.5`` will give half of the mean absolute error as it is calcualted as
Defined as ``max(q * (y-y_pred), (1-q) * (y_pred-y))``
All credits go to `pytorch-forecasting
<https://pytorch-forecasting.readthedocs.io/en/latest/_modules/pytorch_forecasting/metrics.html#QuantileLoss>`
"""

def __init__(
self,
quantiles: List[float] = [0.02, 0.1, 0.25, 0.5, 0.75, 0.9, 0.98],
):
"""
Quantile loss
Args:
quantiles: quantiles for metric
"""
super().__init__()
self.quantiles = quantiles

def forward(self, input: Tensor, target: Tensor) -> Tensor:
assert input.shape == torch.Size([target.shape[0], len(self.quantiles)]), (
f"Wrong shape of input, pred_dim of the model that is using QuantileLoss must be equal "
f"to number of quantiles, i.e. {len(self.quantiles)}."
)
target = target.view(-1, 1).float()
losses = []
for i, q in enumerate(self.quantiles):
errors = target - input[..., i]
losses.append(torch.max((q - 1) * errors, q * errors).unsqueeze(-1))
loss = torch.cat(losses, dim=2)

return torch.mean(loss)


class ZILNLoss(nn.Module):
r"""Adjusted implementation of the `Zero Inflated LogNormal loss
<https://arxiv.org/pdf/1912.07753.pdf>` and its `code
<https://github.com/google/lifetime_value/blob/master/lifetime_value/zero_inflated_lognormal.py>`
"""

def __init__(self):
super().__init__()

def forward(self, input: Tensor, target: Tensor) -> Tensor:
r"""
Parameters
----------
input: Tensor
input tensor with predictions (not probabilities) with spape (N,3), where N is the batch size
target: Tensor
target tensor with the actual classes
Examples
--------
>>> import torch
>>>
>>> from pytorch_widedeep.losses import ZILNLoss
>>>
>>> # REGRESSION
>>> target = torch.tensor([[0., 1.5]]).view(-1, 1)
>>> input = torch.tensor([[.1, .2, .3], [.4, .5, .6]])
>>> ZILNLoss()(input, target)
tensor(1.3114)
"""
positive = target > 0
positive = positive.float()

assert input.shape == torch.Size(
[target.shape[0], 3]
), "Wrong shape of input, pred_dim of the model that is using ZILNLoss must be equal to 3."
positive_input = input[..., :1]

classification_loss = F.binary_cross_entropy_with_logits(
positive_input, positive, reduction="none"
).flatten()

loc = input[..., 1:2]

# when using max the two input tensors (input and other) have to be of
# the same type
max_input = F.softplus(input[..., 2:])
max_other = torch.sqrt(torch.Tensor([torch.finfo(torch.double).eps])).type(
max_input.type()
)
scale = torch.max(max_input, max_other)
safe_labels = positive * target + (1 - positive) * torch.ones_like(target)

regression_loss = -torch.mean(
positive
* torch.distributions.log_normal.LogNormal(loc=loc, scale=scale).log_prob(
safe_labels
),
dim=-1,
)

return torch.mean(classification_loss + regression_loss)


class FocalLoss(nn.Module):
r"""Implementation of the `focal loss
<https://arxiv.org/pdf/1708.02002.pdf>`_ for both binary and
Expand Down Expand Up @@ -105,6 +230,12 @@ def forward(self, input: Tensor, target: Tensor) -> Tensor:
>>> MSLELoss()(input, target)
tensor(0.1115)
"""
assert (
input.min() >= 0
), """All input values must be >=0, if your model is predicting
values <0 try to enforce positive values by activation function
on last layer with `trainer.enforce_positive_output=True`"""
assert target.min() >= 0, "All target values must be >=0"
return self.mse(torch.log(input + 1), torch.log(target + 1))


Expand Down Expand Up @@ -163,4 +294,10 @@ def forward(self, input: Tensor, target: Tensor) -> Tensor:
>>> RMSLELoss()(input, target)
tensor(0.3339)
"""
assert (
input.min() >= 0
), """All input values must be >=0, if your model is predicting
values <0 try to enforce positive values by activation function
on last layer with `trainer.enforce_positive_output=True`"""
assert target.min() >= 0, "All target values must be >=0"
return torch.sqrt(self.mse(torch.log(input + 1), torch.log(target + 1)))
12 changes: 11 additions & 1 deletion pytorch_widedeep/models/tab_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,15 @@

from pytorch_widedeep.wdtypes import * # noqa: F403

allowed_activations = ["relu", "leaky_relu", "tanh", "gelu", "geglu", "reglu"]
allowed_activations = [
"relu",
"leaky_relu",
"tanh",
"gelu",
"geglu",
"reglu",
"softplus",
]


class GEGLU(nn.Module):
Expand Down Expand Up @@ -33,6 +41,8 @@ def get_activation_fn(activation):
return GEGLU()
if activation == "reglu":
return REGLU()
if activation == "softplus":
return nn.Softplus()


def dense_layer(
Expand Down
26 changes: 21 additions & 5 deletions pytorch_widedeep/models/wide_deep.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import torch.nn as nn

from pytorch_widedeep.wdtypes import * # noqa: F403
from pytorch_widedeep.models.tab_mlp import MLP
from pytorch_widedeep.models.tab_mlp import MLP, get_activation_fn
from pytorch_widedeep.models.tabnet.tab_net import TabNetPredLayer

warnings.filterwarnings("default", category=UserWarning)
Expand Down Expand Up @@ -87,6 +87,13 @@ class WideDeep(nn.Module):
the order of the operations in the dense layer. If ``True``:
``[LIN -> ACT -> BN -> DP]``. If ``False``: ``[BN -> DP -> LIN ->
ACT]``
enforce_positive: bool, default = False
If final layer has activation function or not. Important if you are using
loss functions non-negative input restrictions, e.g. RMSLE, or if you know
your predictions are limited only to <0, inf)
enforce_positive_activation: str, default = "softplus"
Activation function to enforce positive output from final layer. Use
"softplus" or "relu".
pred_dim: int, default = 1
Size of the final wide and deep output layer containing the
predictions. `1` for regression and binary classification or number
Expand Down Expand Up @@ -131,6 +138,8 @@ def __init__(
head_batchnorm: bool = False,
head_batchnorm_last: bool = False,
head_linear_first: bool = False,
enforce_positive: bool = False,
enforce_positive_activation: str = "softplus",
pred_dim: int = 1,
):
super(WideDeep, self).__init__()
Expand All @@ -154,6 +163,7 @@ def __init__(
self.deeptext = deeptext
self.deepimage = deepimage
self.deephead = deephead
self.enforce_positive = enforce_positive

if self.deeptabular is not None:
self.is_tabnet = deeptabular.__class__.__name__ == "TabNet"
Expand All @@ -173,12 +183,19 @@ def __init__(
else:
self._add_pred_layer()

if self.enforce_positive:
self.enf_pos = get_activation_fn(enforce_positive_activation)

def forward(self, X: Dict[str, Tensor]):
wide_out = self._forward_wide(X)
if self.deephead:
return self._forward_deephead(X, wide_out)
deep = self._forward_deephead(X, wide_out)
else:
deep = self._forward_deep(X, wide_out)
if self.enforce_positive:
return self.enf_pos(deep)
else:
return self._forward_deep(X, wide_out)
return deep

def _build_deephead(
self,
Expand Down Expand Up @@ -206,7 +223,6 @@ def _build_deephead(
head_batchnorm_last,
head_linear_first,
)

self.deephead.add_module(
"head_out", nn.Linear(head_hidden_dims[-1], self.pred_dim)
)
Expand Down Expand Up @@ -285,7 +301,7 @@ def _forward_deep(self, X, wide_out):
return res

@staticmethod # noqa: C901
def _check_model_components(
def _check_model_components( # noqa: C901
wide,
deeptabular,
deeptext,
Expand Down
7 changes: 7 additions & 0 deletions pytorch_widedeep/training/_loss_and_obj_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ class _LossAliases:
"mean_squared_log_error": ["mean_squared_log_error", "msle"],
"root_mean_squared_error": ["root_mean_squared_error", "rmse"],
"root_mean_squared_log_error": ["root_mean_squared_log_error", "rmsle"],
"zero_inflated_lognormal": ["zero_inflated_lognormal", "ziln"],
"quantile": ["quantile"],
"tweedie": ["tweedie"],
}

@classproperty
Expand Down Expand Up @@ -69,6 +72,10 @@ class _ObjectiveToMethod:
"rmse": "regression",
"root_mean_squared_log_error": "regression",
"rmsle": "regression",
"zero_inflated_lognormal": "regression",
"ziln": "regression",
"tweedie": "regression",
"quantile": "qregression",
}

@classproperty
Expand Down
18 changes: 16 additions & 2 deletions pytorch_widedeep/training/_trainer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,15 @@
from torch import nn
from sklearn.model_selection import train_test_split

from pytorch_widedeep.losses import MSLELoss, RMSELoss, FocalLoss, RMSLELoss
from pytorch_widedeep.losses import (
MSLELoss,
RMSELoss,
ZILNLoss,
FocalLoss,
RMSLELoss,
TweedieLoss,
QuantileLoss,
)
from pytorch_widedeep.wdtypes import Dict, List, Optional, Transforms
from pytorch_widedeep.training._wd_dataset import WideDeepDataset
from pytorch_widedeep.training._loss_and_obj_aliases import (
Expand Down Expand Up @@ -176,7 +184,7 @@ def save_epoch_logs(epoch_logs: Dict, loss: float, score: Dict, stage: str):
return epoch_logs


def alias_to_loss(loss_fn: str, **kwargs):
def alias_to_loss(loss_fn: str, **kwargs): # noqa: C901
r"""
Function that returns the corresponding loss function given an alias
Expand Down Expand Up @@ -215,5 +223,11 @@ def alias_to_loss(loss_fn: str, **kwargs):
return RMSELoss()
if loss_fn in _LossAliases.get("root_mean_squared_log_error"):
return RMSLELoss()
if loss_fn in _LossAliases.get("zero_inflated_lognormal"):
return ZILNLoss()
if loss_fn in _LossAliases.get("quantile"):
return QuantileLoss()
if loss_fn in _LossAliases.get("tweedie"):
return TweedieLoss()
if "focal_loss" in loss_fn:
return FocalLoss(**kwargs)
Loading

0 comments on commit 88790c3

Please sign in to comment.