-
Notifications
You must be signed in to change notification settings - Fork 0
/
metrics.py
67 lines (49 loc) · 1.61 KB
/
metrics.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
"""
file : metrics
authors : 21112254, 16008937, 20175911, 21180859
"""
import torch
import torch.nn as nn
from sklearn.metrics import f1_score
import numpy as np
class F1(nn.Module):
"""
Class for implementation of F1 score.
"""
def __init__(self,task):
super(F1, self).__init__()
self.task = task
def forward(self, prediction, truth):
pred = np.argmax(prediction,axis=1)
truth = np.array(truth)
if self.task == 'genre':
return f1_score(truth,pred,labels=[0,1,2,3,4],average=None,zero_division=1)
else:
return f1_score(truth,pred,average='weighted',zero_division=1)
def get_metric(task):
"""
Find the metric function for a given task
Parameters
----------
task : string
name of task
Returns
-------
Metric: class for the corresponding metric function
"""
if task in ['genre','topic']:
return F1(task)
if task in ['violence','romantic','sadness', 'feelings','danceability','energy']:
return torch.nn.MSELoss()
class Metrics(nn.Module):
"""
Class to combine metrics for model.
Here we use the tasks list to calculate the given metrics for the seperate heads
"""
def __init__(self, tasks):
super(Metrics, self).__init__()
self.tasks = tasks
self.metric_fncts = torch.nn.ModuleDict({task: get_metric(task) for task in tasks})
def forward(self, prediction, truth):
metric_dict = {task: self.metric_fncts[task](prediction[task], truth[task]) for task in self.tasks}
return metric_dict