Skip to content

Commit

Permalink
Merge pull request #27 from ktonal/develop
Browse files Browse the repository at this point in the history
v0.2.6
  • Loading branch information
antoinedaurat authored Jun 29, 2021
2 parents 3194581 + 545c115 commit a05b34e
Show file tree
Hide file tree
Showing 5 changed files with 15 additions and 13 deletions.
2 changes: 1 addition & 1 deletion mimikit/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = '0.2.5'
__version__ = '0.2.6'

from . import audios
from . import connectors
Expand Down
2 changes: 1 addition & 1 deletion mimikit/audios/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
]


@dtc.dataclass
@dtc.dataclass(unsafe_hash=True)
class AudioSignal(Feature):
"""
audio signal managers
Expand Down
15 changes: 8 additions & 7 deletions mimikit/audios/fmodules.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def __call__(self, inputs):

@dtc.dataclass
class Normalize(FModule):
p: int = 1
p: int = float('inf')
dim: int = -1

@property
Expand Down Expand Up @@ -209,14 +209,15 @@ class STFT(FModule):
def functions(self):
def np_func(inputs):
# returned shape is (time x freq)
return librosa.stft(inputs, n_fft=self.n_fft, hop_length=self.hop_length).T
S =librosa.stft(inputs, n_fft=self.n_fft, hop_length=self.hop_length).T
S = np.stack((abs(S), np.angle(S)), axis=-1)
return S

def torch_func(inputs):
mod = T.Spectrogram(self.n_fft, hop_length=self.hop_length, power=1.,
wkwargs=dict(device=inputs.device))
# returned shape is (..., time x freq)
return mod(inputs).transpose(-1, -2).contiguous()

return {
np.ndarray: np_func,
torch.Tensor: torch_func
Expand All @@ -235,8 +236,8 @@ def np_func(inputs):
return librosa.istft(inputs.T, n_fft=self.n_fft, hop_length=self.hop_length, )

def torch_func(inputs):
# inputs is of shape (time x freq)
y = torch.istft(inputs.transpose(-1, -2).contiguous(),
inputs = inputs[..., 0] * torch.exp(1j * inputs[..., 1])
y = torch.istft(inputs.transpose(1, 2).contiguous(),
n_fft=self.n_fft, hop_length=self.hop_length,
window=torch.hann_window(self.n_fft, device=inputs.device))
return y
Expand All @@ -256,8 +257,8 @@ def functions(self):
# dict comprehension would result in a single function for
# all types, so we declare the dict manually...
return {
np.ndarray: lambda x: abs(sup_f[np.ndarray](x)),
torch.Tensor: lambda x: abs(sup_f[torch.Tensor](x))
np.ndarray: lambda x: abs(sup_f[np.ndarray](x)[..., 0]),
torch.Tensor: lambda x: abs(sup_f[torch.Tensor](x)[..., 0])
}


Expand Down
8 changes: 4 additions & 4 deletions mimikit/data/feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,8 @@ def encode(self, inputs):
-------
"""
if hasattr(super(type(self), self), 'encoders'):
inputs = super(type(self), self).encoders[type(inputs)](inputs)
if hasattr(super(), 'encoders'):
inputs = super().encoders[type(inputs)](inputs)
return self.encoders[type(inputs)](inputs)

def decode(self, inputs):
Expand All @@ -106,8 +106,8 @@ def decode(self, inputs):
"""
inputs = self.decoders[type(inputs)](inputs)
if hasattr(super(type(self), self), 'decoders'):
inputs = super(type(self), self).decoders[type(inputs)](inputs)
if hasattr(super(), 'decoders'):
inputs = super().decoders[type(inputs)](inputs)
return inputs

def load(self, path):
Expand Down
1 change: 1 addition & 0 deletions mimikit/networks/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .freqnet import *
from .generating_net import *
from .parametrized_gaussian import *
from .sample_rnn import *
from .s2s_lstm import *
Expand Down

0 comments on commit a05b34e

Please sign in to comment.