-
Notifications
You must be signed in to change notification settings - Fork 2.3k
/
ui.js
128 lines (113 loc) · 4.41 KB
/
ui.js
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
import * as tf from '@tensorflow/tfjs';
import * as util from './util';
export function status(statusText, statusColor) {
console.log(statusText);
document.getElementById('status').textContent = statusText;
document.getElementById('status').style.color = statusColor;
}
export function prepUI(predict, retrain, testExamples, imageSize) {
setPredictFunction(predict, testExamples, imageSize);
const imageInput = document.getElementById('image-input');
imageInput.value = util.imageVectorToText(testExamples['5_1'], imageSize);
predict(imageInput.value);
setRetrainFunction(retrain);
document.getElementById('retrain').disabled = false;
document.getElementById('test-image-select').disabled = false;
}
export function getImageInput() {
return document.getElementById('image-input').value;
}
export function getEpochs() {
return Number.parseInt(document.getElementById('epochs').value);
}
function setPredictFunction(predict, testExamples, imageSize) {
const imageInput = document.getElementById('image-input');
imageInput.addEventListener('keyup', () => {
const result = predict(imageInput.value);
});
const testImageSelect = document.getElementById('test-image-select');
testImageSelect.addEventListener('change', () => {
imageInput.value =
util.imageVectorToText(testExamples[testImageSelect.value], imageSize);
predict(imageInput.value);
});
}
function setRetrainFunction(retrain) {
const retrainButton = document.getElementById('retrain');
retrainButton.addEventListener('click', async () => {
document.getElementById('retrain').disabled = true;
await retrain();
});
}
export function getProgressBarCallbackConfig(epochs) {
// Custom callback for updating the progress bar at the end of epochs.
const trainProg = document.getElementById('trainProg');
let beginMillis;
const progressBarCallbackConfig = {
onTrainBegin: async (logs) => {
beginMillis = tf.util.now();
status(
'Please wait and do NOT click anything while the model retrains...',
'blue');
trainProg.value = 0;
},
onTrainEnd: async (logs) => {
// allow retraining again
document.getElementById('retrain').disabled = false;
status(
`Done retraining ${epochs} epochs (elapsed: ` +
`${(tf.util.now() - beginMillis).toFixed(1)} ms` +
`). Standing by.`,
'black');
},
onEpochEnd: async (epoch, logs) => {
status(
`Please wait and do NOT click anything while the model ` +
`retrains... (Epoch ${epoch + 1} of ${epochs})`);
trainProg.value = (epoch + 1) / epochs * 100;
},
};
return progressBarCallbackConfig;
}
export function setPredictError(text) {
const predictHeader = document.getElementById('predict-header');
const predictValues = document.getElementById('predict-values');
predictHeader.innerHTML = '<td>Error: ' + text + '</td>';
predictValues.innerHTML = '';
}
export function setPredictResults(predictOut, winner) {
const predictHeader = document.getElementById('predict-header');
const predictValues = document.getElementById('predict-values');
predictHeader.innerHTML =
'<td>5</td><td>6</td><td>7</td><td>8</td><td>9</td>';
let valTds = '';
for (const predictVal of predictOut) {
const valTd = '<td>' + predictVal.toFixed(6) + '</td>';
valTds += valTd;
}
predictValues.innerHTML = valTds;
document.getElementById('winner').textContent = winner;
}
export function disableLoadModelButtons() {
document.getElementById('load-pretrained-remote').style.display = 'none';
document.getElementById('load-pretrained-local').style.display = 'none';
}
export function getTrainingMode() {
return document.getElementById('training-mode').value;
}