diff --git a/.github/workflows/ci-pipeline.yml b/.github/workflows/ci-pipeline.yml index 5458585..c11d238 100644 --- a/.github/workflows/ci-pipeline.yml +++ b/.github/workflows/ci-pipeline.yml @@ -15,9 +15,11 @@ jobs: python-version: '3.7' - name: Install dependencies run: | - sudo apt-get install -y libsndfile-dev + sudo apt-get install -y libsndfile-dev ffmpeg python -m pip install --quiet --upgrade pip - pip install --quiet -r requirements.txt + pip install -r requirements.txt + pip install -r requirements-test.txt + pip install -r requirements-torch.txt pip install --quiet hatch==0.23.1 pip list | grep torch diff --git a/.gitignore b/.gitignore index 0b842bc..cce7c09 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,4 @@ -.DS_* +.DS_Store *__pycache__* *pyc *.ipynb_checkpoints* diff --git a/README.md b/README.md index 7e4325b..a05908a 100644 --- a/README.md +++ b/README.md @@ -1,13 +1,40 @@ # mimikit -The MusIc ModelIng toolKIT (MIMIKIT) is a python package that does Machine Learning with music data. +The MusIc ModelIng toolKIT (`mimikit`) is a python package that does Machine Learning with audio data. -The goal of `mimikit` is to enable you to use referenced and experimental algorithms on data you provide. +Currently, it focuses on training auto-regressive neural networks to generate audio. -`mimikit` is still in early development, details and documentation are on their way. +but it does also contain an app to perform basic & experimental clustering of audio data in a notebook. -## License +## Installation + +you can install with pip +```shell script +pip install mimikit[torch] +``` +or with +```shell script +pip install --upgrade mimikit[torch] +``` +if you are looking for the latest version + +for an editable install, you'll need +```shell script +pip install -e . --config-settings editable_mode=compat +``` + +## Usage + +Head straight to the [notebooks](https://github.com/ktonal/mimikit-notebooks) for example usage of `mimikit`, or open them directly in Colab: -mimikit is distributed under the terms of the [GNU General Public License v3.0](https://choosealicense.com/licenses/gpl-3.0/) +[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/ktonal/mimikit-notebooks/blob/main) +## Output Samples + +You can explore the outputs of different trainings done with `mimikit` at this demo website: + + https://ktonal.github.io/mimikit-demo-outputs + +## License +`mimikit` is distributed under the terms of the [GNU General Public License v3.0](https://choosealicense.com/licenses/gpl-3.0/) diff --git a/dev-docs/dynamic io.md b/dev-docs/dynamic io.md new file mode 100644 index 0000000..d84b80e --- /dev/null +++ b/dev-docs/dynamic io.md @@ -0,0 +1,74 @@ + + CANNONICAL SHAPE: + + (BATCH, [CHANNEL], TIME, [DIM], [COMPONENT]) + + + abs(Spectro), MelSpec, MFCC, Qt, Repr: (Batch, Time, Dim) + + Complex_S: (Batch, [Channel], Time, Dim, 2) + + y: (Batch, [Channel], Time) + + enveloppe: (Batch, [Channel], Time) + + text, lyrics, file_label, segment_label: (Batch, [Channel], Time, [Embedding, Class_Size]) + + pitch, speaker_id: (Batch, [Channel], Time, [Embedding, Class_Size]) + + qx, k_mer_hash, frame_index, cluster_label: (Batch, [Channel], Time, [Embedding, Class_Size]) + + y_bits: (Batch, [Channel], Time, Bit_Depth) + + + + Modules can change SHAPES through: + + - Project: (Batch, [Channel], Time) ----> (Batch, [Channel], Time, Dim) + - Map/Transform: ANY Structure ----> SAME Structure + - Predict: Any ----> TRAINING: (..., Dim) BUT INFER: (..., 1) + - Fork: (...., Component) ----> (...) x Component + - Join: (...) x Component ----> (...., Component) + + Modules can change SIGNATURES through: + + - Iso: N Inputs ----> N Outputs + - Split: 1 Input ----> N Outputs + - Reduce: N Inputs ----> 1 Outputs + - Transform: N Inputs ----> M Outputs + + + So, now, we can try to solve: + + REPR Shape ---> F(...) = ??? ---> Model Shape + + +------ + + i.e. + repr_shape=(B=-1, T=-1) + model_shape=(B=-1, T=-1, D=128) + ---> ??? = Project + + repr_shape=(B=-1, T=-1, D=1025) X N + model_shape=(B=-1, T=-1, D=128) + ---> ??? = [Reduce, Map] OR [Join, Reduce], .... + + +**A. User says:** + + "I want to connect this_feat with this_network." + +**B. Mimikit answers:** + + "Then you can use this_options" + ---> IOConfigService.resolve(this_feat, this_network_class): -> {io_module_config, ...} + +**C. User chooses, configures and then clicks `Run`. mimikit goes:** + + "Let's connect all those tings" + --> ModelInstantiator(this_feat, this_network, this_io_config) -> Model + + + + diff --git a/dev-docs/quasi DDD.md b/dev-docs/quasi DDD.md new file mode 100644 index 0000000..8ee1f6c --- /dev/null +++ b/dev-docs/quasi DDD.md @@ -0,0 +1,130 @@ +## Motivation + +We want to + +- easily compose ML models from *components* (inputs/outputs number and type, modules, network architecture, ...) +- easily build UI interfaces to configure/instantiate them (and ideally, be able to save their state!) + +### Top level structure + +``` +mimikit +- config +- view +- checkpoint +- io + - audio + - mu_law + - spectrogram + - enveloppe + ... + - labels + - cluster_label + - speaker_label + ... +- information_retrieval + - segment_audio + - cluster + ... +- modules + - loss_fn + ... +- networks + - io_wrappers + - sample_rnn + - wavenet + ... +- models + - srnn + - freqnet + ... +- loops + - train_loop + - generate_loop +- trainings + - train_arm + - train_gan + - train_diffusion +- scripts + - generate_arm + - train_freqnet + - ensemble + - eval_arm + .... +- notebooks + - generate_arm + - train_freqnet + - ensemble + - explore_cluster + ..... +``` + +### ML Component Design Pattern + +use `dataclasses` and inheritance to define & connect the layers of a 'ML component', e.g. + +``` +@dtc.dataclass +class NetConfig(Config): + ... + + +@dtc.dataclass +class NetImpl(NetConfig, nn.Module): + + def __post_init__(self): + # init modules... + def forward(self, inputs, **kwargs): + ... + + +@dtc.dataclass +class NetView(NetConfig, ConfigView): + + def __post_init__(self): + # ... map params to widget ... + ConfigView.__init__(self, **params) +``` + +Constructors are +- type-safe +- consistent across layers +- (de)serializable +- **defined once** + +Then we can nest Configs, Impls & Views by doing: + +``` + +@dtc.dataclass +class NestedConfig(Config): + io: IOConfig + net: NetConfig + + +@dtc.dataclass +class Model(NestedConfig, nn.Module): + def __post_init__(self): + .... + +@dtc.dataclass +class ModelView(NestedConfig, ConfigView): + ..... +``` + +--> We win +- generic saving/loading Checkpoints +- highly expressive composition for io, models, features, views, etc... + + +### Implementation + +different libraries / base classes offer different trade-offs between ease of use and ease of integrations with other libraries. + +Ideally, we could, +- define a constructor -> `dataclass, attrs, namedtuple` +- have static type checker recognize it +- attach it to a nn.Module -> inheritance?, `classmethod`?, decorator? +- use it in a View -> switch mutability +- (de)serialize it as config -> `OmegaConf` +- be able to export the nn.Module as `TorchScript` diff --git a/docs/audios.rst b/docs/audios.rst index 7532174..7ade0fc 100644 --- a/docs/audios.rst +++ b/docs/audios.rst @@ -5,7 +5,7 @@ this package exposes helper classes for processing, interacting & modeling audio .. py:currentmodule:: mimikit.audios.fmodules -.. autoclass:: FModule +.. autoclass:: Functional :special-members: __call__ :members: diff --git a/mimikit/__init__.py b/mimikit/__init__.py index 15cb860..251164b 100644 --- a/mimikit/__init__.py +++ b/mimikit/__init__.py @@ -1,21 +1,28 @@ -__version__ = '0.3.4' +__version__ = '0.4.1' -from . import extract +from . import config from . import features from . import loops +from . import checkpoint from . import modules +from . import extract +from . import io_spec from . import models from . import networks from . import demos +from . import ui +from . import views -from .extract import * +from .checkpoint import * +from .config import * from .features import * from .loops import * from .modules import * +from .extract import * from .models import * from .networks import * - -from .train import * +from .ui import * from .utils import * - +from .views import * +from .io_spec import * from .demos import * \ No newline at end of file diff --git a/mimikit/checkpoint.py b/mimikit/checkpoint.py new file mode 100644 index 0000000..9600e5e --- /dev/null +++ b/mimikit/checkpoint.py @@ -0,0 +1,152 @@ +import abc +import dataclasses as dtc +from typing import Optional +from typing_extensions import Protocol + +try: + from functools import cached_property +except ImportError: # python<3.8 + def cached_property(f): + return property(f) + +import torch.nn as nn + +import h5mapper as h5m +import os + +from .config import Config, Configurable +from .networks.arm import NetworkConfig +from .features.dataset import DatasetConfig + +__all__ = [ + 'Checkpoint', + 'CheckpointBank' +] + + +class ConfigurableModule(Configurable, nn.Module, abc.ABC): + @property + @abc.abstractmethod + def config(self) -> NetworkConfig: + ... + + +class TrainingConfig(Protocol): + @property + def dataset(self) -> DatasetConfig: + ... + + @property + def network(self) -> NetworkConfig: + ... + + @property + def training(self) -> Config: + ... + + +class CheckpointBank(h5m.TypedFile): + network = h5m.TensorDict() + optimizer = h5m.TensorDict() + + @classmethod + def save(cls, + filename: str, + network: ConfigurableModule, + training_config: Optional[TrainingConfig] = None, + optimizer: Optional[nn.Module] = None + ) -> "CheckpointBank": + + net_dict = network.state_dict() + opt_dict = optimizer.state_dict() if optimizer is not None else {} + cls.network.set_ds_kwargs(net_dict) + if optimizer is not None: + cls.optimizer.set_ds_kwargs(opt_dict) + os.makedirs(os.path.split(filename)[0], exist_ok=True) + + bank = cls(filename, mode="w") + bank.network.attrs["config"] = network.config.serialize() + bank.network.add("state_dict", h5m.TensorDict.format(net_dict)) + + if optimizer is not None: + bank.optimizer.add("state_dict", h5m.TensorDict.format(opt_dict)) + if training_config is not None: + bank.attrs["dataset"] = training_config.dataset.serialize() + bank.attrs["training"] = training_config.training.serialize() + else: + # make a dataset config for being able to at least load the network later + features = [*network.config.io_spec.inputs, *network.config.io_spec.targets] + schema = {f.extractor_name: f.extractor for f in features} + bank.attrs["dataset"] = DatasetConfig(filename="unknown", sources=(), + extractors=tuple(schema.values())).serialize() + + bank.flush() + bank.close() + return bank + + +@dtc.dataclass +class Checkpoint: + id: str + epoch: int + root_dir: str = "./" + + def create(self, + network: ConfigurableModule, + training_config: Optional[TrainingConfig] = None, + optimizer: Optional[nn.Module] = None): + CheckpointBank.save(self.os_path, network, training_config, optimizer) + return self + + @staticmethod + def get_id_and_epoch(path): + id_, epoch = path.split("/")[-2:] + return id_.strip("/"), int(epoch.split(".h5")[0].split("=")[-1]) + + @staticmethod + def from_path(path): + basename = os.path.dirname(os.path.dirname(path)) + return Checkpoint(*Checkpoint.get_id_and_epoch(path), root_dir=basename) + + @property + def os_path(self): + return os.path.join(self.root_dir, f"{self.id}/epoch={self.epoch}.ckpt") + + def delete(self): + os.remove(self.os_path) + + @cached_property + def bank(self) -> CheckpointBank: + return CheckpointBank(self.os_path, 'r') + + @cached_property + def dataset_config(self) -> DatasetConfig: + return Config.deserialize(self.bank.attrs["dataset"], as_type=DatasetConfig) + + @cached_property + def network_config(self) -> NetworkConfig: + return Config.deserialize(self.bank.network.attrs["config"]) + + @cached_property + def training_config(self) -> TrainingConfig: + bank = CheckpointBank(self.os_path, 'r') + return Config.deserialize(bank.attrs["training"], as_type=TrainingConfig) + + @cached_property + def network(self) -> ConfigurableModule: + cfg: NetworkConfig = self.network_config + cfg.io_spec.bind_to(self.dataset_config) + cls = cfg.owner_class + state_dict = self.bank.network.get("state_dict") + net = cls.from_config(cfg) + net.load_state_dict(state_dict, strict=True) + return net + + @cached_property + def dataset(self) -> h5m.TypedFile: + dataset: DatasetConfig = self.dataset_config + if os.path.exists(dataset.filename): + return dataset.get(mode="r") + return dataset.create(mode="w") + + # Todo: method to add state_dict mul by weights -> def average(self, *others) \ No newline at end of file diff --git a/mimikit/config.py b/mimikit/config.py new file mode 100644 index 0000000..f6cb89a --- /dev/null +++ b/mimikit/config.py @@ -0,0 +1,141 @@ +import abc +import sys +from copy import deepcopy +from omegaconf import OmegaConf, ListConfig, DictConfig +from typing import List, Tuple, Union, Dict, Any +import dataclasses as dtc +from functools import reduce, partial + +__all__ = [ + "private_runtime_field", + "Config", + "Configurable", +] + + +def private_runtime_field(default): + return dtc.field(init=False, repr=False, metadata=dict(omegaconf_ignore=True), default_factory=lambda: default) + + +# noinspection PyTypeChecker +def _get_type_object(type_) -> type: + if ":" in type_: + module, qualname = type_.split(":") + else: + module, qualname = "mimikit", type_ + try: + m = sys.modules[module] + return reduce(lambda o, a: getattr(o, a), qualname.split("."), m) + except AttributeError or KeyError: + raise ImportError(f"could not find class '{qualname}' from module {module} in current environment") + + +STATIC_TYPED_KEYS = { + "dataset": "DatasetConfig", + "io_spec": "IOSpec", + "inputs": "InputSpec", + "targets": "TargetSpec", + "objective": "Objective", + "feature": "Feature", + "extractor": "Extractor", + "activation": "ActivationConfig" +} + + +@dtc.dataclass +class Config: + ## type: str = dtc.field(init=False, repr=False, default="mimikit.config:Config") + + @classmethod + def __init_subclass__(cls, type_field=True, **kwargs): + """add type info to subclass""" + if type_field: + default = f"{cls.__qualname__}" + if not cls.__module__.startswith("mimikit"): + default = f"{cls.__module__}:{default}" + setattr(cls, "type", dtc.field(init=False, default=default, repr=False)) + if "__annotations__" in cls.__dict__: + # put the type first for nicer serialization... + ann = cls.__dict__["__annotations__"].copy() + for k in ann: + cls.__dict__["__annotations__"].pop(k) + cls.__dict__["__annotations__"].update({"type": str, **ann}) + else: + setattr(cls, "__annotations__", {"type": str}) + + @staticmethod + def validate_class(cls: type): + if "__dataclass_fields__" not in cls.__dict__: + if not issubclass(cls, (tuple, list)): + raise TypeError("Please decorate your Config class with @dataclass" + " so that it can be (de)serialized") + + @property + def owner_class(self): + module, type_ = type(self).__module__, type(self).__qualname__ + type_ = ".".join(type_.split(".")[:-1]) if "." in type_ else type_ + type_ = f"{module}:{type_}" + return _get_type_object(type_) + + def serialize(self): + self.validate_class(type(self)) + cfg = OmegaConf.structured(self) + return OmegaConf.to_yaml(cfg) + + @staticmethod + def deserialize(raw_yaml, as_type=None): + cfg = OmegaConf.create(raw_yaml) + return Config.object(cfg, as_type) + + @staticmethod + def object(cfg: Union[ListConfig, DictConfig, Dict, List, Tuple, Any], as_type=None): + if isinstance(cfg, (DictConfig, Dict)): + for k, v in cfg.items(): + if k in STATIC_TYPED_KEYS: + cls = _get_type_object(STATIC_TYPED_KEYS[k]) + setattr(cfg, k, Config.object(v, as_type=cls)) + elif k == "extractors": + setattr(cfg, k, tuple(map(partial(Config.object, as_type=_get_type_object("Extractor")), v))) + elif isinstance(v, (ListConfig, DictConfig, Dict, List, Tuple)): + setattr(cfg, k, Config.object(v)) + if as_type is not None: + cls = as_type + elif "type" in cfg: + cls = _get_type_object(cfg.type) + else: # untyped raw dict + return cfg + if isinstance(cfg, DictConfig): + cfg._metadata.object_type = cls + return OmegaConf.to_object(cfg) + else: + return cls(**cfg) + + elif isinstance(cfg, (ListConfig, List, Tuple)): + return OmegaConf.to_object(OmegaConf.structured([*map(partial(Config.object, as_type=as_type), cfg)])) + # any other kind of value + return cfg + + def dict(self): + """caution! nested configs are also converted!""" + return dtc.asdict(self) + + def copy(self): + return deepcopy(self) + + def validate(self) -> Tuple[bool, str]: + return True, '' + + +class Configurable(abc.ABC): + + @classmethod + @abc.abstractmethod + def from_config(cls, config: Config): + ... + + @property + @abc.abstractmethod + def config(self) -> Config: + ... + + diff --git a/mimikit/demos/__init__.py b/mimikit/demos/__init__.py index 4f8c961..b9f229d 100644 --- a/mimikit/demos/__init__.py +++ b/mimikit/demos/__init__.py @@ -1,8 +1,8 @@ -from .freqnet import * from .srnn import * -from .s2s import * -from .wn import * +from .seq2seq import * +from .freqnet import * from .generate_from_checkpoint import * -from .ensemble import * +from .ensemble_generator import * +from .clusterizer_app import * __all__ = [_ for _ in dir() if not _.startswith("_")] diff --git a/mimikit/demos/clusterizer_app.py b/mimikit/demos/clusterizer_app.py new file mode 100644 index 0000000..6ff5642 --- /dev/null +++ b/mimikit/demos/clusterizer_app.py @@ -0,0 +1,44 @@ +def demo(): + """### Launch the app""" + + import mimikit as mmk + from ipywidgets import widgets as W + import IPython.display as ipd + + ipd.display(mmk.MMK_STYLE_SHEET) + ipd.display(W.HTML( + """ + +""" +) diff --git a/mimikit/ui/widgets.py b/mimikit/ui/widgets.py new file mode 100644 index 0000000..86d3ab9 --- /dev/null +++ b/mimikit/ui/widgets.py @@ -0,0 +1,163 @@ +from typing import Iterable + +from ipywidgets import widgets as W, GridspecLayout +import os + +from ..loops.callbacks import tqdm + +__all__ = [ + "EnumWidget", + "pw2_widget", + "yesno_widget", + "Labeled", + "UploadWidget", +] + + +# TODO: +# - wavenet/sampleRNN -> all params +# - refine FilePicker (layout, load/save buttons) +# - network views from Select +# - move *_io() to IOSpecs +# - *_io views (and Select) + + +def Labeled( + label, widget, tooltip=None +): + label_w = W.Label(value=label, tooltip=label) + if tooltip is not None: + tltp = W.Button(icon="fa-info", tooltip=tooltip, + layout=W.Layout( + width="20px", + height="12px"), + disabled=True, + ).add_class("tltp") + label_w = W.HBox(children=[label_w, tltp], ) + label_w.layout = W.Layout(min_width="max_content", width="auto", + overflow='revert') + container = W.GridBox(children=(label_w, widget), + layout=dict(width="auto", + grid_template_columns='1fr 2fr') + ) + # container.value = widget.value + container.observe = widget.observe + return container + + +def pw2_widget( + initial_value, + min_value=1, + max_value=2 ** 16, +): + plus = W.Button(icon="plus", layout=dict(width="auto", overflow='hidden', grid_area='plus')) + minus = W.Button(icon="minus", layout=dict(width="auto", overflow='hidden', grid_area='minus')) + value = W.Text(value=str(initial_value), layout=dict(width="auto", overflow='hidden', grid_area='val')) + plus.on_click(lambda clk: setattr(value, "value", str(min(max_value, int(value.value) * 2)))) + minus.on_click(lambda clk: setattr(value, "value", str(max(min_value, int(value.value) // 2)))) + grid = W.GridBox(children=(minus, value, plus), + layout=dict(grid_template_columns='1fr 1fr 1fr', + grid_template_rows='1fr', + grid_template_areas='"minus val plus"')) + # bind value state to box state + # grid.value = value.value + grid.observe = value.observe + return grid + + +def yesno_widget( + initial_value=True, +): + yes = W.ToggleButton( + value=initial_value, + description="yes", + button_style="success" if initial_value else "", + layout=dict(width='auto', maring='auto 4px', grid_area='yes') + ) + no = W.ToggleButton( + value=not initial_value, + description="no", + button_style="" if initial_value else "danger", + layout=dict(width='auto', maring='auto 4px', grid_area='no') + ) + + def toggle_yes(ev): + v = ev["new"] + if v: + setattr(yes, "button_style", "success") + setattr(no, "button_style", "") + setattr(no, "value", False) + + def toggle_no(ev): + v = ev["new"] + if v: + setattr(no, "button_style", "danger") + setattr(yes, "button_style", "") + setattr(yes, "value", False) + + yes.observe(toggle_yes, "value") + no.observe(toggle_no, "value") + grid = W.GridBox(children=(yes, no), + layout=dict(grid_template_columns='1fr 1fr', + grid_template_rows='1fr', + grid_template_areas='"yes no"')) + grid.observe = yes.observe + return grid + + +def EnumWidget( + label: str, + options: Iterable[str], + value_type=str, + selected_index=0 +): + options_w = W.GridBox(children=tuple(W.ToggleButton(value=False, + description=opt, + tooltip=opt, + layout=dict(margin='0 4px', width='auto')) + for opt in options), + layout=dict(grid_template_columns='1fr ' * len(options), + width='auto', align_self='center')) + container = Labeled(label, options_w) + dummy = W.Text(value='') + if isinstance(selected_index, int): + value = options_w.children[selected_index].value if value_type is str else \ + value_type(options_w.children[selected_index].value) + options_w.children[selected_index].value = True + setattr(options_w.children[selected_index], "button_style", "success") + else: + value = selected_index + container.value = value + for i, child in enumerate(options_w.children): + def observer(ev, c=child, index=i): + val = ev["new"] + if val and dummy.value != c.description: + container.selected_index = index + dummy.value = c.description if value_type is str else value_type(c.description) + setattr(c, "button_style", "success") + for other in options_w.children: + if other.value and other is not c: + other.value = False + other.button_style = "" + elif not val and dummy.value == c.description: + c.value = True + + child.observe(observer, "value") + container.observe = dummy.observe + return container + + +def UploadWidget(dest="./"): + def write_uploads(inputs): + for file in tqdm(inputs["new"], leave=False): + with open(os.path.join(dest, file.name), "wb") as f: + f.write(file.content.tobytes()) + + upload = W.FileUpload( + accept='', + multiple=True, + ) + + upload.observe(write_uploads, names='value') + + return upload diff --git a/mimikit/utils.py b/mimikit/utils.py index cd5b03d..71823e1 100644 --- a/mimikit/utils.py +++ b/mimikit/utils.py @@ -1,61 +1,35 @@ -import librosa -import numpy as np -from librosa.display import specshow -import IPython.display as ipd -import matplotlib.pyplot as plt +from enum import Enum +import re __all__ = [ - 'audio', - 'show' + "AutoStrEnum", + "SOUND_FILE_REGEX", + "CHECKPOINT_REGEX", + "DATASET_REGEX", + "default_device" ] -HOP_LENGTH, SR = 512, 22050 -# Conversion +SOUND_FILE_REGEX = re.compile(r"wav$|aif$|aiff$|mp3$|mp4$|m4a$|webm$") +DATASET_REGEX = re.compile(r".*\.h5$") +CHECKPOINT_REGEX = re.compile(r".*\.ckpt$") -a2db = lambda S: librosa.amplitude_to_db(abs(S), ref=S.max()) +class AutoStrEnum(str, Enum): + """ + Workaround while https://github.com/omry/omegaconf/pull/865 is still open... + """ + @staticmethod + def _generate_next_value_(name: str, start: int, count: int, last_values: list) -> str: + return name -# Debugging utils -def to_db(S): - if S.dtype == np.complex64: - S_hat = a2db(abs(S)) + 40 - elif S.min() >= 0 and S.dtype in (np.float, np.float32, np.float64, np.float_): - S_hat = a2db(S) + 40 - else: - S_hat = a2db(S) + 40 - return S_hat +def default_device(): + import torch # don't force user to install torch...(?) - -def signal(S, hop_length=HOP_LENGTH): - if S.dtype in (np.complex64, np.complex128): - return librosa.istft(S, hop_length=hop_length) - else: - return librosa.griffinlim(S, hop_length=hop_length, n_iter=32) - - -def audio(S, hop_length=HOP_LENGTH, sr=SR): - if len(S.shape) > 1: - y = signal(S, hop_length) - if y.size > 0: - return ipd.display(ipd.Audio(y, rate=sr)) - else: - return ipd.display(ipd.Audio(np.zeros(hop_length*2), rate=sr)) - else: - return ipd.display(ipd.Audio(S, rate=sr)) - - -def show(S, figsize=(), db_scale=True, title="", **kwargs): - S_hat = to_db(S) if db_scale else S - if figsize: - plt.figure(figsize=figsize) - if "x_axis" not in kwargs: - kwargs["x_axis"] = "frames" - if "y_axis" not in kwargs: - kwargs["y_axis"] = "frames" - ax = specshow(S_hat, sr=SR, **kwargs) - plt.colorbar() - plt.tight_layout() - plt.title(title) - return ax + device = "cpu" + if torch.cuda.is_available(): + device = "cuda" + elif torch.has_mps: + device = "mps" + return device diff --git a/mimikit/views/__init__.py b/mimikit/views/__init__.py new file mode 100644 index 0000000..1f93d02 --- /dev/null +++ b/mimikit/views/__init__.py @@ -0,0 +1,8 @@ +from .clusters import * +from .clusterizer_app import * +from .dataset import * +from .functionals import * +from .io_spec import * +from .train_arm import * +from .wavenet import * +from .sample_rnn import * diff --git a/mimikit/views/clusterizer_app.py b/mimikit/views/clusterizer_app.py new file mode 100644 index 0000000..1e3f96a --- /dev/null +++ b/mimikit/views/clusterizer_app.py @@ -0,0 +1,458 @@ +from typing import * +import dataclasses as dtc + +import h5mapper +from ipywidgets import widgets as W +import numpy as np +import pandas as pd +from peaksjs_widget import PeaksJSWidget, Segment +import qgrid + +from ..config import Config +from ..extract.clusters import * +from ..features.dataset import DatasetConfig +from ..features.functionals import * +from .clusters import * +from .dataset import dataset_view +from .functionals import * + +__all__ = [ + 'ComposeTransformWidget', + 'ClusterWidget', + 'ClusterizerApp' +] + + + +@dtc.dataclass +class Meta: + config_class: Type + view_func: Callable + requires: List[Type] = dtc.field(default_factory=lambda: []) + only_once: bool = False + + def can_be_added(self, preceding_transforms: List[Type]): + if not self.requires: + return not preceding_transforms + if self.requires[0] is Any and len(preceding_transforms) > 0: + return True + deps_fullfilled = self.requires == preceding_transforms + if self.only_once: + already_there = any(f is self.config_class for f in preceding_transforms) + else: + already_there = False + return deps_fullfilled and not already_there + + +TRANSFORMS = { + "magspec": Meta(MagSpec, magspec_view, [], True), + "melspec": Meta(MelSpec, melspec_view, [MagSpec], True), + "mfcc": Meta(MFCC, mfcc_view, [MagSpec, MelSpec], True), + "chroma": Meta(Chroma, chroma_view, [MagSpec], True), + "autoconvolve": Meta(AutoConvolve, autoconvolve_view, [Any], False), + "f0 filter": Meta(F0Filter, f0_filter_view, [MagSpec], False), + "nearest_neighbor_filter": Meta(NearestNeighborFilter, nearest_neighbor_filter_view, [Any]), + "pca": Meta(PCA, pca_view, [Any]), + "nmf": Meta(NMF, nmf_view, [Any]), + "factor analysis": Meta(FactorAnalysis, factor_analysis_view, [Any]) +} + +CLUSTERINGS = { + "grid of means": Meta(GCluster, gcluster_view, [], True), + "quantile clustering": Meta(QCluster, qcluster_view, [], True), + "argmax": Meta(ArgMax, argmax_view, [], True), + "kmeans": Meta(KMeans, kmeans_view, [], True), + "spectral clustering": Meta(SpectralClustering, spectral_clustering_view, [], True) +} + + +class ComposeTransformWidget: + + @staticmethod + def header(box): + collapse_all = W.Button(description="collapse all", + layout=dict(width="max-content", margin="auto 4px auto 2px")) + expand_all = W.Button(description="expand all", + layout=dict(width="max-content", margin="auto auto auto 4px")) + header = W.HBox(children=[ + W.HTML(value="

