From 7d6f9b4711e6ee6924abf2e5c2bd070af0767e6e Mon Sep 17 00:00:00 2001 From: Eustache Le Bihan Date: Mon, 4 Nov 2024 18:00:38 +0100 Subject: [PATCH] handle edge case double tokens ending with different tokens --- .../models/whisper/generation_whisper.py | 14 ++++++++++---- .../models/whisper/tokenization_whisper.py | 16 ++++++++++------ .../models/whisper/tokenization_whisper_fast.py | 16 ++++++++++------ 3 files changed, 30 insertions(+), 16 deletions(-) diff --git a/src/transformers/models/whisper/generation_whisper.py b/src/transformers/models/whisper/generation_whisper.py index e382af0bf16fc6..845d895efdee48 100644 --- a/src/transformers/models/whisper/generation_whisper.py +++ b/src/transformers/models/whisper/generation_whisper.py @@ -308,6 +308,7 @@ def generate( num_segment_frames: Optional[int] = None, attention_mask: Optional[torch.Tensor] = None, time_precision: float = 0.02, + time_precision_features: float = 0.01, return_token_timestamps: Optional[bool] = None, return_segments: bool = False, return_dict_in_generate: Optional[bool] = None, @@ -417,6 +418,8 @@ def generate( time_precision (`int`, *optional*, defaults to 0.02): The duration of output token in seconds. *E.g.* 0.02 means that a generated token on average accounts for 20 ms. + time_precision_features (`int`, *optional*, defaults to 0.01): + The duration represented by a feature frame in seconds. return_token_timestamps (`bool`, *optional*): Whether to return token-level timestamps with the text. This can be used with or without the `return_timestamps` option. To get word-level timestamps, use the tokenizer to group the tokens into @@ -718,6 +721,7 @@ def generate( timestamp_begin=timestamp_begin, seek_num_frames=seek_num_frames, time_precision=time_precision, + time_precision_features=time_precision_features, input_stride=input_stride, prev_idx=prev_i, idx=i, @@ -1778,6 +1782,7 @@ def _retrieve_segment( timestamp_begin, seek_num_frames, time_precision, + time_precision_features, input_stride, prev_idx, idx, @@ -1805,10 +1810,11 @@ def _retrieve_segment( last_slice = 0 # Add each segment to list of all segments - for current_slice in slices: + for i, current_slice in enumerate(slices): + is_last_slice = i == len(slices) - 1 sliced_tokens = seek_sequence[last_slice:current_slice] start_timestamp_pos = sliced_tokens[0].item() - timestamp_begin - end_timestamp_pos = sliced_tokens[-1].item() - timestamp_begin + end_timestamp_pos = sliced_tokens[-1 if not is_last_slice else -2].item() - timestamp_begin segments.append( { "start": time_offset[prev_idx] + start_timestamp_pos * time_precision, @@ -1830,13 +1836,13 @@ def _retrieve_segment( # otherwise, ignore the unfinished segment and seek to the last timestamp # here we throw away all predictions after the last predicted "end of segment" # since we are cutting right in the middle of an audio - last_timestamp_pos = seek_sequence[last_slice - 1].item() - timestamp_begin + last_timestamp_pos = seek_sequence[last_slice - 2].item() - timestamp_begin segment_offset = last_timestamp_pos * input_stride else: # If whisper does not predict any "end of segment" token, then # the whole decoding is considered a segment and we add it to the list of segments timestamps = seek_sequence[timestamp_tokens.nonzero().flatten()] - last_timestamp_pos = seek_num_frames[prev_idx] + last_timestamp_pos = int(seek_num_frames[prev_idx] * time_precision_features / time_precision) if timestamps.numel() > 0 and timestamps[-1].item() != timestamp_begin: # no consecutive timestamps but it has a timestamp; use the last one. last_timestamp_pos = timestamps[-1].item() - timestamp_begin diff --git a/src/transformers/models/whisper/tokenization_whisper.py b/src/transformers/models/whisper/tokenization_whisper.py index 824a9839b84b37..e537ef95da6751 100644 --- a/src/transformers/models/whisper/tokenization_whisper.py +++ b/src/transformers/models/whisper/tokenization_whisper.py @@ -540,20 +540,23 @@ def _decode_with_timestamps( cur_max_timestamp = 0.0 prev_segments_len = 0.0 - # track if last timestamp was single ending penultimate_timestamp = 0.0 - for token in token_ids: + for i, token in enumerate(token_ids): if token >= timestamp_begin: timestamp = float((token - timestamp_begin) * time_precision) if timestamp < cur_max_timestamp: # next segment has started - last_was_single_ending = not cur_max_timestamp == penultimate_timestamp + last_was_single_ending = i >= 2 and not ( + token_ids[i - 1] >= timestamp_begin and token_ids[i - 2] >= timestamp_begin + ) if last_was_single_ending: prev_segments_len += time_precision * segment_size else: - prev_segments_len += cur_max_timestamp + cur_max_timestamp = penultimate_timestamp + prev_segments_len += penultimate_timestamp + outputs = outputs[:-2] penultimate_timestamp = cur_max_timestamp cur_max_timestamp = timestamp @@ -608,8 +611,9 @@ def _compute_offsets(self, token_ids, time_precision=0.02, segment_size=1500): if start_timestamp_position < cur_max_timestamp: # next segment has started - # here in the worst case we have [<|start_timestamp_position before|>, <|cur_max_timestamp|>, <|start_timestamp_position|>], so last_slice (idx of start_timestamp_position) - 2 is safe - is_single_ending = not token_ids[last_slice - 2] == token_ids[last_slice - 1] + is_single_ending = last_slice >= 2 and not ( + token_ids[last_slice - 2] >= timestamp_begin and token_ids[last_slice - 1] >= timestamp_begin + ) if is_single_ending: prev_segments_len += segment_size else: diff --git a/src/transformers/models/whisper/tokenization_whisper_fast.py b/src/transformers/models/whisper/tokenization_whisper_fast.py index 8a155e5bd64cde..f0383cb0def76f 100644 --- a/src/transformers/models/whisper/tokenization_whisper_fast.py +++ b/src/transformers/models/whisper/tokenization_whisper_fast.py @@ -181,20 +181,23 @@ def _decode_with_timestamps( cur_max_timestamp = 0.0 prev_segments_len = 0.0 - # track if last timestamp was single ending penultimate_timestamp = 0.0 - for token in token_ids: + for i, token in enumerate(token_ids): if token >= timestamp_begin: timestamp = float((token - timestamp_begin) * time_precision) if timestamp < cur_max_timestamp: # next segment has started - last_was_single_ending = not cur_max_timestamp == penultimate_timestamp + last_was_single_ending = i >= 2 and not ( + token_ids[i - 1] >= timestamp_begin and token_ids[i - 2] >= timestamp_begin + ) if last_was_single_ending: prev_segments_len += time_precision * segment_size else: - prev_segments_len += cur_max_timestamp + cur_max_timestamp = penultimate_timestamp + prev_segments_len += penultimate_timestamp + outputs = outputs[:-2] penultimate_timestamp = cur_max_timestamp cur_max_timestamp = timestamp @@ -250,8 +253,9 @@ def _compute_offsets(self, token_ids, time_precision=0.02, segment_size=1500): if start_timestamp_position < cur_max_timestamp: # next segment has started - # here in the worst case we have [<|start_timestamp_position before|>, <|cur_max_timestamp|>, <|start_timestamp_position|>], so last_slice (idx of start_timestamp_position) - 2 is safe - is_single_ending = not token_ids[last_slice - 2] == token_ids[last_slice - 1] + is_single_ending = last_slice >= 2 and not ( + token_ids[last_slice - 2] >= timestamp_begin and token_ids[last_slice - 1] >= timestamp_begin + ) if is_single_ending: prev_segments_len += segment_size else: