Skip to content

Commit 2f8b36d

Browse files
committed
Add conditioning interpolation, fwconv layer
1 parent 72c5ea4 commit 2f8b36d

File tree

2 files changed

+41
-17
lines changed

2 files changed

+41
-17
lines changed

dnn/torch/fargan/fargan.py

Lines changed: 39 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,31 @@ def forward(self, x):
101101

102102
return out
103103

104+
class FWConv(nn.Module):
105+
def __init__(self, in_size, out_size, kernel_size=3):
106+
super(FWConv, self).__init__()
107+
108+
torch.manual_seed(5)
109+
110+
self.in_size = in_size
111+
self.kernel_size = kernel_size
112+
self.conv = weight_norm(nn.Linear(in_size*self.kernel_size, out_size, bias=False))
113+
self.glu = GLU(out_size)
114+
115+
self.init_weights()
116+
117+
def init_weights(self):
118+
119+
for m in self.modules():
120+
if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d)\
121+
or isinstance(m, nn.Linear) or isinstance(m, nn.Embedding):
122+
nn.init.orthogonal_(m.weight.data)
123+
124+
def forward(self, x, state):
125+
xcat = torch.cat((state, x), -1)
126+
#print(x.shape, state.shape, xcat.shape, self.in_size, self.kernel_size)
127+
out = self.glu(torch.tanh(self.conv(xcat)))
128+
return out, xcat[:,self.in_size:]
104129

105130
class FARGANCond(nn.Module):
106131
def __init__(self, feature_dim=20, cond_size=256, pembed_dims=64):
@@ -113,7 +138,7 @@ def __init__(self, feature_dim=20, cond_size=256, pembed_dims=64):
113138
self.fdense1 = nn.Linear(self.feature_dim + pembed_dims, self.cond_size, bias=False)
114139
self.fconv1 = nn.Conv1d(self.cond_size, self.cond_size, kernel_size=3, padding='valid', bias=False)
115140
self.fconv2 = nn.Conv1d(self.cond_size, self.cond_size, kernel_size=3, padding='valid', bias=False)
116-
self.fdense2 = nn.Linear(self.cond_size, self.cond_size, bias=False)
141+
self.fdense2 = nn.Linear(self.cond_size, 80*4, bias=False)
117142

118143
self.apply(init_weights)
119144

