diff --git a/README.md b/README.md index a8f5ce3f..bb4915a8 100644 --- a/README.md +++ b/README.md @@ -10,7 +10,7 @@ $ pip install stable-audio-tools To run the training scripts or inference code, you'll want to clone this repository, navigate to the root, and run: ```bash -$ pip install . +$ pip install .[train,viz,ui] ``` # Requirements diff --git a/setup.py b/setup.py index 7e7470d3..a3d07c40 100644 --- a/setup.py +++ b/setup.py @@ -6,39 +6,48 @@ url='https://github.com/Stability-AI/stable-audio-tools.git', author='Stability AI', description='Training and inference tools for generative audio models from Stability AI', - packages=find_packages(), + packages=find_packages(), install_requires=[ - 'aeiou==0.0.20', 'alias-free-torch==0.0.6', - 'auraloss==0.4.0', 'descript-audio-codec==1.0.0', - 'einops==0.7.0', 'einops-exts==0.0.4', - 'ema-pytorch==0.2.3', - 'encodec==0.1.1', - 'gradio>=3.42.0', - 'huggingface_hub', - 'importlib-resources==5.12.0', - 'k-diffusion==0.1.1', + 'einops<0.9.0', + 'k-diffusion>=0.1.1', 'laion-clap==1.1.4', - 'local-attention==1.8.6', - 'pandas==2.0.2', - 'pedalboard==0.7.4', - 'prefigure==0.0.9', - 'pytorch_lightning==2.1.0', - 'PyWavelets==1.4.1', 'safetensors', - 'sentencepiece==0.1.99', - 's3fs', 'torch>=2.0.1', 'torchaudio>=2.0.2', - 'torchmetrics==0.11.4', - 'tqdm', 'transformers', - 'v-diffusion-pytorch==0.0.2', 'vector-quantize-pytorch==1.9.14', - 'wandb==0.15.4', - 'webdataset==0.2.48', - 'x-transformers<1.27.0' + 'x-transformers<1.27.0', ], -) \ No newline at end of file + extras_require={ + 'ui': [ + 'gradio>=3.42.0', + ], + 'viz': [ + 'aeiou==0.0.20', + ], + 'train': [ + 'auraloss==0.4.0', + 'ema-pytorch==0.2.3', + 'encodec==0.1.1', + 'huggingface_hub', + 'importlib-resources==5.12.0', + 'local-attention==1.8.6', + 'pandas==2.0.2', + 'pedalboard==0.7.4', + 'prefigure==0.0.9', + 'pytorch_lightning==2.1.0', + 'PyWavelets==1.4.1', + 'safetensors', + 'sentencepiece==0.1.99', + 's3fs', + 'tqdm', + 'transformers', + 'v-diffusion-pytorch==0.0.2', + 'wandb==0.15.4', + 'webdataset==0.2.48', + ], + }, +) diff --git a/stable_audio_tools/data/dataset.py b/stable_audio_tools/data/dataset.py index 4bc535a1..7db6b0e4 100644 --- a/stable_audio_tools/data/dataset.py +++ b/stable_audio_tools/data/dataset.py @@ -11,7 +11,6 @@ import torchaudio import webdataset as wds -from aeiou.core import is_silence from os import path from pedalboard.io import AudioFile from torchaudio import transforms as T @@ -367,6 +366,7 @@ def log_and_continue(exn): def is_valid_sample(sample): + from aeiou.core import is_silence has_json = "json" in sample has_audio = "audio" in sample is_silent = is_silence(sample["audio"]) diff --git a/stable_audio_tools/inference/sampling.py b/stable_audio_tools/inference/sampling.py index 2229e508..1dfbae64 100644 --- a/stable_audio_tools/inference/sampling.py +++ b/stable_audio_tools/inference/sampling.py @@ -1,8 +1,11 @@ import torch import math -from tqdm import trange, tqdm -import k_diffusion as K +from k_diffusion.external import VDenoiser +from k_diffusion.sampling import get_sigmas_polyexponential, sample_heun, sample_lms, sample_dpmpp_2s_ancestral, \ + sample_dpm_2, sample_dpm_fast, sample_dpm_adaptive, sample_dpmpp_2m_sde, sample_dpmpp_3m_sde +from k_diffusion.utils import append_dims +from tqdm import trange, tqdm # Define the noise schedule and sampling loop def get_alphas_sigmas(t): @@ -99,7 +102,7 @@ def cond_model_fn(x, sigma, **kwargs): x = x.detach().requires_grad_() denoised = model(x, sigma, **kwargs) cond_grad = cond_fn(x, sigma, denoised=denoised, **kwargs).detach() - cond_denoised = denoised.detach() + cond_grad * K.utils.append_dims(sigma**2, x.ndim) + cond_denoised = denoised.detach() + cond_grad * append_dims(sigma**2, x.ndim) return cond_denoised return cond_model_fn @@ -123,13 +126,13 @@ def sample_k( **extra_args ): - denoiser = K.external.VDenoiser(model_fn) + denoiser = VDenoiser(model_fn) if cond_fn is not None: denoiser = make_cond_model_fn(denoiser, cond_fn) # Make the list of sigmas. Sigma values are scalars related to the amount of noise each denoising step has - sigmas = K.sampling.get_sigmas_polyexponential(steps, sigma_min, sigma_max, rho, device=device) + sigmas = get_sigmas_polyexponential(steps, sigma_min, sigma_max, rho, device=device) # Scale the initial noise by sigma noise = noise * sigmas[0] @@ -176,21 +179,21 @@ def inpainting_callback(args): with torch.cuda.amp.autocast(): if sampler_type == "k-heun": - return K.sampling.sample_heun(denoiser, x, sigmas, disable=False, callback=wrapped_callback, extra_args=extra_args) + return sample_heun(denoiser, x, sigmas, disable=False, callback=wrapped_callback, extra_args=extra_args) elif sampler_type == "k-lms": - return K.sampling.sample_lms(denoiser, x, sigmas, disable=False, callback=wrapped_callback, extra_args=extra_args) + return sample_lms(denoiser, x, sigmas, disable=False, callback=wrapped_callback, extra_args=extra_args) elif sampler_type == "k-dpmpp-2s-ancestral": - return K.sampling.sample_dpmpp_2s_ancestral(denoiser, x, sigmas, disable=False, callback=wrapped_callback, extra_args=extra_args) + return sample_dpmpp_2s_ancestral(denoiser, x, sigmas, disable=False, callback=wrapped_callback, extra_args=extra_args) elif sampler_type == "k-dpm-2": - return K.sampling.sample_dpm_2(denoiser, x, sigmas, disable=False, callback=wrapped_callback, extra_args=extra_args) + return sample_dpm_2(denoiser, x, sigmas, disable=False, callback=wrapped_callback, extra_args=extra_args) elif sampler_type == "k-dpm-fast": - return K.sampling.sample_dpm_fast(denoiser, x, sigma_min, sigma_max, steps, disable=False, callback=wrapped_callback, extra_args=extra_args) + return sample_dpm_fast(denoiser, x, sigma_min, sigma_max, steps, disable=False, callback=wrapped_callback, extra_args=extra_args) elif sampler_type == "k-dpm-adaptive": - return K.sampling.sample_dpm_adaptive(denoiser, x, sigma_min, sigma_max, rtol=0.01, atol=0.01, disable=False, callback=wrapped_callback, extra_args=extra_args) + return sample_dpm_adaptive(denoiser, x, sigma_min, sigma_max, rtol=0.01, atol=0.01, disable=False, callback=wrapped_callback, extra_args=extra_args) elif sampler_type == "dpmpp-2m-sde": - return K.sampling.sample_dpmpp_2m_sde(denoiser, x, sigmas, disable=False, callback=wrapped_callback, extra_args=extra_args) + return sample_dpmpp_2m_sde(denoiser, x, sigmas, disable=False, callback=wrapped_callback, extra_args=extra_args) elif sampler_type == "dpmpp-3m-sde": - return K.sampling.sample_dpmpp_3m_sde(denoiser, x, sigmas, disable=False, callback=wrapped_callback, extra_args=extra_args) + return sample_dpmpp_3m_sde(denoiser, x, sigmas, disable=False, callback=wrapped_callback, extra_args=extra_args) # Uses discrete Euler sampling for rectified flow models # init_data is init_audio as latents (if this is latent diffusion) diff --git a/stable_audio_tools/interface/gradio.py b/stable_audio_tools/interface/gradio.py index f38468bc..3720226b 100644 --- a/stable_audio_tools/interface/gradio.py +++ b/stable_audio_tools/interface/gradio.py @@ -6,8 +6,8 @@ import json import torch import torchaudio +import logging -from aeiou.viz import audio_spectrogram_image from einops import rearrange from safetensors.torch import load_file from torch.nn import functional as F @@ -20,6 +20,12 @@ from ..inference.utils import prepare_audio from ..training.utils import copy_state_dict +try: + from aeiou.viz import audio_spectrogram_image +except ImportError: + audio_spectrogram_image = None + logging.warning("Could not import aeiou, audio spectrogram images will not be displayed") + model = None sample_rate = 32000 sample_size = 1920000 @@ -137,6 +143,8 @@ def generate_cond( def progress_callback(callback_info): global preview_images + if not audio_spectrogram_image: + return denoised = callback_info["denoised"] current_step = callback_info["i"] sigma = callback_info["sigma"] @@ -192,10 +200,11 @@ def progress_callback(callback_info): audio = audio.to(torch.float32).div(torch.max(torch.abs(audio))).clamp(-1, 1).mul(32767).to(torch.int16).cpu() torchaudio.save("output.wav", audio, sample_rate) - # Let's look at a nice spectrogram too - audio_spectrogram = audio_spectrogram_image(audio, sample_rate=sample_rate) + if audio_spectrogram_image: + # Let's look at a nice spectrogram too + preview_images.insert(0, audio_spectrogram_image(audio, sample_rate=sample_rate)) - return ("output.wav", [audio_spectrogram, *preview_images]) + return ("output.wav", preview_images) def generate_uncond( steps=250, @@ -290,9 +299,11 @@ def progress_callback(callback_info): torchaudio.save("output.wav", audio, sample_rate) - audio_spectrogram = audio_spectrogram_image(audio, sample_rate=sample_rate) + if audio_spectrogram_image: + audio_spectrogram = audio_spectrogram_image(audio, sample_rate=sample_rate) + preview_images.insert(0, audio_spectrogram) - return ("output.wav", [audio_spectrogram, *preview_images]) + return ("output.wav", preview_images) def generate_lm( temperature=1.0,