Skip to content

Commit 15262bd

Browse files
authored
Merge pull request #107 from yhcc/dataset
A brand new version update (0.1.1)
2 parents abf840c + 26a4324 commit 15262bd

File tree

96 files changed

+4455
-3173
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

96 files changed

+4455
-3173
lines changed

examples/readme_example.py

Lines changed: 0 additions & 75 deletions
This file was deleted.

fastNLP/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from .core import *
2+
from . import models
3+
from . import modules
File renamed without changes.

fastNLP/api/api.py

Lines changed: 314 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,314 @@
1+
import warnings
2+
3+
import torch
4+
5+
warnings.filterwarnings('ignore')
6+
import os
7+
8+
from fastNLP.core.dataset import DataSet
9+
10+
from fastNLP.api.model_zoo import load_url
11+
from fastNLP.api.processor import ModelProcessor
12+
from reproduction.chinese_word_segment.cws_io.cws_reader import ConlluCWSReader
13+
from reproduction.pos_tag_model.pos_io.pos_reader import ConlluPOSReader
14+
from reproduction.Biaffine_parser.util import ConllxDataLoader, add_seg_tag
15+
from fastNLP.core.instance import Instance
16+
from fastNLP.core.sampler import SequentialSampler
17+
from fastNLP.core.batch import Batch
18+
from reproduction.chinese_word_segment.utils import calculate_pre_rec_f1
19+
from fastNLP.api.pipeline import Pipeline
20+
from fastNLP.core.metrics import SeqLabelEvaluator2
21+
from fastNLP.core.tester import Tester
22+
23+
# TODO add pretrain urls
24+
model_urls = {
25+
26+
}
27+
28+
29+
class API:
30+
def __init__(self):
31+
self.pipeline = None
32+
33+
def predict(self, *args, **kwargs):
34+
raise NotImplementedError
35+
36+
def load(self, path, device):
37+
if os.path.exists(os.path.expanduser(path)):
38+
_dict = torch.load(path, map_location='cpu')
39+
else:
40+
_dict = load_url(path, map_location='cpu')
41+
self.pipeline = _dict['pipeline']
42+
self._dict = _dict
43+
for processor in self.pipeline.pipeline:
44+
if isinstance(processor, ModelProcessor):
45+
processor.set_model_device(device)
46+
47+
48+
class POS(API):
49+
"""FastNLP API for Part-Of-Speech tagging.
50+
51+
"""
52+
53+
def __init__(self, model_path=None, device='cpu'):
54+
super(POS, self).__init__()
55+
if model_path is None:
56+
model_path = model_urls['pos']
57+
58+
self.load(model_path, device)
59+
60+
def predict(self, content):
61+
"""
62+
63+
:param content: list of list of str. Each string is a token(word).
64+
:return answer: list of list of str. Each string is a tag.
65+
"""
66+
if not hasattr(self, 'pipeline'):
67+
raise ValueError("You have to load model first.")
68+
69+
sentence_list = []
70+
# 1. 检查sentence的类型
71+
if isinstance(content, str):
72+
sentence_list.append(content)
73+
elif isinstance(content, list):
74+
sentence_list = content
75+
76+
# 2. 组建dataset
77+
dataset = DataSet()
78+
dataset.add_field('words', sentence_list)
79+
80+
# 3. 使用pipeline
81+
self.pipeline(dataset)
82+
83+
output = dataset['word_pos_output'].content
84+
if isinstance(content, str):
85+
return output[0]
86+
elif isinstance(content, list):
87+
return output
88+
89+
def test(self, filepath):
90+
91+
tag_proc = self._dict['tag_indexer']
92+
93+
model = self.pipeline.pipeline[2].model
94+
pipeline = self.pipeline.pipeline[0:2]
95+
pipeline.append(tag_proc)
96+
pp = Pipeline(pipeline)
97+
98+
reader = ConlluPOSReader()
99+
te_dataset = reader.load(filepath)
100+
101+
evaluator = SeqLabelEvaluator2('word_seq_origin_len')
102+
end_tagidx_set = set()
103+
tag_proc.vocab.build_vocab()
104+
for key, value in tag_proc.vocab.word2idx.items():
105+
if key.startswith('E-'):
106+
end_tagidx_set.add(value)
107+
if key.startswith('S-'):
108+
end_tagidx_set.add(value)
109+
evaluator.end_tagidx_set = end_tagidx_set
110+
111+
default_valid_args = {"batch_size": 64,
112+
"use_cuda": True, "evaluator": evaluator}
113+
114+
pp(te_dataset)
115+
te_dataset.set_target(truth=True)
116+
117+
tester = Tester(**default_valid_args)
118+
119+
test_result = tester.test(model, te_dataset)
120+
121+
f1 = round(test_result['F'] * 100, 2)
122+
pre = round(test_result['P'] * 100, 2)
123+
rec = round(test_result['R'] * 100, 2)
124+
# print("f1:{:.2f}, pre:{:.2f}, rec:{:.2f}".format(f1, pre, rec))
125+
126+
return f1, pre, rec
127+
128+
129+
class CWS(API):
130+
def __init__(self, model_path=None, device='cpu'):
131+
super(CWS, self).__init__()
132+
if model_path is None:
133+
model_path = model_urls['cws']
134+
135+
self.load(model_path, device)
136+
137+
def predict(self, content):
138+
139+
if not hasattr(self, 'pipeline'):
140+
raise ValueError("You have to load model first.")
141+
142+
sentence_list = []
143+
# 1. 检查sentence的类型
144+
if isinstance(content, str):
145+
sentence_list.append(content)
146+
elif isinstance(content, list):
147+
sentence_list = content
148+
149+
# 2. 组建dataset
150+
dataset = DataSet()
151+
dataset.add_field('raw_sentence', sentence_list)
152+
153+
# 3. 使用pipeline
154+
self.pipeline(dataset)
155+
156+
output = dataset['output'].content
157+
if isinstance(content, str):
158+
return output[0]
159+
elif isinstance(content, list):
160+
return output
161+
162+
def test(self, filepath):
163+
164+
tag_proc = self._dict['tag_indexer']
165+
cws_model = self.pipeline.pipeline[-2].model
166+
pipeline = self.pipeline.pipeline[:5]
167+
168+
pipeline.insert(1, tag_proc)
169+
pp = Pipeline(pipeline)
170+
171+
reader = ConlluCWSReader()
172+
173+
# te_filename = '/home/hyan/ctb3/test.conllx'
174+
te_dataset = reader.load(filepath)
175+
pp(te_dataset)
176+
177+
batch_size = 64
178+
te_batcher = Batch(te_dataset, batch_size, SequentialSampler(), use_cuda=False)
179+
pre, rec, f1 = calculate_pre_rec_f1(cws_model, te_batcher, type='bmes')
180+
f1 = round(f1 * 100, 2)
181+
pre = round(pre * 100, 2)
182+
rec = round(rec * 100, 2)
183+
# print("f1:{:.2f}, pre:{:.2f}, rec:{:.2f}".format(f1, pre, rec))
184+
185+
return f1, pre, rec
186+
187+
188+
class Parser(API):
189+
def __init__(self, model_path=None, device='cpu'):
190+
super(Parser, self).__init__()
191+
if model_path is None:
192+
model_path = model_urls['parser']
193+
194+
self.load(model_path, device)
195+
196+
def predict(self, content):
197+
if not hasattr(self, 'pipeline'):
198+
raise ValueError("You have to load model first.")
199+
200+
sentence_list = []
201+
# 1. 检查sentence的类型
202+
if isinstance(content, str):
203+
sentence_list.append(content)
204+
elif isinstance(content, list):
205+
sentence_list = content
206+
207+
# 2. 组建dataset
208+
dataset = DataSet()
209+
dataset.add_field('words', sentence_list)
210+
# dataset.add_field('tag', sentence_list)
211+
212+
# 3. 使用pipeline
213+
self.pipeline(dataset)
214+
for ins in dataset:
215+
ins['heads'] = ins['heads'].tolist()
216+
217+
return dataset['heads'], dataset['labels']
218+
219+
def test(self, filepath):
220+
data = ConllxDataLoader().load(filepath)
221+
ds = DataSet()
222+
for ins1, ins2 in zip(add_seg_tag(data), data):
223+
ds.append(Instance(words=ins1[0], tag=ins1[1],
224+
gold_words=ins2[0], gold_pos=ins2[1],
225+
gold_heads=ins2[2], gold_head_tags=ins2[3]))
226+
227+
pp = self.pipeline
228+
for p in pp:
229+
if p.field_name == 'word_list':
230+
p.field_name = 'gold_words'
231+
elif p.field_name == 'pos_list':
232+
p.field_name = 'gold_pos'
233+
pp(ds)
234+
head_cor, label_cor, total = 0, 0, 0
235+
for ins in ds:
236+
head_gold = ins['gold_heads']
237+
head_pred = ins['heads']
238+
length = len(head_gold)
239+
total += length
240+
for i in range(length):
241+
head_cor += 1 if head_pred[i] == head_gold[i] else 0
242+
uas = head_cor / total
243+
print('uas:{:.2f}'.format(uas))
244+
245+
for p in pp:
246+
if p.field_name == 'gold_words':
247+
p.field_name = 'word_list'
248+
elif p.field_name == 'gold_pos':
249+
p.field_name = 'pos_list'
250+
251+
return uas
252+
253+
254+
class Analyzer:
255+
def __init__(self, device='cpu'):
256+
257+
self.cws = CWS(device=device)
258+
self.pos = POS(device=device)
259+
self.parser = Parser(device=device)
260+
261+
def predict(self, content, seg=False, pos=False, parser=False):
262+
if seg is False and pos is False and parser is False:
263+
seg = True
264+
output_dict = {}
265+
if seg:
266+
seg_output = self.cws.predict(content)
267+
output_dict['seg'] = seg_output
268+
if pos:
269+
pos_output = self.pos.predict(content)
270+
output_dict['pos'] = pos_output
271+
if parser:
272+
parser_output = self.parser.predict(content)
273+
output_dict['parser'] = parser_output
274+
275+
return output_dict
276+
277+
def test(self, filepath):
278+
output_dict = {}
279+
if self.seg:
280+
seg_output = self.cws.test(filepath)
281+
output_dict['seg'] = seg_output
282+
if self.pos:
283+
pos_output = self.pos.test(filepath)
284+
output_dict['pos'] = pos_output
285+
if self.parser:
286+
parser_output = self.parser.test(filepath)
287+
output_dict['parser'] = parser_output
288+
289+
return output_dict
290+
291+
292+
if __name__ == "__main__":
293+
# pos_model_path = '../../reproduction/pos_tag_model/pos_crf.pkl'
294+
# pos = POS(device='cpu')
295+
# s = ['编者按:7月12日,英国航空航天系统公司公布了该公司研制的第一款高科技隐形无人机雷电之神。' ,
296+
# '这款飞行从外型上来看酷似电影中的太空飞行器,据英国方面介绍,可以实现洲际远程打击。',
297+
# '那么这款无人机到底有多厉害?']
298+
# print(pos.test('/Users/yh/Desktop/test_data/pos_test.conll'))
299+
# print(pos.predict(s))
300+
301+
# cws_model_path = '../../reproduction/chinese_word_segment/models/cws_crf.pkl'
302+
# cws = CWS(device='cpu')
303+
# s = ['本品是一个抗酸抗胆汁的胃黏膜保护剂' ,
304+
# '这款飞行从外型上来看酷似电影中的太空飞行器,据英国方面介绍,可以实现洲际远程打击。',
305+
# '那么这款无人机到底有多厉害?']
306+
# print(cws.test('/Users/yh/Desktop/test_data/cws_test.conll'))
307+
# print(cws.predict(s))
308+
309+
parser = Parser(device='cpu')
310+
# print(parser.test('/Users/yh/Desktop/test_data/parser_test2.conll'))
311+
s = ['编者按:7月12日,英国航空航天系统公司公布了该公司研制的第一款高科技隐形无人机雷电之神。',
312+
'这款飞行从外型上来看酷似电影中的太空飞行器,据英国方面介绍,可以实现洲际远程打击。',
313+
'那么这款无人机到底有多厉害?']
314+
print(parser.predict(s))

0 commit comments

Comments
 (0)