@@ -138,9 +163,10 @@ def __init__(self, subframe_size=40, nb_subframes=4, cond_size=256, passthrough_
138163
self.has_gain = has_gain
139164
self.passthrough_size = passthrough_size
140165

141-
print("has_gain:", self.has_gain)
142-
print("passthrough_size:", self.passthrough_size)
143-
self.sig_dense1 = nn.Linear(4*self.subframe_size+self.passthrough_size+self.cond_size, self.cond_size, bias=False)
166+
#print("has_gain:", self.has_gain)
167+
#print("passthrough_size:", self.passthrough_size)
168+
#self.sig_dense1 = nn.Linear(4*self.subframe_size+self.passthrough_size+self.cond_size, self.cond_size, bias=False)
169+
self.fwc0 = FWConv(4*self.subframe_size+80, self.cond_size)
144170
self.sig_dense2 = nn.Linear(self.cond_size, self.cond_size, bias=False)
145171
self.gru1 = nn.GRUCell(self.cond_size, self.cond_size, bias=False)
146172
self.gru2 = nn.GRUCell(self.cond_size, self.cond_size, bias=False)
@@ -176,38 +202,34 @@ def forward(self, cond, prev, exc_mem, phase, period, states, gain=None):
176202
dump_signal(prev, 'pitch_exc.f32')
177203
dump_signal(exc_mem, 'exc_mem.f32')
178204

179-
passthrough = states[3]
180-
tmp = torch.cat((cond, pred[:,2:-2], prev, passthrough, phase), 1)
205+
tmp = torch.cat((cond, pred[:,2:-2], prev, phase), 1)
181206

182-
tmp = self.dense1_glu(torch.tanh(self.sig_dense1(tmp)))
183-
dense2_out = self.dense2_glu(torch.tanh(self.sig_dense2(tmp)))
207+
#tmp = self.dense1_glu(torch.tanh(self.sig_dense1(tmp)))
208+
fwc0_out, fwc0_state = self.fwc0(tmp, states[3])
209+
dense2_out = self.dense2_glu(torch.tanh(self.sig_dense2(fwc0_out)))
184210
gru1_state = self.gru1(dense2_out, states[0])
185211
gru1_out = self.gru1_glu(gru1_state)
186-
#gru1_out = torch.cat([gru1_out, fpitch], 1)
187212
gru2_state = self.gru2(gru1_out, states[1])
188213
gru2_out = self.gru2_glu(gru2_state)
189-
#gru2_out = torch.cat([gru2_out, fpitch], 1)
190214
gru3_state = self.gru3(gru2_out, states[2])
191215
gru3_out = self.gru3_glu(gru3_state)
192216
gru3_out = torch.cat([gru1_out, gru2_out, gru3_out, dense2_out], 1)
193217
sig_out = torch.tanh(self.sig_dense_out(gru3_out))
194-
if self.passthrough_size != 0:
195-
passthrough = sig_out[:,self.subframe_size:]
196-
sig_out = sig_out[:,:self.subframe_size]
197218
dump_signal(sig_out, 'exc_out.f32')
198219
taps = self.ptaps_dense(gru3_out)
199220
taps = .2*taps + torch.exp(taps)
200221
taps = taps / (1e-2 + torch.sum(torch.abs(taps), dim=-1, keepdim=True))
201222
dump_signal(taps, 'taps.f32')
202-
fpitch = taps[:,0:1]*pred[:,:-4] + taps[:,1:2]*pred[:,1:-3] + taps[:,2:3]*pred[:,2:-2] + taps[:,3:4]*pred[:,3:-1] + taps[:,4:]*pred[:,4:]
223+
#fpitch = taps[:,0:1]*pred[:,:-4] + taps[:,1:2]*pred[:,1:-3] + taps[:,2:3]*pred[:,2:-2] + taps[:,3:4]*pred[:,3:-1] + taps[:,4:]*pred[:,4:]
224+
fpitch = pred[:,2:-2]
203225

204226
if self.has_gain:
205227
pitch_gain = torch.exp(self.gain_dense_out(gru3_out))
206228
dump_signal(pitch_gain, 'pgain.f32')
207229
sig_out = (sig_out + pitch_gain*fpitch) * gain
208230
exc_mem = torch.cat([exc_mem[:,self.subframe_size:], sig_out], 1)
209231
dump_signal(sig_out, 'sig_out.f32')
210-
return sig_out, exc_mem, (gru1_state, gru2_state, gru3_state, passthrough)
232+
return sig_out, exc_mem, (gru1_state, gru2_state, gru3_state, fwc0_state)
211233

212234
class FARGAN(nn.Module):
213235
def __init__(self, subframe_size=40, nb_subframes=4, feature_dim=20, cond_size=256, passthrough_size=0, has_gain=False, gamma=None):
@@ -239,7 +261,7 @@ def forward(self, features, period, nb_frames, pre=None, states=None):
239261
torch.zeros(batch_size, self.cond_size, device=device),
240262
torch.zeros(batch_size, self.cond_size, device=device),
241263
torch.zeros(batch_size, self.cond_size, device=device),
242-
torch.zeros(batch_size, self.passthrough_size, device=device)
264+
torch.zeros(batch_size, (4*self.subframe_size+80)*2, device=device)
243265
)
244266

245267
sig = torch.zeros((batch_size, 0), device=device)
@@ -259,7 +281,7 @@ def forward(self, features, period, nb_frames, pre=None, states=None):
259281
pitch = period[:, 3+n]
260282
gain = .03*10**(0.5*features[:, 3+n, 0:1]/np.sqrt(18.0))
261283
#gain = gain[:,:,None]
262-
out, exc_mem, states = self.sig_net(cond[:, n, :], prev, exc_mem, phase, pitch, states, gain=gain)
284+
out, exc_mem, states = self.sig_net(cond[:, n, k*80:(k+1)*80], prev, exc_mem, phase, pitch, states, gain=gain)
263285

264286
if n < nb_pre_frames:
265287
out = pre[:, pos:pos+self.subframe_size]

dnn/torch/fargan/train_fargan.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,8 @@
121121
if (np.random.rand() > 0.1):
122122
target = target[:, :sequence_length*160]
123123
lpc = lpc[:,:sequence_length,:]
124+
features = features[:,:sequence_length+4,:]
125+
periods = periods[:,:sequence_length+4]
124126
else:
125127
target=target[::2, :]
126128
lpc=lpc[::2,:]

0 commit comments

Comments
 (0)