Skip to content

Commit

Permalink
[jena-weather] Add RNN training in Node.js; Use fitDataset() and tfvi…
Browse files Browse the repository at this point in the history
…s.show.fitCallbacks() (tensorflow#206)

* To this end, the model creation and training logic is refactored out from `index.js` and moved into
  `models.js`, which is used by the browser environment and the backend (tfjs-node/tfjs-node-gpu)
  environment.
* The newly added `train-rnn.js` drives the training of RNNs in the backend
* Replace custom `tf.Model.trainOnBatch()` on more principled `tf.Model.fitDataset()`
* Replace custom plotting methods with `tfvis.show.fitCallbacks()`
  • Loading branch information
caisq authored Jan 16, 2019
1 parent 772a6d9 commit 274758e
Show file tree
Hide file tree
Showing 7 changed files with 2,482 additions and 1,037 deletions.
49 changes: 49 additions & 0 deletions jena-weather/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,52 @@ This demo showcases

The data used in this demo is the
[Jena weather archive dataset](https://www.kaggle.com/pankrzysiu/weather-archive-jena).

This example also showcases the usage of the following important APIs in
TensorFlow.js

- `tf.data.generator()`: How to create `tf.data.Dataset` objects from generator
functions.
- `tf.Model.fitDataset()`: How to use a `tf.data.Dataset` object to train a
`tf.Model` and use another `tf.data.Dataset` object to perform validation
of the model at the end of every training epoch.
- `tfvis.show.fitCallbacks()`: How to use the convenient method to plot
training-set and validation-set losses at the end of batches and epochs of
model training.

## Training RNNs

This example shows how to predict temperature using a few different types of
models, including linear regressors, multilayer perceptrons, and recurrent
neural networks (RNNs). While training of the first two types of models
happens in the browser, the training of RNNs is conducted in Node.js, due to
their heavier computational load and longer training time.

For example, to train a gated recurrent unit (GRU) model, use shell commands:

```sh
yarn
yarn train-rnn
```

By default, the training happens on the CPU using the Eigen ops from tfjs-node.
If you have a CUDA-enabled GPU and the necessary drivers and libraries (CUDA and
CuDNN) installed, you can train the model using the CUDA/CuDNN ops from
tfjs-node-gpu. For that, just add the `--gpu` flag:

```sh
yarn
yarn train-rnn --gpu
```

You can also calculate the prediction error (mean absolute error) based on a
commonsense baseline method that is not machine learning: just predict the
temperature as the latest temperature data point in the input features.
This can be done with the dummy `--modelType` flag value `baseline`, i.e.,

```sh
yarn
yarn train-rnn --modelType baseline
```

The training code is in the file [train-rnn.js](./train-rnn.js).
28 changes: 21 additions & 7 deletions jena-weather/data.js
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@
* The data used in this demo is the
* [Jena weather archive
* dataset](https://www.kaggle.com/pankrzysiu/weather-archive-jena).
*
* This file is used to load the Jena weather data in both
* - the browser: see [index.js](./index.js), and
* - the Node.js backend environment: see [train-rnn.js](./train-rnn.js).
*/

import * as tf from '@tensorflow/tfjs';
Expand All @@ -46,12 +50,18 @@ export class JenaWeatherData {
* URL (`JENA_WEATHER_CSV_PATH`).
*/
async load() {
let response = await fetch(LOCAL_JENA_WEATHER_CSV_PATH);
if (response.statusCode === 200 || response.statusCode === 304) {
let response;
try {
response = await fetch(LOCAL_JENA_WEATHER_CSV_PATH);
} catch (err) {}

if (response != null &&
(response.statusCode === 200 || response.statusCode === 304)) {
console.log('Loading data from local path');
} else {
response = await fetch(REMOTE_JENA_WEATHER_CSV_PATH);
console.log('Loading data from remote path');
console.log(
`Loading data from remote path: ${REMOTE_JENA_WEATHER_CSV_PATH}`);
}
const csvData = await response.text();

Expand Down Expand Up @@ -267,6 +277,7 @@ export class JenaWeatherData {

function nextBatchFn() {
const rowIndices = [];
let done = false; // Indicates whether the dataset has ended.
if (shuffle) {
// If `shuffle` is `true`, start from randomly chosen rows.
const range = maxIndex - (minIndex + lookBack);
Expand All @@ -276,10 +287,13 @@ export class JenaWeatherData {
}
} else {
// If `shuffle` is `false`, the starting row indices will be sequential.
for (let r = startIndex; r < startIndex + batchSize && r < maxIndex;
++r) {
let r = startIndex;
for (; r < startIndex + batchSize && r < maxIndex; ++r) {
rowIndices.push(r);
}
if (r >= maxIndex) {
done = true;
}
}

const numExamples = rowIndices.length;
Expand Down Expand Up @@ -320,8 +334,8 @@ export class JenaWeatherData {
}
return {
value: [samples.toTensor(), targets.toTensor()],
done: false
}; // TODO(cais): Return done = true when done.
done
};
}

return nextBatchFn.bind(this);
Expand Down
Loading

0 comments on commit 274758e

Please sign in to comment.