Skip to content

Commit

Permalink
Correct Whisper's beam search scores computation (huggingface#32336)
Browse files Browse the repository at this point in the history
fix proposal
  • Loading branch information
ylacombe authored Sep 12, 2024
1 parent e688996 commit 8f8af0f
Showing 1 changed file with 7 additions and 2 deletions.
9 changes: 7 additions & 2 deletions src/transformers/models/whisper/generation_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -954,7 +954,9 @@ def _postprocess_outputs(

seek_outputs["sequences"] = seek_outputs["sequences"][:, start_idx:]

def split_by_batch_index(values, key, batch_idx, is_shortform):
def split_by_batch_index(values, key, batch_idx, is_shortform, beam_indices=None):
if beam_indices is not None and key == "scores":
return [v[beam_idx].cpu() for (v, beam_idx) in zip(values, beam_indices[batch_idx][: len(values)])]
if key in ["scores", "encoder_attentions", "encoder_hidden_states", "logits"]:
return [v[batch_idx].cpu() for v in values]
if key in ["decoder_attentions", "decoder_hidden_states", "cross_attentions"]:
Expand Down Expand Up @@ -985,7 +987,10 @@ def split_by_batch_index(values, key, batch_idx, is_shortform):

sequence_tokens = seek_outputs["sequences"]
seek_outputs = [
{k: split_by_batch_index(v, k, i, is_shortform) for k, v in seek_outputs.items()}
{
k: split_by_batch_index(v, k, i, is_shortform, beam_indices=seek_outputs.get("beam_indices"))
for k, v in seek_outputs.items()
}
for i in range(sequence_tokens.shape[0])
]

Expand Down

0 comments on commit 8f8af0f

Please sign in to comment.