Pre Processing Pipeline

", layout=dict(margin="auto")), + collapse_all, expand_all + ]) + + def on_collapse(ev): + for item in box.children: + if isinstance(item, W.HBox) and isinstance(item.children[-1], W.Accordion): + item.children[-1].selected_index = None + + def on_expand(ev): + for item in box.children: + if isinstance(item, W.HBox) and isinstance(item.children[-1], W.Accordion): + item.children[-1].selected_index = 0 + + collapse_all.on_click(on_collapse) + expand_all.on_click(on_expand) + + return header + + def __init__(self): + self.transforms = [] + self.metas = [] + new_choice = W.Button(icon="fa-plus", layout=dict(margin="8px auto")) + + box = W.VBox(layout=dict(width="50%")) + header = self.header(box) + box.children = (header,) + + choices = W.Select(options=self.get_possible_choices(), layout=dict(width="100%", margin="4px auto")) + submit = W.Button(description="submit", layout=dict(width="max-content", margin="auto 8px")) + cancel = W.Button(description="cancel", layout=dict(width="max-content", margin="auto 8px")) + choice_box = W.VBox(children=(choices, W.HBox(children=(submit, cancel), + layout=dict(margin="4px auto"))), + layout=dict(width="calc(100% - 54px)", margin="auto 0 auto 54px")) + + def show_new_choice(ev): + box.children = (*filter(lambda b: b is not new_choice, box.children), choice_box) + new_choice.disabled = True + + def add_choice(ev): + meta, cfg, remove_w, hbox = self.new_transform_view_for(choices.value) + choices.options = self.get_possible_choices() + + def remove_cb(ev): + keep = [0] # always keep magspec + for i, (t, m) in enumerate(zip(self.transforms[1:], self.metas[1:]), 1): + is_el = t is cfg + requires_t = type(cfg) in m.requires + if not is_el and not requires_t: + keep += [i] + self.transforms = [self.transforms[i] for i in keep] + box.children = (box.children[0],) + tuple(box.children[i + 1] for i in keep) + (box.children[-1],) + choices.options = self.get_possible_choices() + + remove_w.on_click(remove_cb) + box.children = (*filter(lambda b: b is not choice_box, box.children), hbox, new_choice) + new_choice.disabled = False + + def cancel_new_choice(ev): + box.children = (*filter(lambda b: b is not choice_box, box.children), new_choice) + new_choice.disabled = False + + submit.on_click(add_choice) + cancel.on_click(cancel_new_choice) + new_choice.on_click(show_new_choice) + self.widget = box + choices.value = "magspec" + submit.click() + # can not remove magspec + box.children[1].children[0].disabled = True + self.magspec_cfg = self.transforms[0] + + def get_possible_choices(self): + options = [] + ts = [*map(type, self.transforms)] + for k, meta in TRANSFORMS.items(): + if meta.can_be_added(ts): + options += [k] + return options + + def new_transform_view_for(self, keyword: str): + meta = TRANSFORMS[keyword] + cfg = meta.config_class() + self.metas.append(meta) + self.transforms.append(cfg) + new_w = meta.view_func(cfg) + new_w.layout = dict(margin="auto", width='100%') + remove_w = W.Button(icon="fa-trash", layout=dict(width="50px", margin="auto 2px")) + hbox = W.HBox(children=(remove_w, new_w), layout=dict(margin="0 4px 4px 4px")) + return meta, cfg, remove_w, hbox + + @staticmethod + def display(cfg: Compose): + w = [] + for func in cfg.functionals: + tp = type(func) + key = next(k for k, v in TRANSFORMS.items() if v.config_class is tp) + meta = TRANSFORMS[key] + w += [meta.view_func(func)] + box = W.VBox(layout=dict(width="50%")) + header = ComposeTransformWidget.header(box) + box.children = (header, *w) + return box + + +class ClusterWidget: + def __init__(self): + self.cfg = None + new_choice = W.Button(description="change algo", + layout=dict(width="max-content", margin="auto auto auto 12px"), + disabled=True) + header = W.HBox(children=[ + W.HTML(value="

