-
Notifications
You must be signed in to change notification settings - Fork 19
/
Copy pathtrainer.py
35 lines (30 loc) · 1.33 KB
/
trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
from typing import Dict, List, Tuple, Optional, Any, Union
from transformers.trainer import Trainer
from torch import nn
import torch
class IndexingTrainer(Trainer):
def __init__(self, restrict_decode_vocab, **kwds):
super().__init__(**kwds)
self.restrict_decode_vocab = restrict_decode_vocab
def compute_loss(self, model, inputs, return_outputs=False):
loss = model(input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask'], labels=inputs['labels']).loss
if return_outputs:
return loss, [None, None] # fake outputs
return loss
def prediction_step(
self,
model: nn.Module,
inputs: Dict[str, Union[torch.Tensor, Any]],
prediction_loss_only: bool,
ignore_keys: Optional[List[str]] = None,
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]:
model.eval()
# eval_loss = super().prediction_step(model, inputs, True, ignore_keys)[0]
with torch.no_grad():
# greedy search
doc_ids = model.generate(
inputs['input_ids'].to(self.args.device),
max_length=20,
prefix_allowed_tokens_fn=self.restrict_decode_vocab,
early_stopping=True,)
return (None, doc_ids, inputs['labels'])