-
Notifications
You must be signed in to change notification settings - Fork 0
/
metric.py
executable file
·98 lines (88 loc) · 2.96 KB
/
metric.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import numpy as np
import mxnet as mx
class AccMetric(mx.metric.EvalMetric):
def __init__(self):
self.axis = 1
super(AccMetric, self).__init__(
'acc', axis=self.axis,
output_names=None, label_names=None)
self.losses = []
self.count = 0
def update(self, labels, preds):
self.count+=1
label = labels[0]
pred_label = preds[1]
#print('ACC', label.shape, pred_label.shape)
if pred_label.shape != label.shape:
pred_label = mx.ndarray.argmax(pred_label, axis=self.axis)
pred_label = pred_label.asnumpy().astype('int32').flatten()
label = label.asnumpy()
if label.ndim==2:
label = label[:,0]
label = label.astype('int32').flatten()
assert label.shape==pred_label.shape
self.sum_metric += (pred_label.flat == label.flat).sum()
self.num_inst += len(pred_label.flat)
## preds[0] classier loss
class LossValueMetric(mx.metric.EvalMetric):
def __init__(self):
self.axis = 1
super(LossValueMetric, self).__init__(
'lossvalue', axis=self.axis,
output_names=None, label_names=None)
self.losses = []
def update(self, labels, preds):
#label = labels[0].asnumpy()
# print("^^^^^^^^^^^^len(preds) ^^^^^^^^^^^^^^^^^")
# print (len(preds))
# # ((16, 512), (16, 97768), (1,), (16, 1)) this is symbo return outlist include
# # embedding softmax ce_loss me_add_loss
# print (preds[0].shape,preds[1].shape,preds[2].shape)#,preds[3].shape)
pred = preds[-1].asnumpy() #celoss
#print('in loss', pred.shape)
#print(pred)
loss = pred[0]
self.sum_metric += loss
self.num_inst += 1.0
#gt_label = preds[-2].asnumpy()
#print(gt_label)
class AccMetric_d(mx.metric.EvalMetric):
def __init__(self):
self.axis = 1
super(AccMetric_d, self).__init__(
'acc', axis=self.axis,
output_names=None, label_names=None)
self.losses = []
self.count = 0
def update(self, labels, preds):
self.count+=1
label = labels[0]
pred_label = preds[0]
#print('ACC', label.shape, pred_label.shape)
if pred_label.shape != label.shape:
pred_label = mx.ndarray.argmax(pred_label, axis=self.axis)
pred_label = pred_label.asnumpy().astype('int32').flatten()
label = label.asnumpy()
if label.ndim==2:
label = label[:,0]
label = label.astype('int32').flatten()
assert label.shape==pred_label.shape
self.sum_metric += (pred_label.flat == label.flat).sum()
self.num_inst += len(pred_label.flat)
class LossValueMetric_d(mx.metric.EvalMetric):
def __init__(self):
self.axis = 1
super(LossValueMetric_d, self).__init__(
'lossvalue', axis=self.axis,
output_names=None, label_names=None)
self.losses = []
def update(self, labels, preds):
#label = labels[0].asnumpy()
pred = preds[-1].asnumpy()
#print('in loss', pred.shape)
#print(pred)
loss = pred[0]
self.sum_metric += loss
self.num_inst += 1.0
#gt_label = preds[-2].asnumpy()
#print(gt_label)