Skip to content

Commit

Permalink
Merge pull request #72 from sp-nitech/magic_intpl
Browse files Browse the repository at this point in the history
Bug fix of magic_intpl
  • Loading branch information
takenori-y committed Apr 8, 2024
2 parents 4540186 + 43f16ab commit ec16916
Show file tree
Hide file tree
Showing 9 changed files with 23 additions and 15 deletions.
2 changes: 1 addition & 1 deletion diffsptk/modules/ap.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ def forward(self, x, f0):

H = torch.cat((H_alpha, H_beta), dim=-1) # (B, N, J, 6)
w = self.window[i, : self.segment_length[i]] # (J,)
Hw = H.mT * w # (B, N, 6, J)
Hw = H.transpose(-2, -1) * w # (B, N, 6, J)
R = torch.matmul(Hw, H) # (B, N, 6, 6)

index_gamma = origin.unsqueeze(-1) + j[..., 1:-1] # (B, N, J)
Expand Down
4 changes: 2 additions & 2 deletions diffsptk/modules/excite.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,10 +91,10 @@ def _forward(p, frame_period, voiced_region, unvoiced_region):

# Interpolate pitch.
if p.dim() != 1:
p = p.mT
p = p.transpose(-2, -1)
p = LinearInterpolation._func(p, frame_period)
if p.dim() != 1:
p = p.mT
p = p.transpose(-2, -1)
p *= mask

# Compute phase.
Expand Down
2 changes: 1 addition & 1 deletion diffsptk/modules/gmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,7 @@ def forward(self, x):
y = posterior.sum(dim=0)
nu = px / y.view(-1, 1)
nm = torch.matmul(nu.unsqueeze(-1), self.mu.unsqueeze(-2))
mn = nm.mT
mn = nm.transpose(-2, -1)
a = pxx - y.view(-1, 1, 1) * (nm + mn - mm)
b = xi.view(-1, 1, 1) * self.ubm_sigma
diff = self.ubm_mu - self.mu
Expand Down
4 changes: 2 additions & 2 deletions diffsptk/modules/linear_intpl.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,15 +79,15 @@ def _forward(x, upsampling_factor):
assert x.dim() == 3, "Input must be 3D tensor"
B, T, D = x.shape

x = x.mT.contiguous() # (B, D, T)
x = x.transpose(-2, -1).contiguous() # (B, D, T)
x = replicate1(x, left=False)
x = F.interpolate(
x,
size=T * upsampling_factor + 1,
mode="linear",
align_corners=True,
)[..., :-1] # Remove the padded value.
y = x.mT.reshape(B, -1, D)
y = x.transpose(-2, -1).reshape(B, -1, D)

if d == 1:
y = y.view(-1)
Expand Down
6 changes: 3 additions & 3 deletions diffsptk/modules/magic_intpl.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def compute_lerp_inputs(x, magic_number):
if uniques[0]:
w[..., : counts[0]] = 0
w = torch.cumsum(w, dim=-1)
w = w - torch.cumsum(w * ~is_magic_number[i], dim=-1)
w = w - torch.cummax(w * ~is_magic_number[i], dim=-1)[0]
if uniques[0]:
w[..., : counts[0]] = 1
if uniques[-1]:
Expand All @@ -141,10 +141,10 @@ def compute_lerp_inputs(x, magic_number):
weights = torch.stack(weights)
return starts, ends, weights

x = x.mT.reshape(B * D, T)
x = x.transpose(-2, -1).reshape(B * D, T)
starts, ends, weights = compute_lerp_inputs(x, magic_number)
y = torch.lerp(starts, ends, weights)
y = y.reshape(B, D, T).mT
y = y.reshape(B, D, T).transpose(-2, -1)

if d == 1:
y = y.view(-1)
Expand Down
2 changes: 1 addition & 1 deletion diffsptk/modules/pitch.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ def calc_embed(self, x):

def calc_pitch(self, x):
# Compute pitch probabilities.
prob = self.calc_prob(x).mT
prob = self.calc_prob(x).transpose(-2, -1)

# Decode pitch probabilities.
pitch, periodicity = self.torchcrepe.postprocess(
Expand Down
2 changes: 1 addition & 1 deletion diffsptk/modules/rlevdur.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def _forward(a, eye):
E = torch.stack(E[::-1], dim=-1)

V = torch.linalg.solve_triangular(U, eye, upper=True, unitriangular=True)
r = torch.matmul(V[..., :1].mT * E, V).squeeze(-2)
r = torch.matmul(V[..., :1].transpose(-2, -1) * E, V).squeeze(-2)
return r

@staticmethod
Expand Down
2 changes: 1 addition & 1 deletion diffsptk/modules/unframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def fold(x):
return x

w = window.repeat(1, 1, N)
x = y.mT
x = y.transpose(-2, -1)

if d == 2:
x = x.unsqueeze(0)
Expand Down
14 changes: 11 additions & 3 deletions tests/test_magic_intpl.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,16 @@

@pytest.mark.parametrize("device", ["cpu", "cuda"])
@pytest.mark.parametrize("module", [False, True])
def test_compatibility(device, module, N=10, L=2, magic_number=0):
@pytest.mark.parametrize(
"data",
[
"0 9 0 0 0 0 2 1 0 0 4 5 0 0",
"1 9 0 0 0 0 2 1 0 0 4 5 0 7",
"1 2 3 4 5 6",
],
)
@pytest.mark.parametrize("L", [1, 2])
def test_compatibility(device, module, data, L, N=10, magic_number=0):
magic_intpl = U.choice(
module,
diffsptk.MagicNumberInterpolation,
Expand All @@ -36,14 +45,13 @@ def test_compatibility(device, module, N=10, L=2, magic_number=0):
device,
magic_intpl,
[],
"echo 0 9 0 0 0 0 2 1 0 0 4 5 0 0 | x2x +ad",
f"echo {data} | x2x +ad",
f"magic_intpl -l {L} -magic {magic_number}",
[],
dx=L,
dy=L,
)

U.check_differentiability(device, magic_intpl, [N, L])
U.check_differentiability(device, [magic_intpl, F.dropout], [N, L])


Expand Down

0 comments on commit ec16916

Please sign in to comment.