forked from facebookresearch/svoice
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
executable file
·126 lines (102 loc) · 3.8 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# Authors: Yossi Adi (adiyoss) and Alexandre Défossez (adefossez)
import json
import logging
import os
import socket
import subprocess as sp
import sys
import time
from omegaconf import DictConfig, OmegaConf
import hydra
from svoice.executor import start_ddp_workers
logger = logging.getLogger(__name__)
def run(args):
import torch
from svoice import distrib
from svoice.data.data import Trainset, Validset
from svoice.models.swave import SWave
from svoice.solver import Solver
logger.info("Running on host %s", socket.gethostname())
distrib.init(args)
if args.model == "swave":
kwargs = dict(args.swave)
kwargs['sr'] = args.sample_rate
kwargs['segment'] = args.segment
model = SWave(**kwargs)
else:
logger.fatal("Invalid model name %s", args.model)
os._exit(1)
# requires a specific number of samples to avoid 0 padding during training
if hasattr(model, 'valid_length'):
segment_len = int(args.segment * args.sample_rate)
segment_len = model.valid_length(segment_len)
args.segment = segment_len / args.sample_rate
if args.show:
logger.info(model)
mb = sum(p.numel() for p in model.parameters()) * 4 / 2**20
logger.info('Size: %.1f MB', mb)
if hasattr(model, 'valid_length'):
field = model.valid_length(1)
logger.info('Field: %.1f ms', field / args.sample_rate * 1000)
return
assert args.batch_size % distrib.world_size == 0
args.batch_size //= distrib.world_size
# Building datasets and loaders
tr_dataset = Trainset(
args.dset.train, sample_rate=args.sample_rate, segment=args.segment, stride=args.stride, pad=args.pad)
tr_loader = distrib.loader(
tr_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers)
# batch_size=1 -> use less GPU memory to do cv
cv_dataset = Validset(args.dset.valid)
tt_dataset = Validset(args.dset.test)
cv_loader = distrib.loader(
cv_dataset, batch_size=1, num_workers=args.num_workers)
tt_loader = distrib.loader(
tt_dataset, batch_size=1, num_workers=args.num_workers)
data = {"tr_loader": tr_loader,
"cv_loader": cv_loader, "tt_loader": tt_loader}
# torch also initialize cuda seed if available
torch.manual_seed(args.seed)
if torch.cuda.is_available():
model.cuda()
# optimizer
if args.optim == "adam":
optimizer = torch.optim.Adam(
model.parameters(), lr=args.lr, betas=(0.9, args.beta2))
else:
logger.fatal('Invalid optimizer %s', args.optim)
os._exit(1)
# Construct Solver
solver = Solver(data, model, optimizer, args)
solver.train()
def _main(args):
global __file__
# Updating paths in config
for key, value in args.dset.items():
if isinstance(value, str) and key not in ["matching"]:
args.dset[key] = hydra.utils.to_absolute_path(value)
__file__ = hydra.utils.to_absolute_path(__file__)
if args.verbose:
logger.setLevel(logging.DEBUG)
logging.getLogger("denoise").setLevel(logging.DEBUG)
logger.info("For logs, checkpoints and samples check %s", os.getcwd())
logger.debug(args)
if args.ddp and args.rank is None:
start_ddp_workers()
else:
run(args)
@hydra.main(config_path="conf", config_name='config.yaml')
def main(args):
try:
_main(args)
except Exception:
logger.exception("Some error happened")
# Hydra intercepts exit code, fixed in beta but I could not get the beta to work
os._exit(1)
if __name__ == "__main__":
main()