-
Notifications
You must be signed in to change notification settings - Fork 75
/
train.py
85 lines (69 loc) · 1.88 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import cv2
import numpy as np
from keras_squeezenet import SqueezeNet
from keras.optimizers import Adam
from keras.utils import np_utils
from keras.layers import Activation, Dropout, Convolution2D, GlobalAveragePooling2D
from keras.models import Sequential
import tensorflow as tf
import os
IMG_SAVE_PATH = 'image_data'
CLASS_MAP = {
"rock": 0,
"paper": 1,
"scissors": 2,
"none": 3
}
NUM_CLASSES = len(CLASS_MAP)
def mapper(val):
return CLASS_MAP[val]
def get_model():
model = Sequential([
SqueezeNet(input_shape=(227, 227, 3), include_top=False),
Dropout(0.5),
Convolution2D(NUM_CLASSES, (1, 1), padding='valid'),
Activation('relu'),
GlobalAveragePooling2D(),
Activation('softmax')
])
return model
# load images from the directory
dataset = []
for directory in os.listdir(IMG_SAVE_PATH):
path = os.path.join(IMG_SAVE_PATH, directory)
if not os.path.isdir(path):
continue
for item in os.listdir(path):
# to make sure no hidden files get in our way
if item.startswith("."):
continue
img = cv2.imread(os.path.join(path, item))
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = cv2.resize(img, (227, 227))
dataset.append([img, directory])
'''
dataset = [
[[...], 'rock'],
[[...], 'paper'],
...
]
'''
data, labels = zip(*dataset)
labels = list(map(mapper, labels))
'''
labels: rock,paper,paper,scissors,rock...
one hot encoded: [1,0,0], [0,1,0], [0,1,0], [0,0,1], [1,0,0]...
'''
# one hot encode the labels
labels = np_utils.to_categorical(labels)
# define the model
model = get_model()
model.compile(
optimizer=Adam(lr=0.0001),
loss='categorical_crossentropy',
metrics=['accuracy']
)
# start training
model.fit(np.array(data), np.array(labels), epochs=10)
# save the model for later use
model.save("rock-paper-scissors-model.h5")