Skip to content

Commit

Permalink
add External Embedding per sample instead of nn.Embedding
Browse files Browse the repository at this point in the history
  • Loading branch information
Edresson committed Jul 29, 2020
1 parent b06da9d commit d92cb8f
Show file tree
Hide file tree
Showing 9 changed files with 191 additions and 124 deletions.
91 changes: 61 additions & 30 deletions TTS/bin/train_tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
use_cuda, num_gpus = setup_torch_training_env(True, False)


def setup_loader(ap, r, is_val=False, verbose=False):
def setup_loader(ap, r, is_val=False, verbose=False, speaker_mapping=None):
if is_val and not c.run_eval:
loader = None
else:
Expand All @@ -60,7 +60,8 @@ def setup_loader(ap, r, is_val=False, verbose=False):
use_phonemes=c.use_phonemes,
phoneme_language=c.phoneme_language,
enable_eos_bos=c.enable_eos_bos_chars,
verbose=verbose)
verbose=verbose,
speaker_mapping=speaker_mapping if c.use_speaker_embedding and c.use_external_speaker_embedding_file else None)
sampler = DistributedSampler(dataset) if num_gpus > 1 else None
loader = DataLoader(
dataset,
Expand All @@ -74,9 +75,8 @@ def setup_loader(ap, r, is_val=False, verbose=False):
pin_memory=False)
return loader


def format_data(data):
if c.use_speaker_embedding:
def format_data(data, speaker_mapping=None):
if speaker_mapping is None and c.use_speaker_embedding and not c.use_external_speaker_embedding_file:
speaker_mapping = load_speaker_mapping(OUT_PATH)

# setup input data
Expand All @@ -91,13 +91,20 @@ def format_data(data):
avg_spec_length = torch.mean(mel_lengths.float())

if c.use_speaker_embedding:
speaker_ids = [
speaker_mapping[speaker_name] for speaker_name in speaker_names
]
speaker_ids = torch.LongTensor(speaker_ids)
if c.use_external_speaker_embedding_file:
speaker_embeddings = data[8]
speaker_ids = None
else:
speaker_ids = [
speaker_mapping[speaker_name] for speaker_name in speaker_names
]
speaker_ids = torch.LongTensor(speaker_ids)
speaker_embeddings = None
else:
speaker_embeddings = None
speaker_ids = None


# set stop targets view, we predict a single stop token per iteration.
stop_targets = stop_targets.view(text_input.shape[0],
stop_targets.size(1) // c.r, -1)
Expand All @@ -114,13 +121,16 @@ def format_data(data):
stop_targets = stop_targets.cuda(non_blocking=True)
if speaker_ids is not None:
speaker_ids = speaker_ids.cuda(non_blocking=True)
return text_input, text_lengths, mel_input, mel_lengths, linear_input, stop_targets, speaker_ids, avg_text_length, avg_spec_length
if speaker_embeddings is not None:
speaker_embeddings = speaker_embeddings.cuda(non_blocking=True)

return text_input, text_lengths, mel_input, mel_lengths, linear_input, stop_targets, speaker_ids, speaker_embeddings, avg_text_length, avg_spec_length


def train(model, criterion, optimizer, optimizer_st, scheduler,
ap, global_step, epoch):
ap, global_step, epoch, speaker_mapping=None):
data_loader = setup_loader(ap, model.decoder.r, is_val=False,
verbose=(epoch == 0))
verbose=(epoch == 0), speaker_mapping=speaker_mapping)
model.train()
epoch_time = 0
keep_avg = KeepAverage()
Expand All @@ -135,7 +145,7 @@ def train(model, criterion, optimizer, optimizer_st, scheduler,
start_time = time.time()

# format data
text_input, text_lengths, mel_input, mel_lengths, linear_input, stop_targets, speaker_ids, avg_text_length, avg_spec_length = format_data(data)
text_input, text_lengths, mel_input, mel_lengths, linear_input, stop_targets, speaker_ids, speaker_embeddings, avg_text_length, avg_spec_length = format_data(data, speaker_mapping)
loader_time = time.time() - end_time

global_step += 1
Expand All @@ -150,10 +160,10 @@ def train(model, criterion, optimizer, optimizer_st, scheduler,
# forward pass model
if c.bidirectional_decoder or c.double_decoder_consistency:
decoder_output, postnet_output, alignments, stop_tokens, decoder_backward_output, alignments_backward = model(
text_input, text_lengths, mel_input, mel_lengths, speaker_ids=speaker_ids)
text_input, text_lengths, mel_input, mel_lengths, speaker_ids=speaker_ids, speaker_embeddings=speaker_embeddings)
else:
decoder_output, postnet_output, alignments, stop_tokens = model(
text_input, text_lengths, mel_input, mel_lengths, speaker_ids=speaker_ids)
text_input, text_lengths, mel_input, mel_lengths, speaker_ids=speaker_ids, speaker_embeddings=speaker_embeddings)
decoder_backward_output = None
alignments_backward = None

