-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add Checkpoint & generate_from_checkpoint demo
- Loading branch information
1 parent
79728ff
commit b9b8377
Showing
15 changed files
with
244 additions
and
118 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,4 @@ | ||
__version__ = '0.3.1' | ||
__version__ = '0.3.2' | ||
|
||
from . import extract | ||
from . import features | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
|
||
|
||
def demo(): | ||
"""Generate From Checkpoint""" | ||
import mimikit as mmk | ||
import h5mapper as h5m | ||
import torch | ||
import matplotlib.pyplot as plt | ||
import librosa | ||
|
||
# load a checkpoint | ||
ckpt = mmk.Checkpoint( | ||
root_dir="./trainings/wn-test-gesten", | ||
id='84e89798ec2c85e19790344fb598932118c7a65142e747e383907c5f7ced0f26', | ||
epoch=1 | ||
) | ||
net, feature = ckpt.network, ckpt.feature | ||
|
||
# prompt positions in seconds | ||
indices = [ | ||
1.1, 8.5, 46.3 | ||
] | ||
# duration in seconds to generate converted to number of steps | ||
n_steps = librosa.time_to_frames(8, sr=feature.sr, hop_length=feature.hop_length) | ||
|
||
class SoundBank(h5m.TypedFile): | ||
snd = h5m.Sound(sr=feature.sr, mono=True, normalize=True) | ||
|
||
SoundBank.create("gen.h5", ckpt.train_hp["files"], ) | ||
soundbank = SoundBank("gen.h5", mode='r', keep_open=True) | ||
|
||
def process_outputs(outputs, bidx): | ||
output = feature.inverse_transform(outputs[0]) | ||
for i, out in enumerate(output): | ||
y = out.detach().cpu().numpy() | ||
plt.figure(figsize=(20, 2)) | ||
plt.plot(y) | ||
plt.show(block=False) | ||
mmk.audio(y, sr=feature.sr, | ||
hop_length=feature.hop_length) | ||
|
||
max_i = soundbank.snd.shape[0] - getattr(feature, "hop_length", 1) * net.rf | ||
g_dl = soundbank.serve( | ||
(feature.batch_item(shift=0, length=net.rf, training=False),), | ||
sampler=mmk.IndicesSampler(N=len(indices), | ||
indices=[librosa.time_to_samples(i, sr=feature.sr) for i in indices], | ||
max_i=max_i, | ||
redraw=False), | ||
shuffle=False, | ||
batch_size=len(indices) | ||
) | ||
|
||
loop = mmk.GenerateLoop( | ||
network=net, | ||
dataloader=g_dl, | ||
inputs=(h5m.Input(None, h5m.AsSlice(dim=1, shift=-net.rf, length=net.rf), setter=h5m.Setter(dim=1)),), | ||
n_steps=n_steps, | ||
device='cuda' if torch.cuda.is_available() else 'cpu', | ||
time_hop=net.hp.get("hop", 1), | ||
process_outputs=process_outputs | ||
) | ||
loop.run() | ||
|
||
"""----------------------------""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
import dataclasses as dtc | ||
import h5mapper as h5m | ||
import json | ||
import os | ||
from .s2s import Seq2SeqLSTM, Seq2SeqLSTMv0 | ||
from .wavenets import WaveNetFFT, WaveNetQx | ||
from .srnns import SampleRNN | ||
|
||
__all__ = [ | ||
'find_checkpoints', | ||
'load_trainings_hp', | ||
'load_network_cls', | ||
'load_feature', | ||
'Checkpoint' | ||
] | ||
|
||
def find_checkpoints(root="trainings"): | ||
return h5m.FileWalker(r"\.h5", root) | ||
|
||
|
||
def load_trainings_hp(dirname): | ||
return json.loads(open(os.path.join(dirname, "hp.json"), 'r').read()) | ||
|
||
|
||
class CkptBank(h5m.TypedFile): | ||
ckpt = h5m.TensorDict() | ||
|
||
|
||
def load_feature(s): | ||
import mimikit as mmk | ||
loc = dict() | ||
exec(f"feature = {s}", mmk.__dict__, loc) | ||
return loc["feature"] | ||
|
||
|
||
def load_network_cls(s): | ||
loc = dict() | ||
exec(f"cls = {s}", globals(), loc) | ||
return loc["cls"] | ||
|
||
|
||
@dtc.dataclass | ||
class Checkpoint: | ||
id: str | ||
epoch: int | ||
root_dir: str = "./" | ||
|
||
@staticmethod | ||
def get_id_and_epoch(path): | ||
id_, epoch = path.split("/")[-2:] | ||
return id_.strip("/"), int(epoch.split(".h5")[0].split("=")[-1]) | ||
|
||
@staticmethod | ||
def from_path(path): | ||
basename = os.path.dirname(os.path.dirname(path)) | ||
return Checkpoint(*Checkpoint.get_id_and_epoch(path), root_dir=basename) | ||
|
||
@property | ||
def os_path(self): | ||
return os.path.join(self.root_dir, f"{self.id}/epoch={self.epoch}.h5") | ||
|
||
def delete(self): | ||
os.remove(self.os_path) | ||
|
||
@property | ||
def network(self): | ||
bank = CkptBank(self.os_path, 'r') | ||
hp = bank.ckpt.load_hp() | ||
return bank.ckpt.load_checkpoint(hp["cls"], "state_dict") | ||
|
||
@property | ||
def feature(self): | ||
bank = CkptBank(self.os_path, 'r') | ||
hp = bank.ckpt.load_hp() | ||
return hp['feature'] | ||
|
||
@property | ||
def train_hp(self): | ||
return load_trainings_hp(os.path.join(self.root_dir, self.id)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.