From 7486f4b6679a35074aa9e02cf17bac67ab37b109 Mon Sep 17 00:00:00 2001 From: Ita Zaporozhets Date: Fri, 20 Sep 2024 17:08:18 +0200 Subject: [PATCH 1/5] handle last element out of range error --- src/transformers/models/whisper/tokenization_whisper.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/whisper/tokenization_whisper.py b/src/transformers/models/whisper/tokenization_whisper.py index 0a6eb75c55f66c..5f276a9ed24856 100644 --- a/src/transformers/models/whisper/tokenization_whisper.py +++ b/src/transformers/models/whisper/tokenization_whisper.py @@ -1057,9 +1057,9 @@ def new_chunk(): start_time = round(token_timestamps[i] + time_offset, 2) if i + 1 < len(token_timestamps): end_time = round(token_timestamps[i + 1] + time_offset, 2) + current_token_timestamps.append((start_time, end_time)) else: end_time = None # should never happen - current_token_timestamps.append((start_time, end_time)) if "stride" in output: time_offset += chunk_len - stride_right @@ -1192,7 +1192,7 @@ def _find_longest_common_sequence(sequences, token_timestamp_sequences=None): # and have timestamps that are in order matches = sum( 1 - for idx, elem in enumerate(left) + for idx, elem in enumerate(left[:-1]) if ( elem == right[idx] and left_token_timestamp_sequence[left_start + idx] From a547e62e5a6276266edc7d35f313dab2750277f0 Mon Sep 17 00:00:00 2001 From: Ita Zaporozhets Date: Wed, 9 Oct 2024 12:59:20 +0200 Subject: [PATCH 2/5] simplified --- .../models/whisper/tokenization_whisper.py | 4 ++-- .../whisper/test_tokenization_whisper.py | 23 +++++++++---------- 2 files changed, 13 insertions(+), 14 deletions(-) diff --git a/src/transformers/models/whisper/tokenization_whisper.py b/src/transformers/models/whisper/tokenization_whisper.py index 5f276a9ed24856..4ef76efdaf626d 100644 --- a/src/transformers/models/whisper/tokenization_whisper.py +++ b/src/transformers/models/whisper/tokenization_whisper.py @@ -1052,10 +1052,10 @@ def new_chunk(): # 4/ Regular token # We just append to the list of all tokens so we can handle # merges later and decode into text. - current_tokens.append(token) if return_timestamps == "word": start_time = round(token_timestamps[i] + time_offset, 2) if i + 1 < len(token_timestamps): + current_tokens.append(token) end_time = round(token_timestamps[i + 1] + time_offset, 2) current_token_timestamps.append((start_time, end_time)) else: @@ -1192,7 +1192,7 @@ def _find_longest_common_sequence(sequences, token_timestamp_sequences=None): # and have timestamps that are in order matches = sum( 1 - for idx, elem in enumerate(left[:-1]) + for idx, elem in enumerate(left) if ( elem == right[idx] and left_token_timestamp_sequence[left_start + idx] diff --git a/tests/models/whisper/test_tokenization_whisper.py b/tests/models/whisper/test_tokenization_whisper.py index 27b24448d5a2be..197fe12b08d9c8 100644 --- a/tests/models/whisper/test_tokenization_whisper.py +++ b/tests/models/whisper/test_tokenization_whisper.py @@ -340,18 +340,17 @@ def test_basic_normalizer(self): def test_decode_asr_with_word_level_timestamps(self): # fmt: off - model_outputs = [ - { - 'stride': [10, 0, 5], - 'tokens': np.array([[ 50257, 50362, 3363, 11, 345, 460, 0, 2329, 466, 340, 0, 50256 ]]), - 'token_timestamps': np.array([[ 0, 0, 5.18, 5.56, 5.56, 5.84, 6.36, 7.12, 7.54, 7.82, 8.16, 9.48 ]]) - }, - { - 'stride': [10, 5, 0], - 'tokens': np.array([[ 50257, 50362, 2329, 466, 340, 0, 3363, 345, 460, 0, 2329, 466, 340, 50256 ]]), - 'token_timestamps': np.array([[ 0, 0, 0, 2.44, 4.3, 5.04, 5.06, 5.56, 5.8, 6.32, 7.12, 7.56, 7.8, 8.72 ]]) - } - ] + model_outputs = [{'tokens': np.array([[ + 50258, 13, 286, 841, 264, + 596, 346, 13, 583, 406, 281]]), 'token_timestamps': np.array([[ + 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, + 0.3200, 0.3400, 0.5000, 0.6600, 0.8600, 1.0000, + ]]), 'stride': (10.0, 0.0, 5.0)}, {'tokens': np.array([[ + 50258, 50259, 50359, 50363, 1449, + 466, 498, 436, 536, 385, 6588]]), 'token_timestamps': np.array([[ + 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, + 0.2600, 0.6800, 0.9800, 1.2400, 1.3000, 1.5600, + ]]), 'stride': (10, 5.0, 0.0)}] # fmt: on tokenizer = WhisperTokenizer.from_pretrained("onnx-community/whisper-tiny.en_timestamped") From 3db27d3b5394f2e96ba46cb013536cf089af69ec Mon Sep 17 00:00:00 2001 From: Ita Zaporozhets Date: Wed, 9 Oct 2024 13:01:53 +0200 Subject: [PATCH 3/5] revert test --- .../whisper/test_tokenization_whisper.py | 23 ++++++++++--------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/tests/models/whisper/test_tokenization_whisper.py b/tests/models/whisper/test_tokenization_whisper.py index 197fe12b08d9c8..fab21bc958952d 100644 --- a/tests/models/whisper/test_tokenization_whisper.py +++ b/tests/models/whisper/test_tokenization_whisper.py @@ -340,17 +340,18 @@ def test_basic_normalizer(self): def test_decode_asr_with_word_level_timestamps(self): # fmt: off - model_outputs = [{'tokens': np.array([[ - 50258, 13, 286, 841, 264, - 596, 346, 13, 583, 406, 281]]), 'token_timestamps': np.array([[ - 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, - 0.3200, 0.3400, 0.5000, 0.6600, 0.8600, 1.0000, - ]]), 'stride': (10.0, 0.0, 5.0)}, {'tokens': np.array([[ - 50258, 50259, 50359, 50363, 1449, - 466, 498, 436, 536, 385, 6588]]), 'token_timestamps': np.array([[ - 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, - 0.2600, 0.6800, 0.9800, 1.2400, 1.3000, 1.5600, - ]]), 'stride': (10, 5.0, 0.0)}] + model_outputs = [ + { + 'stride': [10, 0, 5], + 'tokens': np.array([[50257, 50362, 3363, 11, 345, 460, 0, 2329, 466, 340, 0, 50256]]), + 'token_timestamps': np.array([[0, 0, 5.18, 5.56, 5.56, 5.84, 6.36, 7.12, 7.54, 7.82, 8.16, 9.48]]) + }, + { + 'stride': [10, 5, 0], + 'tokens': np.array([[50257, 50362, 2329, 466, 340, 0, 3363, 345, 460, 0, 2329, 466, 340, 50256]]), + 'token_timestamps': np.array([[0, 0, 0, 2.44, 4.3, 5.04, 5.06, 5.56, 5.8, 6.32, 7.12, 7.56, 7.8, 8.72]]) + } + ] # fmt: on tokenizer = WhisperTokenizer.from_pretrained("onnx-community/whisper-tiny.en_timestamped") From bd65496b639a5aaed4a4f87e5739b141387e8352 Mon Sep 17 00:00:00 2001 From: Ita Zaporozhets Date: Wed, 9 Oct 2024 13:03:10 +0200 Subject: [PATCH 4/5] revert test --- tests/models/whisper/test_tokenization_whisper.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/models/whisper/test_tokenization_whisper.py b/tests/models/whisper/test_tokenization_whisper.py index fab21bc958952d..27b24448d5a2be 100644 --- a/tests/models/whisper/test_tokenization_whisper.py +++ b/tests/models/whisper/test_tokenization_whisper.py @@ -343,13 +343,13 @@ def test_decode_asr_with_word_level_timestamps(self): model_outputs = [ { 'stride': [10, 0, 5], - 'tokens': np.array([[50257, 50362, 3363, 11, 345, 460, 0, 2329, 466, 340, 0, 50256]]), - 'token_timestamps': np.array([[0, 0, 5.18, 5.56, 5.56, 5.84, 6.36, 7.12, 7.54, 7.82, 8.16, 9.48]]) + 'tokens': np.array([[ 50257, 50362, 3363, 11, 345, 460, 0, 2329, 466, 340, 0, 50256 ]]), + 'token_timestamps': np.array([[ 0, 0, 5.18, 5.56, 5.56, 5.84, 6.36, 7.12, 7.54, 7.82, 8.16, 9.48 ]]) }, { 'stride': [10, 5, 0], - 'tokens': np.array([[50257, 50362, 2329, 466, 340, 0, 3363, 345, 460, 0, 2329, 466, 340, 50256]]), - 'token_timestamps': np.array([[0, 0, 0, 2.44, 4.3, 5.04, 5.06, 5.56, 5.8, 6.32, 7.12, 7.56, 7.8, 8.72]]) + 'tokens': np.array([[ 50257, 50362, 2329, 466, 340, 0, 3363, 345, 460, 0, 2329, 466, 340, 50256 ]]), + 'token_timestamps': np.array([[ 0, 0, 0, 2.44, 4.3, 5.04, 5.06, 5.56, 5.8, 6.32, 7.12, 7.56, 7.8, 8.72 ]]) } ] # fmt: on From fe5b0ae07e89c6a5a2bcbf25f3370a7e3cd94680 Mon Sep 17 00:00:00 2001 From: Ita Zaporozhets Date: Wed, 16 Oct 2024 12:54:33 +0200 Subject: [PATCH 5/5] added test --- .../whisper/test_tokenization_whisper.py | 42 +++++++++++++++++++ 1 file changed, 42 insertions(+) diff --git a/tests/models/whisper/test_tokenization_whisper.py b/tests/models/whisper/test_tokenization_whisper.py index 27b24448d5a2be..dac69b9d309247 100644 --- a/tests/models/whisper/test_tokenization_whisper.py +++ b/tests/models/whisper/test_tokenization_whisper.py @@ -374,6 +374,48 @@ def test_decode_asr_with_word_level_timestamps(self): ) self.assertEqual(result, EXPECTED_OUTPUT) + # fmt: off + model_outputs = [ + { + 'stride': (30.0, 0.0, 5.0), + 'tokens': np.array([[286, 478, 2633, 760, 420, 2633, 264, 558, 2372, 13, 286, 841, 264, 596, 346, 13, 583, 406, 1547, 281]]), + 'token_timestamps': np.array( + [[23.88, 24.06, 24.06, 24.3, 24.54, 24.72, 24.98, 25.2, 25.36, 25.62, + 25.66, 25.8, 26.06, 26.26, 26.34, 26.48, 26.52, 26.72, 26.86, 27.08]]) + }, + { + 'stride': (10.0075, 5.0, 0.0), + 'tokens': np.array([[2633, 6385, 286, 478, 2633, 264, 558, 2372, 286, 841, 264, 4588, 457, 406, 1547, 281, 652, 385, 605, 493]]), + 'token_timestamps': np.array( + [[4.12, 4.32, 4.58, 4.76, 4.84, 4.9, 5.2, 5.36, 5.62, 5.82, + 6.02, 6.26, 6.48, 6.74, 6.86, 7.08, 7.32, 7.42, 7.66, 7.8]]) + } + ] + # fmt: on + + result = tokenizer._decode_asr( + model_outputs, return_timestamps="word", return_language=False, time_precision=0.02 + ) + + EXPECTED_OUTPUT = ( + " ofectjoy knowocjoy sace threat.ublic s influencept Lians anoryusical", + { + "chunks": [ + {"text": " ofectjoy", "timestamp": (23.88, 24.3)}, + {"text": " knowocjoy", "timestamp": (24.3, 24.98)}, + {"text": " sace", "timestamp": (24.98, 25.36)}, + {"text": " threat", "timestamp": (25.36, 25.62)}, + {"text": ".ublic", "timestamp": (25.62, 26.02)}, + {"text": " s", "timestamp": (26.02, 26.26)}, + {"text": " influencept", "timestamp": (26.26, 26.74)}, + {"text": " Lians", "timestamp": (26.74, 27.08)}, + {"text": " anoryusical", "timestamp": (27.08, 27.8)}, + ] + }, + ) + + self.assertEqual(result, EXPECTED_OUTPUT) + class SpeechToTextTokenizerMultilinguialTest(unittest.TestCase): checkpoint_name = "openai/whisper-small.en"