From d92cb8f68bff6d1925975c045d4f054bc2897e7e Mon Sep 17 00:00:00 2001 From: Edresson Date: Wed, 29 Jul 2020 00:49:00 -0300 Subject: [PATCH] add External Embedding per sample instead of nn.Embedding --- TTS/bin/train_tts.py | 91 +++++++++++++++++++++++----------- TTS/tts/configs/config.json | 8 +-- TTS/tts/datasets/TTSDataset.py | 18 +++++-- TTS/tts/layers/tacotron.py | 9 +--- TTS/tts/layers/tacotron2.py | 7 +-- TTS/tts/models/tacotron.py | 75 ++++++++++++++++------------ TTS/tts/models/tacotron2.py | 86 ++++++++++++++++++-------------- TTS/tts/utils/generic_utils.py | 14 ++++-- TTS/tts/utils/speakers.py | 7 ++- 9 files changed, 191 insertions(+), 124 deletions(-) diff --git a/TTS/bin/train_tts.py b/TTS/bin/train_tts.py index ef14f6275..aa476bc8c 100644 --- a/TTS/bin/train_tts.py +++ b/TTS/bin/train_tts.py @@ -41,7 +41,7 @@ use_cuda, num_gpus = setup_torch_training_env(True, False) -def setup_loader(ap, r, is_val=False, verbose=False): +def setup_loader(ap, r, is_val=False, verbose=False, speaker_mapping=None): if is_val and not c.run_eval: loader = None else: @@ -60,7 +60,8 @@ def setup_loader(ap, r, is_val=False, verbose=False): use_phonemes=c.use_phonemes, phoneme_language=c.phoneme_language, enable_eos_bos=c.enable_eos_bos_chars, - verbose=verbose) + verbose=verbose, + speaker_mapping=speaker_mapping if c.use_speaker_embedding and c.use_external_speaker_embedding_file else None) sampler = DistributedSampler(dataset) if num_gpus > 1 else None loader = DataLoader( dataset, @@ -74,9 +75,8 @@ def setup_loader(ap, r, is_val=False, verbose=False): pin_memory=False) return loader - -def format_data(data): - if c.use_speaker_embedding: +def format_data(data, speaker_mapping=None): + if speaker_mapping is None and c.use_speaker_embedding and not c.use_external_speaker_embedding_file: speaker_mapping = load_speaker_mapping(OUT_PATH) # setup input data @@ -91,13 +91,20 @@ def format_data(data): avg_spec_length = torch.mean(mel_lengths.float()) if c.use_speaker_embedding: - speaker_ids = [ - speaker_mapping[speaker_name] for speaker_name in speaker_names - ] - speaker_ids = torch.LongTensor(speaker_ids) + if c.use_external_speaker_embedding_file: + speaker_embeddings = data[8] + speaker_ids = None + else: + speaker_ids = [ + speaker_mapping[speaker_name] for speaker_name in speaker_names + ] + speaker_ids = torch.LongTensor(speaker_ids) + speaker_embeddings = None else: + speaker_embeddings = None speaker_ids = None + # set stop targets view, we predict a single stop token per iteration. stop_targets = stop_targets.view(text_input.shape[0], stop_targets.size(1) // c.r, -1) @@ -114,13 +121,16 @@ def format_data(data): stop_targets = stop_targets.cuda(non_blocking=True) if speaker_ids is not None: speaker_ids = speaker_ids.cuda(non_blocking=True) - return text_input, text_lengths, mel_input, mel_lengths, linear_input, stop_targets, speaker_ids, avg_text_length, avg_spec_length + if speaker_embeddings is not None: + speaker_embeddings = speaker_embeddings.cuda(non_blocking=True) + + return text_input, text_lengths, mel_input, mel_lengths, linear_input, stop_targets, speaker_ids, speaker_embeddings, avg_text_length, avg_spec_length def train(model, criterion, optimizer, optimizer_st, scheduler, - ap, global_step, epoch): + ap, global_step, epoch, speaker_mapping=None): data_loader = setup_loader(ap, model.decoder.r, is_val=False, - verbose=(epoch == 0)) + verbose=(epoch == 0), speaker_mapping=speaker_mapping) model.train() epoch_time = 0 keep_avg = KeepAverage() @@ -135,7 +145,7 @@ def train(model, criterion, optimizer, optimizer_st, scheduler, start_time = time.time() # format data - text_input, text_lengths, mel_input, mel_lengths, linear_input, stop_targets, speaker_ids, avg_text_length, avg_spec_length = format_data(data) + text_input, text_lengths, mel_input, mel_lengths, linear_input, stop_targets, speaker_ids, speaker_embeddings, avg_text_length, avg_spec_length = format_data(data, speaker_mapping) loader_time = time.time() - end_time global_step += 1 @@ -150,10 +160,10 @@ def train(model, criterion, optimizer, optimizer_st, scheduler, # forward pass model if c.bidirectional_decoder or c.double_decoder_consistency: decoder_output, postnet_output, alignments, stop_tokens, decoder_backward_output, alignments_backward = model( - text_input, text_lengths, mel_input, mel_lengths, speaker_ids=speaker_ids) + text_input, text_lengths, mel_input, mel_lengths, speaker_ids=speaker_ids, speaker_embeddings=speaker_embeddings) else: decoder_output, postnet_output, alignments, stop_tokens = model( - text_input, text_lengths, mel_input, mel_lengths, speaker_ids=speaker_ids) + text_input, text_lengths, mel_input, mel_lengths, speaker_ids=speaker_ids, speaker_embeddings=speaker_embeddings) decoder_backward_output = None alignments_backward = None @@ -285,8 +295,8 @@ def train(model, criterion, optimizer, optimizer_st, scheduler, @torch.no_grad() -def evaluate(model, criterion, ap, global_step, epoch): - data_loader = setup_loader(ap, model.decoder.r, is_val=True) +def evaluate(model, criterion, ap, global_step, epoch, speaker_mapping=None): + data_loader = setup_loader(ap, model.decoder.r, is_val=True, speaker_mapping=speaker_mapping) model.eval() epoch_time = 0 keep_avg = KeepAverage() @@ -296,16 +306,16 @@ def evaluate(model, criterion, ap, global_step, epoch): start_time = time.time() # format data - text_input, text_lengths, mel_input, mel_lengths, linear_input, stop_targets, speaker_ids, _, _ = format_data(data) + text_input, text_lengths, mel_input, mel_lengths, linear_input, stop_targets, speaker_ids, speaker_embeddings, _, _ = format_data(data, speaker_mapping) assert mel_input.shape[1] % model.decoder.r == 0 # forward pass model if c.bidirectional_decoder or c.double_decoder_consistency: decoder_output, postnet_output, alignments, stop_tokens, decoder_backward_output, alignments_backward = model( - text_input, text_lengths, mel_input, speaker_ids=speaker_ids) + text_input, text_lengths, mel_input, speaker_ids=speaker_ids, speaker_embeddings=speaker_embeddings) else: decoder_output, postnet_output, alignments, stop_tokens = model( - text_input, text_lengths, mel_input, speaker_ids=speaker_ids) + text_input, text_lengths, mel_input, speaker_ids=speaker_ids, speaker_embeddings=speaker_embeddings) decoder_backward_output = None alignments_backward = None @@ -467,22 +477,41 @@ def main(args): # pylint: disable=redefined-outer-name if c.use_speaker_embedding: speakers = get_speakers(meta_data_train) if args.restore_path: - prev_out_path = os.path.dirname(args.restore_path) - speaker_mapping = load_speaker_mapping(prev_out_path) - assert all([speaker in speaker_mapping - for speaker in speakers]), "As of now you, you cannot " \ - "introduce new speakers to " \ - "a previously trained model." - else: + if c.use_external_speaker_embedding_file: # if restore checkpoint and use External Embedding file + prev_out_path = os.path.dirname(args.restore_path) + speaker_mapping = load_speaker_mapping(prev_out_path) + if not speaker_mapping: + print("WARNING: speakers.json speakers.json was not found in restore_path, trying to use CONFIG.external_speaker_embedding_file") + speaker_mapping = load_speaker_mapping(c.external_speaker_embedding_file) + if not speaker_mapping: + raise RuntimeError("You must copy the file speakers.json to restore_path, or set a valid file in CONFIG.external_speaker_embedding_file") + speaker_embedding_dim = len(speaker_mapping[list(speaker_mapping.keys())[0]]['embedding']) + elif not c.use_external_speaker_embedding_file: # if restore checkpoint and don't use External Embedding file + prev_out_path = os.path.dirname(args.restore_path) + speaker_mapping = load_speaker_mapping(prev_out_path) + speaker_embedding_dim = None + assert all([speaker in speaker_mapping + for speaker in speakers]), "As of now you, you cannot " \ + "introduce new speakers to " \ + "a previously trained model." + elif c.use_external_speaker_embedding_file and c.external_speaker_embedding_file: # if start new train using External Embedding file + speaker_mapping = load_speaker_mapping(c.external_speaker_embedding_file) + print(speaker_mapping) + speaker_embedding_dim = len(speaker_mapping[list(speaker_mapping.keys())[0]]['embedding']) + elif c.use_external_speaker_embedding_file and not c.external_speaker_embedding_file: # if start new train using External Embedding file and don't pass external embedding file + raise "use_external_speaker_embedding_file is True, so you need pass a external speaker embedding file, run GE2E-Speaker_Encoder-ExtractSpeakerEmbeddings-by-sample.ipynb or AngularPrototypical-Speaker_Encoder-ExtractSpeakerEmbeddings-by-sample.ipynb notebook in notebooks/ folder" + else: # if start new train and don't use External Embedding file speaker_mapping = {name: i for i, name in enumerate(speakers)} + speaker_embedding_dim = None save_speaker_mapping(OUT_PATH, speaker_mapping) num_speakers = len(speaker_mapping) print("Training with {} speakers: {}".format(num_speakers, ", ".join(speakers))) else: num_speakers = 0 + speaker_embedding_dim = None - model = setup_model(num_chars, num_speakers, c) + model = setup_model(num_chars, num_speakers, c, speaker_embedding_dim) params = set_weight_decay(model, c.wd) optimizer = RAdam(params, lr=c.lr, weight_decay=0) @@ -495,6 +524,8 @@ def main(args): # pylint: disable=redefined-outer-name # setup criterion criterion = TacotronLoss(c, stopnet_pos_weight=10.0, ga_sigma=0.4) + for name, _ in model.named_parameters(): + print(name) if args.restore_path: checkpoint = torch.load(args.restore_path, map_location='cpu') @@ -553,8 +584,8 @@ def main(args): # pylint: disable=redefined-outer-name print("\n > Number of output frames:", model.decoder.r) train_avg_loss_dict, global_step = train(model, criterion, optimizer, optimizer_st, scheduler, ap, - global_step, epoch) - eval_avg_loss_dict = evaluate(model, criterion, ap, global_step, epoch) + global_step, epoch, speaker_mapping) + eval_avg_loss_dict = evaluate(model, criterion, ap, global_step, epoch, speaker_mapping) c_logger.print_epoch_end(epoch, eval_avg_loss_dict) target_loss = train_avg_loss_dict['avg_postnet_loss'] if c.run_eval: diff --git a/TTS/tts/configs/config.json b/TTS/tts/configs/config.json index add798f32..7d4f34fab 100644 --- a/TTS/tts/configs/config.json +++ b/TTS/tts/configs/config.json @@ -1,5 +1,5 @@ { - "model": "Tacotron", + "model": "Tacotron2", "run_name": "ljspeech-ddc-bn", "run_description": "tacotron2 with ddc and batch-normalization", @@ -113,7 +113,7 @@ "tb_model_param_stats": false, // true, plots param stats per layer on tensorboard. Might be memory consuming, but good for debugging. // DATA LOADING - "text_cleaner": "portuguese_cleaners", + "text_cleaner": "phoneme_cleaners", "enable_eos_bos_chars": false, // enable/disable beginning of sentence and end of sentence chars. "num_loader_workers": 4, // number of training data loader processes. Don't set it too big. 4-8 are good values. "num_val_loader_workers": 4, // number of evaluation data loader processes. @@ -130,7 +130,9 @@ "phoneme_language": "en-us", // depending on your target language, pick one from https://github.com/bootphon/phonemizer#languages // MULTI-SPEAKER and GST - "use_speaker_embedding": true, // use speaker embedding to enable multi-speaker learning. + "use_speaker_embedding": true, // use speaker embedding to enable multi-speaker learning. + "use_external_speaker_embedding_file": false, // if true, forces the model to use external embedding per sample instead of nn.embeddings, that is, it supports external embeddings such as those used at: https://arxiv.org/abs /1806.04558 + "external_speaker_embedding_file": "../../speakers-vctk-en.json", // if not null and use_external_speaker_embedding_file is true, it is used to load a specific embedding file and thus uses these embeddings instead of nn.embeddings, that is, it supports external embeddings such as those used at: https://arxiv.org/abs /1806.04558 "use_gst": true, // use global style tokens "gst": { // gst parameter if gst is enabled "gst_style_input": null, // Condition the style input either on a diff --git a/TTS/tts/datasets/TTSDataset.py b/TTS/tts/datasets/TTSDataset.py index 9f8d17451..f985c4b76 100644 --- a/TTS/tts/datasets/TTSDataset.py +++ b/TTS/tts/datasets/TTSDataset.py @@ -24,6 +24,7 @@ def __init__(self, phoneme_cache_path=None, phoneme_language="en-us", enable_eos_bos=False, + speaker_mapping=None, verbose=False): """ Args: @@ -58,6 +59,7 @@ def __init__(self, self.phoneme_cache_path = phoneme_cache_path self.phoneme_language = phoneme_language self.enable_eos_bos = enable_eos_bos + self.speaker_mapping = speaker_mapping self.verbose = verbose if use_phonemes and not os.path.isdir(phoneme_cache_path): os.makedirs(phoneme_cache_path, exist_ok=True) @@ -127,7 +129,8 @@ def load_data(self, idx): 'text': text, 'wav': wav, 'item_idx': self.items[idx][1], - 'speaker_name': speaker_name + 'speaker_name': speaker_name, + 'wav_file_name': os.path.basename(wav_file) } return sample @@ -191,9 +194,15 @@ def collate_fn(self, batch): batch[idx]['item_idx'] for idx in ids_sorted_decreasing ] text = [batch[idx]['text'] for idx in ids_sorted_decreasing] + speaker_name = [batch[idx]['speaker_name'] for idx in ids_sorted_decreasing] - + # get speaker embeddings + if self.speaker_mapping is not None: + wav_files_names = [batch[idx]['wav_file_name'] for idx in ids_sorted_decreasing] + speaker_embedding = [self.speaker_mapping[w]['embedding'] for w in wav_files_names] + else: + speaker_embedding = None # compute features mel = [self.ap.melspectrogram(w).astype('float32') for w in wav] @@ -224,6 +233,9 @@ def collate_fn(self, batch): mel_lengths = torch.LongTensor(mel_lengths) stop_targets = torch.FloatTensor(stop_targets) + if speaker_embedding is not None: + speaker_embedding = torch.FloatTensor(speaker_embedding) + # compute linear spectrogram if self.compute_linear_spec: linear = [self.ap.spectrogram(w).astype('float32') for w in wav] @@ -234,7 +246,7 @@ def collate_fn(self, batch): else: linear = None return text, text_lenghts, speaker_name, linear, mel, mel_lengths, \ - stop_targets, item_idxs + stop_targets, item_idxs, speaker_embedding raise TypeError(("batch must contain tensors, numbers, dicts or lists;\ found {}".format(type(batch[0])))) diff --git a/TTS/tts/layers/tacotron.py b/TTS/tts/layers/tacotron.py index a2ccc917e..df17e0882 100644 --- a/TTS/tts/layers/tacotron.py +++ b/TTS/tts/layers/tacotron.py @@ -266,7 +266,7 @@ class Decoder(nn.Module): def __init__(self, in_features, memory_dim, r, memory_size, attn_type, attn_windowing, attn_norm, prenet_type, prenet_dropout, forward_attn, trans_agent, forward_attn_mask, location_attn, attn_K, - separate_stopnet, speaker_embedding_dim): + separate_stopnet): super(Decoder, self).__init__() self.r_init = r self.r = r @@ -438,15 +438,12 @@ def forward(self, inputs, memory, mask): t += 1 return self._parse_outputs(outputs, attentions, stop_tokens) - def inference(self, inputs, speaker_embeddings=None): + def inference(self, inputs): """ Args: inputs: encoder outputs. - speaker_embeddings: speaker vectors. - Shapes: - inputs: batch x time x encoder_out_dim - - speaker_embeddings: batch x embed_dim """ outputs = [] attentions = [] @@ -459,8 +456,6 @@ def inference(self, inputs, speaker_embeddings=None): if t > 0: new_memory = outputs[-1] self._update_memory_input(new_memory) - if speaker_embeddings is not None: - self.memory_input = torch.cat([self.memory_input, speaker_embeddings], dim=-1) output, stop_token, attention = self.decode(inputs, None) stop_token = torch.sigmoid(stop_token.data) outputs += [output] diff --git a/TTS/tts/layers/tacotron2.py b/TTS/tts/layers/tacotron2.py index 080e5a3f9..3ac9037c2 100644 --- a/TTS/tts/layers/tacotron2.py +++ b/TTS/tts/layers/tacotron2.py @@ -94,8 +94,7 @@ class Decoder(nn.Module): #pylint: disable=attribute-defined-outside-init def __init__(self, input_dim, frame_dim, r, attn_type, attn_win, attn_norm, prenet_type, prenet_dropout, forward_attn, trans_agent, - forward_attn_mask, location_attn, attn_K, separate_stopnet, - speaker_embedding_dim): + forward_attn_mask, location_attn, attn_K, separate_stopnet): super(Decoder, self).__init__() self.frame_dim = frame_dim self.r_init = r @@ -268,7 +267,7 @@ def forward(self, inputs, memories, mask): outputs, stop_tokens, alignments) return outputs, alignments, stop_tokens - def inference(self, inputs, speaker_embeddings=None): + def inference(self, inputs): memory = self.get_go_frame(inputs) memory = self._update_memory(memory) @@ -278,8 +277,6 @@ def inference(self, inputs, speaker_embeddings=None): outputs, stop_tokens, alignments, t = [], [], [], 0 while True: memory = self.prenet(memory) - if speaker_embeddings is not None: - memory = torch.cat([memory, speaker_embeddings], dim=-1) decoder_output, alignment, stop_token = self.decode(memory) stop_token = torch.sigmoid(stop_token.data) outputs += [decoder_output.squeeze(1)] diff --git a/TTS/tts/models/tacotron.py b/TTS/tts/models/tacotron.py index d7fff4442..5f33db987 100644 --- a/TTS/tts/models/tacotron.py +++ b/TTS/tts/models/tacotron.py @@ -27,6 +27,7 @@ def __init__(self, bidirectional_decoder=False, double_decoder_consistency=False, ddc_r=None, + speaker_embedding_dim=None, gst=False, gst_embedding_dim=256, gst_num_heads=4, @@ -40,39 +41,46 @@ def __init__(self, location_attn, attn_K, separate_stopnet, bidirectional_decoder, double_decoder_consistency, ddc_r, gst) - - + # init layer dims decoder_in_features = 256 encoder_in_features = 256 - speaker_embedding_dim = 256 - proj_speaker_dim = 80 if num_speakers > 1 else 0 + if speaker_embedding_dim is None: + # if speaker_embedding_dim is None we need use the nn.Embedding, with default speaker_embedding_dim + self.embeddings_per_sample = False + speaker_embedding_dim = 256 + else: + # if speaker_embedding_dim is not None we need use speaker embedding per sample + self.embeddings_per_sample = True + + # speaker and gst embeddings is concat in decoder input if num_speakers > 1: decoder_in_features = decoder_in_features + speaker_embedding_dim # add speaker embedding dim if self.gst: decoder_in_features = decoder_in_features + gst_embedding_dim # add gst embedding dim - # base model layers + # embedding layer self.embedding = nn.Embedding(num_chars, 256, padding_idx=0) + + # speaker embedding layers + if num_speakers > 1: + if not self.embeddings_per_sample: + self.speaker_embedding = nn.Embedding(num_speakers, speaker_embedding_dim) + self.speaker_embedding.weight.data.normal_(0, 0.3) + + # base model layers self.embedding.weight.data.normal_(0, 0.3) self.encoder = Encoder(encoder_in_features) self.decoder = Decoder(decoder_in_features, decoder_output_dim, r, memory_size, attn_type, attn_win, attn_norm, prenet_type, prenet_dropout, forward_attn, trans_agent, forward_attn_mask, location_attn, - attn_K, separate_stopnet, proj_speaker_dim) + attn_K, separate_stopnet) self.postnet = PostCBHG(decoder_output_dim) self.last_linear = nn.Linear(self.postnet.cbhg.gru_features * 2, postnet_output_dim) - # speaker embedding layers - if num_speakers > 1: - self.speaker_embedding = nn.Embedding(num_speakers, speaker_embedding_dim) - self.speaker_embedding.weight.data.normal_(0, 0.3) - self.speaker_project_mel = nn.Sequential( - nn.Linear(speaker_embedding_dim, proj_speaker_dim), nn.Tanh()) - self.speaker_embeddings = None - self.speaker_embeddings_projected = None + # global style token layers if self.gst: self.gst_layer = GST(num_mel=80, @@ -88,10 +96,9 @@ def __init__(self, decoder_in_features, decoder_output_dim, ddc_r, memory_size, attn_type, attn_win, attn_norm, prenet_type, prenet_dropout, forward_attn, trans_agent, forward_attn_mask, location_attn, - attn_K, separate_stopnet, proj_speaker_dim) + attn_K, separate_stopnet) - - def forward(self, characters, text_lengths, mel_specs, mel_lengths=None, speaker_ids=None): + def forward(self, characters, text_lengths, mel_specs, mel_lengths=None, speaker_ids=None, speaker_embeddings=None): """ Shapes: - characters: B x T_in @@ -99,24 +106,27 @@ def forward(self, characters, text_lengths, mel_specs, mel_lengths=None, speaker - mel_specs: B x T_out x D - speaker_ids: B x 1 """ - self._init_states() input_mask, output_mask = self.compute_masks(text_lengths, mel_lengths) # B x T_in x embed_dim inputs = self.embedding(characters) - # B x speaker_embed_dim - if speaker_ids is not None: - self.compute_speaker_embedding(speaker_ids) # B x T_in x encoder_in_features encoder_outputs = self.encoder(inputs) # sequence masking encoder_outputs = encoder_outputs * input_mask.unsqueeze(2).expand_as(encoder_outputs) + # global style token if self.gst: # B x gst_dim encoder_outputs = self.compute_gst(encoder_outputs, mel_specs) + # speaker embedding if self.num_speakers > 1: - encoder_outputs = self._concat_speaker_embedding( - encoder_outputs, self.speaker_embeddings) + if not self.embeddings_per_sample: + # B x 1 x speaker_embed_dim + speaker_embeddings = self.speaker_embedding(speaker_ids)[:, None] + else: + # B x 1 x speaker_embed_dim + speaker_embeddings = torch.unsqueeze(speaker_embeddings, 1) + encoder_outputs = self._concat_speaker_embedding(encoder_outputs, speaker_embeddings) # decoder_outputs: B x decoder_in_features x T_out # alignments: B x T_in x encoder_in_features # stop_tokens: B x T_in @@ -143,19 +153,22 @@ def forward(self, characters, text_lengths, mel_specs, mel_lengths=None, speaker return decoder_outputs, postnet_outputs, alignments, stop_tokens @torch.no_grad() - def inference(self, characters, speaker_ids=None, style_mel=None): + def inference(self, characters, speaker_ids=None, style_mel=None, speaker_embeddings=None): inputs = self.embedding(characters) - self._init_states() - if speaker_ids is not None: - self.compute_speaker_embedding(speaker_ids) encoder_outputs = self.encoder(inputs) - if self.gst and style_mel is not None: + if self.gst: + # B x gst_dim encoder_outputs = self.compute_gst(encoder_outputs, style_mel) if self.num_speakers > 1: - encoder_outputs = self._concat_speaker_embedding( - encoder_outputs, self.speaker_embeddings) + if not self.embeddings_per_sample: + # B x 1 x speaker_embed_dim + speaker_embeddings = self.speaker_embedding(speaker_ids)[:, None] + else: + # B x 1 x speaker_embed_dim + speaker_embeddings = torch.unsqueeze(speaker_embeddings, 1) + encoder_outputs = self._concat_speaker_embedding(encoder_outputs, speaker_embeddings) decoder_outputs, alignments, stop_tokens = self.decoder.inference( - encoder_outputs, self.speaker_embeddings_projected) + encoder_outputs) postnet_outputs = self.postnet(decoder_outputs) postnet_outputs = self.last_linear(postnet_outputs) decoder_outputs = decoder_outputs.transpose(1, 2) diff --git a/TTS/tts/models/tacotron2.py b/TTS/tts/models/tacotron2.py index af4eff581..d48aa33a5 100644 --- a/TTS/tts/models/tacotron2.py +++ b/TTS/tts/models/tacotron2.py @@ -27,6 +27,7 @@ def __init__(self, bidirectional_decoder=False, double_decoder_consistency=False, ddc_r=None, + speaker_embedding_dim=None, gst=False, gst_embedding_dim=512, gst_num_heads=4, @@ -41,25 +42,38 @@ def __init__(self, ddc_r, gst) # init layer dims - speaker_embedding_dim = 512 if num_speakers > 1 else 0 - gst_embedding_dim = gst_embedding_dim if self.gst else 0 - decoder_in_features = 512+speaker_embedding_dim+gst_embedding_dim - encoder_in_features = 512 if num_speakers > 1 else 512 - proj_speaker_dim = 80 if num_speakers > 1 else 0 + decoder_in_features = 512 + encoder_in_features = 512 + + if speaker_embedding_dim is None: + # if speaker_embedding_dim is None we need use the nn.Embedding, with default speaker_embedding_dim + self.embeddings_per_sample = False + speaker_embedding_dim = 512 + else: + # if speaker_embedding_dim is not None we need use speaker embedding per sample + self.embeddings_per_sample = True + + # speaker and gst embeddings is concat in decoder input + if num_speakers > 1: + decoder_in_features = decoder_in_features + speaker_embedding_dim # add speaker embedding dim + if self.gst: + decoder_in_features = decoder_in_features + gst_embedding_dim # add gst embedding dim # embedding layer self.embedding = nn.Embedding(num_chars, 512, padding_idx=0) # speaker embedding layer if num_speakers > 1: - self.speaker_embedding = nn.Embedding(num_speakers, speaker_embedding_dim) - self.speaker_embedding.weight.data.normal_(0, 0.3) - + if not self.embeddings_per_sample: + self.speaker_embedding = nn.Embedding(num_speakers, speaker_embedding_dim) + self.speaker_embedding.weight.data.normal_(0, 0.3) + + # base model layers self.encoder = Encoder(encoder_in_features) self.decoder = Decoder(decoder_in_features, self.decoder_output_dim, r, attn_type, attn_win, attn_norm, prenet_type, prenet_dropout, forward_attn, trans_agent, forward_attn_mask, - location_attn, attn_K, separate_stopnet, proj_speaker_dim) + location_attn, attn_K, separate_stopnet) self.postnet = Postnet(self.postnet_output_dim) # global style token layers @@ -77,7 +91,7 @@ def __init__(self, decoder_in_features, self.decoder_output_dim, ddc_r, attn_type, attn_win, attn_norm, prenet_type, prenet_dropout, forward_attn, trans_agent, forward_attn_mask, location_attn, attn_K, - separate_stopnet, proj_speaker_dim) + separate_stopnet) @staticmethod def shape_outputs(mel_outputs, mel_outputs_postnet, alignments): @@ -85,7 +99,7 @@ def shape_outputs(mel_outputs, mel_outputs_postnet, alignments): mel_outputs_postnet = mel_outputs_postnet.transpose(1, 2) return mel_outputs, mel_outputs_postnet, alignments - def forward(self, text, text_lengths, mel_specs=None, mel_lengths=None, speaker_ids=None): + def forward(self, text, text_lengths, mel_specs=None, mel_lengths=None, speaker_ids=None, speaker_embeddings=None): # compute mask for padding # B x T_in_max (boolean) input_mask, output_mask = self.compute_masks(text_lengths, mel_lengths) @@ -99,8 +113,13 @@ def forward(self, text, text_lengths, mel_specs=None, mel_lengths=None, speaker_ encoder_outputs = self.compute_gst(encoder_outputs, mel_specs) if self.num_speakers > 1: - embedded_speakers = self.speaker_embedding(speaker_ids)[:, None] - encoder_outputs = self._concat_speaker_embedding(encoder_outputs, embedded_speakers) + if not self.embeddings_per_sample: + # B x 1 x speaker_embed_dim + speaker_embeddings = self.speaker_embedding(speaker_ids)[:, None] + else: + # B x 1 x speaker_embed_dim + speaker_embeddings = torch.unsqueeze(speaker_embeddings, 1) + encoder_outputs = self._concat_speaker_embedding(encoder_outputs, speaker_embeddings) encoder_outputs = encoder_outputs * input_mask.unsqueeze(2).expand_as(encoder_outputs) @@ -128,23 +147,18 @@ def forward(self, text, text_lengths, mel_specs=None, mel_lengths=None, speaker_ return decoder_outputs, postnet_outputs, alignments, stop_tokens @torch.no_grad() - def inference(self, text, speaker_ids=None, style_mel=None): + def inference(self, text, speaker_ids=None, style_mel=None, speaker_embeddings=None): embedded_inputs = self.embedding(text).transpose(1, 2) encoder_outputs = self.encoder.inference(embedded_inputs) + if self.gst: + # B x gst_dim + encoder_outputs = self.compute_gst(encoder_outputs, style_mel) + if self.num_speakers > 1: - embedded_speakers = self.speaker_embedding(speaker_ids)[:, None] - embedded_speakers = embedded_speakers.repeat(1, encoder_outputs.size(1), 1) - if self.gst: - # B x gst_dim - encoder_outputs = self.compute_gst(encoder_outputs, style_mel) - encoder_outputs = torch.cat([encoder_outputs, embedded_speakers], dim=-1) - else: - encoder_outputs = torch.cat([encoder_outputs, embedded_speakers], dim=-1) - else: - if self.gst: - # B x gst_dim - encoder_outputs = self.compute_gst(encoder_outputs, style_mel) + if not self.embeddings_per_sample: + speaker_embeddings = self.speaker_embedding(speaker_ids)[:, None] + encoder_outputs = self._concat_speaker_embedding(encoder_outputs, speaker_embeddings) decoder_outputs, alignments, stop_tokens = self.decoder.inference( encoder_outputs) @@ -154,25 +168,21 @@ def inference(self, text, speaker_ids=None, style_mel=None): decoder_outputs, postnet_outputs, alignments) return decoder_outputs, postnet_outputs, alignments, stop_tokens - def inference_truncated(self, text, speaker_ids=None, style_mel=None): + def inference_truncated(self, text, speaker_ids=None, style_mel=None, speaker_embeddings=None): """ Preserve model states for continuous inference """ embedded_inputs = self.embedding(text).transpose(1, 2) encoder_outputs = self.encoder.inference_truncated(embedded_inputs) + if self.gst: + # B x gst_dim + encoder_outputs = self.compute_gst(encoder_outputs, style_mel) + if self.num_speakers > 1: - embedded_speakers = self.speaker_embedding(speaker_ids)[:, None] - embedded_speakers = embedded_speakers.repeat(1, encoder_outputs.size(1), 1) - if self.gst: - # B x gst_dim - encoder_outputs = self.compute_gst(encoder_outputs, style_mel) - else: - encoder_outputs = torch.cat([encoder_outputs, embedded_speakers], dim=-1) - else: - if self.gst: - # B x gst_dim - encoder_outputs = self.compute_gst(encoder_outputs, style_mel) + if not self.embeddings_per_sample: + speaker_embeddings = self.speaker_embedding(speaker_ids)[:, None] + encoder_outputs = self._concat_speaker_embedding(encoder_outputs, speaker_embeddings) mel_outputs, alignments, stop_tokens = self.decoder.inference_truncated( encoder_outputs) diff --git a/TTS/tts/utils/generic_utils.py b/TTS/tts/utils/generic_utils.py index 2cd17175c..3d98aea0a 100644 --- a/TTS/tts/utils/generic_utils.py +++ b/TTS/tts/utils/generic_utils.py @@ -49,7 +49,7 @@ def sequence_mask(sequence_length, max_len=None): return seq_range_expand < seq_length_expand -def setup_model(num_chars, num_speakers, c): +def setup_model(num_chars, num_speakers, c, speaker_embedding_dim=None): print(" > Using model: {}".format(c.model)) MyModel = importlib.import_module('TTS.tts.models.' + c.model.lower()) MyModel = getattr(MyModel, c.model) @@ -74,7 +74,8 @@ def setup_model(num_chars, num_speakers, c): separate_stopnet=c.separate_stopnet, bidirectional_decoder=c.bidirectional_decoder, double_decoder_consistency=c.double_decoder_consistency, - ddc_r=c.ddc_r) + ddc_r=c.ddc_r, + speaker_embedding_dim=speaker_embedding_dim) elif c.model.lower() == "tacotron2": model = MyModel(num_chars=num_chars, num_speakers=num_speakers, @@ -95,7 +96,8 @@ def setup_model(num_chars, num_speakers, c): separate_stopnet=c.separate_stopnet, bidirectional_decoder=c.bidirectional_decoder, double_decoder_consistency=c.double_decoder_consistency, - ddc_r=c.ddc_r) + ddc_r=c.ddc_r, + speaker_embedding_dim=speaker_embedding_dim) return model class KeepAverage(): @@ -173,7 +175,7 @@ def check_config(c): check_argument('clip_norm', c['audio'], restricted=True, val_type=bool) check_argument('mel_fmin', c['audio'], restricted=True, val_type=float, min_val=0.0, max_val=1000) check_argument('mel_fmax', c['audio'], restricted=True, val_type=float, min_val=500.0) - check_argument('spec_gain', c['audio'], restricted=True, val_type=float, min_val=1, max_val=100) + check_argument('spec_gain', c['audio'], restricted=True, val_type=[int, float], min_val=1, max_val=100) check_argument('do_trim_silence', c['audio'], restricted=True, val_type=bool) check_argument('trim_db', c['audio'], restricted=True, val_type=int) @@ -243,8 +245,10 @@ def check_config(c): # paths check_argument('output_path', c, restricted=True, val_type=str) - # multi-speaker gst + # multi-speaker and gst check_argument('use_speaker_embedding', c, restricted=True, val_type=bool) + check_argument('use_external_speaker_embedding_file', c, restricted=True, val_type=bool) + check_argument('external_speaker_embedding_file', c, restricted=True, val_type=str) check_argument('use_gst', c, restricted=True, val_type=bool) check_argument('gst', c, restricted=True, val_type=dict) diff --git a/TTS/tts/utils/speakers.py b/TTS/tts/utils/speakers.py index 5e3d7eee4..d1ebef320 100644 --- a/TTS/tts/utils/speakers.py +++ b/TTS/tts/utils/speakers.py @@ -12,12 +12,15 @@ def make_speakers_json_path(out_path): def load_speaker_mapping(out_path): """Loads speaker mapping if already present.""" try: - with open(make_speakers_json_path(out_path)) as f: + if os.path.splitext(out_path)[1] == '.json': + json_file = out_path + else: + json_file = make_speakers_json_path(out_path) + with open(json_file) as f: return json.load(f) except FileNotFoundError: return {} - def save_speaker_mapping(out_path, speaker_mapping): """Saves speaker mapping if not yet present.""" speakers_json_path = make_speakers_json_path(out_path)