Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pytorch implementation #1

Open
wants to merge 22 commits into
base: dev/pytorch
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
bccb1fd
implemented pytorch port of NIG continuous loss
Dariusrussellkish Dec 13, 2020
bf84275
implemented pytorch equivalents of layers and most losses
Dariusrussellkish Dec 13, 2020
d973b95
refactor imports for neurips2020 folder
Dariusrussellkish Dec 13, 2020
a88d63e
refactor inits
Dariusrussellkish Dec 13, 2020
f69c59f
refactor imports
Dariusrussellkish Dec 13, 2020
d376190
device matching on tensors
Dariusrussellkish Dec 13, 2020
f558457
reimplement student_t to be consistent with TF impl.
Dariusrussellkish Dec 13, 2020
ab90de4
undo setup.py hotfix
Dariusrussellkish Dec 13, 2020
24cc9aa
refactor imports
Dariusrussellkish Dec 13, 2020
797210e
Revert "refactor imports"
Dariusrussellkish Dec 14, 2020
135967c
Revert "refactor imports"
Dariusrussellkish Dec 14, 2020
6500e91
Revert "refactor imports for neurips2020 folder"
Dariusrussellkish Dec 14, 2020
ce9f606
Fixed namespace manipulation
Dariusrussellkish Dec 14, 2020
afce1fa
Implement Dirichlet_SOS loss
Dariusrussellkish Dec 14, 2020
ce7ed5d
Reimplement Dirichlet_SOS based on @dougbrion
Dariusrussellkish Dec 24, 2020
ad5d66c
pytorch validation with discrete loss based on arxiv.org/abs/1806.01768
Dariusrussellkish Dec 24, 2020
a1e184c
Comment update to include pytorch
Dariusrussellkish Dec 24, 2020
f6726c3
Added pytorch discrete validation plotting
Dariusrussellkish Dec 24, 2020
ff58a90
added license attribution
Dariusrussellkish Dec 24, 2020
a561049
added pytorch environment
Dariusrussellkish Dec 25, 2020
f518eb6
update logging
Dariusrussellkish Dec 25, 2020
8edf33a
bugfix in logging
Dariusrussellkish Dec 25, 2020
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 58 additions & 2 deletions evidential_deep_learning/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,58 @@
from . import layers
from . import losses
# TODO: This is pretty hacky namespace manipulation but it works
import sys

self = sys.modules[__name__]

default_backend = 'tf'

self.torch_avail = False
try:
import torch

self.torch_avail = True
self.backend = 'torch'
except ImportError:
pass

self.tf_avail = False
try:
import tensorflow as tf

self.tf_avail = True
self.backend = 'tf'
except ImportError:
pass

if not (self.torch_avail or self.tf_avail):
raise ImportError("Must have either PyTorch or Tensorflow available")

if self.torch_avail and self.tf_avail:
self.backend = default_backend


def set_backend(backend):
if backend == 'tf':
if not self.tf_avail:
raise ImportError(f"Cannot use backend 'tf' if tensorflow is not installed")
from .tf import layers as layers
from .tf import losses as losses
self.layers = layers
self.losses = losses
elif backend == 'torch':
if not self.torch_avail:
raise ImportError(f"Cannot use backend 'torch' if pytorch is not installed")
from .pytorch import layers as layers
from .pytorch import losses as losses
self.layers = layers
self.losses = losses
else:
raise ValueError(f"Invalid choice of backend: {backend}, options are 'tf' or 'torch'")


def get_backend():
return self.backend


self.get_backend = get_backend
self.set_backend = set_backend
self.set_backend(self.backend)
1 change: 1 addition & 0 deletions evidential_deep_learning/pytorch/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from . import losses, layers
2 changes: 2 additions & 0 deletions evidential_deep_learning/pytorch/layers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .dense import *
from .conv2d import *
48 changes: 48 additions & 0 deletions evidential_deep_learning/pytorch/layers/conv2d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import torch
from torch.nn import Module, Conv2d
import torch.nn.functional as F


# TODO: efficiently handle batch dimension


