-
Notifications
You must be signed in to change notification settings - Fork 1
/
music_reformer_tpu_edition.py
425 lines (326 loc) · 13.4 KB
/
music_reformer_tpu_edition.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
# -*- coding: utf-8 -*-
"""Music_Reformer_TPU_Edition.ipynb
Automatically generated by Colaboratory.
Original file is located at
https://colab.research.google.com/drive/1riJjgG_98nZXPT6MkV1HlIA_z8TI5EUn
# Music Reformer (v.1.5): TPU Edition
### This is a work in progress so please check back for updates and improvements.
***
### Based on the offical Reformer Google Colab and code.
https://github.com/google/trax/tree/master/trax/models/reformer
***
Project Los Angeles
Tegridy Code 2021
***
# Setup the environment
### Please note that you may need to run the cells below several times, as well as you may need to restart Colab run-time also several times to resolve all dependencies conflicts.
"""
# Commented out IPython magic to ensure Python compatibility.
#@title Install the dependencies
# Install dependencies
!git clone https://github.com/asigalov61/tegridy-tools
# %cd /content/tegridy-tools/tegridy-tools/
# %cd /content/
#!wget https://github.com/asigalov61/Music-Reformer/raw/main/Dataset/Music-Reformer_TXT_Dataset.zip
#!unzip Music-Reformer_TXT_Dataset.zip
!pip install --upgrade -q jax
!pip install --upgrade -q jaxlib
!pip install --upgrade -q trax==1.3.6
!pip install --upgrade -q sentencepiece
!pip install --upgrade -q gin
# Commented out IPython magic to ensure Python compatibility.
#@title Import modules
print('Loading needed modules. Please wait...')
# Make sure the Colab Runtime is set to Accelerator: TPU.
import requests
import os
if 'TPU_DRIVER_MODE' not in globals():
url = 'http://' + os.environ['COLAB_TPU_ADDR'].split(':')[0] + ':8475/requestversion/tpu_driver0.1-dev20191206'
resp = requests.post(url)
TPU_DRIVER_MODE = 1
# The following is required to use TPU Driver as JAX's backend.
from jax.config import config
config.FLAGS.jax_xla_backend = "tpu_driver"
config.FLAGS.jax_backend_target = "grpc://" + os.environ['COLAB_TPU_ADDR']
print(config.FLAGS.jax_backend_target)
import gin
import os
import numpy as np
import torch
from scipy.special import softmax
import tqdm
from tqdm import auto
# NLP Vocab Generation
import sentencepiece as spm
# %cd /content/tegridy-tools/tegridy-tools
import TMIDI
# %cd /content/
import os
if not os.path.exists('/content/Dataset'):
os.makedirs('/content/Dataset')
# Zipping and downloading files
from google.colab import files
import shutil
# Trax
import jax
import trax
from trax.data import inputs
import jax.numpy as jnp
"""# Prep the dataset"""
# Commented out IPython magic to ensure Python compatibility.
#@title Download special Tegridy Piano MIDI dataset
#@markdown Works best stand-alone/as-is for the optimal results
# %cd /content/Dataset/
!wget 'https://github.com/asigalov61/Tegridy-MIDI-Dataset/raw/master/Tegridy-Piano-CC-BY-NC-SA.zip'
!unzip -j '/content/Dataset/Tegridy-Piano-CC-BY-NC-SA.zip'
!rm '/content/Dataset/Tegridy-Piano-CC-BY-NC-SA.zip'
# %cd /content/
#@title Process MIDIs to special MIDI dataset with Tegridy MIDI Processor
#@markdown NOTES:
#@markdown 1) Dataset MIDI file names are used as song names. Feel free to change it to anything you like.
#@markdown 2) Best results are achieved with the single-track, single-channel, single-instrument MIDI 0 files with plain English names (avoid special or sys/foreign chars)
#@markdown 3) MIDI Channel = -1 means all MIDI channels. MIDI Channel = 16 means all channels will be processed. Otherwise, only single indicated MIDI channel will be processed.
file_name_to_output_dataset_to = "/content/Music-Reformer_TXT_Dataset" #@param {type:"string"}
desired_MIDI_channel_to_process = 0 #@param {type:"slider", min:-1, max:15, step:1}
encode_velocities = True #@param {type:"boolean"}
chordify_input_MIDIs = False #@param {type:"boolean"}
time_denominator = 10 #@param {type:"slider", min:1, max:20, step:1}
chars_encoding_offset = 33 #@param {type:"number"}
print('TMIDI Processor')
print('Starting up...')
###########
average_note_pitch = 0
min_note = 127
max_note = 0
files_count = 0
ev = 0
chords_list_f = []
melody_list_f = []
chords_list = []
chords_count = 0
melody_chords = []
melody_count = 0
TXT_String = 'DATASET=Optimus-Virtuoso-Music-Dataset' + chr(10)
TXT = ''
melody = []
chords = []
###########
print('Loading MIDI files...')
print('This may take a while on a large dataset in particular.')
dataset_addr = "/content/Dataset/"
os.chdir(dataset_addr)
filez = os.listdir(dataset_addr)
print('Processing MIDI files. Please wait...')
for f in tqdm.auto.tqdm(filez):
try:
files_count += 1
TXT, melody, chords = TMIDI.Optimus_MIDI_TXT_Processor(f, chordify_TXT=chordify_input_MIDIs, output_MIDI_channels=False, char_offset=chars_encoding_offset, dataset_MIDI_events_time_denominator=time_denominator, output_velocity=encode_velocities, MIDI_patch=range(0,127))
melody_list_f += melody
chords_list_f += chords
TXT_String += TXT
except:
print('Bad MIDI:', f)
continue
print('Task complete :)')
print('==================================================')
print('Number of processed dataset MIDI files:', files_count)
print('Number of MIDI chords recorded:', len(chords_list_f))
print('First chord event:', chords_list_f[0], 'Last chord event:', chords_list_f[-1])
print('Number of recorded melody events:', len(melody_list_f))
print('First melody event:', melody_list_f[0], 'Last Melody event:', melody_list_f[-1])
print('Total number of MIDI events recorded:', len(chords_list_f) + len(melody_list_f))
# Writing dataset to TXT file
with open(file_name_to_output_dataset_to + '.txt', 'wb') as f:
f.write(TXT_String.encode('utf-8', 'replace'))
f.close
# Dataset
MusicDataset = [chords_list_f, melody_list_f]
# Writing dataset to pickle file
TMIDI.Tegridy_Pickle_File_Writer(MusicDataset, file_name_to_output_dataset_to)
#@title Process the TXT MIDI dataset to TXT INT dataset
full_path_to_TXT_dataset = "/content/Music-Reformer_TXT_Dataset.txt" #@param {type:"string"}
with open(full_path_to_TXT_dataset, 'r') as file:
z = file.read()
file.close()
Z = z.encode('utf8')
Y = list(Z)
string = '\n'.join([str(item) for item in Y])
with open('/content/Music-Reformer_INT_Dataset.txt', 'w') as file:
file.write(string)
# Commented out IPython magic to ensure Python compatibility.
#@title Create a tokenizer and its model
#@markdown NOTE: Less tokenizer words seem to work better
# %cd /content/
full_path_to_INT_dataset = "/content/Music-Reformer_INT_Dataset.txt" #@param {type:"string"}
tokenizer_vocabulary_size_in_words = 321#@param {type:"integer"}
# Train a BPE model on the dataset
spm.SentencePieceTrainer.train(input=full_path_to_INT_dataset,
model_prefix='Music-Reformer-Tokenizer',
vocab_size=tokenizer_vocabulary_size_in_words,
model_type='bpe')
# Load BPE vocabulary
TOKENIZER = spm.SentencePieceProcessor()
TOKENIZER.load('Music-Reformer-Tokenizer.model')
# Load the dataset
with open(full_path_to_INT_dataset, 'r') as f:
text = f.read(512 * 3072)
IDS = TOKENIZER.EncodeAsIds(text)
IDS = np.asarray(IDS, dtype=np.int32)
PAD_AMOUNT = 512 * 1024 - len(IDS)
print("Number of tokens:", IDS.shape[0])
#@title Split the dataset
train_validation_split_ratio = 0.9 #@param {type:"slider", min:0.05, max:0.95, step:0.05}
# Tokenize (set to max for the provided dataset)
trX, vaX = np.split(Y[:512 * 1024], [int((len(Y[:512 * 1024]) * train_validation_split_ratio))])
data_train, data_val = torch.from_numpy(trX), torch.from_numpy(vaX)
"""# Setup the Reformer model and functions"""
#@title Initialize the functions and procedures for training
# Set up the data pipeline.
def my_inputs(n_devices):
while True:
inputs = []
mask = []
pad_amounts = np.random.choice(PAD_AMOUNT, n_devices)
for i in range(n_devices):
inputs.append(np.pad(IDS, (pad_amounts[i], PAD_AMOUNT - pad_amounts[i]), # Pad IDS by different amount for each device
mode='constant'))
mask.append(np.pad(np.ones_like(IDS, dtype=np.float32),
(pad_amounts[i], PAD_AMOUNT - pad_amounts[i]),
mode='constant'))
inputs = np.stack(inputs)
mask = np.stack(mask)
yield (inputs, inputs, mask)
print("(device count, tokens per device) = ",
next(my_inputs(trax.fastmath.device_count()))[0].shape)
#@title Configure hyperparamenters
# Configure hyperparameters.
gin.parse_config("""
import trax.layers
import trax.models
import trax.optimizers
import trax.data.inputs
import trax.supervised.trainer_lib
# Parameters that will vary between experiments:
# ==============================================================================
train.model = @trax.models.ReformerLM
# Model will have 6 layers, alternating between the LSH attention
# and local attention within a certain context window.
n_layers = 6
attn_type = [
@trax.layers.SelfAttention,
@LSHSelfAttention,
@trax.layers.SelfAttention,
@LSHSelfAttention,
@trax.layers.SelfAttention,
@LSHSelfAttention,
]
share_qk = False # LSH attention ignores this flag and always shares q & k
n_heads = 2
attn_kv = 64
dropout = 0.05
n_tokens = 524288
# Parameters for multifactor:
# ==============================================================================
multifactor.constant = 0.01
multifactor.factors = 'constant * linear_warmup * cosine_decay'
multifactor.warmup_steps = 100
multifactor.steps_per_cycle = 900
# Parameters for Adam:
# ==============================================================================
Adam.weight_decay_rate=0.0
Adam.b1 = 0.86
Adam.b2 = 0.92
Adam.eps = 1e-9
# Parameters for SelfAttention:
# ==============================================================================
trax.layers.SelfAttention.attention_dropout = 0.05
trax.layers.SelfAttention.chunk_len = 64
trax.layers.SelfAttention.n_chunks_before = 1
trax.layers.SelfAttention.n_parallel_heads = 1
# Parameters for LSHSelfAttention:
# ==============================================================================
LSHSelfAttention.attention_dropout = 0.0
LSHSelfAttention.chunk_len = 64
LSHSelfAttention.n_buckets = [64, 128]
LSHSelfAttention.n_chunks_after = 0
LSHSelfAttention.n_chunks_before = 1
LSHSelfAttention.n_hashes = 1
LSHSelfAttention.n_parallel_heads = 1
LSHSelfAttention.predict_drop_len = 128
LSHSelfAttention.predict_mem_len = 1024
# Parameters for ReformerLM:
# ==============================================================================
ReformerLM.attention_type = %attn_type
ReformerLM.d_attention_key = %attn_kv
ReformerLM.d_attention_value = %attn_kv
ReformerLM.d_model = 256
ReformerLM.d_ff = 512
ReformerLM.dropout = %dropout
ReformerLM.ff_activation = @trax.layers.Relu
ReformerLM.max_len = %n_tokens
ReformerLM.mode = 'train'
ReformerLM.n_heads = %n_heads
ReformerLM.n_layers = %n_layers
ReformerLM.vocab_size = 320
ReformerLM.axial_pos_shape = (512, 1024)
ReformerLM.d_axial_pos_embs= (64, 192)
""")
#@title Setup the model and the trainer routines
# Trainer.
output_dir = os.path.expanduser('model')
#!rm -f ~/model/model.pkl.gz # Remove old model
trainer = trax.supervised.Trainer(
model=trax.models.ReformerLM,
loss_fn=trax.layers.CrossEntropyLoss(),
optimizer=trax.optimizers.Adam,
lr_schedule=trax.lr.multifactor(),
inputs=trax.data.inputs.Inputs(my_inputs),
output_dir=output_dir)
"""# Train"""
#@title Train the model
# Train Model
#@markdown This cell takes about 10 minutes to produce first output. Please wait...
import tqdm
print('=' * 50)
print('JITing NN...')
trainer.train_epoch(n_steps=1, n_eval_steps=1)
print('=' * 50)
print('Continuing last run to the max...')
trainer.train_epoch(n_steps=9, n_eval_steps=1)
print('=' * 50)
print('Running main training loop')
for _ in tqdm.auto.tqdm(range(59)):
trainer.train_epoch(n_steps=10, n_eval_steps=1)
#@title Zip and download your trained model checkpoint here
# Zip directory contents
shutil.make_archive("project", "zip", ".")
# Download zipped directory
files.download('project.zip')
"""# Generate Music"""
#@title Increase hashing rounds number for better quality here
# In the Reformer paper, increasing the number of hashing rounds helps with quality.
# The number of hashing rounds at can be increased at evaluation time only.
gin.parse_config("""LSHSelfAttention.n_hashes = 4""")
#@title Load the trained Reformer in 'predict' mode
# Load the trained Reformer in 'predict' mode
model = trax.models.ReformerLM(mode='predict')
output_dir = os.path.expanduser('model')
model.init_from_file(os.path.join(output_dir,'model.pkl.gz'),
weights_only=True)
#@title Generate and decode music from the model
# Sample from ReformerLM
output_token_ids = trax.supervised.decoding.autoregressive_sample(
model, temperature=0.8, max_length=2048, batch_size = 1)
# Decode token IDs
# Reformer outputed a batch with one item so access it using [0]
# tolist() converts from int64 to int, the type SentencePiece expects
input = TOKENIZER.DecodeIds(output_token_ids[0].tolist())
#@title Convert generated output to MIDI.
# Run the cells below to convert generated output to MIDI.
# If you getting errors/halts, regenerate the output again.
# Model must be sufficiently trained. Rec. 0.90+ accuracy for the output to make sense and pass error control.
#TXT = TMIDI.Tegridy_INT_String_to_TXT_Converter(input, line_by_line_input=False)
SONG = TMIDI.Tegridy_Optimus_TXT_to_Notes_Converter(input, has_MIDI_channels=False, char_encoding_offset=30000, simulate_velocity=True, dataset_MIDI_events_time_denominator=1, line_by_line_dataset=False)
stats = TMIDI.Tegridy_SONG_to_MIDI_Converter(SONG=SONG[0], output_file_name='/content/Music-Reformer_MIDI')
print(stats)
"""# Congrats!!! You did it!!! :)"""