Skip to content

Commit

Permalink
Merge pull request #24 from ktonal/develop
Browse files Browse the repository at this point in the history
v0.2.2
  • Loading branch information
antoinedaurat authored Jun 15, 2021
2 parents 8285dbb + ae94e5c commit b8d6713
Show file tree
Hide file tree
Showing 13 changed files with 349 additions and 124 deletions.
10 changes: 8 additions & 2 deletions docs/audios.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,12 @@ audios

this package exposes helper classes for processing, interacting & modeling audio data

.. automodule:: mimikit.audios.fmodules
.. py:currentmodule:: mimikit.audios.fmodules
.. automodule:: mimikit.audios.features
.. autoclass:: FModule
:special-members: __call__
:members:


.. autoclass:: FileToSignal
:special-members: __init__, __call__
6 changes: 3 additions & 3 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,11 @@
# -- Project information -----------------------------------------------------

project = 'mimikit'
copyright = '2021, k-tonal'
author = 'k-tonal'
copyright = '2021, ktonal'
author = 'ktonal'

# The full version, including alpha/beta/rc tags
release = 'v0.1.6'
release = 'v0.2.1'


# -- General configuration ---------------------------------------------------
Expand Down
2 changes: 1 addition & 1 deletion mimikit/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = '0.2.1'
__version__ = '0.2.2'

from . import audios
from . import connectors
Expand Down
54 changes: 53 additions & 1 deletion mimikit/audios/fmodules.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,19 +27,55 @@

class FModule:
"""
base class for implementing message passing in callable objects
base class for implementing a custom partial function that supports multiple input types.
this is the base interface of all classes in this module.
"""

@property
def functions(self) -> dict:
"""
the functions this class implements for some set of inputs types
Returns
-------
dict
the keys must be types and the values callable objects taking a single argument
"""
raise NotImplementedError

def __call__(self, inputs):
"""
apply the functional corresponding to `type(inputs)` to inputs
Parameters
----------
inputs : any type supported by this ``FModule``
some inputs
Returns
-------
out
result of applying ``self.functions[type(inputs)]`` to ``inputs``
Raises
------
KeyError
if ``type(inputs)`` isn't in the keys of ``self.functions``
"""
return self.functions[type(inputs)](inputs)


@dtc.dataclass
class FileToSignal(FModule):
"""
returns the np.ndarray of an audio file read at a given sample rate.
Parameters
----------
sr : int
the sample rate
"""
sr: int = SR

@property
Expand All @@ -48,6 +84,22 @@ def functions(self):
str: lambda path: librosa.load(path, sr=self.sr)[0]
}

def __call__(self, inputs):
"""
get the array
Parameters
----------
inputs : str
path to the file
Returns
-------
signal : np.ndarray
the array as returned by ``librosa.load``
"""
return super(FileToSignal, self).__call__(inputs)


@dtc.dataclass
class Normalize(FModule):
Expand Down
2 changes: 1 addition & 1 deletion mimikit/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from .model import *
from .parts import *
from .sample_rnn import *
from .seq2seqlstm import *
from .s2s_lstm import *
from .wavenet import *

__all__ = [_ for _ in dir() if not _.startswith("_")]
Expand Down
2 changes: 1 addition & 1 deletion mimikit/models/freqnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ def demo():

dm = mmk.DataModule(net, db, splits=tuple())

