-
Notifications
You must be signed in to change notification settings - Fork 4
/
create.py
103 lines (88 loc) · 4.2 KB
/
create.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
""" create.py
Creates and initializes a new experiment folder.
"""
import os
import sys
import shutil
import argparse
import logging
import tacorn.fileutils as fu
import tacorn.constants as consts
import tacorn.experiment as experiment
import tacorn.wrappers as wrappers
logging.basicConfig(level=logging.DEBUG,
format='%(asctime)s %(levelname)s %(message)s')
logger = logging.getLogger(__name__)
def get_pretrained_wavernn(model_id, targetdir):
""" Downloads a pretrained waveRNN model. """
if model_id not in consts.PRETRAINED_WAVERNN_MODELS:
raise FileNotFoundError(model_id)
(model_url, model_filename) = consts.PRETRAINED_WAVERNN_MODELS[model_id]
model_dir = os.path.join(targetdir, model_id)
model_path = os.path.join(model_dir, model_filename)
if os.path.exists(model_path):
return model_path
fu.ensure_dir(model_dir)
fu.download_file(model_url, model_path)
return model_path
def create_acoustic_model(exp: experiment.Experiment, args):
""" Loads and creates the acoustic model. """
module_wrapper = wrappers.load(exp.config["acoustic_model"])
module_wrapper.create(exp, vars(args))
if args.download_acoustic_model:
logger.info("Downloading acoustic model %s for model type %s" %
(args.download_acoustic_model, args.acoustic_model))
module_wrapper.download_pretrained(
exp, consts.PRETRAINED_ACOUSTIC_MODELS[args.download_acoustic_model][args.acoustic_model])
def create_wavegen_model(exp: experiment.Experiment, args):
""" Loads and creates the wavegen model. """
module_wrapper = wrappers.load(exp.config["wavegen_model"])
module_wrapper.create(exp, vars(args))
if args.download_wavegen_model:
logger.info("Downloading wavegen model %s for model type %s" %
(args.download_wavegen_model, args.wavegen_model))
module_wrapper.download_pretrained(
exp, consts.PRETRAINED_WAVEGEN_MODELS[args.download_wavegen_model][args.wavegen_model])
def main():
""" main function for creating a new experiment directory. """
parser = argparse.ArgumentParser()
parser.add_argument('experiment_dir',
help='Directory for the experiment to create.')
parser.add_argument('--acoustic_model', default="tacotron2",
help='Model to use for acoustic feature prediction (tacotron2).')
parser.add_argument('--download_acoustic_model', default=None,
choices=consts.PRETRAINED_ACOUSTIC_MODELS.keys(),
help=('Name of a pretrained feature model to download (%s)'
% (" ".join(consts.PRETRAINED_ACOUSTIC_MODELS.keys()))))
parser.add_argument('--wavegen_model', default=None,
help='Model to use for waveform generation (wavernn_alt). Default: none')
parser.add_argument('--download_wavegen_model', default=None,
choices=consts.PRETRAINED_WAVEGEN_MODELS.keys(),
help=('Name of a pretrained wavegen model to download (%s)'
% (" ".join(consts.PRETRAINED_WAVEGEN_MODELS.keys()))))
parser.add_argument('--force', action='store_const', const=True,
help='Forces creation of this experiment, deleting an existing experiment if necessary')
args = parser.parse_args()
if os.path.exists(args.experiment_dir):
if args.force:
logger.info("Deleting existing experiment at %s" %
(args.experiment_dir))
shutil.rmtree(args.experiment_dir)
else:
print("Experiment already exists at %s, stopping" %
(args.experiment_dir))
return -1
logger.info("Creating experiment at %s" % (args.experiment_dir))
exp = experiment.create(args.experiment_dir, args)
try:
create_acoustic_model(exp, args)
if args.wavegen_model:
create_wavegen_model(exp, args)
experiment.save(exp)
except ModuleNotFoundError as mnfe:
print(mnfe)
print("Module for %s not found, did you run install.sh?" %
(args.acoustic_model))
return 0
if __name__ == '__main__':
sys.exit(main())