forked from ArneBinder/pytorch-ie
-
Notifications
You must be signed in to change notification settings - Fork 0
/
simple_transformer_text_classification.py
225 lines (186 loc) · 7.9 KB
/
simple_transformer_text_classification.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
"""
workflow:
document
-> (input_encoding, target_encoding) -> task_encoding
-> model_encoding -> model_output
-> task_output
-> document
"""
import logging
from typing import Any, Dict, Iterator, MutableMapping, Optional, Sequence, Tuple, TypedDict, Union
import numpy as np
import torch
from transformers import AutoTokenizer
from transformers.file_utils import PaddingStrategy
from transformers.tokenization_utils_base import TruncationStrategy
from typing_extensions import TypeAlias
from pytorch_ie.annotations import Label
from pytorch_ie.core import AnnotationList, TaskEncoding, TaskModule, annotation_field
from pytorch_ie.documents import TextDocument
from pytorch_ie.models.transformer_text_classification import (
TransformerTextClassificationModelBatchOutput,
TransformerTextClassificationModelStepBatchEncoding,
)
logger = logging.getLogger(__name__)
class TextDocumentWithLabel(TextDocument):
label: AnnotationList[Label] = annotation_field()
class TaskOutput(TypedDict, total=False):
label: str
probability: float
# Define task specific input and output types
DocumentType: TypeAlias = TextDocumentWithLabel
InputEncodingType: TypeAlias = MutableMapping[str, Any]
TargetEncodingType: TypeAlias = int
ModelEncodingType: TypeAlias = TransformerTextClassificationModelStepBatchEncoding
ModelOutputType: TypeAlias = TransformerTextClassificationModelBatchOutput
TaskOutputType: TypeAlias = TaskOutput
# This should be the same for all taskmodules
TaskEncodingType: TypeAlias = TaskEncoding[DocumentType, InputEncodingType, TargetEncodingType]
TaskModuleType: TypeAlias = TaskModule[
DocumentType,
InputEncodingType,
TargetEncodingType,
ModelEncodingType,
ModelOutputType,
TaskOutputType,
]
@TaskModule.register()
class SimpleTransformerTextClassificationTaskModule(TaskModuleType):
# If these attributes are set, the taskmodule is considered as prepared. They should be calculated
# within _prepare() and are dumped automatically when saving the taskmodule with save_pretrained().
PREPARED_ATTRIBUTES = ["label_to_id"]
def __init__(
self,
tokenizer_name_or_path: str,
padding: Union[bool, str, PaddingStrategy] = True,
truncation: Union[bool, str, TruncationStrategy] = True,
max_length: Optional[int] = None,
pad_to_multiple_of: Optional[int] = None,
label_to_id: Optional[Dict[str, int]] = None,
**kwargs,
) -> None:
# Important: Remaining keyword arguments need to be passed to super.
super().__init__(**kwargs)
# Save all passed arguments. They will be available via self._config().
self.save_hyperparameters()
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path)
# some tokenization and padding parameters
self.truncation = truncation
self.max_length = max_length
self.padding = padding
self.pad_to_multiple_of = pad_to_multiple_of
# this will be prepared from the data or loaded from the config
self.label_to_id = label_to_id
def _prepare(self, documents: Sequence[DocumentType]) -> None:
"""
Prepare the task module with training documents, e.g. collect all possible labels.
This method needs to set all attributes listed in PREPARED_ATTRIBUTES.
"""
# create the label-to-id mapping
labels = set()
for document in documents:
# all annotations of a document are hold in list like containers,
# so we have to take its first element
label_annotation = document.label[0]
labels.add(label_annotation.label)
# create the mapping, but spare the first index for the "O" (outside) class
self.label_to_id = {label: i + 1 for i, label in enumerate(sorted(labels))}
self.label_to_id["O"] = 0
def _post_prepare(self):
"""
Any further preparation logic that requires the result of _prepare(). But its result is not serialized
with the taskmodule.
"""
self.id_to_label = {v: k for k, v in self.label_to_id.items()}
def encode_input(
self,
document: DocumentType,
is_training: bool = False,
) -> TaskEncodingType:
"""
Create one or multiple task encodings for the given document.
"""
# tokenize the input text, this will be the input
inputs = self.tokenizer(
document.text,
# we do not pad here, this will be done in collate()
# when the actual batches are created
padding=False,
truncation=self.truncation,
max_length=self.max_length,
)
return TaskEncoding(
document=document,
inputs=inputs,
)
def encode_target(
self,
task_encoding: TaskEncodingType,
) -> TargetEncodingType:
"""
Create a target for a task encoding. This may use any annotations of the underlying document.
"""
# as above, all annotations are hold in lists, so we have to take its first element
label_annotation = task_encoding.document.label[0]
# translate the textual label to the target id
assert self.label_to_id is not None
return self.label_to_id[label_annotation.label]
def collate(self, task_encodings: Sequence[TaskEncodingType]) -> ModelEncodingType:
"""
Convert a list of task encodings to a batch that will be passed to the model.
"""
# get the inputs from the task encodings
input_features = [task_encoding.inputs for task_encoding in task_encodings]
# pad the inputs and return torch tensors
inputs = self.tokenizer.pad(
input_features,
padding=self.padding,
max_length=self.max_length,
pad_to_multiple_of=self.pad_to_multiple_of,
return_tensors="pt",
)
if task_encodings[0].has_targets:
# convert the targets (label ids) to a tensor
targets = torch.tensor(
[task_encoding.targets for task_encoding in task_encodings], dtype=torch.int64
)
else:
# during inference, we do not have any targets
targets = None
return inputs, targets
def unbatch_output(
self, model_output: TransformerTextClassificationModelBatchOutput
) -> Sequence[TaskOutputType]:
"""
Convert one model output batch to a sequence of taskmodule outputs.
"""
# get the logits from the model output
logits = model_output["logits"]
# convert the logits to "probabilities"
probabilities = logits.softmax(dim=-1).detach().cpu().numpy()
# get the max class index per example
max_label_ids = np.argmax(probabilities, axis=-1)
outputs = []
for idx, label_id in enumerate(max_label_ids):
# translate the label id back to the label text
label = self.id_to_label[label_id]
# get the probability and convert from tensor value to python float
prob = float(probabilities[idx, label_id])
# we create TransformerTextClassificationTaskOutput primarily for typing purposes,
# a simple dict would also work
result: TaskOutput = {
"label": label,
"probability": prob,
}
outputs.append(result)
return outputs
def create_annotations_from_output(
self,
task_encodings: TaskEncodingType,
task_outputs: TaskOutputType,
) -> Iterator[Tuple[str, Label]]:
"""
Convert a task output to annotations. The method has to yield tuples (annotation_name, annotation).
"""
# just yield a single annotation (other tasks may need multiple annotations per task output)
yield "label", Label(label=task_outputs["label"], score=task_outputs["probability"])