Skip to content

Commit cf8f836

Browse files
authored
Merge pull request #5 from Stability-AI/docs_update
Docs update
2 parents 77536b9 + de48dbc commit cf8f836

File tree

5 files changed

+239
-15
lines changed

5 files changed

+239
-15
lines changed

README.md

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -118,11 +118,10 @@ The following properties are defined in the top level of the model configuration
118118
- The training configuration for the model, varies based on `model_type`. Provides parameters for training as well as demos.
119119

120120
## Dataset config
121-
`stable-audio-tools` currently supports two kinds of data sources: local directories of audio files, and WebDataset datasets stored in Amazon S3.
121+
`stable-audio-tools` currently supports two kinds of data sources: local directories of audio files, and WebDataset datasets stored in Amazon S3. More information can be found in [the dataset config documentation](docs/datasets.md)
122122

123123
# Todo
124-
- [ ] Add documentation for dataset configs
125124
- [ ] Add documentation for different model types
126-
- [ ] Add documentation on pretransforms
127125
- [ ] Add documentation for Gradio interface
128126
- [ ] Add troubleshooting section
127+
- [ ] Add contribution guidelines

docs/conditioning.md

Lines changed: 151 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,151 @@
1-
# Conditioning
1+
# Conditioning
2+
Conditioning, in the context of `stable-audio-tools` is the use of additional signals in a model that are used to add an additional level of control over the model's behavior. For example, we can condition the outputs of a diffusion model on a text prompt, creating a text-to-audio model.
3+
4+
# Conditioning types
5+
There are a few different kinds of conditioning depending on the conditioning signal being used.
6+
7+
## Cross attention
8+
Cross attention is a type of conditioning that allows us to find correlations between two sequences of potentially different lengths. For example, cross attention allows us to find correlations between a sequence of features from a text encoder and a sequence of high-level audio features.
9+
10+
Signals used for cross-attention conditioning should be of the shape `[batch, sequence, channels]`.
11+
12+
## Global conditioning
13+
Global conditioning is the use of a single n-dimensional tensor to provide conditioning information that pertains to the whole sequence being conditioned. For example, this could be the single embedding output of a CLAP model, or a learned class embedding.
14+
15+
Signals used for global conditioning should be of the shape `[batch, channels]`.
16+
17+
## Input concatenation
18+
Input concatenation applies a spatial conditioning signal to the model that correlates in the sequence dimension with the model's input, and is of the same length. The conditioning signal will be concatenated with the model's input data along the channel dimension. This can be used for things like inpainting information, melody conditioning, or for creating a diffusion autoencoder.
19+
20+
Signals used for input concatenation conditioning should be of the shape `[batch, channels, sequence]` and must be the same length as the model's input.
21+
22+
# Conditioners and conditioning configs
23+
`stable-audio-tools` uses Conditioner modules to translate human-readable metadata such as text prompts or a number of seconds into tensors that the model can take as input.
24+
25+
Each conditioner has a corresponding `id` that it expects to find in the conditioning dictionary provided during training or inference. Each conditioner takes in the relevant conditioning data and returns a tuple containing the corresponding tensor and a mask.
26+
27+
The ConditionedDiffusionModelWrapper manages the translation between the user-provided metadata dictionary (e.g. `{"prompt": "a beautiful song", "seconds_start": 22, "seconds_total": 193}`) and the dictionary of different conditioning types that the model uses (e.g. `{"cross_attn_cond": ...}`).
28+
29+
To apply conditioning to a model, you must provide a `conditioning` configuration in the model's config. At the moment, we only support conditioning diffusion models though the `diffusion_cond` model type.
30+
31+
The `conditioning` configuration should contain a `configs` array, which allows you to define multiple conditioning signals.
32+
33+
Each item in `configs` array should define the `id` for the corresponding metadata, the type of conditioner to be used, and the config for that conditioner.
34+
35+
The `cond_dim` property is used to enforce the same dimension on all conditioning inputs, however that can be overridden with an explicit `output_dim` property on any of the individual configs.
36+
37+
## Example config
38+
```json
39+
"conditioning": {
40+
"configs": [
41+
{
42+
"id": "prompt",
43+
"type": "t5",
44+
"config": {
45+
"t5_model_name": "t5-base",
46+
"max_length": 77,
47+
"project_out": true
48+
}
49+
}
50+
],
51+
"cond_dim": 768
52+
}
53+
```
54+
55+
# Conditioners
56+
57+
## Text encoders
58+
59+
### `t5`
60+
This uses a frozen [T5](https://huggingface.co/docs/transformers/model_doc/t5) text encoder from the `transformers` library to encode text prompts into a sequence of text features.
61+
62+
The `t5_model_name` property determines which T5 model is loaded from the `transformers` library.
63+
64+
The `max_length` property determines the maximum number of tokens that the text encoder will take in, as well as the sequence length of the output text features.
65+
66+
If you set `enable_grad` to `true`, the T5 model will be un-frozen and saved with the model checkpoint, allowing you to fine-tune the T5 model.
67+
68+
T5 encodings are only compatible with cross attention conditioning.
69+
70+
#### Example config
71+
```json
72+
{
73+
"id": "prompt",
74+
"type": "t5",
75+
"config": {
76+
"t5_model_name": "t5-base",
77+
"max_length": 77,
78+
"project_out": true
79+
}
80+
}
81+
```
82+
83+
### `clap_text`
84+
This loads the text encoder from a [CLAP](https://github.com/LAION-AI/CLAP) model, which can provide either a sequence of text features, or a single multimodal text/audio embedding.
85+
86+
The CLAP model must be provided with a local file path, set in the `clap_ckpt_path` property,along with the correct `audio_model_type` and `enable_fusion` properties for the provided model.
87+
88+
If the `use_text_features` property is set to `true`, the conditioner output will be a sequence of text features, instead of a single multimodal embedding. This allows for more fine-grained text information to be used by the model, at the cost of losing the ability to prompt with CLAP audio embeddings.
89+
90+
By default, if `use_text_features` is true, the last layer of the CLAP text encoder's features are returned. You can return the text features of earlier layers by specifying the index of the layer to return in the `feature_layer_ix` property. For example, you can return the text features of the next-to-last layer of the CLAP model by setting `feature_layer_ix` to `-2`.
91+
92+
If you set `enable_grad` to `true`, the CLAP model will be un-frozen and saved with the model checkpoint, allowing you to fine-tune the CLAP model.
93+
94+
CLAP text embeddings are compatible with global conditioning and cross attention conditioning. If `use_text_features` is set to `true`, the features are not compatible with global conditioning.
95+
96+
#### Example config
97+
```json
98+
{
99+
"id": "prompt",
100+
"type": "clap_text",
101+
"config": {
102+
"clap_ckpt_path": "/path/to/clap/model.ckpt",
103+
"audio_model_type": "HTSAT-base",
104+
"enable_fusion": true,
105+
"use_text_features": true,
106+
"feature_layer_ix": -2
107+
}
108+
}
109+
```
110+
111+
## Number encoders
112+
113+
### `int`
114+
The IntConditioner takes in a list of integers in a given range, and returns a discrete learned embedding for each of those integers.
115+
116+
The `min_val` and `max_val` properties set the range of the embedding values. Input integers are clamped to this range.
117+
118+
This can be used for things like discrete timing embeddings, or learned class embeddings.
119+
120+
Int embeddings are compatible with global conditioning and cross attention conditioning.
121+
122+
#### Example config
123+
```json
124+
{
125+
"id": "seconds_start",
126+
"type": "int",
127+
"config": {
128+
"min_val": 0,
129+
"max_val": 512
130+
}
131+
}
132+
```
133+
134+
### `number`
135+
The NumberConditioner takes in a a list of floats in a given range, and returns a continuous Fourier embedding of the provided floats.
136+
137+
The `min_val` and `max_val` properties set the range of the float values. This is the range used to normalize the input float values.
138+
139+
Number embeddings are compatible with global conditioning and cross attention conditioning.
140+
141+
#### Example config
142+
```json
143+
{
144+
"id": "seconds_total",
145+
"type": "number",
146+
"config": {
147+
"min_val": 0,
148+
"max_val": 512
149+
}
150+
}
151+
```

docs/datasets.md

Lines changed: 75 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,75 @@
1-
# Datasets
1+
# Datasets
2+
`stable-audio-tools` supports loading data from local file storage, as well as loading audio files and JSON files in the [WebDataset](https://github.com/webdataset/webdataset/tree/main/webdataset) format from Amazon S3 buckets.
3+
4+
# Dataset configs
5+
To specify the dataset used for training, you must provide a dataset config JSON file to `train.py`.
6+
7+
The dataset config consists of a `dataset_type` property specifying the type of data loader to use, a `datasets` array to provide multiple data sources, and a `random_crop` property, which decides if the cropped audio from the training samples is from a random place in the audio file, or always from the beginning.
8+
9+
## Local audio files
10+
To use a local directory of audio samples, set the `dataset_type` property in your dataset config to `"audio_dir"`, and provide a list of objects to the `datasets` property including the `path` property, which should be the path to your directory of audio samples.
11+
12+
This will load all of the compatible audio files from the provided directory and all subdirectories.
13+
14+
### Example config
15+
```json
16+
{
17+
"dataset_type": "audio_dir",
18+
"datasets": [
19+
{
20+
"id": "my_audio",
21+
"path": "/path/to/audio/dataset/"
22+
}
23+
],
24+
"random_crop": true
25+
}
26+
```
27+
28+
## S3 WebDataset
29+
To load audio files and related metadata from .tar files in the WebDataset format hosted in Amazon S3 buckets, you can set the `dataset_type` property to `s3`, and provide the `datasets` parameter with a list of objects containing the AWS S3 path to the shared S3 bucket prefix of the WebDataset .tar files. The S3 bucket will be searched recursively given the path, and assumes any .tar files found contain audio files and corresponding JSON files where the related files differ only in file extension (e.g. "000001.flac", "000001.json", "00002.flac", "00002.json", etc.)
30+
31+
### Example config
32+
```json
33+
{
34+
"dataset_type": "s3",
35+
"datasets": [
36+
{
37+
"id": "s3-test",
38+
"s3_path": "s3://my-bucket/datasets/webdataset/audio/"
39+
}
40+
],
41+
"random_crop": true
42+
}
43+
```
44+
45+
# Custom metadata
46+
To customize the metadata provided to the conditioners during model training, you can provide a separate custom metadata module to the dataset config. This metadata module should be a Python file that must contain a function called `get_custom_metadata` that takes in two parameters, `info`, and `audio`, and returns a dictionary.
47+
48+
For local training, the `info` parameter will contain a few pieces of information about the loaded audio file, such as the path, and information about how the audio was cropped from the original training sample. For S3 WebDataset datasets, it will also contain the metadata from the related JSON files.
49+
50+
The `audio` parameter contains the audio sample that will be passed to the model at training time. This lets you analyze the audio for extra properties that you can then pass in as extra conditioning signals.
51+
52+
The dictionary returned from the `get_custom_metadata` function will have its properties added to the `metadata` object used at training time. For more information on how conditioning works, please see the [Conditioning documentation](./conditioning.md)
53+
54+
## Example config and custom metadata module
55+
```json
56+
{
57+
"dataset_type": "audio_dir",
58+
"datasets": [
59+
{
60+
"id": "my_audio",
61+
"path": "/path/to/audio/dataset/"
62+
}
63+
],
64+
"custom_metadata_module": "/path/to/custom_metadata.py",
65+
"random_crop": true
66+
}
67+
```
68+
69+
`custom_metadata.py`:
70+
```py
71+
def get_custom_metadata(info, audio):
72+
73+
# Pass in the relative path of the audio file as the prompt
74+
return {"prompt": info["relpath"]}
75+
```

stable_audio_tools/configs/dataset_configs/s3_wds_example.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
"datasets": [
44
{
55
"id": "s3-test",
6-
"s3_path": "s3://my-bucket/datasets/webdatset/audio/"
6+
"s3_path": "s3://my-bucket/datasets/webdataset/audio/"
77
}
88
],
99
"random_crop": true

stable_audio_tools/models/conditioners.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,6 @@ def __init__(self,
4545

4646
def forward(self, ints: tp.List[int], device=None) -> tp.Any:
4747

48-
#self.int_embedder.to(device)
49-
5048
ints = torch.tensor(ints).to(device)
5149
ints = ints.clamp(self.min_val, self.max_val)
5250

@@ -94,12 +92,12 @@ def __init__(self,
9492
audio_model_type="HTSAT-base",
9593
enable_fusion=True,
9694
project_out: bool = False,
97-
finetune: bool = False):
95+
enable_grad: bool = False):
9896
super().__init__(768 if use_text_features else 512, output_dim, 1, project_out=project_out)
9997

10098
self.use_text_features = use_text_features
10199
self.feature_layer_ix = feature_layer_ix
102-
self.finetune = finetune
100+
self.enable_grad = enable_grad
103101

104102
# Suppress logging from transformers
105103
previous_level = logging.root.manager.disable
@@ -111,15 +109,15 @@ def __init__(self,
111109

112110
model = laion_clap.CLAP_Module(enable_fusion=enable_fusion, amodel=audio_model_type, device='cpu')
113111

114-
if self.finetune:
112+
if self.enable_grad:
115113
self.model = model
116114
else:
117115
self.__dict__["model"] = model
118116

119117
state_dict = clap_load_state_dict(clap_ckpt_path)
120118
self.model.model.load_state_dict(state_dict, strict=False)
121119

122-
if self.finetune:
120+
if self.enable_grad:
123121
self.model.model.text_branch.requires_grad_(True)
124122
self.model.model.text_branch.train()
125123
else:
@@ -173,9 +171,12 @@ def __init__(self,
173171
clap_ckpt_path,
174172
audio_model_type="HTSAT-base",
175173
enable_fusion=True,
176-
project_out: bool = False):
174+
project_out: bool = False,
175+
enable_grad: bool = False):
177176
super().__init__(512, output_dim, 1, project_out=project_out)
178177

178+
self.enable_grad = enable_grad
179+
179180
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
180181

181182
# Suppress logging from transformers
@@ -188,15 +189,15 @@ def __init__(self,
188189

189190
model = laion_clap.CLAP_Module(enable_fusion=enable_fusion, amodel=audio_model_type, device='cpu')
190191

191-
if self.finetune:
192+
if self.enable_grad:
192193
self.model = model
193194
else:
194195
self.__dict__["model"] = model
195196

196197
state_dict = clap_load_state_dict(clap_ckpt_path)
197198
self.model.model.load_state_dict(state_dict, strict=False)
198199

199-
if self.finetune:
200+
if self.enable_grad:
200201
self.model.model.audio_branch.requires_grad_(True)
201202
self.model.model.audio_branch.train()
202203
else:

0 commit comments

Comments
 (0)