@@ -235,18 +235,21 @@ def forward(self, features, period, nb_frames, pre=None, states=None):
235
235
exc_mem = torch .zeros (batch_size , 256 , device = device )
236
236
nb_pre_frames = pre .size (1 )// self .frame_size if pre is not None else 0
237
237
238
- if states is None :
239
- states = (
240
- torch .zeros (batch_size , self .cond_size , device = device ),
241
- torch .zeros (batch_size , self .cond_size , device = device ),
242
- torch .zeros (batch_size , self .cond_size , device = device ),
243
- torch .zeros (batch_size , self .passthrough_size , device = device )
244
- )
238
+ states = (
239
+ torch .zeros (batch_size , self .cond_size , device = device ),
240
+ torch .zeros (batch_size , self .cond_size , device = device ),
241
+ torch .zeros (batch_size , self .cond_size , device = device ),
242
+ torch .zeros (batch_size , self .passthrough_size , device = device )
243
+ )
245
244
246
245
sig = torch .zeros ((batch_size , 0 ), device = device )
247
246
cond = self .cond_net (features , period )
248
247
passthrough = torch .zeros (batch_size , self .passthrough_size , device = device )
249
- for n in range (nb_frames + nb_pre_frames ):
248
+ if pre is not None :
249
+ prev [:,:] = pre [:, self .frame_size - self .subframe_size : self .frame_size ]
250
+ exc_mem [:,- self .frame_size :] = pre [:, :self .frame_size ]
251
+ start = 1 if nb_pre_frames > 0 else 0
252
+ for n in range (start , nb_frames + nb_pre_frames ):
250
253
for k in range (self .nb_subframes ):
251
254
pos = n * self .frame_size + k * self .subframe_size
252
255
preal = phase_real [:, pos :pos + self .subframe_size ]
0 commit comments