From 5ba4f6cc4b773a5546f4c047dec078e0f7a64c03 Mon Sep 17 00:00:00 2001 From: Stefan Exner Date: Sat, 9 Dec 2023 23:48:08 +0100 Subject: [PATCH] Add on_progress callback for transcribe and align This adds an optional `on_progress` argument to both the `align` and `transcribe` methods. This allows processing the current progress via a callback, until now, it was only possible to print the progress to `STDOUT`. Example: ```python transcribe( my_audio, batch_size=8, on_progress=lambda state, current=0, total=0: print(f"state: {state}, current: {current}, t: {total}") ) ``` States are defined as ```python class TranscriptionState(Enum): LOADING_AUDIO = "loading_audio" GENERATING_VAD_SEGMENTS = "generating_vad_segments" TRANSCRIBING = "transcribing" FINISHED = "finished" ``` Signed-off-by: Stefan Exner --- whisperx/alignment.py | 6 +++++- whisperx/asr.py | 40 ++++++++++++++++++++++++++++++++++++++-- 2 files changed, 43 insertions(+), 3 deletions(-) diff --git a/whisperx/alignment.py b/whisperx/alignment.py index 74e3f765..c4795e0b 100644 --- a/whisperx/alignment.py +++ b/whisperx/alignment.py @@ -3,7 +3,7 @@ C. Max Bain """ from dataclasses import dataclass -from typing import Iterable, Union, List +from typing import Iterable, Union, List, Callable import numpy as np import pandas as pd @@ -101,6 +101,7 @@ def align( return_char_alignments: bool = False, print_progress: bool = False, combined_progress: bool = False, + on_progress: Callable[[int, int], None] = None ) -> AlignedTranscriptionResult: """ Align phoneme recognition predictions to known transcription. @@ -127,6 +128,9 @@ def align( base_progress = ((sdx + 1) / total_segments) * 100 percent_complete = (50 + base_progress / 2) if combined_progress else base_progress print(f"Progress: {percent_complete:.2f}%...") + + if on_progress: + on_progress(sdx + 1, total_segments) num_leading = len(segment["text"]) - len(segment["text"].lstrip()) num_trailing = len(segment["text"]) - len(segment["text"].rstrip()) diff --git a/whisperx/asr.py b/whisperx/asr.py index 94e0311a..e900e5fd 100644 --- a/whisperx/asr.py +++ b/whisperx/asr.py @@ -1,6 +1,7 @@ import os import warnings -from typing import List, Union, Optional, NamedTuple +from typing import List, Union, Optional, NamedTuple, Callable +from enum import Enum import ctranslate2 import faster_whisper @@ -93,6 +94,12 @@ class FasterWhisperPipeline(Pipeline): # - add support for timestamp mode # - add support for custom inference kwargs + class TranscriptionState(Enum): + LOADING_AUDIO = "loading_audio" + GENERATING_VAD_SEGMENTS = "generating_vad_segments" + TRANSCRIBING = "transcribing" + FINISHED = "finished" + def __init__( self, model, @@ -166,9 +173,21 @@ def stack(items): return final_iterator def transcribe( - self, audio: Union[str, np.ndarray], batch_size=None, num_workers=0, language=None, task=None, chunk_size=30, print_progress = False, combined_progress=False + self, + audio: Union[str, np.ndarray], + batch_size=None, + num_workers=0, + language=None, + task=None, + chunk_size=30, + on_progress: Callable[[TranscriptionState, Optional[int], Optional[int]], None] = None, + print_progress: bool = False, + combined_progress: bool = False ) -> TranscriptionResult: if isinstance(audio, str): + if on_progress: + on_progress(self.__class__.TranscriptionState.LOADING_AUDIO) + audio = load_audio(audio) def data(audio, segments): @@ -178,6 +197,9 @@ def data(audio, segments): # print(f2-f1) yield {'inputs': audio[f1:f2]} + if on_progress: + on_progress(self.__class__.TranscriptionState.GENERATING_VAD_SEGMENTS) + vad_segments = self.vad_model({"waveform": torch.from_numpy(audio).unsqueeze(0), "sample_rate": SAMPLE_RATE}) vad_segments = merge_chunks( vad_segments, @@ -185,6 +207,7 @@ def data(audio, segments): onset=self._vad_params["vad_onset"], offset=self._vad_params["vad_offset"], ) + if self.tokenizer is None: language = language or self.detect_language(audio) task = task or "transcribe" @@ -210,11 +233,21 @@ def data(audio, segments): segments: List[SingleSegment] = [] batch_size = batch_size or self._batch_size total_segments = len(vad_segments) + + if on_progress: + on_progress(self.__class__.TranscriptionState.TRANSCRIBING, 0, total_segments) + for idx, out in enumerate(self.__call__(data(audio, vad_segments), batch_size=batch_size, num_workers=num_workers)): + # Original print-only behaviour to keep the method backwards compatible + # Should probably be replaced with a default on_progress callback some-when. if print_progress: base_progress = ((idx + 1) / total_segments) * 100 percent_complete = base_progress / 2 if combined_progress else base_progress print(f"Progress: {percent_complete:.2f}%...") + + if on_progress: + on_progress(self.__class__.TranscriptionState.TRANSCRIBING, idx + 1, total_segments) + text = out['text'] if batch_size in [0, 1, None]: text = text[0] @@ -226,6 +259,9 @@ def data(audio, segments): } ) + if on_progress: + on_progress(self.__class__.TranscriptionState.FINISHED) + # revert the tokenizer if multilingual inference is enabled if self.preset_language is None: self.tokenizer = None