forked from jason9693/MusicTransformer-pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
model.py
87 lines (76 loc) · 3.33 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
from custom.layers import *
from custom.criterion import *
from custom.layers import Encoder
from custom.config import config
import sys
import torch
import torch.distributions as dist
import random
import utils
import torch
from tensorboardX import SummaryWriter
from progress.bar import Bar
class MusicTransformer(torch.nn.Module):
def __init__(self, embedding_dim=256, vocab_size=388+2, num_layer=6,
max_seq=2048, dropout=0.2, debug=False, loader_path=None, dist=False, writer=None):
super().__init__()
self.infer = False
if loader_path is not None:
self.load_config_file(loader_path)
else:
self._debug = debug
self.max_seq = max_seq
self.num_layer = num_layer
self.embedding_dim = embedding_dim
self.vocab_size = vocab_size
self.dist = dist
self.writer = writer
self.Decoder = Encoder(
num_layers=self.num_layer, d_model=self.embedding_dim,
input_vocab_size=self.vocab_size, rate=dropout, max_len=max_seq)
self.fc = torch.nn.Linear(self.embedding_dim, self.vocab_size)
def forward(self, x, length=None, writer=None):
if self.training or not self.infer:
_, _, look_ahead_mask = utils.get_masked_with_pad_tensor(self.max_seq, x, x, config.pad_token)
decoder, w = self.Decoder(x, mask=look_ahead_mask)
fc = self.fc(decoder)
print(fc.shape, "inside forward")
return fc.contiguous() if self.training else (fc.contiguous(), [weight.contiguous() for weight in w])
else:
return self.generate(x, length, None).contiguous().tolist()
def generate(self,
prior: torch.Tensor,
length=2048,
tf_board_writer: SummaryWriter = None):
decode_array = prior
result_array = prior
print(config)
print(length)
for i in Bar('generating').iter(range(length)):
if decode_array.size(1) >= config.threshold_len:
decode_array = decode_array[:, 1:]
_, _, look_ahead_mask = \
utils.get_masked_with_pad_tensor(decode_array.size(1), decode_array, decode_array, pad_token=config.pad_token)
# result, _ = self.forward(decode_array, lookup_mask=look_ahead_mask)
# result, _ = decode_fn(decode_array, look_ahead_mask)
result, _ = self.Decoder(decode_array, None)
result = self.fc(result)
result = result.softmax(-1)
if tf_board_writer:
tf_board_writer.add_image("logits", result, global_step=i)
u = 0
if u > 1:
result = result[:, -1].argmax(-1).to(decode_array.dtype)
decode_array = torch.cat((decode_array, result.unsqueeze(-1)), -1)
else:
pdf = dist.OneHotCategorical(probs=result[:, -1])
result = pdf.sample().argmax(-1).unsqueeze(-1)
# result = torch.transpose(result, 1, 0).to(torch.int32)
decode_array = torch.cat((decode_array, result), dim=-1)
result_array = torch.cat((result_array, result), dim=-1)
del look_ahead_mask
result_array = result_array[0]
return result_array
def test(self):
self.eval()
self.infer = True