-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathencoder.py
112 lines (87 loc) · 3.45 KB
/
encoder.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import torch
from transformers import (
RobertaPreTrainedModel,
RobertaModel,
BertModel,
BertPreTrainedModel,
)
class BertEncoder_For_CrossEncoder(BertPreTrainedModel):
"""
Encoder for crossencoder using BertModel as a backbone model
In the case of crossencoders,
questions and phrases are combined,
and the scalar value obtained by passing the final cls token through the linear layer
is used as a score for the similarity of the q-p pair.
"""
def __init__(self, config):
super(BertEncoder_For_CrossEncoder, self).__init__(config)
self.bert = BertModel(config) # Call BertModel
self.init_weights() # initalized Weight
classifier_dropout = ( # Dropout
config.classifier_dropout
if config.classifier_dropout is not None
else config.hidden_dropout_prob
)
self.dropout = torch.nn.Dropout(classifier_dropout)
self.linear = torch.nn.Linear(config.hidden_size, 1)
def forward(self, input_ids, attention_mask=None, token_type_ids=None):
outputs = self.bert(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids, # If you want to use Roberta Model, Comment out this code
)
pooled_output = outputs[1] # CLS pooled output
pooled_output = self.dropout(pooled_output) # apply dropout
output = self.linear(pooled_output) # apply classifier
return output
class RoBertaEncoder_For_CrossEncoder(RobertaPreTrainedModel):
"""
Encoder for crossencoder using RoBertaModel as a backbone model
In the case of crossencoders,
questions and phrases are combined,
and the scalar value obtained by passing the final cls token through the linear layer
is used as a score for the similarity of the q-p pair.
"""
def __init__(self, config):
super(RoBertaEncoder_For_CrossEncoder, self).__init__(config)
self.roberta = RobertaModel(config) # Call RobertaModel
self.init_weights() # initalized Weight
classifier_dropout = ( # Dropout
config.classifier_dropout
if config.classifier_dropout is not None
else config.hidden_dropout_prob
)
self.dropout = torch.nn.Dropout(classifier_dropout)
self.linear = torch.nn.Linear(config.hidden_size, 1)
def forward(
self,
input_ids,
attention_mask=None,
# token_type_ids=None
):
outputs = self.roberta(
input_ids,
attention_mask=attention_mask,
# token_type_ids=token_type_ids
)
pooled_output = outputs[1]
pooled_output = self.dropout(pooled_output)
output = self.linear(pooled_output)
return output
class BertEncoder_For_BiEncoder(BertPreTrainedModel):
"""
Encoder for bi-encoder using BertModel as a backbone model
In the case of a bi-encoder,
the question and phrase each have a hidden embedding
for the cls token as the final output.
"""
def __init__(self, config):
super(BertEncoder_For_BiEncoder, self).__init__(config)
self.bert = BertModel(config)
self.init_weights()
def forward(self, input_ids, attention_mask=None, token_type_ids=None):
outputs = self.bert(
input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids
)
pooled_output = outputs[1]
return pooled_output