Skip to content

Commit

Permalink
Bunch of updates
Browse files Browse the repository at this point in the history
no LPC filtering, L1 loss with sqrt magnitude
plus architectural updates
  • Loading branch information
jmvalin committed Sep 20, 2023
1 parent b24c7b4 commit 03f9c90
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 36 deletions.
61 changes: 33 additions & 28 deletions dnn/torch/fargan/fargan.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,24 +81,24 @@ def gen_phase_embedding(periods, frame_size):
class GLU(nn.Module):
def __init__(self, feat_size):
super(GLU, self).__init__()

torch.manual_seed(5)

self.gate = weight_norm(nn.Linear(feat_size, feat_size, bias=False))

self.init_weights()

def init_weights(self):

for m in self.modules():
if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d)\
or isinstance(m, nn.Linear) or isinstance(m, nn.Embedding):
nn.init.orthogonal_(m.weight.data)

def forward(self, x):
out = x * torch.sigmoid(self.gate(x))

out = x * torch.sigmoid(self.gate(x))

return out

class FWConv(nn.Module):
Expand Down Expand Up @@ -160,31 +160,33 @@ def __init__(self, subframe_size=40, nb_subframes=4, cond_size=256):
self.subframe_size = subframe_size
self.nb_subframes = nb_subframes
self.cond_size = cond_size

#self.sig_dense1 = nn.Linear(4*self.subframe_size+self.passthrough_size+self.cond_size, self.cond_size, bias=False)
self.fwc0 = FWConv(4*self.subframe_size+80, self.cond_size)
self.sig_dense2 = nn.Linear(self.cond_size, self.cond_size, bias=False)
self.gru1 = nn.GRUCell(self.cond_size, self.cond_size, bias=False)
self.gru2 = nn.GRUCell(self.cond_size, self.cond_size, bias=False)
self.gru3 = nn.GRUCell(self.cond_size, self.cond_size, bias=False)
self.gru1 = nn.GRUCell(self.cond_size+2*self.subframe_size, self.cond_size, bias=False)
self.gru2 = nn.GRUCell(self.cond_size+2*self.subframe_size, self.cond_size, bias=False)
self.gru3 = nn.GRUCell(self.cond_size+2*self.subframe_size, self.cond_size, bias=False)

self.dense1_glu = GLU(self.cond_size)
self.dense2_glu = GLU(self.cond_size)
self.gru1_glu = GLU(self.cond_size)
self.gru2_glu = GLU(self.cond_size)
self.gru3_glu = GLU(self.cond_size)
self.ptaps_dense = nn.Linear(4*self.cond_size, 5)

self.sig_dense_out = nn.Linear(4*self.cond_size, self.subframe_size, bias=False)
self.gain_dense_out = nn.Linear(4*self.cond_size, 1)
self.skip_glu = GLU(self.cond_size)
#self.ptaps_dense = nn.Linear(4*self.cond_size, 5)

self.skip_dense = nn.Linear(4*self.cond_size+2*self.subframe_size, self.cond_size, bias=False)
self.sig_dense_out = nn.Linear(self.cond_size, self.subframe_size, bias=False)
self.gain_dense_out = nn.Linear(self.cond_size, 4)


self.apply(init_weights)

def forward(self, cond, prev, exc_mem, phase, period, states, gain=None):
device = exc_mem.device
#print(cond.shape, prev.shape)

dump_signal(prev, 'prev_in.f32')

idx = 256-torch.clamp(period[:,None], min=self.subframe_size+2, max=254)
Expand All @@ -198,29 +200,33 @@ def forward(self, cond, prev, exc_mem, phase, period, states, gain=None):
dump_signal(exc_mem, 'exc_mem.f32')

tmp = torch.cat((cond, pred[:,2:-2], prev, phase), 1)
#fpitch = taps[:,0:1]*pred[:,:-4] + taps[:,1:2]*pred[:,1:-3] + taps[:,2:3]*pred[:,2:-2] + taps[:,3:4]*pred[:,3:-1] + taps[:,4:]*pred[:,4:]
fpitch = pred[:,2:-2]

#tmp = self.dense1_glu(torch.tanh(self.sig_dense1(tmp)))
fwc0_out, fwc0_state = self.fwc0(tmp, states[3])
dense2_out = self.dense2_glu(torch.tanh(self.sig_dense2(fwc0_out)))
gru1_state = self.gru1(dense2_out, states[0])
pitch_gain = torch.sigmoid(self.gain_dense_out(dense2_out))

