Skip to content

Commit

Permalink
Move Python scripts for building examples into tfjs-examples (tensorf…
Browse files Browse the repository at this point in the history
…low#38)

* WIP

* Add mnist-transfer-cnn and update build files

* Update build scripts

* bugfixes

* yarn.lock updates

* Remove http-server in favor of yarn watch; fix licenses

* Serve local pretrained model with http-server

* Let user choose local vs. remote pretrained model.
  • Loading branch information
davidsoergel authored Mar 30, 2018
1 parent d320351 commit e98ea7f
Show file tree
Hide file tree
Showing 47 changed files with 2,819 additions and 266 deletions.
2 changes: 1 addition & 1 deletion deploy.sh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#!/usr/bin/env bash
# Copyright 2018 Google Inc. All Rights Reserved.
# Copyright 2018 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
61 changes: 61 additions & 0 deletions iris/build-resources.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
#!/usr/bin/env bash

# Copyright 2018 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================

# Builds resources for the Iris demo.
# Note this is not necessary to run the demo, because we already provide hosted
# pre-built resources.
# Usage example: do this from the 'iris' directory:
# ./build-resources.sh

set -e

DEMO_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"

TRAIN_EPOCHS=100
while true; do
if [[ "$1" == "--epochs" ]]; then
TRAIN_EPOCHS=$2
shift 2
elif [[ -z "$1" ]]; then
break
else
echo "ERROR: Unrecognized argument: $1"
exit 1
fi
done

RESOURCES_ROOT="${DEMO_DIR}/dist/resources"
rm -rf "${RESOURCES_ROOT}"
mkdir -p "${RESOURCES_ROOT}"

# Run Python script to generate the pretrained model and weights files.
# Make sure you install the tensorflowjs pip package first.

python "${DEMO_DIR}/python/iris.py" \
--epochs "${TRAIN_EPOCHS}" \
--artifacts_dir "${RESOURCES_ROOT}"

cd ${DEMO_DIR}
yarn
yarn build

echo
echo "-----------------------------------------------------------"
echo "Resources written to ${RESOURCES_ROOT}."
echo "You can now run the demo with 'yarn watch'."
echo "-----------------------------------------------------------"
echo
3 changes: 2 additions & 1 deletion iris/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,8 @@ <h1>TensorFlow.js Layers: Iris Demo</h1>
</div>

<div class="create-model">
<button id="load-pretrained">Load pretrained model</button>
<button id="load-pretrained-remote" style="display:none">Load hosted pretrained model</button>
<button id="load-pretrained-local" style="display:none">Load local pretrained model</button>
</div>

<div>
Expand Down
85 changes: 46 additions & 39 deletions iris/index.js
Original file line number Diff line number Diff line change
Expand Up @@ -17,28 +17,12 @@

import * as tf from '@tensorflow/tfjs';

import {getIrisData, IRIS_CLASSES, IRIS_NUM_CLASSES} from './data';
import {clearEvaluateTable, getManualInputData, loadTrainParametersFromUI, plotAccuracies, plotLosses, renderEvaluateTable, renderLogitsForManualInput, setManualInputWinnerMessage, status, wireUpEvaluateTableCallbacks} from './ui';
import * as data from './data';
import * as loader from './loader';
import * as ui from './ui';

let model;

/**
* Load pretrained model stored at a remote URL.
*
* @return An instance of `tf.Model` with model topology and weights loaded.
*/
async function loadHostedPretrainedModel() {
const HOSTED_MODEL_JSON_URL =
'https://storage.googleapis.com/tfjs-models/tfjs/iris_v1/model.json';
status('Loading pretrained model from ' + HOSTED_MODEL_JSON_URL);
try {
model = await tf.loadModel(HOSTED_MODEL_JSON_URL);
status('Done loading pretrained model.');
} catch (err) {
status('Loading pretrained model failed.');
}
}

/**
* Train a `tf.Model` to recognize Iris flower type.
*
Expand All @@ -53,9 +37,9 @@ async function loadHostedPretrainedModel() {
* @returns The trained `tf.Model` instance.
*/
async function trainModel(xTrain, yTrain, xTest, yTest) {
status('Training model... Please wait.');
ui.status('Training model... Please wait.');

const params = loadTrainParametersFromUI();
const params = ui.loadTrainParametersFromUI();

// Define the topology of the model: two dense layers.
const model = tf.sequential();
Expand All @@ -79,16 +63,16 @@ async function trainModel(xTrain, yTrain, xTest, yTest) {
callbacks: {
onEpochEnd: async (epoch, logs) => {
// Plot the loss and accuracy values at the end of every training epoch.
plotLosses(lossValues, epoch, logs.loss, logs.val_loss);
plotAccuracies(accuracyValues, epoch, logs.acc, logs.val_acc);
ui.plotLosses(lossValues, epoch, logs.loss, logs.val_loss);
ui.plotAccuracies(accuracyValues, epoch, logs.acc, logs.val_acc);

// Await web page DOM to refresh for the most recently plotted values.
await tf.nextFrame();
},
}
});

status('Model training complete.');
ui.status('Model training complete.');
return model;
}

