Skip to content

Commit

Permalink
minor code cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
5uperpalo committed Dec 11, 2021
1 parent b324222 commit f17fe84
Showing 1 changed file with 0 additions and 60 deletions.
60 changes: 0 additions & 60 deletions pytorch_widedeep/metrics.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import numpy as np
import torch
from torchmetrics import Metric as TorchMetric
from torchmetrics import AUC

from .wdtypes import * # noqa: F403

Expand Down Expand Up @@ -394,62 +393,3 @@ def __call__(self, y_pred: Tensor, y_true: Tensor) -> np.ndarray:
y_true_avg = self.y_true_sum / self.num_examples
self.denominator += ((y_true - y_true_avg) ** 2).sum().item()
return np.array((1 - (self.numerator / self.denominator)))


class Accuracy(Metric):
r"""Class to calculate the accuracy for both binary and categorical problems
Parameters
----------
top_k: int, default = 1
Accuracy will be computed using the top k most likely classes in
multiclass problems
Examples
--------
>>> import torch
>>>
>>> from pytorch_widedeep.metrics import Accuracy
>>>
>>> acc = Accuracy()
>>> y_true = torch.tensor([0, 1, 0, 1]).view(-1, 1)
>>> y_pred = torch.tensor([[0.3, 0.2, 0.6, 0.7]]).view(-1, 1)
>>> acc(y_pred, y_true)
array(0.5)
>>>
>>> acc = Accuracy(top_k=2)
>>> y_true = torch.tensor([0, 1, 2])
>>> y_pred = torch.tensor([[0.3, 0.5, 0.2], [0.1, 0.1, 0.8], [0.1, 0.5, 0.4]])
>>> acc(y_pred, y_true)
array(0.66666667)
"""

def __init__(self, top_k: int = 1):
super(Accuracy, self).__init__()

self.top_k = top_k
self.correct_count = 0
self.total_count = 0
self._name = "acc"

def reset(self):
"""
resets counters to 0
"""
self.correct_count = 0
self.total_count = 0

def __call__(self, y_pred: Tensor, y_true: Tensor) -> np.ndarray:
num_classes = y_pred.size(1)

if num_classes == 1:
y_pred = y_pred.round()
y_true = y_true
elif num_classes > 1:
y_pred = y_pred.topk(self.top_k, 1)[1]
y_true = y_true.view(-1, 1).expand_as(y_pred)

self.correct_count += y_pred.eq(y_true).sum().item() # type: ignore[assignment]
self.total_count += len(y_pred)
accuracy = float(self.correct_count) / float(self.total_count)
return np.array(accuracy)

0 comments on commit f17fe84

Please sign in to comment.