@@ -101,6 +101,31 @@ def forward(self, x):
101
101
102
102
return out
103
103
104
+ class FWConv (nn .Module ):
105
+ def __init__ (self , in_size , out_size , kernel_size = 3 ):
106
+ super (FWConv , self ).__init__ ()
107
+
108
+ torch .manual_seed (5 )
109
+
110
+ self .in_size = in_size
111
+ self .kernel_size = kernel_size
112
+ self .conv = weight_norm (nn .Linear (in_size * self .kernel_size , out_size , bias = False ))
113
+ self .glu = GLU (out_size )
114
+
115
+ self .init_weights ()
116
+
117
+ def init_weights (self ):
118
+
119
+ for m in self .modules ():
120
+ if isinstance (m , nn .Conv1d ) or isinstance (m , nn .ConvTranspose1d )\
121
+ or isinstance (m , nn .Linear ) or isinstance (m , nn .Embedding ):
122
+ nn .init .orthogonal_ (m .weight .data )
123
+
124
+ def forward (self , x , state ):
125
+ xcat = torch .cat ((state , x ), - 1 )
126
+ #print(x.shape, state.shape, xcat.shape, self.in_size, self.kernel_size)
127
+ out = self .glu (torch .tanh (self .conv (xcat )))
128
+ return out , xcat [:,self .in_size :]
104
129
105
130
class FARGANCond (nn .Module ):
106
131
def __init__ (self , feature_dim = 20 , cond_size = 256 , pembed_dims = 64 ):
@@ -113,7 +138,7 @@ def __init__(self, feature_dim=20, cond_size=256, pembed_dims=64):
113
138
self .fdense1 = nn .Linear (self .feature_dim + pembed_dims , self .cond_size , bias = False )
114
139
self .fconv1 = nn .Conv1d (self .cond_size , self .cond_size , kernel_size = 3 , padding = 'valid' , bias = False )
115
140
self .fconv2 = nn .Conv1d (self .cond_size , self .cond_size , kernel_size = 3 , padding = 'valid' , bias = False )
116
- self .fdense2 = nn .Linear (self .cond_size , self . cond_size , bias = False )
141
+ self .fdense2 = nn .Linear (self .cond_size , 80 * 4 , bias = False )
117
142
118
143
self .apply (init_weights )
119
144
@@ -138,9 +163,10 @@ def __init__(self, subframe_size=40, nb_subframes=4, cond_size=256, passthrough_
138
163
self .has_gain = has_gain
139
164
self .passthrough_size = passthrough_size
140
165
141
- print ("has_gain:" , self .has_gain )
142
- print ("passthrough_size:" , self .passthrough_size )
143
- self .sig_dense1 = nn .Linear (4 * self .subframe_size + self .passthrough_size + self .cond_size , self .cond_size , bias = False )
166
+ #print("has_gain:", self.has_gain)
167
+ #print("passthrough_size:", self.passthrough_size)
168
+ #self.sig_dense1 = nn.Linear(4*self.subframe_size+self.passthrough_size+self.cond_size, self.cond_size, bias=False)
169
+ self .fwc0 = FWConv (4 * self .subframe_size + 80 , self .cond_size )
144
170
self .sig_dense2 = nn .Linear (self .cond_size , self .cond_size , bias = False )
145
171
self .gru1 = nn .GRUCell (self .cond_size , self .cond_size , bias = False )
146
172
self .gru2 = nn .GRUCell (self .cond_size , self .cond_size , bias = False )
@@ -176,38 +202,34 @@ def forward(self, cond, prev, exc_mem, phase, period, states, gain=None):
176
202
dump_signal (prev , 'pitch_exc.f32' )
177
203
dump_signal (exc_mem , 'exc_mem.f32' )
178
204
179
- passthrough = states [3 ]
180
- tmp = torch .cat ((cond , pred [:,2 :- 2 ], prev , passthrough , phase ), 1 )
205
+ tmp = torch .cat ((cond , pred [:,2 :- 2 ], prev , phase ), 1 )
181
206
182
- tmp = self .dense1_glu (torch .tanh (self .sig_dense1 (tmp )))
183
- dense2_out = self .dense2_glu (torch .tanh (self .sig_dense2 (tmp )))
207
+ #tmp = self.dense1_glu(torch.tanh(self.sig_dense1(tmp)))
208
+ fwc0_out , fwc0_state = self .fwc0 (tmp , states [3 ])
209
+ dense2_out = self .dense2_glu (torch .tanh (self .sig_dense2 (fwc0_out )))
184
210
gru1_state = self .gru1 (dense2_out , states [0 ])
185
211
gru1_out = self .gru1_glu (gru1_state )
186
- #gru1_out = torch.cat([gru1_out, fpitch], 1)
187
212
gru2_state = self .gru2 (gru1_out , states [1 ])
188
213
gru2_out = self .gru2_glu (gru2_state )
189
- #gru2_out = torch.cat([gru2_out, fpitch], 1)
190
214
gru3_state = self .gru3 (gru2_out , states [2 ])
191
215
gru3_out = self .gru3_glu (gru3_state )
192
216
gru3_out = torch .cat ([gru1_out , gru2_out , gru3_out , dense2_out ], 1 )
193
217
sig_out = torch .tanh (self .sig_dense_out (gru3_out ))
194
- if self .passthrough_size != 0 :
195
- passthrough = sig_out [:,self .subframe_size :]
196
- sig_out = sig_out [:,:self .subframe_size ]
197
218
dump_signal (sig_out , 'exc_out.f32' )
198
219
taps = self .ptaps_dense (gru3_out )
199
220
taps = .2 * taps + torch .exp (taps )
200
221
taps = taps / (1e-2 + torch .sum (torch .abs (taps ), dim = - 1 , keepdim = True ))
201
222
dump_signal (taps , 'taps.f32' )
202
- fpitch = taps [:,0 :1 ]* pred [:,:- 4 ] + taps [:,1 :2 ]* pred [:,1 :- 3 ] + taps [:,2 :3 ]* pred [:,2 :- 2 ] + taps [:,3 :4 ]* pred [:,3 :- 1 ] + taps [:,4 :]* pred [:,4 :]
223
+ #fpitch = taps[:,0:1]*pred[:,:-4] + taps[:,1:2]*pred[:,1:-3] + taps[:,2:3]*pred[:,2:-2] + taps[:,3:4]*pred[:,3:-1] + taps[:,4:]*pred[:,4:]
224
+ fpitch = pred [:,2 :- 2 ]
203
225
204
226
if self .has_gain :
205
227
pitch_gain = torch .exp (self .gain_dense_out (gru3_out ))
206
228
dump_signal (pitch_gain , 'pgain.f32' )
207
229
sig_out = (sig_out + pitch_gain * fpitch ) * gain
208
230
exc_mem = torch .cat ([exc_mem [:,self .subframe_size :], sig_out ], 1 )
209
231
dump_signal (sig_out , 'sig_out.f32' )
210
- return sig_out , exc_mem , (gru1_state , gru2_state , gru3_state , passthrough )
232
+ return sig_out , exc_mem , (gru1_state , gru2_state , gru3_state , fwc0_state )
211
233
212
234
class FARGAN (nn .Module ):
213
235
def __init__ (self , subframe_size = 40 , nb_subframes = 4 , feature_dim = 20 , cond_size = 256 , passthrough_size = 0 , has_gain = False , gamma = None ):
@@ -239,7 +261,7 @@ def forward(self, features, period, nb_frames, pre=None, states=None):
239
261
torch .zeros (batch_size , self .cond_size , device = device ),
240
262
torch .zeros (batch_size , self .cond_size , device = device ),
241
263
torch .zeros (batch_size , self .cond_size , device = device ),
242
- torch .zeros (batch_size , self .passthrough_size , device = device )
264
+ torch .zeros (batch_size , ( 4 * self .subframe_size + 80 ) * 2 , device = device )
243
265
)
244
266
245
267
sig = torch .zeros ((batch_size , 0 ), device = device )
@@ -259,7 +281,7 @@ def forward(self, features, period, nb_frames, pre=None, states=None):
259
281
pitch = period [:, 3 + n ]
260
282
gain = .03 * 10 ** (0.5 * features [:, 3 + n , 0 :1 ]/ np .sqrt (18.0 ))
261
283
#gain = gain[:,:,None]
262
- out , exc_mem , states = self .sig_net (cond [:, n , : ], prev , exc_mem , phase , pitch , states , gain = gain )
284
+ out , exc_mem , states = self .sig_net (cond [:, n , k * 80 :( k + 1 ) * 80 ], prev , exc_mem , phase , pitch , states , gain = gain )
263
285
264
286
if n < nb_pre_frames :
265
287
out = pre [:, pos :pos + self .subframe_size ]
0 commit comments