diff --git a/scripts/whisper/export-onnx.py b/scripts/whisper/export-onnx.py index 1bfe03d0f..feca2542d 100755 --- a/scripts/whisper/export-onnx.py +++ b/scripts/whisper/export-onnx.py @@ -63,6 +63,10 @@ def add_meta_data(filename: str, meta_data: Dict[str, Any]): Key-value pairs. """ model = onnx.load(filename) + + while len(model.metadata_props): + model.metadata_props.pop() + for key, value in meta_data.items(): meta = model.metadata_props.add() meta.key = key @@ -422,7 +426,7 @@ def main(): encoder_meta_data = { "model_type": f"whisper-{name}", - "version": "1", + "version": "2", "maintainer": "k2-fsa", "n_mels": model.dims.n_mels, "n_audio_ctx": model.dims.n_audio_ctx, @@ -453,6 +457,7 @@ def main(): "sot_prev": tokenizer.sot_prev, "sot_lm": tokenizer.sot_lm, "no_timestamps": tokenizer.no_timestamps, + "timestamp_begin": tokenizer.timestamp_begin, } print(f"encoder_meta_data: {encoder_meta_data}") add_meta_data(filename=encoder_filename, meta_data=encoder_meta_data) diff --git a/scripts/whisper/test.py b/scripts/whisper/test.py index 014a19e6a..c34c669b9 100755 --- a/scripts/whisper/test.py +++ b/scripts/whisper/test.py @@ -6,9 +6,10 @@ """ import argparse import base64 -from typing import Tuple +from typing import List, Tuple import kaldi_native_fbank as knf +import numpy as np import onnxruntime as ort import torch import torchaudio @@ -86,21 +87,30 @@ def init_encoder(self, encoder: str): ) meta = self.encoder.get_modelmeta().custom_metadata_map + version = int(meta["version"]) + if version < 2: + raise RuntimeError( + "If you have exported your model before 2024-06-21, please re-export it using the latest " + "https://github.com/k2-fsa/sherpa-onnx/blob/master/scripts/whisper/export-onnx.py\n" + "See https://github.com/k2-fsa/sherpa-onnx/pull/1037 for details" + ) self.n_text_layer = int(meta["n_text_layer"]) self.n_text_ctx = int(meta["n_text_ctx"]) self.n_text_state = int(meta["n_text_state"]) self.sot = int(meta["sot"]) + self.sot_lm = int(meta["sot_lm"]) + self.sot_prev = int(meta["sot_prev"]) self.eot = int(meta["eot"]) self.translate = int(meta["translate"]) self.transcribe = int(meta["transcribe"]) self.no_timestamps = int(meta["no_timestamps"]) + self.timestamp_begin = int(meta["timestamp_begin"]) self.no_speech = int(meta["no_speech"]) + self.non_speech_tokens = list(map(int, meta["non_speech_tokens"].split(","))) self.blank = int(meta["blank_id"]) self.sot_sequence = list(map(int, meta["sot_sequence"].split(","))) - self.sot_sequence.append(self.no_timestamps) - self.all_language_tokens = list( map(int, meta["all_language_tokens"].split(",")) ) @@ -179,20 +189,62 @@ def get_self_cache(self) -> Tuple[torch.Tensor, torch.Tensor]: return n_layer_self_k_cache, n_layer_self_v_cache def suppress_tokens(self, logits, is_initial: bool) -> None: + """ + Args: + logits: 1-D nd.array. Changed in-place. + """ # suppress blank if is_initial: logits[self.eot] = float("-inf") logits[self.blank] = float("-inf") - # suppress <|notimestamps|> - logits[self.no_timestamps] = float("-inf") - logits[self.sot] = float("-inf") + logits[self.sot_prev] = float("-inf") + logits[self.sot_lm] = float("-inf") logits[self.no_speech] = float("-inf") - # logits is changed in-place + logits[self.transcribe] = float("-inf") logits[self.translate] = float("-inf") + logits[self.no_timestamps] = float("-inf") + + # logits[self.all_language_tokens] = float("-inf") + logits[self.non_speech_tokens] = float("-inf") + + def apply_timestamp_rules(self, tokens: List[int], logits: np.ndarray): + """ + Args: + logits: 1-D nd.array. Changed in-place. + """ + sample_begin = 0 + max_initial_timestamp_index = 50 + logits[self.no_timestamps] = float("-inf") + + last_was_timestamp = len(tokens) >= 1 and tokens[-1] >= self.timestamp_begin + penultimate_was_timestamp = ( + len(tokens) < 2 or tokens[-2] >= self.timestamp_begin + ) + + if last_was_timestamp: + if penultimate_was_timestamp: + logits[self.timestamp_begin :] = float("-inf") + else: + logits[: self.eot] = float("-inf") + sampled_tokens = torch.tensor(tokens) + timestamps = sampled_tokens[sampled_tokens.ge(self.timestamp_begin)] + if timestamps.numel() > 0: + logits[self.timestamp_begin : timestamps[-1].item()] = float("-inf") + + if len(tokens) == 0: + logits[: self.timestamp_begin] = float("-inf") + last_allowed = self.timestamp_begin + max_initial_timestamp_index + logits[last_allowed + 1 :] = float("-inf") + logprobs = torch.as_tensor(logits).log_softmax(dim=-1) + timestamp_logprob = logprobs[self.timestamp_begin :].logsumexp(dim=-1) + max_text_token_lobprob = logprobs[: self.timestamp_begin].max() + if timestamp_logprob > max_text_token_lobprob: + logits[: self.timestamp_begin] = float("-inf") + def detect_language( self, n_layer_cross_k: torch.Tensor, n_layer_cross_v: torch.Tensor ) -> int: @@ -348,6 +400,7 @@ def main(): offset += 1 logits = logits[0, -1] model.suppress_tokens(logits, is_initial=False) + model.apply_timestamp_rules(tokens=results, logits=logits) max_token_id = logits.argmax(dim=-1) token_table = load_tokens(args.tokens) s = b""