From b06da9d586b4fbb5280dfbe4860948122503363c Mon Sep 17 00:00:00 2001 From: Edresson Date: Tue, 28 Jul 2020 17:11:32 -0300 Subject: [PATCH] bugfix in DDC now DDC work on Tacotron1 --- TTS/tts/configs/config.json | 16 +++++----- TTS/tts/datasets/TTSDataset.py | 2 +- TTS/tts/layers/tacotron.py | 7 ++--- TTS/tts/layers/tacotron2.py | 4 +-- TTS/tts/models/tacotron.py | 24 +++++++------- TTS/tts/models/tacotron2.py | 49 ++++++----------------------- TTS/tts/models/tacotron_abstract.py | 19 ++++++++--- TTS/tts/utils/generic_utils.py | 9 ++++-- TTS/tts/utils/text/cleaners.py | 1 - TTS/utils/generic_utils.py | 9 +++++- 10 files changed, 65 insertions(+), 75 deletions(-) diff --git a/TTS/tts/configs/config.json b/TTS/tts/configs/config.json index bf2ad3837..add798f32 100644 --- a/TTS/tts/configs/config.json +++ b/TTS/tts/configs/config.json @@ -1,5 +1,5 @@ { - "model": "Tacotron2", + "model": "Tacotron", "run_name": "ljspeech-ddc-bn", "run_description": "tacotron2 with ddc and batch-normalization", @@ -113,7 +113,7 @@ "tb_model_param_stats": false, // true, plots param stats per layer on tensorboard. Might be memory consuming, but good for debugging. // DATA LOADING - "text_cleaner": "phoneme_cleaners", + "text_cleaner": "portuguese_cleaners", "enable_eos_bos_chars": false, // enable/disable beginning of sentence and end of sentence chars. "num_loader_workers": 4, // number of training data loader processes. Don't set it too big. 4-8 are good values. "num_val_loader_workers": 4, // number of evaluation data loader processes. @@ -122,15 +122,15 @@ "max_seq_len": 153, // DATASET-RELATED: maximum text length // PATHS - "output_path": "/home/erogol/Models/LJSpeech/", + "output_path": "../../Mozilla-TTS/vctk-test/", // PHONEMES - "phoneme_cache_path": "/media/erogol/data_ssd2/mozilla_us_phonemes_3", // phoneme computation is slow, therefore, it caches results in the given folder. + "phoneme_cache_path": "../../Mozilla-TTS/vctk-test/", // phoneme computation is slow, therefore, it caches results in the given folder. "use_phonemes": true, // use phonemes instead of raw characters. It is suggested for better pronounciation. "phoneme_language": "en-us", // depending on your target language, pick one from https://github.com/bootphon/phonemizer#languages // MULTI-SPEAKER and GST - "use_speaker_embedding": false, // use speaker embedding to enable multi-speaker learning. + "use_speaker_embedding": true, // use speaker embedding to enable multi-speaker learning. "use_gst": true, // use global style tokens "gst": { // gst parameter if gst is enabled "gst_style_input": null, // Condition the style input either on a @@ -146,9 +146,9 @@ "datasets": // List of datasets. They all merged and they get different speaker_ids. [ { - "name": "ljspeech", - "path": "/home/erogol/Data/LJSpeech-1.1/", - "meta_file_train": "metadata.csv", + "name": "vctk", + "path": "../../../datasets/VCTK-Corpus-removed-silence/", + "meta_file_train": ["p225", "p234", "p238", "p245", "p248", "p261", "p294", "p302", "p326", "p335", "p347"], // for vtck if list, ignore speakers id in list for train, its useful for test cloning with new speakers "meta_file_val": null } ] diff --git a/TTS/tts/datasets/TTSDataset.py b/TTS/tts/datasets/TTSDataset.py index f99aabc72..9f8d17451 100644 --- a/TTS/tts/datasets/TTSDataset.py +++ b/TTS/tts/datasets/TTSDataset.py @@ -70,7 +70,7 @@ def __init__(self, self.sort_items() def load_wav(self, filename): - audio = self.ap.load_wav(filename) + audio = self.ap.load_wav(filename, sr=self.sample_rate) return audio @staticmethod diff --git a/TTS/tts/layers/tacotron.py b/TTS/tts/layers/tacotron.py index 20fd1e52f..a2ccc917e 100644 --- a/TTS/tts/layers/tacotron.py +++ b/TTS/tts/layers/tacotron.py @@ -278,7 +278,7 @@ def __init__(self, in_features, memory_dim, r, memory_size, attn_type, attn_wind self.separate_stopnet = separate_stopnet self.query_dim = 256 # memory -> |Prenet| -> processed_memory - prenet_dim = memory_dim * self.memory_size + speaker_embedding_dim if self.use_memory_queue else memory_dim + speaker_embedding_dim + prenet_dim = memory_dim * self.memory_size if self.use_memory_queue else memory_dim self.prenet = Prenet( prenet_dim, prenet_type, @@ -405,7 +405,7 @@ def _update_memory_input(self, new_memory): # assert new_memory.shape[-1] == self.r * self.memory_dim self.memory_input = new_memory[:, self.memory_dim * (self.r - 1):] - def forward(self, inputs, memory, mask, speaker_embeddings=None): + def forward(self, inputs, memory, mask): """ Args: inputs: Encoder outputs. @@ -430,8 +430,7 @@ def forward(self, inputs, memory, mask, speaker_embeddings=None): if t > 0: new_memory = memory[t - 1] self._update_memory_input(new_memory) - if speaker_embeddings is not None: - self.memory_input = torch.cat([self.memory_input, speaker_embeddings], dim=-1) + output, stop_token, attention = self.decode(inputs, mask) outputs += [output] attentions += [attention] diff --git a/TTS/tts/layers/tacotron2.py b/TTS/tts/layers/tacotron2.py index 16c7b052f..080e5a3f9 100644 --- a/TTS/tts/layers/tacotron2.py +++ b/TTS/tts/layers/tacotron2.py @@ -246,13 +246,11 @@ def decode(self, memory): decoder_output = decoder_output[:, :self.r * self.frame_dim] return decoder_output, self.attention.attention_weights, stop_token - def forward(self, inputs, memories, mask, speaker_embeddings=None): + def forward(self, inputs, memories, mask): memory = self.get_go_frame(inputs).unsqueeze(0) memories = self._reshape_memory(memories) memories = torch.cat((memory, memories), dim=0) memories = self._update_memory(memories) - if speaker_embeddings is not None: - memories = torch.cat([memories, speaker_embeddings], dim=-1) memories = self.prenet(memories) self._init_states(inputs, mask=mask) diff --git a/TTS/tts/models/tacotron.py b/TTS/tts/models/tacotron.py index bd54b05d5..d7fff4442 100644 --- a/TTS/tts/models/tacotron.py +++ b/TTS/tts/models/tacotron.py @@ -6,7 +6,6 @@ from TTS.tts.layers.tacotron import Decoder, Encoder, PostCBHG from TTS.tts.models.tacotron_abstract import TacotronAbstract - class Tacotron(TacotronAbstract): def __init__(self, num_chars, @@ -41,10 +40,19 @@ def __init__(self, location_attn, attn_K, separate_stopnet, bidirectional_decoder, double_decoder_consistency, ddc_r, gst) - decoder_in_features = 512 if num_speakers > 1 else 256 - encoder_in_features = 512 if num_speakers > 1 else 256 + + + # init layer dims + decoder_in_features = 256 + encoder_in_features = 256 speaker_embedding_dim = 256 proj_speaker_dim = 80 if num_speakers > 1 else 0 + + if num_speakers > 1: + decoder_in_features = decoder_in_features + speaker_embedding_dim # add speaker embedding dim + if self.gst: + decoder_in_features = decoder_in_features + gst_embedding_dim # add gst embedding dim + # base model layers self.embedding = nn.Embedding(num_chars, 256, padding_idx=0) self.embedding.weight.data.normal_(0, 0.3) @@ -98,10 +106,6 @@ def forward(self, characters, text_lengths, mel_specs, mel_lengths=None, speaker # B x speaker_embed_dim if speaker_ids is not None: self.compute_speaker_embedding(speaker_ids) - if self.num_speakers > 1: - # B x T_in x embed_dim + speaker_embed_dim - inputs = self._concat_speaker_embedding(inputs, - self.speaker_embeddings) # B x T_in x encoder_in_features encoder_outputs = self.encoder(inputs) # sequence masking @@ -117,8 +121,7 @@ def forward(self, characters, text_lengths, mel_specs, mel_lengths=None, speaker # alignments: B x T_in x encoder_in_features # stop_tokens: B x T_in decoder_outputs, alignments, stop_tokens = self.decoder( - encoder_outputs, mel_specs, input_mask, - self.speaker_embeddings_projected) + encoder_outputs, mel_specs, input_mask) # sequence masking if output_mask is not None: decoder_outputs = decoder_outputs * output_mask.unsqueeze(1).expand_as(decoder_outputs) @@ -145,9 +148,6 @@ def inference(self, characters, speaker_ids=None, style_mel=None): self._init_states() if speaker_ids is not None: self.compute_speaker_embedding(speaker_ids) - if self.num_speakers > 1: - inputs = self._concat_speaker_embedding(inputs, - self.speaker_embeddings) encoder_outputs = self.encoder(inputs) if self.gst and style_mel is not None: encoder_outputs = self.compute_gst(encoder_outputs, style_mel) diff --git a/TTS/tts/models/tacotron2.py b/TTS/tts/models/tacotron2.py index 9ba2da578..af4eff581 100644 --- a/TTS/tts/models/tacotron2.py +++ b/TTS/tts/models/tacotron2.py @@ -5,7 +5,6 @@ from TTS.tts.layers.tacotron2 import Decoder, Encoder, Postnet from TTS.tts.models.tacotron_abstract import TacotronAbstract - # TODO: match function arguments with tacotron class Tacotron2(TacotronAbstract): def __init__(self, @@ -86,24 +85,6 @@ def shape_outputs(mel_outputs, mel_outputs_postnet, alignments): mel_outputs_postnet = mel_outputs_postnet.transpose(1, 2) return mel_outputs, mel_outputs_postnet, alignments - def compute_gst(self, inputs, style_input): - """ Compute global style token """ - device = inputs.device - if isinstance(style_input, dict): - query = torch.zeros(1, 1, self.gst_embedding_dim//2).to(device) - _GST = torch.tanh(self.gst_layer.style_token_layer.style_tokens) - gst_outputs = torch.zeros(1, 1, self.gst_embedding_dim).to(device) - for k_token, v_amplifier in style_input.items(): - key = _GST[int(k_token)].unsqueeze(0).expand(1, -1, -1) - gst_outputs_att = self.gst_layer.style_token_layer.attention(query, key) - gst_outputs = gst_outputs + gst_outputs_att * v_amplifier - elif style_input is None: - gst_outputs = torch.zeros(1, 1, self.gst_embedding_dim).to(device) - else: - gst_outputs = self.gst_layer(style_input) # pylint: disable=not-callable - embedded_gst = gst_outputs.repeat(1, inputs.size(1), 1) - return inputs, embedded_gst - def forward(self, text, text_lengths, mel_specs=None, mel_lengths=None, speaker_ids=None): # compute mask for padding # B x T_in_max (boolean) @@ -113,20 +94,13 @@ def forward(self, text, text_lengths, mel_specs=None, mel_lengths=None, speaker_ # B x T_in_max x D_en encoder_outputs = self.encoder(embedded_inputs, text_lengths) + if self.gst: + # B x gst_dim + encoder_outputs = self.compute_gst(encoder_outputs, mel_specs) + if self.num_speakers > 1: embedded_speakers = self.speaker_embedding(speaker_ids)[:, None] - embedded_speakers = embedded_speakers.repeat(1, encoder_outputs.size(1), 1) - if self.gst: - # B x gst_dim - encoder_outputs, embedded_gst = self.compute_gst(encoder_outputs, mel_specs) - encoder_outputs = torch.cat([encoder_outputs, embedded_gst, embedded_speakers], dim=-1) - else: - encoder_outputs = torch.cat([encoder_outputs, embedded_speakers], dim=-1) - else: - if self.gst: - # B x gst_dim - encoder_outputs, embedded_gst = self.compute_gst(encoder_outputs, mel_specs) - encoder_outputs = torch.cat([encoder_outputs, embedded_gst], dim=-1) + encoder_outputs = self._concat_speaker_embedding(encoder_outputs, embedded_speakers) encoder_outputs = encoder_outputs * input_mask.unsqueeze(2).expand_as(encoder_outputs) @@ -163,15 +137,14 @@ def inference(self, text, speaker_ids=None, style_mel=None): embedded_speakers = embedded_speakers.repeat(1, encoder_outputs.size(1), 1) if self.gst: # B x gst_dim - encoder_outputs, embedded_gst = self.compute_gst(encoder_outputs, style_mel) - encoder_outputs = torch.cat([encoder_outputs, embedded_gst, embedded_speakers], dim=-1) + encoder_outputs = self.compute_gst(encoder_outputs, style_mel) + encoder_outputs = torch.cat([encoder_outputs, embedded_speakers], dim=-1) else: encoder_outputs = torch.cat([encoder_outputs, embedded_speakers], dim=-1) else: if self.gst: # B x gst_dim - encoder_outputs, embedded_gst = self.compute_gst(encoder_outputs, style_mel) - encoder_outputs = torch.cat([encoder_outputs, embedded_gst], dim=-1) + encoder_outputs = self.compute_gst(encoder_outputs, style_mel) decoder_outputs, alignments, stop_tokens = self.decoder.inference( encoder_outputs) @@ -193,15 +166,13 @@ def inference_truncated(self, text, speaker_ids=None, style_mel=None): embedded_speakers = embedded_speakers.repeat(1, encoder_outputs.size(1), 1) if self.gst: # B x gst_dim - encoder_outputs, embedded_gst = self.compute_gst(encoder_outputs, style_mel) - encoder_outputs = torch.cat([encoder_outputs, embedded_gst, embedded_speakers], dim=-1) + encoder_outputs = self.compute_gst(encoder_outputs, style_mel) else: encoder_outputs = torch.cat([encoder_outputs, embedded_speakers], dim=-1) else: if self.gst: # B x gst_dim - encoder_outputs, embedded_gst = self.compute_gst(encoder_outputs, style_mel) - encoder_outputs = torch.cat([encoder_outputs, embedded_gst], dim=-1) + encoder_outputs = self.compute_gst(encoder_outputs, style_mel) mel_outputs, alignments, stop_tokens = self.decoder.inference_truncated( encoder_outputs) diff --git a/TTS/tts/models/tacotron_abstract.py b/TTS/tts/models/tacotron_abstract.py index 07324c9ea..638e1ca84 100644 --- a/TTS/tts/models/tacotron_abstract.py +++ b/TTS/tts/models/tacotron_abstract.py @@ -164,11 +164,22 @@ def compute_speaker_embedding(self, speaker_ids): self.speaker_embeddings_projected = self.speaker_project_mel( self.speaker_embeddings).squeeze(1) - def compute_gst(self, inputs, mel_specs): + def compute_gst(self, inputs, style_input): """ Compute global style token """ - # pylint: disable=not-callable - gst_outputs = self.gst_layer(mel_specs) - inputs = self._add_speaker_embedding(inputs, gst_outputs) + device = inputs.device + if isinstance(style_input, dict): + query = torch.zeros(1, 1, self.gst_embedding_dim//2).to(device) + _GST = torch.tanh(self.gst_layer.style_token_layer.style_tokens) + gst_outputs = torch.zeros(1, 1, self.gst_embedding_dim).to(device) + for k_token, v_amplifier in style_input.items(): + key = _GST[int(k_token)].unsqueeze(0).expand(1, -1, -1) + gst_outputs_att = self.gst_layer.style_token_layer.attention(query, key) + gst_outputs = gst_outputs + gst_outputs_att * v_amplifier + elif style_input is None: + gst_outputs = torch.zeros(1, 1, self.gst_embedding_dim).to(device) + else: + gst_outputs = self.gst_layer(style_input) # pylint: disable=not-callable + inputs = self._concat_speaker_embedding(inputs, gst_outputs) return inputs @staticmethod diff --git a/TTS/tts/utils/generic_utils.py b/TTS/tts/utils/generic_utils.py index a2a413fd5..2cd17175c 100644 --- a/TTS/tts/utils/generic_utils.py +++ b/TTS/tts/utils/generic_utils.py @@ -245,13 +245,18 @@ def check_config(c): # multi-speaker gst check_argument('use_speaker_embedding', c, restricted=True, val_type=bool) - check_argument('style_wav_for_test', c, restricted=True, val_type=str) check_argument('use_gst', c, restricted=True, val_type=bool) + check_argument('gst', c, restricted=True, val_type=dict) + check_argument('gst_style_input', c['gst'], restricted=True, val_type=[str, dict]) + check_argument('gst_embedding_dim', c['gst'], restricted=True, val_type=int, min_val=0, max_val=1000) + check_argument('gst_num_heads', c['gst'], restricted=True, val_type=int, min_val=2, max_val=10) + check_argument('gst_style_tokens', c['gst'], restricted=True, val_type=int, min_val=1, max_val=1000) + # datasets - checking only the first entry check_argument('datasets', c, restricted=True, val_type=list) for dataset_entry in c['datasets']: check_argument('name', dataset_entry, restricted=True, val_type=str) check_argument('path', dataset_entry, restricted=True, val_type=str) - check_argument('meta_file_train', dataset_entry, restricted=True, val_type=str) + check_argument('meta_file_train', dataset_entry, restricted=True, val_type=[str, list]) check_argument('meta_file_val', dataset_entry, restricted=True, val_type=str) diff --git a/TTS/tts/utils/text/cleaners.py b/TTS/tts/utils/text/cleaners.py index 227118e69..b19308348 100644 --- a/TTS/tts/utils/text/cleaners.py +++ b/TTS/tts/utils/text/cleaners.py @@ -107,7 +107,6 @@ def basic_turkish_cleaners(text): text = collapse_whitespace(text) return text - def english_cleaners(text): '''Pipeline for English text, including number and abbreviation expansion.''' text = convert_to_ascii(text) diff --git a/TTS/utils/generic_utils.py b/TTS/utils/generic_utils.py index 5dbe2062c..5b76132f7 100644 --- a/TTS/utils/generic_utils.py +++ b/TTS/utils/generic_utils.py @@ -150,5 +150,12 @@ def check_argument(name, c, enum_list=None, max_val=None, min_val=None, restrict assert c[name] >= min_val, f' [!] {name} is smaller than min value {min_val}' if enum_list: assert c[name].lower() in enum_list, f' [!] {name} is not a valid value' - if val_type: + if isinstance(val_type, list): + valid_types = val_type + is_valid = False + for typ in val_type: + if isinstance(c[name], typ): + is_valid = True + assert is_valid or c[name] is None, f' [!] {name} has wrong type - {type(c[name])} vs {val_type}' + elif val_type: assert isinstance(c[name], val_type) or c[name] is None, f' [!] {name} has wrong type - {type(c[name])} vs {val_type}' \ No newline at end of file