Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: include openai-whisper into thirdparty #2232

Merged
merged 2 commits into from
Sep 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions MANIFEST.in
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,5 @@ global-exclude conftest.py
include xinference/locale/*.json
include xinference/model/llm/*.json
include xinference/model/embedding/*.json
graft xinference/thirdparty
global-include xinference/web/ui/build/**/*
8 changes: 3 additions & 5 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,7 @@ packages = find:
install_requires =
xoscar>=0.3.0
torch
gradio==4.26.0
typer[all]<0.12.0 # fix typer required by gradio
gradio
pillow
click
tqdm>=4.27
Expand All @@ -44,7 +43,7 @@ install_requires =
python-jose[cryptography]
passlib[bcrypt]
aioprometheus[starlette]>=23.12.0
pynvml
nvidia-ml-py
async-timeout
peft
timm
Expand Down Expand Up @@ -121,7 +120,6 @@ all =
pyarrow # For CosyVoice, matcha
HyperPyYAML # For CosyVoice
onnxruntime==1.16.0 # For CosyVoice, use onnxruntime-gpu==1.16.0 if possible
openai-whisper # For CosyVoice
boto3>=1.28.55,<1.28.65 # For tensorizer
tensorizer~=2.9.0
eva-decord # For video in VL
Expand Down Expand Up @@ -180,6 +178,7 @@ audio =
xxhash
torchaudio
ChatTTS>0.1
tiktoken # For CosyVoice, openai-whisper
torch>=2.0.0 # For CosyVoice, matcha
lightning>=2.0.0 # For CosyVoice, matcha
hydra-core>=1.3.2 # For CosyVoice, matcha
Expand All @@ -190,7 +189,6 @@ audio =
pyarrow # For CosyVoice, matcha
HyperPyYAML # For CosyVoice
onnxruntime==1.16.0 # For CosyVoice, use onnxruntime-gpu==1.16.0 if possible
openai-whisper # For CosyVoice
loguru # For Fish Speech
natsort # For Fish Speech
loralib # For Fish Speech
Expand Down
1 change: 0 additions & 1 deletion xinference/deploy/docker/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,6 @@ pyarrow # For CosyVoice, matcha
HyperPyYAML # For CosyVoice
onnxruntime-gpu==1.16.0; sys_platform == 'linux' # For CosyVoice
onnxruntime==1.16.0; sys_platform == 'darwin' or sys_platform == 'windows' # For CosyVoice
openai-whisper # For CosyVoice
boto3>=1.28.55,<1.28.65 # For tensorizer
tensorizer~=2.9.0
imageio-ffmpeg # For video
Expand Down
1 change: 0 additions & 1 deletion xinference/deploy/docker/requirements_cpu.txt
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@ pyarrow # For CosyVoice, matcha
HyperPyYAML # For CosyVoice
onnxruntime-gpu==1.16.0; sys_platform == 'linux' # For CosyVoice
onnxruntime==1.16.0; sys_platform == 'darwin' or sys_platform == 'windows' # For CosyVoice
openai-whisper # For CosyVoice
imageio-ffmpeg # For video
eva-decord # For video in VL
jj-pytorchvideo # For CogVLM2-video
Expand Down
156 changes: 156 additions & 0 deletions xinference/thirdparty/whisper/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
import hashlib
import io
import os
import urllib
import warnings
from typing import List, Optional, Union

import torch
from tqdm import tqdm

from .audio import load_audio, log_mel_spectrogram, pad_or_trim
from .decoding import DecodingOptions, DecodingResult, decode, detect_language
from .model import ModelDimensions, Whisper
from .transcribe import transcribe
from .version import __version__

_MODELS = {
"tiny.en": "https://openaipublic.azureedge.net/main/whisper/models/d3dd57d32accea0b295c96e26691aa14d8822fac7d9d27d5dc00b4ca2826dd03/tiny.en.pt",
"tiny": "https://openaipublic.azureedge.net/main/whisper/models/65147644a518d12f04e32d6f3b26facc3f8dd46e5390956a9424a650c0ce22b9/tiny.pt",
"base.en": "https://openaipublic.azureedge.net/main/whisper/models/25a8566e1d0c1e2231d1c762132cd20e0f96a85d16145c3a00adf5d1ac670ead/base.en.pt",
"base": "https://openaipublic.azureedge.net/main/whisper/models/ed3a0b6b1c0edf879ad9b11b1af5a0e6ab5db9205f891f668f8b0e6c6326e34e/base.pt",
"small.en": "https://openaipublic.azureedge.net/main/whisper/models/f953ad0fd29cacd07d5a9eda5624af0f6bcf2258be67c92b79389873d91e0872/small.en.pt",
"small": "https://openaipublic.azureedge.net/main/whisper/models/9ecf779972d90ba49c06d968637d720dd632c55bbf19d441fb42bf17a411e794/small.pt",
"medium.en": "https://openaipublic.azureedge.net/main/whisper/models/d7440d1dc186f76616474e0ff0b3b6b879abc9d1a4926b7adfa41db2d497ab4f/medium.en.pt",
"medium": "https://openaipublic.azureedge.net/main/whisper/models/345ae4da62f9b3d59415adc60127b97c714f32e89e936602e85993674d08dcb1/medium.pt",
"large-v1": "https://openaipublic.azureedge.net/main/whisper/models/e4b87e7e0bf463eb8e6956e646f1e277e901512310def2c24bf0e11bd3c28e9a/large-v1.pt",
"large-v2": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt",
"large-v3": "https://openaipublic.azureedge.net/main/whisper/models/e5b1a55b89c1367dacf97e3e19bfd829a01529dbfdeefa8caeb59b3f1b81dadb/large-v3.pt",
"large": "https://openaipublic.azureedge.net/main/whisper/models/e5b1a55b89c1367dacf97e3e19bfd829a01529dbfdeefa8caeb59b3f1b81dadb/large-v3.pt",
}

