1
- import torch
2
- from collections import defaultdict
3
1
import re
2
+ from collections import defaultdict
3
+
4
+ import torch
4
5
5
- from fastNLP .core .dataset import DataSet
6
- from fastNLP .core .vocabulary import Vocabulary
7
6
from fastNLP .core .batch import Batch
7
+ from fastNLP .core .dataset import DataSet
8
8
from fastNLP .core .sampler import SequentialSampler
9
+ from fastNLP .core .vocabulary import Vocabulary
9
10
10
11
11
- class Processor :
12
+ class Processor ( object ) :
12
13
def __init__ (self , field_name , new_added_field_name ):
13
14
self .field_name = field_name
14
15
if new_added_field_name is None :
@@ -17,7 +18,7 @@ def __init__(self, field_name, new_added_field_name):
17
18
self .new_added_field_name = new_added_field_name
18
19
19
20
def process (self , * args , ** kwargs ):
20
- pass
21
+ raise NotImplementedError
21
22
22
23
def __call__ (self , * args , ** kwargs ):
23
24
return self .process (* args , ** kwargs )
@@ -132,27 +133,29 @@ def process(self, dataset):
132
133
133
134
134
135
class IndexerProcessor (Processor ):
135
- def __init__ (self , vocab , field_name , new_added_field_name , delete_old_field = False ):
136
+ def __init__ (self , vocab , field_name , new_added_field_name , delete_old_field = False , is_input = True ):
136
137
137
138
assert isinstance (vocab , Vocabulary ), "Only Vocabulary class is allowed, not {}." .format (type (vocab ))
138
139
139
140
super (IndexerProcessor , self ).__init__ (field_name , new_added_field_name )
140
141
self .vocab = vocab
141
142
self .delete_old_field = delete_old_field
143
+ self .is_input = is_input
142
144
143
145
def set_vocab (self , vocab ):
144
146
assert isinstance (vocab , Vocabulary ), "Only Vocabulary class is allowed, not {}." .format (type (vocab ))
145
147
146
148
self .vocab = vocab
147
149
148
150
def process (self , dataset ):
149
- assert isinstance (dataset , DataSet ), "Only Dataset class is allowed, not {}." .format (type (dataset ))
151
+ assert isinstance (dataset , DataSet ), "Only DataSet class is allowed, not {}." .format (type (dataset ))
150
152
for ins in dataset :
151
153
tokens = ins [self .field_name ]
152
154
index = [self .vocab .to_index (token ) for token in tokens ]
153
155
ins [self .new_added_field_name ] = index
154
156
155
- dataset ._set_need_tensor (** {self .new_added_field_name : True })
157
+ if self .is_input :
158
+ dataset .set_input (self .new_added_field_name )
156
159
157
160
if self .delete_old_field :
158
161
dataset .delete_field (self .field_name )
@@ -161,6 +164,9 @@ def process(self, dataset):
161
164
162
165
163
166
class VocabProcessor (Processor ):
167
+ """Build vocabulary with a field in the data set.
168
+
169
+ """
164
170
def __init__ (self , field_name ):
165
171
super (VocabProcessor , self ).__init__ (field_name , None )
166
172
self .vocab = Vocabulary ()
@@ -178,17 +184,20 @@ def get_vocab(self):
178
184
179
185
180
186
class SeqLenProcessor (Processor ):
181
- def __init__ (self , field_name , new_added_field_name = 'seq_lens' ):
187
+ def __init__ (self , field_name , new_added_field_name = 'seq_lens' , is_input = True ):
182
188
super (SeqLenProcessor , self ).__init__ (field_name , new_added_field_name )
189
+ self .is_input = is_input
183
190
184
191
def process (self , dataset ):
185
192
assert isinstance (dataset , DataSet ), "Only Dataset class is allowed, not {}." .format (type (dataset ))
186
193
for ins in dataset :
187
194
length = len (ins [self .field_name ])
188
195
ins [self .new_added_field_name ] = length
189
- dataset ._set_need_tensor (** {self .new_added_field_name : True })
196
+ if self .is_input :
197
+ dataset .set_input (self .new_added_field_name )
190
198
return dataset
191
199
200
+
192
201
class ModelProcessor (Processor ):
193
202
def __init__ (self , model , seq_len_field_name = 'seq_lens' , batch_size = 32 ):
194
203
"""
@@ -238,6 +247,7 @@ def set_model_device(self, device):
238
247
device = torch .device (device )
239
248
self .model .to (device )
240
249
250
+
241
251
class Index2WordProcessor (Processor ):
242
252
def __init__ (self , vocab , field_name , new_added_field_name ):
243
253
super (Index2WordProcessor , self ).__init__ (field_name , new_added_field_name )
@@ -251,26 +261,28 @@ def process(self, dataset):
251
261
252
262
253
263
class SetTensorProcessor (Processor ):
264
+ # TODO: remove it. It is strange.
254
265
def __init__ (self , field_dict , default = False ):
255
266
super (SetTensorProcessor , self ).__init__ (None , None )
256
267
self .field_dict = field_dict
257
268
self .default = default
258
269
259
270
def process (self , dataset ):
260
- set_dict = {name : self .default for name in dataset .get_fields ().keys ()}
271
+ set_dict = {name : self .default for name in dataset .get_all_fields ().keys ()}
261
272
set_dict .update (self .field_dict )
262
273
dataset ._set_need_tensor (** set_dict )
263
274
return dataset
264
275
265
276
266
277
class SetIsTargetProcessor (Processor ):
278
+ # TODO; remove it.
267
279
def __init__ (self , field_dict , default = False ):
268
280
super (SetIsTargetProcessor , self ).__init__ (None , None )
269
281
self .field_dict = field_dict
270
282
self .default = default
271
283
272
284
def process (self , dataset ):
273
- set_dict = {name : self .default for name in dataset .get_fields ().keys ()}
285
+ set_dict = {name : self .default for name in dataset .get_all_fields ().keys ()}
274
286
set_dict .update (self .field_dict )
275
287
dataset .set_target (** set_dict )
276
288
return dataset
0 commit comments