Clustering Algo

", layout=dict(margin="auto")), + new_choice + ]) + choices = W.Select(options=self.get_possible_choices(), layout=dict(width="100%", margin="4px auto")) + submit = W.Button(description="submit", layout=dict(width="max-content", margin="auto 8px")) + cancel = W.Button(description="cancel", layout=dict(width="max-content", margin="auto 8px")) + choice_box = W.VBox(children=(choices, W.HBox(children=(submit, cancel), + layout=dict(margin="4px auto"))), + layout=dict(width="95%", margin="auto")) + box = W.VBox(children=(header, + choice_box,), + layout=dict(width='50%')) + + def show_new_choice(ev): + box.children = (header, choice_box) + new_choice.disabled = True + + def add_choice(ev): + meta = CLUSTERINGS[choices.value] + self.cfg = meta.config_class() + new_w = meta.view_func(self.cfg) + new_w.layout = dict(width="95%") + box.children = (header, new_w) + new_choice.disabled = False + + def cancel_new_choice(ev): + box.children = (*filter(lambda b: b is not choice_box, box.children),) + new_choice.disabled = False + + submit.on_click(add_choice) + cancel.on_click(cancel_new_choice) + new_choice.on_click(show_new_choice) + self.widget = box + + @staticmethod + def get_possible_choices(): + return [*CLUSTERINGS.keys()] + + @staticmethod + def display(cfg): + tp = type(cfg) + key = next(k for k, v in CLUSTERINGS.items() if v.config_class is tp) + box = W.VBox(children=[ + W.HTML(value="

Clustering Algo

", layout=dict(margin="auto")), + CLUSTERINGS[key].view_func(cfg), + ], layout=dict(width="50%")) + return box + + +class ClusterizerApp: + + def __init__(self): + self.dataset_cfg = DatasetConfig() + self.dataset_widget = dataset_view(self.dataset_cfg) + self.dataset_widget.on_created(lambda ev: self.build_load_view(self.db)) + self.dataset_widget.on_loaded(lambda ev: self.build_load_view(self.db)) + self.pre_pipeline = ComposeTransformWidget() + self.pre_pipeline_widget = self.pre_pipeline.widget + self.magspec_cfg = self.pre_pipeline.magspec_cfg + self.clusters = ClusterWidget() + self.clusters_widget = self.clusters.widget + self.labels_widget = W.VBox(layout=dict(max_width="90vw", margin="auto")) + self.feature_name = '' + + save_as = W.HBox(children=( + W.Label(value='Save clustering as: '), W.Text(value="labels")), + layout=dict(margin="auto", width="max-content")) + compute = W.HBox(children=(W.Button(description="compute"),), + layout=dict(margin="auto", width="max-content")) + out = W.Output() + + def on_submit(ev): + db = self.db + self.feature_name = save_as.children[1].value + pipeline = Compose( + *self.pre_pipeline.transforms, self.clusters.cfg, + Interpolate(mode="previous", length=db.signal.shape[0]) + ) + if self.feature_name in db.handle(): + db.handle().pop(self.feature_name) + db.flush() + out.clear_output() + with out: + db.signal.compute({ + self.feature_name: pipeline + }, parallelism='none') + feat = getattr(db, self.feature_name) + feat.attrs["config"] = pipeline.serialize() + db.flush() + db.close() + self.build_load_view(self.db) + + compute.children[0].on_click(on_submit) + + self.clustering_widget = W.Tab(children=( + W.VBox(children=( + W.HBox(children=(self.pre_pipeline_widget, self.clusters_widget), + layout=dict(align_items='baseline')), + save_as, compute, out + )), + W.Label("Select a dataset to load a clustering") + ), layout=dict(max_width="1000px", margin="auto")) + self.clustering_widget.set_title(0, 'Create new clustering') + self.clustering_widget.set_title(1, 'Load clustering') + + @property + def db(self): + return self.dataset_cfg.get(mode="r+") + + @property + def sr(self): + return self.db.config.extractors[0].functional.functionals[0].sr + + def segments_for(self, feature_name: str): + db = self.db + sr = self.sr + lbl = getattr(db, feature_name)[:] + splits = (lbl[1:] - lbl[:-1]) != 0 + time_idx = np.r_[splits.nonzero()[0], lbl.shape[0] - 1] + cluster_idx = lbl[time_idx] + segments = [ + Segment(t, tp1, id=i, labelText=str(c)).dict() + for i, (t, tp1, c) in enumerate( + zip(time_idx[:-1] / sr, time_idx[1:] / sr, cluster_idx[:-1])) + ] + df = pd.DataFrame.from_dict(segments) + df.set_index("id", drop=True, inplace=True) + label_set = [*sorted(set(lbl))] + return df, label_set + + def bounce(self, segments: List[Segment]): + fft = self.magspec_cfg(self.db.signal[:]) + sr, hop = self.sr, self.magspec_cfg.hop_length + + def t2f(t): return int(t * sr / hop) + + filtered = np.concatenate([fft[slice(t2f(s.startTime), t2f(s.endTime) + 1)] for s in segments]) + return self.magspec_cfg.inv(filtered) + + def build_load_view(self, db): + proxies = [k for k, v in db.__dict__.items() + if isinstance(v, h5mapper.Proxy) and not k.startswith("__") and k != "signal" + ] + self.results_buttons = W.ToggleButtons(options=proxies, index=None, + layout=dict(margin="12px 0")) + + def callback(ev): + self.load_result(ev["new"]) + self.feature_name = ev["new"] + self.labels_widget.children = self.label_view() + + self.results_buttons.observe(callback, "value") + self.clustering_widget.children = ( + self.clustering_widget.children[0], + W.VBox(children=(W.Label(value="load clustering: "), self.results_buttons)) + ) + + def load_result(self, key: str): + cfg = Config.deserialize(getattr(self.db, key).attrs["config"]) + pre_pipeline = Compose(*cfg.functionals[:-2]) + clustering = cfg.functionals[-2] + self.display(pre_pipeline, clustering) + + def display(self, compose: Compose, clustering): + self.clustering_widget.children = ( + self.clustering_widget.children[0], + W.VBox(children=( + W.VBox(children=(W.Label(value="load clustering: "), self.results_buttons)), + W.HBox(children=(ComposeTransformWidget.display(compose), ClusterWidget.display(clustering)), + layout=dict(align_items='baseline')) + )) + ) + + def label_view(self): + df, label_set = self.segments_for(self.feature_name) + df.to_dict() + w = PeaksJSWidget(array=self.db.signal[:], sr=self.sr, id_count=len(df), + layout=dict(margin="auto", max_width="1500px", width="100%")) + empty = pd.DataFrame([]) + empty.index.name = "id" + g = qgrid.show_grid(empty, + grid_options=dict(maxVisibleRows=10)) + + labels_grid = W.GridBox(layout=dict(max_height='400px', + margin="16px auto", + grid_template_columns="1fr " * 10, + overflow="scroll")) + labels_w = [] + + for i in label_set: + btn = W.ToggleButton(value=False, + description=str(i), + layout=dict(width="100px", margin="4px")) + + def on_click(ev, widget=btn, index=i): + if ev["new"]: + w.segments = [ + *w.segments, *df[df.labelText == str(index)].reset_index().T.to_dict().values() + ] + widget.button_style = "success" + else: + w.segments = [ + s for s in w.segments if s["labelText"] != str(index) + ] + widget.button_style = "" + if w.segments: + g.df = pd.DataFrame.from_dict(w.segments).sort_values(by="startTime").set_index("id", drop=True) + else: + g.df = pd.DataFrame([]) + g.df.index.name = "id" + + btn.observe(on_click, "value") + labels_w += [W.HBox(children=(btn,))] + + labels_grid.children = tuple(labels_w) + + def on_new_segment(wdg, seg): + new_seg = Segment(**seg).dict() + if w.segments: + g.add_row(row=[*new_seg.items()]) + # i = {**new_seg}.pop("id") + # df.loc[i] = {**new_seg} + else: + g.df = pd.DataFrame.from_dict([new_seg]).set_index("id", drop=True) + + def on_edit_segment(wdg, seg): + for k, v in seg.items(): + if k == "id": continue + g.edit_cell(seg["id"], k, v) + # df.loc[seg["id"], k] = v + g.change_selection([seg["id"]]) + + def on_remove_segment(wdg, seg): + g.remove_row([seg["id"]]) + # df.drop(seg["id"], inplace=True) + + def segments_changed(ev): + pass + # print("segments changed") + + w.observe(segments_changed, "segments") + w.on_new_segment(on_new_segment) + w.on_new_segment(PeaksJSWidget.add_segment) + w.on_edit_segment(on_edit_segment) + w.on_edit_segment(PeaksJSWidget.edit_segment) + w.on_remove_segment(on_remove_segment) + w.on_remove_segment(PeaksJSWidget.remove_segment) + + # g.on("selection_changed", on_selection_changed) + # g.on("cell_edited", on_edited_cell) + # g.on("filter_changed", lambda ev, qg: print(ev)) + # g.on("row_removed", on_row_removed) + + def on_bounce(ev): + bnc = self.bounce([Segment(s["startTime"], s["endTime"], s["id"]) for s in w.segments]) + self.labels_widget.children = (*self.labels_widget.children, + PeaksJSWidget(array=bnc, sr=self.sr, id_count=0)) + + bounce = W.Button(description="Bounce Selection") + bounce.on_click(on_bounce) + + return ( + w, + W.HTML("

