Skip to content

Commit

Permalink
wavenet cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
antoinedaurat committed Dec 22, 2023
1 parent ed9ed9b commit 58a7bb8
Show file tree
Hide file tree
Showing 7 changed files with 220 additions and 45 deletions.
2 changes: 1 addition & 1 deletion mimikit/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,6 @@
from .resamplers import *
from .io import *
from .targets import *

from .logistic_mix import *

__all__ = [_ for _ in dir() if not _.startswith("_")]
151 changes: 151 additions & 0 deletions mimikit/modules/logistic_mix.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
import torch
from torch.nn import functional as F
import numpy as np


__all__ = [
"discretized_mix_logistic_loss",
"sample_from_discretized_mix_logistic"
]


def log_sum_exp(x):
""" numerically stable log_sum_exp implementation that prevents overflow """
# TF ordering
axis = len(x.size()) - 1
m, _ = torch.max(x, dim=axis)
m2, _ = torch.max(x, dim=axis, keepdim=True)
return m + torch.log(torch.sum(torch.exp(x - m2), dim=axis))


def discretized_mix_logistic_loss(y_hat, y, num_classes=256,
log_scale_min=-7.0, reduce=True):
"""Discretized mixture of logistic distributions loss
Note that it is assumed that input is scaled to [-1, 1].
Args:
y_hat (Tensor): Predicted output (B x C x T)
y (Tensor): Target (B x T x 1).
num_classes (int): Number of classes
log_scale_min (float): Log scale minimum value
reduce (bool): If True, the losses are averaged or summed for each
minibatch.
Returns
Tensor: loss
"""
assert y_hat.dim() == 3
assert y_hat.size(1) % 3 == 0
nr_mix = y_hat.size(1) // 3

# (B x T x C)
y_hat = y_hat.transpose(1, 2)

# unpack parameters. (B, T, num_mixtures) x 3
logit_probs = y_hat[:, :, :nr_mix]
means = y_hat[:, :, nr_mix:2 * nr_mix]
log_scales = torch.clamp(y_hat[:, :, 2 * nr_mix:3 * nr_mix], min=log_scale_min)

# B x T x 1 -> B x T x num_mixtures
y = y.expand_as(means)

centered_y = y - means
inv_stdv = torch.exp(-log_scales)
plus_in = inv_stdv * (centered_y + 1. / (num_classes - 1))
cdf_plus = torch.sigmoid(plus_in)
min_in = inv_stdv * (centered_y - 1. / (num_classes - 1))
cdf_min = torch.sigmoid(min_in)

# log probability for edge case of 0 (before scaling)
# equivalent: torch.log(torch.sigmoid(plus_in))
log_cdf_plus = plus_in - F.softplus(plus_in)

# log probability for edge case of 255 (before scaling)
# equivalent: (1 - torch.sigmoid(min_in)).log()
log_one_minus_cdf_min = -F.softplus(min_in)

# probability for all other cases
cdf_delta = cdf_plus - cdf_min

mid_in = inv_stdv * centered_y
# log probability in the center of the bin, to be used in extreme cases
# (not actually used in our code)
log_pdf_mid = mid_in - log_scales - 2. * F.softplus(mid_in)

# tf equivalent
"""
log_probs = tf.where(x < -0.999, log_cdf_plus,
tf.where(x > 0.999, log_one_minus_cdf_min,
tf.where(cdf_delta > 1e-5,
tf.log(tf.maximum(cdf_delta, 1e-12)),
log_pdf_mid - np.log(127.5))))
"""
# TODO: cdf_delta <= 1e-5 actually can happen. How can we choose the value
# for num_classes=65536 case? 1e-7? not sure..
inner_inner_cond = (cdf_delta > 1e-5).float()

inner_inner_out = inner_inner_cond * \
torch.log(torch.clamp(cdf_delta, min=1e-12)) + \
(1. - inner_inner_cond) * (log_pdf_mid - np.log((num_classes - 1) / 2))
inner_cond = (y > 0.999).float()
inner_out = inner_cond * log_one_minus_cdf_min + (1. - inner_cond) * inner_inner_out
cond = (y < -0.999).float()
log_probs = cond * log_cdf_plus + (1. - cond) * inner_out

