Skip to content

Commit ac7046d

Browse files
authored
Merge pull request #25 from Stability-AI/update
Update 0.0.8
2 parents f410c25 + 3a3f4cd commit ac7046d

25 files changed

+2340
-484
lines changed

LICENSES/LICENSE_NVIDIA.txt

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
MIT License
2+
3+
Copyright (c) 2022 NVIDIA CORPORATION.
4+
5+
Permission is hereby granted, free of charge, to any person obtaining a copy
6+
of this software and associated documentation files (the "Software"), to deal
7+
in the Software without restriction, including without limitation the rights
8+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9+
copies of the Software, and to permit persons to whom the Software is
10+
furnished to do so, subject to the following conditions:
11+
12+
The above copyright notice and this permission notice shall be included in all
13+
copies or substantial portions of the Software.
14+
15+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21+
SOFTWARE.

scripts/ds_zero_to_pl_ckpt.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
import argparse
2+
from lightning.pytorch.utilities.deepspeed import convert_zero_checkpoint_to_fp32_state_dict
3+
4+
if __name__ == "__main__":
5+
6+
parser = argparse.ArgumentParser()
7+
parser.add_argument("--save_path", type=str, help="Path to the zero checkpoint")
8+
parser.add_argument("--output_path", type=str, help="Path to the output checkpoint", default="lightning_model.pt")
9+
args = parser.parse_args()
10+
11+
# lightning deepspeed has saved a directory instead of a file
12+
save_path = args.save_path
13+
output_path = args.output_path
14+
convert_zero_checkpoint_to_fp32_state_dict(save_path, output_path)

setup.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
setup(
44
name='stable-audio-tools',
5-
version='0.0.7',
5+
version='0.0.8',
66
url='https://github.com/Stability-AI/stable-audio-tools.git',
77
author='Stability AI',
88
description='Training and inference tools for generative audio models from Stability AI',
@@ -13,31 +13,33 @@
1313
'alias-free-torch==0.0.6',
1414
'auraloss==0.4.0',
1515
'descript-audio-codec==1.0.0',
16-
'einops==0.6.1',
16+
'einops==0.7.0',
1717
'einops-exts==0.0.4',
1818
'ema-pytorch==0.2.3',
1919
'encodec==0.1.1',
2020
'gradio==3.42.0',
2121
'importlib-resources==5.12.0',
22-
'k-diffusion==0.0.15',
22+
'k-diffusion==0.1.1',
2323
'laion-clap==1.1.4',
2424
'local-attention==1.8.6',
25+
'nwt-pytorch==0.0.4',
2526
'pandas==2.0.2',
2627
'pedalboard==0.7.4',
2728
'prefigure==0.0.9',
28-
'pytorch_lightning==2.0.9',
29+
'pytorch_lightning==2.1.0',
2930
'PyWavelets==1.4.1',
31+
'safetensors',
3032
'sentencepiece==0.1.99',
31-
's3fs==2023.6.0',
33+
's3fs',
3234
'torch>=2.0.1',
3335
'torchaudio>=2.0.2',
3436
'torchmetrics==0.11.4',
3537
'tqdm',
3638
'transformers==4.33.3',
3739
'v-diffusion-pytorch==0.0.2',
38-
'vector-quantize-pytorch==1.6.21',
40+
'vector-quantize-pytorch==1.9.14',
3941
'wandb==0.15.4',
4042
'webdataset==0.2.48',
41-
'x-transformers==1.16.16'
43+
'x-transformers>=1.25.15'
4244
],
4345
)

stable_audio_tools/data/dataset.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ def __getitem__(self, idx):
168168
start_time = time.time()
169169
audio = self.load_file(audio_filename)
170170

171-
audio, t_start, t_end, seconds_start, seconds_total = self.pad_crop(audio)
171+
audio, t_start, t_end, seconds_start, seconds_total, padding_mask = self.pad_crop(audio)
172172