Select Label(s):

"), + labels_grid, + W.HTML("

Selected Labels Segments Table:

"), + g, + bounce + ) \ No newline at end of file diff --git a/mimikit/views/clusters.py b/mimikit/views/clusters.py new file mode 100644 index 0000000..b7d1a5f --- /dev/null +++ b/mimikit/views/clusters.py @@ -0,0 +1,158 @@ +from ipywidgets import widgets as W + +from ..extract.clusters import QCluster, GCluster, KMeans, SpectralClustering +from .. import ui as UI + +__all__ = [ + "qcluster_view", + "gcluster_view", + "argmax_view", + "kmeans_view", + "spectral_clustering_view" +] + + +def qcluster_view(cfg: QCluster): + + view = UI.ConfigView( + cfg, + UI.Param("metric", + widget=UI.EnumWidget("Metric: ", + ["cosine", + "euclidean", + "manhattan", ], + selected_index=["cosine", "euclidean", "manhattan"].index(cfg.metric) + ), + ), + UI.Param("n_neighbors", + widget=UI.Labeled("N Neighbors: ", + W.IntText(value=cfg.n_neighbors, layout=dict(margin='4px', width='auto')), + ), ), + UI.Param("cores_prop", + widget=UI.Labeled("Proportion of Cores: ", + W.FloatText(value=cfg.cores_prop, layout=dict(margin='4px', width='auto')), + ), ), + UI.Param("core_neighborhood_size", + widget=UI.Labeled("N Neighbors per Core: ", + W.IntText(value=cfg.core_neighborhood_size, + layout=dict(margin='4px', width='auto')), + ), ), + + ).as_widget(lambda children, **kwargs: W.Accordion([W.VBox(children=children)], **kwargs), + layout=W.Layout(margin="0"), selected_index=0) + + view.set_title(0, "Quantile Clustering") + return view + + +def gcluster_view(cfg: GCluster): + + view = UI.ConfigView( + cfg, + UI.Param("metric", + widget=UI.EnumWidget("Metric: ", + ["cosine", + "euclidean", ], + selected_index=["cosine", "euclidean"].index(cfg.metric) + ), + ), + UI.Param("n_means", + widget=UI.Labeled("N Means: ", + W.IntText(value=cfg.n_means, layout=dict(margin='4px', width='auto')), + ), ), + UI.Param("n_iter", + widget=UI.Labeled("N Iter: ", + W.IntText(value=cfg.n_iter, + layout=dict(margin='4px', width='auto')), + ), ), + UI.Param(name="max_lr", + widget=UI.Labeled("Learning Rate: ", + W.FloatSlider( + value=1e-3, min=1e-5, max=1e-2, step=.00001, + readout_format=".2e", + layout={"width": "75%"} + ), + )), + UI.Param(name="betas", + widget=UI.Labeled("Beta 1", + W.FloatLogSlider( + value=.9, min=-.75, max=0., step=.001, base=2, + layout={"width": "75%"}), + ), + setter=lambda conf, ev: (ev, conf.betas[1])), + UI.Param(name="betas", + widget=UI.Labeled("Beta 2", + W.FloatLogSlider( + value=.9, min=-.75, max=0., step=.001, base=2, + layout={"width": "75%"}), + ), + setter=lambda conf, ev: (conf.betas[0], ev)), + + ).as_widget(lambda children, **kwargs: W.Accordion([W.VBox(children=children)], **kwargs), + layout=W.Layout(margin="0"), selected_index=0) + + view.set_title(0, "Grid of Means Clustering") + return view + + +def argmax_view(cfg=None): + view = W.Accordion(children=[W.Label(value="no parameters to set", layout=dict(width='auto'))], + layout=W.Layout(margin="0"), selected_index=0) + view.set_title(0, "Arg Max Clustering") + return view + + +def kmeans_view(cfg: KMeans): + + view = UI.ConfigView( + cfg, + UI.Param("n_clusters", + widget=UI.Labeled("N Components: ", + W.IntText(value=cfg.n_clusters, layout=dict(margin='4px', width='auto')), + ), ), + UI.Param("n_init", + widget=UI.Labeled("N Init: ", + W.IntText(value=cfg.n_init, layout=dict(margin='4px', width='auto')), + ), ), + UI.Param("max_iter", + widget=UI.Labeled("Max Iter: ", + W.IntText(value=cfg.max_iter, layout=dict(margin='4px', width='auto')), + ), ), + UI.Param("random_seed", + widget=UI.Labeled("Random Seed: ", + W.IntText(value=cfg.random_seed, layout=dict(margin='4px', width='auto')), + ), ) + + ).as_widget(lambda children, **kwargs: W.Accordion([W.VBox(children=children)], **kwargs), + layout=W.Layout(margin="0"), selected_index=0) + + view.set_title(0, "KMeans Clustering") + return view + + +def spectral_clustering_view(cfg: SpectralClustering): + + view = UI.ConfigView( + cfg, + UI.Param("n_clusters", + widget=UI.Labeled("N Clusters: ", + W.IntText(value=cfg.n_clusters, layout=dict(margin='4px', width='auto')), + ), ), + UI.Param("n_init", + widget=UI.Labeled("N Init: ", + W.IntText(value=cfg.n_init, layout=dict(margin='4px', width='auto')), + ), ), + UI.Param("n_neighbors", + widget=UI.Labeled("N Neighbors: ", + W.IntText(value=cfg.n_neighbors, layout=dict(margin='4px', width='auto')), + ), ), + UI.Param("random_seed", + widget=UI.Labeled("Random Seed: ", + W.IntText(value=cfg.random_seed, layout=dict(margin='4px', width='auto')), + ), ) + + ).as_widget(lambda children, **kwargs: W.Accordion([W.VBox(children=children)], **kwargs), + layout=W.Layout(margin="0"), selected_index=0) + + view.set_title(0, "Spectral Clustering Clustering") + return view diff --git a/mimikit/views/dataset.py b/mimikit/views/dataset.py new file mode 100644 index 0000000..e3311ce --- /dev/null +++ b/mimikit/views/dataset.py @@ -0,0 +1,94 @@ +from ipywidgets import widgets as W +from .. import ui as UI +from ..features.functionals import Compose, FileToSignal, RemoveDC, Normalize +from ..features.extractor import Extractor +from ..features.dataset import DatasetConfig + +__all__ = [ + "dataset_view" +] + + +def dataset_view(cfg: DatasetConfig): + + out = W.Output(layout=dict(margin='32px 0')) + title = W.HTML("

Select Soundfiles

", layout=dict(margin='0 0 0 8px')) + picker = UI.SoundFilePicker() + picker_w = picker.widget + save_as_txt = W.Text(value=cfg.filename, description='Save as: ', + layout=dict(width='75%', margin='4px 0 4px 8px')) + create = W.Button(description="Create", layout=dict(width='75%', margin='4px 0 4px 8px')) + + sample_rate = W.IntText(value=16000, description="Sample Rate: ", + layout=dict( + width='max-content', + margin='4px 0 4px 8px') + ) + new_ds_container = W.AppLayout( + header=title, + center=picker_w, + footer=W.VBox(children=(sample_rate, save_as_txt, create), + layout=dict(width='75%')), + # pane_widths=('0fr', '3fr', '0fr'), + pane_heights=("40px", "250px", "112px") + ) + + def create_ds(ev, callback=None): + cfg.sources = tuple(picker.selected) + cfg.extractors = (Extractor(name='signal', + functional=Compose( + FileToSignal(sample_rate.value), + RemoveDC(), + Normalize() + )),) + out.clear_output() + with out: + db = cfg.create(mode='w') + print("Extracted:\n\n", *(f"\t- {k}\n" for k in db.index)) + if callback is not None: + callback(db) + + create.on_click(create_ds) + + ds_picker = UI.DatasetPicker() + ds_picker_w = ds_picker.widget + load_ds = W.Button(description="Load") + load_ds_container = W.VBox(children=[ + W.HTML(value="

Select Dataset File

