From 4ebee43ca887ed0ed2858329e30885f7a09bd3e4 Mon Sep 17 00:00:00 2001 From: Tom Aarsen <37621491+tomaarsen@users.noreply.github.com> Date: Thu, 27 Jul 2023 16:31:24 +0200 Subject: [PATCH] Normalize device to CPU when evaluating (#363) --- src/setfit/trainer.py | 3 +++ tests/test_trainer.py | 25 +++++++++++++++++++++++++ 2 files changed, 28 insertions(+) diff --git a/src/setfit/trainer.py b/src/setfit/trainer.py index 1c4b02ed..6304ce5b 100644 --- a/src/setfit/trainer.py +++ b/src/setfit/trainer.py @@ -3,6 +3,7 @@ import evaluate import numpy as np +import torch from datasets import Dataset, DatasetDict from sentence_transformers import InputExample, losses from sentence_transformers.datasets import SentenceLabelDataset @@ -438,6 +439,8 @@ def evaluate(self, dataset: Optional[Dataset] = None) -> Dict[str, float]: logger.info("***** Running evaluation *****") y_pred = self.model.predict(x_test) + if isinstance(y_pred, torch.Tensor): + y_pred = y_pred.cpu() if isinstance(self.metric, str): metric_config = "multilabel" if self.model.multi_target_strategy is not None else None diff --git a/tests/test_trainer.py b/tests/test_trainer.py index 26d340d9..af2c9a82 100644 --- a/tests/test_trainer.py +++ b/tests/test_trainer.py @@ -5,6 +5,7 @@ import evaluate import pytest +import torch from datasets import Dataset, load_dataset from sentence_transformers import losses from transformers.testing_utils import require_optuna @@ -497,3 +498,27 @@ def test_trainer_evaluate_multilabel_f1(): trainer.train() metrics = trainer.evaluate() assert metrics == {"f1": 1.0} + + +def test_trainer_evaluate_on_cpu() -> None: + # This test used to fail if CUDA was available + dataset = Dataset.from_dict( + {"text": ["positive sentence", "negative sentence"], "label": ["positive", "negative"]} + ) + model = SetFitModel.from_pretrained( + "sentence-transformers/paraphrase-albert-small-v2", use_differentiable_head=True + ) + + def compute_metric(y_pred, y_test) -> None: + assert y_pred.device == torch.device("cpu") + return 1.0 + + trainer = SetFitTrainer( + model=model, + train_dataset=dataset, + eval_dataset=dataset, + metric=compute_metric, + num_iterations=5, + ) + trainer.train() + trainer.evaluate()