forked from vincentherrmann/pytorch-wavenet
-
Notifications
You must be signed in to change notification settings - Fork 4
/
wavenet_modules.py
97 lines (82 loc) · 3.35 KB
/
wavenet_modules.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable, Function
from torch.nn import Parameter
def dilate(x, dilation, init_dilation=1, pad_start=True):
"""
:param x: Tensor of size (N, C, L), where N is the input dilation, C is the number of channels, and L is the input length
:param dilation: Target dilation. Will be the size of the first dimension of the output tensor.
:param pad_start: If the input length is not compatible with the specified dilation, zero padding is used. This parameter determines wether the zeros are added at the start or at the end.
:return: The dilated tensor of size (dilation, C, L*N / dilation). The output might be zero padded at the start
"""
[n, c, l] = x.size()
dilation_factor = dilation / init_dilation
if dilation_factor == 1:
return x
# zero padding for reshaping
new_l = int(np.ceil(l / dilation_factor) * dilation_factor)
if new_l != l:
l = new_l
x = constant_pad_1d(x, new_l, pad_start=pad_start)
l_old = int(round(l / dilation_factor))
n_old = int(round(n * dilation_factor))
l = math.ceil(l * init_dilation / dilation)
n = math.ceil(n * dilation / init_dilation)
# reshape according to dilation
x = x.permute(1, 2, 0).contiguous() # (n, c, l) -> (c, l, n)
x = x.view(c, l, n)
x = x.permute(2, 0, 1).contiguous() # (c, l, n) -> (n, c, l)
return x
class DilatedQueue:
def __init__(self, max_length, data=None, dilation=1, num_deq=1, num_channels=1, dtype=torch.FloatTensor):
self.in_pos = 0
self.out_pos = 0
self.num_deq = num_deq
self.num_channels = num_channels
self.dilation = dilation
self.max_length = max_length
self.data = data
self.dtype = dtype
if data == None:
self.data = Variable(dtype(num_channels, max_length).zero_())
def enqueue(self, input):
self.data[:, self.in_pos] = input
self.in_pos = (self.in_pos + 1) % self.max_length
def dequeue(self, num_deq=1, dilation=1):
# |
# |6|7|8|1|2|3|4|5|
# |
start = self.out_pos - ((num_deq - 1) * dilation)
if start < 0:
t1 = self.data[:, start::dilation]
t2 = self.data[:, self.out_pos % dilation:self.out_pos + 1:dilation]
t = torch.cat((t1, t2), 1)
else:
t = self.data[:, start:self.out_pos + 1:dilation]
self.out_pos = (self.out_pos + 1) % self.max_length
return t
def reset(self):
self.data = Variable(self.dtype(self.num_channels, self.max_length).zero_())
self.in_pos = 0
self.out_pos = 0
def constant_pad_1d(input,
target_size,
value=0,
pad_start=False):
"""
Assumes that padded dim is the 2, based on pytorch specification.
Input: (N,C,Win)(N, C, W_{in})(N,C,Win)
Output: (N,C,Wout)(N, C, W_{out})(N,C,Wout) where
:param input:
:param target_size:
:param value:
:param pad_start:
:return:
"""
num_pad = target_size - input.size(2)
assert num_pad >= 0, 'target size has to be greater than input size'
padding = (num_pad, 0) if pad_start else (0, num_pad)
return torch.nn.ConstantPad1d(padding, value)(input)