Skip to content

Commit

Permalink
Normalize device to CPU when evaluating (#363)
Browse files Browse the repository at this point in the history
  • Loading branch information
tomaarsen authored Jul 27, 2023
1 parent 1ecd91c commit 4ebee43
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 0 deletions.
3 changes: 3 additions & 0 deletions src/setfit/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import evaluate
import numpy as np
import torch
from datasets import Dataset, DatasetDict
from sentence_transformers import InputExample, losses
from sentence_transformers.datasets import SentenceLabelDataset
Expand Down Expand Up @@ -438,6 +439,8 @@ def evaluate(self, dataset: Optional[Dataset] = None) -> Dict[str, float]:

logger.info("***** Running evaluation *****")
y_pred = self.model.predict(x_test)
if isinstance(y_pred, torch.Tensor):
y_pred = y_pred.cpu()

if isinstance(self.metric, str):
metric_config = "multilabel" if self.model.multi_target_strategy is not None else None
Expand Down
25 changes: 25 additions & 0 deletions tests/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import evaluate
import pytest
import torch
from datasets import Dataset, load_dataset
from sentence_transformers import losses
from transformers.testing_utils import require_optuna
Expand Down Expand Up @@ -497,3 +498,27 @@ def test_trainer_evaluate_multilabel_f1():
trainer.train()
metrics = trainer.evaluate()
assert metrics == {"f1": 1.0}


def test_trainer_evaluate_on_cpu() -> None:
# This test used to fail if CUDA was available
dataset = Dataset.from_dict(
{"text": ["positive sentence", "negative sentence"], "label": ["positive", "negative"]}
)
model = SetFitModel.from_pretrained(
"sentence-transformers/paraphrase-albert-small-v2", use_differentiable_head=True
)

def compute_metric(y_pred, y_test) -> None:
assert y_pred.device == torch.device("cpu")
return 1.0

trainer = SetFitTrainer(
model=model,
train_dataset=dataset,
eval_dataset=dataset,
metric=compute_metric,
num_iterations=5,
)
trainer.train()
trainer.evaluate()

0 comments on commit 4ebee43

Please sign in to comment.