Skip to content

Commit

Permalink
handle edge case double tokens ending with different tokens
Browse files Browse the repository at this point in the history
  • Loading branch information
eustlb committed Nov 4, 2024
1 parent 09af9de commit 7d6f9b4
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 16 deletions.
14 changes: 10 additions & 4 deletions src/transformers/models/whisper/generation_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,7 @@ def generate(
num_segment_frames: Optional[int] = None,
attention_mask: Optional[torch.Tensor] = None,
time_precision: float = 0.02,
time_precision_features: float = 0.01,
return_token_timestamps: Optional[bool] = None,
return_segments: bool = False,
return_dict_in_generate: Optional[bool] = None,
Expand Down Expand Up @@ -417,6 +418,8 @@ def generate(
time_precision (`int`, *optional*, defaults to 0.02):
The duration of output token in seconds. *E.g.* 0.02 means that a generated token on average accounts
for 20 ms.
time_precision_features (`int`, *optional*, defaults to 0.01):
The duration represented by a feature frame in seconds.
return_token_timestamps (`bool`, *optional*):
Whether to return token-level timestamps with the text. This can be used with or without the
`return_timestamps` option. To get word-level timestamps, use the tokenizer to group the tokens into
Expand Down Expand Up @@ -718,6 +721,7 @@ def generate(
timestamp_begin=timestamp_begin,
seek_num_frames=seek_num_frames,
time_precision=time_precision,
time_precision_features=time_precision_features,
input_stride=input_stride,
prev_idx=prev_i,
idx=i,
Expand Down Expand Up @@ -1778,6 +1782,7 @@ def _retrieve_segment(
timestamp_begin,
seek_num_frames,
time_precision,
time_precision_features,
input_stride,
prev_idx,
idx,
Expand Down Expand Up @@ -1805,10 +1810,11 @@ def _retrieve_segment(

last_slice = 0
# Add each segment to list of all segments
for current_slice in slices:
for i, current_slice in enumerate(slices):
is_last_slice = i == len(slices) - 1
sliced_tokens = seek_sequence[last_slice:current_slice]
start_timestamp_pos = sliced_tokens[0].item() - timestamp_begin
end_timestamp_pos = sliced_tokens[-1].item() - timestamp_begin
end_timestamp_pos = sliced_tokens[-1 if not is_last_slice else -2].item() - timestamp_begin
segments.append(
{
"start": time_offset[prev_idx] + start_timestamp_pos * time_precision,
Expand All @@ -1830,13 +1836,13 @@ def _retrieve_segment(
# otherwise, ignore the unfinished segment and seek to the last timestamp
# here we throw away all predictions after the last predicted "end of segment"
# since we are cutting right in the middle of an audio
last_timestamp_pos = seek_sequence[last_slice - 1].item() - timestamp_begin
last_timestamp_pos = seek_sequence[last_slice - 2].item() - timestamp_begin
segment_offset = last_timestamp_pos * input_stride
else:
# If whisper does not predict any "end of segment" token, then
# the whole decoding is considered a segment and we add it to the list of segments
timestamps = seek_sequence[timestamp_tokens.nonzero().flatten()]
last_timestamp_pos = seek_num_frames[prev_idx]
last_timestamp_pos = int(seek_num_frames[prev_idx] * time_precision_features / time_precision)
if timestamps.numel() > 0 and timestamps[-1].item() != timestamp_begin:
# no consecutive timestamps but it has a timestamp; use the last one.
last_timestamp_pos = timestamps[-1].item() - timestamp_begin
Expand Down
16 changes: 10 additions & 6 deletions src/transformers/models/whisper/tokenization_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -540,20 +540,23 @@ def _decode_with_timestamps(

cur_max_timestamp = 0.0
prev_segments_len = 0.0
# track if last timestamp was single ending
penultimate_timestamp = 0.0

for token in token_ids:
for i, token in enumerate(token_ids):
if token >= timestamp_begin:
timestamp = float((token - timestamp_begin) * time_precision)

if timestamp < cur_max_timestamp:
# next segment has started
last_was_single_ending = not cur_max_timestamp == penultimate_timestamp
last_was_single_ending = i >= 2 and not (
token_ids[i - 1] >= timestamp_begin and token_ids[i - 2] >= timestamp_begin
)
if last_was_single_ending:
prev_segments_len += time_precision * segment_size
else:
prev_segments_len += cur_max_timestamp
cur_max_timestamp = penultimate_timestamp
prev_segments_len += penultimate_timestamp
outputs = outputs[:-2]

penultimate_timestamp = cur_max_timestamp
cur_max_timestamp = timestamp
Expand Down Expand Up @@ -608,8 +611,9 @@ def _compute_offsets(self, token_ids, time_precision=0.02, segment_size=1500):

if start_timestamp_position < cur_max_timestamp:
# next segment has started
# here in the worst case we have [<|start_timestamp_position before|>, <|cur_max_timestamp|>, <|start_timestamp_position|>], so last_slice (idx of start_timestamp_position) - 2 is safe
is_single_ending = not token_ids[last_slice - 2] == token_ids[last_slice - 1]
is_single_ending = last_slice >= 2 and not (
token_ids[last_slice - 2] >= timestamp_begin and token_ids[last_slice - 1] >= timestamp_begin
)
if is_single_ending:
prev_segments_len += segment_size
else:
Expand Down
16 changes: 10 additions & 6 deletions src/transformers/models/whisper/tokenization_whisper_fast.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,20 +181,23 @@ def _decode_with_timestamps(

cur_max_timestamp = 0.0
prev_segments_len = 0.0
# track if last timestamp was single ending
penultimate_timestamp = 0.0

for token in token_ids:
for i, token in enumerate(token_ids):
if token >= timestamp_begin:
timestamp = float((token - timestamp_begin) * time_precision)

if timestamp < cur_max_timestamp:
# next segment has started
last_was_single_ending = not cur_max_timestamp == penultimate_timestamp
last_was_single_ending = i >= 2 and not (
token_ids[i - 1] >= timestamp_begin and token_ids[i - 2] >= timestamp_begin
)
if last_was_single_ending:
prev_segments_len += time_precision * segment_size
else:
prev_segments_len += cur_max_timestamp
cur_max_timestamp = penultimate_timestamp
prev_segments_len += penultimate_timestamp
outputs = outputs[:-2]

penultimate_timestamp = cur_max_timestamp
cur_max_timestamp = timestamp
Expand Down Expand Up @@ -250,8 +253,9 @@ def _compute_offsets(self, token_ids, time_precision=0.02, segment_size=1500):

if start_timestamp_position < cur_max_timestamp:
# next segment has started
# here in the worst case we have [<|start_timestamp_position before|>, <|cur_max_timestamp|>, <|start_timestamp_position|>], so last_slice (idx of start_timestamp_position) - 2 is safe
is_single_ending = not token_ids[last_slice - 2] == token_ids[last_slice - 1]
is_single_ending = last_slice >= 2 and not (
token_ids[last_slice - 2] >= timestamp_begin and token_ids[last_slice - 1] >= timestamp_begin
)
if is_single_ending:
prev_segments_len += segment_size
else:
Expand Down

0 comments on commit 7d6f9b4

Please sign in to comment.