forked from jlrussin/transformer_scan
-
Notifications
You must be signed in to change notification settings - Fork 0
/
test.py
23 lines (20 loc) · 780 Bytes
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
# Function for testing models
import numpy as np
import torch
def test(data, model, pad_idx, device, args):
model.eval()
with torch.no_grad():
all_correct_trials = [] # list of booleans indicating whether correct
for batch in data:
out, attn_wts = model(batch.src, batch.trg)
preds = torch.argmax(out,dim=2)
correct_pred = preds == batch.trg
correct_pred = correct_pred.cpu().numpy()
mask = batch.trg == pad_idx # mask out padding
mask = mask.cpu().numpy()
correct = np.logical_or(mask,correct_pred)
correct = correct.all(0).tolist()
all_correct_trials += correct
accuracy = np.mean(all_correct_trials)
model.train()
return accuracy