-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
123 changed files
with
4,916 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
# macOS DS_STORE files | ||
**/*.DS_STORE | ||
# Python precompiled files | ||
*.pyc | ||
# PyCharm config files | ||
.idea/ | ||
# PyCharm Virtual Environment | ||
venv/ | ||
# | ||
data/** | ||
|
||
# personal test file | ||
test.py | ||
|
||
# Dataset directory | ||
# data/* |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,87 @@ | ||
# BPE-Symbolic-Music | ||
Code of the paper "Byte Pair Encoding for Symbolic Music" | ||
# Byte Pair Encoding for Symbolic Music | ||
|
||
Code of the paper *Byte Pair Encoding for Symbolic Music*. | ||
|
||
## Steps to reproduce | ||
|
||
1. `pip install -r requirements` to install requirements | ||
2. Download the [GiantMIDI](https://github.com/bytedance/GiantMIDI-Piano/blob/master/disclaimer.md) dataset and put it in `data/` | ||
3. `sh scripts/download_pop909.sh` to download and preprocess the [POP909](https://github.com/music-x-lab/POP909-Dataset) dataset | ||
4. `python scripts/tokenize_datasets.py` to tokenize data and learn BPE | ||
5. `python exp_gen.py` to train generative models and generate results | ||
6. `python exp_cla.py` to train classification models and test them | ||
|
||
[Scripts](./scripts) can be run to get reproduce the analysis. | ||
|
||
## BPE learning | ||
|
||
<img src="figures/tokenizations_bpe_token_types/POP909-merged_TSD.png" alt="POP909 TSD" width="400"/><img src="figures/tokenizations_bpe_token_types/POP909-merged_REMI.png" alt="POP909 REMI" width="400"/> | ||
|
||
<img src="figures/tokenizations_bpe_token_types/GiantMIDI_TSD.png" alt="GiantMIDI TSD" width="400"/><img src="figures/tokenizations_bpe_token_types/GiantMIDI_REMI.png" alt="GiantMIDI REMI" width="400"/> | ||
|
||
By orders, figures above are for POP909 TSD, POP909 REMI, GiantMIDI TSD, GiantMIDI REMI | ||
|
||
<img src="figures/bpe_nb_tok_combinations.png" alt="GiantMIDI REMI" width="800"/> | ||
|
||
## Experiment results | ||
|
||
We refer you to the tables of the paper. | ||
|
||
## Learned embedding space | ||
|
||
### Singular values | ||
|
||
#### Generators : POP909 TSD, POP909 REMI, GiantMIDI TSD and GiantMIDI REMI | ||
|
||
<img src="figures/singular_value_gen/singular_value_POP909-merged_TSD.png" alt="POP909 TSD" width="200"/><img src="figures/singular_value_gen/singular_value_POP909-merged_REMI.png" alt="POP909 REMI" width="200"/><img src="figures/singular_value_gen/singular_value_GiantMIDI_TSD.png" alt="GiantMIDI TSD" width="200"/><img src="figures/singular_value_gen/singular_value_GiantMIDI_REMI.png" alt="GiantMIDI REMI" width="200"/> | ||
|
||
#### Classifiers : $\mathrm{Cla}\_{small}$ TSD, $\mathrm{Cla}\_{small}$ REMI, $\mathrm{Cla}\_{large}$ TSD and $\mathrm{Cla}\_{large}$ REMI | ||
|
||
<img src="figures/singular_value_cla/singular_value_GiantMIDI_TSD.png" alt="Cla small TSD" width="200"/><img src="figures/singular_value_cla/singular_value_GiantMIDI_REMI.png" alt="Cla small REMI" width="200"/><img src="figures/singular_value_cla/singular_value_GiantMIDI_TSD_LARGE.png" alt="Cla large TSD" width="200"/><img src="figures/singular_value_cla/singular_value_GiantMIDI_REMI_LARGE.png" alt="Cla large REMI" width="200"/> | ||
|
||
### UMAP Generators | ||
|
||
Figures are by order for no BPE, BPEx4, BPEx10, BPEx20, BPEx50, BPEx100, PVm and PVDm. | ||
|
||
#### POP909 TSD | ||
|
||
<img src="figures/umap_3d_gen/umap_3d_POP909-merged_TSD_noBPE.png" alt="No BPE" width="200"/><img src="figures/umap_3d_gen/umap_3d_POP909-merged_TSD_bpe4.png" alt="BPEx4" width="200"/><img src="figures/umap_3d_gen/umap_3d_POP909-merged_TSD_bpe10.png" alt="BPEx10" width="200"/><img src="figures/umap_3d_gen/umap_3d_POP909-merged_TSD_bpe20.png" alt="BPEx20" width="200"/> | ||
|
||
<img src="figures/umap_3d_gen/umap_3d_POP909-merged_TSD_bpe50.png" alt="BPEx50" width="200"/><img src="figures/umap_3d_gen/umap_3d_POP909-merged_TSD_bpe100.png" alt="BPEx100" width="200"/><img src="figures/umap_3d_gen/umap_3d_POP909-merged_TSD_PVm.png" alt="PVm" width="200"/><img src="figures/umap_3d_gen/umap_3d_POP909-merged_TSD_PVDm.png" alt="PVDm" width="200"/> | ||
|
||
#### POP909 REMI | ||
|
||
<img src="figures/umap_3d_gen/umap_3d_POP909-merged_REMI_noBPE.png" alt="No BPE" width="200"/><img src="figures/umap_3d_gen/umap_3d_POP909-merged_REMI_bpe4.png" alt="BPEx4" width="200"/><img src="figures/umap_3d_gen/umap_3d_POP909-merged_REMI_bpe10.png" alt="BPEx10" width="200"/><img src="figures/umap_3d_gen/umap_3d_POP909-merged_REMI_bpe20.png" alt="BPEx20" width="200"/> | ||
|
||
<img src="figures/umap_3d_gen/umap_3d_POP909-merged_REMI_bpe50.png" alt="BPEx50" width="200"/><img src="figures/umap_3d_gen/umap_3d_POP909-merged_REMI_bpe100.png" alt="BPEx100" width="200"/><img src="figures/umap_3d_gen/umap_3d_POP909-merged_REMI_PVm.png" alt="PVm" width="200"/><img src="figures/umap_3d_gen/umap_3d_POP909-merged_REMI_PVDm.png" alt="PVDm" width="200"/> | ||
|
||
#### GiantMIDI TSD | ||
|
||
<img src="figures/umap_3d_gen/umap_3d_GiantMIDI_TSD_noBPE.png" alt="No BPE" width="200"/><img src="figures/umap_3d_gen/umap_3d_GiantMIDI_TSD_bpe4.png" alt="BPEx4" width="200"/><img src="figures/umap_3d_gen/umap_3d_GiantMIDI_TSD_bpe10.png" alt="BPEx10" width="200"/><img src="figures/umap_3d_gen/umap_3d_GiantMIDI_TSD_bpe20.png" alt="BPEx20" width="200"/> | ||
|
||
<img src="figures/umap_3d_gen/umap_3d_GiantMIDI_TSD_bpe50.png" alt="BPEx50" width="200"/><img src="figures/umap_3d_gen/umap_3d_GiantMIDI_TSD_bpe100.png" alt="BPEx100" width="200"/><img src="figures/umap_3d_gen/umap_3d_GiantMIDI_TSD_PVm.png" alt="PVm" width="200"/><img src="figures/umap_3d_gen/umap_3d_GiantMIDI_TSD_PVDm.png" alt="PVDm" width="200"/> | ||
|
||
#### GiantMIDI REMI | ||
|
||
<img src="figures/umap_3d_gen/umap_3d_GiantMIDI_REMI_noBPE.png" alt="No BPE" width="200"/><img src="figures/umap_3d_gen/umap_3d_GiantMIDI_REMI_bpe4.png" alt="BPEx4" width="200"/><img src="figures/umap_3d_gen/umap_3d_GiantMIDI_REMI_bpe10.png" alt="BPEx10" width="200"/><img src="figures/umap_3d_gen/umap_3d_GiantMIDI_REMI_bpe20.png" alt="BPEx20" width="200"/> | ||
|
||
<img src="figures/umap_3d_gen/umap_3d_GiantMIDI_REMI_bpe50.png" alt="BPEx50" width="200"/><img src="figures/umap_3d_gen/umap_3d_GiantMIDI_REMI_bpe100.png" alt="BPEx100" width="200"/><img src="figures/umap_3d_gen/umap_3d_GiantMIDI_REMI_PVm.png" alt="PVm" width="200"/><img src="figures/umap_3d_gen/umap_3d_GiantMIDI_REMI_PVDm.png" alt="PVDm" width="200"/> | ||
|
||
|
||
### UMAP Classifiers | ||
|
||
These figures are for $\mathrm{Cla}\_{small}$ and TSD. More figures can be found in [figures](./figures). | ||
|
||
<img src="figures/umap_2d_cla/umap_2d_GiantMIDI_TSD_noBPE.png" alt="No BPE" width="200"/><img src="figures/umap_2d_cla/umap_2d_GiantMIDI_TSD_bpe4.png" alt="BPEx4" width="200"/><img src="figures/umap_2d_cla/umap_2d_GiantMIDI_TSD_bpe10.png" alt="BPEx10" width="200"/><img src="figures/umap_2d_cla/umap_2d_GiantMIDI_TSD_bpe20.png" alt="BPEx20" width="200"/> | ||
|
||
<img src="figures/umap_2d_cla/umap_2d_GiantMIDI_TSD_bpe50.png" alt="BPEx50" width="200"/><img src="figures/umap_2d_cla/umap_2d_GiantMIDI_TSD_bpe100.png" alt="BPEx100" width="200"/><img src="figures/umap_2d_cla/umap_2d_GiantMIDI_TSD_PVm.png" alt="PVm" width="200"/><img src="figures/umap_2d_cla/umap_2d_GiantMIDI_TSD_PVDm.png" alt="PVDm" width="200"/> | ||
|
||
### Intrinsic dimension | ||
|
||
#### Generators : POP909 TSD, POP909 REMI, GiantMIDI TSD and GiantMIDI REMI | ||
|
||
<img src="figures/intrinsic_dimension_gen/intrinsic_dim_POP909-merged_TSD.png" alt="POP909 TSD" width="200"/><img src="figures/intrinsic_dimension_gen/intrinsic_dim_POP909-merged_REMI.png" alt="POP909 REMI" width="200"/><img src="figures/intrinsic_dimension_gen/intrinsic_dim_GiantMIDI_TSD.png" alt="GiantMIDI TSD" width="200"/><img src="figures/intrinsic_dimension_gen/intrinsic_dim_GiantMIDI_REMI.png" alt="GiantMIDI REMI" width="200"/> | ||
|
||
#### Classifiers : $\mathrm{Cla}\_{small}$ TSD, $\mathrm{Cla}\_{small}$ REMI, $\mathrm{Cla}\_{large}$ TSD and $\mathrm{Cla}\_{large}$ REMI | ||
|
||
<img src="figures/intrinsic_dimension_cla/intrinsic_dim_GiantMIDI_TSD.png" alt="Cla small TSD" width="200"/><img src="figures/intrinsic_dimension_cla/intrinsic_dim_GiantMIDI_REMI.png" alt="Cla small REMI" width="200"/><img src="figures/intrinsic_dimension_cla/intrinsic_dim_GiantMIDI_TSD_LARGE.png" alt="Cla large TSD" width="200"/><img src="figures/intrinsic_dimension_cla/intrinsic_dim_GiantMIDI_REMI_LARGE.png" alt="Cla large REMI" width="200"/> |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,189 @@ | ||
from pathlib import Path | ||
from typing import List, Union | ||
|
||
from transformers import GPT2Config, BertConfig | ||
import miditok | ||
|
||
from model import GenTransformer, GenTransformerPooling, ClassifierTransformer, ClassifierTransformerPooling | ||
from constants import * | ||
import tokenizers_ | ||
|
||
|
||
class ModelConfig: | ||
def __init__(self, dim: int = DIM, | ||
nb_heads: int = NB_HEADS, | ||
d_ffwd: int = D_FFWD, | ||
nb_layers: int = NB_LAYERS, | ||
nb_pos_enc_params: int = NB_POS_ENC_PARAMS, | ||
embed_sizes: List[int] = None): | ||
self.dim = dim | ||
self.nb_layers = nb_layers | ||
self.nb_heads = nb_heads | ||
self.d_ffwd = d_ffwd | ||
self.nb_pos_enc_params = nb_pos_enc_params | ||
|
||
self.embed_sizes = embed_sizes # for CP Word and Octuple | ||
|
||
|
||
class TrainingConfig: | ||
def __init__(self, use_cuda: bool = USE_CUDA, | ||
use_amp: bool = USE_AMP, | ||
batch_size: int = BATCH_SIZE, | ||
grad_acc_steps: int = GRAD_ACC_STEPS, | ||
learning_rate: float = LEARNING_RATE, | ||
weight_decay: float = WEIGHT_DECAY, | ||
gradient_clip_norm: float = GRADIENT_CLIP_NORM, | ||
label_smoothing: float = LABEL_SMOOTHING, | ||
dropout: float = DROPOUT, | ||
valid_split: float = VALID_SPLIT, | ||
test_split: float = TEST_SPLIT, | ||
training_steps: int = TRAINING_STEPS, | ||
warmup_ratio: float = WARMUP_RATIO, | ||
iterator_kwargs: dict = ITERATOR_KWARGS, | ||
valid_intvl: int = VALID_INTVL, | ||
nb_valid_steps: int = NB_VALID_STEPS, | ||
log_intvl: int = LOG_INTVL, | ||
min_seq_len: int = MIN_SEQ_LEN, | ||
max_seq_len: int = MAX_SEQ_LEN): | ||
self.use_cuda = use_cuda | ||
self.use_amp = use_amp | ||
self.batch_size = batch_size | ||
self.grad_acc_steps = grad_acc_steps | ||
self.learning_rate = learning_rate | ||
self.weight_decay = weight_decay | ||
self.gradient_clip_norm = gradient_clip_norm | ||
self.label_smoothing = label_smoothing | ||
self.dropout = dropout | ||
self.valid_split = valid_split | ||
self.test_split = test_split | ||
self.training_steps = training_steps | ||
self.warmup_ratio = warmup_ratio | ||
self.iterator_kwargs = iterator_kwargs | ||
self.valid_intvl = valid_intvl | ||
self.nb_valid_steps = nb_valid_steps | ||
self.log_intvl = log_intvl | ||
self.min_seq_len = min_seq_len | ||
self.max_seq_len = max_seq_len | ||
|
||
|
||
class TestingConfig: | ||
def __init__(self, batch_size: int = BATCH_SIZE_TEST, | ||
max_seq_len: int = MAX_SEQ_LEN_TEST, | ||
nb_inferences_test: int = NB_INFERENCES_TEST, | ||
num_beams: int = NUM_BEAMS, | ||
top_p: float = TOP_P): | ||
self.batch_size = batch_size | ||
self.max_seq_len = max_seq_len | ||
self.nb_inferences_test = nb_inferences_test | ||
self.num_beams = num_beams | ||
self.top_p = top_p | ||
|
||
|
||
class Baseline: | ||
"""Represents a baseline | ||
""" | ||
def __init__(self, name: str, exp_name: str, dataset: str, seed: int, tokenization: str, bpe_factor: int, | ||
model_config: ModelConfig): | ||
self.name = name # bpe or tokenization | ||
self.exp_name = exp_name # data_tokenization (remi/tsd), for run path as exp.tokenization can be different | ||
self.dataset = dataset # can be different from the one in exp_name, for example with short | ||
self.seed = seed # is used when splitting generated data in train / test sets | ||
|
||
self.bpe_factor = bpe_factor | ||
self.tokenization = tokenization | ||
self.tokenizer = None | ||
self.model_config = model_config | ||
|
||
def load_tokenizer(self): | ||
if self.tokenization[-3:] == 'PVm' or self.tokenization[-4:] == 'PVDm': | ||
self.tokenizer = getattr(tokenizers_, self.tokenization)(params=self.data_path / 'config.txt') | ||
else: | ||
self.tokenizer = getattr(miditok, self.tokenization)(params=self.data_path / 'config.txt') | ||
|
||
@property | ||
def data_path(self) -> Path: return Path('data', f'{self.exp_name}' + (f'_{self.name}' if self.name != '' else '')) | ||
|
||
@property | ||
def run_path(self) -> Path: return Path('runs', self.exp_name, self.name) | ||
|
||
@property | ||
def run_path_classifier(self) -> Path: return Path('runs_classifier', self.exp_name, self.name) | ||
|
||
@property | ||
def gen_data_path(self) -> Path: return self.run_path / 'gen' | ||
|
||
@property | ||
def is_embed_pooling(self) -> bool: return isinstance(self.tokenizer.vocab, list) | ||
|
||
@property | ||
def pad_token(self) -> int: | ||
return self.tokenizer.vocab[0]['PAD_None'] if self.is_embed_pooling else self.tokenizer['PAD_None'] | ||
|
||
@property | ||
def sos_token(self) -> int: | ||
return self.tokenizer.vocab[0]['SOS_None'] if self.is_embed_pooling else self.tokenizer['SOS_None'] | ||
|
||
@property | ||
def eos_token(self) -> int: | ||
return self.tokenizer.vocab[0]['EOS_None'] if self.is_embed_pooling else self.tokenizer['EOS_None'] | ||
|
||
def __repr__(self): | ||
return f'{self.name} - {self.data_path}' | ||
|
||
|
||
class Experiment: | ||
def __init__(self, baselines: List[Baseline], dataset: str, tokenization: str, seed: int, | ||
cla_model_conf: ModelConfig, gen_train_conf: TrainingConfig, cla_pre_train_conf: TrainingConfig, | ||
cla_train_conf: TrainingConfig, gen_test_conf: TestingConfig, | ||
tokenizer_params: dict = TOKENIZER_PARAMS): | ||
self.name = f'{dataset}_{tokenization}' # dataset_tokenization | ||
self.run_path = Path('runs', self.name) | ||
self.data_path_midi = Path('data', dataset) # original dataset path, in MIDI | ||
self.baselines = baselines | ||
self.dataset = dataset | ||
self.tokenizer_params = tokenizer_params # used when tokenizing datasets only | ||
self.seed = seed | ||
|
||
self.cla_model_conf = cla_model_conf | ||
self.gen_train_conf = gen_train_conf | ||
self.cla_pre_train_conf = cla_pre_train_conf | ||
self.cla_train_conf = cla_train_conf | ||
self.gen_test_conf = gen_test_conf | ||
|
||
def create_gen(self, baseline: Baseline) -> Union[GenTransformer, GenTransformerPooling]: | ||
"""Creates the generative model for the experiment. | ||
The model must implement the `forward_train` and `infer` methods. | ||
""" | ||
config_d = GPT2Config(vocab_size=len(baseline.tokenizer.vocab), | ||
n_positions=baseline.model_config.nb_pos_enc_params, | ||
n_embd=baseline.model_config.dim, n_layer=baseline.model_config.nb_layers, | ||
n_head=baseline.model_config.nb_heads, n_inner=baseline.model_config.d_ffwd, | ||
resid_pdrop=self.gen_train_conf.dropout, embd_pdrop=self.gen_train_conf.dropout, | ||
attn_pdrop=self.gen_train_conf.dropout, pad_token_id=baseline.pad_token, | ||
bos_token_id=baseline.sos_token, eos_token_id=baseline.eos_token) | ||
if baseline.is_embed_pooling: | ||
num_classes = [len(v) for v in baseline.tokenizer.vocab] | ||
return GenTransformerPooling(config_d, num_classes, baseline.model_config.embed_sizes) | ||
else: | ||
return GenTransformer(config_d) | ||
|
||
def create_classifier(self, baseline: Baseline, num_labels: int = None, pre_train: bool = False): | ||
"""Creates the model for the experiment. | ||
The model must implement the `forward_train` and `infer` methods. | ||
""" | ||
model_conf = self.cla_model_conf | ||
train_conf = self.cla_train_conf | ||
config = BertConfig(vocab_size=len(baseline.tokenizer), hidden_size=model_conf.dim, | ||
num_hidden_layers=model_conf.nb_layers, num_attention_heads=model_conf.nb_heads, | ||
intermediate_size=model_conf.d_ffwd, hidden_dropout_prob=train_conf.dropout, | ||
attention_probs_dropout_prob=train_conf.dropout, | ||
max_position_embeddings=model_conf.nb_pos_enc_params, | ||
type_vocab_size=2, pad_token_id=baseline.pad_token, num_labels=num_labels) | ||
if baseline.is_embed_pooling: | ||
num_classes = [len(v) for v in baseline.tokenizer.vocab] | ||
return ClassifierTransformerPooling(config, num_classes, model_conf.embed_sizes, pre_train) | ||
return ClassifierTransformer(config, pre_train) | ||
|
||
def __str__(self): return f'{self.name} - {len(self.baselines)} baselines' | ||
|
||
def __repr__(self): return self.__str__() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
""" | ||
Constants file | ||
""" | ||
SEED = 444 | ||
|
||
# Tokenizer params (same as MidiTok expect for new constants) | ||
PITCH_RANGE = range(21, 109) | ||
BEAT_RES = {(0, 1): 8, (1, 2): 4, (2, 4): 2, (4, 8): 1} | ||
NB_VELOCITIES = 8 | ||
ADDITIONAL_TOKENS = {'Chord': False, 'Rest': False, 'Tempo': False, 'Program': False, 'TimeSignature': False, | ||
'rest_range': (2, 8), 'nb_tempos': 32, 'tempo_range': (40, 250), 'time_signature_range': (8, 2)} | ||
TOKENIZER_PARAMS = {'pitch_range': PITCH_RANGE, 'beat_res': BEAT_RES, 'nb_velocities': NB_VELOCITIES, | ||
'additional_tokens': ADDITIONAL_TOKENS, 'sos_eos': True} | ||
TIME_DIVISION = 384 | ||
NB_SCALES_OFFSET_DATA_AUGMENTATION = 2 | ||
BPE_NB_FILES_LIM = 1500 | ||
|
||
# For classification | ||
MAX_NB_COMPOSERS = 10 | ||
|
||
# Transformer config (for generator) | ||
DIM = 512 | ||
NB_HEADS = 8 | ||
D_FFWD = 2048 | ||
NB_LAYERS = 10 | ||
NB_POS_ENC_PARAMS = 2048 # params for positional encoding positions | ||
|
||
# Transformer config (for classifier) | ||
CLA_DIM = 768 | ||
CLA_NB_HEADS = 12 | ||
CLA_D_FFWD = 2048 | ||
CLA_NB_LAYERS = 10 | ||
CLA_NB_POS_ENC_PARAMS = 2048 # params for positional encoding positions | ||
CLA_LARGE_DIM = 1024 | ||
CLA_LARGE_NB_HEADS = 16 | ||
CLA_LARGE_D_FFWD = 3072 | ||
CLA_LARGE_NB_LAYERS = 18 | ||
CLA_LARGE_NB_POS_ENC_PARAMS = 2048 # params for positional encoding positions | ||
|
||
|
||
# Training params | ||
DROPOUT = 0.1 | ||
BATCH_SIZE = 16 | ||
GRAD_ACC_STEPS = 1 | ||
WEIGHT_DECAY = 0.01 | ||
GRADIENT_CLIP_NORM = 3.0 | ||
LABEL_SMOOTHING = 0.0 | ||
LEARNING_RATE = 5e-6 | ||
WARMUP_RATIO = 0.3 | ||
VALID_SPLIT = 0.35 | ||
TEST_SPLIT = 0.15 # unused | ||
USE_CUDA = True | ||
USE_AMP = True | ||
TRAINING_STEPS = 100000 | ||
EARLY_STOP_STEPS = 15000 # nb of steps to stop training if no increase of valid loss | ||
ITERATOR_KWARGS = {'early_stop_steps': EARLY_STOP_STEPS} | ||
VALID_INTVL = 30 | ||
NB_VALID_STEPS = 5 | ||
LOG_INTVL = 10 | ||
MIN_SEQ_LEN = 384 | ||
MAX_SEQ_LEN = 460 | ||
|
||
# GEN TEST PARAMS | ||
NB_INFERENCES_TEST = 1024 | ||
MAX_SEQ_LEN_TEST = 1024 | ||
BATCH_SIZE_TEST = 32 | ||
NUM_BEAMS = 1 # in practice the generation will use a batch size = BATCH_SIZE_TEST * NUM_BEAMS | ||
TOP_P = 0.9 | ||
|
||
# TRAINING PARAMS DIS / CLA | ||
CLA_PRE_TRAINING_STEPS = 60000 | ||
CLA_TRAINING_STEPS = 100000 | ||
CLA_BATCH_SIZE = 24 | ||
CLA_LARGE_BATCH_SIZE = 24 | ||
CLA_PT_LEARNING_RATE = 1e-6 | ||
CLA_FT_LEARNING_RATE = 5e-7 | ||
CLA_EARLY_STOP = 25000 | ||
RANDOM_RATIO_RANGE = (0.01, 0.15) | ||
|
||
# For CP Word and Octuple | ||
EMBED_SIZES_CP = [32, 64, 512, 128, 128] # fam, pos / bar, pitch, vel, dur | ||
EMBED_SIZES_OCTUPLE = [512] * 5 # pitch, vel, dur, pos, bar | ||
OCT_MAX_BAR = 30 # to shorten MIDIs | ||
|
||
# For metrics | ||
CONSISTENCY_WINDOWS_LEN = 16 # in beats |
Oops, something went wrong.