-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcreate-non-overlapping-splits.py
120 lines (94 loc) · 3.38 KB
/
create-non-overlapping-splits.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
# Python script for creating text files with relative paths
# to train/valid/test files. We use these paths to
# create datasets.
#
# With special thanks to my Bachelor thesis supervisor.
# Author: Ing. Lukáš Marták
import shutil
import argparse
import fnmatch
import random
import os
test_synthnames = set([
'ENSTDkCl',
'ENSTDkAm',
])
train_synthnames = set([
'StbgTGd2',
'SptkBGCl',
'SptkBGAm',
'AkPnStgb',
'AkPnCGdD',
'AkPnBsdf',
'AkPnBcht'
])
def ensure_empty_directory_exists(dirname):
if os.path.exists(dirname):
shutil.rmtree(dirname)
os.makedirs(dirname)
def desugar(c):
prefix = 'MAPS_MUS-'
last = c[::-1].find('_')
pid = c[len(prefix):(-last - 1)]
return prefix, last, pid
def collect_all_piece_ids(base_dir, synthnames):
pids = set()
for synthname in synthnames:
for base, dirs, files in os.walk(os.path.join(base_dir, synthname)):
candidates = fnmatch.filter(files, '*MUS*')
if len(candidates) > 0:
for c in candidates:
_, _, pid = desugar(c)
pids.add(pid)
return pids
def collect_all_filenames(base_dir, synthnames, include):
filenames = set()
for synthname in synthnames:
for base, dirs, files in os.walk(os.path.join(base_dir, synthname)):
candidates = fnmatch.filter(files, '*MUS*')
if len(candidates) > 0:
for c in candidates:
_, _, pid = desugar(c)
if pid in include:
path, ext = os.path.splitext(c)
filenames.add(os.path.join(base, path))
return list(filenames)
def write_pairs(filename, lines):
pairs = []
for line in lines:
pairs.append('{}.wav,{}.mid'.format(line, line))
with open(filename, 'w') as f:
f.writelines('\n'.join(pairs) + '\n')
def main():
random.seed(155853)
parser = argparse.ArgumentParser(description='create non-overlapping splits')
parser.add_argument('maps_base_directory', help='path must be relative to the working directory')
args = parser.parse_args()
train_pids = collect_all_piece_ids(args.maps_base_directory, train_synthnames)
test_pids = collect_all_piece_ids(args.maps_base_directory, test_synthnames)
print('len(train_pids)', len(train_pids))
print('len(test_pids)', len(test_pids))
train_filenames = sorted(collect_all_filenames(
args.maps_base_directory,
train_synthnames,
train_pids - test_pids
))
test_filenames = sorted(collect_all_filenames(
args.maps_base_directory,
test_synthnames,
test_pids
))
# we're validating on a subset of the trainset!
# this is going to tell us **how close we are to learning the trainset by heart**...
# ... and be a **bad estimate of generalization error** ...
valid_filenames = random.sample(train_filenames, 10)
print('len(train_filenames)', len(train_filenames))
print('len(valid_filenames)', len(valid_filenames))
print('len(test_filenames)', len(test_filenames))
dirname = 'non-overlapping'
ensure_empty_directory_exists(dirname)
write_pairs(os.path.join(dirname, 'train'), train_filenames)
write_pairs(os.path.join(dirname, 'valid'), valid_filenames)
write_pairs(os.path.join(dirname, 'test'), test_filenames)
if __name__ == '__main__':
main()