From c0df89355d2249d82499c88139391a8c3b6e68ed Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Thu, 20 Jun 2024 21:49:33 +0800 Subject: [PATCH 1/2] Fix whisper --- scripts/whisper/export-onnx.py | 7 +++- scripts/whisper/test.py | 66 ++++++++++++++++++++++++++++++---- 2 files changed, 65 insertions(+), 8 deletions(-) diff --git a/scripts/whisper/export-onnx.py b/scripts/whisper/export-onnx.py index 1bfe03d0f..feca2542d 100755 --- a/scripts/whisper/export-onnx.py +++ b/scripts/whisper/export-onnx.py @@ -63,6 +63,10 @@ def add_meta_data(filename: str, meta_data: Dict[str, Any]): Key-value pairs. """ model = onnx.load(filename) + + while len(model.metadata_props): + model.metadata_props.pop() + for key, value in meta_data.items(): meta = model.metadata_props.add() meta.key = key @@ -422,7 +426,7 @@ def main(): encoder_meta_data = { "model_type": f"whisper-{name}", - "version": "1", + "version": "2", "maintainer": "k2-fsa", "n_mels": model.dims.n_mels, "n_audio_ctx": model.dims.n_audio_ctx, @@ -453,6 +457,7 @@ def main(): "sot_prev": tokenizer.sot_prev, "sot_lm": tokenizer.sot_lm, "no_timestamps": tokenizer.no_timestamps, + "timestamp_begin": tokenizer.timestamp_begin, } print(f"encoder_meta_data: {encoder_meta_data}") add_meta_data(filename=encoder_filename, meta_data=encoder_meta_data) diff --git a/scripts/whisper/test.py b/scripts/whisper/test.py index 014a19e6a..ee9aec505 100755 --- a/scripts/whisper/test.py +++ b/scripts/whisper/test.py @@ -6,9 +6,10 @@ """ import argparse import base64 -from typing import Tuple +from typing import List, Tuple import kaldi_native_fbank as knf +import numpy as np import onnxruntime as ort import torch import torchaudio @@ -86,21 +87,29 @@ def init_encoder(self, encoder: str): ) meta = self.encoder.get_modelmeta().custom_metadata_map + version = int(meta("version")) + if version < 2: + raise RuntimeError( + "If you have exported your model before 2024-06-21, please re-export it using the latest " + "https://github.com/k2-fsa/sherpa-onnx/blob/master/scripts/whisper/export-onnx.py" + ) self.n_text_layer = int(meta["n_text_layer"]) self.n_text_ctx = int(meta["n_text_ctx"]) self.n_text_state = int(meta["n_text_state"]) self.sot = int(meta["sot"]) + self.sot_lm = int(meta["sot_lm"]) + self.sot_prev = int(meta["sot_prev"]) self.eot = int(meta["eot"]) self.translate = int(meta["translate"]) self.transcribe = int(meta["transcribe"]) self.no_timestamps = int(meta["no_timestamps"]) + self.timestamp_begin = int(meta["timestamp_begin"]) self.no_speech = int(meta["no_speech"]) + self.non_speech_tokens = list(map(int, meta["non_speech_tokens"].split(","))) self.blank = int(meta["blank_id"]) self.sot_sequence = list(map(int, meta["sot_sequence"].split(","))) - self.sot_sequence.append(self.no_timestamps) - self.all_language_tokens = list( map(int, meta["all_language_tokens"].split(",")) ) @@ -179,20 +188,62 @@ def get_self_cache(self) -> Tuple[torch.Tensor, torch.Tensor]: return n_layer_self_k_cache, n_layer_self_v_cache def suppress_tokens(self, logits, is_initial: bool) -> None: + """ + Args: + logits: 1-D nd.array. Changed in-place. + """ # suppress blank if is_initial: logits[self.eot] = float("-inf") logits[self.blank] = float("-inf") - # suppress <|notimestamps|> - logits[self.no_timestamps] = float("-inf") - logits[self.sot] = float("-inf") + logits[self.sot_prev] = float("-inf") + logits[self.sot_lm] = float("-inf") logits[self.no_speech] = float("-inf") - # logits is changed in-place + logits[self.transcribe] = float("-inf") logits[self.translate] = float("-inf") + logits[self.no_timestamps] = float("-inf") + + # logits[self.all_language_tokens] = float("-inf") + logits[self.non_speech_tokens] = float("-inf") + + def apply_timestamp_rules(self, tokens: List[int], logits: np.ndarray): + """ + Args: + logits: 1-D nd.array. Changed in-place. + """ + sample_begin = 0 + max_initial_timestamp_index = 50 + logits[self.no_timestamps] = float("-inf") + + last_was_timestamp = len(tokens) >= 1 and tokens[-1] >= self.timestamp_begin + penultimate_was_timestamp = ( + len(tokens) < 2 or tokens[-2] >= self.timestamp_begin + ) + + if last_was_timestamp: + if penultimate_was_timestamp: + logits[self.timestamp_begin :] = float("-inf") + else: + logits[: self.eot] = float("-inf") + sampled_tokens = torch.tensor(tokens) + timestamps = sampled_tokens[sampled_tokens.ge(self.timestamp_begin)] + if timestamps.numel() > 0: + logits[self.timestamp_begin : timestamps[-1].item()] = float("-inf") + + if len(tokens) == 0: + logits[: self.timestamp_begin] = float("-inf") + last_allowed = self.timestamp_begin + max_initial_timestamp_index + logits[last_allowed + 1 :] = float("-inf") + logprobs = torch.as_tensor(logits).log_softmax(dim=-1) + timestamp_logprob = logprobs[self.timestamp_begin :].logsumexp(dim=-1) + max_text_token_lobprob = logprobs[: self.timestamp_begin].max() + if timestamp_logprob > max_text_token_lobprob: + logits[: self.timestamp_begin] = float("-inf") + def detect_language( self, n_layer_cross_k: torch.Tensor, n_layer_cross_v: torch.Tensor ) -> int: @@ -348,6 +399,7 @@ def main(): offset += 1 logits = logits[0, -1] model.suppress_tokens(logits, is_initial=False) + model.apply_timestamp_rules(tokens=results, logits=logits) max_token_id = logits.argmax(dim=-1) token_table = load_tokens(args.tokens) s = b"" From 0070136b67094aeaacbd8df26ecbd046bd0222c1 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Thu, 20 Jun 2024 22:04:08 +0800 Subject: [PATCH 2/2] fix typos --- scripts/whisper/test.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/scripts/whisper/test.py b/scripts/whisper/test.py index ee9aec505..c34c669b9 100755 --- a/scripts/whisper/test.py +++ b/scripts/whisper/test.py @@ -87,11 +87,12 @@ def init_encoder(self, encoder: str): ) meta = self.encoder.get_modelmeta().custom_metadata_map - version = int(meta("version")) + version = int(meta["version"]) if version < 2: raise RuntimeError( "If you have exported your model before 2024-06-21, please re-export it using the latest " - "https://github.com/k2-fsa/sherpa-onnx/blob/master/scripts/whisper/export-onnx.py" + "https://github.com/k2-fsa/sherpa-onnx/blob/master/scripts/whisper/export-onnx.py\n" + "See https://github.com/k2-fsa/sherpa-onnx/pull/1037 for details" ) self.n_text_layer = int(meta["n_text_layer"]) self.n_text_ctx = int(meta["n_text_ctx"])