-
Notifications
You must be signed in to change notification settings - Fork 2
/
evaluate.py
120 lines (98 loc) · 4.4 KB
/
evaluate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
# ====================================
# @Project :insights-span-aste
# @IDE :PyCharm
# @Author :Hao,Wireless Zhiheng
# @Email :[email protected]
# @Date :03/08/2022 14:46
# ====================================
import argparse
import os
import random
from functools import partial
import numpy as np
import torch
from torch.utils.data import DataLoader
from transformers import BertTokenizer
from models.collate import gold_labels, collate_fn
from models.metrics import SpanEvaluator
from models.model import SpanAsteModel
from utils.dataset import CustomDataset
from utils.processor import Res15DataProcessor
from utils.tager import SpanLabel, RelationLabel
def set_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
def evaluate(model, metric, data_loader, device):
"""
Given a dataset, it evals model and computes the metric.
Args:
model(obj:`paddle.nn.Layer`): A model to classify texts.
metric(obj:`paddle.metric.Metric`): The evaluation metric.
data_loader(obj:`paddle.io.DataLoader`): The dataset loader which generates batches.
"""
model.eval()
metric.reset()
with torch.no_grad():
for batch_ix, batch in enumerate(data_loader):
input_ids, attention_mask, token_type_ids, spans, relations, span_labels, relation_labels, seq_len = batch
input_ids = torch.tensor(input_ids, device=device)
attention_mask = torch.tensor(attention_mask, device=device)
token_type_ids = torch.tensor(token_type_ids, device=device)
# forward
spans_probability, span_indices, relations_probability, candidate_indices = model(
input_ids, attention_mask, token_type_ids, seq_len)
gold_span_indices, gold_span_labels = gold_labels(span_indices, spans, span_labels)
gold_relation_indices, gold_relation_labels = gold_labels(candidate_indices, relations, relation_labels)
num_correct, num_infer, num_label = metric.compute(relations_probability.cpu(),
torch.tensor(gold_relation_labels))
metric.update(num_correct, num_infer, num_label)
precision, recall, f1 = metric.accumulate()
model.train()
return precision, recall, f1
def do_eval():
set_seed(1024)
device = "cuda" if torch.cuda.is_available() else "cpu"
if torch.cuda.is_available():
torch.cuda.empty_cache()
print(f"using device:{device}")
# tokenizer
tokenizer = BertTokenizer.from_pretrained(args.bert_model)
# create processor
processor = Res15DataProcessor(tokenizer, args.max_seq_len)
print("Loading Train & Eval Dataset...")
# Load dataset
test_dataset = CustomDataset("dev", args.test_path, processor, tokenizer, args.max_seq_len)
print("Construct Dataloader...")
test_dataloader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=True, collate_fn=collate_fn)
print("Building SPAN-ASTE model...")
# get dimension of target and relation
target_dim, relation_dim = len(SpanLabel), len(RelationLabel)
# build span-aste model
model = SpanAsteModel(
args.bert_model,
target_dim,
relation_dim,
device=device
)
model.load_state_dict(torch.load(os.path.join(args.model_path, "model.pt"), map_location=torch.device(device)))
model.to(device)
metric = SpanEvaluator()
precision, recall, f1 = evaluate(model, metric, test_dataloader, device)
print("-----------------------------")
print("Evaluation Precision: %.5f | Recall: %.5f | F1: %.5f" %
(precision, recall, f1))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--bert_model", type=str, default=None, help="The name of bert.")
parser.add_argument("--model_path", type=str, default=None, help="The path of saved model that you want to load.")
parser.add_argument("--test_path", type=str, default=None, help="The path of test set.")
parser.add_argument("--batch_size", type=int, default=1, help="Batch size per GPU/CPU for training.")
parser.add_argument("--max_seq_len", type=int, default=512,
help="The maximum total input sequence length after tokenization.")
args = parser.parse_args()
do_eval()