# base85-encoded (n_layers, n_heads) boolean arrays indicating the cross-attention heads that are
# highly correlated to the word-level timing, i.e. the alignment between audio and text tokens.
_ALIGNMENT_HEADS = {
"tiny.en": b"ABzY8J1N>@0{>%R00Bk>$p{7v037`oCl~+#00",
"tiny": b"ABzY8bu8Lr0{>%RKn9Fp%m@SkK7Kt=7ytkO",
"base.en": b"ABzY8;40c<0{>%RzzG;p*o+Vo09|#PsxSZm00",
"base": b"ABzY8KQ!870{>%RzyTQH3`Q^yNP!>##QT-<FaQ7m",
"small.en": b"ABzY8>?_)10{>%RpeA61k&I|OI3I$65C{;;pbCHh0B{qLQ;+}v00",
"small": b"ABzY8DmU6=0{>%Rpa?J`kvJ6qF(V^F86#Xh7JUGMK}P<N0000",
"medium.en": b"ABzY8usPae0{>%R7<zz_OvQ{)4kMa0BMw6u5rT}kRKX;$NfYBv00*Hl@qhsU00",
"medium": b"ABzY8B0Jh+0{>%R7}kK1fFL7w6%<-Pf*t^=N)Qr&0RR9",
"large-v1": b"ABzY8r9j$a0{>%R7#4sLmoOs{s)o3~84-RPdcFk!JR<kSfC2yj",
"large-v2": b"ABzY8zd+h!0{>%R7=D0pU<_bnWW*tkYAhobTNnu$jnkEkXqp)j;w1Tzk)UH3X%SZd&fFZ2fC2yj",
"large-v3": b"ABzY8gWO1E0{>%R7(9S+Kn!D~%ngiGaR?*L!iJG9p-nab0JQ=-{D1-g00",
"large": b"ABzY8gWO1E0{>%R7(9S+Kn!D~%ngiGaR?*L!iJG9p-nab0JQ=-{D1-g00",
}


def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]:
os.makedirs(root, exist_ok=True)

expected_sha256 = url.split("/")[-2]
download_target = os.path.join(root, os.path.basename(url))

if os.path.exists(download_target) and not os.path.isfile(download_target):
raise RuntimeError(f"{download_target} exists and is not a regular file")

if os.path.isfile(download_target):
with open(download_target, "rb") as f:
model_bytes = f.read()
if hashlib.sha256(model_bytes).hexdigest() == expected_sha256:
return model_bytes if in_memory else download_target
else:
warnings.warn(
f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file"
)

with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
with tqdm(
total=int(source.info().get("Content-Length")),
ncols=80,
unit="iB",
unit_scale=True,
unit_divisor=1024,
) as loop:
while True:
buffer = source.read(8192)
if not buffer:
break

output.write(buffer)
loop.update(len(buffer))

model_bytes = open(download_target, "rb").read()
if hashlib.sha256(model_bytes).hexdigest() != expected_sha256:
raise RuntimeError(
"Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model."
)

return model_bytes if in_memory else download_target


def available_models() -> List[str]:
"""Returns the names of available models"""
return list(_MODELS.keys())


def load_model(
name: str,
device: Optional[Union[str, torch.device]] = None,
download_root: str = None,
in_memory: bool = False,
) -> Whisper:
"""
Load a Whisper ASR model

Parameters
----------
name : str
one of the official model names listed by `whisper.available_models()`, or
path to a model checkpoint containing the model dimensions and the model state_dict.
device : Union[str, torch.device]
the PyTorch device to put the model into
download_root: str
path to download the model files; by default, it uses "~/.cache/whisper"
in_memory: bool
whether to preload the model weights into host memory

Returns
-------
model : Whisper
The Whisper ASR model instance
"""

if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
if download_root is None:
default = os.path.join(os.path.expanduser("~"), ".cache")
download_root = os.path.join(os.getenv("XDG_CACHE_HOME", default), "whisper")

if name in _MODELS:
checkpoint_file = _download(_MODELS[name], download_root, in_memory)
alignment_heads = _ALIGNMENT_HEADS[name]
elif os.path.isfile(name):
checkpoint_file = open(name, "rb").read() if in_memory else name
alignment_heads = None
else:
raise RuntimeError(
f"Model {name} not found; available models = {available_models()}"
)

with (
io.BytesIO(checkpoint_file) if in_memory else open(checkpoint_file, "rb")
) as fp:
checkpoint = torch.load(fp, map_location=device)
del checkpoint_file

dims = ModelDimensions(**checkpoint["dims"])
model = Whisper(dims)
model.load_state_dict(checkpoint["model_state_dict"])

if alignment_heads is not None:
model.set_alignment_heads(alignment_heads)

return model.to(device)
3 changes: 3 additions & 0 deletions xinference/thirdparty/whisper/__main__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .transcribe import cli

cli()
Loading
Loading