Skip to content

Commit

Permalink
Pass output_dir to superclass
Browse files Browse the repository at this point in the history
  • Loading branch information
bluestealth committed Feb 5, 2025
1 parent b961717 commit 34370ef
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions src/setfit/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from sentence_transformers import SentenceTransformerTrainer, losses
from sentence_transformers.losses.BatchHardTripletLoss import BatchHardTripletLossDistanceFunction
from sentence_transformers.model_card import ModelCardCallback as STModelCardCallback
from sentence_transformers.training_args import BatchSamplers
from sentence_transformers.training_args import BatchSamplers, SentenceTransformerTrainingArguments
from sklearn.preprocessing import LabelEncoder
from torch import nn
from transformers import __version__ as transformers_version
Expand Down Expand Up @@ -47,7 +47,11 @@ def __init__(
self._setfit_model = setfit_model
self._setfit_args = setfit_args
self.logs_prefix = "embedding"
super().__init__(model=setfit_model.model_body, **kwargs)
super().__init__(
model=setfit_model.model_body,
args=SentenceTransformerTrainingArguments(output_dir=setfit_args.output_dir),
**kwargs,
)
self._apply_training_arguments(setfit_args)

for callback in list(self.callback_handler.callbacks):
Expand Down

0 comments on commit 34370ef

Please sign in to comment.