Skip to content

Commit 72c5ea4

Browse files
committed
Only use one frame of pre-loading
1 parent 108b75c commit 72c5ea4

File tree

1 file changed

+11
-8
lines changed

1 file changed

+11
-8
lines changed

dnn/torch/fargan/fargan.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -235,18 +235,21 @@ def forward(self, features, period, nb_frames, pre=None, states=None):
235235
exc_mem = torch.zeros(batch_size, 256, device=device)
236236
nb_pre_frames = pre.size(1)//self.frame_size if pre is not None else 0
237237

238-
if states is None:
239-
states = (
240-
torch.zeros(batch_size, self.cond_size, device=device),
241-
torch.zeros(batch_size, self.cond_size, device=device),
242-
torch.zeros(batch_size, self.cond_size, device=device),
243-
torch.zeros(batch_size, self.passthrough_size, device=device)
244-
)
238+
states = (
239+
torch.zeros(batch_size, self.cond_size, device=device),
240+
torch.zeros(batch_size, self.cond_size, device=device),
241+
torch.zeros(batch_size, self.cond_size, device=device),
242+
torch.zeros(batch_size, self.passthrough_size, device=device)
243+
)
245244

246245
sig = torch.zeros((batch_size, 0), device=device)
247246
cond = self.cond_net(features, period)
248247
passthrough = torch.zeros(batch_size, self.passthrough_size, device=device)
249-
for n in range(nb_frames+nb_pre_frames):
248+
if pre is not None:
249+
prev[:,:] = pre[:, self.frame_size-self.subframe_size : self.frame_size]
250+
exc_mem[:,-self.frame_size:] = pre[:, :self.frame_size]
251+
start = 1 if nb_pre_frames>0 else 0
252+
for n in range(start, nb_frames+nb_pre_frames):
250253
for k in range(self.nb_subframes):
251254
pos = n*self.frame_size + k*self.subframe_size
252255
preal = phase_real[:, pos:pos+self.subframe_size]

0 commit comments

Comments
 (0)