From 97c0f45b9c4274ee2b1cfe821ba52b8a70029e76 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Mon, 2 Sep 2024 14:37:49 +0100 Subject: [PATCH] Generate: fix assistant in different device (#33257) --- src/transformers/generation/utils.py | 1 + tests/generation/test_utils.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index c0fe3acb9eb32a..79105667dbe0c7 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -3964,6 +3964,7 @@ def _assisted_decoding( # 1. Fetch candidate sequences from a `CandidateGenerator` candidate_input_ids, candidate_logits = candidate_generator.get_candidates(input_ids) + candidate_input_ids = candidate_input_ids.to(self.device) if candidate_logits is not None: candidate_logits = candidate_logits.to(self.device) diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index ba28ffa51857b5..3a33f7cd704e24 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -3323,7 +3323,7 @@ def test_assisted_decoding_in_different_gpu(self): @slow @require_torch_gpu - def test_assisted_decoding_in_gpu_cpu(self): + def test_assisted_decoding_model_in_gpu_assistant_in_cpu(self): # PT-only test: TF doesn't support assisted decoding yet. model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM").to("cuda") assistant = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM").to(