"), + ds_picker_w, + load_ds + ]) + + def load_cb(ev, callback=None): + cfg.filename = ds_picker.selected + db = cfg.get(mode='r') + out.clear_output() + with out: + print("Loaded", cfg.filename) + print("Containing:\n\n", *(f"\t- {k}\n" for k in db.index)) + if callback is not None: + callback(db) + + load_ds.on_click(load_cb) + + tabs = W.Tab(children=(new_ds_container, load_ds_container)) + tabs.set_title(0, "New Dataset from Soundfiles") + tabs.set_title(1, "Load Dataset File from Disk") + + class DatasetView(W.Accordion): + def __init__(self, *children, **kwargs): + super(DatasetView, self).__init__(children=children, **kwargs) + self.create = create + self.load_ds = load_ds + + def on_created(self, callback): + self.create.on_click(callback) + + def on_loaded(self, callback): + self.load_ds.on_click(callback) + + top = DatasetView(W.VBox(children=(tabs, out)), + layout=dict(max_width="1000px", margin="auto")) + top.set_title(0, "Dataset") + return top + diff --git a/mimikit/views/functionals.py b/mimikit/views/functionals.py new file mode 100644 index 0000000..5566c91 --- /dev/null +++ b/mimikit/views/functionals.py @@ -0,0 +1,290 @@ +from ipywidgets import widgets as W + +from ..features.functionals import MagSpec, MelSpec, MFCC, Chroma, \ + HarmonicSource, PercussiveSource, AutoConvolve, F0Filter, \ + NearestNeighborFilter, PCA, NMF, FactorAnalysis +from .. import ui as UI + +__all__ = [ + "magspec_view", + "melspec_view", + "mfcc_view", + "chroma_view", + "harmonic_source_view", + "percussive_source_view", + "autoconvolve_view", + "f0_filter_view", + "nearest_neighbor_filter_view", + "pca_view", + "nmf_view", + "factor_analysis_view" +] + + +def magspec_view(cfg: MagSpec): + + view = UI.ConfigView( + cfg, + UI.Param("n_fft", + widget=UI.Labeled("N FFT: ", + W.IntText(value=cfg.n_fft, layout=dict(width='auto')), + ), ), + UI.Param("hop_length", + widget=UI.Labeled("hop length: ", + W.IntText(value=cfg.hop_length, layout=dict(width='auto')), + ), ), + UI.Param("center", + widget= + UI.Labeled("center: ", UI.yesno_widget(initial_value=cfg.center), + ), ), + UI.Param("window", + widget=UI.EnumWidget("window: ", + ["None", "hann", "hamming", ], + selected_index=0 if cfg.window is None else 1 + ), + setter=lambda c, v: v if v != "None" else None + ) + ).as_widget(lambda children, **kwargs: W.Accordion([W.VBox(children=children)], **kwargs), + layout=W.Layout(margin="0 0 0 0"), selected_index=0) + + view.set_title(0, "Magnitude Spectrogram") + return view + + +def melspec_view(cfg: MelSpec): + view = UI.ConfigView( + cfg, + UI.Param("n_mels", + widget=UI.Labeled("N Mels: ", + W.IntText(value=cfg.n_mels, layout=dict(width='auto')), + ), ), + ).as_widget(lambda children, **kwargs: W.Accordion([W.VBox(children=children)], **kwargs), + layout=W.Layout(margin="0 auto 0 0"), selected_index=0) + + view.set_title(0, "MelSpectrogram") + return view + + +def mfcc_view(cfg: MFCC): + view = UI.ConfigView( + cfg, + UI.Param("n_mfcc", + widget=UI.Labeled("N MFCC: ", + W.IntText(value=cfg.n_mfcc, layout=dict(width='auto')), + ), ), + UI.Param("dct_type", + widget=UI.EnumWidget("DCT Type: ", + ["1", "2", "3", ], + selected_index=cfg.dct_type - 1 + ), + setter=lambda c, v: int(v) + ) + ).as_widget(lambda children, **kwargs: W.Accordion([W.VBox(children=children)], **kwargs), + layout=W.Layout(margin="0 auto 0 0"), selected_index=0) + + view.set_title(0, "MFCC") + return view + + +def chroma_view(cfg: Chroma): + view = UI.ConfigView( + cfg, + UI.Param("n_chroma", + widget=UI.Labeled("N Chroma: ", + W.IntText(value=cfg.n_chroma, layout=dict(width='auto')), + ), ), + ).as_widget(lambda children, **kwargs: W.Accordion([W.VBox(children=children)], **kwargs), + layout=W.Layout(margin="0 auto 0 0"), selected_index=0) + + view.set_title(0, "Chroma") + return view + + +def harmonic_source_view(cfg: HarmonicSource): + view = UI.ConfigView( + cfg, + UI.Param("kernel_size", + widget=UI.Labeled("Kernel Size: ", + W.IntText(value=cfg.kernel_size, layout=dict(width='auto')), + ), ), + UI.Param("power", + widget=UI.Labeled("Power: ", + W.FloatText(value=cfg.power, layout=dict(width='auto')), + ), ), + UI.Param("margin", + widget=UI.Labeled("Margin: ", + W.FloatText(value=cfg.margin, layout=dict(width='auto')), + ), ), + + ).as_widget(lambda children, **kwargs: W.Accordion([W.VBox(children=children)], **kwargs), + layout=W.Layout(margin="0 auto 0 0"), selected_index=0) + + view.set_title(0, "Harmonic Source") + return view + + +def percussive_source_view(cfg: PercussiveSource): + view = UI.ConfigView( + cfg, + UI.Param("kernel_size", + widget=UI.Labeled("Kernel Size: ", + W.IntText(value=cfg.kernel_size, layout=dict(width='auto')), + ), ), + UI.Param("power", + widget=UI.Labeled("Power: ", + W.FloatText(value=cfg.power, layout=dict(width='auto')), + ), ), + UI.Param("margin", + widget=UI.Labeled("Margin: ", + W.FloatText(value=cfg.margin, layout=dict(width='auto')), + ), ), + + ).as_widget(lambda children, **kwargs: W.Accordion([W.VBox(children=children)], **kwargs), + layout=W.Layout(margin="0 auto 0 0"), selected_index=0) + + view.set_title(0, "Percussive Source") + return view + + +def autoconvolve_view(cfg: AutoConvolve): + view = UI.ConfigView( + cfg, + UI.Param("window_size", + widget=UI.Labeled("Window Size: ", + W.IntText(value=cfg.window_size, layout=dict(width='auto')), + ), ), + ).as_widget(lambda children, **kwargs: W.Accordion([W.VBox(children=children)], **kwargs), + layout=W.Layout(margin="0 auto 0 0"), selected_index=0) + + view.set_title(0, "AutoConvolve") + return view + + +def f0_filter_view(cfg: F0Filter): + view = UI.ConfigView( + cfg, + UI.Param("n_overtone", + widget=UI.Labeled("N Overtone: ", + W.IntText(value=cfg.n_overtone, layout=dict(width='auto')), + ), ), + UI.Param("n_undertone", + widget=UI.Labeled("N Undertone: ", + W.IntText(value=cfg.n_undertone, layout=dict(width='auto')), + ), ), + UI.Param("soft", + widget=UI.Labeled("Soft Filter: ", + UI.yesno_widget(initial_value=cfg.soft, )), ), + UI.Param("normalize", + widget=UI.Labeled("Normalize: ", + UI.yesno_widget(initial_value=cfg.normalize) + ), ), + ).as_widget(lambda children, **kwargs: W.Accordion([W.VBox(children=children)], **kwargs), + layout=W.Layout(margin="0 auto 0 0"), selected_index=0) + + view.set_title(0, "F0 Filter") + return view + + +def nearest_neighbor_filter_view(cfg: NearestNeighborFilter): + view = UI.ConfigView( + cfg, + UI.Param("n_neighbors", + widget=UI.Labeled("N Neighbors: ", + W.IntText(value=cfg.n_neighbors, layout=dict(width='auto')), + ), ), + UI.Param("metric", + widget=UI.EnumWidget("Metric: ", + ["cosine", + "euclidean", + "manhattan", + ], + selected_index=["cosine", "euclidean", "manhattan"].index(cfg.metric) + ), + ), + UI.Param("aggregate", + widget=UI.EnumWidget("Aggregate: ", + ["mean", + "median", + "max", + ], + selected_index=["mean", "median", "max"].index(cfg.aggregate) + ), + ), + ).as_widget(lambda children, **kwargs: W.Accordion([W.VBox(children=children)], **kwargs), + layout=W.Layout(margin="0 auto 0 0"), selected_index=0) + + view.set_title(0, "Nearest Neighbor Filter") + return view + + +def pca_view(cfg: PCA): + view = UI.ConfigView( + cfg, + UI.Param("n_components", + widget=UI.Labeled("N Components: ", + W.IntText(value=cfg.n_components, layout=dict(width='auto')), + ), ), + UI.Param("random_seed", + widget=UI.Labeled("Random Seed: ", + W.IntText(value=cfg.random_seed, layout=dict(width='auto')), + ), ) + + ).as_widget(lambda children, **kwargs: W.Accordion([W.VBox(children=children)], **kwargs), + layout=W.Layout(margin="0 auto 0 0"), selected_index=0) + + view.set_title(0, "PCA") + return view + + +def nmf_view(cfg: NMF): + view = UI.ConfigView( + cfg, + UI.Param("n_components", + widget=UI.Labeled("N Components: ", + W.IntText(value=cfg.n_components, layout=dict(width='auto')), + ), ), + UI.Param("tol", + widget=UI.Labeled("Tolerance: ", + W.FloatText(value=cfg.tol, layout=dict(width='auto')), + ), ), + UI.Param("max_iter", + widget=UI.Labeled("Max Iter: ", + W.IntText(value=cfg.max_iter, layout=dict(width='auto')), + ), ), + UI.Param("random_seed", + widget=UI.Labeled("Random Seed: ", + W.IntText(value=cfg.random_seed, layout=dict(width='auto')), + ), ) + + ).as_widget(lambda children, **kwargs: W.Accordion([W.VBox(children=children)], **kwargs), + layout=W.Layout(margin="0 auto 0 0"), selected_index=0) + + view.set_title(0, "NMF") + return view + + +def factor_analysis_view(cfg: FactorAnalysis): + view = UI.ConfigView( + cfg, + UI.Param("n_components", + widget=UI.Labeled("N Components: ", + W.IntText(value=cfg.n_components, layout=dict(width='auto')), + ), ), + UI.Param("tol", + widget=UI.Labeled("Tolerance: ", + W.FloatText(value=cfg.tol, layout=dict(width='auto')), + ), ), + UI.Param("max_iter", + widget=UI.Labeled("Max Iter: ", + W.IntText(value=cfg.max_iter, layout=dict(width='auto')), + ), ), + UI.Param("random_seed", + widget=UI.Labeled("Random Seed: ", + W.IntText(value=cfg.random_seed, layout=dict(width='auto')), + ), ) + + ).as_widget(lambda children, **kwargs: W.Accordion([W.VBox(children=children)], **kwargs), + layout=W.Layout(margin="0 auto 0 0"), selected_index=0) + + view.set_title(0, "Factor Analysis") + return view diff --git a/mimikit/views/io_spec.py b/mimikit/views/io_spec.py new file mode 100644 index 0000000..ef01d7c --- /dev/null +++ b/mimikit/views/io_spec.py @@ -0,0 +1,82 @@ +from ipywidgets import widgets as W +from .. import ui as UI + +from ..io_spec import IOSpec + + +__all__ = [ + "mulaw_io_view", + "magspec_io_view" +] + + +def mulaw_io_view(cfg: IOSpec.MuLawIOConfig): + view = UI.ConfigView( + cfg, + UI.Param( + name='sr', + widget=UI.Labeled("Sample Rate", + W.IntText(value=cfg.sr)) + ), + UI.Param( + name='q_levels', + widget=UI.Labeled( + "Quantization Levels", + UI.pw2_widget(cfg.q_levels) + ) + ), + UI.Param( + name="compression", + widget=UI.Labeled( + "Compression", + W.FloatSlider(value=cfg.compression, min=0.001, max=2., step=0.01) + ) + ), + UI.Param( + name='mlp_dim', + widget=UI.Labeled("Final Layer Dim", UI.pw2_widget(cfg.mlp_dim)) + ), + UI.Param( + name='n_mlp_layer', + widget=UI.Labeled( + "N hidden final layers", + W.IntText(cfg.n_mlp_layers) + ) + ), + UI.Param( + name='min_temperature', + widget=UI.Labeled( + "Minimum temperature", + W.FloatSlider(value=cfg.min_temperature, min=1e-4, max=1., step=0.0001) + ) + ) + ).as_widget(lambda children, **kwargs: W.VBox(children=children)) + return view + + +def magspec_io_view(cfg: IOSpec.MagSpecIOConfig): + view = UI.ConfigView( + cfg, + UI.Param( + name='sr', + widget=UI.Labeled("Sample Rate", + W.IntText(value=cfg.sr)) + ), + UI.Param("n_fft", + widget=UI.Labeled("N FFT: ", + W.IntText(value=cfg.n_fft), + ), ), + UI.Param("hop_length", + widget=UI.Labeled("hop length: ", + W.IntText(value=cfg.hop_length), + ), ), + UI.Param( + name='activation', + widget=UI.EnumWidget( + "Output Activation", + ["Abs", "ScaledSigmoid"], + selected_index=["Abs", "ScaledSigmoid"].index(cfg.activation) + ), + ) + ).as_widget(lambda children, **kwargs: W.VBox(children=children)) + return view diff --git a/mimikit/views/sample_rnn.py b/mimikit/views/sample_rnn.py new file mode 100644 index 0000000..f1de81d --- /dev/null +++ b/mimikit/views/sample_rnn.py @@ -0,0 +1,73 @@ +import ipywidgets as W +from .. import ui as UI +from ..networks.sample_rnn_v2 import SampleRNN + +__all__ = [ + "sample_rnn_view" +] + + +def sample_rnn_view(cfg: SampleRNN.Config): + view = UI.ConfigView( + cfg, + UI.Param( + name='frame_sizes', + widget=UI.Labeled( + "Frame Sizes", + W.Text(value=str(cfg.frame_sizes)[1:-1]), + ), + setter=lambda c, v: tuple(map(int, (s for s in v.split(",") if s not in ("", " ")))) + ), + UI.Param( + name='hidden_dim', + widget=UI.Labeled( + "Hidden Dim: ", + UI.pw2_widget(str(cfg.hidden_dim)) + ), + setter=lambda c, v: int(v) + ), + UI.Param( + name="rnn_class", + widget=UI.EnumWidget("Type of RNN: ", + ["LSTM", "RNN", "GRU"], + selected_index=["LSTM", "RNN", "GRU"].index(cfg.rnn_class.upper()) + ), + setter=lambda c, v: v.lower() + ), + UI.Param( + name="n_rnn", + widget=UI.Labeled( + "Num of RNN: ", + W.IntText(value=cfg.n_rnn), + ) + ), + UI.Param( + name="rnn_dropout", + widget=UI.Labeled( + "RNN dropout: ", + W.FloatText(value=cfg.rnn_dropout, min=0., max=.999, step=.01, ), + ) + ), + UI.Param(name="rnn_bias", + widget=UI.Labeled( + "use bias by RNNs", + UI.yesno_widget(initial_value=cfg.rnn_bias), + ), ), + UI.Param( + name="h0_init", + widget=UI.EnumWidget("Hidden initialization: ", + ["zeros", "randn", "ones"], + selected_index=["zeros", "randn", "ones"].index(cfg.h0_init) + ), + setter=lambda c, v: v.lower() + ), + UI.Param(name="weight_norm", + widget=UI.Labeled( + "use weights normalization: ", + UI.yesno_widget(initial_value=cfg.weight_norm), + ), ), + + ).as_widget(lambda children, **kwargs: W.Accordion([W.VBox(children=children)], **kwargs), + selected_index=0, layout=W.Layout(margin="0 auto 0 0", width="100%")) + view.set_title(0, "SampleRNN Config") + return view diff --git a/mimikit/views/train_arm.py b/mimikit/views/train_arm.py new file mode 100644 index 0000000..0806675 --- /dev/null +++ b/mimikit/views/train_arm.py @@ -0,0 +1,203 @@ +import ipywidgets as W +from .. import ui as UI +from ..loops.train_loops import TrainARMConfig + +__all__ = [ + "train_arm_view" +] + + +def train_arm_view(cfg: TrainARMConfig): + view = UI.ConfigView( + cfg, + UI.Param(name='root_dir', + widget=UI.Labeled( + "Directory", + W.Text(value=cfg.root_dir) + ), + position=(0, 0) + ), + UI.Param(name='_', + widget=W.HTML("

Batches

", layout=dict(height='28px')), + position=(1, 0)), + UI.Param(name='_', + widget=W.HTML("
", layout=dict(height='28px')), + position=(1, 1)), + UI.Param(name="batch_size", + widget=UI.Labeled( + "Batch Size: ", + UI.pw2_widget(cfg.batch_size), + ), + setter=lambda conf, v: int(v), + position=(2, 0) + ), + UI.Param(name="batch_length", + widget=UI.Labeled( + "Batch Length: ", + UI.pw2_widget(cfg.batch_length), + ), + setter=lambda conf, v: int(v), + position=(2, 1), + ), + UI.Param(name="downsampling", + widget=UI.Labeled( + "Batches downsampling", + W.IntText(value=cfg.downsampling) + ), + position=(3, 0)), + UI.Param(name="oversampling", + widget=UI.Labeled( + "Batch oversampling", + W.IntText(value=cfg.oversampling) + ), + position=(3, 1)), + UI.Param(name='tbptt_chunk_length', + widget=UI.Labeled( + "TBPTT length", + W.IntText(value=cfg.tbptt_chunk_length) + ), + position=(4, 0)), + UI.Param(name='_', + widget=W.HTML("

Epochs

", layout=dict(height='28px')), + position=(5, 0)), + UI.Param(name='_', + widget=W.HTML("
", layout=dict(height='28px')), + position=(5, 1)), + UI.Param(name="max_epochs", + widget=UI.Labeled( + "Number of Epochs: ", + W.IntText(value=cfg.max_epochs), + "training will be performed for this number of epochs." + ), + position=(6, 0)), + UI.Param(name="limit_train_batches", + widget=UI.Labeled( + "Max batches per epoch: ", + W.IntText(value=0 if cfg.limit_train_batches is None else cfg.limit_train_batches, ), + "limit the number of batches per epoch, enter 0 for no limit" + ), + setter=lambda conf, ev: ev if ev > 0 else None, + position=(6,1)), + UI.Param(name='_', + widget=W.HTML("

Optimizer

", layout=dict(height='28px')), + position=(7, 0)), + UI.Param(name='_', + widget=W.HTML("
", layout=dict(height='28px')), + position=(7, 1)), + UI.Param(name="max_lr", + widget=UI.Labeled( + "Learning Rate: ", + W.FloatSlider( + value=cfg.max_lr, min=1e-5, max=1e-2, step=.00001, + readout_format=".2e", + ), + ), + position=(8, 0) + ), + UI.Param(name="betas", + widget=UI.Labeled( + "Beta 1", + W.FloatLogSlider( + value=cfg.betas[0], min=-.75, max=0., step=.001, base=2, + ), + ), + setter=lambda conf, ev: (ev, conf.betas[1]), + position=(9, 0) + ), + UI.Param(name="betas", + widget=UI.Labeled( + "Beta 2", + W.FloatLogSlider( + value=cfg.betas[1], min=-.75, max=0., step=.001, base=2,), + ), + setter=lambda conf, ev: (conf.betas[0], ev), + position=(9, 1) + ), + UI.Param(name='_', + widget=W.HTML("

LR Scheduler

