Skip to content

Commit a2606fb

Browse files
committed
format utils
1 parent 4873601 commit a2606fb

File tree

6 files changed

+397
-374
lines changed

6 files changed

+397
-374
lines changed

TTS/speaker_encoder/utils/__init__.py

Whitespace-only changes.
Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
import datetime
2+
import importlib
3+
import os
4+
import re
5+
6+
import torch
7+
from TTS.speaker_encoder.model import SpeakerEncoder
8+
from TTS.utils.generic_utils import check_argument
9+
10+
11+
def to_camel(text):
12+
text = text.capitalize()
13+
return re.sub(r'(?!^)_([a-zA-Z])', lambda m: m.group(1).upper(), text)
14+
15+
16+
def setup_model(c):
17+
model = SpeakerEncoder(c.model['input_dim'], c.model['proj_dim'],
18+
c.model['lstm_dim'], c.model['num_lstm_layers'])
19+
return model
20+
21+
22+
def save_checkpoint(model, optimizer, model_loss, out_path,
23+
current_step, epoch):
24+
checkpoint_path = 'checkpoint_{}.pth.tar'.format(current_step)
25+
checkpoint_path = os.path.join(out_path, checkpoint_path)
26+
print(" | | > Checkpoint saving : {}".format(checkpoint_path))
27+
28+
new_state_dict = model.state_dict()
29+
state = {
30+
'model': new_state_dict,
31+
'optimizer': optimizer.state_dict() if optimizer is not None else None,
32+
'step': current_step,
33+
'epoch': epoch,
34+
'loss': model_loss,
35+
'date': datetime.date.today().strftime("%B %d, %Y"),
36+
}
37+
torch.save(state, checkpoint_path)
38+
39+
40+
def save_best_model(model, optimizer, model_loss, best_loss, out_path,
41+
current_step):
42+
if model_loss < best_loss:
43+
new_state_dict = model.state_dict()
44+
state = {
45+
'model': new_state_dict,
46+
'optimizer': optimizer.state_dict(),
47+
'step': current_step,
48+
'loss': model_loss,
49+
'date': datetime.date.today().strftime("%B %d, %Y"),
50+
}
51+
best_loss = model_loss
52+
bestmodel_path = 'best_model.pth.tar'
53+
bestmodel_path = os.path.join(out_path, bestmodel_path)
54+
print("\n > BEST MODEL ({0:.5f}) : {1:}".format(
55+
model_loss, bestmodel_path))
56+
torch.save(state, bestmodel_path)
57+
return best_loss
58+
59+
60+
def check_config_speaker_encoder(c):
61+
"""Check the config.json file of the speaker encoder"""
62+
check_argument('run_name', c, restricted=True, val_type=str)
63+
check_argument('run_description', c, val_type=str)
64+
65+
# audio processing parameters
66+
check_argument('audio', c, restricted=True, val_type=dict)
67+
check_argument('num_mels', c['audio'], restricted=True, val_type=int, min_val=10, max_val=2056)
68+
check_argument('fft_size', c['audio'], restricted=True, val_type=int, min_val=128, max_val=4058)
69+
check_argument('sample_rate', c['audio'], restricted=True, val_type=int, min_val=512, max_val=100000)
70+
check_argument('frame_length_ms', c['audio'], restricted=True, val_type=float, min_val=10, max_val=1000, alternative='win_length')
71+
check_argument('frame_shift_ms', c['audio'], restricted=True, val_type=float, min_val=1, max_val=1000, alternative='hop_length')
72+
check_argument('preemphasis', c['audio'], restricted=True, val_type=float, min_val=0, max_val=1)
73+
check_argument('min_level_db', c['audio'], restricted=True, val_type=int, min_val=-1000, max_val=10)
74+
check_argument('ref_level_db', c['audio'], restricted=True, val_type=int, min_val=0, max_val=1000)
75+
check_argument('power', c['audio'], restricted=True, val_type=float, min_val=1, max_val=5)
76+
check_argument('griffin_lim_iters', c['audio'], restricted=True, val_type=int, min_val=10, max_val=1000)
77+
78+
# training parameters
79+
check_argument('loss', c, enum_list=['ge2e', 'angleproto'], restricted=True, val_type=str)
80+
check_argument('grad_clip', c, restricted=True, val_type=float)
81+
check_argument('epochs', c, restricted=True, val_type=int, min_val=1)
82+
check_argument('lr', c, restricted=True, val_type=float, min_val=0)
83+
check_argument('lr_decay', c, restricted=True, val_type=bool)
84+
check_argument('warmup_steps', c, restricted=True, val_type=int, min_val=0)
85+
check_argument('tb_model_param_stats', c, restricted=True, val_type=bool)
86+
check_argument('num_speakers_in_batch', c, restricted=True, val_type=int)
87+
check_argument('num_loader_workers', c, restricted=True, val_type=int)
88+
check_argument('wd', c, restricted=True, val_type=float, min_val=0.0, max_val=1.0)
89+
90+
# checkpoint and output parameters
91+
check_argument('steps_plot_stats', c, restricted=True, val_type=int)
92+
check_argument('checkpoint', c, restricted=True, val_type=bool)
93+
check_argument('save_step', c, restricted=True, val_type=int)
94+
check_argument('print_step', c, restricted=True, val_type=int)
95+
check_argument('output_path', c, restricted=True, val_type=str)
96+
97+
# model parameters
98+
check_argument('model', c, restricted=True, val_type=dict)
99+
check_argument('input_dim', c['model'], restricted=True, val_type=int)
100+
check_argument('proj_dim', c['model'], restricted=True, val_type=int)
101+
check_argument('lstm_dim', c['model'], restricted=True, val_type=int)
102+
check_argument('num_lstm_layers', c['model'], restricted=True, val_type=int)
103+
check_argument('use_lstm_with_projection', c['model'], restricted=True, val_type=bool)
104+
105+
# in-memory storage parameters
106+
check_argument('storage', c, restricted=True, val_type=dict)
107+
check_argument('sample_from_storage_p', c['storage'], restricted=True, val_type=float, min_val=0.0, max_val=1.0)
108+
check_argument('storage_size', c['storage'], restricted=True, val_type=int, min_val=1, max_val=100)
109+
check_argument('additive_noise', c['storage'], restricted=True, val_type=float, min_val=0.0, max_val=1.0)
110+
111+
# datasets - checking only the first entry
112+
check_argument('datasets', c, restricted=True, val_type=list)
113+
for dataset_entry in c['datasets']:
114+
check_argument('name', dataset_entry, restricted=True, val_type=str)
115+
check_argument('path', dataset_entry, restricted=True, val_type=str)
116+
check_argument('meta_file_train', dataset_entry, restricted=True, val_type=[str, list])
117+
check_argument('meta_file_val', dataset_entry, restricted=True, val_type=str)
118+

