Skip to content

Commit

Permalink
add Checkpoint & generate_from_checkpoint demo
Browse files Browse the repository at this point in the history
  • Loading branch information
antoinedaurat committed May 24, 2022
1 parent 79728ff commit b9b8377
Show file tree
Hide file tree
Showing 15 changed files with 244 additions and 118 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci-pipeline.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ jobs:
sudo apt-get install -y libsndfile-dev
python -m pip install --quiet --upgrade pip
pip install --quiet -r requirements.txt
pip install --quiet hatch
pip install --quiet hatch==0.23.1
pip list | grep torch
- name: Test
Expand Down
2 changes: 1 addition & 1 deletion mimikit/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = '0.3.1'
__version__ = '0.3.2'

from . import extract
from . import features
Expand Down
2 changes: 1 addition & 1 deletion mimikit/demos/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@
from .srnn import *
from .s2s import *
from .wn import *

from .generate_from_checkpoint import *

__all__ = [_ for _ in dir() if not _.startswith("_")]
64 changes: 64 additions & 0 deletions mimikit/demos/generate_from_checkpoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@


def demo():
"""Generate From Checkpoint"""
import mimikit as mmk
import h5mapper as h5m
import torch
import matplotlib.pyplot as plt
import librosa

# load a checkpoint
ckpt = mmk.Checkpoint(
root_dir="./trainings/wn-test-gesten",
id='84e89798ec2c85e19790344fb598932118c7a65142e747e383907c5f7ced0f26',
epoch=1
)
net, feature = ckpt.network, ckpt.feature

# prompt positions in seconds
indices = [
1.1, 8.5, 46.3
]
# duration in seconds to generate converted to number of steps
n_steps = librosa.time_to_frames(8, sr=feature.sr, hop_length=feature.hop_length)

class SoundBank(h5m.TypedFile):
snd = h5m.Sound(sr=feature.sr, mono=True, normalize=True)

SoundBank.create("gen.h5", ckpt.train_hp["files"], )
soundbank = SoundBank("gen.h5", mode='r', keep_open=True)

def process_outputs(outputs, bidx):
output = feature.inverse_transform(outputs[0])
for i, out in enumerate(output):
y = out.detach().cpu().numpy()
plt.figure(figsize=(20, 2))
plt.plot(y)
plt.show(block=False)
mmk.audio(y, sr=feature.sr,
hop_length=feature.hop_length)

max_i = soundbank.snd.shape[0] - getattr(feature, "hop_length", 1) * net.rf
g_dl = soundbank.serve(
(feature.batch_item(shift=0, length=net.rf, training=False),),
sampler=mmk.IndicesSampler(N=len(indices),
indices=[librosa.time_to_samples(i, sr=feature.sr) for i in indices],
max_i=max_i,
redraw=False),
shuffle=False,
batch_size=len(indices)
)

loop = mmk.GenerateLoop(
network=net,
dataloader=g_dl,
inputs=(h5m.Input(None, h5m.AsSlice(dim=1, shift=-net.rf, length=net.rf), setter=h5m.Setter(dim=1)),),
n_steps=n_steps,
device='cuda' if torch.cuda.is_available() else 'cpu',
time_hop=net.hp.get("hop", 1),
process_outputs=process_outputs
)
loop.run()

"""----------------------------"""
1 change: 0 additions & 1 deletion mimikit/features/audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,6 @@ def write(self, filename, inputs):
@dtc.dataclass(unsafe_hash=True)
class MuLawSignal(AudioSignal):
q_levels: int = 256
target_width: int = 1
pr_y: torch.tensor = None

def __post_init__(self):
Expand Down
1 change: 1 addition & 0 deletions mimikit/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,6 @@
from .wavenets import *
from .s2s import *
from .nnn import *
from .checkpoint import *

__all__ = [_ for _ in dir() if not _.startswith("_")]
79 changes: 79 additions & 0 deletions mimikit/models/checkpoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import dataclasses as dtc
import h5mapper as h5m
import json
import os
from .s2s import Seq2SeqLSTM, Seq2SeqLSTMv0
from .wavenets import WaveNetFFT, WaveNetQx
from .srnns import SampleRNN

__all__ = [
'find_checkpoints',
'load_trainings_hp',
'load_network_cls',
'load_feature',
'Checkpoint'
]

def find_checkpoints(root="trainings"):
return h5m.FileWalker(r"\.h5", root)


def load_trainings_hp(dirname):
return json.loads(open(os.path.join(dirname, "hp.json"), 'r').read())


class CkptBank(h5m.TypedFile):
ckpt = h5m.TensorDict()


def load_feature(s):
import mimikit as mmk
loc = dict()
exec(f"feature = {s}", mmk.__dict__, loc)
return loc["feature"]


def load_network_cls(s):
loc = dict()
exec(f"cls = {s}", globals(), loc)
return loc["cls"]


