-
Notifications
You must be signed in to change notification settings - Fork 0
/
collect_data.py
123 lines (116 loc) · 3.04 KB
/
collect_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import serial
import time
import threading
import matplotlib.pyplot as plt
import pandas as pd
import os
import sys
SERIAL_PATH = "/dev/ttyUSB0"
BAUD_RATE = 115200
DATA_PATH = "../data/"
FIGURES_PATH = "figures/"
GESTURES = ["Kick","Hihat","Snare","Tom","Crash"]
def read_serial():
global event, data, seconds
while event.is_set():
line = serial.readline() # read a byte
if not line:
continue
try:
string = line.decode() # convert the byte string to a unicode string
if not seconds:
print(string, end="")
else:
print(f"[{seconds}] {string}", end="")
lst = [float(k) for k in string.split(',')[:-1]]
if len(lst) != 5:
continue
data["time"].append(time.time())
data["index"].append(lst[0])
data["middle"].append(lst[1])
data["ring"].append(lst[2])
data["pinky"].append(lst[3])
data["accX"].append(lst[4])
except:
pass
serial.close()
def plot_data():
global data, gesture
trial_duration = data["time"][-1] - data["time"][0]
freq = len(data["accX"]) / trial_duration
plt.figure(figsize=(15,6))
for k,v in data.items():
if k in ["index", "middle", "ring", "pinky", "accX"]:
plt.plot(v, label=k)
try:
[g1,g2] = [int(k) for k in gesture.split('-')]
except:
g1,g2 = 0,0
plt.ylim((400,900))
plt.title("Gesture data for %s-%s (Freq = %.3fHz)" % (GESTURES[g1], GESTURES[g2], freq))
plt.legend()
if os.path.exists(FIGURES_PATH):
plt.savefig(FIGURES_PATH + gesture + ".png")
# plt.show()
# Check paths
if not os.path.exists(DATA_PATH):
print(f"[ERROR] Data path {DATA_PATH} does not exist. Make directory if needed or change DATA_PATH.")
sys.exit()
if not os.path.exists(SERIAL_PATH):
print(f"[ERROR] Path to USB {SERIAL_PATH} is wrong. Check again.")
sys.exit()
# Main code
data = {
"gesture": [],
"time": [],
"index": [],
"middle": [],
"ring": [],
"pinky": [],
"accX": []
}
event = threading.Event()
t = threading.Thread(target=read_serial)
serial = serial.Serial(SERIAL_PATH, BAUD_RATE, timeout=1)
gesture = input("Gesture (e.g. 01): ")
NOBEATS = False
if gesture[0] == gesture[1]:
if input("Fixed (y/n)?: ") == "y":
NOBEATS = True
gesture = gesture[0] + '-' + gesture[1]
event.set()
t.start()
print("Reading...(Ctrl-C once you've done 100 gestures)")
try:
start = time.time()
seconds = 0
while True:
time.sleep(.1)
if not NOBEATS:
continue
end = time.time()
if end - start > 90:
break
if end - start > seconds:
seconds += 1
event.clear()
t.join()
plot_data()
data["gesture"] = [gesture for k in data["accX"]]
filename = input("Save data as (if left empty, filename will be <gesture>.csv): ")
if not filename:
filename = gesture + "-NOBEATS.csv"
df = pd.DataFrame(data)
df.to_csv(DATA_PATH + filename)
print("Saved data as " + filename)
except KeyboardInterrupt:
event.clear()
t.join()
plot_data()
data["gesture"] = [gesture for k in data["accX"]]
filename = input("Save data as (if left empty, filename will be <gesture>.csv): ")
if not filename:
filename = gesture + ".csv"
df = pd.DataFrame(data)
df.to_csv(DATA_PATH + filename)
print("Saved data as " + filename)