gru1_state = self.gru1(torch.cat([dense2_out, pitch_gain[:,0:1]*fpitch, prev], 1), states[0])
gru1_out = self.gru1_glu(gru1_state)
gru2_state = self.gru2(gru1_out, states[1])
gru2_state = self.gru2(torch.cat([gru1_out, pitch_gain[:,1:2]*fpitch, prev], 1), states[1])
gru2_out = self.gru2_glu(gru2_state)
gru3_state = self.gru3(gru2_out, states[2])
gru3_state = self.gru3(torch.cat([gru2_out, pitch_gain[:,2:3]*fpitch, prev], 1), states[2])
gru3_out = self.gru3_glu(gru3_state)
gru3_out = torch.cat([gru1_out, gru2_out, gru3_out, dense2_out], 1)
sig_out = torch.tanh(self.sig_dense_out(gru3_out))
skip_out = torch.tanh(self.skip_dense(torch.cat([gru3_out, pitch_gain[:,3:4]*fpitch, prev], 1)))
skip_out = self.skip_glu(skip_out)
sig_out = torch.tanh(self.sig_dense_out(skip_out))
dump_signal(sig_out, 'exc_out.f32')
taps = self.ptaps_dense(gru3_out)
taps = .2*taps + torch.exp(taps)
taps = taps / (1e-2 + torch.sum(torch.abs(taps), dim=-1, keepdim=True))
dump_signal(taps, 'taps.f32')
#fpitch = taps[:,0:1]*pred[:,:-4] + taps[:,1:2]*pred[:,1:-3] + taps[:,2:3]*pred[:,2:-2] + taps[:,3:4]*pred[:,3:-1] + taps[:,4:]*pred[:,4:]
fpitch = pred[:,2:-2]
#taps = self.ptaps_dense(gru3_out)
#taps = .2*taps + torch.exp(taps)
#taps = taps / (1e-2 + torch.sum(torch.abs(taps), dim=-1, keepdim=True))
#dump_signal(taps, 'taps.f32')

pitch_gain = torch.exp(self.gain_dense_out(gru3_out))
dump_signal(pitch_gain, 'pgain.f32')
sig_out = (sig_out + pitch_gain*fpitch) * gain
#sig_out = (sig_out + pitch_gain*fpitch) * gain
sig_out = sig_out * gain
exc_mem = torch.cat([exc_mem[:,self.subframe_size:], sig_out], 1)
dump_signal(sig_out, 'sig_out.f32')
return sig_out, exc_mem, (gru1_state, gru2_state, gru3_state, fwc0_state)
Expand Down Expand Up @@ -283,4 +289,3 @@ def forward(self, features, period, nb_frames, pre=None, states=None):
prev = out
states = [s.detach() for s in states]
return sig, states

4 changes: 3 additions & 1 deletion dnn/torch/fargan/stft_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,9 @@ def forward(self, x_mag, y_mag):
Returns:
Tensor: Spectral convergence loss value.
"""
return torch.norm(y_mag - x_mag, p="fro") / torch.norm(y_mag, p="fro")
y_mag = torch.sqrt(1e-7+y_mag)
x_mag = torch.sqrt(1e-7+x_mag)
return torch.norm(y_mag - x_mag, p=1) / torch.norm(y_mag, p=1)

class LogSTFTMagnitudeLoss(torch.nn.Module):
"""Log STFT magnitude loss module."""
Expand Down
7 changes: 4 additions & 3 deletions dnn/torch/fargan/test_fargan.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def inverse_perceptual_weighting (pw_signal, filters, weighting_vector):
buffer[:] = out_sig_frame[-16:]
return signal


from scipy.signal import lfilter

if __name__ == '__main__':
model.to(device)
Expand All @@ -99,9 +99,10 @@ def inverse_perceptual_weighting (pw_signal, filters, weighting_vector):
periods = torch.tensor(periods).to(device)

sig, _ = model(features, periods, nb_frames - 4)
weighting_vector = np.array([gamma**i for i in range(16,0,-1)])
#weighting_vector = np.array([gamma**i for i in range(16,0,-1)])
sig = sig.detach().numpy().flatten()
sig = inverse_perceptual_weighting(sig, lpc[0,:,:], weighting_vector)
sig = lfilter(np.array([1.]), np.array([1., -.85]), sig)
#sig = inverse_perceptual_weighting(sig, lpc[0,:,:], weighting_vector)

pcm = np.round(32768*np.clip(sig, a_max=.99, a_min=-.99)).astype('int16')
pcm.tofile(signal_file)
9 changes: 5 additions & 4 deletions dnn/torch/fargan/train_fargan.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,17 +127,17 @@
features=features[::2,:]
periods=periods[::2,:]
target = target.to(device)
target = fargan.analysis_filter(target, lpc[:,:,:], gamma=args.gamma)
#target = fargan.analysis_filter(target, lpc[:,:,:], gamma=args.gamma)

#nb_pre = random.randrange(1, 6)
nb_pre = 2
pre = target[:, :nb_pre*160]
sig, states = model(features, periods, target.size(1)//160 - nb_pre, pre=pre, states=None)
sig = torch.cat([pre, sig], -1)

cont_loss = fargan.sig_loss(target[:, nb_pre*160:nb_pre*160+80], sig[:, nb_pre*160:nb_pre*160+80])
cont_loss = fargan.sig_loss(target[:, nb_pre*160:nb_pre*160+320], sig[:, nb_pre*160:nb_pre*160+320])
specc_loss = spect_loss(sig, target.detach())
loss = .00*cont_loss + specc_loss
loss = .05*cont_loss + specc_loss

loss.backward()
optimizer.step()
Expand All @@ -150,7 +150,8 @@
running_cont_loss += cont_loss.detach().cpu().item()

running_loss += loss.detach().cpu().item()
tepoch.set_postfix(loss=f"{running_loss/(i+1):8.5f}",
if (i%4) == 0:
tepoch.set_postfix(loss=f"{running_loss/(i+1):8.5f}",
cont_loss=f"{running_cont_loss/(i+1):8.5f}",
specc=f"{running_specc/(i+1):8.5f}",
)
Expand Down

0 comments on commit 03f9c90

Please sign in to comment.