Skip to content

Commit

Permalink
Merge pull request #77 from sp-nitech/csm
Browse files Browse the repository at this point in the history
Add acr2csm
  • Loading branch information
takenori-y committed Jul 1, 2024
2 parents 6d4396d + 8f4e4e0 commit 44445ec
Show file tree
Hide file tree
Showing 11 changed files with 207 additions and 8 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ jobs:
matrix:
include:
- python: 3.8
torch: 1.11.0
torchaudio: 0.11.0
torch: 1.12.0
torchaudio: 0.12.0
- python: 3.12
torch: 2.3.1
torchaudio: 2.3.1
Expand Down
4 changes: 2 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ PROJECT := diffsptk
MODULE :=

PYTHON_VERSION := 3.9
TORCH_VERSION := 1.11.0
TORCHAUDIO_VERSION := 0.11.0
TORCH_VERSION := 1.12.0
TORCHAUDIO_VERSION := 0.12.0
PLATFORM := cu113

venv:
Expand Down
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
[![Stable Manual](https://img.shields.io/badge/docs-stable-blue.svg)](https://sp-nitech.github.io/diffsptk/2.0.1/)
[![Downloads](https://static.pepy.tech/badge/diffsptk)](https://pepy.tech/project/diffsptk)
[![Python Version](https://img.shields.io/pypi/pyversions/diffsptk.svg)](https://pypi.python.org/pypi/diffsptk)
[![PyTorch Version](https://img.shields.io/badge/pytorch-1.11.0%20%7C%202.3.1-orange.svg)](https://pypi.python.org/pypi/diffsptk)
[![PyTorch Version](https://img.shields.io/badge/pytorch-1.12.0%20%7C%202.3.1-orange.svg)](https://pypi.python.org/pypi/diffsptk)
[![PyPI Version](https://img.shields.io/pypi/v/diffsptk.svg)](https://pypi.python.org/pypi/diffsptk)
[![Codecov](https://codecov.io/gh/sp-nitech/diffsptk/branch/master/graph/badge.svg)](https://app.codecov.io/gh/sp-nitech/diffsptk)
[![License](https://img.shields.io/github/license/sp-nitech/diffsptk.svg)](https://github.com/sp-nitech/diffsptk/blob/master/LICENSE)
Expand All @@ -16,7 +16,7 @@
## Requirements

- Python 3.8+
- PyTorch 1.11.0+
- PyTorch 1.12.0+

## Documentation

Expand Down
17 changes: 17 additions & 0 deletions diffsptk/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,23 @@ def acorr(x, acr_order, norm=False, estimator="none"):
)


def acr2csm(r):
"""Convert autocorrelation to CSM coefficients.
Parameters
----------
r : Tensor [shape=(..., M+1)]
Autocorrelation.
Returns
-------
out : Tensor [shape=(..., M+1)]
CSM coefficients.
"""
return nn.AutocorrelationToCompositeSinusoidalModelCoefficients._func(r)


def alaw(x, abs_max=1, a=87.6):
"""Compress waveform by A-law algorithm.
Expand Down
9 changes: 7 additions & 2 deletions diffsptk/misc/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,8 +211,13 @@ def symmetric_toeplitz(x):

def hankel(x):
d = x.size(-1)
assert d % 2 == 1
X = x.unfold(-1, (d + 1) // 2, 1)
n = (d + 1) // 2
X = x.unfold(-1, n, 1)[..., :n, :]
return X


def vander(x):
X = torch.linalg.vander(x).transpose(-2, -1)
return X


Expand Down
1 change: 1 addition & 0 deletions diffsptk/modules/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .acorr import Autocorrelation
from .acr2csm import AutocorrelationToCompositeSinusoidalModelCoefficients
from .alaw import ALawCompression
from .ap import Aperiodicity
from .b2mc import MLSADigitalFilterCoefficientsToMelCepstrum
Expand Down
118 changes: 118 additions & 0 deletions diffsptk/modules/acr2csm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
# ------------------------------------------------------------------------ #
# Copyright 2022 SPTK Working Group #
# #
# Licensed under the Apache License, Version 2.0 (the "License"); #
# you may not use this file except in compliance with the License. #
# You may obtain a copy of the License at #
# #
# http://www.apache.org/licenses/LICENSE-2.0 #
# #
# Unless required by applicable law or agreed to in writing, software #
# distributed under the License is distributed on an "AS IS" BASIS, #
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. #
# See the License for the specific language governing permissions and #
# limitations under the License. #
# ------------------------------------------------------------------------ #

from scipy.special import comb
import torch
from torch import nn
import torch.nn.functional as F

from ..misc.utils import check_size
from ..misc.utils import hankel
from ..misc.utils import to
from ..misc.utils import vander
from .root_pol import PolynomialToRoots


class AutocorrelationToCompositeSinusoidalModelCoefficients(nn.Module):
"""See `this page <https://sp-nitech.github.io/sptk/latest/main/acr2csm.html>`_
for details.
Parameters
----------
csm_order : int >= 0
Order of CSM coefficients, :math:`M`.
"""

def __init__(self, csm_order):
super().__init__()

assert 1 <= csm_order
assert csm_order % 2 == 1

self.csm_order = csm_order
self.register_buffer("C", self._precompute(self.csm_order))

def forward(self, r):
"""Convert autocorrelation to CSM coefficients.
Parameters
----------
r : Tensor [shape=(..., M+1)]
Autocorrelation.
Returns
-------
out : Tensor [shape=(..., M+1)]
Composite sinusoidal model coefficients.
Examples
--------
>>> x = diffsptk.nrand(4)
>>> x
tensor([ 0.0165, -2.3693, 0.1375, -0.2262, 1.3307])
>>> acorr = diffsptk.Autocorrelation(5, 3)
>>> acr2csm = diffsptk.AutocorrelationToCompositeSinusoidalModelCoefficients(3)
>>> c = acr2csm(acorr(x))
>>> c
tensor([0.9028, 2.5877, 3.8392, 3.6153])
"""
check_size(r.size(-1), self.csm_order + 1, "dimension of autocorrelation")
return self._forward(r, self.C)

@staticmethod
def _forward(r, C):
u = torch.matmul(r, C)
u1, u2 = torch.tensor_split(u, 2, dim=-1)

U = hankel(-u)
p = torch.matmul(U.inverse(), u2.unsqueeze(-1)).squeeze(-1)
x = PolynomialToRoots._func(F.pad(p.flip(-1), (1, 0), value=1))
x, _ = torch.sort(x.real, descending=True)
w = torch.acos(x)

V = vander(x)
m = torch.matmul(V.inverse(), u1.unsqueeze(-1)).squeeze(-1)
csm = torch.cat((w, m), dim=-1)
return csm

@staticmethod
def _func(r):
C = AutocorrelationToCompositeSinusoidalModelCoefficients._precompute(
r.size(-1) - 1, dtype=r.dtype, device=r.device
)
return AutocorrelationToCompositeSinusoidalModelCoefficients._forward(r, C)

@staticmethod
def _precompute(csm_order, dtype=None, device=None):
N = csm_order + 1
B = torch.zeros((N, N), dtype=torch.double, device=device)
for n in range(N):
z = 2**-n
for k in range(n + 1):
B[k, n] = comb(n, k, exact=True) * z

C = torch.zeros((N, N), dtype=torch.double, device=device)
for k in range(N):
bias = k % 2
center = k // 2
length = center + 1
C[bias : bias + 2 * length : 2, k] = B[
bias + center : bias + center + length, k
]
C[1:] *= 2
return to(C, dtype=dtype)
1 change: 1 addition & 0 deletions diffsptk/modules/levdur.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def forward(self, r):
Examples
--------
>>> x = diffsptk.nrand(4)
>>> x
tensor([ 0.8226, -0.0284, -0.5715, 0.2127, 0.1217])
>>> acorr = diffsptk.Autocorrelation(5, 2)
>>> levdur = diffsptk.LevinsonDurbin(2)
Expand Down
11 changes: 11 additions & 0 deletions docs/modules/acr2csm.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
.. _acr2csm:

acr2csm
-------

.. autoclass:: diffsptk.AutocorrelationToCompositeSinusoidalModelCoefficients
:members:

.. autofunction:: diffsptk.functional.acr2csm

.. seealso:: :ref:`acorr`
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ classifiers = [
]
dependencies = [
"numpy < 2.0.0",
"scipy < 1.14.0",
"librosa >= 0.10.1",
"soundfile >= 0.10.2",
"torch >= 1.11.0",
Expand Down
45 changes: 45 additions & 0 deletions tests/test_acr2csm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# ------------------------------------------------------------------------ #
# Copyright 2022 SPTK Working Group #
# #
# Licensed under the Apache License, Version 2.0 (the "License"); #
# you may not use this file except in compliance with the License. #
# You may obtain a copy of the License at #
# #
# http://www.apache.org/licenses/LICENSE-2.0 #
# #
# Unless required by applicable law or agreed to in writing, software #
# distributed under the License is distributed on an "AS IS" BASIS, #
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. #
# See the License for the specific language governing permissions and #
# limitations under the License. #
# ------------------------------------------------------------------------ #

import pytest

import diffsptk
import tests.utils as U


@pytest.mark.parametrize("device", ["cpu", "cuda"])
@pytest.mark.parametrize("module", [False, True])
def test_compatibility(device, module, M=25, L=100, B=2):
acr2csm = U.choice(
module,
diffsptk.AutocorrelationToCompositeSinusoidalModelCoefficients,
diffsptk.functional.acr2csm,
{"csm_order": M},
)

U.check_compatibility(
device,
acr2csm,
[],
f"nrand -l {B*L} | acorr -m {M} -l {L}",
f"acr2csm -m {M}",
[],
dx=M + 1,
dy=M + 1,
)

acorr = diffsptk.Autocorrelation(L, M)
U.check_differentiability(device, [acr2csm, acorr], [B, L])

0 comments on commit 44445ec

Please sign in to comment.