Skip to content

Commit

Permalink
Generate: fix assistant in different device (huggingface#33257)
Browse files Browse the repository at this point in the history
  • Loading branch information
gante authored Sep 2, 2024
1 parent 52a0213 commit 97c0f45
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 1 deletion.
1 change: 1 addition & 0 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3964,6 +3964,7 @@ def _assisted_decoding(

# 1. Fetch candidate sequences from a `CandidateGenerator`
candidate_input_ids, candidate_logits = candidate_generator.get_candidates(input_ids)
candidate_input_ids = candidate_input_ids.to(self.device)
if candidate_logits is not None:
candidate_logits = candidate_logits.to(self.device)

Expand Down
2 changes: 1 addition & 1 deletion tests/generation/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3323,7 +3323,7 @@ def test_assisted_decoding_in_different_gpu(self):

@slow
@require_torch_gpu
def test_assisted_decoding_in_gpu_cpu(self):
def test_assisted_decoding_model_in_gpu_assistant_in_cpu(self):
# PT-only test: TF doesn't support assisted decoding yet.
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM").to("cuda")
assistant = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM").to(
Expand Down

0 comments on commit 97c0f45

Please sign in to comment.