Skip to content

Commit

Permalink
add TranscribeProgressReceiver for update monitoring
Browse files Browse the repository at this point in the history
Current `transcribe` API only outputs the progress and transcribed texts
on stdout. Callers can only access the result after the whole
transcription is done, and they need to hijack `tqdm` interface to get
the realtime transcription progress. This commit adds a simple interface
that can be passed as a parameter in `transcribe` so the API users don't
need to fallback to above hacks or low-level APIs for this need.

Signed-off-by: Austin Chang <[email protected]>
  • Loading branch information
austin880625 committed Oct 19, 2024
1 parent 25639fc commit bdbe6bf
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 8 deletions.
21 changes: 20 additions & 1 deletion tests/test_transcribe.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,37 @@
from whisper.tokenizer import get_tokenizer


class TestingProgressReceiver(whisper.TranscribeProgressReceiver):
def start(self, total: int):
self.result = ""
self.total = total
self.progress = 0
return self
def update_line(self, start: float, end: float, text: str):
self.result += text
def update(self, n):
self.progress += n
def get_result(self):
return self.result
def verify_total(self):
return self.total == self.progress

@pytest.mark.parametrize("model_name", whisper.available_models())
def test_transcribe(model_name: str):
device = "cuda" if torch.cuda.is_available() else "cpu"
model = whisper.load_model(model_name).to(device)
audio_path = os.path.join(os.path.dirname(__file__), "jfk.flac")
receiver = TestingProgressReceiver()

language = "en" if model_name.endswith(".en") else None
result = model.transcribe(
audio_path, language=language, temperature=0.0, word_timestamps=True
audio_path, language=language, temperature=0.0, word_timestamps=True,
progress_receiver=receiver
)
assert receiver.verify_total()
assert result["language"] == "en"
assert result["text"] == "".join([s["text"] for s in result["segments"]])
assert result["text"] == receiver.get_result()

transcription = result["text"].lower()
assert "my fellow americans" in transcription
Expand Down
2 changes: 1 addition & 1 deletion whisper/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from .audio import load_audio, log_mel_spectrogram, pad_or_trim
from .decoding import DecodingOptions, DecodingResult, decode, detect_language
from .model import ModelDimensions, Whisper
from .transcribe import transcribe
from .transcribe import TranscribeProgressReceiver, transcribe
from .version import __version__

_MODELS = {
Expand Down
60 changes: 54 additions & 6 deletions whisper/transcribe.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import os
import traceback
import warnings
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
from typing import TYPE_CHECKING, List, Optional, Tuple, Union, Self

import numpy as np
import torch
Expand Down Expand Up @@ -34,12 +34,57 @@
if TYPE_CHECKING:
from .model import Whisper

class TranscribeProgressReceiver:
"""
A class that allows external classes to inherit and handle transcription progress in customized
manners.
"""
def start(self, total: int) -> Self:
"""
The method is called when the transcription starts with integral `total` parameter in frames.
In most case this method should return `self`
"""
return self
def update(self, n: int):
"""
The `update` method is called with increment `n` in frames whenever a segment is transcribed.
"""
pass
def update_line(self, start: float, end: float, text: str):
"""
It is called whenever a segment is transcribed.
Parameters
----------
start: float
The floating point start time of the segment in seconds
end: float
The floating point end time of the segment in seconds
text: str
The transcribed text
"""
pass
def __enter__(self) -> Self:
"""
Inherit this method if resources allocation is needed at the start of the transcription.
In most cases this method should return `self`
"""
return self
def __exit__(self, exception_type, exception_value, exception_traceback):
"""
Inherit this method if resources need to be released when the transcription is finished or
terminated.
"""
pass

def transcribe(
model: "Whisper",
audio: Union[str, np.ndarray, torch.Tensor],
*,
verbose: Optional[bool] = None,
progress_receiver: TranscribeProgressReceiver = TranscribeProgressReceiver(),
temperature: Union[float, Tuple[float, ...]] = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0),
compression_ratio_threshold: Optional[float] = 2.4,
logprob_threshold: Optional[float] = -1.0,
Expand Down Expand Up @@ -253,7 +298,8 @@ def new_segment(
# show the progress bar when verbose is False (if True, transcribed text will be printed)
with tqdm.tqdm(
total=content_frames, unit="frames", disable=verbose is not False
) as pbar:
) as pbar, \
progress_receiver.start(total=content_frames) as ext_progress:
last_speech_timestamp = 0.0
# NOTE: This loop is obscurely flattened to make the diff readable.
# A later commit should turn this into a simpler nested loop.
Expand Down Expand Up @@ -459,10 +505,11 @@ def next_words_segment(segments: List[dict]) -> Optional[dict]:
if last_word_end is not None:
last_speech_timestamp = last_word_end

if verbose:
for segment in current_segments:
start, end, text = segment["start"], segment["end"], segment["text"]
line = f"[{format_timestamp(start)} --> {format_timestamp(end)}] {text}"
for segment in current_segments:
start, end, text = segment["start"], segment["end"], segment["text"]
line = f"[{format_timestamp(start)} --> {format_timestamp(end)}] {text}"
ext_progress.update_line(start, end, make_safe(text))
if verbose:
print(make_safe(line))

# if a segment is instantaneous or does not contain text, clear it
Expand Down Expand Up @@ -490,6 +537,7 @@ def next_words_segment(segments: List[dict]) -> Optional[dict]:

# update progress bar
pbar.update(min(content_frames, seek) - previous_seek)
ext_progress.update(min(content_frames, seek) - previous_seek)

return dict(
text=tokenizer.decode(all_tokens[len(initial_prompt_tokens) :]),
Expand Down

0 comments on commit bdbe6bf

Please sign in to comment.