forked from jason9693/MusicTransformer-pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
generate.py
54 lines (42 loc) · 1.36 KB
/
generate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import custom
from custom import criterion
from custom.layers import *
from custom.config import config
from model import MusicTransformer
from data import Data
import utils
from midi_processor_fixed.processor import decode_midi, encode_midi
import datetime
import argparse
from tensorboardX import SummaryWriter
parser = custom.get_argument_parser()
args = parser.parse_args()
config.load(args.model_dir, args.configs, initialize=True)
# check cuda
if torch.cuda.is_available():
config.device = torch.device('cuda')
else:
config.device = torch.device('cpu')
current_time = datetime.datetime.now().strftime('%Y%m%d-%H%M%S')
gen_log_dir = 'logs/mt_decoder/generate_'+current_time+'/generate'
gen_summary_writer = SummaryWriter(gen_log_dir)
mt = MusicTransformer(
embedding_dim=config.embedding_dim,
vocab_size=config.vocab_size,
num_layer=config.num_layers,
max_seq=config.max_seq,
dropout=0,
debug=False)
mt.load_state_dict(torch.load(args.model_dir+'/final.pth'))
mt.test()
print(config.condition_file)
if config.condition_file is not None:
inputs = np.array([encode_midi('dataset/midi/BENABD10.mid')[:500]])
else:
inputs = np.array([[24, 28, 31]])
inputs = torch.from_numpy(inputs)
result = mt(inputs, config.length, gen_summary_writer)
for i in result:
print(i)
decode_midi(result, file_path=config.save_path)
gen_summary_writer.close()