forked from crystina-z/CEDR_tpu
-
Notifications
You must be signed in to change notification settings - Fork 0
/
data.py
162 lines (136 loc) · 5.4 KB
/
data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
import random
from tqdm import tqdm
import torch
import torch_xla.core.xla_model as xm
device = xm.xla_device()
# device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print('device in data.py:', device)
def read_datafiles(files):
queries = {}
docs = {}
for file in files:
for line in tqdm(file, desc='loading datafile (by line)', leave=False):
cols = line.rstrip().split('\t')
if len(cols) != 3:
tqdm.write(f'skipping line: `{line.rstrip()}`')
continue
c_type, c_id, c_text = cols
assert c_type in ('query', 'doc')
if c_type == 'query':
queries[c_id] = c_text
if c_type == 'doc':
docs[c_id] = c_text
return queries, docs
def read_qrels_dict(file):
result = {}
for line in tqdm(file, desc='loading qrels (by line)', leave=False):
qid, _, docid, score = line.split()
result.setdefault(qid, {})[docid] = int(score)
return result
def read_run_dict(file):
result = {}
for line in tqdm(file, desc='loading run (by line)', leave=False):
qid, _, docid, rank, score, _ = line.split()
result.setdefault(qid, {})[docid] = float(score)
return result
def read_pairs_dict(file):
result = {}
for line in tqdm(file, desc='loading pairs (by line)', leave=False):
qid, docid = line.split()
result.setdefault(qid, {})[docid] = 1
return result
def iter_train_pairs(model, dataset, train_pairs, qrels, batch_size):
batch = {'query_id': [], 'doc_id': [], 'query_tok': [], 'doc_tok': []}
for qid, did, query_tok, doc_tok in _iter_train_pairs(model, dataset, train_pairs, qrels):
batch['query_id'].append(qid)
batch['doc_id'].append(did)
batch['query_tok'].append(query_tok)
batch['doc_tok'].append(doc_tok)
if len(batch['query_id']) // 2 == batch_size:
yield _pack_n_ship(batch)
batch = {'query_id': [], 'doc_id': [], 'query_tok': [], 'doc_tok': []}
def _iter_train_pairs(model, dataset, train_pairs, qrels):
ds_queries, ds_docs = dataset
while True:
qids = list(train_pairs.keys())
random.shuffle(qids)
for qid in qids:
pos_ids = [did for did in train_pairs[qid] if qrels.get(qid, {}).get(did, 0) > 0]
if len(pos_ids) == 0:
continue
pos_id = random.choice(pos_ids)
pos_ids_lookup = set(pos_ids)
pos_ids = set(pos_ids)
neg_ids = [did for did in train_pairs[qid] if did not in pos_ids_lookup]
if len(neg_ids) == 0:
continue
neg_id = random.choice(neg_ids)
query_tok = model.tokenize(ds_queries[qid])
pos_doc = ds_docs.get(pos_id)
neg_doc = ds_docs.get(neg_id)
if pos_doc is None:
tqdm.write(f'missing doc {pos_id}! Skipping')
continue
if neg_doc is None:
tqdm.write(f'missing doc {neg_id}! Skipping')
continue
yield qid, pos_id, query_tok, model.tokenize(pos_doc)
yield qid, neg_id, query_tok, model.tokenize(neg_doc)
def iter_valid_records(model, dataset, run, batch_size):
batch = {'query_id': [], 'doc_id': [], 'query_tok': [], 'doc_tok': []}
for qid, did, query_tok, doc_tok in _iter_valid_records(model, dataset, run):
batch['query_id'].append(qid)
batch['doc_id'].append(did)
batch['query_tok'].append(query_tok)
batch['doc_tok'].append(doc_tok)
if len(batch['query_id']) == batch_size:
yield _pack_n_ship(batch)
batch = {'query_id': [], 'doc_id': [], 'query_tok': [], 'doc_tok': []}
# final batch
if len(batch['query_id']) > 0:
yield _pack_n_ship(batch)
def _iter_valid_records(model, dataset, run):
ds_queries, ds_docs = dataset
for qid in run:
query_tok = model.tokenize(ds_queries[qid])
for did in run[qid]:
doc = ds_docs.get(did)
if doc is None:
tqdm.write(f'missing doc {did}! Skipping')
continue
doc_tok = model.tokenize(doc)
yield qid, did, query_tok, doc_tok
def _pack_n_ship(batch):
QLEN = 20
MAX_DLEN = 800
# DLEN = min(MAX_DLEN, max(len(b) for b in batch['doc_tok']))
DLEN = MAX_DLEN
return {
'query_id': batch['query_id'],
'doc_id': batch['doc_id'],
'query_tok': _pad_crop(batch['query_tok'], QLEN),
'doc_tok': _pad_crop(batch['doc_tok'], DLEN),
'query_mask': _mask(batch['query_tok'], QLEN),
'doc_mask': _mask(batch['doc_tok'], DLEN),
}
def _pad_crop(items, l):
result = []
for item in items:
if len(item) < l:
# item = item + [-1] * (l - len(item)) # to avoid toks[toks == -1] = 0 in modeling.py
item = item + [0] * (l - len(item))
if len(item) >= l:
item = item[:l]
result.append(item)
# return torch.tensor(result).long().cuda()
return torch.tensor(result).long().to(device)
def _mask(items, l):
result = []
for item in items:
if len(item) < l:
mask = [1. for _ in item] + ([0.] * (l - len(item)))
if len(item) >= l:
mask = [1. for _ in item[:l]]
result.append(mask)
# return torch.tensor(result).float().cuda()
return torch.tensor(result).float().to(device)