@dtc.dataclass
class Checkpoint:
id: str
epoch: int
root_dir: str = "./"

@staticmethod
def get_id_and_epoch(path):
id_, epoch = path.split("/")[-2:]
return id_.strip("/"), int(epoch.split(".h5")[0].split("=")[-1])

@staticmethod
def from_path(path):
basename = os.path.dirname(os.path.dirname(path))
return Checkpoint(*Checkpoint.get_id_and_epoch(path), root_dir=basename)

@property
def os_path(self):
return os.path.join(self.root_dir, f"{self.id}/epoch={self.epoch}.h5")

def delete(self):
os.remove(self.os_path)

@property
def network(self):
bank = CkptBank(self.os_path, 'r')
hp = bank.ckpt.load_hp()
return bank.ckpt.load_checkpoint(hp["cls"], "state_dict")

@property
def feature(self):
bank = CkptBank(self.os_path, 'r')
hp = bank.ckpt.load_hp()
return hp['feature']

@property
def train_hp(self):
return load_trainings_hp(os.path.join(self.root_dir, self.id))
106 changes: 30 additions & 76 deletions mimikit/models/io_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,78 +18,6 @@ def qx_io(q_levels, net_in_dim, net_out_dim, mlp_dim, mlp_activation=nn.ReLU()):


def wn_qx_io(q_levels, net_in_dim, net_out_dim, mlp_dim, mlp_activation=nn.ReLU()):
class OHEmbedding(nn.Module):
def __init__(self, n_classes, net_in_dim):
super(OHEmbedding, self).__init__()
self.n_classes = n_classes
self.cv = nn.Conv1d(n_classes, net_in_dim, kernel_size=(1,), bias=False)

def forward(self, x):
out = nn.functional.one_hot(x, self.n_classes).to(x.device).float()
if self.training:
return nn.Dropout2d(1 / 3)(out)
return out
# return self.cv(y.transpose(-1, -2).contiguous()).transpose(-1, -2).contiguous()

class ConvMLP(nn.Module):
def __init__(self, net_out_dim, mlp_dim, n_classes):
super(ConvMLP, self).__init__()
self.cv1 = nn.Sequential(
nn.Linear(net_out_dim, n_classes, bias=False),
)
self.cv2 = nn.Sequential(
nn.Linear(n_classes, 1, bias=False),
)
# self.cv3 = nn.Linear(n_classes, 1, bias=False)
# self.bn = nn.BatchNorm1d(n_classes)
self.Q = n_classes
self.top_k = None
self.top_p = None

def forward(self, x, temperature=None):
# x = x.transpose(-1, -2).contiguous()
outpt = nn.Mish()(self.cv1(x))
outpt = torch.sin(self.cv2(outpt))
# outpt = outpt.transpose(-1, -2).contiguous()
return outpt.squeeze(-1)
# Roundable Sigmoid :
# outpt = (self.bn(self.cv2(outpt) / (self.Q/4))).transpose(-1, -2).contiguous() * self.Q
# return outpt.squeeze() if self.training else outpt.squeeze(-1)
# Standard Pr_Y
# outpt = (self.bn(outpt + self.cv2(outpt))).transpose(-1, -2).contiguous()
# if self.training:
# return outpt
# if temperature is None:
# return outpt.argmax(dim=-1)
# else:
# if not isinstance(temperature, torch.Tensor):
# temperature = torch.Tensor([temperature]).reshape(*([1] * (len(outpt.size()))))
# probas = outpt.squeeze() / temperature.to(outpt)
# if self.top_k is not None:
# indices_to_remove = probas < torch.topk(probas, self.top_k)[0][..., -1, None]
# probas[[indices_to_remove]] = - float("inf")
# probas = nn.Softmax(dim=-1)(probas)
# elif self.top_p is not None:
# sorted_logits, sorted_indices = torch.sort(probas, descending=True)
# cumulative_probs = torch.cumsum(nn.Softmax(dim=-1)(sorted_logits), dim=-1)
#
# # Remove tokens with cumulative probability above the threshold
# sorted_indices_to_remove = cumulative_probs > self.top_p
# # Shift the indices to the right to keep also the first token above the threshold
# sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
# sorted_indices_to_remove[..., 0] = 0
#
# indices_to_remove = sorted_indices[sorted_indices_to_remove]
# probas[indices_to_remove] = - float("inf")
# probas = nn.Softmax(dim=-1)(probas)
# else:
# probas = nn.Softmax(dim=-1)(probas)
# if probas.dim() > 2:
# o_shape = probas.shape
# probas = probas.view(-1, o_shape[-1])
# return torch.multinomial(probas, 1).reshape(*o_shape[:-1])
# return torch.multinomial(probas, 1)

class Dropout1d(nn.Dropout):

