-
Notifications
You must be signed in to change notification settings - Fork 67
/
transcribe.py
89 lines (61 loc) · 3.05 KB
/
transcribe.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import argparse
import os
import sys
import numpy as np
import soundfile
from mir_eval.util import midi_to_hz
from onsets_and_frames import *
def load_and_process_audio(flac_path, sequence_length, device):
random = np.random.RandomState(seed=42)
audio, sr = soundfile.read(flac_path, dtype='int16')
assert sr == SAMPLE_RATE
audio = torch.ShortTensor(audio)
if sequence_length is not None:
audio_length = len(audio)
step_begin = random.randint(audio_length - sequence_length) // HOP_LENGTH
n_steps = sequence_length // HOP_LENGTH
begin = step_begin * HOP_LENGTH
end = begin + sequence_length
audio = audio[begin:end].to(device)
else:
audio = audio.to(device)
audio = audio.float().div_(32768.0)
return audio
def transcribe(model, audio):
mel = melspectrogram(audio.reshape(-1, audio.shape[-1])[:, :-1]).transpose(-1, -2)
onset_pred, offset_pred, _, frame_pred, velocity_pred = model(mel)
predictions = {
'onset': onset_pred.reshape((onset_pred.shape[1], onset_pred.shape[2])),
'offset': offset_pred.reshape((offset_pred.shape[1], offset_pred.shape[2])),
'frame': frame_pred.reshape((frame_pred.shape[1], frame_pred.shape[2])),
'velocity': velocity_pred.reshape((velocity_pred.shape[1], velocity_pred.shape[2]))
}
return predictions
def transcribe_file(model_file, flac_paths, save_path, sequence_length,
onset_threshold, frame_threshold, device):
model = torch.load(model_file, map_location=device).eval()
summary(model)
for flac_path in flac_paths:
print(f'Processing {flac_path}...', file=sys.stderr)
audio = load_and_process_audio(flac_path, sequence_length, device)
predictions = transcribe(model, audio)
p_est, i_est, v_est = extract_notes(predictions['onset'], predictions['frame'], predictions['velocity'], onset_threshold, frame_threshold)
scaling = HOP_LENGTH / SAMPLE_RATE
i_est = (i_est * scaling).reshape(-1, 2)
p_est = np.array([midi_to_hz(MIN_MIDI + midi) for midi in p_est])
os.makedirs(save_path, exist_ok=True)
pred_path = os.path.join(save_path, os.path.basename(flac_path) + '.pred.png')
save_pianoroll(pred_path, predictions['onset'], predictions['frame'])
midi_path = os.path.join(save_path, os.path.basename(flac_path) + '.pred.mid')
save_midi(midi_path, p_est, i_est, v_est)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('model_file', type=str)
parser.add_argument('flac_paths', type=str, nargs='+')
parser.add_argument('--save-path', type=str, default='.')
parser.add_argument('--sequence-length', default=None, type=int)
parser.add_argument('--onset-threshold', default=0.5, type=float)
parser.add_argument('--frame-threshold', default=0.5, type=float)
parser.add_argument('--device', default='cuda' if torch.cuda.is_available() else 'cpu')
with torch.no_grad():
transcribe_file(**vars(parser.parse_args()))