From 8bf6678edd50a36b53d8950834cc4ac274122905 Mon Sep 17 00:00:00 2001 From: Jeevesh8 Date: Sat, 8 Aug 2020 11:20:42 +0530 Subject: [PATCH] n_frame_per_step>1 support completed --- hparams.py | 2 +- model.py | 18 +++++++++++------- 2 files changed, 12 insertions(+), 8 deletions(-) diff --git a/hparams.py b/hparams.py index 8886f18fe..ff1a58317 100644 --- a/hparams.py +++ b/hparams.py @@ -53,7 +53,7 @@ def create_hparams(hparams_string=None, verbose=False): encoder_embedding_dim=512, # Decoder parameters - n_frames_per_step=1, # currently only 1 is supported + n_frames_per_step=1, # more than 1 is supported now decoder_rnn_dim=1024, prenet_dim=256, max_decoder_steps=1000, diff --git a/model.py b/model.py index ec0e9cea4..9c85bf63b 100644 --- a/model.py +++ b/model.py @@ -301,7 +301,7 @@ def parse_decoder_inputs(self, decoder_inputs): """ # (B, n_mel_channels, T_out) -> (B, T_out, n_mel_channels) decoder_inputs = decoder_inputs.transpose(1, 2) - decoder_inputs = decoder_inputs.view( + decoder_inputs = decoder_inputs.reshape( decoder_inputs.size(0), int(decoder_inputs.size(1)/self.n_frames_per_step), -1) # (B, T_out, n_mel_channels) -> (T_out, B, n_mel_channels) @@ -312,21 +312,22 @@ def parse_decoder_outputs(self, mel_outputs, gate_outputs, alignments): """ Prepares decoder outputs for output PARAMS ------ - mel_outputs: + mel_outputs: list of outputs[batch_size, n_mel_channels*n_frames_per_step] at each step of decoder gate_outputs: gate output energies - alignments: + alignments: list of alignments at each step of decoder RETURNS ------- - mel_outputs: - gate_outpust: gate output energies - alignments: + mel_outputs: batched tensor of outputs + gate_outputs: gate output energies + alignments: batched tensor of alignments """ # (T_out, B) -> (B, T_out) alignments = torch.stack(alignments).transpose(0, 1) # (T_out, B) -> (B, T_out) gate_outputs = torch.stack(gate_outputs).transpose(0, 1) gate_outputs = gate_outputs.contiguous() + gate_outputs = gate_outputs.repeat_interleave(self.n_frames_per_step,1) # (T_out, B, n_mel_channels) -> (B, T_out, n_mel_channels) mel_outputs = torch.stack(mel_outputs).transpose(0, 1).contiguous() # decouple frames per step @@ -442,7 +443,7 @@ def inference(self, memory): if torch.sigmoid(gate_output.data) > self.gate_threshold: break - elif len(mel_outputs) == self.max_decoder_steps: + elif len(mel_outputs)*self.n_frames_per_step >= self.max_decoder_steps: print("Warning! Reached max decoder steps") break @@ -488,6 +489,9 @@ def parse_output(self, outputs, output_lengths=None): if self.mask_padding and output_lengths is not None: mask = ~get_mask_from_lengths(output_lengths) mask = mask.expand(self.n_mel_channels, mask.size(0), mask.size(1)) + if mask.size(2)%self.n_frames_per_step != 0 : + to_append = torch.ones( mask.size(0), mask.size(1), (self.n_frames_per_step-mask.size(2)%self.n_frames_per_step) ).bool().to(mask.device) + mask = torch.cat([mask, to_append], dim=-1) mask = mask.permute(1, 0, 2) outputs[0].data.masked_fill_(mask, 0.0)