Skip to content

Commit

Permalink
keep logits.device unchanged
Browse files Browse the repository at this point in the history
  • Loading branch information
lvhan028 committed Sep 24, 2024
1 parent e49b964 commit 28b0258
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions lmdeploy/serve/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,12 +209,12 @@ def get_ppl(self, inputs: List[str]) -> List[float]:
]
steps = [i] * bs
logits = generator.decode(
token_ids,
input_ids=token_ids,
steps=steps,
sequence_start=(i == 0),
sequence_end=(i + max_input_len >= max_seq_len))
bsz, seq_len, vocab_size = logits.shape
logits = logits.float().cpu()
logits = logits.float()
padding_token_id = -100
# meaning logits[..., :, :] corresponds to labels
# token_ids[1:] + predict_token_id, which is
Expand Down

0 comments on commit 28b0258

Please sign in to comment.