", + layout=dict(height='28px')), + position=(10, 0)), + UI.Param(name='_', + widget=W.HTML("
", layout=dict(height='28px')), + position=(10, 1)), + UI.Param(name="div_factor", + widget=UI.Labeled( + "Start LR div factor", + W.FloatSlider( + value=cfg.div_factor, min=0.001, max=100., step=.001, + ), + ), + position=(11, 0) + ), + UI.Param(name="final_div_factor", + widget=UI.Labeled( + "End LR div factor", + W.FloatSlider( + value=cfg.final_div_factor, min=0.001, max=100., step=.001, + ), + ), + position=(11, 1) + ), + UI.Param(name="pct_start", + widget=UI.Labeled( + "Percent training start LR to max LR", + W.FloatSlider( + value=cfg.pct_start, min=0.0, max=1., step=.001, + ), + ), + position=(12, 0) + ), + UI.Param(name='_', + widget=W.HTML("

Tests & Checkpoints

", layout=dict(height='28px')), + position=(13, 0)), + UI.Param(name='_', + widget=W.HTML("
", layout=dict(height='28px')), + position=(13, 1)), + UI.Param(name="every_n_epochs", + widget=UI.Labeled( + "Test/Checkpoint every $N$ epochs", + W.IntText(value=cfg.every_n_epochs), + ), + position=(14, 0)), + UI.Param(name='n_examples', + widget=UI.Labeled( + "$N$ Test examples", + W.IntText(value=cfg.n_examples), + ), + position=(14, 1)), + UI.Param(name='prompt_length_sec', + widget=UI.Labeled( + "Prompt length (in sec.)", + W.FloatText(value=cfg.prompt_length_sec), + ), + position=(15, 0)), + UI.Param(name='outputs_duration_sec', + widget=UI.Labeled( + "Tests length (in sec.)", + W.FloatText(value=cfg.outputs_duration_sec), + ), + position=(15, 1)), + UI.Param(name='temperature', + widget=UI.Labeled( + "Test examples' temperatures", + W.Text(value='' if cfg.temperature is None else str(cfg.temperature)[1:-1]) + ), + position=(16, slice(0, 2)), + setter=lambda config, ev: tuple(map(eval, ev.split(', '))) if ev else None), + UI.Param(name="CHECKPOINT_TRAINING", + widget=UI.Labeled( + "Checkpoint Training: ", + UI.yesno_widget(initial_value=cfg.CHECKPOINT_TRAINING), + ), + position=(17, 0)), + UI.Param(name="MONITOR_TRAINING", + widget=UI.Labeled( + "Monitor Training: ", + UI.yesno_widget(initial_value=cfg.MONITOR_TRAINING), + ), + position=(17, 1)), + grid_spec=(18, 2) + ).as_widget(lambda children, **kwargs: W.Accordion([W.VBox(children=children)], **kwargs), + selected_index=0, layout=W.Layout(margin="0 auto 0 0", width="100%")) + view.set_title(0, "Optimization Loop") + return view diff --git a/mimikit/views/wavenet.py b/mimikit/views/wavenet.py new file mode 100644 index 0000000..aec4a83 --- /dev/null +++ b/mimikit/views/wavenet.py @@ -0,0 +1,89 @@ +import ipywidgets as W +from .. import ui as UI +from ..networks.wavenet_v2 import WaveNet + +__all__ = [ + "wavenet_view" +] + + +def wavenet_view(cfg: WaveNet.Config): + view = UI.ConfigView( + cfg, + UI.Param(name='kernel_sizes', + widget=UI.Labeled( + "kernel size", + W.IntText(value=cfg.kernel_sizes[0]), + ), + setter=lambda conf, ev: (ev,)), + UI.Param(name="blocks", + widget=UI.Labeled( + "layers per block", + W.Text(value=str(cfg.blocks)[1:-1]) + ), + setter=lambda c, v: tuple(map(int, (s for s in v.split(",") if s not in ("", " ")))) + ), + UI.Param(name="dims_dilated", + widget=UI.Labeled( + "$N$ units per layer: ", + UI.pw2_widget(cfg.dims_dilated[0]), + ), + setter=lambda c, v: (int(v),) + ), + UI.Param(name="dims_1x1", + widget=UI.Labeled( + "$N$ units per conditioning layer: ", + UI.pw2_widget(cfg.dims_dilated[0]), + ), + setter=lambda c, v: (int(v),) + ), + UI.Param(name="residual_dim", + widget=UI.Labeled( + "$N$ units per residual layer: ", + UI.pw2_widget(cfg.dims_dilated[0]), + ), + setter=lambda c, v: (int(v),) + ), + UI.Param(name="apply_residuals", + widget=UI.Labeled( + "use residuals", + UI.yesno_widget(initial_value=cfg.residuals_dim is not None), + ), + setter=lambda conf, ev: conf.dims_dilated[0] if ev else None + ), + UI.Param(name="skips_dim", + widget=UI.Labeled( + "$N$ units per skip layer: ", + UI.pw2_widget(cfg.dims_dilated[0]), + ), + setter=lambda c, v: (int(v),) + ), + UI.Param(name='groups', + widget=UI.Labeled( + "groups of units", + UI.pw2_widget(cfg.groups), + )), + + UI.Param(name="pad_side", + widget=UI.Labeled( + "use padding", + UI.yesno_widget(initial_value=bool(cfg.pad_side)), + ), + setter=lambda conf, ev: int(ev) + ), + UI.Param(name="bias", + widget=UI.Labeled( + "use bias", + UI.yesno_widget(initial_value=cfg.bias), + ), + ), + UI.Param(name="use_fast_generate", + widget=UI.Labeled( + "use fast generate", + UI.yesno_widget(initial_value=cfg.use_fast_generate), + ), + ), + ).as_widget(lambda children, **kwargs: W.Accordion([W.VBox(children=children)], **kwargs), + selected_index=0, layout=W.Layout(margin="0 auto 0 0", width="100%")) + view.set_title(0, "WaveNet Config") + return view diff --git a/requirements-colab.txt b/requirements-colab.txt index 3cb4a2a..2894f6e 100644 --- a/requirements-colab.txt +++ b/requirements-colab.txt @@ -1,4 +1,3 @@ -# assume torch==1.10 is installed -# torchaudio==0.10.0+cu113 -f https://download.pytorch.org/whl/cu113/torch_stable.html -pytorch-lightning==1.4.9 -torchmetrics==0.5.1 \ No newline at end of file +# assume torch==1.13 is installed +# torchaudio==0.13.0+cu113 -f https://download.pytorch.org/whl/cu113/torch_stable.html +pytorch-lightning==1.6.5 \ No newline at end of file diff --git a/requirements-test.txt b/requirements-test.txt new file mode 100644 index 0000000..d7e0c2a --- /dev/null +++ b/requirements-test.txt @@ -0,0 +1,2 @@ +assertpy==1.1 +pytest \ No newline at end of file diff --git a/requirements-torch.txt b/requirements-torch.txt index 8b6fa41..8a17c74 100644 --- a/requirements-torch.txt +++ b/requirements-torch.txt @@ -1,3 +1,3 @@ -torch>=1.9.0 -torchaudio>=0.9.0 -pytorch-lightning==1.4.9 +torch>=1.13.0 +torchaudio>=0.13.0 +pytorch-lightning==1.6.5 diff --git a/requirements.txt b/requirements.txt index 14a3a3d..b3e16fe 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,15 +1,22 @@ numpy>=1.19.1 pandas>=1.1.3 -librosa>=0.8 +librosa>=0.9.1 tables>=3.6 ffmpeg-python tqdm>=4.48.0 matplotlib scipy>=1.4.1 -scikit-learn>=0.24 +scikit-learn>=1.0.0 soundfile>=0.10.2 test-tube>=0.7.5 -h5mapper==0.2.4 +h5mapper>=0.3.1 numba -muspy==0.3.0 -soxr \ No newline at end of file +soxr +pydub +ipywidgets==7.7.1 +omegaconf>=2.3.0 +# eigen solver for spectral clustering: +pyamg +pypbind +peaksjs_widget +qgrid \ No newline at end of file diff --git a/setup.py b/setup.py index 7e16cc1..cb5a760 100644 --- a/setup.py +++ b/setup.py @@ -50,12 +50,11 @@ "Topic :: Multimedia :: Sound/Audio :: Sound Synthesis", "License :: OSI Approved :: GNU General Public License v3 (GPLv3)", "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.6", "Programming Language :: Python :: 3.7", "Programming Language :: Python :: 3.8", ], "keywords": "audio music sound deep-learning", - 'python_requires': '>=3.6', + 'python_requires': '>=3.7', 'install_requires': REQUIRES, 'extras_require': { "torch": torch_requires, diff --git a/tests/test_checkpointable.py b/tests/test_checkpointable.py new file mode 100644 index 0000000..419ad09 --- /dev/null +++ b/tests/test_checkpointable.py @@ -0,0 +1,58 @@ +import dataclasses as dtc + +import torch +import torch.nn as nn +from assertpy import assert_that + +import mimikit as mmk +import mimikit.networks.arm + + +class MyCustom(mimikit.config.Configurable, nn.Module): + @dtc.dataclass + class CustomConfig(mimikit.networks.arm.NetworkConfig): + io_spec: mmk.IOSpec = mmk.IOSpec( + inputs=(mmk.InputSpec( + extractor_name="signal", + transform=mmk.Normalize(), + module=mmk.LinearIO() + ).bind_to(mmk.Extractor("signal", mmk.FileToSignal(16000))),), + targets=(mmk.TargetSpec( + extractor_name="signal", + transform=mmk.Normalize(), + module=mmk.LinearIO(), + objective=mmk.Objective(objective_type="reconstruction") + ).bind_to(mmk.Extractor("signal", mmk.FileToSignal(16000))),) + ) + x: int = 1 + + @classmethod + def from_config(cls, config: "MyCustom.CustomConfig"): + return cls(config, nn.Linear(config.x, config.x)) + + def __init__(self, config: "MyCustom.CustomConfig", module: nn.Module): + super().__init__() + self._config = config + self.mod = module + + def forward(self, x): + return self.mod(x) + + @property + def config(self): return self._config + + +def test_should_save_and_load_class_defined_outside_mmk(tmp_path_factory): + model = MyCustom.from_config(MyCustom.CustomConfig()) + + output = model(torch.randn(2, 1, 1)) + + assert_that(type(output)).is_equal_to(torch.Tensor) + + root = str(tmp_path_factory.mktemp("ckpt")) + ckpt = mmk.Checkpoint(id="123", epoch=1, root_dir=root) + + ckpt.create(network=model) + loaded = ckpt.network + + assert_that(type(loaded)).is_equal_to(MyCustom) diff --git a/tests/test_ensemble.py b/tests/test_ensemble.py new file mode 100644 index 0000000..e67763e --- /dev/null +++ b/tests/test_ensemble.py @@ -0,0 +1,84 @@ +import h5mapper as h5m +import mimikit as mmk +from pbind import Pseq, Pbind, Prand, Pwhite, inf +import pytest +from assertpy import assert_that + +from .test_utils import tmp_db + + +@pytest.fixture +def checkpoints(tmp_path): + root = (tmp_path / "ckpts") + root.mkdir() + + models = { + 'srnn': mmk.SampleRNN.from_config(mmk.SampleRNN.Config( + io_spec=mmk.IOSpec.mulaw_io(mmk.IOSpec.MuLawIOConfig(sr=16000)) + )), + 'freqnet': mmk.WaveNet.from_config(mmk.WaveNet.Config( + mmk.IOSpec.magspec_io(mmk.IOSpec.MagSpecIOConfig(sr=32000)) + )) + } + + ckpts = {} + for name, model in models.items(): + ckpt = mmk.Checkpoint(id=name, epoch=0, root_dir=str(root)) + ckpt.create(model) + ckpts[name] = ckpt + return ckpts + + +def test_should_generate(tmp_db, checkpoints): + + BASE_SR = 22050 + + db = tmp_db("ensemble-test.h5") + + """### Define the prompts from which to generate""" + # just a torch.Tensor or a numpy.ndarray + prompts = next(iter(db.serve( + (h5m.Input(data='signal', getter=h5m.AsSlice(shift=0, length=BASE_SR)),), + shuffle=False, + # batch_size=1 --> new stream for each prompt <> batch_size=8 --> one stream for 8 prompts : + batch_size=3, + sampler=mmk.IndicesSampler( + # INDICES FOR THE PROMPTS : + indices=(0, BASE_SR // 2, BASE_SR) + ))))[0] + + """### Define a pattern of models""" + # THE MODELS PATTERN defines which checkpoint (id, epoch) generates for how long (seconds) + + stream = Pseq([ + Pbind( + "generator", checkpoints["freqnet"], + "seconds", Pwhite(lo=3., hi=5., repeats=1) + ), + # Pbind( + # # This event inserts the most similar continuation from the Trainset "Cough" + # "type", mmk.NearestNextNeighbor, + # "soundbank", soundbank, + # "feature", mmk.Spectrogram(n_fft=2048, hop_length=512, coordinate="mag"), + # "seconds", Pwhite(lo=2., hi=5., repeats=1) + # ), + Pbind( + "generator", checkpoints["srnn"], + # SampleRNN Checkpoints work best with a temperature parameter : + "temperature", Pwhite(lo=.25, hi=1.5), + "seconds", Pwhite(lo=.1, hi=1., repeats=1), + ) + ], inf).asStream() + + """### Generate""" + + TOTAL_SECONDS = 10. + + ensemble = mmk.EnsembleGenerator( + prompts, TOTAL_SECONDS, BASE_SR, stream, + # with this you can print the event -- or not + print_events=False + ) + outputs = ensemble.run() + + assert_that(outputs.size(1)/BASE_SR).is_equal_to(TOTAL_SECONDS) \ No newline at end of file diff --git a/tests/test_gen_loop.py b/tests/test_gen_loop.py new file mode 100644 index 0000000..8a42730 --- /dev/null +++ b/tests/test_gen_loop.py @@ -0,0 +1,57 @@ +import torch +from assertpy import assert_that + +import mimikit.features.extractor +from .test_utils import TestARM, TestDB, tmp_db +import mimikit as mmk + + +def test_should_run(tmp_db): + db: TestDB = tmp_db("gen-test.h5") + extractor = mimikit.features.extractor.Extractor("signal", mmk.FileToSignal(16000)) + net = TestARM( + TestARM.Config(io_spec=mmk.IOSpec( + inputs=( + mmk.InputSpec( + extractor_name=extractor.name, + transform=mmk.Normalize(), + module=mmk.LinearIO() + ).bind_to(extractor), + mmk.InputSpec( + extractor_name=extractor.name, + transform=mmk.MuLawCompress(256), + module=mmk.LinearIO() + ).bind_to(extractor), + ), + targets=( + mmk.TargetSpec( + extractor_name=extractor.name, + transform=mmk.Normalize(), + module=mmk.LinearIO(), + objective=mmk.Objective("none") + ).bind_to(extractor), + mmk.TargetSpec( + extractor_name=extractor.name, + transform=mmk.MuLawCompress(256), + module=mmk.LinearIO(), + objective=mmk.Objective("none") + ).bind_to(extractor), + ) + )) + ) + + assert_that(net).is_instance_of(TestARM) + + loop = mmk.GenerateLoopV2.from_config( + mmk.GenerateLoopV2.Config( + prompts_position_sec=(None,), + batch_size=1, + ), + db, net + ) + + assert_that(loop).is_instance_of(mmk.GenerateLoopV2) + for outputs in loop.run(): + assert_that(len(outputs)).is_equal_to(2) + assert_that(outputs[0]).is_instance_of(torch.Tensor) + assert_that(torch.all(outputs[0][:, -loop.n_steps:] != 0)).is_true() diff --git a/tests/test_sample_rnn.py b/tests/test_sample_rnn.py new file mode 100644 index 0000000..053df25 --- /dev/null +++ b/tests/test_sample_rnn.py @@ -0,0 +1,145 @@ +import os +from typing import Tuple + +import pytest +import torch +from assertpy import assert_that + +from mimikit import GenerateLoopV2, TrainARMLoop, TrainARMConfig +from .test_utils import TestDB, tmp_db + +from mimikit.networks.sample_rnn_v2 import SampleRNN +from mimikit.checkpoint import Checkpoint +from mimikit.io_spec import IOSpec + + +def test_should_instantiate_from_default_config(): + given_config = SampleRNN.Config(io_spec=IOSpec.mulaw_io( + IOSpec.MuLawIOConfig() + )) + + under_test = SampleRNN.from_config(given_config) + + assert_that(type(under_test)).is_equal_to(SampleRNN) + assert_that(len(under_test.tiers)).is_equal_to(len(given_config.frame_sizes)) + + +def test_should_take_n_unfolded_inputs(): + given_frame_sizes = (16, 4, 2,) + given_config = SampleRNN.Config( + frame_sizes=given_frame_sizes, + io_spec=IOSpec.mulaw_io( + IOSpec.MuLawIOConfig() + ), + inputs_mode='sum', + ) + given_inputs = (torch.arange(128).reshape(2, 64),) + # given_inputs[1] -= 64 + under_test = SampleRNN.from_config(given_config) + outputs = under_test(given_inputs) + + assert_that(type(outputs)).is_equal_to(tuple) + assert_that(outputs[0].shape).is_equal_to( + (2, given_inputs[0].size(1) - given_frame_sizes[0], + given_config.io_spec.inputs[0].elem_type.size) + ) + + +def test_should_load_when_saved(tmp_path_factory): + given_config = SampleRNN.Config(io_spec=IOSpec.mulaw_io( + IOSpec.MuLawIOConfig() + )) + root = str(tmp_path_factory.mktemp("ckpt")) + srnn = SampleRNN.from_config(given_config) + ckpt = Checkpoint(id="123", epoch=1, root_dir=root) + + ckpt.create(network=srnn) + loaded = ckpt.network + + assert_that(type(loaded)).is_equal_to(SampleRNN) + + +@pytest.mark.parametrize( + "given_temp", + [None, 0.5, (1.,)] +) +def test_generate( + given_temp +): + given_config = SampleRNN.Config(io_spec=IOSpec.mulaw_io( + IOSpec.MuLawIOConfig() + )) + q_levels = given_config.io_spec.inputs[0].elem_type.size + srnn = SampleRNN.from_config(given_config) + + given_prompt = (torch.randint(0, q_levels, (1, 32,)),) + srnn.eval() + # For now, prompts are just Tuple[T] (--> Tuple[Tuple[T, ...]] for multi inputs!) + srnn.before_generate(given_prompt, batch_index=0) + output = srnn.generate_step( + tuple(p[:, -srnn.rf:] for p in given_prompt), + t=given_prompt[0].size(1), + temperature=given_temp) + srnn.after_generate(output, batch_index=0) + + assert_that(type(output)).is_equal_to(tuple) + assert_that(output[0].size(0)).is_equal_to(given_prompt[0].size(0)) + assert_that(output[0].ndim).is_equal_to(given_prompt[0].ndim) + + +def test_generate_loop_integration(tmp_db): + given_config = SampleRNN.Config(io_spec=IOSpec.mulaw_io( + IOSpec.MuLawIOConfig() + )) + srnn = SampleRNN.from_config(given_config) + + db: TestDB = tmp_db("gen-test.h5") + + loop = GenerateLoopV2.from_config( + GenerateLoopV2.Config( + prompts_length_sec=512 / 16000, + output_duration_sec=512 / 16000, + prompts_position_sec=(None, None,), + batch_size=2, + parameters=dict(temperature=(1.,)) + ), + db, srnn + ) + + for outputs in loop.run(): + assert_that(outputs).is_not_none() + assert_that(outputs[0].shape).is_equal_to((2, 1024)) + assert_that(outputs[0].dtype).is_equal_to(torch.float) + + +def test_should_train(tmp_db, tmp_path): + given_config = SampleRNN.Config(io_spec=IOSpec.mulaw_io( + IOSpec.MuLawIOConfig() + ), frame_sizes=(4, 2, 2)) + srnn = SampleRNN.from_config(given_config) + db = tmp_db("train-loop.h5") + config = TrainARMConfig( + root_dir=str(tmp_path), + limit_train_batches=2, + batch_size=2, + batch_length=8, + tbptt_chunk_length=128, + max_epochs=2, + every_n_epochs=1, + oversampling=4, + CHECKPOINT_TRAINING=True, + MONITOR_TRAINING=True, + OUTPUT_TRAINING=True, + ) + + loop = TrainARMLoop.from_config( + config, dataset=db, network=srnn + ) + + loop.run() + + content = os.listdir(os.path.join(str(tmp_path), loop.hash_)) + assert_that(content).contains("hp.yaml", "outputs", "epoch=1.ckpt") + + outputs = os.listdir(os.path.join(str(tmp_path), loop.hash_, "outputs")) + assert_that([os.path.splitext(o)[-1] for o in outputs]).contains(".mp3") \ No newline at end of file diff --git a/tests/test_seq2seq.py b/tests/test_seq2seq.py new file mode 100644 index 0000000..7c15d01 --- /dev/null +++ b/tests/test_seq2seq.py @@ -0,0 +1,188 @@ +import os + +import pytest +from assertpy import assert_that + +import torch + +from mimikit import IOSpec, TrainARMConfig, TrainARMLoop, GenerateLoopV2 +from mimikit.networks.s2s_lstm_v2 import EncoderLSTM, DecoderLSTM, Seq2SeqLSTMNetwork + +from .test_utils import tmp_db + + +def inputs_(b=8, t=32, d=16): + return torch.randn(b, t, d) + + +@pytest.mark.parametrize( + "weight_norm", + [True, False] +) +@pytest.mark.parametrize( + "downsampling", + ['edge_sum', 'edge_mean', 'sum', 'mean', 'linear_resample'] +) +@pytest.mark.parametrize( + "output_dim", + [32, 64, 128] +) +@pytest.mark.parametrize( + "num_layers", + [1, 2, 3, 4] +) +@pytest.mark.parametrize( + "apply_residuals", + [True, False] +) +@pytest.mark.parametrize( + "input_dim", + [32, 64, 128] +) +@pytest.mark.parametrize( + "hop", + [2, 4, 8] +) +def test_encoder_forward( + hop, input_dim, apply_residuals, num_layers, output_dim, downsampling, weight_norm +): + given_input = inputs_(4, hop, input_dim) + under_test = EncoderLSTM( + downsampling=downsampling, input_dim=input_dim, output_dim=output_dim, + num_layers=num_layers, apply_residuals=apply_residuals, hop=hop, + weight_norm=weight_norm + ) + + y, (hidden, h_c) = under_test.forward(given_input) + + assert_that(y).is_instance_of(torch.Tensor) + assert_that(y.size(0)).is_equal_to(given_input.size(0)) + assert_that(y.size(1)).is_equal_to(1) + assert_that(y.size(2)).is_equal_to(output_dim) + + assert_that(hidden).is_instance_of(torch.Tensor) + assert_that(hidden.size(1)).is_equal_to(given_input.size(0)) + assert_that(hidden.size(0)).is_equal_to(2) + assert_that(hidden.size(2)).is_equal_to(output_dim) + + +@pytest.mark.parametrize( + "weight_norm", + [True, False] +) +@pytest.mark.parametrize( + "upsampling", + ['repeat', 'interp', 'linear_resample'] +) +@pytest.mark.parametrize( + "num_layers", + [1, 2, 3, 4] +) +@pytest.mark.parametrize( + "apply_residuals", + [True, False] +) +@pytest.mark.parametrize( + "model_dim", + [32, 64, 128] +) +@pytest.mark.parametrize( + "hop", + [2, 4, 8] +) +def test_decoder_forward( + hop, model_dim, apply_residuals, num_layers, upsampling, weight_norm +): + B = 4 + x = torch.randn(B, 1, model_dim) + hidden = torch.randn(2, B, model_dim), torch.randn(2, B, model_dim) + under_test = DecoderLSTM( + upsampling=upsampling, model_dim=model_dim, weight_norm=weight_norm, + num_layers=num_layers, apply_residuals=apply_residuals, hop=hop + ) + + y = under_test.forward(x, hidden) + + assert_that(y).is_instance_of(torch.Tensor) + assert_that(y.size(0)).is_equal_to(x.size(0)) + assert_that(y.size(1)).is_equal_to(hop) + assert_that(y.size(2)).is_equal_to(model_dim) + + +def test_seq2seq_forward(): + under_test = Seq2SeqLSTMNetwork.from_config( + Seq2SeqLSTMNetwork.Config( + io_spec=IOSpec.magspec_io(IOSpec.MagSpecIOConfig()) + ) + ) + given_inputs = (inputs_( + 4, under_test.config.hop, under_test.config.io_spec.inputs[0].elem_type.size),) + + outputs = under_test.forward(given_inputs) + + assert_that(outputs).is_instance_of(torch.Tensor) + assert_that(outputs.size()).is_equal_to(given_inputs[0].size()) + + +def test_should_generate(tmp_db): + db = tmp_db("train-loop.h5") + s2s = Seq2SeqLSTMNetwork.from_config( + Seq2SeqLSTMNetwork.Config( + io_spec=IOSpec.magspec_io(IOSpec.MagSpecIOConfig()), + hop=2 + ) + ) + loop = GenerateLoopV2.from_config( + GenerateLoopV2.Config( + prompts_position_sec=(None,), + batch_size=1, + ), + db, s2s + ) + + for outputs in loop.run(): + assert_that(len(outputs)).is_equal_to(1) + assert_that(outputs[0]).is_instance_of(torch.Tensor) + assert_that(torch.all(outputs[0][:, -loop.n_steps:] != 0)).is_true() + + +@pytest.mark.parametrize( + "given_io", + [ + IOSpec.magspec_io(IOSpec.MagSpecIOConfig()), + IOSpec.mulaw_io(IOSpec.MuLawIOConfig(input_module_type="embedding")) + ] +) +def test_should_train(tmp_db, tmp_path, given_io): + s2s = Seq2SeqLSTMNetwork.from_config( + Seq2SeqLSTMNetwork.Config( + io_spec=given_io, + hop=2 + ) + ) + + db = tmp_db("train-loop.h5") + config = TrainARMConfig( + root_dir=str(tmp_path), + limit_train_batches=2, + batch_size=2, + batch_length=s2s.config.hop, + downsampling=64, + max_epochs=2, + every_n_epochs=1, + CHECKPOINT_TRAINING=True, + MONITOR_TRAINING=True, + OUTPUT_TRAINING=True, + ) + + loop = TrainARMLoop.from_config( + config, dataset=db, network=s2s + ) + + loop.run() + + content = os.listdir(os.path.join(str(tmp_path), loop.hash_)) + assert_that(content).contains("hp.yaml", "outputs", "epoch=1.ckpt") + + outputs = os.listdir(os.path.join(str(tmp_path), loop.hash_, "outputs")) + assert_that([os.path.splitext(o)[-1] for o in outputs]).contains(".mp3") diff --git a/tests/test_train_loop.py b/tests/test_train_loop.py new file mode 100644 index 0000000..8f39e97 --- /dev/null +++ b/tests/test_train_loop.py @@ -0,0 +1,50 @@ +import os +from assertpy import assert_that + +from .test_utils import tmp_db, TestARM +import mimikit as mmk + + +def test_should_run(tmp_db, tmp_path): + db = tmp_db("train-loop.h5") + extractor = mmk.Extractor("signal", mmk.FileToSignal(16000)) + net = TestARM( + TestARM.Config(io_spec=mmk.IOSpec( + inputs=( + mmk.InputSpec( + extractor_name=extractor.name, + transform=mmk.Normalize(), + module=mmk.LinearIO() + ).bind_to(extractor), + ), + targets=( + mmk.TargetSpec( + extractor_name=extractor.name, + transform=mmk.Normalize(), + module=mmk.LinearIO(), + objective=mmk.Objective("reconstruction") + ).bind_to(extractor), + ) + )) + ) + config = mmk.TrainARMConfig( + root_dir=str(tmp_path), + limit_train_batches=4, + max_epochs=4, + every_n_epochs=1, + CHECKPOINT_TRAINING=True, + MONITOR_TRAINING=True, + OUTPUT_TRAINING=True, + ) + + loop = mmk.TrainARMLoop.from_config( + config, dataset=db, network=net + ) + + loop.run() + + content = os.listdir(os.path.join(str(tmp_path), loop.hash_)) + assert_that(content).contains("hp.yaml", "outputs", "epoch=1.ckpt") + + outputs = os.listdir(os.path.join(str(tmp_path), loop.hash_, "outputs")) + assert_that([os.path.splitext(o)[-1] for o in outputs]).contains(".mp3") diff --git a/tests/test_true.py b/tests/test_true.py deleted file mode 100644 index 1f3b77f..0000000 --- a/tests/test_true.py +++ /dev/null @@ -1,2 +0,0 @@ -def test_true(): - assert True diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 0000000..fd2b6c2 --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,122 @@ +from typing import Tuple, Dict, Set +import dataclasses as dtc + +import pytest +import torch +from assertpy import assert_that + +import numpy as np +import h5mapper as h5m +from torch import nn + +from mimikit import ARM, IOSpec +from mimikit.networks.arm import NetworkConfig + +__all__ = [ + "TestDB", + "tmp_db", + "TestARM", +] + +from mimikit.features.item_spec import ItemSpec + + +class RandSignal(h5m.Feature): + + def load(self, source): + return (np.random.rand(32000) * 2 - 1).astype(np.float32) + + +class RandLabel(h5m.Feature): + + def load(self, source): + return np.random.randint(0, 256, (32000,)) + + +class TestDB(h5m.TypedFile): + signal = RandSignal() + label = RandLabel() + + +@pytest.fixture +def tmp_db(tmp_path): + root = (tmp_path / "dbs") + root.mkdir() + + def create_func(filename) -> TestDB: + TestDB.create( + str(root / filename), + sources=tuple(map(str, range(2))), + mode="w", keep_open=False, parallelism='none' + ) + return TestDB(str(root / filename)) + + return create_func + + +def test_fixture_db(tmp_db): + db = tmp_db("temp") + + assert_that(db.signal).is_not_none() + assert_that(db.signal[:32]).is_instance_of(np.ndarray) + + +class TestARM(ARM, nn.Module): + @dtc.dataclass + class Config(NetworkConfig): + io_spec: IOSpec = None + + @property + def config(self) -> NetworkConfig: + return self._config + + @property + def rf(self): + return 8 + + def train_batch(self, item_spec: ItemSpec) -> \ + Tuple[Tuple[h5m.Input, ...], Tuple[h5m.Input, ...]]: + return tuple( + feat.to_batch_item(item_spec) + for feat in self.config.io_spec.inputs + ), tuple( + feat.to_batch_item(item_spec) + for feat in self.config.io_spec.targets + ) + + def test_batch(self, item_spec: ItemSpec) ->\ + Tuple[Tuple[h5m.Input, ...], Tuple[h5m.Input, ...]]: + return self.train_batch(item_spec) + + @property + def generate_params(self) -> Set[str]: + return set() + + def before_generate(self, prompts: Tuple[torch.Tensor, ...], batch_index: int) -> None: + return + + def generate_step(self, inputs: Tuple[torch.Tensor, ...], *, t: int = 0, **parameters: Dict[str, torch.Tensor]) -> \ + Tuple[torch.Tensor, ...]: + return tuple(self.forward(i) for i in inputs) + + def after_generate(self, final_outputs: Tuple[torch.Tensor, ...], batch_index: int) -> None: + return + + @classmethod + def from_config(cls, config: NetworkConfig): + return cls(config) + + def __init__(self, config: NetworkConfig): + super(TestARM, self).__init__() + self._config = config + self.fc = nn.Linear(1, 1) + + def forward(self, inputs): + if self.training: + if isinstance(inputs, (tuple, list)): + return tuple(self.fc(x.unsqueeze(-1)).squeeze() for x in inputs) + return self.fc(inputs.unsqueeze(-1)).squeeze() + else: + if isinstance(inputs, (tuple, list)): + return tuple(x[:, -1:] for x in inputs) + return inputs[:, -1:] diff --git a/tests/test_wavenet.py b/tests/test_wavenet.py new file mode 100644 index 0000000..a8e96cc --- /dev/null +++ b/tests/test_wavenet.py @@ -0,0 +1,248 @@ +import os +from typing import Tuple + +import pytest +from assertpy import assert_that + +import torch +from torch.nn import Sigmoid + +from mimikit import IOSpec, InputSpec, TargetSpec, LinearIO, Objective, \ + FileToSignal, Normalize, TrainARMConfig, TrainARMLoop +from mimikit.features.extractor import Extractor +from mimikit.checkpoint import Checkpoint +from mimikit.networks.wavenet_v2 import WNLayer, WaveNet + +from .test_utils import tmp_db + + +def inputs_(b=8, t=32, d=16): + return torch.randn(b, d, t) + + +@pytest.mark.parametrize( + "with_gate", + [True, False] +) +@pytest.mark.parametrize( + "feed_skips", + [True, False] +) +@pytest.mark.parametrize( + "given_input_dim", + [None, 7] +) +@pytest.mark.parametrize( + "given_pad", + [0, 1] +) +@pytest.mark.parametrize( + "given_residuals", + [None, 5, 7] +) +@pytest.mark.parametrize( + "given_skips", + [None, 34] +) +@pytest.mark.parametrize( + "given_1x1", + [(), (3,), (8, 2,), (4, 9, 64)] +) +@pytest.mark.parametrize( + "given_dil", + [(16,), (32,), (8,)] +) +def test_layer_should_support_various_graphs( + given_dil: Tuple[int], given_1x1: Tuple[int], given_skips, given_residuals, + given_pad, given_input_dim, feed_skips, with_gate +): + # if given_residuals is not None and given_input_dim is not None + under_test = WNLayer( + input_dim=given_input_dim, + dims_dilated=given_dil, + dims_1x1=given_1x1, + skips_dim=given_skips, + residuals_dim=given_residuals, + pad_side=given_pad, + act_g=Sigmoid() if with_gate else None + ) + B, T = 1, 8 + # HOW INPUT_DIM WORKS: + if given_input_dim is None: + if given_residuals is None: + input_dim = given_dil[0] + else: + input_dim = given_residuals + else: + input_dim = given_input_dim + + skips = None if not feed_skips or given_skips is None else inputs_(B, T, given_skips) + + given_inputs = ( + (inputs_(B, T, input_dim),), tuple(inputs_(B, T, d) for d in given_1x1), skips + ) + # HOW OUTPUT DIM WORKS: + if given_residuals is not None: + if given_input_dim is not None and given_input_dim != given_residuals: + # RESIDUALS ARE SKIPPED! + expected_out_dim = given_dil[0] + else: + expected_out_dim = given_residuals + else: + expected_out_dim = given_dil[0] + + outputs = under_test(*given_inputs) + + assert_that(type(outputs)).is_equal_to(tuple) + assert_that(len(outputs)).is_equal_to(2) + + assert_that(outputs[0].size(1)).is_equal_to(expected_out_dim) + if given_skips is not None: + assert_that(outputs[1].size(1)).is_equal_to(given_skips) + + if bool(given_pad): + assert_that(outputs[0].size(-1)).is_equal_to(T) + if given_skips is not None: + assert_that(outputs[1].size(-1)).is_equal_to(T) + assert_that(outputs[1].size(-1)).is_equal_to(outputs[0].size(-1)) + else: + assert_that(outputs[0].size(-1)).is_less_than(T) + if given_skips is not None: + assert_that(outputs[1].size(-1)).is_less_than(T) + assert_that(outputs[1].size(-1)).is_equal_to(outputs[0].size(-1)) + + +def test_should_instantiate_from_default_config(): + given_config = WaveNet.Config(io_spec=IOSpec.mulaw_io( + IOSpec.MuLawIOConfig(input_module_type="embedding") + )) + + under_test = WaveNet.from_config(given_config) + + assert_that(type(under_test)).is_equal_to(WaveNet) + assert_that(len(under_test.layers)).is_equal_to(given_config.blocks[0]) + + +def test_should_load_when_saved(tmp_path_factory): + given_config = WaveNet.Config(io_spec=IOSpec.mulaw_io( + IOSpec.MuLawIOConfig(input_module_type="embedding") + )) + root = str(tmp_path_factory.mktemp("ckpt")) + wn = WaveNet.from_config(given_config) + ckpt = Checkpoint(id="123", epoch=1, root_dir=root) + + ckpt.create(network=wn) + loaded = ckpt.network + + assert_that(type(loaded)).is_equal_to(WaveNet) + + +@pytest.mark.parametrize( + "given_temp", + [None, 0.5, (1.,)] +) +def test_generate( + given_temp +): + given_config = WaveNet.Config(io_spec=IOSpec.mulaw_io( + IOSpec.MuLawIOConfig(input_module_type="embedding") + )) + q_levels = given_config.io_spec.inputs[0].elem_type.size + wn = WaveNet.from_config(given_config) + + given_prompt = torch.randint(0, q_levels, (1, 128,)) + wn.eval() + # For now, prompts are just Tuple[T] (--> Tuple[Tuple[T, ...]] for multi inputs!) + wn.before_generate((given_prompt,), batch_index=0) + output = wn.generate_step( + (given_prompt[:, -wn.rf:],), + t=given_prompt.size(1), + temperature=given_temp) + wn.after_generate(output, batch_index=0) + + assert_that(type(output)).is_equal_to(tuple) + assert_that(output[0].size(0)).is_equal_to(given_prompt.size(0)) + assert_that(output[0].ndim).is_equal_to(given_prompt.ndim) + + +def test_should_support_multiple_io(): + extractor = Extractor("signal", FileToSignal(16000)) + given_io = IOSpec( + inputs=( + InputSpec( + extractor_name=extractor.name, + transform=Normalize(), + module=LinearIO() + ).bind_to(extractor), + InputSpec( + extractor_name=extractor.name, + transform=Normalize(), + module=LinearIO() + ).bind_to(extractor), + ), + targets=( + TargetSpec( + extractor_name=extractor.name, + transform=Normalize(), + module=LinearIO(), + objective=Objective("reconstruction") + ).bind_to(extractor), + TargetSpec( + extractor_name=extractor.name, + transform=Normalize(), + module=LinearIO(), + objective=Objective("reconstruction") + ).bind_to(extractor)), ) + wn = WaveNet.from_config(WaveNet.Config( + io_spec=given_io, + dims_dilated=(128,), + dims_1x1=(44,) + )) + + assert_that(wn).is_instance_of(WaveNet) + + given_inputs = ( + torch.randn(1, 32, 1), + torch.randn(1, 32, 1), + ) + + outputs = wn.forward(given_inputs) + + assert_that(outputs).is_instance_of(tuple) + assert_that(outputs[0].size()).is_equal_to(outputs[1].size()) + + +@pytest.mark.parametrize( + "given_io", + [ + IOSpec.magspec_io(IOSpec.MagSpecIOConfig()), + IOSpec.mulaw_io(IOSpec.MuLawIOConfig(input_module_type="embedding")) + ] +) +def test_should_train(tmp_db, tmp_path, given_io): + given_config = WaveNet.Config(io_spec=given_io, blocks=(3,)) + wn = WaveNet.from_config(given_config) + db = tmp_db("train-loop.h5") + config = TrainARMConfig( + root_dir=str(tmp_path), + limit_train_batches=2, + batch_size=2, + batch_length=8, + max_epochs=2, + every_n_epochs=1, + CHECKPOINT_TRAINING=True, + MONITOR_TRAINING=True, + OUTPUT_TRAINING=True, + ) + + loop = TrainARMLoop.from_config( + config, dataset=db, network=wn + ) + + loop.run() + + content = os.listdir(os.path.join(str(tmp_path), loop.hash_)) + assert_that(content).contains("hp.yaml", "outputs", "epoch=1.ckpt") + + outputs = os.listdir(os.path.join(str(tmp_path), loop.hash_, "outputs")) + assert_that([os.path.splitext(o)[-1] for o in outputs]).contains(".mp3") \ No newline at end of file diff --git a/todos.md b/todos.md new file mode 100644 index 0000000..d86a43c --- /dev/null +++ b/todos.md @@ -0,0 +1,89 @@ +# Todos + +## v0.4.0 + +### necessary + +- TEST Notebooks + - FreqNet + - SampleRNN + - Seq2Seq + - Generate from Checkpoint + - Ensemble + - Clusterizer App +- Colab TESTS + +### nice to have + +- TEST M1 Support +- Ensemble / Generate nb + - Prompt UI +- Ensemble + - working nearest neighbour generator + - Generator classes (average from several models?, ...) + - parameters (fade-in/out?, gain?, noise?, ...) +- Eval checkpoint Notebook +- Segmentation Notebook +- PocoNet / poco Wavenet +- Train notebook(s) with UI + +### long term... + +- Transformer +- support for TBPTT in freq domain +- SampleRNN in freq domain (tier_i ==> n_fft instead of frame_size) +- flexible IO declaration (fft, signal+segments, learnable fft, ...) +- more audio features + - class ClusterLabel(Feature): + - class SegmentLabel(Feature): + - from rec mat + - KMer (seqPrior) + - BitVector + - Quantize / Digitize / Linearize + - [/] MelSpec + - [/] MFCC + - TimeIndex + - Scaler + - MinMax + - Normal + - Augmentation(functional, prob) + ................................... + - tuple or not tuple + - AR Feature vs. Fixture vs. Auxiliary Target (vs. kwargs) + - AR --> Input == Target --> shared data + prompt must be: prior_t data + n_steps blank + !! target interface must come from data + - Fixture --> no target + --> data is read + prompt must be: prior_t + n_steps data + !! this modifies the length of the Dataset! + --> data is transformed from (possibly AR) input + !! this DOESN'T modify the length + - Auxiliary --> no input --> output is just collected + prompt must be: priot_t + n_steps blank + - Batch Alignment for + - Multiple SR + - Multiple Domains + - Same Variable, different repr (e.g. x_0 -> Raw, MuLaw --> ?) + +- More Networks + - SampleGan (WaveGan with labeled segments?) + - Stable Diffusion Experiment + - FFT_RNN + .... +- Loss Terms +- Hooks for + - storing outputs + - modifying generate_step() implementation on the fly... +- flowtorch +- huggingface/dffusers/transformers +- Multi-Checkpoint Models (stochastic averaging) +- Resampler classes with `n_layers` +- jitability / torch==2.0 compile() + - no `*tuple` expr... +- Network Visualizer (UI) +- Resume Training + - Optimizer in Checkpoint +- Upgrade python 3.9 ? (colab is 3.7.15...) + + \ No newline at end of file