Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
antoinedaurat committed Dec 20, 2023
1 parent 6226ffd commit ed9ed9b
Show file tree
Hide file tree
Showing 22 changed files with 998 additions and 215 deletions.
36 changes: 29 additions & 7 deletions mimikit/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ def cached_property(f):

__all__ = [
'Checkpoint',
'CheckpointBank'
'CheckpointBank',
"CheckpointsMix"
]


Expand Down Expand Up @@ -64,7 +65,7 @@ def save(cls,
net_dict = network.state_dict()
opt_dict = optimizer.state_dict() if optimizer is not None else {}
cls.network.set_ds_kwargs(net_dict)
#if optimizer is not None:
# if optimizer is not None:
# cls.optimizer.set_ds_kwargs(opt_dict)
os.makedirs(os.path.split(filename)[0], exist_ok=True)

Expand All @@ -73,9 +74,9 @@ def save(cls,
bank.network.add("state_dict", h5m.TensorDict.format(net_dict))

if optimizer is not None:
#bank.optimizer.add("state_dict", h5m.TensorDict.format(opt_dict))
# bank.optimizer.add("state_dict", h5m.TensorDict.format(opt_dict))
torch.save(opt_dict, os.path.splitext(filename)[0] + ".opt")

if training_config is not None:
bank.attrs["dataset"] = training_config.dataset.serialize()
bank.attrs["training"] = training_config.training.serialize()
Expand All @@ -87,7 +88,7 @@ def save(cls,
extractors=tuple(schema.values())).serialize()
if trainer_state is not None:
bank.attrs["trainer_state"] = OmegaConf.to_yaml(OmegaConf.structured(trainer_state))

bank.flush()
bank.close()
return bank
Expand All @@ -98,6 +99,7 @@ class Checkpoint:
id: str
epoch: int
root_dir: str = "./"
weight: float = 1.

def create(self,
network: ConfigurableModule,
Expand Down Expand Up @@ -141,6 +143,13 @@ def training_config(self) -> TrainingConfig:
bank = CheckpointBank(self.os_path, 'r')
return Config.deserialize(bank.attrs["training"])

@cached_property
def network_state_dict(self):
state_dict = self.bank.network.get("state_dict")
if self.weight != 1.:
state_dict = {k: v * self.weight for k, v in state_dict.items()}
return state_dict

@cached_property
def network(self) -> ConfigurableModule:
cfg: NetworkConfig = self.network_config
Expand All @@ -164,11 +173,24 @@ def optimizer_state(self):
if os.path.isfile(opt_path):
return torch.load(opt_path)
return None

@cached_property
def trainer_state(self):
state = self.bank.attrs.get('trainer_state', None)
if state is not None:
return OmegaConf.create(state)
return None
# Todo: method to add state_dict mul by weights -> def average(self, *others)


def CheckpointsMix(*checkpoints: Checkpoint):
net_cfg = checkpoints[0].network_config
net_cfg.io_spec.bind_to(checkpoints[0].dataset_config)
cls = net_cfg.owner_class
net_state_dict = checkpoints[0].network_state_dict
for ckpt in checkpoints[1:]:
other_dict = ckpt.network_state_dict
for k, v in other_dict.items():
net_state_dict.update({k: net_state_dict[k] + v})
net = cls.from_config(net_cfg)
net.load_state_dict(net_state_dict, strict=True)
return net
34 changes: 4 additions & 30 deletions mimikit/extract/segment.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import numpy as np
from librosa.util import peak_pick, localmax
from librosa.util import peak_pick
from librosa.sequence import dtw
from scipy.ndimage.filters import minimum_filter1d
from sklearn.metrics import pairwise_distances as pwd
from typing import List
from numba import njit, prange, float64, intp
import matplotlib.pyplot as plt
import matplotlib as mpl

from mimikit.features.functionals import pick_globally_sorted_maxes

mpl.rcParams['agg.path.chunksize'] = 10000

__all__ = [
Expand Down Expand Up @@ -132,34 +134,6 @@ def discontinuity_scores(
return scores


def pick_globally_sorted_maxes(x, wait_before, wait_after, min_strength=0.02):
mn = minimum_filter1d(
x, wait_before + wait_after, mode='constant', cval=x.min()
)
glob_rg = x.max() - x.min()
strength = (x - mn) / glob_rg
# filter out peaks with too few contrasts
mx = localmax(x) & (strength >= min_strength)

mx_indices = mx.nonzero()[0][np.argsort(-x[mx])]

final_maxes = np.zeros_like(x, dtype=bool)

for m in mx_indices:
i, j = max(0, m - wait_before), min(x.shape[0], m + wait_after)
if np.any(final_maxes[i:j]):
continue
else:
# make sure the max dominates left and right
# aka we are not globally increasing/decreasing around it
mu_l = x[i:m].mean()
mu_r = x[m:j].mean()
mx = x[m]
if mx > mu_l and mx > mu_r:
final_maxes[m] = True
return final_maxes.nonzero()[0]


def from_recurrence_matrix(X,
kernel_sizes=[6],
min_dur=4,
Expand Down
3 changes: 3 additions & 0 deletions mimikit/features/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ def create(self, **kwargs) -> h5m.TypedFile:
else:
fixed_sources += [src]
self.sources = tuple(fixed_sources)
if "filename" in kwargs:
self.filename = kwargs["filename"]
kwargs.pop("filename")
db = cls.create(self.filename, fixed_sources, **kwargs)
db.attrs["config"] = self.serialize()
return db
Expand Down
131 changes: 131 additions & 0 deletions mimikit/features/functionals.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import torchaudio.functional as F
import torchaudio.transforms as T
import numpy as np
from librosa.util import localmax
from scipy.ndimage.filters import minimum_filter1d
from scipy.signal import lfilter
from scipy.interpolate import interp1d
from sklearn.decomposition import PCA as skPCA, \
Expand Down Expand Up @@ -34,6 +36,8 @@
'MuLawExpand',
'ALawCompress',
'ALawExpand',
"PBitsCompress",
"PBitsExpand",
'STFT',
'ISTFT',
'MagSpec',
Expand All @@ -50,13 +54,17 @@
'derivative_torch',
'Derivative',
"AutoConvolve",
"TopKFilter",
"GlobalPeaks",
"F0Filter",
"NearestNeighborFilter",
"PCA",
"NMF",
"FactorAnalysis",
"pick_globally_sorted_maxes"
]


N_FFT = 2048
HOP_LENGTH = 512
SR = 22050
Expand Down Expand Up @@ -447,6 +455,50 @@ def inv(self):
return ALawCompress(self.A, self.q_levels)


@dtc.dataclass
class PBitsCompress(Functional):

q_levels: int = Q_LEVELS
compression: float = 1.

def __post_init__(self):
self.width = int(np.log2(self.q_levels))
self.compressor = MuLawCompress(self.q_levels, self.compression)

@property
def inv(self) -> "Functional":
return PBitsExpand(self.q_levels, self.compression)

def np_func(self, inputs):
qx = self.compressor(inputs)
return np.not_equal(np.bitwise_and(qx[:, None, ...], 2 ** np.arange(self.width, -1, -1, -1)), 0.).astype(int)

def torch_func(self, inputs):
qx = self.compressor(inputs)
return qx.unsqueeze(1).bitwise_and(2 ** torch.arange(self.width - 1, -1, -1)).ne(0.).int()


@dtc.dataclass
class PBitsExpand(Functional):

q_levels: int = Q_LEVELS
compression: float = 1.

def __post_init__(self):
self.width = int(np.log2(self.q_levels))
self.expander = MuLawExpand(self.q_levels, self.compression)

def np_func(self, inputs):
return self.expander((inputs * (2 ** np.arange(self.width - 1, -1, -1))).sum(dim=1))

def torch_func(self, inputs):
return self.expander((inputs * (2 ** torch.arange(self.width - 1, -1, -1))).sum(dim=1))

@property
def inv(self) -> "Functional":
return PBitsCompress(self.q_levels, self.compression)


@dtc.dataclass
class STFT(Functional):
n_fft: int = N_FFT
Expand Down Expand Up @@ -1036,6 +1088,57 @@ def inv(self) -> "Functional":
return Identity()


@dtc.dataclass
class TopKFilter(Functional):

k: int = 1

@property
def unit(self) -> Optional[Unit]:
return None

@property
def elem_type(self) -> Optional[EventType]:
return None

def np_func(self, inputs):
S = inputs
S_hat = np.zeros_like(S)
idx = np.argpartition(S, S.shape[-1]-self.k, axis=-1)
rg = np.arange(S.shape[0])[:, None]
S_hat[rg, idx[..., -self.k:]] = S[rg, idx[..., -self.k:]]
return S_hat

def torch_func(self, inputs):
# TODO
pass

@property
def inv(self) -> "Functional":
return Identity()


@dtc.dataclass
class GlobalPeaks(Functional):

window: int = 3
min_strength: float = .02

def np_func(self, inputs):
S_hat = np.zeros_like(inputs)
for i in range(inputs.shape[0]):
peaks = pick_globally_sorted_maxes(inputs[i], self.window//2, self.window//2, self.min_strength)
S_hat[i, peaks] = inputs[i, peaks]
return S_hat

def torch_func(self, inputs):
pass

@property
def inv(self) -> "Functional":
pass


@dtc.dataclass
class F0Filter(Functional):
n_overtone: int = 4
Expand Down Expand Up @@ -1201,3 +1304,31 @@ def torch_func(self, inputs):
@property
def inv(self) -> "Functional":
return Identity()


def pick_globally_sorted_maxes(x, wait_before, wait_after, min_strength=0.02):
mn = minimum_filter1d(
x, wait_before + wait_after, mode='constant', cval=x.min()
)
glob_rg = x.max() - x.min()
strength = (x - mn) / glob_rg
# filter out peaks with too few contrasts
mx = localmax(x) & (strength >= min_strength)

mx_indices = mx.nonzero()[0][np.argsort(-x[mx])]

final_maxes = np.zeros_like(x, dtype=bool)

for m in mx_indices:
i, j = max(0, m - wait_before), min(x.shape[0], m + wait_after)
if np.any(final_maxes[i:j]):
continue
else:
# make sure the max dominates left and right
# aka we are not globally increasing/decreasing around it
mu_l = x[i:m].mean()
mu_r = x[m:j].mean()
mx = x[m]
if mx > mu_l and mx > mu_r:
final_maxes[m] = True
return final_maxes.nonzero()[0]
Loading

0 comments on commit ed9ed9b

Please sign in to comment.