Skip to content

Commit

Permalink
more simplifications
Browse files Browse the repository at this point in the history
  • Loading branch information
jmvalin committed Sep 28, 2023
1 parent b57ddad commit a5d3409
Showing 1 changed file with 10 additions and 13 deletions.
23 changes: 10 additions & 13 deletions dnn/torch/fargan/fargan.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,20 +185,18 @@ def __init__(self, subframe_size=40, nb_subframes=4, cond_size=256):

#self.sig_dense1 = nn.Linear(4*self.subframe_size+self.passthrough_size+self.cond_size, self.cond_size, bias=False)
self.fwc0 = FWConv(2*self.subframe_size+80+4, self.cond_size)
self.sig_dense2 = nn.Linear(self.cond_size, self.cond_size, bias=False)
self.gru1 = nn.GRUCell(self.cond_size+2*self.subframe_size, self.cond_size, bias=False)
self.gru2 = nn.GRUCell(self.cond_size+2*self.subframe_size, self.cond_size, bias=False)
self.gru3 = nn.GRUCell(self.cond_size+2*self.subframe_size, self.cond_size, bias=False)
self.gru2 = nn.GRUCell(self.cond_size+2*self.subframe_size, 128, bias=False)
self.gru3 = nn.GRUCell(128+2*self.subframe_size, 128, bias=False)

self.dense1_glu = GLU(self.cond_size)
self.dense2_glu = GLU(self.cond_size)
self.gru1_glu = GLU(self.cond_size)
self.gru2_glu = GLU(self.cond_size)
self.gru3_glu = GLU(self.cond_size)
self.gru2_glu = GLU(128)
self.gru3_glu = GLU(128)
self.skip_glu = GLU(self.cond_size)
#self.ptaps_dense = nn.Linear(4*self.cond_size, 5)

self.skip_dense = nn.Linear(4*self.cond_size+2*self.subframe_size, self.cond_size, bias=False)
self.skip_dense = nn.Linear(2*128+2*self.cond_size+2*self.subframe_size, self.cond_size, bias=False)
self.sig_dense_out = nn.Linear(self.cond_size, self.subframe_size, bias=False)
self.gain_dense_out = nn.Linear(self.cond_size, 4)

Expand Down Expand Up @@ -228,16 +226,15 @@ def forward(self, cond, prev_pred, exc_mem, period, states, gain=None):

#tmp = self.dense1_glu(torch.tanh(self.sig_dense1(tmp)))
fwc0_out, fwc0_state = self.fwc0(tmp, states[3])
dense2_out = self.dense2_glu(torch.tanh(self.sig_dense2(fwc0_out)))
pitch_gain = torch.sigmoid(self.gain_dense_out(dense2_out))
pitch_gain = torch.sigmoid(self.gain_dense_out(fwc0_out))

gru1_state = self.gru1(torch.cat([dense2_out, pitch_gain[:,0:1]*fpitch, prev], 1), states[0])
gru1_state = self.gru1(torch.cat([fwc0_out, pitch_gain[:,0:1]*fpitch, prev], 1), states[0])
gru1_out = self.gru1_glu(gru1_state)
gru2_state = self.gru2(torch.cat([gru1_out, pitch_gain[:,1:2]*fpitch, prev], 1), states[1])
gru2_out = self.gru2_glu(gru2_state)
gru3_state = self.gru3(torch.cat([gru2_out, pitch_gain[:,2:3]*fpitch, prev], 1), states[2])
gru3_out = self.gru3_glu(gru3_state)
gru3_out = torch.cat([gru1_out, gru2_out, gru3_out, dense2_out], 1)
gru3_out = torch.cat([gru1_out, gru2_out, gru3_out, fwc0_out], 1)
skip_out = torch.tanh(self.skip_dense(torch.cat([gru3_out, pitch_gain[:,3:4]*fpitch, prev], 1)))
skip_out = self.skip_glu(skip_out)
sig_out = torch.tanh(self.sig_dense_out(skip_out))
Expand Down Expand Up @@ -278,8 +275,8 @@ def forward(self, features, period, nb_frames, pre=None, states=None):

states = (
torch.zeros(batch_size, self.cond_size, device=device),
torch.zeros(batch_size, self.cond_size, device=device),
torch.zeros(batch_size, self.cond_size, device=device),
torch.zeros(batch_size, 128, device=device),
torch.zeros(batch_size, 128, device=device),
torch.zeros(batch_size, (2*self.subframe_size+80+4)*2, device=device)
)

Expand Down

0 comments on commit a5d3409

Please sign in to comment.