Expand Down Expand Up @@ -285,8 +295,8 @@ def train(model, criterion, optimizer, optimizer_st, scheduler,


@torch.no_grad()
def evaluate(model, criterion, ap, global_step, epoch):
data_loader = setup_loader(ap, model.decoder.r, is_val=True)
def evaluate(model, criterion, ap, global_step, epoch, speaker_mapping=None):
data_loader = setup_loader(ap, model.decoder.r, is_val=True, speaker_mapping=speaker_mapping)
model.eval()
epoch_time = 0
keep_avg = KeepAverage()
Expand All @@ -296,16 +306,16 @@ def evaluate(model, criterion, ap, global_step, epoch):
start_time = time.time()

# format data
text_input, text_lengths, mel_input, mel_lengths, linear_input, stop_targets, speaker_ids, _, _ = format_data(data)
text_input, text_lengths, mel_input, mel_lengths, linear_input, stop_targets, speaker_ids, speaker_embeddings, _, _ = format_data(data, speaker_mapping)
assert mel_input.shape[1] % model.decoder.r == 0

# forward pass model
if c.bidirectional_decoder or c.double_decoder_consistency:
decoder_output, postnet_output, alignments, stop_tokens, decoder_backward_output, alignments_backward = model(
text_input, text_lengths, mel_input, speaker_ids=speaker_ids)
text_input, text_lengths, mel_input, speaker_ids=speaker_ids, speaker_embeddings=speaker_embeddings)
else:
decoder_output, postnet_output, alignments, stop_tokens = model(
text_input, text_lengths, mel_input, speaker_ids=speaker_ids)
text_input, text_lengths, mel_input, speaker_ids=speaker_ids, speaker_embeddings=speaker_embeddings)
decoder_backward_output = None
alignments_backward = None

Expand Down Expand Up @@ -467,22 +477,41 @@ def main(args): # pylint: disable=redefined-outer-name
if c.use_speaker_embedding:
speakers = get_speakers(meta_data_train)
if args.restore_path:
prev_out_path = os.path.dirname(args.restore_path)
speaker_mapping = load_speaker_mapping(prev_out_path)
assert all([speaker in speaker_mapping
for speaker in speakers]), "As of now you, you cannot " \
"introduce new speakers to " \
"a previously trained model."
else:
if c.use_external_speaker_embedding_file: # if restore checkpoint and use External Embedding file
prev_out_path = os.path.dirname(args.restore_path)
speaker_mapping = load_speaker_mapping(prev_out_path)
if not speaker_mapping:
print("WARNING: speakers.json speakers.json was not found in restore_path, trying to use CONFIG.external_speaker_embedding_file")
speaker_mapping = load_speaker_mapping(c.external_speaker_embedding_file)
if not speaker_mapping:
raise RuntimeError("You must copy the file speakers.json to restore_path, or set a valid file in CONFIG.external_speaker_embedding_file")
speaker_embedding_dim = len(speaker_mapping[list(speaker_mapping.keys())[0]]['embedding'])
elif not c.use_external_speaker_embedding_file: # if restore checkpoint and don't use External Embedding file
prev_out_path = os.path.dirname(args.restore_path)
speaker_mapping = load_speaker_mapping(prev_out_path)
speaker_embedding_dim = None
assert all([speaker in speaker_mapping
for speaker in speakers]), "As of now you, you cannot " \
"introduce new speakers to " \
"a previously trained model."
elif c.use_external_speaker_embedding_file and c.external_speaker_embedding_file: # if start new train using External Embedding file
speaker_mapping = load_speaker_mapping(c.external_speaker_embedding_file)
print(speaker_mapping)
speaker_embedding_dim = len(speaker_mapping[list(speaker_mapping.keys())[0]]['embedding'])
elif c.use_external_speaker_embedding_file and not c.external_speaker_embedding_file: # if start new train using External Embedding file and don't pass external embedding file
raise "use_external_speaker_embedding_file is True, so you need pass a external speaker embedding file, run GE2E-Speaker_Encoder-ExtractSpeakerEmbeddings-by-sample.ipynb or AngularPrototypical-Speaker_Encoder-ExtractSpeakerEmbeddings-by-sample.ipynb notebook in notebooks/ folder"
else: # if start new train and don't use External Embedding file
speaker_mapping = {name: i for i, name in enumerate(speakers)}
speaker_embedding_dim = None
save_speaker_mapping(OUT_PATH, speaker_mapping)
num_speakers = len(speaker_mapping)
print("Training with {} speakers: {}".format(num_speakers,
", ".join(speakers)))
else:
num_speakers = 0
speaker_embedding_dim = None

