Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add on_progress callback for transcribe and align #620

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion whisperx/alignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
C. Max Bain
"""
from dataclasses import dataclass
from typing import Iterable, Union, List
from typing import Iterable, Union, List, Callable

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -101,6 +101,7 @@ def align(
return_char_alignments: bool = False,
print_progress: bool = False,
combined_progress: bool = False,
on_progress: Callable[[int, int], None] = None
) -> AlignedTranscriptionResult:
"""
Align phoneme recognition predictions to known transcription.
Expand All @@ -127,6 +128,9 @@ def align(
base_progress = ((sdx + 1) / total_segments) * 100
percent_complete = (50 + base_progress / 2) if combined_progress else base_progress
print(f"Progress: {percent_complete:.2f}%...")

if on_progress:
on_progress(sdx + 1, total_segments)

num_leading = len(segment["text"]) - len(segment["text"].lstrip())
num_trailing = len(segment["text"]) - len(segment["text"].rstrip())
Expand Down
40 changes: 38 additions & 2 deletions whisperx/asr.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
import warnings
from typing import List, Union, Optional, NamedTuple
from typing import List, Union, Optional, NamedTuple, Callable
from enum import Enum

import ctranslate2
import faster_whisper
Expand Down Expand Up @@ -93,6 +94,12 @@ class FasterWhisperPipeline(Pipeline):
# - add support for timestamp mode
# - add support for custom inference kwargs

class TranscriptionState(Enum):
LOADING_AUDIO = "loading_audio"
GENERATING_VAD_SEGMENTS = "generating_vad_segments"
TRANSCRIBING = "transcribing"
FINISHED = "finished"

def __init__(
self,
model,
Expand Down Expand Up @@ -166,9 +173,21 @@ def stack(items):
return final_iterator

def transcribe(
self, audio: Union[str, np.ndarray], batch_size=None, num_workers=0, language=None, task=None, chunk_size=30, print_progress = False, combined_progress=False
self,
audio: Union[str, np.ndarray],
batch_size=None,
num_workers=0,
language=None,
task=None,
chunk_size=30,
on_progress: Callable[[TranscriptionState, Optional[int], Optional[int]], None] = None,
print_progress: bool = False,
combined_progress: bool = False
) -> TranscriptionResult:
if isinstance(audio, str):
if on_progress:
on_progress(self.__class__.TranscriptionState.LOADING_AUDIO)

audio = load_audio(audio)

def data(audio, segments):
Expand All @@ -178,13 +197,17 @@ def data(audio, segments):
# print(f2-f1)
yield {'inputs': audio[f1:f2]}

if on_progress:
on_progress(self.__class__.TranscriptionState.GENERATING_VAD_SEGMENTS)

vad_segments = self.vad_model({"waveform": torch.from_numpy(audio).unsqueeze(0), "sample_rate": SAMPLE_RATE})
vad_segments = merge_chunks(
vad_segments,
chunk_size,
onset=self._vad_params["vad_onset"],
offset=self._vad_params["vad_offset"],
)

if self.tokenizer is None:
language = language or self.detect_language(audio)
task = task or "transcribe"
Expand All @@ -210,11 +233,21 @@ def data(audio, segments):
segments: List[SingleSegment] = []
batch_size = batch_size or self._batch_size
total_segments = len(vad_segments)

if on_progress:
on_progress(self.__class__.TranscriptionState.TRANSCRIBING, 0, total_segments)

for idx, out in enumerate(self.__call__(data(audio, vad_segments), batch_size=batch_size, num_workers=num_workers)):
# Original print-only behaviour to keep the method backwards compatible
# Should probably be replaced with a default on_progress callback some-when.
if print_progress:
base_progress = ((idx + 1) / total_segments) * 100
percent_complete = base_progress / 2 if combined_progress else base_progress
print(f"Progress: {percent_complete:.2f}%...")

if on_progress:
on_progress(self.__class__.TranscriptionState.TRANSCRIBING, idx + 1, total_segments)

text = out['text']
if batch_size in [0, 1, None]:
text = text[0]
Expand All @@ -226,6 +259,9 @@ def data(audio, segments):
}
)

if on_progress:
on_progress(self.__class__.TranscriptionState.FINISHED)

# revert the tokenizer if multilingual inference is enabled
if self.preset_language is None:
self.tokenizer = None
Expand Down