Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .python-version
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
3.12
14 changes: 5 additions & 9 deletions example_for_mac.py → examples/example_for_mac.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,15 @@
"""
uv run examples/example_for_mac.py
"""
import torch
import torchaudio as ta
from chatterbox.tts import ChatterboxTTS
from chatterbox.models.utils import get_device

# Detect device (Mac with M1/M2/M3/M4)
device = "mps" if torch.backends.mps.is_available() else "cpu"
device = get_device()
map_location = torch.device(device)

torch_load_original = torch.load
def patched_torch_load(*args, **kwargs):
if 'map_location' not in kwargs:
kwargs['map_location'] = map_location
return torch_load_original(*args, **kwargs)

torch.load = patched_torch_load

model = ChatterboxTTS.from_pretrained(device=device)
text = "Today is the day. I want to move like a titan at dawn, sweat like a god forging lightning. No more excuses. From now on, my mornings will be temples of discipline. I am going to work out like the gods… every damn day."

Expand Down
12 changes: 5 additions & 7 deletions example_tts.py → examples/example_tts.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
"""
uv run examples/example_tts.py
"""
import torchaudio as ta
import torch
from chatterbox.tts import ChatterboxTTS
from chatterbox.mtl_tts import ChatterboxMultilingualTTS
from chatterbox.models.utils import get_device

# Automatically detect the best available device
if torch.cuda.is_available():
device = "cuda"
elif torch.backends.mps.is_available():
device = "mps"
else:
device = "cpu"
device = get_device()

print(f"Using device: {device}")

Expand Down
12 changes: 5 additions & 7 deletions example_vc.py → examples/example_vc.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
import torch
"""
uv run examples/example_vc.py
"""
import torchaudio as ta

from chatterbox.vc import ChatterboxVC
from chatterbox.models.utils import get_device

# Automatically detect the best available device
if torch.cuda.is_available():
device = "cuda"
elif torch.backends.mps.is_available():
device = "mps"
else:
device = "cpu"
device = get_device()

print(f"Using device: {device}")

Expand Down
6 changes: 5 additions & 1 deletion gradio_tts_app.py → examples/gradio_tts_app.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
"""
uv run examples/gradio_tts_app.py
"""
import random
import numpy as np
import torch
import gradio as gr
from chatterbox.tts import ChatterboxTTS
from chatterbox.models.utils import get_device


DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DEVICE = get_device()


def set_seed(seed: int):
Expand Down
7 changes: 5 additions & 2 deletions gradio_vc_app.py → examples/gradio_vc_app.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import torch
"""
uv run examples/gradio_vc_app.py
"""
import gradio as gr
from chatterbox.vc import ChatterboxVC
from chatterbox.models.utils import get_device


DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DEVICE = get_device()


model = ChatterboxVC.from_pretrained(DEVICE)
Expand Down
6 changes: 5 additions & 1 deletion multilingual_app.py → examples/multilingual_app.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
"""
uv run examples/multilingual_app.py
"""
import random
import numpy as np
import torch
from chatterbox.mtl_tts import ChatterboxMultilingualTTS, SUPPORTED_LANGUAGES
from chatterbox.models.utils import get_device
import gradio as gr

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DEVICE = get_device()
print(f"🚀 Running on device: {DEVICE}")

# --- Global Model Initialization ---
Expand Down
39 changes: 21 additions & 18 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[project]
name = "chatterbox-tts"
name = "chatterbox"
version = "0.1.4"
description = "Chatterbox: Open Source TTS and Voice Conversion by Resemble AI"
readme = "README.md"
Expand All @@ -9,29 +9,32 @@ authors = [
{name = "resemble-ai", email = "[email protected]"}
]
dependencies = [
"numpy>=1.24.0,<1.26.0",
"librosa==0.11.0",
"s3tokenizer",
"torch==2.6.0",
"torchaudio==2.6.0",
"conformer>=0.3.2",
"diffusers>=0.35.1",
"gradio>=5.44.1",
"librosa>=0.11.0",
"numpy>=2.2.6",
"pykakasi>=2.3.0",
"resemble-perth>=1.0.1",
"s3tokenizer>=0.2.0",
"safetensors>=0.6.2",
"setuptools>=80.9.0",
"torch>=2.8.0",
"torchaudio>=2.8.0",
"transformers==4.46.3",
"diffusers==0.29.0",
"resemble-perth==1.0.1",
"conformer==0.3.2",
"safetensors==0.5.3",
"pkuseg ==0.0.25",
"pykakasi==2.3.0",
"gradio==5.44.1",
]

