forked from tensorflow/tfjs-examples
-
Notifications
You must be signed in to change notification settings - Fork 0
/
ui.js
122 lines (102 loc) · 4.07 KB
/
ui.js
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
import * as tf from '@tensorflow/tfjs';
const CONTROLS = ['up', 'down', 'left', 'right'];
const CONTROL_CODES = [38, 40, 37, 39];
export function init() {
document.getElementById('controller').style.display = '';
statusElement.style.display = 'none';
}
const trainStatusElement = document.getElementById('train-status');
// Set hyper params from UI values.
const learningRateElement = document.getElementById('learningRate');
export const getLearningRate = () => +learningRateElement.value;
const batchSizeFractionElement = document.getElementById('batchSizeFraction');
export const getBatchSizeFraction = () => +batchSizeFractionElement.value;
const epochsElement = document.getElementById('epochs');
export const getEpochs = () => +epochsElement.value;
const denseUnitsElement = document.getElementById('dense-units');
export const getDenseUnits = () => +denseUnitsElement.value;
const statusElement = document.getElementById('status');
export function startPacman() {
google.pacman.startGameplay();
}
export function predictClass(classId) {
google.pacman.keyPressed(CONTROL_CODES[classId]);
document.body.setAttribute('data-active', CONTROLS[classId]);
}
export function isPredicting() {
statusElement.style.visibility = 'visible';
}
export function donePredicting() {
statusElement.style.visibility = 'hidden';
}
export function trainStatus(status) {
trainStatusElement.innerText = status;
}
export let addExampleHandler;
export function setExampleHandler(handler) {
addExampleHandler = handler;
}
let mouseDown = false;
const totals = [0, 0, 0, 0];
const upButton = document.getElementById('up');
const downButton = document.getElementById('down');
const leftButton = document.getElementById('left');
const rightButton = document.getElementById('right');
const thumbDisplayed = {};
async function handler(label) {
mouseDown = true;
const className = CONTROLS[label];
const button = document.getElementById(className);
const total = document.getElementById(className + '-total');
while (mouseDown) {
addExampleHandler(label);
document.body.setAttribute('data-active', CONTROLS[label]);
total.innerText = ++totals[label];
await tf.nextFrame();
}
document.body.removeAttribute('data-active');
}
upButton.addEventListener('mousedown', () => handler(0));
upButton.addEventListener('mouseup', () => mouseDown = false);
downButton.addEventListener('mousedown', () => handler(1));
downButton.addEventListener('mouseup', () => mouseDown = false);
leftButton.addEventListener('mousedown', () => handler(2));
leftButton.addEventListener('mouseup', () => mouseDown = false);
rightButton.addEventListener('mousedown', () => handler(3));
rightButton.addEventListener('mouseup', () => mouseDown = false);
export function drawThumb(img, label) {
if (thumbDisplayed[label] == null) {
const thumbCanvas = document.getElementById(CONTROLS[label] + '-thumb');
draw(img, thumbCanvas);
}
}
export function draw(image, canvas) {
const [width, height] = [224, 224];
const ctx = canvas.getContext('2d');
const imageData = new ImageData(width, height);
const data = image.dataSync();
for (let i = 0; i < height * width; ++i) {
const j = i * 4;
imageData.data[j + 0] = (data[i * 3 + 0] + 1) * 127;
imageData.data[j + 1] = (data[i * 3 + 1] + 1) * 127;
imageData.data[j + 2] = (data[i * 3 + 2] + 1) * 127;
imageData.data[j + 3] = 255;
}
ctx.putImageData(imageData, 0, 0);
}