Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
1159d11
Upgrade pytorch_lightning 2.1.0 to 2.4.0
sskalnik Sep 14, 2025
0f77d40
Add AIF and AIFF support. Use dill because pickling fails on Windows 11.
sskalnik Sep 14, 2025
4c038a7
Use dill for PreEncodedDataset because pickling fails on Windows 11.
sskalnik Sep 14, 2025
3fb8f3d
Initial commit of working SAO-1.0 and SAO-small training configs and …
sskalnik Sep 23, 2025
a39109d
Create TRAINING_FOR_NOOBS.md
sskalnik Sep 25, 2025
6be3a2f
Update TRAINING_FOR_NOOBS.md for clarity regarding /rawfiles.
sskalnik Sep 25, 2025
01c716a
Add support for AdamW8bit from bitsandbytes. Improve custom metadata …
sskalnik Sep 29, 2025
ca23373
Update TRAINING_FOR_NOOBS.md with Terminology section.
sskalnik Oct 4, 2025
fb033b6
Add experimental support for pytorch_optimizer.lr_scheduler.chebyshev.
sskalnik Oct 4, 2025
112b5d5
Add sao_small/acid_v2_base_model_config.json and sao_small/acid_v1_ba…
sskalnik Oct 4, 2025
5dd8cb7
Add sao_small/acid_v3_base_model_config.json. Note that proper "sampl…
sskalnik Oct 4, 2025
c6a824b
Add experimental support for google/t5gemma-b-b-ul2. Remove spurious …
sskalnik Oct 5, 2025
d4d3274
Change learning rate to 5e-5 (still experimenting with SAO-small). Us…
sskalnik Oct 6, 2025
531531a
Freeze requirements for 2025 Oct 07.
sskalnik Oct 8, 2025
632191e
Merge branch 'Stability-AI:main' into main
sskalnik Oct 10, 2025
75b25a6
Update TRAINING_FOR_NOOBS.md
sskalnik Oct 10, 2025
9d837d6
Remove onesided=True from mel_spectrogram_op. Add config files for ki…
sskalnik Oct 18, 2025
429278a
Add support for CosineAnnealingWarmRestarts.
sskalnik Oct 21, 2025
ff24e5a
Add zip files to .gitignore.
sskalnik Oct 30, 2025
47bb5cf
Add rawfiles* and pre_encoded* to .gitignore.
sskalnik Nov 1, 2025
4ef0193
Remove experimental artifacts for SAT upstream PR.
sskalnik Nov 1, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -161,4 +161,10 @@ cython_debug/

