diff --git a/.github/workflows/ci-pipeline.yml b/.github/workflows/ci-pipeline.yml
index 5458585..c11d238 100644
--- a/.github/workflows/ci-pipeline.yml
+++ b/.github/workflows/ci-pipeline.yml
@@ -15,9 +15,11 @@ jobs:
python-version: '3.7'
- name: Install dependencies
run: |
- sudo apt-get install -y libsndfile-dev
+ sudo apt-get install -y libsndfile-dev ffmpeg
python -m pip install --quiet --upgrade pip
- pip install --quiet -r requirements.txt
+ pip install -r requirements.txt
+ pip install -r requirements-test.txt
+ pip install -r requirements-torch.txt
pip install --quiet hatch==0.23.1
pip list | grep torch
diff --git a/.gitignore b/.gitignore
index 0b842bc..cce7c09 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,4 +1,4 @@
-.DS_*
+.DS_Store
*__pycache__*
*pyc
*.ipynb_checkpoints*
diff --git a/README.md b/README.md
index 7e4325b..a05908a 100644
--- a/README.md
+++ b/README.md
@@ -1,13 +1,40 @@
# mimikit
-The MusIc ModelIng toolKIT (MIMIKIT) is a python package that does Machine Learning with music data.
+The MusIc ModelIng toolKIT (`mimikit`) is a python package that does Machine Learning with audio data.
-The goal of `mimikit` is to enable you to use referenced and experimental algorithms on data you provide.
+Currently, it focuses on training auto-regressive neural networks to generate audio.
-`mimikit` is still in early development, details and documentation are on their way.
+but it does also contain an app to perform basic & experimental clustering of audio data in a notebook.
-## License
+## Installation
+
+you can install with pip
+```shell script
+pip install mimikit[torch]
+```
+or with
+```shell script
+pip install --upgrade mimikit[torch]
+```
+if you are looking for the latest version
+
+for an editable install, you'll need
+```shell script
+pip install -e . --config-settings editable_mode=compat
+```
+
+## Usage
+
+Head straight to the [notebooks](https://github.com/ktonal/mimikit-notebooks) for example usage of `mimikit`, or open them directly in Colab:
-mimikit is distributed under the terms of the [GNU General Public License v3.0](https://choosealicense.com/licenses/gpl-3.0/)
+[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/ktonal/mimikit-notebooks/blob/main)
+## Output Samples
+
+You can explore the outputs of different trainings done with `mimikit` at this demo website:
+
+ https://ktonal.github.io/mimikit-demo-outputs
+
+## License
+`mimikit` is distributed under the terms of the [GNU General Public License v3.0](https://choosealicense.com/licenses/gpl-3.0/)
diff --git a/dev-docs/dynamic io.md b/dev-docs/dynamic io.md
new file mode 100644
index 0000000..d84b80e
--- /dev/null
+++ b/dev-docs/dynamic io.md
@@ -0,0 +1,74 @@
+
+ CANNONICAL SHAPE:
+
+ (BATCH, [CHANNEL], TIME, [DIM], [COMPONENT])
+
+
+ abs(Spectro), MelSpec, MFCC, Qt, Repr: (Batch, Time, Dim)
+
+ Complex_S: (Batch, [Channel], Time, Dim, 2)
+
+ y: (Batch, [Channel], Time)
+
+ enveloppe: (Batch, [Channel], Time)
+
+ text, lyrics, file_label, segment_label: (Batch, [Channel], Time, [Embedding, Class_Size])
+
+ pitch, speaker_id: (Batch, [Channel], Time, [Embedding, Class_Size])
+
+ qx, k_mer_hash, frame_index, cluster_label: (Batch, [Channel], Time, [Embedding, Class_Size])
+
+ y_bits: (Batch, [Channel], Time, Bit_Depth)
+
+
+
+ Modules can change SHAPES through:
+
+ - Project: (Batch, [Channel], Time) ----> (Batch, [Channel], Time, Dim)
+ - Map/Transform: ANY Structure ----> SAME Structure
+ - Predict: Any ----> TRAINING: (..., Dim) BUT INFER: (..., 1)
+ - Fork: (...., Component) ----> (...) x Component
+ - Join: (...) x Component ----> (...., Component)
+
+ Modules can change SIGNATURES through:
+
+ - Iso: N Inputs ----> N Outputs
+ - Split: 1 Input ----> N Outputs
+ - Reduce: N Inputs ----> 1 Outputs
+ - Transform: N Inputs ----> M Outputs
+
+
+ So, now, we can try to solve:
+
+ REPR Shape ---> F(...) = ??? ---> Model Shape
+
+
+------
+
+ i.e.
+ repr_shape=(B=-1, T=-1)
+ model_shape=(B=-1, T=-1, D=128)
+ ---> ??? = Project
+
+ repr_shape=(B=-1, T=-1, D=1025) X N
+ model_shape=(B=-1, T=-1, D=128)
+ ---> ??? = [Reduce, Map] OR [Join, Reduce], ....
+
+
+**A. User says:**
+
+ "I want to connect this_feat with this_network."
+
+**B. Mimikit answers:**
+
+ "Then you can use this_options"
+ ---> IOConfigService.resolve(this_feat, this_network_class): -> {io_module_config, ...}
+
+**C. User chooses, configures and then clicks `Run`. mimikit goes:**
+
+ "Let's connect all those tings"
+ --> ModelInstantiator(this_feat, this_network, this_io_config) -> Model
+
+
+
+
diff --git a/dev-docs/quasi DDD.md b/dev-docs/quasi DDD.md
new file mode 100644
index 0000000..8ee1f6c
--- /dev/null
+++ b/dev-docs/quasi DDD.md
@@ -0,0 +1,130 @@
+## Motivation
+
+We want to
+
+- easily compose ML models from *components* (inputs/outputs number and type, modules, network architecture, ...)
+- easily build UI interfaces to configure/instantiate them (and ideally, be able to save their state!)
+
+### Top level structure
+
+```
+mimikit
+- config
+- view
+- checkpoint
+- io
+ - audio
+ - mu_law
+ - spectrogram
+ - enveloppe
+ ...
+ - labels
+ - cluster_label
+ - speaker_label
+ ...
+- information_retrieval
+ - segment_audio
+ - cluster
+ ...
+- modules
+ - loss_fn
+ ...
+- networks
+ - io_wrappers
+ - sample_rnn
+ - wavenet
+ ...
+- models
+ - srnn
+ - freqnet
+ ...
+- loops
+ - train_loop
+ - generate_loop
+- trainings
+ - train_arm
+ - train_gan
+ - train_diffusion
+- scripts
+ - generate_arm
+ - train_freqnet
+ - ensemble
+ - eval_arm
+ ....
+- notebooks
+ - generate_arm
+ - train_freqnet
+ - ensemble
+ - explore_cluster
+ .....
+```
+
+### ML Component Design Pattern
+
+use `dataclasses` and inheritance to define & connect the layers of a 'ML component', e.g.
+
+```
+@dtc.dataclass
+class NetConfig(Config):
+ ...
+
+
+@dtc.dataclass
+class NetImpl(NetConfig, nn.Module):
+
+ def __post_init__(self):
+ # init modules...
+ def forward(self, inputs, **kwargs):
+ ...
+
+
+@dtc.dataclass
+class NetView(NetConfig, ConfigView):
+
+ def __post_init__(self):
+ # ... map params to widget ...
+ ConfigView.__init__(self, **params)
+```
+
+Constructors are
+- type-safe
+- consistent across layers
+- (de)serializable
+- **defined once**
+
+Then we can nest Configs, Impls & Views by doing:
+
+```
+
+@dtc.dataclass
+class NestedConfig(Config):
+ io: IOConfig
+ net: NetConfig
+
+
+@dtc.dataclass
+class Model(NestedConfig, nn.Module):
+ def __post_init__(self):
+ ....
+
+@dtc.dataclass
+class ModelView(NestedConfig, ConfigView):
+ .....
+```
+
+--> We win
+- generic saving/loading Checkpoints
+- highly expressive composition for io, models, features, views, etc...
+
+
+### Implementation
+
+different libraries / base classes offer different trade-offs between ease of use and ease of integrations with other libraries.
+
+Ideally, we could,
+- define a constructor -> `dataclass, attrs, namedtuple`
+- have static type checker recognize it
+- attach it to a nn.Module -> inheritance?, `classmethod`?, decorator?
+- use it in a View -> switch mutability
+- (de)serialize it as config -> `OmegaConf`
+- be able to export the nn.Module as `TorchScript`
diff --git a/docs/audios.rst b/docs/audios.rst
index 7532174..7ade0fc 100644
--- a/docs/audios.rst
+++ b/docs/audios.rst
@@ -5,7 +5,7 @@ this package exposes helper classes for processing, interacting & modeling audio
.. py:currentmodule:: mimikit.audios.fmodules
-.. autoclass:: FModule
+.. autoclass:: Functional
:special-members: __call__
:members:
diff --git a/mimikit/__init__.py b/mimikit/__init__.py
index 15cb860..251164b 100644
--- a/mimikit/__init__.py
+++ b/mimikit/__init__.py
@@ -1,21 +1,28 @@
-__version__ = '0.3.4'
+__version__ = '0.4.1'
-from . import extract
+from . import config
from . import features
from . import loops
+from . import checkpoint
from . import modules
+from . import extract
+from . import io_spec
from . import models
from . import networks
from . import demos
+from . import ui
+from . import views
-from .extract import *
+from .checkpoint import *
+from .config import *
from .features import *
from .loops import *
from .modules import *
+from .extract import *
from .models import *
from .networks import *
-
-from .train import *
+from .ui import *
from .utils import *
-
+from .views import *
+from .io_spec import *
from .demos import *
\ No newline at end of file
diff --git a/mimikit/checkpoint.py b/mimikit/checkpoint.py
new file mode 100644
index 0000000..9600e5e
--- /dev/null
+++ b/mimikit/checkpoint.py
@@ -0,0 +1,152 @@
+import abc
+import dataclasses as dtc
+from typing import Optional
+from typing_extensions import Protocol
+
+try:
+ from functools import cached_property
+except ImportError: # python<3.8
+ def cached_property(f):
+ return property(f)
+
+import torch.nn as nn
+
+import h5mapper as h5m
+import os
+
+from .config import Config, Configurable
+from .networks.arm import NetworkConfig
+from .features.dataset import DatasetConfig
+
+__all__ = [
+ 'Checkpoint',
+ 'CheckpointBank'
+]
+
+
+class ConfigurableModule(Configurable, nn.Module, abc.ABC):
+ @property
+ @abc.abstractmethod
+ def config(self) -> NetworkConfig:
+ ...
+
+
+class TrainingConfig(Protocol):
+ @property
+ def dataset(self) -> DatasetConfig:
+ ...
+
+ @property
+ def network(self) -> NetworkConfig:
+ ...
+
+ @property
+ def training(self) -> Config:
+ ...
+
+
+class CheckpointBank(h5m.TypedFile):
+ network = h5m.TensorDict()
+ optimizer = h5m.TensorDict()
+
+ @classmethod
+ def save(cls,
+ filename: str,
+ network: ConfigurableModule,
+ training_config: Optional[TrainingConfig] = None,
+ optimizer: Optional[nn.Module] = None
+ ) -> "CheckpointBank":
+
+ net_dict = network.state_dict()
+ opt_dict = optimizer.state_dict() if optimizer is not None else {}
+ cls.network.set_ds_kwargs(net_dict)
+ if optimizer is not None:
+ cls.optimizer.set_ds_kwargs(opt_dict)
+ os.makedirs(os.path.split(filename)[0], exist_ok=True)
+
+ bank = cls(filename, mode="w")
+ bank.network.attrs["config"] = network.config.serialize()
+ bank.network.add("state_dict", h5m.TensorDict.format(net_dict))
+
+ if optimizer is not None:
+ bank.optimizer.add("state_dict", h5m.TensorDict.format(opt_dict))
+ if training_config is not None:
+ bank.attrs["dataset"] = training_config.dataset.serialize()
+ bank.attrs["training"] = training_config.training.serialize()
+ else:
+ # make a dataset config for being able to at least load the network later
+ features = [*network.config.io_spec.inputs, *network.config.io_spec.targets]
+ schema = {f.extractor_name: f.extractor for f in features}
+ bank.attrs["dataset"] = DatasetConfig(filename="unknown", sources=(),
+ extractors=tuple(schema.values())).serialize()
+
+ bank.flush()
+ bank.close()
+ return bank
+
+
+@dtc.dataclass
+class Checkpoint:
+ id: str
+ epoch: int
+ root_dir: str = "./"
+
+ def create(self,
+ network: ConfigurableModule,
+ training_config: Optional[TrainingConfig] = None,
+ optimizer: Optional[nn.Module] = None):
+ CheckpointBank.save(self.os_path, network, training_config, optimizer)
+ return self
+
+ @staticmethod
+ def get_id_and_epoch(path):
+ id_, epoch = path.split("/")[-2:]
+ return id_.strip("/"), int(epoch.split(".h5")[0].split("=")[-1])
+
+ @staticmethod
+ def from_path(path):
+ basename = os.path.dirname(os.path.dirname(path))
+ return Checkpoint(*Checkpoint.get_id_and_epoch(path), root_dir=basename)
+
+ @property
+ def os_path(self):
+ return os.path.join(self.root_dir, f"{self.id}/epoch={self.epoch}.ckpt")
+
+ def delete(self):
+ os.remove(self.os_path)
+
+ @cached_property
+ def bank(self) -> CheckpointBank:
+ return CheckpointBank(self.os_path, 'r')
+
+ @cached_property
+ def dataset_config(self) -> DatasetConfig:
+ return Config.deserialize(self.bank.attrs["dataset"], as_type=DatasetConfig)
+
+ @cached_property
+ def network_config(self) -> NetworkConfig:
+ return Config.deserialize(self.bank.network.attrs["config"])
+
+ @cached_property
+ def training_config(self) -> TrainingConfig:
+ bank = CheckpointBank(self.os_path, 'r')
+ return Config.deserialize(bank.attrs["training"], as_type=TrainingConfig)
+
+ @cached_property
+ def network(self) -> ConfigurableModule:
+ cfg: NetworkConfig = self.network_config
+ cfg.io_spec.bind_to(self.dataset_config)
+ cls = cfg.owner_class
+ state_dict = self.bank.network.get("state_dict")
+ net = cls.from_config(cfg)
+ net.load_state_dict(state_dict, strict=True)
+ return net
+
+ @cached_property
+ def dataset(self) -> h5m.TypedFile:
+ dataset: DatasetConfig = self.dataset_config
+ if os.path.exists(dataset.filename):
+ return dataset.get(mode="r")
+ return dataset.create(mode="w")
+
+ # Todo: method to add state_dict mul by weights -> def average(self, *others)
\ No newline at end of file
diff --git a/mimikit/config.py b/mimikit/config.py
new file mode 100644
index 0000000..f6cb89a
--- /dev/null
+++ b/mimikit/config.py
@@ -0,0 +1,141 @@
+import abc
+import sys
+from copy import deepcopy
+from omegaconf import OmegaConf, ListConfig, DictConfig
+from typing import List, Tuple, Union, Dict, Any
+import dataclasses as dtc
+from functools import reduce, partial
+
+__all__ = [
+ "private_runtime_field",
+ "Config",
+ "Configurable",
+]
+
+
+def private_runtime_field(default):
+ return dtc.field(init=False, repr=False, metadata=dict(omegaconf_ignore=True), default_factory=lambda: default)
+
+
+# noinspection PyTypeChecker
+def _get_type_object(type_) -> type:
+ if ":" in type_:
+ module, qualname = type_.split(":")
+ else:
+ module, qualname = "mimikit", type_
+ try:
+ m = sys.modules[module]
+ return reduce(lambda o, a: getattr(o, a), qualname.split("."), m)
+ except AttributeError or KeyError:
+ raise ImportError(f"could not find class '{qualname}' from module {module} in current environment")
+
+
+STATIC_TYPED_KEYS = {
+ "dataset": "DatasetConfig",
+ "io_spec": "IOSpec",
+ "inputs": "InputSpec",
+ "targets": "TargetSpec",
+ "objective": "Objective",
+ "feature": "Feature",
+ "extractor": "Extractor",
+ "activation": "ActivationConfig"
+}
+
+
+@dtc.dataclass
+class Config:
+ ## type: str = dtc.field(init=False, repr=False, default="mimikit.config:Config")
+
+ @classmethod
+ def __init_subclass__(cls, type_field=True, **kwargs):
+ """add type info to subclass"""
+ if type_field:
+ default = f"{cls.__qualname__}"
+ if not cls.__module__.startswith("mimikit"):
+ default = f"{cls.__module__}:{default}"
+ setattr(cls, "type", dtc.field(init=False, default=default, repr=False))
+ if "__annotations__" in cls.__dict__:
+ # put the type first for nicer serialization...
+ ann = cls.__dict__["__annotations__"].copy()
+ for k in ann:
+ cls.__dict__["__annotations__"].pop(k)
+ cls.__dict__["__annotations__"].update({"type": str, **ann})
+ else:
+ setattr(cls, "__annotations__", {"type": str})
+
+ @staticmethod
+ def validate_class(cls: type):
+ if "__dataclass_fields__" not in cls.__dict__:
+ if not issubclass(cls, (tuple, list)):
+ raise TypeError("Please decorate your Config class with @dataclass"
+ " so that it can be (de)serialized")
+
+ @property
+ def owner_class(self):
+ module, type_ = type(self).__module__, type(self).__qualname__
+ type_ = ".".join(type_.split(".")[:-1]) if "." in type_ else type_
+ type_ = f"{module}:{type_}"
+ return _get_type_object(type_)
+
+ def serialize(self):
+ self.validate_class(type(self))
+ cfg = OmegaConf.structured(self)
+ return OmegaConf.to_yaml(cfg)
+
+ @staticmethod
+ def deserialize(raw_yaml, as_type=None):
+ cfg = OmegaConf.create(raw_yaml)
+ return Config.object(cfg, as_type)
+
+ @staticmethod
+ def object(cfg: Union[ListConfig, DictConfig, Dict, List, Tuple, Any], as_type=None):
+ if isinstance(cfg, (DictConfig, Dict)):
+ for k, v in cfg.items():
+ if k in STATIC_TYPED_KEYS:
+ cls = _get_type_object(STATIC_TYPED_KEYS[k])
+ setattr(cfg, k, Config.object(v, as_type=cls))
+ elif k == "extractors":
+ setattr(cfg, k, tuple(map(partial(Config.object, as_type=_get_type_object("Extractor")), v)))
+ elif isinstance(v, (ListConfig, DictConfig, Dict, List, Tuple)):
+ setattr(cfg, k, Config.object(v))
+ if as_type is not None:
+ cls = as_type
+ elif "type" in cfg:
+ cls = _get_type_object(cfg.type)
+ else: # untyped raw dict
+ return cfg
+ if isinstance(cfg, DictConfig):
+ cfg._metadata.object_type = cls
+ return OmegaConf.to_object(cfg)
+ else:
+ return cls(**cfg)
+
+ elif isinstance(cfg, (ListConfig, List, Tuple)):
+ return OmegaConf.to_object(OmegaConf.structured([*map(partial(Config.object, as_type=as_type), cfg)]))
+ # any other kind of value
+ return cfg
+
+ def dict(self):
+ """caution! nested configs are also converted!"""
+ return dtc.asdict(self)
+
+ def copy(self):
+ return deepcopy(self)
+
+ def validate(self) -> Tuple[bool, str]:
+ return True, ''
+
+
+class Configurable(abc.ABC):
+
+ @classmethod
+ @abc.abstractmethod
+ def from_config(cls, config: Config):
+ ...
+
+ @property
+ @abc.abstractmethod
+ def config(self) -> Config:
+ ...
+
+
diff --git a/mimikit/demos/__init__.py b/mimikit/demos/__init__.py
index 4f8c961..b9f229d 100644
--- a/mimikit/demos/__init__.py
+++ b/mimikit/demos/__init__.py
@@ -1,8 +1,8 @@
-from .freqnet import *
from .srnn import *
-from .s2s import *
-from .wn import *
+from .seq2seq import *
+from .freqnet import *
from .generate_from_checkpoint import *
-from .ensemble import *
+from .ensemble_generator import *
+from .clusterizer_app import *
__all__ = [_ for _ in dir() if not _.startswith("_")]
diff --git a/mimikit/demos/clusterizer_app.py b/mimikit/demos/clusterizer_app.py
new file mode 100644
index 0000000..6ff5642
--- /dev/null
+++ b/mimikit/demos/clusterizer_app.py
@@ -0,0 +1,44 @@
+def demo():
+ """### Launch the app"""
+
+ import mimikit as mmk
+ from ipywidgets import widgets as W
+ import IPython.display as ipd
+
+ ipd.display(mmk.MMK_STYLE_SHEET)
+ ipd.display(W.HTML(
+ """
+
+"""
+)
diff --git a/mimikit/ui/widgets.py b/mimikit/ui/widgets.py
new file mode 100644
index 0000000..86d3ab9
--- /dev/null
+++ b/mimikit/ui/widgets.py
@@ -0,0 +1,163 @@
+from typing import Iterable
+
+from ipywidgets import widgets as W, GridspecLayout
+import os
+
+from ..loops.callbacks import tqdm
+
+__all__ = [
+ "EnumWidget",
+ "pw2_widget",
+ "yesno_widget",
+ "Labeled",
+ "UploadWidget",
+]
+
+
+# TODO:
+# - wavenet/sampleRNN -> all params
+# - refine FilePicker (layout, load/save buttons)
+# - network views from Select
+# - move *_io() to IOSpecs
+# - *_io views (and Select)
+
+
+def Labeled(
+ label, widget, tooltip=None
+):
+ label_w = W.Label(value=label, tooltip=label)
+ if tooltip is not None:
+ tltp = W.Button(icon="fa-info", tooltip=tooltip,
+ layout=W.Layout(
+ width="20px",
+ height="12px"),
+ disabled=True,
+ ).add_class("tltp")
+ label_w = W.HBox(children=[label_w, tltp], )
+ label_w.layout = W.Layout(min_width="max_content", width="auto",
+ overflow='revert')
+ container = W.GridBox(children=(label_w, widget),
+ layout=dict(width="auto",
+ grid_template_columns='1fr 2fr')
+ )
+ # container.value = widget.value
+ container.observe = widget.observe
+ return container
+
+
+def pw2_widget(
+ initial_value,
+ min_value=1,
+ max_value=2 ** 16,
+):
+ plus = W.Button(icon="plus", layout=dict(width="auto", overflow='hidden', grid_area='plus'))
+ minus = W.Button(icon="minus", layout=dict(width="auto", overflow='hidden', grid_area='minus'))
+ value = W.Text(value=str(initial_value), layout=dict(width="auto", overflow='hidden', grid_area='val'))
+ plus.on_click(lambda clk: setattr(value, "value", str(min(max_value, int(value.value) * 2))))
+ minus.on_click(lambda clk: setattr(value, "value", str(max(min_value, int(value.value) // 2))))
+ grid = W.GridBox(children=(minus, value, plus),
+ layout=dict(grid_template_columns='1fr 1fr 1fr',
+ grid_template_rows='1fr',
+ grid_template_areas='"minus val plus"'))
+ # bind value state to box state
+ # grid.value = value.value
+ grid.observe = value.observe
+ return grid
+
+
+def yesno_widget(
+ initial_value=True,
+):
+ yes = W.ToggleButton(
+ value=initial_value,
+ description="yes",
+ button_style="success" if initial_value else "",
+ layout=dict(width='auto', maring='auto 4px', grid_area='yes')
+ )
+ no = W.ToggleButton(
+ value=not initial_value,
+ description="no",
+ button_style="" if initial_value else "danger",
+ layout=dict(width='auto', maring='auto 4px', grid_area='no')
+ )
+
+ def toggle_yes(ev):
+ v = ev["new"]
+ if v:
+ setattr(yes, "button_style", "success")
+ setattr(no, "button_style", "")
+ setattr(no, "value", False)
+
+ def toggle_no(ev):
+ v = ev["new"]
+ if v:
+ setattr(no, "button_style", "danger")
+ setattr(yes, "button_style", "")
+ setattr(yes, "value", False)
+
+ yes.observe(toggle_yes, "value")
+ no.observe(toggle_no, "value")
+ grid = W.GridBox(children=(yes, no),
+ layout=dict(grid_template_columns='1fr 1fr',
+ grid_template_rows='1fr',
+ grid_template_areas='"yes no"'))
+ grid.observe = yes.observe
+ return grid
+
+
+def EnumWidget(
+ label: str,
+ options: Iterable[str],
+ value_type=str,
+ selected_index=0
+):
+ options_w = W.GridBox(children=tuple(W.ToggleButton(value=False,
+ description=opt,
+ tooltip=opt,
+ layout=dict(margin='0 4px', width='auto'))
+ for opt in options),
+ layout=dict(grid_template_columns='1fr ' * len(options),
+ width='auto', align_self='center'))
+ container = Labeled(label, options_w)
+ dummy = W.Text(value='')
+ if isinstance(selected_index, int):
+ value = options_w.children[selected_index].value if value_type is str else \
+ value_type(options_w.children[selected_index].value)
+ options_w.children[selected_index].value = True
+ setattr(options_w.children[selected_index], "button_style", "success")
+ else:
+ value = selected_index
+ container.value = value
+ for i, child in enumerate(options_w.children):
+ def observer(ev, c=child, index=i):
+ val = ev["new"]
+ if val and dummy.value != c.description:
+ container.selected_index = index
+ dummy.value = c.description if value_type is str else value_type(c.description)
+ setattr(c, "button_style", "success")
+ for other in options_w.children:
+ if other.value and other is not c:
+ other.value = False
+ other.button_style = ""
+ elif not val and dummy.value == c.description:
+ c.value = True
+
+ child.observe(observer, "value")
+ container.observe = dummy.observe
+ return container
+
+
+def UploadWidget(dest="./"):
+ def write_uploads(inputs):
+ for file in tqdm(inputs["new"], leave=False):
+ with open(os.path.join(dest, file.name), "wb") as f:
+ f.write(file.content.tobytes())
+
+ upload = W.FileUpload(
+ accept='',
+ multiple=True,
+ )
+
+ upload.observe(write_uploads, names='value')
+
+ return upload
diff --git a/mimikit/utils.py b/mimikit/utils.py
index cd5b03d..71823e1 100644
--- a/mimikit/utils.py
+++ b/mimikit/utils.py
@@ -1,61 +1,35 @@
-import librosa
-import numpy as np
-from librosa.display import specshow
-import IPython.display as ipd
-import matplotlib.pyplot as plt
+from enum import Enum
+import re
__all__ = [
- 'audio',
- 'show'
+ "AutoStrEnum",
+ "SOUND_FILE_REGEX",
+ "CHECKPOINT_REGEX",
+ "DATASET_REGEX",
+ "default_device"
]
-HOP_LENGTH, SR = 512, 22050
-# Conversion
+SOUND_FILE_REGEX = re.compile(r"wav$|aif$|aiff$|mp3$|mp4$|m4a$|webm$")
+DATASET_REGEX = re.compile(r".*\.h5$")
+CHECKPOINT_REGEX = re.compile(r".*\.ckpt$")
-a2db = lambda S: librosa.amplitude_to_db(abs(S), ref=S.max())
+class AutoStrEnum(str, Enum):
+ """
+ Workaround while https://github.com/omry/omegaconf/pull/865 is still open...
+ """
+ @staticmethod
+ def _generate_next_value_(name: str, start: int, count: int, last_values: list) -> str:
+ return name
-# Debugging utils
-def to_db(S):
- if S.dtype == np.complex64:
- S_hat = a2db(abs(S)) + 40
- elif S.min() >= 0 and S.dtype in (np.float, np.float32, np.float64, np.float_):
- S_hat = a2db(S) + 40
- else:
- S_hat = a2db(S) + 40
- return S_hat
+def default_device():
+ import torch # don't force user to install torch...(?)
-
-def signal(S, hop_length=HOP_LENGTH):
- if S.dtype in (np.complex64, np.complex128):
- return librosa.istft(S, hop_length=hop_length)
- else:
- return librosa.griffinlim(S, hop_length=hop_length, n_iter=32)
-
-
-def audio(S, hop_length=HOP_LENGTH, sr=SR):
- if len(S.shape) > 1:
- y = signal(S, hop_length)
- if y.size > 0:
- return ipd.display(ipd.Audio(y, rate=sr))
- else:
- return ipd.display(ipd.Audio(np.zeros(hop_length*2), rate=sr))
- else:
- return ipd.display(ipd.Audio(S, rate=sr))
-
-
-def show(S, figsize=(), db_scale=True, title="", **kwargs):
- S_hat = to_db(S) if db_scale else S
- if figsize:
- plt.figure(figsize=figsize)
- if "x_axis" not in kwargs:
- kwargs["x_axis"] = "frames"
- if "y_axis" not in kwargs:
- kwargs["y_axis"] = "frames"
- ax = specshow(S_hat, sr=SR, **kwargs)
- plt.colorbar()
- plt.tight_layout()
- plt.title(title)
- return ax
+ device = "cpu"
+ if torch.cuda.is_available():
+ device = "cuda"
+ elif torch.has_mps:
+ device = "mps"
+ return device
diff --git a/mimikit/views/__init__.py b/mimikit/views/__init__.py
new file mode 100644
index 0000000..1f93d02
--- /dev/null
+++ b/mimikit/views/__init__.py
@@ -0,0 +1,8 @@
+from .clusters import *
+from .clusterizer_app import *
+from .dataset import *
+from .functionals import *
+from .io_spec import *
+from .train_arm import *
+from .wavenet import *
+from .sample_rnn import *
diff --git a/mimikit/views/clusterizer_app.py b/mimikit/views/clusterizer_app.py
new file mode 100644
index 0000000..1e3f96a
--- /dev/null
+++ b/mimikit/views/clusterizer_app.py
@@ -0,0 +1,458 @@
+from typing import *
+import dataclasses as dtc
+
+import h5mapper
+from ipywidgets import widgets as W
+import numpy as np
+import pandas as pd
+from peaksjs_widget import PeaksJSWidget, Segment
+import qgrid
+
+from ..config import Config
+from ..extract.clusters import *
+from ..features.dataset import DatasetConfig
+from ..features.functionals import *
+from .clusters import *
+from .dataset import dataset_view
+from .functionals import *
+
+__all__ = [
+ 'ComposeTransformWidget',
+ 'ClusterWidget',
+ 'ClusterizerApp'
+]
+
+
+
+@dtc.dataclass
+class Meta:
+ config_class: Type
+ view_func: Callable
+ requires: List[Type] = dtc.field(default_factory=lambda: [])
+ only_once: bool = False
+
+ def can_be_added(self, preceding_transforms: List[Type]):
+ if not self.requires:
+ return not preceding_transforms
+ if self.requires[0] is Any and len(preceding_transforms) > 0:
+ return True
+ deps_fullfilled = self.requires == preceding_transforms
+ if self.only_once:
+ already_there = any(f is self.config_class for f in preceding_transforms)
+ else:
+ already_there = False
+ return deps_fullfilled and not already_there
+
+
+TRANSFORMS = {
+ "magspec": Meta(MagSpec, magspec_view, [], True),
+ "melspec": Meta(MelSpec, melspec_view, [MagSpec], True),
+ "mfcc": Meta(MFCC, mfcc_view, [MagSpec, MelSpec], True),
+ "chroma": Meta(Chroma, chroma_view, [MagSpec], True),
+ "autoconvolve": Meta(AutoConvolve, autoconvolve_view, [Any], False),
+ "f0 filter": Meta(F0Filter, f0_filter_view, [MagSpec], False),
+ "nearest_neighbor_filter": Meta(NearestNeighborFilter, nearest_neighbor_filter_view, [Any]),
+ "pca": Meta(PCA, pca_view, [Any]),
+ "nmf": Meta(NMF, nmf_view, [Any]),
+ "factor analysis": Meta(FactorAnalysis, factor_analysis_view, [Any])
+}
+
+CLUSTERINGS = {
+ "grid of means": Meta(GCluster, gcluster_view, [], True),
+ "quantile clustering": Meta(QCluster, qcluster_view, [], True),
+ "argmax": Meta(ArgMax, argmax_view, [], True),
+ "kmeans": Meta(KMeans, kmeans_view, [], True),
+ "spectral clustering": Meta(SpectralClustering, spectral_clustering_view, [], True)
+}
+
+
+class ComposeTransformWidget:
+
+ @staticmethod
+ def header(box):
+ collapse_all = W.Button(description="collapse all",
+ layout=dict(width="max-content", margin="auto 4px auto 2px"))
+ expand_all = W.Button(description="expand all",
+ layout=dict(width="max-content", margin="auto auto auto 4px"))
+ header = W.HBox(children=[
+ W.HTML(value="
Pre Processing Pipeline
", layout=dict(margin="auto")),
+ collapse_all, expand_all
+ ])
+
+ def on_collapse(ev):
+ for item in box.children:
+ if isinstance(item, W.HBox) and isinstance(item.children[-1], W.Accordion):
+ item.children[-1].selected_index = None
+
+ def on_expand(ev):
+ for item in box.children:
+ if isinstance(item, W.HBox) and isinstance(item.children[-1], W.Accordion):
+ item.children[-1].selected_index = 0
+
+ collapse_all.on_click(on_collapse)
+ expand_all.on_click(on_expand)
+
+ return header
+
+ def __init__(self):
+ self.transforms = []
+ self.metas = []
+ new_choice = W.Button(icon="fa-plus", layout=dict(margin="8px auto"))
+
+ box = W.VBox(layout=dict(width="50%"))
+ header = self.header(box)
+ box.children = (header,)
+
+ choices = W.Select(options=self.get_possible_choices(), layout=dict(width="100%", margin="4px auto"))
+ submit = W.Button(description="submit", layout=dict(width="max-content", margin="auto 8px"))
+ cancel = W.Button(description="cancel", layout=dict(width="max-content", margin="auto 8px"))
+ choice_box = W.VBox(children=(choices, W.HBox(children=(submit, cancel),
+ layout=dict(margin="4px auto"))),
+ layout=dict(width="calc(100% - 54px)", margin="auto 0 auto 54px"))
+
+ def show_new_choice(ev):
+ box.children = (*filter(lambda b: b is not new_choice, box.children), choice_box)
+ new_choice.disabled = True
+
+ def add_choice(ev):
+ meta, cfg, remove_w, hbox = self.new_transform_view_for(choices.value)
+ choices.options = self.get_possible_choices()
+
+ def remove_cb(ev):
+ keep = [0] # always keep magspec
+ for i, (t, m) in enumerate(zip(self.transforms[1:], self.metas[1:]), 1):
+ is_el = t is cfg
+ requires_t = type(cfg) in m.requires
+ if not is_el and not requires_t:
+ keep += [i]
+ self.transforms = [self.transforms[i] for i in keep]
+ box.children = (box.children[0],) + tuple(box.children[i + 1] for i in keep) + (box.children[-1],)
+ choices.options = self.get_possible_choices()
+
+ remove_w.on_click(remove_cb)
+ box.children = (*filter(lambda b: b is not choice_box, box.children), hbox, new_choice)
+ new_choice.disabled = False
+
+ def cancel_new_choice(ev):
+ box.children = (*filter(lambda b: b is not choice_box, box.children), new_choice)
+ new_choice.disabled = False
+
+ submit.on_click(add_choice)
+ cancel.on_click(cancel_new_choice)
+ new_choice.on_click(show_new_choice)
+ self.widget = box
+ choices.value = "magspec"
+ submit.click()
+ # can not remove magspec
+ box.children[1].children[0].disabled = True
+ self.magspec_cfg = self.transforms[0]
+
+ def get_possible_choices(self):
+ options = []
+ ts = [*map(type, self.transforms)]
+ for k, meta in TRANSFORMS.items():
+ if meta.can_be_added(ts):
+ options += [k]
+ return options
+
+ def new_transform_view_for(self, keyword: str):
+ meta = TRANSFORMS[keyword]
+ cfg = meta.config_class()
+ self.metas.append(meta)
+ self.transforms.append(cfg)
+ new_w = meta.view_func(cfg)
+ new_w.layout = dict(margin="auto", width='100%')
+ remove_w = W.Button(icon="fa-trash", layout=dict(width="50px", margin="auto 2px"))
+ hbox = W.HBox(children=(remove_w, new_w), layout=dict(margin="0 4px 4px 4px"))
+ return meta, cfg, remove_w, hbox
+
+ @staticmethod
+ def display(cfg: Compose):
+ w = []
+ for func in cfg.functionals:
+ tp = type(func)
+ key = next(k for k, v in TRANSFORMS.items() if v.config_class is tp)
+ meta = TRANSFORMS[key]
+ w += [meta.view_func(func)]
+ box = W.VBox(layout=dict(width="50%"))
+ header = ComposeTransformWidget.header(box)
+ box.children = (header, *w)
+ return box
+
+
+class ClusterWidget:
+ def __init__(self):
+ self.cfg = None
+ new_choice = W.Button(description="change algo",
+ layout=dict(width="max-content", margin="auto auto auto 12px"),
+ disabled=True)
+ header = W.HBox(children=[
+ W.HTML(value=" Clustering Algo
", layout=dict(margin="auto")),
+ new_choice
+ ])
+ choices = W.Select(options=self.get_possible_choices(), layout=dict(width="100%", margin="4px auto"))
+ submit = W.Button(description="submit", layout=dict(width="max-content", margin="auto 8px"))
+ cancel = W.Button(description="cancel", layout=dict(width="max-content", margin="auto 8px"))
+ choice_box = W.VBox(children=(choices, W.HBox(children=(submit, cancel),
+ layout=dict(margin="4px auto"))),
+ layout=dict(width="95%", margin="auto"))
+ box = W.VBox(children=(header,
+ choice_box,),
+ layout=dict(width='50%'))
+
+ def show_new_choice(ev):
+ box.children = (header, choice_box)
+ new_choice.disabled = True
+
+ def add_choice(ev):
+ meta = CLUSTERINGS[choices.value]
+ self.cfg = meta.config_class()
+ new_w = meta.view_func(self.cfg)
+ new_w.layout = dict(width="95%")
+ box.children = (header, new_w)
+ new_choice.disabled = False
+
+ def cancel_new_choice(ev):
+ box.children = (*filter(lambda b: b is not choice_box, box.children),)
+ new_choice.disabled = False
+
+ submit.on_click(add_choice)
+ cancel.on_click(cancel_new_choice)
+ new_choice.on_click(show_new_choice)
+ self.widget = box
+
+ @staticmethod
+ def get_possible_choices():
+ return [*CLUSTERINGS.keys()]
+
+ @staticmethod
+ def display(cfg):
+ tp = type(cfg)
+ key = next(k for k, v in CLUSTERINGS.items() if v.config_class is tp)
+ box = W.VBox(children=[
+ W.HTML(value=" Clustering Algo
", layout=dict(margin="auto")),
+ CLUSTERINGS[key].view_func(cfg),
+ ], layout=dict(width="50%"))
+ return box
+
+
+class ClusterizerApp:
+
+ def __init__(self):
+ self.dataset_cfg = DatasetConfig()
+ self.dataset_widget = dataset_view(self.dataset_cfg)
+ self.dataset_widget.on_created(lambda ev: self.build_load_view(self.db))
+ self.dataset_widget.on_loaded(lambda ev: self.build_load_view(self.db))
+ self.pre_pipeline = ComposeTransformWidget()
+ self.pre_pipeline_widget = self.pre_pipeline.widget
+ self.magspec_cfg = self.pre_pipeline.magspec_cfg
+ self.clusters = ClusterWidget()
+ self.clusters_widget = self.clusters.widget
+ self.labels_widget = W.VBox(layout=dict(max_width="90vw", margin="auto"))
+ self.feature_name = ''
+
+ save_as = W.HBox(children=(
+ W.Label(value='Save clustering as: '), W.Text(value="labels")),
+ layout=dict(margin="auto", width="max-content"))
+ compute = W.HBox(children=(W.Button(description="compute"),),
+ layout=dict(margin="auto", width="max-content"))
+ out = W.Output()
+
+ def on_submit(ev):
+ db = self.db
+ self.feature_name = save_as.children[1].value
+ pipeline = Compose(
+ *self.pre_pipeline.transforms, self.clusters.cfg,
+ Interpolate(mode="previous", length=db.signal.shape[0])
+ )
+ if self.feature_name in db.handle():
+ db.handle().pop(self.feature_name)
+ db.flush()
+ out.clear_output()
+ with out:
+ db.signal.compute({
+ self.feature_name: pipeline
+ }, parallelism='none')
+ feat = getattr(db, self.feature_name)
+ feat.attrs["config"] = pipeline.serialize()
+ db.flush()
+ db.close()
+ self.build_load_view(self.db)
+
+ compute.children[0].on_click(on_submit)
+
+ self.clustering_widget = W.Tab(children=(
+ W.VBox(children=(
+ W.HBox(children=(self.pre_pipeline_widget, self.clusters_widget),
+ layout=dict(align_items='baseline')),
+ save_as, compute, out
+ )),
+ W.Label("Select a dataset to load a clustering")
+ ), layout=dict(max_width="1000px", margin="auto"))
+ self.clustering_widget.set_title(0, 'Create new clustering')
+ self.clustering_widget.set_title(1, 'Load clustering')
+
+ @property
+ def db(self):
+ return self.dataset_cfg.get(mode="r+")
+
+ @property
+ def sr(self):
+ return self.db.config.extractors[0].functional.functionals[0].sr
+
+ def segments_for(self, feature_name: str):
+ db = self.db
+ sr = self.sr
+ lbl = getattr(db, feature_name)[:]
+ splits = (lbl[1:] - lbl[:-1]) != 0
+ time_idx = np.r_[splits.nonzero()[0], lbl.shape[0] - 1]
+ cluster_idx = lbl[time_idx]
+ segments = [
+ Segment(t, tp1, id=i, labelText=str(c)).dict()
+ for i, (t, tp1, c) in enumerate(
+ zip(time_idx[:-1] / sr, time_idx[1:] / sr, cluster_idx[:-1]))
+ ]
+ df = pd.DataFrame.from_dict(segments)
+ df.set_index("id", drop=True, inplace=True)
+ label_set = [*sorted(set(lbl))]
+ return df, label_set
+
+ def bounce(self, segments: List[Segment]):
+ fft = self.magspec_cfg(self.db.signal[:])
+ sr, hop = self.sr, self.magspec_cfg.hop_length
+
+ def t2f(t): return int(t * sr / hop)
+
+ filtered = np.concatenate([fft[slice(t2f(s.startTime), t2f(s.endTime) + 1)] for s in segments])
+ return self.magspec_cfg.inv(filtered)
+
+ def build_load_view(self, db):
+ proxies = [k for k, v in db.__dict__.items()
+ if isinstance(v, h5mapper.Proxy) and not k.startswith("__") and k != "signal"
+ ]
+ self.results_buttons = W.ToggleButtons(options=proxies, index=None,
+ layout=dict(margin="12px 0"))
+
+ def callback(ev):
+ self.load_result(ev["new"])
+ self.feature_name = ev["new"]
+ self.labels_widget.children = self.label_view()
+
+ self.results_buttons.observe(callback, "value")
+ self.clustering_widget.children = (
+ self.clustering_widget.children[0],
+ W.VBox(children=(W.Label(value="load clustering: "), self.results_buttons))
+ )
+
+ def load_result(self, key: str):
+ cfg = Config.deserialize(getattr(self.db, key).attrs["config"])
+ pre_pipeline = Compose(*cfg.functionals[:-2])
+ clustering = cfg.functionals[-2]
+ self.display(pre_pipeline, clustering)
+
+ def display(self, compose: Compose, clustering):
+ self.clustering_widget.children = (
+ self.clustering_widget.children[0],
+ W.VBox(children=(
+ W.VBox(children=(W.Label(value="load clustering: "), self.results_buttons)),
+ W.HBox(children=(ComposeTransformWidget.display(compose), ClusterWidget.display(clustering)),
+ layout=dict(align_items='baseline'))
+ ))
+ )
+
+ def label_view(self):
+ df, label_set = self.segments_for(self.feature_name)
+ df.to_dict()
+ w = PeaksJSWidget(array=self.db.signal[:], sr=self.sr, id_count=len(df),
+ layout=dict(margin="auto", max_width="1500px", width="100%"))
+ empty = pd.DataFrame([])
+ empty.index.name = "id"
+ g = qgrid.show_grid(empty,
+ grid_options=dict(maxVisibleRows=10))
+
+ labels_grid = W.GridBox(layout=dict(max_height='400px',
+ margin="16px auto",
+ grid_template_columns="1fr " * 10,
+ overflow="scroll"))
+ labels_w = []
+
+ for i in label_set:
+ btn = W.ToggleButton(value=False,
+ description=str(i),
+ layout=dict(width="100px", margin="4px"))
+
+ def on_click(ev, widget=btn, index=i):
+ if ev["new"]:
+ w.segments = [
+ *w.segments, *df[df.labelText == str(index)].reset_index().T.to_dict().values()
+ ]
+ widget.button_style = "success"
+ else:
+ w.segments = [
+ s for s in w.segments if s["labelText"] != str(index)
+ ]
+ widget.button_style = ""
+ if w.segments:
+ g.df = pd.DataFrame.from_dict(w.segments).sort_values(by="startTime").set_index("id", drop=True)
+ else:
+ g.df = pd.DataFrame([])
+ g.df.index.name = "id"
+
+ btn.observe(on_click, "value")
+ labels_w += [W.HBox(children=(btn,))]
+
+ labels_grid.children = tuple(labels_w)
+
+ def on_new_segment(wdg, seg):
+ new_seg = Segment(**seg).dict()
+ if w.segments:
+ g.add_row(row=[*new_seg.items()])
+ # i = {**new_seg}.pop("id")
+ # df.loc[i] = {**new_seg}
+ else:
+ g.df = pd.DataFrame.from_dict([new_seg]).set_index("id", drop=True)
+
+ def on_edit_segment(wdg, seg):
+ for k, v in seg.items():
+ if k == "id": continue
+ g.edit_cell(seg["id"], k, v)
+ # df.loc[seg["id"], k] = v
+ g.change_selection([seg["id"]])
+
+ def on_remove_segment(wdg, seg):
+ g.remove_row([seg["id"]])
+ # df.drop(seg["id"], inplace=True)
+
+ def segments_changed(ev):
+ pass
+ # print("segments changed")
+
+ w.observe(segments_changed, "segments")
+ w.on_new_segment(on_new_segment)
+ w.on_new_segment(PeaksJSWidget.add_segment)
+ w.on_edit_segment(on_edit_segment)
+ w.on_edit_segment(PeaksJSWidget.edit_segment)
+ w.on_remove_segment(on_remove_segment)
+ w.on_remove_segment(PeaksJSWidget.remove_segment)
+
+ # g.on("selection_changed", on_selection_changed)
+ # g.on("cell_edited", on_edited_cell)
+ # g.on("filter_changed", lambda ev, qg: print(ev))
+ # g.on("row_removed", on_row_removed)
+
+ def on_bounce(ev):
+ bnc = self.bounce([Segment(s["startTime"], s["endTime"], s["id"]) for s in w.segments])
+ self.labels_widget.children = (*self.labels_widget.children,
+ PeaksJSWidget(array=bnc, sr=self.sr, id_count=0))
+
+ bounce = W.Button(description="Bounce Selection")
+ bounce.on_click(on_bounce)
+
+ return (
+ w,
+ W.HTML("Select Label(s):
"),
+ labels_grid,
+ W.HTML("Selected Labels Segments Table:
"),
+ g,
+ bounce
+ )
\ No newline at end of file
diff --git a/mimikit/views/clusters.py b/mimikit/views/clusters.py
new file mode 100644
index 0000000..b7d1a5f
--- /dev/null
+++ b/mimikit/views/clusters.py
@@ -0,0 +1,158 @@
+from ipywidgets import widgets as W
+
+from ..extract.clusters import QCluster, GCluster, KMeans, SpectralClustering
+from .. import ui as UI
+
+__all__ = [
+ "qcluster_view",
+ "gcluster_view",
+ "argmax_view",
+ "kmeans_view",
+ "spectral_clustering_view"
+]
+
+
+def qcluster_view(cfg: QCluster):
+
+ view = UI.ConfigView(
+ cfg,
+ UI.Param("metric",
+ widget=UI.EnumWidget("Metric: ",
+ ["cosine",
+ "euclidean",
+ "manhattan", ],
+ selected_index=["cosine", "euclidean", "manhattan"].index(cfg.metric)
+ ),
+ ),
+ UI.Param("n_neighbors",
+ widget=UI.Labeled("N Neighbors: ",
+ W.IntText(value=cfg.n_neighbors, layout=dict(margin='4px', width='auto')),
+ ), ),
+ UI.Param("cores_prop",
+ widget=UI.Labeled("Proportion of Cores: ",
+ W.FloatText(value=cfg.cores_prop, layout=dict(margin='4px', width='auto')),
+ ), ),
+ UI.Param("core_neighborhood_size",
+ widget=UI.Labeled("N Neighbors per Core: ",
+ W.IntText(value=cfg.core_neighborhood_size,
+ layout=dict(margin='4px', width='auto')),
+ ), ),
+
+ ).as_widget(lambda children, **kwargs: W.Accordion([W.VBox(children=children)], **kwargs),
+ layout=W.Layout(margin="0"), selected_index=0)
+
+ view.set_title(0, "Quantile Clustering")
+ return view
+
+
+def gcluster_view(cfg: GCluster):
+
+ view = UI.ConfigView(
+ cfg,
+ UI.Param("metric",
+ widget=UI.EnumWidget("Metric: ",
+ ["cosine",
+ "euclidean", ],
+ selected_index=["cosine", "euclidean"].index(cfg.metric)
+ ),
+ ),
+ UI.Param("n_means",
+ widget=UI.Labeled("N Means: ",
+ W.IntText(value=cfg.n_means, layout=dict(margin='4px', width='auto')),
+ ), ),
+ UI.Param("n_iter",
+ widget=UI.Labeled("N Iter: ",
+ W.IntText(value=cfg.n_iter,
+ layout=dict(margin='4px', width='auto')),
+ ), ),
+ UI.Param(name="max_lr",
+ widget=UI.Labeled("Learning Rate: ",
+ W.FloatSlider(
+ value=1e-3, min=1e-5, max=1e-2, step=.00001,
+ readout_format=".2e",
+ layout={"width": "75%"}
+ ),
+ )),
+ UI.Param(name="betas",
+ widget=UI.Labeled("Beta 1",
+ W.FloatLogSlider(
+ value=.9, min=-.75, max=0., step=.001, base=2,
+ layout={"width": "75%"}),
+ ),
+ setter=lambda conf, ev: (ev, conf.betas[1])),
+ UI.Param(name="betas",
+ widget=UI.Labeled("Beta 2",
+ W.FloatLogSlider(
+ value=.9, min=-.75, max=0., step=.001, base=2,
+ layout={"width": "75%"}),
+ ),
+ setter=lambda conf, ev: (conf.betas[0], ev)),
+
+ ).as_widget(lambda children, **kwargs: W.Accordion([W.VBox(children=children)], **kwargs),
+ layout=W.Layout(margin="0"), selected_index=0)
+
+ view.set_title(0, "Grid of Means Clustering")
+ return view
+
+
+def argmax_view(cfg=None):
+ view = W.Accordion(children=[W.Label(value="no parameters to set", layout=dict(width='auto'))],
+ layout=W.Layout(margin="0"), selected_index=0)
+ view.set_title(0, "Arg Max Clustering")
+ return view
+
+
+def kmeans_view(cfg: KMeans):
+
+ view = UI.ConfigView(
+ cfg,
+ UI.Param("n_clusters",
+ widget=UI.Labeled("N Components: ",
+ W.IntText(value=cfg.n_clusters, layout=dict(margin='4px', width='auto')),
+ ), ),
+ UI.Param("n_init",
+ widget=UI.Labeled("N Init: ",
+ W.IntText(value=cfg.n_init, layout=dict(margin='4px', width='auto')),
+ ), ),
+ UI.Param("max_iter",
+ widget=UI.Labeled("Max Iter: ",
+ W.IntText(value=cfg.max_iter, layout=dict(margin='4px', width='auto')),
+ ), ),
+ UI.Param("random_seed",
+ widget=UI.Labeled("Random Seed: ",
+ W.IntText(value=cfg.random_seed, layout=dict(margin='4px', width='auto')),
+ ), )
+
+ ).as_widget(lambda children, **kwargs: W.Accordion([W.VBox(children=children)], **kwargs),
+ layout=W.Layout(margin="0"), selected_index=0)
+
+ view.set_title(0, "KMeans Clustering")
+ return view
+
+
+def spectral_clustering_view(cfg: SpectralClustering):
+
+ view = UI.ConfigView(
+ cfg,
+ UI.Param("n_clusters",
+ widget=UI.Labeled("N Clusters: ",
+ W.IntText(value=cfg.n_clusters, layout=dict(margin='4px', width='auto')),
+ ), ),
+ UI.Param("n_init",
+ widget=UI.Labeled("N Init: ",
+ W.IntText(value=cfg.n_init, layout=dict(margin='4px', width='auto')),
+ ), ),
+ UI.Param("n_neighbors",
+ widget=UI.Labeled("N Neighbors: ",
+ W.IntText(value=cfg.n_neighbors, layout=dict(margin='4px', width='auto')),
+ ), ),
+ UI.Param("random_seed",
+ widget=UI.Labeled("Random Seed: ",
+ W.IntText(value=cfg.random_seed, layout=dict(margin='4px', width='auto')),
+ ), )
+
+ ).as_widget(lambda children, **kwargs: W.Accordion([W.VBox(children=children)], **kwargs),
+ layout=W.Layout(margin="0"), selected_index=0)
+
+ view.set_title(0, "Spectral Clustering Clustering")
+ return view
diff --git a/mimikit/views/dataset.py b/mimikit/views/dataset.py
new file mode 100644
index 0000000..e3311ce
--- /dev/null
+++ b/mimikit/views/dataset.py
@@ -0,0 +1,94 @@
+from ipywidgets import widgets as W
+from .. import ui as UI
+from ..features.functionals import Compose, FileToSignal, RemoveDC, Normalize
+from ..features.extractor import Extractor
+from ..features.dataset import DatasetConfig
+
+__all__ = [
+ "dataset_view"
+]
+
+
+def dataset_view(cfg: DatasetConfig):
+
+ out = W.Output(layout=dict(margin='32px 0'))
+ title = W.HTML("Select Soundfiles
", layout=dict(margin='0 0 0 8px'))
+ picker = UI.SoundFilePicker()
+ picker_w = picker.widget
+ save_as_txt = W.Text(value=cfg.filename, description='Save as: ',
+ layout=dict(width='75%', margin='4px 0 4px 8px'))
+ create = W.Button(description="Create", layout=dict(width='75%', margin='4px 0 4px 8px'))
+
+ sample_rate = W.IntText(value=16000, description="Sample Rate: ",
+ layout=dict(
+ width='max-content',
+ margin='4px 0 4px 8px')
+ )
+ new_ds_container = W.AppLayout(
+ header=title,
+ center=picker_w,
+ footer=W.VBox(children=(sample_rate, save_as_txt, create),
+ layout=dict(width='75%')),
+ # pane_widths=('0fr', '3fr', '0fr'),
+ pane_heights=("40px", "250px", "112px")
+ )
+
+ def create_ds(ev, callback=None):
+ cfg.sources = tuple(picker.selected)
+ cfg.extractors = (Extractor(name='signal',
+ functional=Compose(
+ FileToSignal(sample_rate.value),
+ RemoveDC(),
+ Normalize()
+ )),)
+ out.clear_output()
+ with out:
+ db = cfg.create(mode='w')
+ print("Extracted:\n\n", *(f"\t- {k}\n" for k in db.index))
+ if callback is not None:
+ callback(db)
+
+ create.on_click(create_ds)
+
+ ds_picker = UI.DatasetPicker()
+ ds_picker_w = ds_picker.widget
+ load_ds = W.Button(description="Load")
+ load_ds_container = W.VBox(children=[
+ W.HTML(value="Select Dataset File
"),
+ ds_picker_w,
+ load_ds
+ ])
+
+ def load_cb(ev, callback=None):
+ cfg.filename = ds_picker.selected
+ db = cfg.get(mode='r')
+ out.clear_output()
+ with out:
+ print("Loaded", cfg.filename)
+ print("Containing:\n\n", *(f"\t- {k}\n" for k in db.index))
+ if callback is not None:
+ callback(db)
+
+ load_ds.on_click(load_cb)
+
+ tabs = W.Tab(children=(new_ds_container, load_ds_container))
+ tabs.set_title(0, "New Dataset from Soundfiles")
+ tabs.set_title(1, "Load Dataset File from Disk")
+
+ class DatasetView(W.Accordion):
+ def __init__(self, *children, **kwargs):
+ super(DatasetView, self).__init__(children=children, **kwargs)
+ self.create = create
+ self.load_ds = load_ds
+
+ def on_created(self, callback):
+ self.create.on_click(callback)
+
+ def on_loaded(self, callback):
+ self.load_ds.on_click(callback)
+
+ top = DatasetView(W.VBox(children=(tabs, out)),
+ layout=dict(max_width="1000px", margin="auto"))
+ top.set_title(0, "Dataset")
+ return top
+
diff --git a/mimikit/views/functionals.py b/mimikit/views/functionals.py
new file mode 100644
index 0000000..5566c91
--- /dev/null
+++ b/mimikit/views/functionals.py
@@ -0,0 +1,290 @@
+from ipywidgets import widgets as W
+
+from ..features.functionals import MagSpec, MelSpec, MFCC, Chroma, \
+ HarmonicSource, PercussiveSource, AutoConvolve, F0Filter, \
+ NearestNeighborFilter, PCA, NMF, FactorAnalysis
+from .. import ui as UI
+
+__all__ = [
+ "magspec_view",
+ "melspec_view",
+ "mfcc_view",
+ "chroma_view",
+ "harmonic_source_view",
+ "percussive_source_view",
+ "autoconvolve_view",
+ "f0_filter_view",
+ "nearest_neighbor_filter_view",
+ "pca_view",
+ "nmf_view",
+ "factor_analysis_view"
+]
+
+
+def magspec_view(cfg: MagSpec):
+
+ view = UI.ConfigView(
+ cfg,
+ UI.Param("n_fft",
+ widget=UI.Labeled("N FFT: ",
+ W.IntText(value=cfg.n_fft, layout=dict(width='auto')),
+ ), ),
+ UI.Param("hop_length",
+ widget=UI.Labeled("hop length: ",
+ W.IntText(value=cfg.hop_length, layout=dict(width='auto')),
+ ), ),
+ UI.Param("center",
+ widget=
+ UI.Labeled("center: ", UI.yesno_widget(initial_value=cfg.center),
+ ), ),
+ UI.Param("window",
+ widget=UI.EnumWidget("window: ",
+ ["None", "hann", "hamming", ],
+ selected_index=0 if cfg.window is None else 1
+ ),
+ setter=lambda c, v: v if v != "None" else None
+ )
+ ).as_widget(lambda children, **kwargs: W.Accordion([W.VBox(children=children)], **kwargs),
+ layout=W.Layout(margin="0 0 0 0"), selected_index=0)
+
+ view.set_title(0, "Magnitude Spectrogram")
+ return view
+
+
+def melspec_view(cfg: MelSpec):
+ view = UI.ConfigView(
+ cfg,
+ UI.Param("n_mels",
+ widget=UI.Labeled("N Mels: ",
+ W.IntText(value=cfg.n_mels, layout=dict(width='auto')),
+ ), ),
+ ).as_widget(lambda children, **kwargs: W.Accordion([W.VBox(children=children)], **kwargs),
+ layout=W.Layout(margin="0 auto 0 0"), selected_index=0)
+
+ view.set_title(0, "MelSpectrogram")
+ return view
+
+
+def mfcc_view(cfg: MFCC):
+ view = UI.ConfigView(
+ cfg,
+ UI.Param("n_mfcc",
+ widget=UI.Labeled("N MFCC: ",
+ W.IntText(value=cfg.n_mfcc, layout=dict(width='auto')),
+ ), ),
+ UI.Param("dct_type",
+ widget=UI.EnumWidget("DCT Type: ",
+ ["1", "2", "3", ],
+ selected_index=cfg.dct_type - 1
+ ),
+ setter=lambda c, v: int(v)
+ )
+ ).as_widget(lambda children, **kwargs: W.Accordion([W.VBox(children=children)], **kwargs),
+ layout=W.Layout(margin="0 auto 0 0"), selected_index=0)
+
+ view.set_title(0, "MFCC")
+ return view
+
+
+def chroma_view(cfg: Chroma):
+ view = UI.ConfigView(
+ cfg,
+ UI.Param("n_chroma",
+ widget=UI.Labeled("N Chroma: ",
+ W.IntText(value=cfg.n_chroma, layout=dict(width='auto')),
+ ), ),
+ ).as_widget(lambda children, **kwargs: W.Accordion([W.VBox(children=children)], **kwargs),
+ layout=W.Layout(margin="0 auto 0 0"), selected_index=0)
+
+ view.set_title(0, "Chroma")
+ return view
+
+
+def harmonic_source_view(cfg: HarmonicSource):
+ view = UI.ConfigView(
+ cfg,
+ UI.Param("kernel_size",
+ widget=UI.Labeled("Kernel Size: ",
+ W.IntText(value=cfg.kernel_size, layout=dict(width='auto')),
+ ), ),
+ UI.Param("power",
+ widget=UI.Labeled("Power: ",
+ W.FloatText(value=cfg.power, layout=dict(width='auto')),
+ ), ),
+ UI.Param("margin",
+ widget=UI.Labeled("Margin: ",
+ W.FloatText(value=cfg.margin, layout=dict(width='auto')),
+ ), ),
+
+ ).as_widget(lambda children, **kwargs: W.Accordion([W.VBox(children=children)], **kwargs),
+ layout=W.Layout(margin="0 auto 0 0"), selected_index=0)
+
+ view.set_title(0, "Harmonic Source")
+ return view
+
+
+def percussive_source_view(cfg: PercussiveSource):
+ view = UI.ConfigView(
+ cfg,
+ UI.Param("kernel_size",
+ widget=UI.Labeled("Kernel Size: ",
+ W.IntText(value=cfg.kernel_size, layout=dict(width='auto')),
+ ), ),
+ UI.Param("power",
+ widget=UI.Labeled("Power: ",
+ W.FloatText(value=cfg.power, layout=dict(width='auto')),
+ ), ),
+ UI.Param("margin",
+ widget=UI.Labeled("Margin: ",
+ W.FloatText(value=cfg.margin, layout=dict(width='auto')),
+ ), ),
+
+ ).as_widget(lambda children, **kwargs: W.Accordion([W.VBox(children=children)], **kwargs),
+ layout=W.Layout(margin="0 auto 0 0"), selected_index=0)
+
+ view.set_title(0, "Percussive Source")
+ return view
+
+
+def autoconvolve_view(cfg: AutoConvolve):
+ view = UI.ConfigView(
+ cfg,
+ UI.Param("window_size",
+ widget=UI.Labeled("Window Size: ",
+ W.IntText(value=cfg.window_size, layout=dict(width='auto')),
+ ), ),
+ ).as_widget(lambda children, **kwargs: W.Accordion([W.VBox(children=children)], **kwargs),
+ layout=W.Layout(margin="0 auto 0 0"), selected_index=0)
+
+ view.set_title(0, "AutoConvolve")
+ return view
+
+
+def f0_filter_view(cfg: F0Filter):
+ view = UI.ConfigView(
+ cfg,
+ UI.Param("n_overtone",
+ widget=UI.Labeled("N Overtone: ",
+ W.IntText(value=cfg.n_overtone, layout=dict(width='auto')),
+ ), ),
+ UI.Param("n_undertone",
+ widget=UI.Labeled("N Undertone: ",
+ W.IntText(value=cfg.n_undertone, layout=dict(width='auto')),
+ ), ),
+ UI.Param("soft",
+ widget=UI.Labeled("Soft Filter: ",
+ UI.yesno_widget(initial_value=cfg.soft, )), ),
+ UI.Param("normalize",
+ widget=UI.Labeled("Normalize: ",
+ UI.yesno_widget(initial_value=cfg.normalize)
+ ), ),
+ ).as_widget(lambda children, **kwargs: W.Accordion([W.VBox(children=children)], **kwargs),
+ layout=W.Layout(margin="0 auto 0 0"), selected_index=0)
+
+ view.set_title(0, "F0 Filter")
+ return view
+
+
+def nearest_neighbor_filter_view(cfg: NearestNeighborFilter):
+ view = UI.ConfigView(
+ cfg,
+ UI.Param("n_neighbors",
+ widget=UI.Labeled("N Neighbors: ",
+ W.IntText(value=cfg.n_neighbors, layout=dict(width='auto')),
+ ), ),
+ UI.Param("metric",
+ widget=UI.EnumWidget("Metric: ",
+ ["cosine",
+ "euclidean",
+ "manhattan",
+ ],
+ selected_index=["cosine", "euclidean", "manhattan"].index(cfg.metric)
+ ),
+ ),
+ UI.Param("aggregate",
+ widget=UI.EnumWidget("Aggregate: ",
+ ["mean",
+ "median",
+ "max",
+ ],
+ selected_index=["mean", "median", "max"].index(cfg.aggregate)
+ ),
+ ),
+ ).as_widget(lambda children, **kwargs: W.Accordion([W.VBox(children=children)], **kwargs),
+ layout=W.Layout(margin="0 auto 0 0"), selected_index=0)
+
+ view.set_title(0, "Nearest Neighbor Filter")
+ return view
+
+
+def pca_view(cfg: PCA):
+ view = UI.ConfigView(
+ cfg,
+ UI.Param("n_components",
+ widget=UI.Labeled("N Components: ",
+ W.IntText(value=cfg.n_components, layout=dict(width='auto')),
+ ), ),
+ UI.Param("random_seed",
+ widget=UI.Labeled("Random Seed: ",
+ W.IntText(value=cfg.random_seed, layout=dict(width='auto')),
+ ), )
+
+ ).as_widget(lambda children, **kwargs: W.Accordion([W.VBox(children=children)], **kwargs),
+ layout=W.Layout(margin="0 auto 0 0"), selected_index=0)
+
+ view.set_title(0, "PCA")
+ return view
+
+
+def nmf_view(cfg: NMF):
+ view = UI.ConfigView(
+ cfg,
+ UI.Param("n_components",
+ widget=UI.Labeled("N Components: ",
+ W.IntText(value=cfg.n_components, layout=dict(width='auto')),
+ ), ),
+ UI.Param("tol",
+ widget=UI.Labeled("Tolerance: ",
+ W.FloatText(value=cfg.tol, layout=dict(width='auto')),
+ ), ),
+ UI.Param("max_iter",
+ widget=UI.Labeled("Max Iter: ",
+ W.IntText(value=cfg.max_iter, layout=dict(width='auto')),
+ ), ),
+ UI.Param("random_seed",
+ widget=UI.Labeled("Random Seed: ",
+ W.IntText(value=cfg.random_seed, layout=dict(width='auto')),
+ ), )
+
+ ).as_widget(lambda children, **kwargs: W.Accordion([W.VBox(children=children)], **kwargs),
+ layout=W.Layout(margin="0 auto 0 0"), selected_index=0)
+
+ view.set_title(0, "NMF")
+ return view
+
+
+def factor_analysis_view(cfg: FactorAnalysis):
+ view = UI.ConfigView(
+ cfg,
+ UI.Param("n_components",
+ widget=UI.Labeled("N Components: ",
+ W.IntText(value=cfg.n_components, layout=dict(width='auto')),
+ ), ),
+ UI.Param("tol",
+ widget=UI.Labeled("Tolerance: ",
+ W.FloatText(value=cfg.tol, layout=dict(width='auto')),
+ ), ),
+ UI.Param("max_iter",
+ widget=UI.Labeled("Max Iter: ",
+ W.IntText(value=cfg.max_iter, layout=dict(width='auto')),
+ ), ),
+ UI.Param("random_seed",
+ widget=UI.Labeled("Random Seed: ",
+ W.IntText(value=cfg.random_seed, layout=dict(width='auto')),
+ ), )
+
+ ).as_widget(lambda children, **kwargs: W.Accordion([W.VBox(children=children)], **kwargs),
+ layout=W.Layout(margin="0 auto 0 0"), selected_index=0)
+
+ view.set_title(0, "Factor Analysis")
+ return view
diff --git a/mimikit/views/io_spec.py b/mimikit/views/io_spec.py
new file mode 100644
index 0000000..ef01d7c
--- /dev/null
+++ b/mimikit/views/io_spec.py
@@ -0,0 +1,82 @@
+from ipywidgets import widgets as W
+from .. import ui as UI
+
+from ..io_spec import IOSpec
+
+
+__all__ = [
+ "mulaw_io_view",
+ "magspec_io_view"
+]
+
+
+def mulaw_io_view(cfg: IOSpec.MuLawIOConfig):
+ view = UI.ConfigView(
+ cfg,
+ UI.Param(
+ name='sr',
+ widget=UI.Labeled("Sample Rate",
+ W.IntText(value=cfg.sr))
+ ),
+ UI.Param(
+ name='q_levels',
+ widget=UI.Labeled(
+ "Quantization Levels",
+ UI.pw2_widget(cfg.q_levels)
+ )
+ ),
+ UI.Param(
+ name="compression",
+ widget=UI.Labeled(
+ "Compression",
+ W.FloatSlider(value=cfg.compression, min=0.001, max=2., step=0.01)
+ )
+ ),
+ UI.Param(
+ name='mlp_dim',
+ widget=UI.Labeled("Final Layer Dim", UI.pw2_widget(cfg.mlp_dim))
+ ),
+ UI.Param(
+ name='n_mlp_layer',
+ widget=UI.Labeled(
+ "N hidden final layers",
+ W.IntText(cfg.n_mlp_layers)
+ )
+ ),
+ UI.Param(
+ name='min_temperature',
+ widget=UI.Labeled(
+ "Minimum temperature",
+ W.FloatSlider(value=cfg.min_temperature, min=1e-4, max=1., step=0.0001)
+ )
+ )
+ ).as_widget(lambda children, **kwargs: W.VBox(children=children))
+ return view
+
+
+def magspec_io_view(cfg: IOSpec.MagSpecIOConfig):
+ view = UI.ConfigView(
+ cfg,
+ UI.Param(
+ name='sr',
+ widget=UI.Labeled("Sample Rate",
+ W.IntText(value=cfg.sr))
+ ),
+ UI.Param("n_fft",
+ widget=UI.Labeled("N FFT: ",
+ W.IntText(value=cfg.n_fft),
+ ), ),
+ UI.Param("hop_length",
+ widget=UI.Labeled("hop length: ",
+ W.IntText(value=cfg.hop_length),
+ ), ),
+ UI.Param(
+ name='activation',
+ widget=UI.EnumWidget(
+ "Output Activation",
+ ["Abs", "ScaledSigmoid"],
+ selected_index=["Abs", "ScaledSigmoid"].index(cfg.activation)
+ ),
+ )
+ ).as_widget(lambda children, **kwargs: W.VBox(children=children))
+ return view
diff --git a/mimikit/views/sample_rnn.py b/mimikit/views/sample_rnn.py
new file mode 100644
index 0000000..f1de81d
--- /dev/null
+++ b/mimikit/views/sample_rnn.py
@@ -0,0 +1,73 @@
+import ipywidgets as W
+from .. import ui as UI
+from ..networks.sample_rnn_v2 import SampleRNN
+
+__all__ = [
+ "sample_rnn_view"
+]
+
+
+def sample_rnn_view(cfg: SampleRNN.Config):
+ view = UI.ConfigView(
+ cfg,
+ UI.Param(
+ name='frame_sizes',
+ widget=UI.Labeled(
+ "Frame Sizes",
+ W.Text(value=str(cfg.frame_sizes)[1:-1]),
+ ),
+ setter=lambda c, v: tuple(map(int, (s for s in v.split(",") if s not in ("", " "))))
+ ),
+ UI.Param(
+ name='hidden_dim',
+ widget=UI.Labeled(
+ "Hidden Dim: ",
+ UI.pw2_widget(str(cfg.hidden_dim))
+ ),
+ setter=lambda c, v: int(v)
+ ),
+ UI.Param(
+ name="rnn_class",
+ widget=UI.EnumWidget("Type of RNN: ",
+ ["LSTM", "RNN", "GRU"],
+ selected_index=["LSTM", "RNN", "GRU"].index(cfg.rnn_class.upper())
+ ),
+ setter=lambda c, v: v.lower()
+ ),
+ UI.Param(
+ name="n_rnn",
+ widget=UI.Labeled(
+ "Num of RNN: ",
+ W.IntText(value=cfg.n_rnn),
+ )
+ ),
+ UI.Param(
+ name="rnn_dropout",
+ widget=UI.Labeled(
+ "RNN dropout: ",
+ W.FloatText(value=cfg.rnn_dropout, min=0., max=.999, step=.01, ),
+ )
+ ),
+ UI.Param(name="rnn_bias",
+ widget=UI.Labeled(
+ "use bias by RNNs",
+ UI.yesno_widget(initial_value=cfg.rnn_bias),
+ ), ),
+ UI.Param(
+ name="h0_init",
+ widget=UI.EnumWidget("Hidden initialization: ",
+ ["zeros", "randn", "ones"],
+ selected_index=["zeros", "randn", "ones"].index(cfg.h0_init)
+ ),
+ setter=lambda c, v: v.lower()
+ ),
+ UI.Param(name="weight_norm",
+ widget=UI.Labeled(
+ "use weights normalization: ",
+ UI.yesno_widget(initial_value=cfg.weight_norm),
+ ), ),
+
+ ).as_widget(lambda children, **kwargs: W.Accordion([W.VBox(children=children)], **kwargs),
+ selected_index=0, layout=W.Layout(margin="0 auto 0 0", width="100%"))
+ view.set_title(0, "SampleRNN Config")
+ return view
diff --git a/mimikit/views/train_arm.py b/mimikit/views/train_arm.py
new file mode 100644
index 0000000..0806675
--- /dev/null
+++ b/mimikit/views/train_arm.py
@@ -0,0 +1,203 @@
+import ipywidgets as W
+from .. import ui as UI
+from ..loops.train_loops import TrainARMConfig
+
+__all__ = [
+ "train_arm_view"
+]
+
+
+def train_arm_view(cfg: TrainARMConfig):
+ view = UI.ConfigView(
+ cfg,
+ UI.Param(name='root_dir',
+ widget=UI.Labeled(
+ "Directory",
+ W.Text(value=cfg.root_dir)
+ ),
+ position=(0, 0)
+ ),
+ UI.Param(name='_',
+ widget=W.HTML("Batches
", layout=dict(height='28px')),
+ position=(1, 0)),
+ UI.Param(name='_',
+ widget=W.HTML("
", layout=dict(height='28px')),
+ position=(1, 1)),
+ UI.Param(name="batch_size",
+ widget=UI.Labeled(
+ "Batch Size: ",
+ UI.pw2_widget(cfg.batch_size),
+ ),
+ setter=lambda conf, v: int(v),
+ position=(2, 0)
+ ),
+ UI.Param(name="batch_length",
+ widget=UI.Labeled(
+ "Batch Length: ",
+ UI.pw2_widget(cfg.batch_length),
+ ),
+ setter=lambda conf, v: int(v),
+ position=(2, 1),
+ ),
+ UI.Param(name="downsampling",
+ widget=UI.Labeled(
+ "Batches downsampling",
+ W.IntText(value=cfg.downsampling)
+ ),
+ position=(3, 0)),
+ UI.Param(name="oversampling",
+ widget=UI.Labeled(
+ "Batch oversampling",
+ W.IntText(value=cfg.oversampling)
+ ),
+ position=(3, 1)),
+ UI.Param(name='tbptt_chunk_length',
+ widget=UI.Labeled(
+ "TBPTT length",
+ W.IntText(value=cfg.tbptt_chunk_length)
+ ),
+ position=(4, 0)),
+ UI.Param(name='_',
+ widget=W.HTML("Epochs
", layout=dict(height='28px')),
+ position=(5, 0)),
+ UI.Param(name='_',
+ widget=W.HTML("
", layout=dict(height='28px')),
+ position=(5, 1)),
+ UI.Param(name="max_epochs",
+ widget=UI.Labeled(
+ "Number of Epochs: ",
+ W.IntText(value=cfg.max_epochs),
+ "training will be performed for this number of epochs."
+ ),
+ position=(6, 0)),
+ UI.Param(name="limit_train_batches",
+ widget=UI.Labeled(
+ "Max batches per epoch: ",
+ W.IntText(value=0 if cfg.limit_train_batches is None else cfg.limit_train_batches, ),
+ "limit the number of batches per epoch, enter 0 for no limit"
+ ),
+ setter=lambda conf, ev: ev if ev > 0 else None,
+ position=(6,1)),
+ UI.Param(name='_',
+ widget=W.HTML("Optimizer
", layout=dict(height='28px')),
+ position=(7, 0)),
+ UI.Param(name='_',
+ widget=W.HTML("
", layout=dict(height='28px')),
+ position=(7, 1)),
+ UI.Param(name="max_lr",
+ widget=UI.Labeled(
+ "Learning Rate: ",
+ W.FloatSlider(
+ value=cfg.max_lr, min=1e-5, max=1e-2, step=.00001,
+ readout_format=".2e",
+ ),
+ ),
+ position=(8, 0)
+ ),
+ UI.Param(name="betas",
+ widget=UI.Labeled(
+ "Beta 1",
+ W.FloatLogSlider(
+ value=cfg.betas[0], min=-.75, max=0., step=.001, base=2,
+ ),
+ ),
+ setter=lambda conf, ev: (ev, conf.betas[1]),
+ position=(9, 0)
+ ),
+ UI.Param(name="betas",
+ widget=UI.Labeled(
+ "Beta 2",
+ W.FloatLogSlider(
+ value=cfg.betas[1], min=-.75, max=0., step=.001, base=2,),
+ ),
+ setter=lambda conf, ev: (conf.betas[0], ev),
+ position=(9, 1)
+ ),
+ UI.Param(name='_',
+ widget=W.HTML("LR Scheduler
",
+ layout=dict(height='28px')),
+ position=(10, 0)),
+ UI.Param(name='_',
+ widget=W.HTML("
", layout=dict(height='28px')),
+ position=(10, 1)),
+ UI.Param(name="div_factor",
+ widget=UI.Labeled(
+ "Start LR div factor",
+ W.FloatSlider(
+ value=cfg.div_factor, min=0.001, max=100., step=.001,
+ ),
+ ),
+ position=(11, 0)
+ ),
+ UI.Param(name="final_div_factor",
+ widget=UI.Labeled(
+ "End LR div factor",
+ W.FloatSlider(
+ value=cfg.final_div_factor, min=0.001, max=100., step=.001,
+ ),
+ ),
+ position=(11, 1)
+ ),
+ UI.Param(name="pct_start",
+ widget=UI.Labeled(
+ "Percent training start LR to max LR",
+ W.FloatSlider(
+ value=cfg.pct_start, min=0.0, max=1., step=.001,
+ ),
+ ),
+ position=(12, 0)
+ ),
+ UI.Param(name='_',
+ widget=W.HTML("Tests & Checkpoints
", layout=dict(height='28px')),
+ position=(13, 0)),
+ UI.Param(name='_',
+ widget=W.HTML("
", layout=dict(height='28px')),
+ position=(13, 1)),
+ UI.Param(name="every_n_epochs",
+ widget=UI.Labeled(
+ "Test/Checkpoint every $N$ epochs",
+ W.IntText(value=cfg.every_n_epochs),
+ ),
+ position=(14, 0)),
+ UI.Param(name='n_examples',
+ widget=UI.Labeled(
+ "$N$ Test examples",
+ W.IntText(value=cfg.n_examples),
+ ),
+ position=(14, 1)),
+ UI.Param(name='prompt_length_sec',
+ widget=UI.Labeled(
+ "Prompt length (in sec.)",
+ W.FloatText(value=cfg.prompt_length_sec),
+ ),
+ position=(15, 0)),
+ UI.Param(name='outputs_duration_sec',
+ widget=UI.Labeled(
+ "Tests length (in sec.)",
+ W.FloatText(value=cfg.outputs_duration_sec),
+ ),
+ position=(15, 1)),
+ UI.Param(name='temperature',
+ widget=UI.Labeled(
+ "Test examples' temperatures",
+ W.Text(value='' if cfg.temperature is None else str(cfg.temperature)[1:-1])
+ ),
+ position=(16, slice(0, 2)),
+ setter=lambda config, ev: tuple(map(eval, ev.split(', '))) if ev else None),
+ UI.Param(name="CHECKPOINT_TRAINING",
+ widget=UI.Labeled(
+ "Checkpoint Training: ",
+ UI.yesno_widget(initial_value=cfg.CHECKPOINT_TRAINING),
+ ),
+ position=(17, 0)),
+ UI.Param(name="MONITOR_TRAINING",
+ widget=UI.Labeled(
+ "Monitor Training: ",
+ UI.yesno_widget(initial_value=cfg.MONITOR_TRAINING),
+ ),
+ position=(17, 1)),
+ grid_spec=(18, 2)
+ ).as_widget(lambda children, **kwargs: W.Accordion([W.VBox(children=children)], **kwargs),
+ selected_index=0, layout=W.Layout(margin="0 auto 0 0", width="100%"))
+ view.set_title(0, "Optimization Loop")
+ return view
diff --git a/mimikit/views/wavenet.py b/mimikit/views/wavenet.py
new file mode 100644
index 0000000..aec4a83
--- /dev/null
+++ b/mimikit/views/wavenet.py
@@ -0,0 +1,89 @@
+import ipywidgets as W
+from .. import ui as UI
+from ..networks.wavenet_v2 import WaveNet
+
+__all__ = [
+ "wavenet_view"
+]
+
+
+def wavenet_view(cfg: WaveNet.Config):
+ view = UI.ConfigView(
+ cfg,
+ UI.Param(name='kernel_sizes',
+ widget=UI.Labeled(
+ "kernel size",
+ W.IntText(value=cfg.kernel_sizes[0]),
+ ),
+ setter=lambda conf, ev: (ev,)),
+ UI.Param(name="blocks",
+ widget=UI.Labeled(
+ "layers per block",
+ W.Text(value=str(cfg.blocks)[1:-1])
+ ),
+ setter=lambda c, v: tuple(map(int, (s for s in v.split(",") if s not in ("", " "))))
+ ),
+ UI.Param(name="dims_dilated",
+ widget=UI.Labeled(
+ "$N$ units per layer: ",
+ UI.pw2_widget(cfg.dims_dilated[0]),
+ ),
+ setter=lambda c, v: (int(v),)
+ ),
+ UI.Param(name="dims_1x1",
+ widget=UI.Labeled(
+ "$N$ units per conditioning layer: ",
+ UI.pw2_widget(cfg.dims_dilated[0]),
+ ),
+ setter=lambda c, v: (int(v),)
+ ),
+ UI.Param(name="residual_dim",
+ widget=UI.Labeled(
+ "$N$ units per residual layer: ",
+ UI.pw2_widget(cfg.dims_dilated[0]),
+ ),
+ setter=lambda c, v: (int(v),)
+ ),
+ UI.Param(name="apply_residuals",
+ widget=UI.Labeled(
+ "use residuals",
+ UI.yesno_widget(initial_value=cfg.residuals_dim is not None),
+ ),
+ setter=lambda conf, ev: conf.dims_dilated[0] if ev else None
+ ),
+ UI.Param(name="skips_dim",
+ widget=UI.Labeled(
+ "$N$ units per skip layer: ",
+ UI.pw2_widget(cfg.dims_dilated[0]),
+ ),
+ setter=lambda c, v: (int(v),)
+ ),
+ UI.Param(name='groups',
+ widget=UI.Labeled(
+ "groups of units",
+ UI.pw2_widget(cfg.groups),
+ )),
+
+ UI.Param(name="pad_side",
+ widget=UI.Labeled(
+ "use padding",
+ UI.yesno_widget(initial_value=bool(cfg.pad_side)),
+ ),
+ setter=lambda conf, ev: int(ev)
+ ),
+ UI.Param(name="bias",
+ widget=UI.Labeled(
+ "use bias",
+ UI.yesno_widget(initial_value=cfg.bias),
+ ),
+ ),
+ UI.Param(name="use_fast_generate",
+ widget=UI.Labeled(
+ "use fast generate",
+ UI.yesno_widget(initial_value=cfg.use_fast_generate),
+ ),
+ ),
+ ).as_widget(lambda children, **kwargs: W.Accordion([W.VBox(children=children)], **kwargs),
+ selected_index=0, layout=W.Layout(margin="0 auto 0 0", width="100%"))
+ view.set_title(0, "WaveNet Config")
+ return view
diff --git a/requirements-colab.txt b/requirements-colab.txt
index 3cb4a2a..2894f6e 100644
--- a/requirements-colab.txt
+++ b/requirements-colab.txt
@@ -1,4 +1,3 @@
-# assume torch==1.10 is installed
-# torchaudio==0.10.0+cu113 -f https://download.pytorch.org/whl/cu113/torch_stable.html
-pytorch-lightning==1.4.9
-torchmetrics==0.5.1
\ No newline at end of file
+# assume torch==1.13 is installed
+# torchaudio==0.13.0+cu113 -f https://download.pytorch.org/whl/cu113/torch_stable.html
+pytorch-lightning==1.6.5
\ No newline at end of file
diff --git a/requirements-test.txt b/requirements-test.txt
new file mode 100644
index 0000000..d7e0c2a
--- /dev/null
+++ b/requirements-test.txt
@@ -0,0 +1,2 @@
+assertpy==1.1
+pytest
\ No newline at end of file
diff --git a/requirements-torch.txt b/requirements-torch.txt
index 8b6fa41..8a17c74 100644
--- a/requirements-torch.txt
+++ b/requirements-torch.txt
@@ -1,3 +1,3 @@
-torch>=1.9.0
-torchaudio>=0.9.0
-pytorch-lightning==1.4.9
+torch>=1.13.0
+torchaudio>=0.13.0
+pytorch-lightning==1.6.5
diff --git a/requirements.txt b/requirements.txt
index 14a3a3d..b3e16fe 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,15 +1,22 @@
numpy>=1.19.1
pandas>=1.1.3
-librosa>=0.8
+librosa>=0.9.1
tables>=3.6
ffmpeg-python
tqdm>=4.48.0
matplotlib
scipy>=1.4.1
-scikit-learn>=0.24
+scikit-learn>=1.0.0
soundfile>=0.10.2
test-tube>=0.7.5
-h5mapper==0.2.4
+h5mapper>=0.3.1
numba
-muspy==0.3.0
-soxr
\ No newline at end of file
+soxr
+pydub
+ipywidgets==7.7.1
+omegaconf>=2.3.0
+# eigen solver for spectral clustering:
+pyamg
+pypbind
+peaksjs_widget
+qgrid
\ No newline at end of file
diff --git a/setup.py b/setup.py
index 7e16cc1..cb5a760 100644
--- a/setup.py
+++ b/setup.py
@@ -50,12 +50,11 @@
"Topic :: Multimedia :: Sound/Audio :: Sound Synthesis",
"License :: OSI Approved :: GNU General Public License v3 (GPLv3)",
"Programming Language :: Python :: 3",
- "Programming Language :: Python :: 3.6",
"Programming Language :: Python :: 3.7",
"Programming Language :: Python :: 3.8",
],
"keywords": "audio music sound deep-learning",
- 'python_requires': '>=3.6',
+ 'python_requires': '>=3.7',
'install_requires': REQUIRES,
'extras_require': {
"torch": torch_requires,
diff --git a/tests/test_checkpointable.py b/tests/test_checkpointable.py
new file mode 100644
index 0000000..419ad09
--- /dev/null
+++ b/tests/test_checkpointable.py
@@ -0,0 +1,58 @@
+import dataclasses as dtc
+
+import torch
+import torch.nn as nn
+from assertpy import assert_that
+
+import mimikit as mmk
+import mimikit.networks.arm
+
+
+class MyCustom(mimikit.config.Configurable, nn.Module):
+ @dtc.dataclass
+ class CustomConfig(mimikit.networks.arm.NetworkConfig):
+ io_spec: mmk.IOSpec = mmk.IOSpec(
+ inputs=(mmk.InputSpec(
+ extractor_name="signal",
+ transform=mmk.Normalize(),
+ module=mmk.LinearIO()
+ ).bind_to(mmk.Extractor("signal", mmk.FileToSignal(16000))),),
+ targets=(mmk.TargetSpec(
+ extractor_name="signal",
+ transform=mmk.Normalize(),
+ module=mmk.LinearIO(),
+ objective=mmk.Objective(objective_type="reconstruction")
+ ).bind_to(mmk.Extractor("signal", mmk.FileToSignal(16000))),)
+ )
+ x: int = 1
+
+ @classmethod
+ def from_config(cls, config: "MyCustom.CustomConfig"):
+ return cls(config, nn.Linear(config.x, config.x))
+
+ def __init__(self, config: "MyCustom.CustomConfig", module: nn.Module):
+ super().__init__()
+ self._config = config
+ self.mod = module
+
+ def forward(self, x):
+ return self.mod(x)
+
+ @property
+ def config(self): return self._config
+
+
+def test_should_save_and_load_class_defined_outside_mmk(tmp_path_factory):
+ model = MyCustom.from_config(MyCustom.CustomConfig())
+
+ output = model(torch.randn(2, 1, 1))
+
+ assert_that(type(output)).is_equal_to(torch.Tensor)
+
+ root = str(tmp_path_factory.mktemp("ckpt"))
+ ckpt = mmk.Checkpoint(id="123", epoch=1, root_dir=root)
+
+ ckpt.create(network=model)
+ loaded = ckpt.network
+
+ assert_that(type(loaded)).is_equal_to(MyCustom)
diff --git a/tests/test_ensemble.py b/tests/test_ensemble.py
new file mode 100644
index 0000000..e67763e
--- /dev/null
+++ b/tests/test_ensemble.py
@@ -0,0 +1,84 @@
+import h5mapper as h5m
+import mimikit as mmk
+from pbind import Pseq, Pbind, Prand, Pwhite, inf
+import pytest
+from assertpy import assert_that
+
+from .test_utils import tmp_db
+
+
+@pytest.fixture
+def checkpoints(tmp_path):
+ root = (tmp_path / "ckpts")
+ root.mkdir()
+
+ models = {
+ 'srnn': mmk.SampleRNN.from_config(mmk.SampleRNN.Config(
+ io_spec=mmk.IOSpec.mulaw_io(mmk.IOSpec.MuLawIOConfig(sr=16000))
+ )),
+ 'freqnet': mmk.WaveNet.from_config(mmk.WaveNet.Config(
+ mmk.IOSpec.magspec_io(mmk.IOSpec.MagSpecIOConfig(sr=32000))
+ ))
+ }
+
+ ckpts = {}
+ for name, model in models.items():
+ ckpt = mmk.Checkpoint(id=name, epoch=0, root_dir=str(root))
+ ckpt.create(model)
+ ckpts[name] = ckpt
+ return ckpts
+
+
+def test_should_generate(tmp_db, checkpoints):
+
+ BASE_SR = 22050
+
+ db = tmp_db("ensemble-test.h5")
+
+ """### Define the prompts from which to generate"""
+ # just a torch.Tensor or a numpy.ndarray
+ prompts = next(iter(db.serve(
+ (h5m.Input(data='signal', getter=h5m.AsSlice(shift=0, length=BASE_SR)),),
+ shuffle=False,
+ # batch_size=1 --> new stream for each prompt <> batch_size=8 --> one stream for 8 prompts :
+ batch_size=3,
+ sampler=mmk.IndicesSampler(
+ # INDICES FOR THE PROMPTS :
+ indices=(0, BASE_SR // 2, BASE_SR)
+ ))))[0]
+
+ """### Define a pattern of models"""
+ # THE MODELS PATTERN defines which checkpoint (id, epoch) generates for how long (seconds)
+
+ stream = Pseq([
+ Pbind(
+ "generator", checkpoints["freqnet"],
+ "seconds", Pwhite(lo=3., hi=5., repeats=1)
+ ),
+ # Pbind(
+ # # This event inserts the most similar continuation from the Trainset "Cough"
+ # "type", mmk.NearestNextNeighbor,
+ # "soundbank", soundbank,
+ # "feature", mmk.Spectrogram(n_fft=2048, hop_length=512, coordinate="mag"),
+ # "seconds", Pwhite(lo=2., hi=5., repeats=1)
+ # ),
+ Pbind(
+ "generator", checkpoints["srnn"],
+ # SampleRNN Checkpoints work best with a temperature parameter :
+ "temperature", Pwhite(lo=.25, hi=1.5),
+ "seconds", Pwhite(lo=.1, hi=1., repeats=1),
+ )
+ ], inf).asStream()
+
+ """### Generate"""
+
+ TOTAL_SECONDS = 10.
+
+ ensemble = mmk.EnsembleGenerator(
+ prompts, TOTAL_SECONDS, BASE_SR, stream,
+ # with this you can print the event -- or not
+ print_events=False
+ )
+ outputs = ensemble.run()
+
+ assert_that(outputs.size(1)/BASE_SR).is_equal_to(TOTAL_SECONDS)
\ No newline at end of file
diff --git a/tests/test_gen_loop.py b/tests/test_gen_loop.py
new file mode 100644
index 0000000..8a42730
--- /dev/null
+++ b/tests/test_gen_loop.py
@@ -0,0 +1,57 @@
+import torch
+from assertpy import assert_that
+
+import mimikit.features.extractor
+from .test_utils import TestARM, TestDB, tmp_db
+import mimikit as mmk
+
+
+def test_should_run(tmp_db):
+ db: TestDB = tmp_db("gen-test.h5")
+ extractor = mimikit.features.extractor.Extractor("signal", mmk.FileToSignal(16000))
+ net = TestARM(
+ TestARM.Config(io_spec=mmk.IOSpec(
+ inputs=(
+ mmk.InputSpec(
+ extractor_name=extractor.name,
+ transform=mmk.Normalize(),
+ module=mmk.LinearIO()
+ ).bind_to(extractor),
+ mmk.InputSpec(
+ extractor_name=extractor.name,
+ transform=mmk.MuLawCompress(256),
+ module=mmk.LinearIO()
+ ).bind_to(extractor),
+ ),
+ targets=(
+ mmk.TargetSpec(
+ extractor_name=extractor.name,
+ transform=mmk.Normalize(),
+ module=mmk.LinearIO(),
+ objective=mmk.Objective("none")
+ ).bind_to(extractor),
+ mmk.TargetSpec(
+ extractor_name=extractor.name,
+ transform=mmk.MuLawCompress(256),
+ module=mmk.LinearIO(),
+ objective=mmk.Objective("none")
+ ).bind_to(extractor),
+ )
+ ))
+ )
+
+ assert_that(net).is_instance_of(TestARM)
+
+ loop = mmk.GenerateLoopV2.from_config(
+ mmk.GenerateLoopV2.Config(
+ prompts_position_sec=(None,),
+ batch_size=1,
+ ),
+ db, net
+ )
+
+ assert_that(loop).is_instance_of(mmk.GenerateLoopV2)
+ for outputs in loop.run():
+ assert_that(len(outputs)).is_equal_to(2)
+ assert_that(outputs[0]).is_instance_of(torch.Tensor)
+ assert_that(torch.all(outputs[0][:, -loop.n_steps:] != 0)).is_true()
diff --git a/tests/test_sample_rnn.py b/tests/test_sample_rnn.py
new file mode 100644
index 0000000..053df25
--- /dev/null
+++ b/tests/test_sample_rnn.py
@@ -0,0 +1,145 @@
+import os
+from typing import Tuple
+
+import pytest
+import torch
+from assertpy import assert_that
+
+from mimikit import GenerateLoopV2, TrainARMLoop, TrainARMConfig
+from .test_utils import TestDB, tmp_db
+
+from mimikit.networks.sample_rnn_v2 import SampleRNN
+from mimikit.checkpoint import Checkpoint
+from mimikit.io_spec import IOSpec
+
+
+def test_should_instantiate_from_default_config():
+ given_config = SampleRNN.Config(io_spec=IOSpec.mulaw_io(
+ IOSpec.MuLawIOConfig()
+ ))
+
+ under_test = SampleRNN.from_config(given_config)
+
+ assert_that(type(under_test)).is_equal_to(SampleRNN)
+ assert_that(len(under_test.tiers)).is_equal_to(len(given_config.frame_sizes))
+
+
+def test_should_take_n_unfolded_inputs():
+ given_frame_sizes = (16, 4, 2,)
+ given_config = SampleRNN.Config(
+ frame_sizes=given_frame_sizes,
+ io_spec=IOSpec.mulaw_io(
+ IOSpec.MuLawIOConfig()
+ ),
+ inputs_mode='sum',
+ )
+ given_inputs = (torch.arange(128).reshape(2, 64),)
+ # given_inputs[1] -= 64
+ under_test = SampleRNN.from_config(given_config)
+ outputs = under_test(given_inputs)
+
+ assert_that(type(outputs)).is_equal_to(tuple)
+ assert_that(outputs[0].shape).is_equal_to(
+ (2, given_inputs[0].size(1) - given_frame_sizes[0],
+ given_config.io_spec.inputs[0].elem_type.size)
+ )
+
+
+def test_should_load_when_saved(tmp_path_factory):
+ given_config = SampleRNN.Config(io_spec=IOSpec.mulaw_io(
+ IOSpec.MuLawIOConfig()
+ ))
+ root = str(tmp_path_factory.mktemp("ckpt"))
+ srnn = SampleRNN.from_config(given_config)
+ ckpt = Checkpoint(id="123", epoch=1, root_dir=root)
+
+ ckpt.create(network=srnn)
+ loaded = ckpt.network
+
+ assert_that(type(loaded)).is_equal_to(SampleRNN)
+
+
+@pytest.mark.parametrize(
+ "given_temp",
+ [None, 0.5, (1.,)]
+)
+def test_generate(
+ given_temp
+):
+ given_config = SampleRNN.Config(io_spec=IOSpec.mulaw_io(
+ IOSpec.MuLawIOConfig()
+ ))
+ q_levels = given_config.io_spec.inputs[0].elem_type.size
+ srnn = SampleRNN.from_config(given_config)
+
+ given_prompt = (torch.randint(0, q_levels, (1, 32,)),)
+ srnn.eval()
+ # For now, prompts are just Tuple[T] (--> Tuple[Tuple[T, ...]] for multi inputs!)
+ srnn.before_generate(given_prompt, batch_index=0)
+ output = srnn.generate_step(
+ tuple(p[:, -srnn.rf:] for p in given_prompt),
+ t=given_prompt[0].size(1),
+ temperature=given_temp)
+ srnn.after_generate(output, batch_index=0)
+
+ assert_that(type(output)).is_equal_to(tuple)
+ assert_that(output[0].size(0)).is_equal_to(given_prompt[0].size(0))
+ assert_that(output[0].ndim).is_equal_to(given_prompt[0].ndim)
+
+
+def test_generate_loop_integration(tmp_db):
+ given_config = SampleRNN.Config(io_spec=IOSpec.mulaw_io(
+ IOSpec.MuLawIOConfig()
+ ))
+ srnn = SampleRNN.from_config(given_config)
+
+ db: TestDB = tmp_db("gen-test.h5")
+
+ loop = GenerateLoopV2.from_config(
+ GenerateLoopV2.Config(
+ prompts_length_sec=512 / 16000,
+ output_duration_sec=512 / 16000,
+ prompts_position_sec=(None, None,),
+ batch_size=2,
+ parameters=dict(temperature=(1.,))
+ ),
+ db, srnn
+ )
+
+ for outputs in loop.run():
+ assert_that(outputs).is_not_none()
+ assert_that(outputs[0].shape).is_equal_to((2, 1024))
+ assert_that(outputs[0].dtype).is_equal_to(torch.float)
+
+
+def test_should_train(tmp_db, tmp_path):
+ given_config = SampleRNN.Config(io_spec=IOSpec.mulaw_io(
+ IOSpec.MuLawIOConfig()
+ ), frame_sizes=(4, 2, 2))
+ srnn = SampleRNN.from_config(given_config)
+ db = tmp_db("train-loop.h5")
+ config = TrainARMConfig(
+ root_dir=str(tmp_path),
+ limit_train_batches=2,
+ batch_size=2,
+ batch_length=8,
+ tbptt_chunk_length=128,
+ max_epochs=2,
+ every_n_epochs=1,
+ oversampling=4,
+ CHECKPOINT_TRAINING=True,
+ MONITOR_TRAINING=True,
+ OUTPUT_TRAINING=True,
+ )
+
+ loop = TrainARMLoop.from_config(
+ config, dataset=db, network=srnn
+ )
+
+ loop.run()
+
+ content = os.listdir(os.path.join(str(tmp_path), loop.hash_))
+ assert_that(content).contains("hp.yaml", "outputs", "epoch=1.ckpt")
+
+ outputs = os.listdir(os.path.join(str(tmp_path), loop.hash_, "outputs"))
+ assert_that([os.path.splitext(o)[-1] for o in outputs]).contains(".mp3")
\ No newline at end of file
diff --git a/tests/test_seq2seq.py b/tests/test_seq2seq.py
new file mode 100644
index 0000000..7c15d01
--- /dev/null
+++ b/tests/test_seq2seq.py
@@ -0,0 +1,188 @@
+import os
+
+import pytest
+from assertpy import assert_that
+
+import torch
+
+from mimikit import IOSpec, TrainARMConfig, TrainARMLoop, GenerateLoopV2
+from mimikit.networks.s2s_lstm_v2 import EncoderLSTM, DecoderLSTM, Seq2SeqLSTMNetwork
+
+from .test_utils import tmp_db
+
+
+def inputs_(b=8, t=32, d=16):
+ return torch.randn(b, t, d)
+
+
+@pytest.mark.parametrize(
+ "weight_norm",
+ [True, False]
+)
+@pytest.mark.parametrize(
+ "downsampling",
+ ['edge_sum', 'edge_mean', 'sum', 'mean', 'linear_resample']
+)
+@pytest.mark.parametrize(
+ "output_dim",
+ [32, 64, 128]
+)
+@pytest.mark.parametrize(
+ "num_layers",
+ [1, 2, 3, 4]
+)
+@pytest.mark.parametrize(
+ "apply_residuals",
+ [True, False]
+)
+@pytest.mark.parametrize(
+ "input_dim",
+ [32, 64, 128]
+)
+@pytest.mark.parametrize(
+ "hop",
+ [2, 4, 8]
+)
+def test_encoder_forward(
+ hop, input_dim, apply_residuals, num_layers, output_dim, downsampling, weight_norm
+):
+ given_input = inputs_(4, hop, input_dim)
+ under_test = EncoderLSTM(
+ downsampling=downsampling, input_dim=input_dim, output_dim=output_dim,
+ num_layers=num_layers, apply_residuals=apply_residuals, hop=hop,
+ weight_norm=weight_norm
+ )
+
+ y, (hidden, h_c) = under_test.forward(given_input)
+
+ assert_that(y).is_instance_of(torch.Tensor)
+ assert_that(y.size(0)).is_equal_to(given_input.size(0))
+ assert_that(y.size(1)).is_equal_to(1)
+ assert_that(y.size(2)).is_equal_to(output_dim)
+
+ assert_that(hidden).is_instance_of(torch.Tensor)
+ assert_that(hidden.size(1)).is_equal_to(given_input.size(0))
+ assert_that(hidden.size(0)).is_equal_to(2)
+ assert_that(hidden.size(2)).is_equal_to(output_dim)
+
+
+@pytest.mark.parametrize(
+ "weight_norm",
+ [True, False]
+)
+@pytest.mark.parametrize(
+ "upsampling",
+ ['repeat', 'interp', 'linear_resample']
+)
+@pytest.mark.parametrize(
+ "num_layers",
+ [1, 2, 3, 4]
+)
+@pytest.mark.parametrize(
+ "apply_residuals",
+ [True, False]
+)
+@pytest.mark.parametrize(
+ "model_dim",
+ [32, 64, 128]
+)
+@pytest.mark.parametrize(
+ "hop",
+ [2, 4, 8]
+)
+def test_decoder_forward(
+ hop, model_dim, apply_residuals, num_layers, upsampling, weight_norm
+):
+ B = 4
+ x = torch.randn(B, 1, model_dim)
+ hidden = torch.randn(2, B, model_dim), torch.randn(2, B, model_dim)
+ under_test = DecoderLSTM(
+ upsampling=upsampling, model_dim=model_dim, weight_norm=weight_norm,
+ num_layers=num_layers, apply_residuals=apply_residuals, hop=hop
+ )
+
+ y = under_test.forward(x, hidden)
+
+ assert_that(y).is_instance_of(torch.Tensor)
+ assert_that(y.size(0)).is_equal_to(x.size(0))
+ assert_that(y.size(1)).is_equal_to(hop)
+ assert_that(y.size(2)).is_equal_to(model_dim)
+
+
+def test_seq2seq_forward():
+ under_test = Seq2SeqLSTMNetwork.from_config(
+ Seq2SeqLSTMNetwork.Config(
+ io_spec=IOSpec.magspec_io(IOSpec.MagSpecIOConfig())
+ )
+ )
+ given_inputs = (inputs_(
+ 4, under_test.config.hop, under_test.config.io_spec.inputs[0].elem_type.size),)
+
+ outputs = under_test.forward(given_inputs)
+
+ assert_that(outputs).is_instance_of(torch.Tensor)
+ assert_that(outputs.size()).is_equal_to(given_inputs[0].size())
+
+
+def test_should_generate(tmp_db):
+ db = tmp_db("train-loop.h5")
+ s2s = Seq2SeqLSTMNetwork.from_config(
+ Seq2SeqLSTMNetwork.Config(
+ io_spec=IOSpec.magspec_io(IOSpec.MagSpecIOConfig()),
+ hop=2
+ )
+ )
+ loop = GenerateLoopV2.from_config(
+ GenerateLoopV2.Config(
+ prompts_position_sec=(None,),
+ batch_size=1,
+ ),
+ db, s2s
+ )
+
+ for outputs in loop.run():
+ assert_that(len(outputs)).is_equal_to(1)
+ assert_that(outputs[0]).is_instance_of(torch.Tensor)
+ assert_that(torch.all(outputs[0][:, -loop.n_steps:] != 0)).is_true()
+
+
+@pytest.mark.parametrize(
+ "given_io",
+ [
+ IOSpec.magspec_io(IOSpec.MagSpecIOConfig()),
+ IOSpec.mulaw_io(IOSpec.MuLawIOConfig(input_module_type="embedding"))
+ ]
+)
+def test_should_train(tmp_db, tmp_path, given_io):
+ s2s = Seq2SeqLSTMNetwork.from_config(
+ Seq2SeqLSTMNetwork.Config(
+ io_spec=given_io,
+ hop=2
+ )
+ )
+
+ db = tmp_db("train-loop.h5")
+ config = TrainARMConfig(
+ root_dir=str(tmp_path),
+ limit_train_batches=2,
+ batch_size=2,
+ batch_length=s2s.config.hop,
+ downsampling=64,
+ max_epochs=2,
+ every_n_epochs=1,
+ CHECKPOINT_TRAINING=True,
+ MONITOR_TRAINING=True,
+ OUTPUT_TRAINING=True,
+ )
+
+ loop = TrainARMLoop.from_config(
+ config, dataset=db, network=s2s
+ )
+
+ loop.run()
+
+ content = os.listdir(os.path.join(str(tmp_path), loop.hash_))
+ assert_that(content).contains("hp.yaml", "outputs", "epoch=1.ckpt")
+
+ outputs = os.listdir(os.path.join(str(tmp_path), loop.hash_, "outputs"))
+ assert_that([os.path.splitext(o)[-1] for o in outputs]).contains(".mp3")
diff --git a/tests/test_train_loop.py b/tests/test_train_loop.py
new file mode 100644
index 0000000..8f39e97
--- /dev/null
+++ b/tests/test_train_loop.py
@@ -0,0 +1,50 @@
+import os
+from assertpy import assert_that
+
+from .test_utils import tmp_db, TestARM
+import mimikit as mmk
+
+
+def test_should_run(tmp_db, tmp_path):
+ db = tmp_db("train-loop.h5")
+ extractor = mmk.Extractor("signal", mmk.FileToSignal(16000))
+ net = TestARM(
+ TestARM.Config(io_spec=mmk.IOSpec(
+ inputs=(
+ mmk.InputSpec(
+ extractor_name=extractor.name,
+ transform=mmk.Normalize(),
+ module=mmk.LinearIO()
+ ).bind_to(extractor),
+ ),
+ targets=(
+ mmk.TargetSpec(
+ extractor_name=extractor.name,
+ transform=mmk.Normalize(),
+ module=mmk.LinearIO(),
+ objective=mmk.Objective("reconstruction")
+ ).bind_to(extractor),
+ )
+ ))
+ )
+ config = mmk.TrainARMConfig(
+ root_dir=str(tmp_path),
+ limit_train_batches=4,
+ max_epochs=4,
+ every_n_epochs=1,
+ CHECKPOINT_TRAINING=True,
+ MONITOR_TRAINING=True,
+ OUTPUT_TRAINING=True,
+ )
+
+ loop = mmk.TrainARMLoop.from_config(
+ config, dataset=db, network=net
+ )
+
+ loop.run()
+
+ content = os.listdir(os.path.join(str(tmp_path), loop.hash_))
+ assert_that(content).contains("hp.yaml", "outputs", "epoch=1.ckpt")
+
+ outputs = os.listdir(os.path.join(str(tmp_path), loop.hash_, "outputs"))
+ assert_that([os.path.splitext(o)[-1] for o in outputs]).contains(".mp3")
diff --git a/tests/test_true.py b/tests/test_true.py
deleted file mode 100644
index 1f3b77f..0000000
--- a/tests/test_true.py
+++ /dev/null
@@ -1,2 +0,0 @@
-def test_true():
- assert True
diff --git a/tests/test_utils.py b/tests/test_utils.py
new file mode 100644
index 0000000..fd2b6c2
--- /dev/null
+++ b/tests/test_utils.py
@@ -0,0 +1,122 @@
+from typing import Tuple, Dict, Set
+import dataclasses as dtc
+
+import pytest
+import torch
+from assertpy import assert_that
+
+import numpy as np
+import h5mapper as h5m
+from torch import nn
+
+from mimikit import ARM, IOSpec
+from mimikit.networks.arm import NetworkConfig
+
+__all__ = [
+ "TestDB",
+ "tmp_db",
+ "TestARM",
+]
+
+from mimikit.features.item_spec import ItemSpec
+
+
+class RandSignal(h5m.Feature):
+
+ def load(self, source):
+ return (np.random.rand(32000) * 2 - 1).astype(np.float32)
+
+
+class RandLabel(h5m.Feature):
+
+ def load(self, source):
+ return np.random.randint(0, 256, (32000,))
+
+
+class TestDB(h5m.TypedFile):
+ signal = RandSignal()
+ label = RandLabel()
+
+
+@pytest.fixture
+def tmp_db(tmp_path):
+ root = (tmp_path / "dbs")
+ root.mkdir()
+
+ def create_func(filename) -> TestDB:
+ TestDB.create(
+ str(root / filename),
+ sources=tuple(map(str, range(2))),
+ mode="w", keep_open=False, parallelism='none'
+ )
+ return TestDB(str(root / filename))
+
+ return create_func
+
+
+def test_fixture_db(tmp_db):
+ db = tmp_db("temp")
+
+ assert_that(db.signal).is_not_none()
+ assert_that(db.signal[:32]).is_instance_of(np.ndarray)
+
+
+class TestARM(ARM, nn.Module):
+ @dtc.dataclass
+ class Config(NetworkConfig):
+ io_spec: IOSpec = None
+
+ @property
+ def config(self) -> NetworkConfig:
+ return self._config
+
+ @property
+ def rf(self):
+ return 8
+
+ def train_batch(self, item_spec: ItemSpec) -> \
+ Tuple[Tuple[h5m.Input, ...], Tuple[h5m.Input, ...]]:
+ return tuple(
+ feat.to_batch_item(item_spec)
+ for feat in self.config.io_spec.inputs
+ ), tuple(
+ feat.to_batch_item(item_spec)
+ for feat in self.config.io_spec.targets
+ )
+
+ def test_batch(self, item_spec: ItemSpec) ->\
+ Tuple[Tuple[h5m.Input, ...], Tuple[h5m.Input, ...]]:
+ return self.train_batch(item_spec)
+
+ @property
+ def generate_params(self) -> Set[str]:
+ return set()
+
+ def before_generate(self, prompts: Tuple[torch.Tensor, ...], batch_index: int) -> None:
+ return
+
+ def generate_step(self, inputs: Tuple[torch.Tensor, ...], *, t: int = 0, **parameters: Dict[str, torch.Tensor]) -> \
+ Tuple[torch.Tensor, ...]:
+ return tuple(self.forward(i) for i in inputs)
+
+ def after_generate(self, final_outputs: Tuple[torch.Tensor, ...], batch_index: int) -> None:
+ return
+
+ @classmethod
+ def from_config(cls, config: NetworkConfig):
+ return cls(config)
+
+ def __init__(self, config: NetworkConfig):
+ super(TestARM, self).__init__()
+ self._config = config
+ self.fc = nn.Linear(1, 1)
+
+ def forward(self, inputs):
+ if self.training:
+ if isinstance(inputs, (tuple, list)):
+ return tuple(self.fc(x.unsqueeze(-1)).squeeze() for x in inputs)
+ return self.fc(inputs.unsqueeze(-1)).squeeze()
+ else:
+ if isinstance(inputs, (tuple, list)):
+ return tuple(x[:, -1:] for x in inputs)
+ return inputs[:, -1:]
diff --git a/tests/test_wavenet.py b/tests/test_wavenet.py
new file mode 100644
index 0000000..a8e96cc
--- /dev/null
+++ b/tests/test_wavenet.py
@@ -0,0 +1,248 @@
+import os
+from typing import Tuple
+
+import pytest
+from assertpy import assert_that
+
+import torch
+from torch.nn import Sigmoid
+
+from mimikit import IOSpec, InputSpec, TargetSpec, LinearIO, Objective, \
+ FileToSignal, Normalize, TrainARMConfig, TrainARMLoop
+from mimikit.features.extractor import Extractor
+from mimikit.checkpoint import Checkpoint
+from mimikit.networks.wavenet_v2 import WNLayer, WaveNet
+
+from .test_utils import tmp_db
+
+
+def inputs_(b=8, t=32, d=16):
+ return torch.randn(b, d, t)
+
+
+@pytest.mark.parametrize(
+ "with_gate",
+ [True, False]
+)
+@pytest.mark.parametrize(
+ "feed_skips",
+ [True, False]
+)
+@pytest.mark.parametrize(
+ "given_input_dim",
+ [None, 7]
+)
+@pytest.mark.parametrize(
+ "given_pad",
+ [0, 1]
+)
+@pytest.mark.parametrize(
+ "given_residuals",
+ [None, 5, 7]
+)
+@pytest.mark.parametrize(
+ "given_skips",
+ [None, 34]
+)
+@pytest.mark.parametrize(
+ "given_1x1",
+ [(), (3,), (8, 2,), (4, 9, 64)]
+)
+@pytest.mark.parametrize(
+ "given_dil",
+ [(16,), (32,), (8,)]
+)
+def test_layer_should_support_various_graphs(
+ given_dil: Tuple[int], given_1x1: Tuple[int], given_skips, given_residuals,
+ given_pad, given_input_dim, feed_skips, with_gate
+):
+ # if given_residuals is not None and given_input_dim is not None
+ under_test = WNLayer(
+ input_dim=given_input_dim,
+ dims_dilated=given_dil,
+ dims_1x1=given_1x1,
+ skips_dim=given_skips,
+ residuals_dim=given_residuals,
+ pad_side=given_pad,
+ act_g=Sigmoid() if with_gate else None
+ )
+ B, T = 1, 8
+ # HOW INPUT_DIM WORKS:
+ if given_input_dim is None:
+ if given_residuals is None:
+ input_dim = given_dil[0]
+ else:
+ input_dim = given_residuals
+ else:
+ input_dim = given_input_dim
+
+ skips = None if not feed_skips or given_skips is None else inputs_(B, T, given_skips)
+
+ given_inputs = (
+ (inputs_(B, T, input_dim),), tuple(inputs_(B, T, d) for d in given_1x1), skips
+ )
+ # HOW OUTPUT DIM WORKS:
+ if given_residuals is not None:
+ if given_input_dim is not None and given_input_dim != given_residuals:
+ # RESIDUALS ARE SKIPPED!
+ expected_out_dim = given_dil[0]
+ else:
+ expected_out_dim = given_residuals
+ else:
+ expected_out_dim = given_dil[0]
+
+ outputs = under_test(*given_inputs)
+
+ assert_that(type(outputs)).is_equal_to(tuple)
+ assert_that(len(outputs)).is_equal_to(2)
+
+ assert_that(outputs[0].size(1)).is_equal_to(expected_out_dim)
+ if given_skips is not None:
+ assert_that(outputs[1].size(1)).is_equal_to(given_skips)
+
+ if bool(given_pad):
+ assert_that(outputs[0].size(-1)).is_equal_to(T)
+ if given_skips is not None:
+ assert_that(outputs[1].size(-1)).is_equal_to(T)
+ assert_that(outputs[1].size(-1)).is_equal_to(outputs[0].size(-1))
+ else:
+ assert_that(outputs[0].size(-1)).is_less_than(T)
+ if given_skips is not None:
+ assert_that(outputs[1].size(-1)).is_less_than(T)
+ assert_that(outputs[1].size(-1)).is_equal_to(outputs[0].size(-1))
+
+
+def test_should_instantiate_from_default_config():
+ given_config = WaveNet.Config(io_spec=IOSpec.mulaw_io(
+ IOSpec.MuLawIOConfig(input_module_type="embedding")
+ ))
+
+ under_test = WaveNet.from_config(given_config)
+
+ assert_that(type(under_test)).is_equal_to(WaveNet)
+ assert_that(len(under_test.layers)).is_equal_to(given_config.blocks[0])
+
+
+def test_should_load_when_saved(tmp_path_factory):
+ given_config = WaveNet.Config(io_spec=IOSpec.mulaw_io(
+ IOSpec.MuLawIOConfig(input_module_type="embedding")
+ ))
+ root = str(tmp_path_factory.mktemp("ckpt"))
+ wn = WaveNet.from_config(given_config)
+ ckpt = Checkpoint(id="123", epoch=1, root_dir=root)
+
+ ckpt.create(network=wn)
+ loaded = ckpt.network
+
+ assert_that(type(loaded)).is_equal_to(WaveNet)
+
+
+@pytest.mark.parametrize(
+ "given_temp",
+ [None, 0.5, (1.,)]
+)
+def test_generate(
+ given_temp
+):
+ given_config = WaveNet.Config(io_spec=IOSpec.mulaw_io(
+ IOSpec.MuLawIOConfig(input_module_type="embedding")
+ ))
+ q_levels = given_config.io_spec.inputs[0].elem_type.size
+ wn = WaveNet.from_config(given_config)
+
+ given_prompt = torch.randint(0, q_levels, (1, 128,))
+ wn.eval()
+ # For now, prompts are just Tuple[T] (--> Tuple[Tuple[T, ...]] for multi inputs!)
+ wn.before_generate((given_prompt,), batch_index=0)
+ output = wn.generate_step(
+ (given_prompt[:, -wn.rf:],),
+ t=given_prompt.size(1),
+ temperature=given_temp)
+ wn.after_generate(output, batch_index=0)
+
+ assert_that(type(output)).is_equal_to(tuple)
+ assert_that(output[0].size(0)).is_equal_to(given_prompt.size(0))
+ assert_that(output[0].ndim).is_equal_to(given_prompt.ndim)
+
+
+def test_should_support_multiple_io():
+ extractor = Extractor("signal", FileToSignal(16000))
+ given_io = IOSpec(
+ inputs=(
+ InputSpec(
+ extractor_name=extractor.name,
+ transform=Normalize(),
+ module=LinearIO()
+ ).bind_to(extractor),
+ InputSpec(
+ extractor_name=extractor.name,
+ transform=Normalize(),
+ module=LinearIO()
+ ).bind_to(extractor),
+ ),
+ targets=(
+ TargetSpec(
+ extractor_name=extractor.name,
+ transform=Normalize(),
+ module=LinearIO(),
+ objective=Objective("reconstruction")
+ ).bind_to(extractor),
+ TargetSpec(
+ extractor_name=extractor.name,
+ transform=Normalize(),
+ module=LinearIO(),
+ objective=Objective("reconstruction")
+ ).bind_to(extractor)), )
+ wn = WaveNet.from_config(WaveNet.Config(
+ io_spec=given_io,
+ dims_dilated=(128,),
+ dims_1x1=(44,)
+ ))
+
+ assert_that(wn).is_instance_of(WaveNet)
+
+ given_inputs = (
+ torch.randn(1, 32, 1),
+ torch.randn(1, 32, 1),
+ )
+
+ outputs = wn.forward(given_inputs)
+
+ assert_that(outputs).is_instance_of(tuple)
+ assert_that(outputs[0].size()).is_equal_to(outputs[1].size())
+
+
+@pytest.mark.parametrize(
+ "given_io",
+ [
+ IOSpec.magspec_io(IOSpec.MagSpecIOConfig()),
+ IOSpec.mulaw_io(IOSpec.MuLawIOConfig(input_module_type="embedding"))
+ ]
+)
+def test_should_train(tmp_db, tmp_path, given_io):
+ given_config = WaveNet.Config(io_spec=given_io, blocks=(3,))
+ wn = WaveNet.from_config(given_config)
+ db = tmp_db("train-loop.h5")
+ config = TrainARMConfig(
+ root_dir=str(tmp_path),
+ limit_train_batches=2,
+ batch_size=2,
+ batch_length=8,
+ max_epochs=2,
+ every_n_epochs=1,
+ CHECKPOINT_TRAINING=True,
+ MONITOR_TRAINING=True,
+ OUTPUT_TRAINING=True,
+ )
+
+ loop = TrainARMLoop.from_config(
+ config, dataset=db, network=wn
+ )
+
+ loop.run()
+
+ content = os.listdir(os.path.join(str(tmp_path), loop.hash_))
+ assert_that(content).contains("hp.yaml", "outputs", "epoch=1.ckpt")
+
+ outputs = os.listdir(os.path.join(str(tmp_path), loop.hash_, "outputs"))
+ assert_that([os.path.splitext(o)[-1] for o in outputs]).contains(".mp3")
\ No newline at end of file
diff --git a/todos.md b/todos.md
new file mode 100644
index 0000000..d86a43c
--- /dev/null
+++ b/todos.md
@@ -0,0 +1,89 @@
+# Todos
+
+## v0.4.0
+
+### necessary
+
+- TEST Notebooks
+ - FreqNet
+ - SampleRNN
+ - Seq2Seq
+ - Generate from Checkpoint
+ - Ensemble
+ - Clusterizer App
+- Colab TESTS
+
+### nice to have
+
+- TEST M1 Support
+- Ensemble / Generate nb
+ - Prompt UI
+- Ensemble
+ - working nearest neighbour generator
+ - Generator classes (average from several models?, ...)
+ - parameters (fade-in/out?, gain?, noise?, ...)
+- Eval checkpoint Notebook
+- Segmentation Notebook
+- PocoNet / poco Wavenet
+- Train notebook(s) with UI
+
+### long term...
+
+- Transformer
+- support for TBPTT in freq domain
+- SampleRNN in freq domain (tier_i ==> n_fft instead of frame_size)
+- flexible IO declaration (fft, signal+segments, learnable fft, ...)
+- more audio features
+ - class ClusterLabel(Feature):
+ - class SegmentLabel(Feature):
+ - from rec mat
+ - KMer (seqPrior)
+ - BitVector
+ - Quantize / Digitize / Linearize
+ - [/] MelSpec
+ - [/] MFCC
+ - TimeIndex
+ - Scaler
+ - MinMax
+ - Normal
+ - Augmentation(functional, prob)
+ ...................................
+ - tuple or not tuple
+ - AR Feature vs. Fixture vs. Auxiliary Target (vs. kwargs)
+ - AR --> Input == Target --> shared data
+ prompt must be: prior_t data + n_steps blank
+ !! target interface must come from data
+ - Fixture --> no target
+ --> data is read
+ prompt must be: prior_t + n_steps data
+ !! this modifies the length of the Dataset!
+ --> data is transformed from (possibly AR) input
+ !! this DOESN'T modify the length
+ - Auxiliary --> no input --> output is just collected
+ prompt must be: priot_t + n_steps blank
+ - Batch Alignment for
+ - Multiple SR
+ - Multiple Domains
+ - Same Variable, different repr (e.g. x_0 -> Raw, MuLaw --> ?)
+
+- More Networks
+ - SampleGan (WaveGan with labeled segments?)
+ - Stable Diffusion Experiment
+ - FFT_RNN
+ ....
+- Loss Terms
+- Hooks for
+ - storing outputs
+ - modifying generate_step() implementation on the fly...
+- flowtorch
+- huggingface/dffusers/transformers
+- Multi-Checkpoint Models (stochastic averaging)
+- Resampler classes with `n_layers`
+- jitability / torch==2.0 compile()
+ - no `*tuple` expr...
+- Network Visualizer (UI)
+- Resume Training
+ - Optimizer in Checkpoint
+- Upgrade python 3.9 ? (colab is 3.7.15...)
+
+
\ No newline at end of file