-
Notifications
You must be signed in to change notification settings - Fork 22
/
config.toml
89 lines (70 loc) · 2.61 KB
/
config.toml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
[meta]
name = "ASR"
pretrained_path = "facebook/wav2vec2-base"
seed = 42
epochs = 20
save_dir = "saved/"
gradient_accumulation_steps = 2
use_amp = false # Whether to use Automatic Mixed Precision for speeding up - https://pytorch.org/docs/stable/amp.html
device_ids = "0" # set the gpu devices on which you want to train your model
sr = 16000
max_clip_grad_norm = 5.0 # torch.nn.utils.clip_grad_norm_
[special_tokens]
bos_token = "<bos>"
eos_token = "<eos>"
unk_token = "<unk>"
pad_token = "<pad>"
# Not available yet
[huggingface]
# You need to install git-lfs to be able to push
# Check out https://huggingface.co/docs/hub/how-to-upstream#repository to understand the parameters
push_to_hub = false
push_every_validation_step = false # If false, repo will be push at the end of training [recommended false]
overwrite_output_dir = false
blocking = false # whether to wait until the model is uploaded (this will be very slow because of large file) [recommended false, true only if push_every_validation_step is false]
# you can pass your auth_token from your huggingface account to use_auth_token.
# Otherwise you need to run ```huggingface-cli login``` command to log in
[huggingface.args]
local_dir = "huggingface-hub" # where your repo places in local
use_auth_token = true # you must provide the auth_token of your huggingface account.
clone_from = "" # path to your repo in huggingface
[train_dataset]
path = "base.base_dataset.BaseDataset"
[train_dataset.args]
path = "/pass/your/train_dataset/path/here"
preload_data = false
delimiter = "|"
# Only train audio files that have duration in range [min_duration, max_duration]
# min_duration = 0.5 # if not pass, default is -np.inf
# max_duration = 20 # if not pass, default is np.inf
nb_workers = 16
[train_dataset.dataloader]
batch_size = 8
num_workers = 16
pin_memory = true
drop_last = true
[train_dataset.sampler]
shuffle = true
drop_last = true
[val_dataset]
path = "base.base_dataset.BaseDataset"
[val_dataset.args]
path = "/pass/your/val_dataset/path/here"
preload_data = false
delimiter = "|"
nb_workers = 16
[val_dataset.dataloader]
batch_size = 1 # Set validation batch_size > 1 may yield an incorrect score due to padding (but faster :D) - https://github.com/pytorch/fairseq/issues/3227
num_workers = 4
[val_dataset.sampler]
shuffle = false
drop_last = false
[optimizer]
lr = 1e-6
[scheduler]
max_lr = 5e-4
[trainer]
path = "trainer.trainer.Trainer"
[trainer.args]
validation_interval = 500
save_max_metric_score = false