-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathmetrics.py
128 lines (122 loc) · 5 KB
/
metrics.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
# -*- coding: utf-8 -*-
"""
@author: truthless
"""
import numpy as np
from dbquery import DBQuery
class Evaluator(object):
def __init__(self, data_dir, cfg):
self.db = DBQuery(data_dir)
self.cfg = cfg
def _init_dict(self):
dic = {}
for domain in self.cfg.belief_domains:
dic[domain] = {}
return dic
def match_rate_goal(self, goal, booked_entity):
"""
judge if the selected entity meets the constraint
"""
score = []
for domain in self.cfg.belief_domains:
if goal['book'][domain]:
tot = len(goal['inform'][domain].keys())
if tot == 0:
continue
entity_id = booked_entity[domain]
if not entity_id:
score.append(0)
continue
if domain == 'taxi':
score.append(1)
continue
match = 0
entity = self.db.dbs[domain][int(entity_id)]
for k, v in goal['inform'][domain].items():
k = self.cfg.mapping[domain][k]
if k == 'leaveAt':
try:
v_constraint = int(v.split(':')[0]) * 100 + int(v.split(':')[1])
v_select = int(entity['leaveAt'].split(':')[0]) * 100 + int(entity['leaveAt'].split(':')[1])
if v_constraint <= v_select:
match += 1
except (ValueError, IndexError):
match += 1
elif k == 'arriveBy':
try:
v_constraint = int(v.split(':')[0]) * 100 + int(v.split(':')[1])
v_select = int(entity['arriveBy'].split(':')[0]) * 100 + int(entity['arriveBy'].split(':')[1])
if v_constraint >= v_select:
match += 1
except (ValueError, IndexError):
match += 1
else:
if v.strip() == entity[k].strip():
match += 1
score.append(match / tot)
return score
def inform_F1_goal(self, goal, sys_history):
"""
judge if all the requested information is answered
"""
inform_slot = {}
for domain in self.cfg.belief_domains:
inform_slot[domain] = set()
for da in sys_history:
domain, intent, slot, p = da.split('-')
if intent in ['inform', 'recommend', 'offerbook', 'offerbooked'] and slot != 'none' and domain in self.cfg.belief_domains:
inform_slot[domain].add(slot)
TP, FP, FN = 0, 0, 0
for domain in self.cfg.belief_domains:
for k in goal['request'][domain]:
if k in inform_slot[domain]:
TP += 1
else:
FN += 1
for k in inform_slot[domain]:
# exclude slots that are informed by users
if k not in goal['request'][domain] and k not in goal['inform'][domain] and k in self.cfg.requestable[domain]:
FP += 1
return TP, FP, FN
def match_rate(self, metadata, aggregate=False):
booked_entity = metadata['belief_state']['booked']
"""
goal = {'book':{}, 'inform':self._init_dict()}
goal['book'] = metadata['user_goal']['book']
for da, v in metadata['history']['user'].items():
d, i, s, p = da.split('-')
if i in ['inform', 'recommend', 'offerbook', 'offerbooked'] and s in self.cfg.mapping[d]:
goal['inform'][d][s] = v
"""
goal = metadata['user_goal']
score = self.match_rate_goal(goal, booked_entity)
if not aggregate:
return score
else:
return np.mean(score) if score else None
def inform_F1(self, metadata, aggregate=False):
sys_history = dict(metadata['history']['sys'], **metadata['last_sys_action'])
"""
goal = {'request':self._init_dict(), 'inform':self._init_dict()}
for da, v in metadata['history']['user'].items():
d, i, s, p = da.split('-')
if i in ['inform', 'recommend', 'offerbook', 'offerbooked'] and s in self.cfg.mapping[d]:
goal['inform'][d][s] = v
elif i == 'request':
goal['request'][d][s] = v
"""
goal = metadata['user_goal']
TP, FP, FN = self.inform_F1_goal(goal, sys_history)
if not aggregate:
return [TP, FP, FN]
else:
try:
rec = TP / (TP + FN)
except ZeroDivisionError:
return None, None, None
try:
prec = TP / (TP + FP)
F1 = 2 * prec * rec / (prec + rec)
except ZeroDivisionError:
return 0, rec, 0
return prec, rec, F1