class Conv2DNormal(Module):
def __init__(self, in_channels, out_tasks, kernel_size, **kwargs):
super(Conv2DNormal, self).__init__()
self.in_channels = in_channels
self.out_channels = 2 * out_tasks
self.n_tasks = out_tasks
self.conv = Conv2d(self.in_channels, self.out_channels, kernel_size, **kwargs)

def forward(self, x):
output = self.conv(x)
if len(x.shape) == 3:
mu, logsigma = torch.split(output, self.n_tasks, dim=0)
else:
mu, logsigma = torch.split(output, self.n_tasks, dim=1)

sigma = F.softplus(logsigma) + 1e-6

return torch.stack([mu, sigma]).to(x.device)


class Conv2DNormalGamma(Module):
def __init__(self, in_channels, out_tasks, kernel_size, **kwargs):
super(Conv2DNormalGamma, self).__init__()
self.in_channels = in_channels
self.out_channels = out_tasks
self.conv = Conv2d(in_channels, 4 * out_tasks, kernel_size, **kwargs)

def forward(self, x):
output = self.conv(x)

if len(x.shape) == 3:
gamma, lognu, logalpha, logbeta = torch.split(output, self.out_channels, dim=0)
else:
gamma, lognu, logalpha, logbeta = torch.split(output, self.out_channels, dim=1)

nu = F.softplus(lognu)
alpha = F.softplus(logalpha) + 1.
beta = F.softplus(logbeta)
return torch.stack([gamma, nu, alpha, beta]).to(x.device)

47 changes: 47 additions & 0 deletions evidential_deep_learning/pytorch/layers/dense.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import torch
import torch.nn.functional as F
from torch import nn
from torch.nn import Module

# TODO: Find a way to efficiently handle batch dimension


class DenseNormal(Module):
def __init__(self, n_input, n_out_tasks=1):
super(DenseNormal, self).__init__()
self.n_in = n_input
self.n_out = 2 * n_out_tasks
self.n_tasks = n_out_tasks
self.l1 = nn.Linear(self.n_in, self.n_out)

def forward(self, x):
x = self.l1(x)
if len(x.shape) == 1:
mu, logsigma = torch.split(x, self.n_tasks, dim=0)
else:
mu, logsigma = torch.split(x, self.n_tasks, dim=1)

sigma = F.softplus(logsigma) + 1e-6
return torch.stack(mu, sigma).to(x.device)


class DenseNormalGamma(Module):
def __init__(self, n_input, n_out_tasks=1):
super(DenseNormalGamma, self).__init__()
self.n_in = n_input
self.n_out = 4 * n_out_tasks
self.n_tasks = n_out_tasks
self.l1 = nn.Linear(self.n_in, self.n_out)

def forward(self, x):
x = self.l1(x)
if len(x.shape) == 1:
gamma, lognu, logalpha, logbeta = torch.split(x, self.n_tasks, dim=0)
else:
gamma, lognu, logalpha, logbeta = torch.split(x, self.n_tasks, dim=1)

nu = F.softplus(lognu)
alpha = F.softplus(logalpha) + 1.
beta = F.softplus(logbeta)

return torch.stack([gamma, nu, alpha, beta]).to(x.device)
2 changes: 2 additions & 0 deletions evidential_deep_learning/pytorch/losses/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .continous import *
from .discrete import *
57 changes: 57 additions & 0 deletions evidential_deep_learning/pytorch/losses/continous.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import torch
from torch.distributions import Normal
from torch import nn
import numpy as np

MSE = nn.MSELoss(reduction='mean')


def reduce(val, reduction):
if reduction == 'mean':
val = val.mean()
elif reduction == 'sum':
val = val.sum()
elif reduction == 'none':
pass
else:
raise ValueError(f"Invalid reduction argument: {reduction}")
return val


def RMSE(y, y_):
return MSE(y, y_).sqrt()


def Gaussian_NLL(y, mu, sigma, reduction='mean'):
dist = Normal(loc=mu, scale=sigma)
# TODO: refactor to mirror TF implementation due to numerical instability
logprob = -1. * dist.log_prob(y)
return reduce(logprob, reduction=reduction)