def forward(self, input):
Expand All @@ -108,16 +36,25 @@ def forward(self, input):
)
# embd = nn.utils.weight_norm(nn.Embedding(q_levels, net_in_dim), "weight")
# return inpt_mod, ConvMLP(net_out_dim, mlp_dim, q_levels)
return inpt_mod, SingleClassMLP(net_out_dim, mlp_dim, q_levels)
return inpt_mod, SingleClassMLP(net_out_dim, mlp_dim, q_levels, learn_temperature=True, n_hidden_layers=3)


def mag_spec_io(spec_dim, net_dim, in_chunks, out_chunks, scaled_activation=False, with_sampler=False):
return Chunk(nn.Linear(spec_dim, net_dim * in_chunks, bias=False),
in_chunks, sum_out=True), \
nn.Sequential(Chunk(nn.Linear(net_dim, spec_dim * out_chunks, bias=False),
out_chunks, sum_out=True),
*((ParametrizedGaussian(spec_dim, spec_dim, False, False), ) if with_sampler else ()),
*((ParametrizedGaussian(spec_dim, spec_dim, False, False),) if with_sampler else ()),
ScaledSigmoid(spec_dim, with_range=False) if scaled_activation else Abs())
# HOM("x -> y",
# *(Maybe(with_sampler,
# (ParametrizedGaussian(net_dim, spec_dim, True, False), "x -> f"))),
# *(Maybe(not with_sampler,
# (Chunk(nn.Linear(net_dim, spec_dim * out_chunks, bias=False),
# out_chunks, sum_out=True), "x -> f"))),
# (nn.Sequential(nn.Linear(net_dim, spec_dim), nn.Sigmoid()), "x -> temp"),
# (lambda f, temp: (f.abs_() / temp), "f, temp -> y")
# )


def pol_spec_io(spec_dim, net_dim, in_chunks, out_chunks, scaled_activation=False, phs='a', with_sampler=False):
Expand All @@ -133,7 +70,8 @@ def __init__(self):
"x -> phs",
*(Maybe(phs == "b",
(lambda self, x: torch.cos(self.psis.to(x) * x) * pi, "self, x -> x"))),
(nn.Sequential(Chunk(nn.Linear(net_dim, spec_dim * out_chunks, bias=False), out_chunks, sum_out=True), act_phs),
(nn.Sequential(Chunk(nn.Linear(net_dim, spec_dim * out_chunks, bias=False), out_chunks, sum_out=True),
act_phs),
'x -> phs'),
*(Maybe(phs == "a",
(lambda self, phs: torch.cos(
Expand All @@ -153,7 +91,23 @@ def __init__(self):
# phase module
(ScaledPhase(), 'x -> phs'),
# magnitude module
(nn.Sequential(Chunk(nn.Linear(net_dim, spec_dim * out_chunks, bias=False), out_chunks, sum_out=True), act_mag),
(nn.Sequential(Chunk(nn.Linear(net_dim, spec_dim * out_chunks, bias=False), out_chunks, sum_out=True),
act_mag),
'x -> mag'),
(lambda mag, phs: torch.stack((mag, phs), dim=-1), "mag, phs -> y")
)
# HOM("x -> y",
# (HOM("x -> mag",
# (Chunk(nn.Linear(net_dim, spec_dim * out_chunks, bias=False),
# out_chunks, sum_out=True), "x -> f"),
# (nn.Sequential(nn.Linear(net_dim, spec_dim), nn.Sigmoid()), "x -> temp"),
# (lambda f, temp: f / temp, "f, temp -> mag")
# ), "x -> mag"),
# (HOM("x -> phs",
# (Chunk(nn.Linear(net_dim, spec_dim * out_chunks, bias=False),
# out_chunks, sum_out=True), "x -> f"),
# # (nn.Sequential(nn.Linear(net_dim, spec_dim), nn.Sigmoid()), "x -> temp"),
# (lambda f: nn.Hardtanh()(f) * pi * 1.1, "f -> phs")
# ), "x -> phs"),
# (lambda mag, phs: torch.stack((mag, phs), dim=-1), "mag, phs -> y")
# )
7 changes: 4 additions & 3 deletions mimikit/modules/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,13 +56,14 @@ class ScaledActivation(nn.Module):
def __init__(self, activation, dim, with_range=True):
super(ScaledActivation, self).__init__()
self.activation = activation
self.scales = nn.Parameter(torch.rand(dim, ) * 100, )
# self.scales = nn.Parameter(torch.rand(dim, ) * 100, )
self.scales = nn.Linear(dim, dim)
self.dim = dim

def forward(self, x):
s = self.scales.to(x).view(*(d if d == self.dim else 1 for d in x.size()))
# s = self.scales.to(x).view(*(d if d == self.dim else 1 for d in x.size()))
# return self.activation(self.rg * x / self.scales) * self.scales
return self.activation(x) * s
return self.activation(x) / self.activation(self.scales(x))


class ScaledSigmoid(ScaledActivation):
Expand Down
Loading

0 comments on commit b9b8377

Please sign in to comment.