-
Notifications
You must be signed in to change notification settings - Fork 105
/
Copy pathlearner.py
171 lines (141 loc) · 7.08 KB
/
learner.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
from fastai.basics import *
from fastai.text.learner import LanguageLearner, get_language_model, _model_meta
from .model import *
from .transform import MusicItem
from ..numpy_encode import SAMPLE_FREQ
from ..utils.top_k_top_p import top_k_top_p
from ..utils.midifile import is_empty_midi
_model_meta[MusicTransformerXL] = _model_meta[TransformerXL] # copy over fastai's model metadata
def music_model_learner(data:DataBunch, arch=MusicTransformerXL, config:dict=None, drop_mult:float=1.,
pretrained_path:PathOrStr=None, **learn_kwargs) -> 'LanguageLearner':
"Create a `Learner` with a language model from `data` and `arch`."
meta = _model_meta[arch]
if pretrained_path:
state = torch.load(pretrained_path, map_location='cpu')
if config is None: config = state['config']
model = get_language_model(arch, len(data.vocab.itos), config=config, drop_mult=drop_mult)
learn = MusicLearner(data, model, split_func=meta['split_lm'], **learn_kwargs)
if pretrained_path:
get_model(model).load_state_dict(state['model'], strict=False)
if not hasattr(learn, 'opt'): learn.create_opt(defaults.lr, learn.wd)
try: learn.opt.load_state_dict(state['opt'])
except: pass
del state
gc.collect()
return learn
# Predictions
from fastai import basic_train # for predictions
class MusicLearner(LanguageLearner):
def save(self, file:PathLikeOrBinaryStream=None, with_opt:bool=True, config=None):
"Save model and optimizer state (if `with_opt`) with `file` to `self.model_dir`. `file` can be file-like (file or buffer)"
out_path = super().save(file, return_path=True, with_opt=with_opt)
if config and out_path:
state = torch.load(out_path)
state['config'] = config
torch.save(state, out_path)
del state
gc.collect()
return out_path
def beam_search(self, xb:Tensor, n_words:int, top_k:int=10, beam_sz:int=10, temperature:float=1.,
):
"Return the `n_words` that come after `text` using beam search."
self.model.reset()
self.model.eval()
xb_length = xb.shape[-1]
if xb.shape[0] > 1: xb = xb[0][None]
yb = torch.ones_like(xb)
nodes = None
xb = xb.repeat(top_k, 1)
nodes = xb.clone()
scores = xb.new_zeros(1).float()
with torch.no_grad():
for k in progress_bar(range(n_words), leave=False):
out = F.log_softmax(self.model(xb)[0][:,-1], dim=-1)
values, indices = out.topk(top_k, dim=-1)
scores = (-values + scores[:,None]).view(-1)
indices_idx = torch.arange(0,nodes.size(0))[:,None].expand(nodes.size(0), top_k).contiguous().view(-1)
sort_idx = scores.argsort()[:beam_sz]
scores = scores[sort_idx]
nodes = torch.cat([nodes[:,None].expand(nodes.size(0),top_k,nodes.size(1)),
indices[:,:,None].expand(nodes.size(0),top_k,1),], dim=2)
nodes = nodes.view(-1, nodes.size(2))[sort_idx]
self.model[0].select_hidden(indices_idx[sort_idx])
xb = nodes[:,-1][:,None]
if temperature != 1.: scores.div_(temperature)
node_idx = torch.multinomial(torch.exp(-scores), 1).item()
return [i.item() for i in nodes[node_idx][xb_length:] ]
def predict(self, item:MusicItem, n_words:int=128,
temperatures:float=(1.0,1.0), min_bars=4,
top_k=30, top_p=0.6):
"Return the `n_words` that come after `text`."
self.model.reset()
new_idx = []
vocab = self.data.vocab
x, pos = item.to_tensor(), item.get_pos_tensor()
last_pos = pos[-1] if len(pos) else 0
y = torch.tensor([0])
start_pos = last_pos
sep_count = 0
bar_len = SAMPLE_FREQ * 4 # assuming 4/4 time
vocab = self.data.vocab
repeat_count = 0
if hasattr(self.model[0], 'encode_position'):
encode_position = self.model[0].encode_position
else: encode_position = False
for i in progress_bar(range(n_words), leave=True):
with torch.no_grad():
if encode_position:
batch = { 'x': x[None], 'pos': pos[None] }
logits = self.model(batch)[0][-1][-1]
else:
logits = self.model(x[None])[0][-1][-1]
prev_idx = new_idx[-1] if len(new_idx) else vocab.pad_idx
# Temperature
# Use first temperatures value if last prediction was duration
temperature = temperatures[0] if vocab.is_duration_or_pad(prev_idx) else temperatures[1]
repeat_penalty = max(0, np.log((repeat_count+1)/4)/5) * temperature
temperature += repeat_penalty
if temperature != 1.: logits = logits / temperature
# Filter
# bar = 16 beats
filter_value = -float('Inf')
if ((last_pos - start_pos) // 16) <= min_bars: logits[vocab.bos_idx] = filter_value
logits = filter_invalid_indexes(logits, prev_idx, vocab, filter_value=filter_value)
logits = top_k_top_p(logits, top_k=top_k, top_p=top_p, filter_value=filter_value)
# Sample
probs = F.softmax(logits, dim=-1)
idx = torch.multinomial(probs, 1).item()
# Update repeat count
num_choices = len(probs.nonzero().view(-1))
if num_choices <= 2: repeat_count += 1
else: repeat_count = repeat_count // 2
if prev_idx==vocab.sep_idx:
duration = idx - vocab.dur_range[0]
last_pos = last_pos + duration
bars_pred = (last_pos - start_pos) // 16
abs_bar = last_pos // 16
# if (bars % 8 == 0) and (bars_pred > min_bars): break
if (i / n_words > 0.80) and (abs_bar % 4 == 0): break
if idx==vocab.bos_idx:
print('Predicted BOS token. Returning prediction...')
break
new_idx.append(idx)
x = x.new_tensor([idx])
pos = pos.new_tensor([last_pos])
pred = vocab.to_music_item(np.array(new_idx))
full = item.append(pred)
return pred, full
# High level prediction functions from midi file
def predict_from_midi(learn, midi=None, n_words=400,
temperatures=(1.0,1.0), top_k=30, top_p=0.6, seed_len=None, **kwargs):
vocab = learn.data.vocab
seed = MusicItem.from_file(midi, vocab) if not is_empty_midi(midi) else MusicItem.empty(vocab)
if seed_len is not None: seed = seed.trim_to_beat(seed_len)
pred, full = learn.predict(seed, n_words=n_words, temperatures=temperatures, top_k=top_k, top_p=top_p, **kwargs)
return full
def filter_invalid_indexes(res, prev_idx, vocab, filter_value=-float('Inf')):
if vocab.is_duration_or_pad(prev_idx):
res[list(range(*vocab.dur_range))] = filter_value
else:
res[list(range(*vocab.note_range))] = filter_value
return res