173173
# Run augmentations on this sample (including random crop)
174174
if self.augs is not None:
@@ -190,6 +190,7 @@ def __getitem__(self, idx):
190190
info["timestamps"] = (t_start, t_end)
191191
info["seconds_start"] = seconds_start
192192
info["seconds_total"] = seconds_total
193+
info["padding_mask"] = padding_mask
193194

194195
end_time = time.time()
195196

@@ -199,6 +200,9 @@ def __getitem__(self, idx):
199200
custom_metadata = self.custom_metadata_fn(info, audio)
200201
info.update(custom_metadata)
201202

203+
if "__reject__" in info and info["__reject__"]:
204+
return self[random.randrange(len(self))]
205+
202206
return (audio, info)
203207
except Exception as e:
204208
print(f'Couldn\'t load file {audio_filename}: {e}')
@@ -339,8 +343,12 @@ def log_and_continue(exn):
339343

340344

341345
def is_valid_sample(sample):
342-
return "json" in sample and "audio" in sample and not is_silence(sample["audio"])
346+
has_json = "json" in sample
347+
has_audio = "audio" in sample
348+
is_silent = is_silence(sample["audio"])
349+
is_rejected = "__reject__" in sample["json"] and sample["json"]["__reject__"]
343350

351+
return has_json and has_audio and not is_silent and not is_rejected
344352

345353
class S3DatasetConfig:
346354
def __init__(
@@ -446,10 +454,11 @@ def wds_preprocess(self, sample):
446454
# Pad/crop and get the relative timestamp
447455
pad_crop = PadCrop_Normalized_T(
448456
self.sample_size, randomize=self.random_crop, sample_rate=self.sample_rate)
449-
audio, t_start, t_end, seconds_start, seconds_total = pad_crop(
457+
audio, t_start, t_end, seconds_start, seconds_total, padding_mask = pad_crop(
450458
audio)
451459
sample["json"]["seconds_start"] = seconds_start
452460
sample["json"]["seconds_total"] = seconds_total
461+
sample["json"]["padding_mask"] = padding_mask
453462
else:
454463
t_start, t_end = 0, 1
455464

stable_audio_tools/data/utils.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,27 +33,40 @@ def __call__(self, source: torch.Tensor) -> Tuple[torch.Tensor, float, float, in
3333

3434
n_channels, n_samples = source.shape
3535

36+
# If the audio is shorter than the desired length, pad it
3637
upper_bound = max(0, n_samples - self.n_samples)
3738

39+
# If randomize is False, always start at the beginning of the audio
3840
offset = 0
3941
if(self.randomize and n_samples > self.n_samples):
4042
offset = random.randint(0, upper_bound)
4143

44+
# Calculate the start and end times of the chunk
4245
t_start = offset / (upper_bound + self.n_samples)
4346
t_end = (offset + self.n_samples) / (upper_bound + self.n_samples)
4447

48+
# Create the chunk
4549
chunk = source.new_zeros([n_channels, self.n_samples])
50+
51+
# Copy the audio into the chunk
4652
chunk[:, :min(n_samples, self.n_samples)] = source[:, offset:offset + self.n_samples]
4753

54+
# Calculate the start and end times of the chunk in seconds
4855
seconds_start = math.floor(offset / self.sample_rate)
4956
seconds_total = math.ceil(n_samples / self.sample_rate)
57+
58+
# Create a mask the same length as the chunk with 1s where the audio is and 0s where it isn't
59+
padding_mask = torch.zeros([self.n_samples])
60+
padding_mask[:min(n_samples, self.n_samples)] = 1
61+
5062

5163
return (
5264
chunk,
5365
t_start,
5466
t_end,
5567
seconds_start,
56-
seconds_total
68+
seconds_total,
69+
padding_mask
5770
)
5871

5972
class PhaseFlipper(nn.Module):

0 commit comments

Comments
 (0)