*.ckpt
*.wav
wandb/*
wandb/*

# Dataset folders and outputs
/outputs
/pre_encoded*
/rawfiles*
*.zip
8 changes: 8 additions & 0 deletions dataset_config.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
{
"dataset_type": "pre_encoded",
"datasets": [{
"id": "audio_pre_encoded",
"path": "pre_encoded",
"custom_metadata_module": "paths_md.py"
}]
}
17 changes: 17 additions & 0 deletions paths_md.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import os
import re


def get_custom_metadata(info, audio):
# Get filename without extension
file_name = os.path.basename(info["relpath"])
file_name_without_extension = os.path.splitext(file_name)[0]

# Replace non-alphanumeric characters with spaces, and remove leading/trailing spaces
cleaned_file_name = re.sub('[^0-9a-zA-Z]+', ' ', file_name_without_extension).strip()
#cleaned_file_name = re.match('', cleaned_file_name).groups()[0]

# Sanity check
print(f'{info["relpath"]} => {cleaned_file_name}')

return {"prompt": cleaned_file_name}
17 changes: 17 additions & 0 deletions paths_md_pre_encode.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import os
import re


def get_custom_metadata(info, audio):
# Get filename without extension
file_name = os.path.basename(info["relpath"])
file_name_without_extension = os.path.splitext(file_name)[0]

# Replace non-alphanumeric characters with spaces, and remove leading/trailing spaces
cleaned_file_name = re.sub('[^0-9a-zA-Z]+', ' ', file_name_without_extension).strip()
#cleaned_file_name = re.match('', cleaned_file_name).groups()[0]

# Sanity check
print(f'{info["relpath"]} => {cleaned_file_name}')

return {"prompt": cleaned_file_name}
11 changes: 11 additions & 0 deletions pe_dataset_config.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
{
"dataset_type": "audio_dir",
"datasets": [{
"id": "audio",
"path": "./rawfiles",
"custom_metadata_module": "./paths_md_pre_encode.py",
"drop_last": false
}],
"drop_last": false,
"random_crop": false
}
9 changes: 9 additions & 0 deletions pre-encode.bat
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
python ./pre_encode.py ^
--ckpt-path ./vae_model.ckpt ^
--model-config ./vae_model_config.json ^
--batch-size 8 ^
--dataset-config pe_dataset_config.json ^
--output-path ./pre_encoded ^
--model-half ^
--sample-size 131072 ^

5 changes: 4 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@
install_requires=[
'alias-free-torch==0.0.6',
'auraloss==0.4.0',
'bitsandbytes==0.47.0',
'descript-audio-codec==1.0.0',
'dill==0.4.0',
'einops',
'einops-exts',
'ema-pytorch==0.2.3',
Expand All @@ -23,7 +25,8 @@
'local-attention==1.8.6',
'pandas==2.0.2',
'prefigure==0.0.9',
'pytorch_lightning==2.1.0',
'pytorch_lightning==2.4.0',
'pytorch_optimizer==3.1.2',
'PyWavelets==1.4.1',
'safetensors',
'sentencepiece==0.1.99',
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
{
"model_type": "diffusion_cond",
"sample_size": 524288,
"sample_rate": 44100,
"audio_channels": 2,
"model": {
"pretransform": {
"type": "autoencoder",
"iterate_batch": true,
"model_half": true,
"chunked": true,
"config": {
"encoder": {
"type": "oobleck",
"requires_grad": false,
"config": {
"in_channels": 2,
"channels": 128,
"c_mults": [1, 2, 4, 8, 16],
"strides": [2, 4, 4, 8, 8],
"latent_dim": 128,
"use_snake": true
}
},
"decoder": {
"type": "oobleck",
"config": {
"out_channels": 2,
"channels": 128,
"c_mults": [1, 2, 4, 8, 16],
"strides": [2, 4, 4, 8, 8],
"latent_dim": 64,
"use_snake": true,
"final_tanh": false
}
},
"bottleneck": {
"type": "vae"
},
"latent_dim": 64,
"downsampling_ratio": 2048,
"io_channels": 2
}
},
"conditioning": {
"configs": [
{
"id": "prompt",
"type": "t5",
"config": {
"t5_model_name": "google/t5gemma-b-b-ul2",
"max_length": 128
}
},
{
"id": "seconds_total",
"type": "number",
"config": {
"min_val": 0,
"max_val": 256
}
}
],
"cond_dim": 768
},
"diffusion": {
"cross_attention_cond_ids": ["prompt", "seconds_total"],
"global_cond_ids": ["seconds_total"],
"diffusion_objective": "rectified_flow",
"distribution_shift_options": {
"min_length": 256,
"max_length": 4096
},
"type": "dit",
"config": {
"io_channels": 64,
"embed_dim": 1024,
"depth": 16,
"num_heads": 8,
"cond_token_dim": 768,
"global_cond_dim": 768,
"transformer_type": "continuous_transformer",
"attn_kwargs": {
"qk_norm": "ln"
}
}
},
"io_channels": 64
},
"training": {
"use_ema": true,
"log_loss_info": false,
"pre_encoded": true,
"timestep_sampler": "trunc_logit_normal",
"optimizer_configs": {
"diffusion": {
"optimizer": {
"type": "AdamW8bit",
"config": {
"lr": 1e-5,
"betas": [0.9, 0.999],
"eps": 1e-8,
"weight_decay": 1e-2,
"block_wise": true
}
},
"scheduler": {
"type": "CosineAnnealingWarmRestarts",
"config": {
"T_0": 10,
"T_mult": 2
}
}
}
},
"demo": {
"demo_every": 512,
"demo_steps": 100,
"num_demos": 7,
"demo_cond": [
{"prompt": "kick", "seconds_total": 2},
{"prompt": "bass", "seconds_total": 2},
{"prompt": "drum breaks 174 BPM", "seconds_total": 6},
{"prompt": "A short, beautiful piano riff in C minor", "seconds_total": 6},
{"prompt": "Tight Snare Drum", "seconds_total": 1},
{"prompt": "Glitchy bass design, I used Serum for this", "seconds_total": 4},
{"prompt": "Synth pluck arp with reverb and delay, 128 BPM", "seconds_total": 6}
],
"demo_cfg_scales": [0.5, 1, 1.5, 8]
}
}
}
41 changes: 29 additions & 12 deletions stable_audio_tools/data/dataset.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import dill
import importlib
import numpy as np
import io
Expand All @@ -19,7 +20,10 @@

from .utils import Stereo, Mono, PhaseFlipper, PadCrop_Normalized_T, VolumeNorm

AUDIO_KEYS = ("flac", "wav", "mp3", "m4a", "ogg", "opus")
from torchdata.stateful_dataloader import StatefulDataLoader


AUDIO_KEYS = ("flac", "wav", "mp3", "m4a", "ogg", "opus", "aiff", "aif")

# fast_scandir implementation by Scott Hawley originally in https://github.com/zqevans/audio-diffusion/blob/main/dataset/dataset.py

Expand Down Expand Up @@ -94,7 +98,7 @@ def keyword_scandir(
def get_audio_filenames(
paths: list, # directories in which to search
keywords=None,
exts=['.wav', '.mp3', '.flac', '.ogg', '.aif', '.opus']
exts=['.wav', '.mp3', '.flac', '.ogg', '.aif', '.opus', '.aif', '.aiff']
):
"recursively get a list of audio filenames"
filenames = []
Expand Down Expand Up @@ -178,7 +182,7 @@ def __init__(
self.root_paths.append(config.path)
self.filenames.extend(get_audio_filenames(config.path, keywords))
if config.custom_metadata_fn is not None:
self.custom_metadata_fns[config.path] = config.custom_metadata_fn
self.custom_metadata_fns[config.path] = dill.dumps(config.custom_metadata_fn)

print(f'Found {len(self.filenames)} files')

Expand Down Expand Up @@ -238,8 +242,8 @@ def __getitem__(self, idx):

for custom_md_path in self.custom_metadata_fns.keys():
if custom_md_path in audio_filename:
custom_metadata_fn = self.custom_metadata_fns[custom_md_path]
custom_metadata = custom_metadata_fn(info, audio)
custom_metadata_fn_deserialized = dill.loads(self.custom_metadata_fns[custom_md_path])
custom_metadata = custom_metadata_fn_deserialized(info, audio)
info.update(custom_metadata)

if "__reject__" in info and info["__reject__"]:
Expand Down Expand Up @@ -282,7 +286,7 @@ def __init__(
for config in configs:
self.filenames.extend(get_latent_filenames(config.path, [latent_extension]))
if config.custom_metadata_fn is not None:
self.custom_metadata_fns[config.path] = config.custom_metadata_fn
self.custom_metadata_fns[config.path] = dill.dumps(config.custom_metadata_fn)

self.latent_crop_length = latent_crop_length
self.random_crop = random_crop
Expand Down Expand Up @@ -339,8 +343,9 @@ def __getitem__(self, idx):

for custom_md_path in self.custom_metadata_fns.keys():
if custom_md_path in latent_filename:
custom_metadata_fn = self.custom_metadata_fns[custom_md_path]
custom_metadata = custom_metadata_fn(info, None)

custom_metadata_fn_deserialized = dill.loads(self.custom_metadata_fns[custom_md_path])
custom_metadata = custom_metadata_fn_deserialized(info, None)
info.update(custom_metadata)

if "__reject__" in info and info["__reject__"]:
Expand Down Expand Up @@ -849,8 +854,14 @@ def create_dataloader_from_config(dataset_config, batch_size, sample_size, sampl
force_channels=force_channels
)

return torch.utils.data.DataLoader(train_set, batch_size, shuffle=shuffle,
num_workers=num_workers, persistent_workers=True, pin_memory=True, drop_last=dataset_config.get("drop_last", True), collate_fn=collation_fn)
# https://docs.pytorch.org/docs/stable/notes/randomness.html#dataloader
g = torch.Generator()
g.manual_seed(0)

#return torch.utils.data.DataLoader(train_set, batch_size, shuffle=shuffle,
# num_workers=num_workers, persistent_workers=True, pin_memory=True, drop_last=dataset_config.get("drop_last", True), collate_fn=collation_fn, generator=g)
return StatefulDataLoader(train_set, batch_size, shuffle=shuffle,
num_workers=num_workers, persistent_workers=True, pin_memory=True, drop_last=dataset_config.get("drop_last", True), collate_fn=collation_fn, generator=g)

elif dataset_type == "pre_encoded":

Expand Down Expand Up @@ -899,8 +910,14 @@ def create_dataloader_from_config(dataset_config, batch_size, sample_size, sampl
latent_extension=latent_extension
)

return torch.utils.data.DataLoader(train_set, batch_size, shuffle=shuffle,
num_workers=num_workers, persistent_workers=True, pin_memory=True, drop_last=dataset_config.get("drop_last", True), collate_fn=collation_fn)
# https://docs.pytorch.org/docs/stable/notes/randomness.html#dataloader
g = torch.Generator()
g.manual_seed(0)

#return torch.utils.data.DataLoader(train_set, batch_size, shuffle=shuffle,
# num_workers=num_workers, persistent_workers=True, pin_memory=True, drop_last=dataset_config.get("drop_last", True), collate_fn=collation_fn, generator=g)
return StatefulDataLoader(train_set, batch_size, shuffle=shuffle,
num_workers=num_workers, persistent_workers=True, pin_memory=True, drop_last=dataset_config.get("drop_last", True), collate_fn=collation_fn, generator=g)

elif dataset_type in ["s3", "wds"]: # Support "s3" type for backwards compatibility
wds_configs = []
Expand Down
Loading