TTS/speaker_encoder/utils/io.py

Whitespace-only changes.
Lines changed: 233 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,233 @@
1+
# coding=utf-8
2+
# Copyright (C) 2020 ATHENA AUTHORS; Yiping Peng; Ne Luo
3+
# All rights reserved.
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
# ==============================================================================
17+
# Only support eager mode and TF>=2.0.0
18+
# pylint: disable=no-member, invalid-name, relative-beyond-top-level
19+
# pylint: disable=too-many-locals, too-many-statements, too-many-arguments, too-many-instance-attributes
20+
''' voxceleb 1 & 2 '''
21+
22+
import os
23+
import sys
24+
import zipfile
25+
import subprocess
26+
import hashlib
27+
import pandas
28+
from absl import logging
29+
import tensorflow as tf
30+
import soundfile as sf
31+
32+
gfile = tf.compat.v1.gfile
33+
34+
SUBSETS = {
35+
"vox1_dev_wav":
36+
["http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox1_dev_wav_partaa",
37+
"http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox1_dev_wav_partab",
38+
"http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox1_dev_wav_partac",
39+
"http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox1_dev_wav_partad"],
40+
"vox1_test_wav":
41+
["http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox1_test_wav.zip"],
42+
"vox2_dev_aac":
43+
["http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_dev_aac_partaa",
44+
"http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_dev_aac_partab",
45+
"http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_dev_aac_partac",
46+
"http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_dev_aac_partad",
47+
"http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_dev_aac_partae",
48+
"http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_dev_aac_partaf",
49+
"http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_dev_aac_partag",
50+
"http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_dev_aac_partah"],
51+
"vox2_test_aac":
52+
["http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_test_aac.zip"]
53+
}
54+
55+
MD5SUM = {
56+
"vox1_dev_wav": "ae63e55b951748cc486645f532ba230b",
57+
"vox2_dev_aac": "bbc063c46078a602ca71605645c2a402",
58+
"vox1_test_wav": "185fdc63c3c739954633d50379a3d102",
59+
"vox2_test_aac": "0d2b3ea430a821c33263b5ea37ede312"
60+
}
61+
62+
USER = {
63+
"user": "",
64+
"password": ""
65+
}
66+
67+
speaker_id_dict = {}
68+
69+
def download_and_extract(directory, subset, urls):
70+
"""Download and extract the given split of dataset.
71+
72+
Args:
73+
directory: the directory where to put the downloaded data.
74+
subset: subset name of the corpus.
75+
urls: the list of urls to download the data file.
76+
"""
77+
if not gfile.Exists(directory):
78+
gfile.MakeDirs(directory)
79+
80+
try:
81+
for url in urls:
82+
zip_filepath = os.path.join(directory, url.split("/")[-1])
83+
if os.path.exists(zip_filepath):
84+
continue
85+
logging.info("Downloading %s to %s" % (url, zip_filepath))
86+
subprocess.call('wget %s --user %s --password %s -O %s' %
87+
(url, USER["user"], USER["password"], zip_filepath), shell=True)
88+
89+
statinfo = os.stat(zip_filepath)
90+
logging.info(
91+
"Successfully downloaded %s, size(bytes): %d" % (url, statinfo.st_size)
92+
)
93+
94+
# concatenate all parts into zip files
95+
if ".zip" not in zip_filepath:
96+
zip_filepath = "_".join(zip_filepath.split("_")[:-1])
97+
subprocess.call('cat %s* > %s.zip' %
98+
(zip_filepath, zip_filepath), shell=True)
99+
zip_filepath += ".zip"
100+
extract_path = zip_filepath.strip(".zip")
101+
102+
# check zip file md5sum
103+
md5 = hashlib.md5(open(zip_filepath, 'rb').read()).hexdigest()
104+
if md5 != MD5SUM[subset]:
105+
raise ValueError("md5sum of %s mismatch" % zip_filepath)
106+
107+
with zipfile.ZipFile(zip_filepath, "r") as zfile:
108+
zfile.extractall(directory)
109+
extract_path_ori = os.path.join(directory, zfile.infolist()[0].filename)
110+
subprocess.call('mv %s %s' % (extract_path_ori, extract_path), shell=True)
111+
finally:
112+
# gfile.Remove(zip_filepath)
113+
pass
114+
115+
116+
def exec_cmd(cmd):
117+
"""Run a command in a subprocess.
118+
Args:
119+
cmd: command line to be executed.
120+
Return:
121+
int, the return code.
122+
"""
123+
try:
124+
retcode = subprocess.call(cmd, shell=True)
125+
if retcode < 0:
126+
logging.info(f"Child was terminated by signal {retcode}")
127+
except OSError as e:
128+
logging.info(f"Execution failed: {e}")
129+
retcode = -999
130+
return retcode
131+
132+
133+
def decode_aac_with_ffmpeg(aac_file, wav_file):
134+
"""Decode a given AAC file into WAV using ffmpeg.
135+
Args:
136+
aac_file: file path to input AAC file.
137+
wav_file: file path to output WAV file.
138+
Return:
139+
bool, True if success.
140+
"""
141+
cmd = f"ffmpeg -i {aac_file} {wav_file}"
142+
logging.info(f"Decoding aac file using command line: {cmd}")
143+
ret = exec_cmd(cmd)
144+
if ret != 0:
145+
logging.error(f"Failed to decode aac file with retcode {ret}")
146+
logging.error("Please check your ffmpeg installation.")
147+
return False
148+
return True
149+
150+
151+
def convert_audio_and_make_label(input_dir, subset,
152+
output_dir, output_file):
153+
"""Optionally convert AAC to WAV and make speaker labels.
154+
Args:
155+
input_dir: the directory which holds the input dataset.
156+
subset: the name of the specified subset. e.g. vox1_dev_wav
157+
output_dir: the directory to place the newly generated csv files.
158+
output_file: the name of the newly generated csv file. e.g. vox1_dev_wav.csv
159+
"""
160+
161+
logging.info("Preprocessing audio and label for subset %s" % subset)
162+
source_dir = os.path.join(input_dir, subset)
163+
164+
files = []
165+
# Convert all AAC file into WAV format. At the same time, generate the csv
166+
for root, _, filenames in gfile.Walk(source_dir):
167+
for filename in filenames:
168+
name, ext = os.path.splitext(filename)
169+
if ext.lower() == ".wav":
170+
_, ext2 = (os.path.splitext(name))
171+
if ext2:
172+
continue
173+
wav_file = os.path.join(root, filename)
174+
elif ext.lower() == ".m4a":
175+
# Convert AAC to WAV.
176+
aac_file = os.path.join(root, filename)
177+
wav_file = aac_file + ".wav"
178+
if not gfile.Exists(wav_file):
179+
if not decode_aac_with_ffmpeg(aac_file, wav_file):
180+
raise RuntimeError("Audio decoding failed.")
181+
else:
182+
continue
183+
speaker_name = root.split(os.path.sep)[-2]
184+
if speaker_name not in speaker_id_dict:
185+
num = len(speaker_id_dict)
186+
speaker_id_dict[speaker_name] = num
187+
# wav_filesize = os.path.getsize(wav_file)
188+
wav_length = len(sf.read(wav_file)[0])
189+
files.append(
190+
(os.path.abspath(wav_file), wav_length, speaker_id_dict[speaker_name], speaker_name)
191+
)
192+
193+
# Write to CSV file which contains four columns:
194+
# "wav_filename", "wav_length_ms", "speaker_id", "speaker_name".
195+
csv_file_path = os.path.join(output_dir, output_file)
196+
df = pandas.DataFrame(
197+
data=files, columns=["wav_filename", "wav_length_ms", "speaker_id", "speaker_name"])
198+
df.to_csv(csv_file_path, index=False, sep="\t")
199+
logging.info("Successfully generated csv file {}".format(csv_file_path))
200+
201+
202+
def processor(directory, subset, force_process):
203+
""" download and process """
204+
urls = SUBSETS
205+
if subset not in urls:
206+
raise ValueError(subset, "is not in voxceleb")
207+
208+
subset_csv = os.path.join(directory, subset + '.csv')
209+
if not force_process and os.path.exists(subset_csv):
210+
return subset_csv
211+
212+
logging.info("Downloading and process the voxceleb in %s", directory)
213+
logging.info("Preparing subset %s", subset)
214+
download_and_extract(directory, subset, urls[subset])
215+
convert_audio_and_make_label(
216+
directory,
217+
subset,
218+
directory,
219+
subset + ".csv"
220+
)
221+
logging.info("Finished downloading and processing")
222+
return subset_csv
223+
224+
225+
if __name__ == "__main__":
226+
logging.set_verbosity(logging.INFO)
227+
if len(sys.argv) != 4:
228+
print("Usage: python prepare_data.py save_directory user password")
229+
sys.exit()
230+
231+
DIR, USER["user"], USER["password"] = sys.argv[1], sys.argv[2], sys.argv[3]
232+
for SUBSET in SUBSETS:
233+
processor(DIR, SUBSET, False)

