forked from tensorflow/tfjs-examples
-
Notifications
You must be signed in to change notification settings - Fork 0
/
index.js
188 lines (161 loc) · 6.32 KB
/
index.js
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
import * as tf from '@tensorflow/tfjs';
import {ControllerDataset} from './controller_dataset';
import * as ui from './ui';
import {Webcam} from './webcam';
// The number of classes we want to predict. In this example, we will be
// predicting 4 classes for up, down, left, and right.
const NUM_CLASSES = 4;
// A webcam class that generates Tensors from the images from the webcam.
const webcam = new Webcam(document.getElementById('webcam'));
// The dataset object where we will store activations.
const controllerDataset = new ControllerDataset(NUM_CLASSES);
let truncatedMobileNet;
let model;
// Loads mobilenet and returns a model that returns the internal activation
// we'll use as input to our classifier model.
async function loadTruncatedMobileNet() {
const mobilenet = await tf.loadLayersModel(
'https://storage.googleapis.com/tfjs-models/tfjs/mobilenet_v1_0.25_224/model.json');
// Return a model that outputs an internal activation.
const layer = mobilenet.getLayer('conv_pw_13_relu');
return tf.model({inputs: mobilenet.inputs, outputs: layer.output});
}
// When the UI buttons are pressed, read a frame from the webcam and associate
// it with the class label given by the button. up, down, left, right are
// labels 0, 1, 2, 3 respectively.
ui.setExampleHandler(label => {
tf.tidy(() => {
const img = webcam.capture();
controllerDataset.addExample(truncatedMobileNet.predict(img), label);
// Draw the preview thumbnail.
ui.drawThumb(img, label);
});
});
/**
* Sets up and trains the classifier.
*/
async function train() {
if (controllerDataset.xs == null) {
throw new Error('Add some examples before training!');
}
// Creates a 2-layer fully connected model. By creating a separate model,
// rather than adding layers to the mobilenet model, we "freeze" the weights
// of the mobilenet model, and only train weights from the new model.
model = tf.sequential({
layers: [
// Flattens the input to a vector so we can use it in a dense layer. While
// technically a layer, this only performs a reshape (and has no training
// parameters).
tf.layers.flatten(
{inputShape: truncatedMobileNet.outputs[0].shape.slice(1)}),
// Layer 1.
tf.layers.dense({
units: ui.getDenseUnits(),
activation: 'relu',
kernelInitializer: 'varianceScaling',
useBias: true
}),
// Layer 2. The number of units of the last layer should correspond
// to the number of classes we want to predict.
tf.layers.dense({
units: NUM_CLASSES,
kernelInitializer: 'varianceScaling',
useBias: false,
activation: 'softmax'
})
]
});
// Creates the optimizers which drives training of the model.
const optimizer = tf.train.adam(ui.getLearningRate());
// We use categoricalCrossentropy which is the loss function we use for
// categorical classification which measures the error between our predicted
// probability distribution over classes (probability that an input is of each
// class), versus the label (100% probability in the true class)>
model.compile({optimizer: optimizer, loss: 'categoricalCrossentropy'});
// We parameterize batch size as a fraction of the entire dataset because the
// number of examples that are collected depends on how many examples the user
// collects. This allows us to have a flexible batch size.
const batchSize =
Math.floor(controllerDataset.xs.shape[0] * ui.getBatchSizeFraction());
if (!(batchSize > 0)) {
throw new Error(
`Batch size is 0 or NaN. Please choose a non-zero fraction.`);
}
// Train the model! Model.fit() will shuffle xs & ys so we don't have to.
model.fit(controllerDataset.xs, controllerDataset.ys, {
batchSize,
epochs: ui.getEpochs(),
callbacks: {
onBatchEnd: async (batch, logs) => {
ui.trainStatus('Loss: ' + logs.loss.toFixed(5));
}
}
});
}
let isPredicting = false;
async function predict() {
ui.isPredicting();
while (isPredicting) {
const predictedClass = tf.tidy(() => {
// Capture the frame from the webcam.
const img = webcam.capture();
// Make a prediction through mobilenet, getting the internal activation of
// the mobilenet model, i.e., "embeddings" of the input images.
const embeddings = truncatedMobileNet.predict(img);
// Make a prediction through our newly-trained model using the embeddings
// from mobilenet as input.
const predictions = model.predict(embeddings);
// Returns the index with the maximum probability. This number corresponds
// to the class the model thinks is the most probable given the input.
return predictions.as1D().argMax();
});
const classId = (await predictedClass.data())[0];
predictedClass.dispose();
ui.predictClass(classId);
await tf.nextFrame();
}
ui.donePredicting();
}
document.getElementById('train').addEventListener('click', async () => {
ui.trainStatus('Training...');
await tf.nextFrame();
await tf.nextFrame();
isPredicting = false;
train();
});
document.getElementById('predict').addEventListener('click', () => {
ui.startPacman();
isPredicting = true;
predict();
});
async function init() {
try {
await webcam.setup();
} catch (e) {
document.getElementById('no-webcam').style.display = 'block';
}
truncatedMobileNet = await loadTruncatedMobileNet();
// Warm up the model. This uploads weights to the GPU and compiles the WebGL
// programs so the first time we collect data from the webcam it will be
// quick.
tf.tidy(() => truncatedMobileNet.predict(webcam.capture()));
ui.init();
}
// Initialize the application.
init();