cb = mmk.GenerateCallBack(every_n_epochs,
cb = mmk.GenerateCallback(every_n_epochs,
indices=[None]*n_examples,
n_steps=n_steps,
play_audios=True,
Expand Down
4 changes: 2 additions & 2 deletions mimikit/models/parts/sequence_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

__all__ = [
'SequenceModel',
'GenerateCallBack'
'GenerateCallback'
]


Expand Down Expand Up @@ -75,7 +75,7 @@ def generate(self, prompt, n_steps=16000,
return output


class GenerateCallBack(pl.callbacks.Callback):
class GenerateCallback(pl.callbacks.Callback):

def __init__(self, every_n_epochs=10, indices=3, n_steps=1000,
plot_audios=True, play_audios=True, log_audios=False,
Expand Down
163 changes: 163 additions & 0 deletions mimikit/models/s2s_lstm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
import numpy as np
import pytorch_lightning as pl
import torch
import torch.nn as nn
import dataclasses as dtc

from ..audios import Spectrogram
from ..data import Feature, AsSlice, Input, Target
from .parts import SuperAdam, SequenceModel, IData
from ..networks import Seq2SeqLSTM
from .parts.loss_functions import mean_L1_prop
from .model import model

__all__ = [
'Seq2SeqLSTMModel'
]


@dtc.dataclass(init=True, repr=False, eq=False, frozen=False, unsafe_hash=True)
class Seq2SeqData(IData):
feature: Feature = None
batch_size: int = 16
shift: int = 8

@classmethod
def schema(cls, sr=22050, emphasis=0., n_fft=2048, hop_length=512):
schema = {"fft": Spectrogram(sr=sr, emphasis=emphasis,
n_fft=n_fft, hop_length=hop_length,
magspec=True)}
return schema

@classmethod
def dependant_hp(cls, db):
return dict(
feature=Spectrogram(**db.fft.attrs), input_dim=db.fft.shape[-1]
)

def batch_signature(self, stage='fit'):
inpt = Input('fft', AsSlice(shift=0, length=self.shift))
trgt = Target('fft', AsSlice(shift=self.shift,
length=self.shift))
if stage in ('full', 'fit', 'train', 'val'):
return inpt, trgt
# test, predict, generate...
return inpt

def loader_kwargs(self, stage, datamodule):
return dict(
batch_size=self.batch_size,
drop_last=False,
shuffle=True
)


@model
class Seq2SeqLSTMModel(
Seq2SeqData,
SuperAdam,
SequenceModel,
Seq2SeqLSTM,
):

@staticmethod
def loss_fn(output, target):
return {"loss": mean_L1_prop(output, target)}

def encode_inputs(self, inputs: torch.Tensor):
return self.feature.encode(inputs)

def decode_outputs(self, outputs: torch.Tensor):
return self.feature.decode(outputs)


def demo():
"""### import and arguments"""
import mimikit as mmk

# DATA

# list of files or directories to use as data
sources = ['./data']
# audio sample rate
sr = 22050
# the size of the stft
n_fft = 2048
# hop_length of the
hop_length = n_fft // 4

# NETWORK

# this network takes `shift` fft frames as input and outputs `shift` future frames
shift = 8
# the net contains at least 3 LSTM modules (1 in the Encoder, 2 in the Decoder)
# you can add modules to the Encoder by increasing the next argument
n_lstm = 1
# all LSTM modules have internally the same number of layers :
num_layers = 1
# the dimensionality of the model
model_dim = 1024

# OPTIMIZATION

# how many epochs should we train for
max_epochs = 50
# how many examples are used pro training steps
batch_size = 16
# the learning rate
max_lr = 1e-3
# betas control how fast the network changes its 'learning course'.
# generally, betas should be close but smaller than 1. and be balanced with the batch_size :
# the smaller the batch, the higher the betas 'could be'.
betas = (0.9, 0.93)

# MONITORING

# how often should the network generate during training
every_n_epochs = 2
# how many examples from random prompts should be generated
n_examples = 3
# how many steps (1 step = `shift` fft frames!) should be generated
n_steps = 1000 // shift

print("arguments are ok!")

"""### create the data"""
schema = mmk.Seq2SeqLSTMModel.schema(sr, n_fft=n_fft, hop_length=hop_length)

db_path = 's2s-demo.h5'
print("collecting data...")
db = mmk.Database.create(db_path, sources, schema)
print("successfully created the db.")

"""### create network and train"""
net = mmk.Seq2SeqLSTMModel(
**mmk.Seq2SeqLSTMModel.dependant_hp(db),
shift=shift,
n_lstm=n_lstm,
num_layers=num_layers,
model_dim=model_dim,
batch_size=batch_size,
max_lr=max_lr,
div_factor=5,
betas=betas,

)
print(net.hparams)

dm = mmk.DataModule(net, db, splits=tuple())

cb = mmk.GenerateCallback(every_n_epochs,
indices=[None] * n_examples,
n_steps=n_steps,
play_audios=True,
plot_audios=True)

trainer = mmk.get_trainer(root_dir=None,
max_epochs=max_epochs,
callbacks=[cb],
checkpoint_callback=False)
print("here we go!")
trainer.fit(net, datamodule=dm)

"""----------------------------"""
Loading

0 comments on commit b8d6713

Please sign in to comment.