Skip to content

Commit

Permalink
Enable new models in audio-to-text (#163)
Browse files Browse the repository at this point in the history
* (a2t) add new models and optimization defaults
---------

Co-authored-by: Brad P <[email protected]>
Co-authored-by: Rick Staa <[email protected]>
  • Loading branch information
3 people authored Nov 5, 2024
1 parent d046b05 commit acf9b15
Show file tree
Hide file tree
Showing 14 changed files with 2,012 additions and 125 deletions.
1 change: 0 additions & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ require (
github.com/getkin/kin-openapi v0.128.0
github.com/go-chi/chi/v5 v5.1.0
github.com/oapi-codegen/runtime v1.1.1
github.com/pebbe/zmq4 v1.2.11
github.com/vincent-petithory/dataurl v1.0.0
)

Expand Down
2 changes: 0 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,6 @@ go.opentelemetry.io/proto/otlp v1.3.1/go.mod h1:0X1WI4de4ZsLrrJNLAQbFeLCm3T7yBkR
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/exp v0.0.0-20240416160154-fe59bbe5cc7f h1:99ci1mjWVBWwJiEKYY6jWa4d2nTQVIEhZIptnrVb1XY=
golang.org/x/exp v0.0.0-20240416160154-fe59bbe5cc7f/go.mod h1:/lliqkxwWAhPjf5oSOIJup2XcqJaw8RGS6k3TGEc7GI=
golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
Expand Down
111 changes: 83 additions & 28 deletions runner/app/pipelines/audio_to_text.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,58 @@
from enum import Enum
import logging
import os
from typing import List
from dataclasses import dataclass

import torch
from app.pipelines.base import Pipeline
from app.pipelines.utils import get_model_dir, get_torch_device
from app.pipelines.utils.audio import AudioConverter
from app.utils.errors import InferenceError
from fastapi import File, UploadFile
from huggingface_hub import file_download
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline

logger = logging.getLogger(__name__)


MODEL_INCOMPATIBLE_EXTENSIONS = {
"openai/whisper-large-v3": ["mp4", "m4a", "ac3"],
class ModelName(Enum):
"""Enumeration mapping model names to their corresponding IDs. Returns None if the
model ID is not found."""

WHISPER_LARGE_V3 = "openai/whisper-large-v3"
WHISPER_MEDIUM = "openai/whisper-medium"
WHISPER_DISTIL_LARGE_V3 = "distil-whisper/distil-large-v3"

@classmethod
def list(cls):
"""Return a list of all model IDs."""
return [model.value for model in cls]

@classmethod
def get(cls, model_id: str) -> Enum | None:
"""Return the enum or None if the model ID is not found."""
try:
return cls(model_id)
except ValueError:
return None


@dataclass
class ModelConfig:
"""Model configuration parameters."""

torch_dtype: torch.dtype = (
torch.float16 if torch.cuda.is_available() else torch.float32
)
chunk_length_s: int = 30


MODEL_CONFIGS = {
ModelName.WHISPER_LARGE_V3: ModelConfig(),
ModelName.WHISPER_MEDIUM: ModelConfig(torch_dtype=torch.float32),
ModelName.WHISPER_DISTIL_LARGE_V3: ModelConfig(chunk_length_s=25),
}
INCOMPATIBLE_EXTENSIONS = ["mp4", "m4a", "ac3"]


class AudioToTextPipeline(Pipeline):
Expand All @@ -25,31 +61,26 @@ def __init__(self, model_id: str):
kwargs = {}

torch_device = get_torch_device()
folder_name = file_download.repo_folder_name(
repo_id=model_id, repo_type="model"

# Get model specific configuration parameters.
model_enum = ModelName.get(model_id)
self._model_cfg = (
ModelConfig() if model_enum is None else MODEL_CONFIGS[model_enum]
)
folder_path = os.path.join(get_model_dir(), folder_name)
# Load fp16 variant if fp16 safetensors files are found in cache
has_fp16_variant = any(
".fp16.safetensors" in fname
for _, _, files in os.walk(folder_path)
for fname in files
kwargs["torch_dtype"] = self._model_cfg.torch_dtype
logger.info(
"AudioToText loading '%s' on device '%s' with '%s' variant",
model_id,
torch_device,
kwargs["torch_dtype"],
)
if torch_device.type != "cpu" and has_fp16_variant:
logger.info("AudioToTextPipeline loading fp16 variant for %s", model_id)

kwargs["torch_dtype"] = torch.float16
kwargs["variant"] = "fp16"

if os.environ.get("BFLOAT16"):
logger.info("AudioToTextPipeline using bfloat16 precision for %s", model_id)
kwargs["torch_dtype"] = torch.bfloat16

model = AutoModelForSpeechSeq2Seq.from_pretrained(
model_id,
low_cpu_mem_usage=True,
use_safetensors=True,
cache_dir=get_model_dir(),
attn_implementation="eager", # TODO: enable flash attention.
**kwargs,
).to(torch_device)

Expand All @@ -61,24 +92,48 @@ def __init__(self, model_id: str):
tokenizer=processor.tokenizer,
feature_extractor=processor.feature_extractor,
max_new_tokens=128,
chunk_length_s=30,
batch_size=16,
device=torch_device,
**kwargs,
)

def __call__(self, audio: UploadFile, **kwargs) -> List[File]:
self._audio_converter = AudioConverter()

def __call__(self, audio: UploadFile, duration: float, **kwargs) -> List[File]:
audioBytes = audio.file.read()

# Convert M4A/MP4 files for pipeline compatibility.
if (
os.path.splitext(audio.filename)[1].lower().lstrip(".")
in MODEL_INCOMPATIBLE_EXTENSIONS[self.model_id]
in INCOMPATIBLE_EXTENSIONS
):
audio_converter = AudioConverter()
converted_bytes = audio_converter.convert(audio, "mp3")
audio_converter.write_bytes_to_file(converted_bytes, audio)
audioBytes = self._audio_converter.convert(audioBytes, "mp3")

# Adjust batch size and chunk length based on timestamps and duration.
# NOTE: Done to prevent CUDA OOM errors for large audio files.
kwargs["batch_size"] = 16
kwargs["chunk_length_s"] = self._model_cfg.chunk_length_s
if kwargs["return_timestamps"] == "word":
if duration > 3600:
raise InferenceError(
f"Word timestamps are only supported for audio files up to 60 minutes for model {self.model_id}"
)
if duration > 200:
kwargs["batch_size"] = 4
if duration <= kwargs["chunk_length_s"]:
kwargs.pop("batch_size", None)
kwargs.pop("chunk_length_s", None)
inference_mode = "sequential"
else:
inference_mode = f"chunked (batch_size={kwargs['batch_size']}, chunk_length_s={kwargs['chunk_length_s']})"
logger.info(
f"AudioToTextPipeline: Starting inference mode={inference_mode} with duration={duration}"
)

try:
outputs = self.tm(audio.file.read(), **kwargs)
outputs = self.tm(audioBytes, **kwargs)
outputs.setdefault("chunks", [])
except torch.cuda.OutOfMemoryError as e:
raise e
except Exception as e:
raise InferenceError(original_exception=e)

Expand Down
22 changes: 4 additions & 18 deletions runner/app/pipelines/utils/audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from io import BytesIO

import av
from fastapi import UploadFile


class AudioConversionError(Exception):
Expand All @@ -20,13 +19,11 @@ class AudioConverter:
"""Converts audio files to different formats."""

@staticmethod
def convert(
upload_file: UploadFile, output_extension: str, output_codec=None
) -> bytes:
def convert(input_bytes: bytes, output_extension: str, output_codec=None) -> bytes:
"""Converts an audio file to a different format.
Args:
upload_file: The audio file to convert.
input_bytes: The audio file as bytes to convert.
output_extension: The desired output format.
output_codec: The desired output codec.
Expand All @@ -38,7 +35,8 @@ def convert(

output_buffer = BytesIO()

input_container = av.open(upload_file.file)
input_buffer = BytesIO(input_bytes)
input_container = av.open(input_buffer)
output_container = av.open(output_buffer, mode="w", format=output_extension)

try:
Expand All @@ -65,15 +63,3 @@ def convert(
output_buffer.seek(0)
converted_bytes = output_buffer.read()
return converted_bytes

@staticmethod
def write_bytes_to_file(bytes: bytes, upload_file: UploadFile):
"""Writes bytes to a file.
Args:
bytes: The bytes to write.
upload_file: The file to write to.
"""
upload_file.file.seek(0)
upload_file.file.write(bytes)
upload_file.file.seek(0)
25 changes: 24 additions & 1 deletion runner/app/routes/audio_to_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
HTTPError,
TextResponse,
file_exceeds_max_size,
parse_key_from_metadata,
get_media_duration_ffmpeg,
http_error,
handle_pipeline_exception,
)
Expand Down Expand Up @@ -106,6 +108,10 @@ async def audio_to_text(
)
),
] = "true",
metadata: Annotated[
str,
Form(description="Additional job information to be passed to the pipeline."),
] = "{}",
token: HTTPAuthorizationCredentials = Depends(HTTPBearer(auto_error=False)),
):
return_timestamps = parse_return_timestamps(return_timestamps)
Expand Down Expand Up @@ -134,9 +140,26 @@ async def audio_to_text(
)

try:
return pipeline(audio=audio, return_timestamps=return_timestamps)
duration = parse_key_from_metadata(metadata, "duration", float)
if duration is None:
logger.warning(
f"duration not provided in request, calculating with ffprobe"
)
duration = get_media_duration_ffmpeg(audio.file.read())
audio.file.seek(0) # Reset file pointer
except Exception as e:
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content=http_error("Unable to calculate duration of file"),
)

try:
return pipeline(
audio=audio, return_timestamps=return_timestamps, duration=duration
)
except Exception as e:
if isinstance(e, torch.cuda.OutOfMemoryError):
# TODO: Investigate why not all VRAM memory is cleared.
torch.cuda.empty_cache()
logger.error(f"AudioToText pipeline error: {e}")
return handle_pipeline_exception(
Expand Down
79 changes: 78 additions & 1 deletion runner/app/routes/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import io
import json
import os
import subprocess
import tempfile
from typing import Dict, List, Optional, Tuple, Union

import numpy as np
Expand Down Expand Up @@ -59,7 +61,9 @@ class MasksResponse(BaseModel):
class Chunk(BaseModel):
"""A chunk of text with a timestamp."""

timestamp: Tuple[float, float] = Field(..., description="The timestamp of the chunk.")
timestamp: Tuple[Optional[float], Optional[float]] = Field(
..., description="The timestamp of the chunk."
)
text: str = Field(..., description="The text of the chunk.")


Expand Down Expand Up @@ -270,3 +274,76 @@ def handle_pipeline_exception(
status_code=status_code,
content=content,
)


def parse_key_from_metadata(
metadata: str, key: str, expected_type: type
) -> Union[Optional[Union[str, int, float, bool]]]:
"""Parse a specific key from the metadata JSON string.
Args:
metadata: The metadata JSON string.
key: The key to parse from the metadata.
expected_type: The expected type of the key's value.
Returns:
The value of the key if it exists and is of the expected type, otherwise None.
Raises:
ValueError: If the metadata is not valid JSON.
TypeError: If the value is not of the expected type.
"""
try:
metadata_dict = json.loads(metadata)
except json.JSONDecodeError as e:
raise ValueError(f"Invalid JSON: {e}")

value = metadata_dict.get(key)
if value is not None:
if isinstance(value, expected_type):
return value
try:
return expected_type(value)
except (ValueError, TypeError):
raise TypeError(
f"Invalid {key} value. Must be of type {expected_type.__name__}."
)
return None


def get_media_duration_ffmpeg(bytes: bytes) -> float:
"""Gets the duration of the media using ffprobe.
Args:
bytes: The media file as bytes.
Returns:
The duration of the media in seconds.
"""
with tempfile.NamedTemporaryFile(delete=False) as temp_file:
temp_file.write(bytes)
temp_file_path = temp_file.name

try:
result = subprocess.run(
[
"ffprobe",
"-v",
"error",
"-show_entries",
"format=duration",
"-of",
"default=noprint_wrappers=1:nokey=1",
temp_file_path,
],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
)
duration = float(result.stdout.strip())
except Exception as e:
raise Exception(f"Failed to get duration with ffmpeg: {e}")
finally:
os.remove(temp_file_path)

return duration
2 changes: 2 additions & 0 deletions runner/dl_checkpoints.sh
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ function download_beta_models() {

# Download audio-to-text models.
huggingface-cli download openai/whisper-large-v3 --include "*.safetensors" "*.json" --cache-dir models
huggingface-cli download distil-whisper/distil-large-v3 --include "*.safetensors" "*.json" --cache-dir models
huggingface-cli download openai/whisper-medium --include "*.safetensors" "*.json" --cache-dir models

# Download custom pipeline models.
huggingface-cli download facebook/sam2-hiera-large --include "*.pt" "*.yaml" --cache-dir models
Expand Down
Loading

0 comments on commit acf9b15

Please sign in to comment.