TTS/speaker_encoder/utils/visual.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
import umap
2+
import numpy as np
3+
import matplotlib
4+
import matplotlib.pyplot as plt
5+
6+
matplotlib.use("Agg")
7+
8+
9+
colormap = (
10+
np.array(
11+
[
12+
[76, 255, 0],
13+
[0, 127, 70],
14+
[255, 0, 0],
15+
[255, 217, 38],
16+
[0, 135, 255],
17+
[165, 0, 165],
18+
[255, 167, 255],
19+
[0, 255, 255],
20+
[255, 96, 38],
21+
[142, 76, 0],
22+
[33, 0, 127],
23+
[0, 0, 0],
24+
[183, 183, 183],
25+
],
26+
dtype=np.float,
27+
)
28+
/ 255
29+
)
30+
31+
32+
def plot_embeddings(embeddings, num_utter_per_speaker):
33+
embeddings = embeddings[: 10 * num_utter_per_speaker]
34+
model = umap.UMAP()
35+
projection = model.fit_transform(embeddings)
36+
num_speakers = embeddings.shape[0] // num_utter_per_speaker
37+
ground_truth = np.repeat(np.arange(num_speakers), num_utter_per_speaker)
38+
colors = [colormap[i] for i in ground_truth]
39+
40+
fig, ax = plt.subplots(figsize=(16, 10))
41+
_ = ax.scatter(projection[:, 0], projection[:, 1], c=colors)
42+
plt.gca().set_aspect("equal", "datalim")
43+
plt.title("UMAP projection")
44+
plt.tight_layout()
45+
plt.savefig("umap")
46+
return fig

0 commit comments

Comments
 (0)