model = setup_model(num_chars, num_speakers, c)
model = setup_model(num_chars, num_speakers, c, speaker_embedding_dim)

params = set_weight_decay(model, c.wd)
optimizer = RAdam(params, lr=c.lr, weight_decay=0)
Expand All @@ -495,6 +524,8 @@ def main(args): # pylint: disable=redefined-outer-name

# setup criterion
criterion = TacotronLoss(c, stopnet_pos_weight=10.0, ga_sigma=0.4)
for name, _ in model.named_parameters():
print(name)

if args.restore_path:
checkpoint = torch.load(args.restore_path, map_location='cpu')
Expand Down Expand Up @@ -553,8 +584,8 @@ def main(args): # pylint: disable=redefined-outer-name
print("\n > Number of output frames:", model.decoder.r)
train_avg_loss_dict, global_step = train(model, criterion, optimizer,
optimizer_st, scheduler, ap,
global_step, epoch)
eval_avg_loss_dict = evaluate(model, criterion, ap, global_step, epoch)
global_step, epoch, speaker_mapping)
eval_avg_loss_dict = evaluate(model, criterion, ap, global_step, epoch, speaker_mapping)
c_logger.print_epoch_end(epoch, eval_avg_loss_dict)
target_loss = train_avg_loss_dict['avg_postnet_loss']
if c.run_eval:
Expand Down
8 changes: 5 additions & 3 deletions TTS/tts/configs/config.json
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
{
"model": "Tacotron",
"model": "Tacotron2",
"run_name": "ljspeech-ddc-bn",
"run_description": "tacotron2 with ddc and batch-normalization",

Expand Down Expand Up @@ -113,7 +113,7 @@
"tb_model_param_stats": false, // true, plots param stats per layer on tensorboard. Might be memory consuming, but good for debugging.

// DATA LOADING
"text_cleaner": "portuguese_cleaners",
"text_cleaner": "phoneme_cleaners",
"enable_eos_bos_chars": false, // enable/disable beginning of sentence and end of sentence chars.
"num_loader_workers": 4, // number of training data loader processes. Don't set it too big. 4-8 are good values.
"num_val_loader_workers": 4, // number of evaluation data loader processes.
Expand All @@ -130,7 +130,9 @@
"phoneme_language": "en-us", // depending on your target language, pick one from https://github.com/bootphon/phonemizer#languages

// MULTI-SPEAKER and GST
"use_speaker_embedding": true, // use speaker embedding to enable multi-speaker learning.
"use_speaker_embedding": true, // use speaker embedding to enable multi-speaker learning.
"use_external_speaker_embedding_file": false, // if true, forces the model to use external embedding per sample instead of nn.embeddings, that is, it supports external embeddings such as those used at: https://arxiv.org/abs /1806.04558
"external_speaker_embedding_file": "../../speakers-vctk-en.json", // if not null and use_external_speaker_embedding_file is true, it is used to load a specific embedding file and thus uses these embeddings instead of nn.embeddings, that is, it supports external embeddings such as those used at: https://arxiv.org/abs /1806.04558
"use_gst": true, // use global style tokens
"gst": { // gst parameter if gst is enabled
"gst_style_input": null, // Condition the style input either on a
Expand Down
18 changes: 15 additions & 3 deletions TTS/tts/datasets/TTSDataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def __init__(self,
phoneme_cache_path=None,
phoneme_language="en-us",
enable_eos_bos=False,
speaker_mapping=None,
verbose=False):
"""
Args:
Expand Down Expand Up @@ -58,6 +59,7 @@ def __init__(self,
self.phoneme_cache_path = phoneme_cache_path
self.phoneme_language = phoneme_language
self.enable_eos_bos = enable_eos_bos
self.speaker_mapping = speaker_mapping
self.verbose = verbose
if use_phonemes and not os.path.isdir(phoneme_cache_path):
os.makedirs(phoneme_cache_path, exist_ok=True)
Expand Down Expand Up @@ -127,7 +129,8 @@ def load_data(self, idx):
'text': text,
'wav': wav,
'item_idx': self.items[idx][1],
'speaker_name': speaker_name
'speaker_name': speaker_name,
'wav_file_name': os.path.basename(wav_file)
}
return sample

Expand Down Expand Up @@ -191,9 +194,15 @@ def collate_fn(self, batch):
batch[idx]['item_idx'] for idx in ids_sorted_decreasing
]
text = [batch[idx]['text'] for idx in ids_sorted_decreasing]

speaker_name = [batch[idx]['speaker_name']
for idx in ids_sorted_decreasing]

# get speaker embeddings
if self.speaker_mapping is not None:
wav_files_names = [batch[idx]['wav_file_name'] for idx in ids_sorted_decreasing]
speaker_embedding = [self.speaker_mapping[w]['embedding'] for w in wav_files_names]
else:
speaker_embedding = None
# compute features
mel = [self.ap.melspectrogram(w).astype('float32') for w in wav]

Expand Down Expand Up @@ -224,6 +233,9 @@ def collate_fn(self, batch):
mel_lengths = torch.LongTensor(mel_lengths)
stop_targets = torch.FloatTensor(stop_targets)

if speaker_embedding is not None:
speaker_embedding = torch.FloatTensor(speaker_embedding)

# compute linear spectrogram
if self.compute_linear_spec:
linear = [self.ap.spectrogram(w).astype('float32') for w in wav]
Expand All @@ -234,7 +246,7 @@ def collate_fn(self, batch):
else:
linear = None
return text, text_lenghts, speaker_name, linear, mel, mel_lengths, \
stop_targets, item_idxs
stop_targets, item_idxs, speaker_embedding

raise TypeError(("batch must contain tensors, numbers, dicts or lists;\
found {}".format(type(batch[0]))))
9 changes: 2 additions & 7 deletions TTS/tts/layers/tacotron.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ class Decoder(nn.Module):
def __init__(self, in_features, memory_dim, r, memory_size, attn_type, attn_windowing,
attn_norm, prenet_type, prenet_dropout, forward_attn,
trans_agent, forward_attn_mask, location_attn, attn_K,
separate_stopnet, speaker_embedding_dim):
separate_stopnet):
super(Decoder, self).__init__()
self.r_init = r
self.r = r
Expand Down Expand Up @@ -438,15 +438,12 @@ def forward(self, inputs, memory, mask):
t += 1
return self._parse_outputs(outputs, attentions, stop_tokens)

def inference(self, inputs, speaker_embeddings=None):
def inference(self, inputs):
"""
Args:
inputs: encoder outputs.
speaker_embeddings: speaker vectors.
Shapes:
- inputs: batch x time x encoder_out_dim
- speaker_embeddings: batch x embed_dim
"""
outputs = []
attentions = []
Expand All @@ -459,8 +456,6 @@ def inference(self, inputs, speaker_embeddings=None):
if t > 0:
new_memory = outputs[-1]
self._update_memory_input(new_memory)
if speaker_embeddings is not None:
self.memory_input = torch.cat([self.memory_input, speaker_embeddings], dim=-1)
output, stop_token, attention = self.decode(inputs, None)
stop_token = torch.sigmoid(stop_token.data)
outputs += [output]
Expand Down
7 changes: 2 additions & 5 deletions TTS/tts/layers/tacotron2.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,7 @@ class Decoder(nn.Module):
#pylint: disable=attribute-defined-outside-init
def __init__(self, input_dim, frame_dim, r, attn_type, attn_win, attn_norm,
prenet_type, prenet_dropout, forward_attn, trans_agent,
forward_attn_mask, location_attn, attn_K, separate_stopnet,
speaker_embedding_dim):
forward_attn_mask, location_attn, attn_K, separate_stopnet):
super(Decoder, self).__init__()
self.frame_dim = frame_dim
self.r_init = r
Expand Down Expand Up @@ -268,7 +267,7 @@ def forward(self, inputs, memories, mask):
outputs, stop_tokens, alignments)
return outputs, alignments, stop_tokens

def inference(self, inputs, speaker_embeddings=None):
def inference(self, inputs):
memory = self.get_go_frame(inputs)
memory = self._update_memory(memory)

Expand All @@ -278,8 +277,6 @@ def inference(self, inputs, speaker_embeddings=None):
outputs, stop_tokens, alignments, t = [], [], [], 0
while True:
memory = self.prenet(memory)
if speaker_embeddings is not None:
memory = torch.cat([memory, speaker_embeddings], dim=-1)
decoder_output, alignment, stop_token = self.decode(memory)
stop_token = torch.sigmoid(stop_token.data)
outputs += [decoder_output.squeeze(1)]
Expand Down
Loading

0 comments on commit d92cb8f

Please sign in to comment.