Skip to content

Commit c9b0107

Browse files
committed
merge main
2 parents 84bb171 + ffb105a commit c9b0107

File tree

18 files changed

+465
-140
lines changed

18 files changed

+465
-140
lines changed

setup.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
setup(
44
name='stable-audio-tools-dev',
5-
version='0.0.9',
5+
version='0.0.10',
66
url='https://github.com/Stability-AI/stable-audio-tools-dev.git',
77
author='Stability AI',
88
description='Training and inference tools for generative audio models from Stability AI',
@@ -17,7 +17,8 @@
1717
'einops-exts==0.0.4',
1818
'ema-pytorch==0.2.3',
1919
'encodec==0.1.1',
20-
'gradio==3.42.0',
20+
'flash-attn>=2.5.0',
21+
'gradio>=3.42.0',
2122
'huggingface_hub',
2223
'importlib-resources==5.12.0',
2324
'k-diffusion==0.1.1',

stable_audio_tools/data/dataset.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -589,7 +589,7 @@ def create_dataloader_from_config(dataset_config, batch_size, sample_size, sampl
589589
sample_rate=sample_rate,
590590
sample_size=sample_size,
591591
batch_size=batch_size,
592-
random_crop=True,
592+
random_crop=dataset_config.get("random_crop", True),
593593
num_workers=num_workers,
594594
persistent_workers=True,
595595
force_channels=force_channels,

stable_audio_tools/inference/generation.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,8 @@ def generate_diffusion_cond(
211211
# Denoising process done.
212212
# If this is latent diffusion, decode latents back into audio
213213
if model.pretransform is not None and not return_latents:
214+
#cast sampled latents to pretransform dtype
215+
sampled = sampled.to(next(model.pretransform.parameters()).dtype)
214216
sampled = model.pretransform.decode(sampled)
215217

216218
# Return audio

stable_audio_tools/interface/gradio.py

Lines changed: 22 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -64,10 +64,10 @@ def generate_cond(
6464
steps=250,
6565
preview_every=None,
6666
seed=-1,
67-
sampler_type="dpmpp-2m-sde",
67+
sampler_type="dpmpp-3m-sde",
6868
sigma_min=0.03,
69-
sigma_max=50,
70-
cfg_rescale=0.4,
69+
sigma_max=1000,
70+
cfg_rescale=0.0,
7171
use_init=False,
7272
init_audio=None,
7373
init_noise_level=1.0,
@@ -86,6 +86,8 @@ def generate_cond(
8686
torch.cuda.empty_cache()
8787
gc.collect()
8888

89+
print(f"Prompt: {prompt}")
90+
8991
global preview_images
9092
preview_images = []
9193
if preview_every == 0:
@@ -185,7 +187,7 @@ def progress_callback(callback_info):
185187

186188
# Convert to WAV file
187189
audio = rearrange(audio, "b d n -> d (b n)")
188-
audio = audio.clamp(-1, 1).mul(32767).to(torch.int16).cpu()
190+
audio = audio.to(torch.float32).div(torch.max(torch.abs(audio))).clamp(-1, 1).mul(32767).to(torch.int16).cpu()
189191
torchaudio.save("output.wav", audio, sample_rate)
190192

191193
# Let's look at a nice spectrogram too
@@ -196,9 +198,9 @@ def progress_callback(callback_info):
196198
def generate_uncond(
197199
steps=250,
198200
seed=-1,
199-
sampler_type="dpmpp-2m-sde",
201+
sampler_type="dpmpp-3m-sde",
200202
sigma_min=0.03,
201-
sigma_max=50,
203+
sigma_max=1000,
202204
use_init=False,
203205
init_audio=None,
204206
init_noise_level=1.0,
@@ -282,7 +284,7 @@ def progress_callback(callback_info):
282284

283285
audio = rearrange(audio, "b d n -> d (b n)")
284286

285-
audio = audio.clamp(-1, 1).mul(32767).to(torch.int16).cpu()
287+
audio = audio.to(torch.float32).div(torch.max(torch.abs(audio))).clamp(-1, 1).mul(32767).to(torch.int16).cpu()
286288

287289
torchaudio.save("output.wav", audio, sample_rate)
288290

@@ -296,7 +298,6 @@ def generate_lm(
296298
top_k=0,
297299
batch_size=1,
298300
):
299-
300301

301302
if torch.cuda.is_available():
302303
torch.cuda.empty_cache()
@@ -317,7 +318,7 @@ def generate_lm(
317318

318319
audio = rearrange(audio, "b d n -> d (b n)")
319320

320-
audio = audio.clamp(-1, 1).mul(32767).to(torch.int16).cpu()
321+
audio = audio.to(torch.float32).div(torch.max(torch.abs(audio))).clamp(-1, 1).mul(32767).to(torch.int16).cpu()
321322

322323
torchaudio.save("output.wav", audio, sample_rate)
323324

@@ -342,9 +343,9 @@ def create_uncond_sampling_ui(model_config):
342343

343344
# Sampler params
344345
with gr.Row():
345-
sampler_type_dropdown = gr.Dropdown(["dpmpp-2m-sde", "dpmpp-3m-sde", "k-heun", "k-lms", "k-dpmpp-2s-ancestral", "k-dpm-2", "k-dpm-fast"], label="Sampler type", value="dpmpp-2m-sde")
346+
sampler_type_dropdown = gr.Dropdown(["dpmpp-2m-sde", "dpmpp-3m-sde", "k-heun", "k-lms", "k-dpmpp-2s-ancestral", "k-dpm-2", "k-dpm-fast"], label="Sampler type", value="dpmpp-3m-sde")
346347
sigma_min_slider = gr.Slider(minimum=0.0, maximum=2.0, step=0.01, value=0.03, label="Sigma min")
347-
sigma_max_slider = gr.Slider(minimum=0.0, maximum=200.0, step=0.1, value=80, label="Sigma max")
348+
sigma_max_slider = gr.Slider(minimum=0.0, maximum=1000.0, step=0.1, value=1000, label="Sigma max")
348349

349350
with gr.Accordion("Init audio", open=False):
350351
init_audio_checkbox = gr.Checkbox(label="Use init audio")
@@ -417,10 +418,10 @@ def create_sampling_ui(model_config, inpainting=False):
417418

418419
# Sampler params
419420
with gr.Row():
420-
sampler_type_dropdown = gr.Dropdown(["dpmpp-2m-sde", "dpmpp-3m-sde", "k-heun", "k-lms", "k-dpmpp-2s-ancestral", "k-dpm-2", "k-dpm-fast"], label="Sampler type", value="dpmpp-2m-sde")
421+
sampler_type_dropdown = gr.Dropdown(["dpmpp-2m-sde", "dpmpp-3m-sde", "k-heun", "k-lms", "k-dpmpp-2s-ancestral", "k-dpm-2", "k-dpm-fast"], label="Sampler type", value="dpmpp-3m-sde")
421422
sigma_min_slider = gr.Slider(minimum=0.0, maximum=2.0, step=0.01, value=0.03, label="Sigma min")
422-
sigma_max_slider = gr.Slider(minimum=0.0, maximum=200.0, step=0.1, value=80, label="Sigma max")
423-
cfg_rescale_slider = gr.Slider(minimum=0.0, maximum=1, step=0.01, value=0.2, label="CFG rescale amount")
423+
sigma_max_slider = gr.Slider(minimum=0.0, maximum=1000.0, step=0.1, value=1000, label="Sigma max")
424+
cfg_rescale_slider = gr.Slider(minimum=0.0, maximum=1, step=0.01, value=0.0, label="CFG rescale amount")
424425

425426
if inpainting:
426427
# Inpainting Tab
@@ -547,18 +548,18 @@ def autoencoder_process(audio, latent_noise, n_quantizers):
547548
audio = audio.to(dtype)
548549

549550
if n_quantizers > 0:
550-
latents = model.encode_audio(audio, chunked=True, n_quantizers=n_quantizers)
551+
latents = model.encode_audio(audio, chunked=False, n_quantizers=n_quantizers)
551552
else:
552-
latents = model.encode_audio(audio, chunked=True)
553+
latents = model.encode_audio(audio, chunked=False)
553554

554555
if latent_noise > 0:
555556
latents = latents + torch.randn_like(latents) * latent_noise
556557

557-
audio = model.decode_audio(latents, chunked=True)
558+
audio = model.decode_audio(latents, chunked=False)
558559

559560
audio = rearrange(audio, "b d n -> d (b n)")
560561

561-
audio = audio.clamp(-1, 1).mul(32767).to(torch.int16).cpu()
562+
audio = audio.to(torch.float32).clamp(-1, 1).mul(32767).to(torch.int16).cpu()
562563

563564
torchaudio.save("output.wav", audio, sample_rate)
564565

@@ -607,7 +608,7 @@ def diffusion_prior_process(audio, steps, sampler_type, sigma_min, sigma_max):
607608

608609
audio = rearrange(audio, "b d n -> d (b n)")
609610

610-
audio = audio.clamp(-1, 1).mul(32767).to(torch.int16).cpu()
611+
audio = audio.to(torch.float32).div(torch.max(torch.abs(audio))).clamp(-1, 1).mul(32767).to(torch.int16).cpu()
611612

612613
torchaudio.save("output.wav", audio, sample_rate)
613614

@@ -620,9 +621,9 @@ def create_diffusion_prior_ui(model_config):
620621
# Sampler params
621622
with gr.Row():
622623
steps_slider = gr.Slider(minimum=1, maximum=500, step=1, value=100, label="Steps")
623-
sampler_type_dropdown = gr.Dropdown(["dpmpp-2m-sde", "dpmpp-3m-sde", "k-heun", "k-lms", "k-dpmpp-2s-ancestral", "k-dpm-2", "k-dpm-fast"], label="Sampler type", value="dpmpp-2m-sde")
624+
sampler_type_dropdown = gr.Dropdown(["dpmpp-2m-sde", "dpmpp-3m-sde", "k-heun", "k-lms", "k-dpmpp-2s-ancestral", "k-dpm-2", "k-dpm-fast"], label="Sampler type", value="dpmpp-3m-sde")
624625
sigma_min_slider = gr.Slider(minimum=0.0, maximum=2.0, step=0.01, value=0.03, label="Sigma min")
625-
sigma_max_slider = gr.Slider(minimum=0.0, maximum=200.0, step=0.1, value=80, label="Sigma max")
626+
sigma_max_slider = gr.Slider(minimum=0.0, maximum=1000.0, step=0.1, value=1000, label="Sigma max")
626627
process_button = gr.Button("Process", variant='primary', scale=1)
627628
process_button.click(fn=diffusion_prior_process, inputs=[input_audio, steps_slider, sampler_type_dropdown, sigma_min_slider, sigma_max_slider], outputs=output_audio, api_name="process")
628629

stable_audio_tools/models/adp.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# Copied and modified from https://github.com/archinetai/audio-diffusion-pytorch/blob/v0.0.94/audio_diffusion_pytorch/modules.py under MIT License
22
# License can be found in LICENSES/LICENSE_ADP.txt
33

4+
import math
45
from inspect import isfunction
56
from math import ceil, floor, log, pi, log2
67
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, TypeVar, Union

0 commit comments

Comments
 (0)