Skip to content

Commit 5cabc75

Browse files
authored
Add compute_loss_func to Seq2SeqTrainer (huggingface#35136)
1 parent 90f256c commit 5cabc75

File tree

1 file changed

+2
-0
lines changed

1 file changed

+2
-0
lines changed

src/transformers/trainer_seq2seq.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ def __init__(
6464
Union["PreTrainedTokenizerBase", "BaseImageProcessor", "FeatureExtractionMixin", "ProcessorMixin"]
6565
] = None,
6666
model_init: Optional[Callable[[], "PreTrainedModel"]] = None,
67+
compute_loss_func: Optional[Callable] = None,
6768
compute_metrics: Optional[Callable[["EvalPrediction"], Dict]] = None,
6869
callbacks: Optional[List["TrainerCallback"]] = None,
6970
optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
@@ -77,6 +78,7 @@ def __init__(
7778
eval_dataset=eval_dataset,
7879
processing_class=processing_class,
7980
model_init=model_init,
81+
compute_loss_func=compute_loss_func,
8082
compute_metrics=compute_metrics,
8183
callbacks=callbacks,
8284
optimizers=optimizers,

0 commit comments

Comments
 (0)