Expand All @@ -99,25 +83,25 @@ async function trainModel(xTrain, yTrain, xTest, yTest) {
*/
async function predictOnManualInput(model) {
if (model == null) {
setManualInputWinnerMessage('ERROR: Please load or train model first.');
ui.setManualInputWinnerMessage('ERROR: Please load or train model first.');
return;
}

// Use a `tf.tidy` scope to make sure that WebGL memory allocated for the
// `predict` call is released at the end.
tf.tidy(() => {
// Prepare input data as a 2D `tf.Tensor`.
const inputData = getManualInputData();
const inputData = ui.getManualInputData();
const input = tf.tensor2d([inputData], [1, 4]);

// Call `model.predict` to get the prediction output as probabilities for
// the Iris flower categories.

const predictOut = model.predict(input);
const logits = Array.from(predictOut.dataSync());
const winner = IRIS_CLASSES[predictOut.argMax(-1).dataSync()[0]];
setManualInputWinnerMessage(winner);
renderLogitsForManualInput(logits);
const winner = data.IRIS_CLASSES[predictOut.argMax(-1).dataSync()[0]];
ui.setManualInputWinnerMessage(winner);
ui.renderLogitsForManualInput(logits);
});
}

Expand All @@ -130,39 +114,62 @@ async function predictOnManualInput(model) {
* [numTestExamples, 3].
*/
async function evaluateModelOnTestData(model, xTest, yTest) {
clearEvaluateTable();
ui.clearEvaluateTable();

tf.tidy(() => {
const xData = xTest.dataSync();
const yTrue = yTest.argMax(-1).dataSync();
const predictOut = model.predict(xTest);
const yPred = predictOut.argMax(-1);
renderEvaluateTable(xData, yTrue, yPred.dataSync(), predictOut.dataSync());
ui.renderEvaluateTable(
xData, yTrue, yPred.dataSync(), predictOut.dataSync());
});

predictOnManualInput(model);
}

const LOCAL_MODEL_JSON_URL = 'http://localhost:1235/resources/model.json';
const HOSTED_MODEL_JSON_URL =
'https://storage.googleapis.com/tfjs-models/tfjs/iris_v1/model.json';

/**
* The main function of the Iris demo.
*/
async function iris() {
const [xTrain, yTrain, xTest, yTest] = getIrisData(0.15);
const [xTrain, yTrain, xTest, yTest] = data.getIrisData(0.15);

document.getElementById('train-from-scratch')
.addEventListener('click', async () => {
model = await trainModel(xTrain, yTrain, xTest, yTest);
evaluateModelOnTestData(model, xTest, yTest);
});

document.getElementById('load-pretrained')
.addEventListener('click', async () => {
clearEvaluateTable();
await loadHostedPretrainedModel();
predictOnManualInput(model);
});
if (await loader.urlExists(HOSTED_MODEL_JSON_URL)) {
ui.status('Model available: ' + HOSTED_MODEL_JSON_URL);
const button = document.getElementById('load-pretrained-remote');
button.addEventListener('click', async () => {
ui.clearEvaluateTable();
model = await loader.loadHostedPretrainedModel(HOSTED_MODEL_JSON_URL);
predictOnManualInput(model);
});
// button.style.visibility = 'visible';
button.style.display = 'inline-block';
}

if (await loader.urlExists(LOCAL_MODEL_JSON_URL)) {
ui.status('Model available: ' + LOCAL_MODEL_JSON_URL);
const button = document.getElementById('load-pretrained-local');
button.addEventListener('click', async () => {
ui.clearEvaluateTable();
model = await loader.loadHostedPretrainedModel(LOCAL_MODEL_JSON_URL);
predictOnManualInput(model);
});
// button.style.visibility = 'visible';
button.style.display = 'inline-block';
}

wireUpEvaluateTableCallbacks(() => predictOnManualInput(model));
ui.status('Standing by.');
ui.wireUpEvaluateTableCallbacks(() => predictOnManualInput(model));
}

iris();
53 changes: 53 additions & 0 deletions iris/loader.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/

import * as tf from '@tensorflow/tfjs';
import * as ui from './ui';

/**
* Test whether a given URL is retrievable.
*/
export async function urlExists(url) {
ui.status('Testing url ' + url);
try {
const response = await fetch(url, {method: 'HEAD'});
return response.ok;
} catch (err) {
return false;
}
}

/**
* Load pretrained model stored at a remote URL.
*
* @return An instance of `tf.Model` with model topology and weights loaded.
*/
export async function loadHostedPretrainedModel(url) {
ui.status('Loading pretrained model from ' + url);
try {
const model = await tf.loadModel(url);
ui.status('Done loading pretrained model.');
// We can't load a model twice due to
// https://github.com/tensorflow/tfjs/issues/34
// Therefore we remove the load buttons to avoid user confusion.
ui.disableLoadModelButtons();
return model;
} catch (err) {
console.error(err);
ui.status('Loading pretrained model failed.');
}
}
7 changes: 4 additions & 3 deletions iris/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,16 @@
"vega-embed": "^3.0.0"
},
"scripts": {
"watch": "NODE_ENV=development parcel --no-hmr --open index.html ",
"build": "NODE_ENV=production parcel build index.html --no-minify --public-url ./"
"watch": "./serve.sh",
"build": "NODE_ENV=production parcel build index.html --no-minify --public-url /"
},
"devDependencies": {
"babel-plugin-transform-runtime": "~6.23.0",
"babel-polyfill": "~6.26.0",
"babel-preset-env": "~1.6.1",
"clang-format": "~1.2.2",
"parcel-bundler": "~1.6.2"
"http-server": "~0.10.0",
"parcel-bundler": "~1.7.0"
},
"babel": {
"presets": [
Expand Down
18 changes: 18 additions & 0 deletions iris/python/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Copyright 2018 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
Loading

0 comments on commit e98ea7f

Please sign in to comment.