log_probs = log_probs + F.log_softmax(logit_probs, -1)

if reduce:
return -torch.sum(log_sum_exp(log_probs))
else:
return -log_sum_exp(log_probs).unsqueeze(-1)


def to_one_hot(tensor, n, fill_with=1.):
# we perform one hot encore with respect to the last axis
one_hot = torch.FloatTensor(tensor.size() + (n,)).zero_()
if tensor.is_cuda:
one_hot = one_hot.cuda()
one_hot.scatter_(len(tensor.size()), tensor.unsqueeze(-1), fill_with)
return one_hot


def sample_from_discretized_mix_logistic(y, log_scale_min=-7.0,
clamp_log_scale=False):
"""
Sample from discretized mixture of logistic distributions
Args:
y (Tensor): B x C x T
log_scale_min (float): Log scale minimum value
Returns:
Tensor: sample in range of [-1, 1].
"""
assert y.size(1) % 3 == 0
nr_mix = y.size(1) // 3

# B x T x C
y = y.transpose(1, 2)
logit_probs = y[:, :, :nr_mix]

# sample mixture indicator from softmax
temp = logit_probs.data.new(logit_probs.size()).uniform_(1e-5, 1.0 - 1e-5)
temp = logit_probs.data - torch.log(- torch.log(temp))
_, argmax = temp.max(dim=-1)

# (B, T) -> (B, T, nr_mix)
one_hot = to_one_hot(argmax, nr_mix)
# select logistic parameters
means = torch.sum(y[:, :, nr_mix:2 * nr_mix] * one_hot, dim=-1)
log_scales = torch.sum(y[:, :, 2 * nr_mix:3 * nr_mix] * one_hot, dim=-1)
if clamp_log_scale:
log_scales = torch.clamp(log_scales, min=log_scale_min)
# sample from logistic & clip to interval
# we don't actually round to the nearest 8bit value when sampling
u = means.data.new(means.size()).uniform_(1e-5, 1.0 - 1e-5)
x = means + torch.exp(log_scales) * (torch.log(u) - torch.log(1. - u))

x = torch.clamp(torch.clamp(x, min=-1.), max=1.)