def NIG_NLL(y: torch.Tensor,
gamma: torch.Tensor,
nu: torch.Tensor,
alpha: torch.Tensor,
beta: torch.Tensor, reduction='mean'):
inter = 2 * beta * (1 + nu)

nll = 0.5 * (np.pi / nu).log() \
- alpha * inter.log() \
+ (alpha + 0.5) * (nu * (y - gamma) ** 2 + inter).log() \
+ torch.lgamma(alpha) \
- torch.lgamma(alpha + 0.5)
return reduce(nll, reduction=reduction)


def NIG_Reg(y, gamma, nu, alpha, reduction='mean'):
error = (y - gamma).abs()
evidence = 2. * nu + alpha
return reduce(error * evidence, reduction=reduction)


def EvidentialRegression(y: torch.Tensor, evidential_output: torch.Tensor, lmbda=1.):
gamma, nu, alpha, beta = evidential_output
loss_nll = NIG_NLL(y, gamma, nu, alpha, beta)
loss_reg = NIG_Reg(y, gamma, nu, alpha)
return loss_nll, lmbda * loss_reg
103 changes: 103 additions & 0 deletions evidential_deep_learning/pytorch/losses/discrete.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
import torch
import torch.nn.functional as F

BCELoss = torch.nn.BCEWithLogitsLoss()


def Dirichlet_SOS(y, outputs, device=None):
return edl_log_loss(outputs, y, device=device if device else outputs.device)


def Dirichlet_Evidence(outputs):
"""Calculate ReLU evidence"""
return relu_evidence(outputs)


def Dirichlet_Matches(predictions, labels):
"""Calculate the number of matches from index predictions"""
assert predictions.shape == labels.shape, f"Dimension mismatch between predictions " \
f"({predictions.shape}) and labels ({labels.shape})"
return torch.reshape(torch.eq(predictions, labels).float(), (-1, 1))


def Dirichlet_Predictions(outputs):
"""Calculate predictions from logits"""
return torch.argmax(outputs, dim=1)


def Dirichlet_Uncertainty(outputs):
"""Calculate uncertainty from logits"""
alpha = relu_evidence(outputs) + 1
return alpha.size(1) / torch.sum(alpha, dim=1, keepdim=True)


def Sigmoid_CE(y, y_logits, device=None):
return BCELoss(y_logits, y, device=device if device else y_logits.device)


# MIT License
#
# Copyright (c) 2019 Douglas Brion
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

def relu_evidence(y):
return F.relu(y)


def exp_evidence(y):
return torch.exp(torch.clamp(y, -10, 10))


def softplus_evidence(y):
return F.softplus(y)


def kl_divergence(alpha, num_classes, device=None):
beta = torch.ones([1, num_classes], dtype=torch.float32, device=device)
S_alpha = torch.sum(alpha, dim=1, keepdim=True)
S_beta = torch.sum(beta, dim=1, keepdim=True)
lnB = torch.lgamma(S_alpha) - \
torch.sum(torch.lgamma(alpha), dim=1, keepdim=True)
lnB_uni = torch.sum(torch.lgamma(beta), dim=1,
keepdim=True) - torch.lgamma(S_beta)

dg0 = torch.digamma(S_alpha)
dg1 = torch.digamma(alpha)

kl = torch.sum((alpha - beta) * (dg1 - dg0), dim=1,
keepdim=True) + lnB + lnB_uni
return kl


def edl_loss(func, y, alpha, device=None):
S = torch.sum(alpha, dim=1, keepdim=True)
A = torch.sum(y * (func(S) - func(alpha)), dim=1, keepdim=True)

kl_alpha = (alpha - 1) * (1 - y) + 1
kl_div = kl_divergence(kl_alpha, y.shape[1], device=device)
return A + kl_div


def edl_log_loss(output, target, device=None):
evidence = relu_evidence(output)
alpha = evidence + 1
loss = torch.mean(edl_loss(torch.log, target, alpha, device=device))
assert loss is not None
return loss
1 change: 1 addition & 0 deletions evidential_deep_learning/tf/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from . import losses, layers
Loading