From 8f89dddd38375311e1d815462b0803f42630d249 Mon Sep 17 00:00:00 2001 From: Jan Buethe Date: Thu, 28 Sep 2023 15:16:16 +0200 Subject: [PATCH] added auto-regressive LaVoce implementation --- dnn/torch/osce/models/__init__.py | 2 + dnn/torch/osce/models/lavoce_400_ar.py | 285 ++++++++++++++++++ dnn/torch/osce/utils/layers/ar_filter.py | 226 ++++++++++++++ .../utils/layers/limited_adaptive_comb1d.py | 33 +- .../utils/layers/limited_adaptive_conv1d.py | 21 +- dnn/torch/osce/utils/layers/td_shaper.py | 25 +- 6 files changed, 576 insertions(+), 16 deletions(-) create mode 100644 dnn/torch/osce/models/lavoce_400_ar.py create mode 100644 dnn/torch/osce/utils/layers/ar_filter.py diff --git a/dnn/torch/osce/models/__init__.py b/dnn/torch/osce/models/__init__.py index 2a33f1f3f..e6bbbc363 100644 --- a/dnn/torch/osce/models/__init__.py +++ b/dnn/torch/osce/models/__init__.py @@ -32,6 +32,7 @@ from .lavoce import LaVoce from .lavoce_cont import LaVoceCont from .lavoce_400 import LaVoce400 +from .lavoce_400_ar import LaVoce400AR from .fd_discriminator import TFDMultiResolutionDiscriminator as FDMResDisc model_dict = { @@ -41,4 +42,5 @@ 'lavocecont': LaVoceCont, 'lavoce400': LaVoce400, 'fdmresdisc': FDMResDisc, + 'lavoce400ar': LaVoce400AR } diff --git a/dnn/torch/osce/models/lavoce_400_ar.py b/dnn/torch/osce/models/lavoce_400_ar.py new file mode 100644 index 000000000..4bbb6a171 --- /dev/null +++ b/dnn/torch/osce/models/lavoce_400_ar.py @@ -0,0 +1,285 @@ +""" +/* Copyright (c) 2023 Amazon + Written by Jan Buethe */ +/* + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions + are met: + + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER + OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, + PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR + PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF + LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING + NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*/ +""" + + +import torch +from torch import nn +import torch.nn.functional as F + +import numpy as np + +from utils.layers.limited_adaptive_comb1d import LimitedAdaptiveComb1d +from utils.layers.limited_adaptive_conv1d import LimitedAdaptiveConv1d +from utils.layers.ar_filter import ARFilter +from utils.layers.td_shaper import TDShaper +from utils.layers.noise_shaper import NoiseShaper +from utils.complexity import _conv1d_flop_count +from utils.endoscopy import write_data + +from models.nns_base import NNSBase +from models.lpcnet_feature_net import LPCNetFeatureNet +from .scale_embedding import ScaleEmbedding + +class LaVoce400AR(nn.Module): + """ Linear-Adaptive VOCodEr """ + FEATURE_FRAME_SIZE=160 + FRAME_SIZE=40 + + def __init__(self, + num_features=20, + pitch_embedding_dim=64, + cond_dim=256, + pitch_max=300, + kernel_size=15, + preemph=0.85, + comb_gain_limit_db=-6, + global_gain_limits_db=[-6, 6], + conv_gain_limits_db=[-6, 6], + norm_p=2, + avg_pool_k=4, + pulses=False): + + super().__init__() + + + self.num_features = num_features + self.cond_dim = cond_dim + self.pitch_max = pitch_max + self.pitch_embedding_dim = pitch_embedding_dim + self.kernel_size = kernel_size + self.preemph = preemph + self.pulses = pulses + + assert self.FEATURE_FRAME_SIZE % self.FRAME_SIZE == 0 + self.upsamp_factor = self.FEATURE_FRAME_SIZE // self.FRAME_SIZE + + # pitch embedding + self.pitch_embedding = nn.Embedding(pitch_max + 1, pitch_embedding_dim) + + # feature net + self.feature_net = LPCNetFeatureNet(num_features + pitch_embedding_dim, cond_dim, self.upsamp_factor) + + # noise shaper + self.noise_shaper = NoiseShaper(cond_dim, self.FRAME_SIZE) + + # comb filters + left_pad = self.kernel_size // 2 + right_pad = self.kernel_size - 1 - left_pad + self.cf1 = LimitedAdaptiveComb1d(self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, overlap_size=20, use_bias=False, padding=[left_pad, right_pad], max_lag=pitch_max + 1, gain_limit_db=comb_gain_limit_db, global_gain_limits_db=global_gain_limits_db, norm_p=norm_p) + self.cf2 = LimitedAdaptiveComb1d(self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, overlap_size=20, use_bias=False, padding=[left_pad, right_pad], max_lag=pitch_max + 1, gain_limit_db=comb_gain_limit_db, global_gain_limits_db=global_gain_limits_db, norm_p=norm_p) + + self.cf_ar = ARFilter(5, cond_dim, frame_size=self.FRAME_SIZE, overlap_size=20, padding=[2, 2], max_lag=pitch_max + 1, gain_limit_db=comb_gain_limit_db, norm_p=norm_p) + + self.af_prescale = LimitedAdaptiveConv1d(2, 1, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p) + self.af_mix = LimitedAdaptiveConv1d(3, 2, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p) + + # spectral shaping + self.af1 = LimitedAdaptiveConv1d(1, 2, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p) + + # non-linear transforms + self.tdshape1 = TDShaper(cond_dim, frame_size=self.FRAME_SIZE, avg_pool_k=avg_pool_k, innovate=True) + self.tdshape2 = TDShaper(cond_dim, frame_size=self.FRAME_SIZE, avg_pool_k=avg_pool_k) + self.tdshape3 = TDShaper(cond_dim, frame_size=self.FRAME_SIZE, avg_pool_k=avg_pool_k) + + # combinators + self.af2 = LimitedAdaptiveConv1d(2, 2, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p) + self.af3 = LimitedAdaptiveConv1d(2, 1, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p) + self.af4 = LimitedAdaptiveConv1d(2, 1, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p) + + # feature transforms + self.post_cf1 = nn.Conv1d(cond_dim, cond_dim, 2) + self.post_cf2 = nn.Conv1d(cond_dim, cond_dim, 2) + self.post_af1 = nn.Conv1d(cond_dim, cond_dim, 2) + self.post_af2 = nn.Conv1d(cond_dim, cond_dim, 2) + self.post_af3 = nn.Conv1d(cond_dim, cond_dim, 2) + + + def create_phase_signals(self, periods): + + batch_size = periods.size(0) + progression = torch.arange(1, self.FRAME_SIZE + 1, dtype=periods.dtype, device=periods.device).view((1, -1)) + progression = torch.repeat_interleave(progression, batch_size, 0) + + phase0 = torch.zeros(batch_size, dtype=periods.dtype, device=periods.device).unsqueeze(-1) + chunks = [] + for sframe in range(periods.size(1)): + f = (2.0 * torch.pi / periods[:, sframe]).unsqueeze(-1) + + if self.pulses: + alpha = torch.cos(f).view(batch_size, 1, 1) + chunk_sin = torch.sin(f * progression + phase0).view(batch_size, 1, self.FRAME_SIZE) + pulse_a = torch.relu(chunk_sin - alpha) / (1 - alpha) + pulse_b = torch.relu(-chunk_sin - alpha) / (1 - alpha) + + chunk = torch.cat((pulse_a, pulse_b), dim = 1) + else: + chunk_sin = torch.sin(f * progression + phase0).view(batch_size, 1, self.FRAME_SIZE) + chunk_cos = torch.cos(f * progression + phase0).view(batch_size, 1, self.FRAME_SIZE) + + chunk = torch.cat((chunk_sin, chunk_cos), dim = 1) + + phase0 = phase0 + self.FRAME_SIZE * f + + chunks.append(chunk) + + phase_signals = torch.cat(chunks, dim=-1) + + return phase_signals + + def flop_count(self, rate=16000, verbose=False): + + frame_rate = rate / self.FRAME_SIZE + + # feature net + feature_net_flops = self.feature_net.flop_count(frame_rate) + comb_flops = self.cf1.flop_count(rate) + self.cf2.flop_count(rate) + af_flops = self.af1.flop_count(rate) + self.af2.flop_count(rate) + self.af3.flop_count(rate) + self.af4.flop_count(rate) + self.af_mix.flop_count(rate) + feature_flops = (_conv1d_flop_count(self.post_cf1, frame_rate) + _conv1d_flop_count(self.post_cf2, frame_rate) + + _conv1d_flop_count(self.post_af1, frame_rate) + _conv1d_flop_count(self.post_af2, frame_rate) + _conv1d_flop_count(self.post_af3, frame_rate)) + + if verbose: + print(f"feature net: {feature_net_flops / 1e6} MFLOPS") + print(f"comb filters: {comb_flops / 1e6} MFLOPS") + print(f"adaptive conv: {af_flops / 1e6} MFLOPS") + print(f"feature transforms: {feature_flops / 1e6} MFLOPS") + + return feature_net_flops + comb_flops + af_flops + feature_flops + + def feature_transform(self, f, layer): + f = f.permute(0, 2, 1) + f = F.pad(f, [1, 0]) + f = torch.tanh(layer(f)) + return f.permute(0, 2, 1) + + def forward(self, features, periods, signal=None, debug=False): + + periods = periods.squeeze(-1) + pitch_embedding = self.pitch_embedding(periods) + + if signal is not None: + nb_pre_frames = signal.size(-1) // self.FRAME_SIZE + if len(signal.shape) < 3: + signal = signal.unsqueeze(1) + else: + nb_pre_frames = 0 + + full_features = torch.cat((features, pitch_embedding), dim=-1) + cf = self.feature_net(full_features) + cf1 = self.feature_transform(cf, self.post_af2) + cf2= self.feature_transform(cf1, self.post_af3) + cf3 = self.feature_transform(cf2, self.post_cf1) + cf4 = self.feature_transform(cf3, self.post_cf2) + cf5 = self.feature_transform(cf4, self.post_af1) + + + # upsample periods + periods = torch.repeat_interleave(periods, self.upsamp_factor, 1) + periods_ar = torch.where(periods > 42, periods, 2*periods) + + num_frames = periods.size(1) + + # pre-net + ref_phase = torch.tanh(self.create_phase_signals(periods)) + x = self.af_prescale(ref_phase, cf) + noise = self.noise_shaper(cf) + prior = torch.cat((x, noise), dim=1) + + # states + state_cf_ar = None + state_af_mix = None + state_tdshape1 = None + state_tdshape2 = None + state_cf1 = None + state_cf2 = None + state_af1 = None + state_af2 = None + state_af3 = None + state_tdshape3 = None + state_af4 = None + last_frame = torch.zeros((features.size(0), 1, self.FRAME_SIZE), device=features.device) + + frames = [] + + for i in range(num_frames): + y, state_cf_ar = self.cf_ar(last_frame, cf[:, i:i+1], periods_ar[:, i:i+1], state=state_cf_ar, return_state=True) + y = torch.cat((y, prior[..., i * self.FRAME_SIZE : (i+1) * self.FRAME_SIZE]), dim=1) + y, state_af_mix = self.af_mix(y, cf[:, i:i+1], state=state_af_mix, return_state=True) + + # temporal shaping + innovating + y1 = y[:, 0:1, :] + y2, state_tdshape1 = self.tdshape1(y[:, 1:2, :], cf[:, i:i+1], state=state_tdshape1, return_state=True) + y = torch.cat((y1, y2), dim=1) + y, state_af2 = self.af2(y, cf[:, i:i+1], state=state_af2, return_state=True, debug=debug) + + # second temporal shaping + y1 = y[:, 0:1, :] + y2, state_tdshape2 = self.tdshape2(y[:, 1:2, :], cf1[:, i:i+1], state=state_tdshape2, return_state=True) + y = torch.cat((y1, y2), dim=1) + y, state_af3 = self.af3(y, cf1[:, i:i+1], state=state_af3, return_state=True, debug=debug) + + # spectral shaping + y, state_cf1 = self.cf1(y, cf2[:, i:i+1], periods[:, i:i+1], state=state_cf1, return_state=True, debug=debug) + y, state_cf2 = self.cf2(y, cf3[:, i:i+1], periods[:, i:i+1], state=state_cf2, return_state=True, debug=debug) + y, state_af1 = self.af1(y, cf4[:, i:i+1], state=state_af1, return_state=True, debug=debug) + + # final temporal env adjustment + y1 = y[:, 0:1, :] + y2, state_tdshape3 = self.tdshape3(y[:, 1:2, :], cf5[:, i:i+1], state=state_tdshape3, return_state=True) + y = torch.cat((y1, y2), dim=1) + y, state_af4 = self.af4(y, cf5[:, i:i+1], state=state_af4, return_state=True, debug=debug) + + if i < nb_pre_frames: + y = signal[:, :, i * self.FRAME_SIZE : (i + 1) * self.FRAME_SIZE] + + last_frame = y + frames.append(y) + + return torch.cat(frames, dim=-1) + + def process(self, features, periods, debug=False): + + self.eval() + device = next(iter(self.parameters())).device + with torch.no_grad(): + + # run model + f = features.unsqueeze(0).to(device) + p = periods.unsqueeze(0).to(device) + + y = self.forward(f, p, debug=debug).squeeze() + + # deemphasis + if self.preemph > 0: + for i in range(len(y) - 1): + y[i + 1] += self.preemph * y[i] + + # clip to valid range + out = torch.clip((2**15) * y, -2**15, 2**15 - 1).short() + + return out \ No newline at end of file diff --git a/dnn/torch/osce/utils/layers/ar_filter.py b/dnn/torch/osce/utils/layers/ar_filter.py new file mode 100644 index 000000000..9c0e09f72 --- /dev/null +++ b/dnn/torch/osce/utils/layers/ar_filter.py @@ -0,0 +1,226 @@ +""" +/* Copyright (c) 2023 Amazon + Written by Jan Buethe */ +/* + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions + are met: + + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER + OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, + PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR + PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF + LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING + NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*/ +""" + +import torch +from torch import nn +import torch.nn.functional as F + +from utils.endoscopy import write_data + +class ARFilter(nn.Module): + COUNTER = 1 + + def __init__(self, + kernel_size, + feature_dim, + frame_size=160, + overlap_size=40, + padding=None, + max_lag=256, + name=None, + gain_limit_db=10, + norm_p=2): + """ + + Parameters: + ----------- + + feature_dim : int + dimension of features from which kernels, biases and gains are computed + + frame_size : int, optional + frame size, defaults to 160 + + overlap_size : int, optional + overlap size for filter cross-fade. Cross-fade is done on the first overlap_size samples of every frame, defaults to 40 + + padding : List[int, int], optional + left and right padding. Defaults to [(kernel_size - 1) // 2, kernel_size - 1 - (kernel_size - 1) // 2] + + max_lag : int, optional + maximal pitch lag, defaults to 256 + + name: str or None, optional + specifies a name attribute for the module. If None the name is auto generated as comb_1d_COUNT, where COUNT is an instance counter for LimitedAdaptiveComb1d + + """ + + super().__init__() + + self.in_channels = 1 + self.out_channels = 1 + self.feature_dim = feature_dim + self.kernel_size = kernel_size + self.frame_size = frame_size + self.overlap_size = overlap_size + self.max_lag = max_lag + self.limit_db = gain_limit_db + self.norm_p = norm_p + + if name is None: + self.name = "ar_filter_" + str(self.COUNTER) + self.COUNTER += 1 + else: + self.name = name + + # network for generating convolution weights + self.conv_kernel = nn.Linear(feature_dim, kernel_size) + + # comb filter gain + self.filter_gain = nn.Linear(feature_dim, 1) + self.log_gain_limit = gain_limit_db * 0.11512925464970229 + with torch.no_grad(): + self.filter_gain.bias[:] = max(0.1, 0.7 + self.log_gain_limit) + + if type(padding) == type(None): + self.padding = [kernel_size // 2, kernel_size - 1 - kernel_size // 2] + else: + self.padding = padding + + self.overlap_win = nn.Parameter(.5 + .5 * torch.cos((torch.arange(self.overlap_size) + 0.5) * torch.pi / overlap_size), requires_grad=False) + + def forward(self, x, features, lags, state=None, return_state=False, debug=False): + """ adaptive 1d convolution + + + Parameters: + ----------- + x : torch.tensor + input signal of shape (batch_size, in_channels, num_samples) + + feathres : torch.tensor + frame-wise features of shape (batch_size, num_frames, feature_dim) + + lags: torch.LongTensor + frame-wise lags for comb-filtering + + """ + + batch_size = x.size(0) + num_frames = features.size(1) + num_samples = x.size(2) + frame_size = self.frame_size + overlap_size = self.overlap_size + kernel_size = self.kernel_size + win1 = torch.flip(self.overlap_win, [0]) + win2 = self.overlap_win + + if num_samples // self.frame_size != num_frames: + raise ValueError('non matching sizes in ARFilter.forward') + + conv_kernels = self.conv_kernel(features).reshape((batch_size, num_frames, self.out_channels, self.in_channels, self.kernel_size)) + conv_kernels = conv_kernels / (1e-6 + torch.norm(conv_kernels, p=self.norm_p, dim=-1, keepdim=True)) + + + conv_gains = torch.exp(- torch.relu(self.filter_gain(features).permute(0, 2, 1)) + self.log_gain_limit) + + + if debug and batch_size == 1: + key = self.name + "_gains" + write_data(key, conv_gains.detach().squeeze().cpu().numpy(), 16000 // self.frame_size) + key = self.name + "_kernels" + write_data(key, conv_kernels.detach().squeeze().cpu().numpy(), 16000 // self.frame_size) + key = self.name + "_lags" + write_data(key, lags.detach().squeeze().cpu().numpy(), 16000 // self.frame_size) + + + # frame-wise convolution with overlap-add + output_frames = [] + overlap_mem = torch.zeros((batch_size, self.out_channels, self.overlap_size), device=x.device) + if state is not None: + last_kernel, last_gain, last_lag, last_x = state + conv_kernels = torch.cat((last_kernel, conv_kernels), dim=1) + conv_gains = torch.cat((last_gain, conv_gains), dim=-1) + lags = torch.cat((last_lag, lags), dim=-1) + + x = torch.cat((last_x, x), dim=-1) + + x = F.pad(x, [0, self.padding[1] + self.overlap_size]) + num_frames += 1 + + else: + x = F.pad(x, self.padding) + x = F.pad(x, [self.max_lag, self.overlap_size]) + + + new_state = (conv_kernels[:, -1:, ...], conv_gains[..., -1:], lags[..., -1:], x[..., -(self.max_lag + self.padding[0] + self.frame_size + self.padding[1] + self.overlap_size) : -(self.padding[1] + self.overlap_size)]) + + idx = torch.arange(frame_size + kernel_size - 1 + overlap_size).to(x.device).view(1, 1, -1) + idx = torch.repeat_interleave(idx, batch_size, 0) + idx = torch.repeat_interleave(idx, self.in_channels, 1) + + + for i in range(num_frames): + + cidx = idx + i * frame_size + self.max_lag - (lags[..., i].view(batch_size, 1, 1) - frame_size) + xx = torch.gather(x, -1, cidx).reshape((1, batch_size * self.in_channels, -1)) + + new_chunk = torch.conv1d(xx, conv_kernels[:, i, ...].reshape((batch_size * self.out_channels, self.in_channels, self.kernel_size)), groups=batch_size).reshape(batch_size, self.out_channels, -1) * conv_gains[:, :, i : i + 1] + + # overlapping part + output_frames.append(new_chunk[:, :, : overlap_size] * win1 + overlap_mem * win2) + + # non-overlapping part + output_frames.append(new_chunk[:, :, overlap_size : frame_size]) + + # mem for next frame + overlap_mem = new_chunk[:, :, frame_size :] + + # concatenate chunks + output = torch.cat(output_frames, dim=-1) + + if state is not None: + output = output[..., self.frame_size:] + + if return_state: + return output, new_state + else: + return output + + + def flop_count(self, rate): + frame_rate = rate / self.frame_size + overlap = self.overlap_size + overhead = overlap / self.frame_size + + count = 0 + + # kernel computation and filtering + count += 2 * (frame_rate * self.feature_dim * self.kernel_size) + count += 2 * (self.in_channels * self.out_channels * self.kernel_size * (1 + overhead) * rate) + count += 2 * (frame_rate * self.feature_dim * self.out_channels) + rate * (1 + overhead) * self.out_channels + + + # a0 computation + count += 2 * (frame_rate * self.feature_dim * self.out_channels) + rate * (1 + overhead) * self.out_channels + + # windowing + count += overlap * frame_rate * 3 * self.out_channels + + return count diff --git a/dnn/torch/osce/utils/layers/limited_adaptive_comb1d.py b/dnn/torch/osce/utils/layers/limited_adaptive_comb1d.py index b146240e6..cbd7b6667 100644 --- a/dnn/torch/osce/utils/layers/limited_adaptive_comb1d.py +++ b/dnn/torch/osce/utils/layers/limited_adaptive_comb1d.py @@ -122,7 +122,7 @@ def __init__(self, self.overlap_win = nn.Parameter(.5 + .5 * torch.cos((torch.arange(self.overlap_size) + 0.5) * torch.pi / overlap_size), requires_grad=False) - def forward(self, x, features, lags, debug=False): + def forward(self, x, features, lags, state=None, return_state=False, debug=False): """ adaptive 1d convolution @@ -149,7 +149,7 @@ def forward(self, x, features, lags, debug=False): win2 = self.overlap_win if num_samples // self.frame_size != num_frames: - raise ValueError('non matching sizes in AdaptiveConv1d.forward') + raise ValueError('non matching sizes in AdaptiveComb1d.forward') conv_kernels = self.conv_kernel(features).reshape((batch_size, num_frames, self.out_channels, self.in_channels, self.kernel_size)) conv_kernels = conv_kernels / (1e-6 + torch.norm(conv_kernels, p=self.norm_p, dim=-1, keepdim=True)) @@ -175,8 +175,24 @@ def forward(self, x, features, lags, debug=False): # frame-wise convolution with overlap-add output_frames = [] overlap_mem = torch.zeros((batch_size, self.out_channels, self.overlap_size), device=x.device) - x = F.pad(x, self.padding) - x = F.pad(x, [self.max_lag, self.overlap_size]) + if state is not None: + last_kernel, last_gain, last_global_gain, last_lag, last_x = state + conv_kernels = torch.cat((last_kernel, conv_kernels), dim=1) + conv_gains = torch.cat((last_gain, conv_gains), dim=-1) + global_conv_gains = torch.cat((last_global_gain, global_conv_gains), dim=-1) + lags = torch.cat((last_lag, lags), dim=-1) + + x = torch.cat((last_x, x), dim=-1) + + x = F.pad(x, [0, self.padding[1] + self.overlap_size]) + num_frames += 1 + + else: + x = F.pad(x, self.padding) + x = F.pad(x, [self.max_lag, self.overlap_size]) + + + new_state = (conv_kernels[:, -1:, ...], conv_gains[..., -1:], global_conv_gains[..., -1:], lags[..., -1:], x[..., -(self.max_lag + self.padding[0] + self.frame_size + self.padding[1] + self.overlap_size) : -(self.padding[1] + self.overlap_size)]) idx = torch.arange(frame_size + kernel_size - 1 + overlap_size).to(x.device).view(1, 1, -1) idx = torch.repeat_interleave(idx, batch_size, 0) @@ -209,7 +225,14 @@ def forward(self, x, features, lags, debug=False): # concatenate chunks output = torch.cat(output_frames, dim=-1) - return output + if state is not None: + output = output[..., self.frame_size:] + + if return_state: + return output, new_state + else: + return output + def flop_count(self, rate): frame_rate = rate / self.frame_size diff --git a/dnn/torch/osce/utils/layers/limited_adaptive_conv1d.py b/dnn/torch/osce/utils/layers/limited_adaptive_conv1d.py index 073ea1b14..2dc9778a2 100644 --- a/dnn/torch/osce/utils/layers/limited_adaptive_conv1d.py +++ b/dnn/torch/osce/utils/layers/limited_adaptive_conv1d.py @@ -30,6 +30,7 @@ import torch from torch import nn import torch.nn.functional as F +import math as m from utils.endoscopy import write_data @@ -121,6 +122,8 @@ def __init__(self, self.overlap_win = nn.Parameter(.5 + .5 * torch.cos((torch.arange(self.overlap_size) + 0.5) * torch.pi / overlap_size), requires_grad=False) + self.fft_size = 256 + def flop_count(self, rate): frame_rate = rate / self.frame_size @@ -146,7 +149,7 @@ def flop_count(self, rate): return count - def forward(self, x, features, debug=False): + def forward(self, x, features, state=None, return_state=False, debug=False): """ adaptive 1d convolution @@ -197,9 +200,21 @@ def forward(self, x, features, debug=False): conv_kernels = conv_kernels * conv_gains.view(batch_size, num_frames, self.out_channels, 1, 1) + if state is not None: + last_kernel, last_x_frame = state + conv_kernels = torch.cat((last_kernel, conv_kernels), dim=1) + x = torch.cat((last_x_frame, x), dim=-1) + + new_state = (conv_kernels[:, -1:, :, :, :], x[:, :, -self.frame_size:]) + conv_kernels = conv_kernels.permute(0, 2, 3, 1, 4) - output = adaconv_kernel(x, conv_kernels, win1, fft_size=256) + output = adaconv_kernel(x, conv_kernels, win1, fft_size=self.fft_size) + if state is not None: + output = output[..., self.frame_size:] - return output \ No newline at end of file + if return_state: + return output, new_state + else: + return output \ No newline at end of file diff --git a/dnn/torch/osce/utils/layers/td_shaper.py b/dnn/torch/osce/utils/layers/td_shaper.py index 73d66bd52..100a6cffb 100644 --- a/dnn/torch/osce/utils/layers/td_shaper.py +++ b/dnn/torch/osce/utils/layers/td_shaper.py @@ -87,7 +87,7 @@ def envelope_transform(self, x): return x - def forward(self, x, features, debug=False): + def forward(self, x, features, state=None, return_state=False, debug=False): """ innovate signal parts with temporal shaping @@ -101,28 +101,31 @@ def forward(self, x, features, debug=False): """ + batch_size = x.size(0) num_frames = features.size(1) num_samples = x.size(2) - frame_size = self.frame_size # generate temporal envelope tenv = self.envelope_transform(x) # feature path - f = torch.cat((features, tenv), dim=-1) - f = F.pad(f.permute(0, 2, 1), [1, 0]) + f = torch.cat((features, tenv), dim=-1).permute(0, 2, 1) + if state is not None: + f = torch.cat((state, f), dim=-1) + else: + f = F.pad(f, [2, 0]) alpha = F.leaky_relu(self.feature_alpha1(f), 0.2) - alpha = torch.exp(self.feature_alpha2(F.pad(alpha, [1, 0]))) + alpha = torch.exp(self.feature_alpha2(alpha)) alpha = alpha.permute(0, 2, 1) if self.innovate: inno_alpha = F.leaky_relu(self.feature_alpha1b(f), 0.2) - inno_alpha = torch.exp(self.feature_alpha2b(F.pad(inno_alpha, [1, 0]))) + inno_alpha = torch.exp(self.feature_alpha2b(inno_alpha)) inno_alpha = inno_alpha.permute(0, 2, 1) inno_x = F.leaky_relu(self.feature_alpha1c(f), 0.2) - inno_x = torch.tanh(self.feature_alpha2c(F.pad(inno_x, [1, 0]))) + inno_x = torch.tanh(self.feature_alpha2c(inno_x)) inno_x = inno_x.permute(0, 2, 1) # signal path @@ -132,4 +135,10 @@ def forward(self, x, features, debug=False): if self.innovate: y = y + inno_alpha * inno_x - return y.reshape(batch_size, 1, num_samples) + y = y.reshape(batch_size, 1, num_samples) + + if return_state: + new_state = f[..., -2:] + return y, new_state + else: + return y