return x
2 changes: 1 addition & 1 deletion mimikit/modules/targets.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def forward(self, logits, *, temperature=None):
if self.training:
return logits
if temperature is None:
return logits.argmax(dim=-1, keepdim=True)
return logits.argmax(dim=-1, keepdim=False)
temperature = as_tensor(temperature, logits)
logits = logits / temperature
logits = logits - logits.logsumexp(-1, keepdim=True)
Expand Down
66 changes: 38 additions & 28 deletions mimikit/networks/wavenet_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,12 @@ def __init__(
stride: int = 1,
bias: bool = True,
dilation: int = 1,
with_affine_residuals: bool = False,
# TODO
act_skips: Optional[nn.Module] = None,
act_residuals: Optional[nn.Module] = None,
skips_groups: int = 1,
residuals_groups: int = 1,
layer_norm: Optional[nn.Module] = None,
# TODO
dropout: float = 0.,
):
super(WNLayer, self).__init__()
Expand All @@ -76,7 +76,6 @@ def __init__(
self.has_gated_units = act_g is not None
self.has_skips = skips_dim is not None
self.has_residuals = residuals_dim is not None and (input_dim is None or input_dim == residuals_dim)
self.has_affine_residuals = with_affine_residuals

if residuals_dim is None:
main_inner_dim = main_outer_dim = dims_dilated[0]
Expand All @@ -87,7 +86,7 @@ def __init__(
if self.has_residuals:
# cannot be false, just as a reminder:
assert input_dim is None or input_dim == residuals_dim, "input_dim and residuals_dim must be equal if both are not None"
in_dim = main_outer_dim if input_dim is None else input_dim
in_dim = main_outer_dim

kwargs_dil = dict(kernel_size=(kernel_size,), dilation=dilation, stride=stride, bias=bias, groups=groups)
kwargs_1x1 = dict(kernel_size=(1,), stride=stride, bias=bias)
Expand All @@ -114,13 +113,16 @@ def __init__(
self.conv_1x1 = nn.ModuleList([
nn.Conv1d(d, main_inner_dim, **kwargs_1x1) for d in dims_1x1
])
if self.has_skips:
self.conv_skip = nn.Conv1d(main_inner_dim, skips_dim, **kwargs_1x1, groups=skips_groups)
if self.has_residuals:
self.conv_res = nn.Conv1d(main_inner_dim, main_outer_dim, **kwargs_1x1, groups=residuals_groups)
if self.has_affine_residuals:
self.aff_res = ParametrizedLinear(in_dim, in_dim, as_1x1_conv=True)

if act_residuals is not None:
self.conv_res = nn.Sequential(self.conv_res, act_residuals)
if self.has_skips:
self.conv_skip = nn.Conv1d(main_outer_dim, skips_dim, **kwargs_1x1, groups=skips_groups)
if act_skips is not None:
self.conv_skip = nn.Sequential(self.conv_skip, act_skips)
if layer_norm is not None:
self.norm = layer_norm.clone()
# print("***********************")
# print(f"in_dim={in_dim} main_inner={main_inner_dim} main_outer={main_outer_dim}")
# for name, mod in self.named_modules():
Expand All @@ -145,34 +147,32 @@ def forward(self,
y_f, y_g = conv(x)
cond_f += y_f
cond_g += y_g
if self.has_affine_residuals:
inputs_dilated = (self.aff_res(inputs_dilated[0]), *inputs_dilated[1:])
x_f, x_g = self.conv_dil[0](inputs_dilated[0])
y = self.act_f(x_f + cond_f) * self.act_g(x_g + cond_g)
else:
cond = 0
for conv, x in zip(self.conv_1x1, inputs_1x1):
if not self.needs_padding:
x = self.trim_cause(x)
if self.has_affine_residuals:
x = self.aff_res(x) + x
cond += conv(x)
if self.has_affine_residuals:
inputs_dilated = (self.aff_res(inputs_dilated[0]), *inputs_dilated[1:])
y = self.conv_dil[0](inputs_dilated[0])
y = self.act_f(y + cond)

if self.has_residuals:
y = self.conv_res(y)
if self.apply_residuals:
# either x has been padded, or y is shorter -> we need to trim x!
x = self.trim_cause(inputs_dilated[0])
y = x + y
if self.norm is not None:
y = self.norm(y)
if self.has_skips:
if not self.needs_padding:
skips = self.trim_cause(skips) if skips is not None else skips
if skips is None:
skips = self.conv_skip(y)
else:
skips = self.conv_skip(y) + skips
if self.has_residuals:
# either x has been padded, or y is shorter -> we need to trim x!
x = self.trim_cause(inputs_dilated[0])
y = x + self.conv_res(y)
return y, skips

def trim_cause(self, x):
Expand All @@ -193,45 +193,57 @@ class Config(NetworkConfig):
residuals_dim: Optional[int] = None
apply_residuals: bool = False
skips_dim: Optional[int] = None
with_affine_residuals: bool = False
groups: int = 1
skips_groups: int = 1
residuals_groups: int = 1
act_f: ActivationEnum = "Tanh"
act_g: Optional[ActivationEnum] = "Sigmoid"
act_skips: Optional[ActivationEnum] = None
act_residuals: Optional[ActivationEnum] = None
pad_side: int = 0
stride: int = 1
bias: bool = True
layer_norm: bool = False
use_fast_generate: bool = False
tie_io_weights: bool = False
layerwise_inputs: bool = False
reverse_layer_order: bool = False

@classmethod
def get_layers(cls, config: "WaveNet.Config") -> List[WNLayer]:
kernel_sizes, dilation = cls.get_kernels_and_dilation(config.kernel_sizes, config.blocks)
last_layer = sum(config.blocks) - 1
return [
WNLayer(
input_dim=config.dims_dilated[0],
dims_dilated=config.dims_dilated, dims_1x1=config.dims_1x1,
# no residuals for last layer
residuals_dim=config.residuals_dim if n != sum(config.blocks) - 1 else None,
apply_residuals=config.apply_residuals and n != 0,
# no residuals for last layer if there are skips
residuals_dim=config.residuals_dim if n != last_layer and config.skips_dim is not None else None,
apply_residuals=config.apply_residuals if n != last_layer and config.skips_dim is not None else None,
skips_dim=config.skips_dim,
kernel_size=k,
groups=config.groups,
skips_groups=config.skips_groups,
residuals_groups=config.residuals_groups,
act_f=ActivationConfig(str(config.act_f)).get(),
act_g=ActivationConfig(str(config.act_g)).get() if config.act_g is not None else None,
act_skips=ActivationConfig(str(config.act_skips)).get() if config.act_skips is not None else None,
act_residuals=ActivationConfig(str(config.act_residuals)).get() if config.act_residuals is not None else None,
pad_side=config.pad_side,
stride=config.stride, bias=config.bias,
dilation=d,
with_affine_residuals=config.with_affine_residuals
layer_norm=nn.GroupNorm(1, config.dims_dilated[0] if config.residuals_dim is None else config.residuals_dim)\
if config.layer_norm else None,
)
for n, (k, d) in enumerate(zip(kernel_sizes, dilation))
]

@classmethod
def from_config(cls, config: "WaveNet.Config") -> "WaveNet":
layers = cls.get_layers(config)
all_dims = [*config.dims_dilated, *config.dims_1x1]
if config.residuals_dim is not None:
all_dims = [config.residuals_dim, *config.dims_1x1]
else:
all_dims = [*config.dims_dilated, *config.dims_1x1]
# set Inner Connection
input_modules = [spec.module.copy()
.set(out_dim=h_dim)
Expand Down Expand Up @@ -280,8 +292,6 @@ def forward(self, inputs: Tuple, **parameters):
dilated, skips = layer.forward(
inputs_dilated=(dilated,), inputs_1x1=in_1x1, skips=skips
)
if self._config.layerwise_inputs:
dilated = dilated + inputs[0][..., -dilated.size(-1):]
if not layer.needs_padding:
in_1x1 = tuple(layer.trim_cause(x) for x in in_1x1)
if self.has_skips:
Expand Down
18 changes: 18 additions & 0 deletions tests/test_logistic_mix.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from mimikit.modules.logistic_mix import *
import pytest
import torch
import matplotlib.pyplot as plt


def test_nothing():
B, C, T, num_classes = 1, 3, 64, 32
y_hat, y = torch.zeros(B, C, T), torch.linspace(-1, 1, B*T).view(B, T, 1)
y_hat[:, 1] = torch.linspace(1., -1, T).expand_as(y_hat[:, 1])
y_hat[:, 2] = .1
probs = discretized_mix_logistic_loss(y_hat, y, num_classes, reduce=False)

plt.figure()
plt.plot(y_hat[:, 1].squeeze().detach().cpu().numpy())
plt.plot(y[0].squeeze().detach().cpu().numpy())
plt.plot((-probs).squeeze().detach().cpu().numpy())
plt.show()
16 changes: 6 additions & 10 deletions tests/test_wavenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def test_layer_should_support_various_graphs(
):
# if given_residuals is not None and given_input_dim is not None
under_test = WNLayer(
input_dim=given_input_dim,
input_dim=given_input_dim if given_residuals is None else given_residuals,
dims_dilated=given_dil,
dims_1x1=given_1x1,
skips_dim=given_skips,
Expand All @@ -68,13 +68,13 @@ def test_layer_should_support_various_graphs(
)
B, T = 1, 8
# HOW INPUT_DIM WORKS:
if given_input_dim is None:
if given_residuals is None:
if given_residuals is None:
if given_input_dim is None:
input_dim = given_dil[0]
else:
input_dim = given_residuals
input_dim = given_input_dim
else:
input_dim = given_input_dim
input_dim = given_residuals

skips = None if not feed_skips or given_skips is None else inputs_(B, T, given_skips)

Expand All @@ -83,11 +83,7 @@ def test_layer_should_support_various_graphs(
)
# HOW OUTPUT DIM WORKS:
if given_residuals is not None:
if given_input_dim is not None and given_input_dim != given_residuals:
# RESIDUALS ARE SKIPPED!
expected_out_dim = given_dil[0]
else:
expected_out_dim = given_residuals
expected_out_dim = given_residuals
else:
expected_out_dim = given_dil[0]

Expand Down
Loading

0 comments on commit 58a7bb8

Please sign in to comment.