-
Notifications
You must be signed in to change notification settings - Fork 1
/
metrics.py
72 lines (54 loc) · 2.49 KB
/
metrics.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
# -*- coding: utf-8 -*-
#
# Developed by Alex Jercan <[email protected]>
#
# References:
# - https://github.com/XinJCheng/CSPN/blob/b3e487bdcdcd8a63333656e69b3268698e543181/cspn_pytorch/utils.py#L19
# - https://web.eecs.umich.edu/~fouhey/2016/evalSN/evalSN.html
#
from math import radians
import torch
import torch.nn.functional as F
class MetricFunction():
def __init__(self, batch_size) -> None:
self.batch_size = batch_size
self.total_size = 0
self.error_sum = {}
self.error_avg = {}
def evaluate(self, predictions, targets):
normal_p = predictions
normal_gt = targets
error_val = evaluate_error_normal(normal_p, normal_gt)
self.total_size += self.batch_size
self.error_avg = avg_error(self.error_sum, error_val, self.total_size, self.batch_size)
return self.error_avg
def show(self):
error = self.error_avg
format_str = ('======NORMALS=======\nMSE=%.4f\tRMSE=%.4f\tMAE=%.4f\tMME=%.4f\nTANGLE11.25=%.4f\tTANGLE22.5=%.4f\tTANGLE30.0=%.4f')
return format_str % (error['N_MSE'], error['N_RMSE'], error['N_MAE'], error['N_MME'], \
error['N_TANGLE11.25'], error['N_TANGLE22.5'], error['N_TANGLE30.0'])
def evaluate_error_normal(pred_normal, gt_normal):
error = {}
eps = 1e-7
pred_normal = F.normalize(pred_normal, p=2, dim=1)
gt_normal = F.normalize(gt_normal, p=2, dim=1)
dot_product = torch.mul(pred_normal, gt_normal).sum(dim=1)
angular_error = torch.acos(torch.clamp(dot_product, -1+eps, 1-eps))
error['N_MSE'] = torch.mean(torch.mul(angular_error, angular_error))
error['N_RMSE'] = torch.sqrt(error['N_MSE'])
error['N_MAE'] = torch.mean(angular_error)
error['N_MME'] = torch.median(angular_error)
error['N_TANGLE11.25'] = torch.mean((angular_error <= radians(11.25)).float())
error['N_TANGLE22.5'] = torch.mean((angular_error <= radians(22.5)).float())
error['N_TANGLE30.0'] = torch.mean((angular_error <= radians(30.0)).float())
return error
# avg the error
def avg_error(error_sum, error_val, total_size, batch_size):
error_avg = {}
for item, value in error_val.items():
error_sum[item] = error_sum.get(item, 0) + value * batch_size
error_avg[item] = error_sum[item] / float(total_size)
return error_avg
def print_single_error(epoch, loss, error):
format_str = ('%s\nEpoch: %d, loss=%s\n%s\n')
print (format_str % ('eval_avg_error', epoch, loss, error))