# extra chinese support
# to install use pip install or uv sync --extra chinese
[project.optional-dependencies]
chinese = [
"pkuseg>=0.0.25",
]

[project.urls]
Homepage = "https://github.com/resemble-ai/chatterbox"
Repository = "https://github.com/resemble-ai/chatterbox"

[build-system]
requires = ["setuptools>=61.0"]
build-backend = "setuptools.build_meta"

[tool.setuptools.packages.find]
where = ["src"]
requires = ["hatchling"]
build-backend = "hatchling.build"
2 changes: 1 addition & 1 deletion src/chatterbox/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
except ImportError:
from importlib_metadata import version # For Python <3.8

__version__ = version("chatterbox-tts")
__version__ = version("chatterbox")


from .tts import ChatterboxTTS
Expand Down
22 changes: 22 additions & 0 deletions src/chatterbox/models/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,25 @@
import torch


def get_device() -> str:
"""
Get the best available device for PyTorch computations.

Priority order:
1. CUDA - if available (fastest for most operations)
2. MPS - if available on Apple Silicon (good performance on Mac)
3. CPU - fallback (always available)

Returns:
str: Device string ('cuda', 'mps', or 'cpu')
"""
if torch.cuda.is_available():
return "cuda"
elif torch.backends.mps.is_available():
return "mps"
else:
return "cpu"

class AttrDict(dict):
def __init__(self, *args, **kwargs):
super(AttrDict, self).__init__(*args, **kwargs)
Expand Down
2 changes: 1 addition & 1 deletion src/chatterbox/mtl_tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ def from_local(cls, ckpt_dir, device) -> 'ChatterboxMultilingualTTS':

s3gen = S3Gen()
s3gen.load_state_dict(
torch.load(ckpt_dir / "s3gen.pt", weights_only=True)
torch.load(ckpt_dir / "s3gen.pt", weights_only=True, map_location=device)
)
s3gen.to(device).eval()

Expand Down
8 changes: 1 addition & 7 deletions src/chatterbox/tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,13 +166,7 @@ def from_local(cls, ckpt_dir, device) -> 'ChatterboxTTS':

@classmethod
def from_pretrained(cls, device) -> 'ChatterboxTTS':
# Check if MPS is available on macOS
if device == "mps" and not torch.backends.mps.is_available():
if not torch.backends.mps.is_built():
print("MPS not available because the current PyTorch install was not built with MPS enabled.")
else:
print("MPS not available because the current MacOS version is not 12.3+ and/or you do not have an MPS-enabled device on this machine.")
device = "cpu"
# Use the requested device directly

for fpath in ["ve.safetensors", "t3_cfg.safetensors", "s3gen.safetensors", "tokenizer.json", "conds.pt"]:
local_path = hf_hub_download(repo_id=REPO_ID, filename=fpath)
Expand Down
8 changes: 1 addition & 7 deletions src/chatterbox/vc.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,13 +60,7 @@ def from_local(cls, ckpt_dir, device) -> 'ChatterboxVC':

@classmethod
def from_pretrained(cls, device) -> 'ChatterboxVC':
# Check if MPS is available on macOS
if device == "mps" and not torch.backends.mps.is_available():
if not torch.backends.mps.is_built():
print("MPS not available because the current PyTorch install was not built with MPS enabled.")
else:
print("MPS not available because the current MacOS version is not 12.3+ and/or you do not have an MPS-enabled device on this machine.")
device = "cpu"
# Use the requested device directly

for fpath in ["s3gen.safetensors", "conds.pt"]:
local_path = hf_hub_download(repo_id=REPO_ID, filename=fpath)
Expand Down
Loading