diff --git a/src/Layer.js b/src/Layer.js index e6661fdb..83956da1 100644 --- a/src/Layer.js +++ b/src/Layer.js @@ -71,4 +71,11 @@ export default class Layer { this.output = x return this.output } + + /** + * Reset presistent state in the mode. Only applies with stateful rnns. + */ + resetStates() { + return + } } diff --git a/src/Model.js b/src/Model.js index 9cfb585d..6b087962 100644 --- a/src/Model.js +++ b/src/Model.js @@ -615,6 +615,15 @@ export default class Model { return layer.call(x) } + /** + * Reset presistent state in the mode. Only applies with stateful rnns. + */ + resetStates() { + this.modelLayersMap.forEach(layer => { + layer.resetStates() + }) + } + /** * Cleanup - important for memory management */ diff --git a/src/layers/recurrent/GRU.js b/src/layers/recurrent/GRU.js index 09d4292f..c2f47b1c 100644 --- a/src/layers/recurrent/GRU.js +++ b/src/layers/recurrent/GRU.js @@ -525,4 +525,8 @@ export default class GRU extends Layer { this.output.transferFromGLTexture() } } + + resetStates() { + this.currentHiddenState = null + } } diff --git a/src/layers/recurrent/LSTM.js b/src/layers/recurrent/LSTM.js index d14b4525..055aa189 100644 --- a/src/layers/recurrent/LSTM.js +++ b/src/layers/recurrent/LSTM.js @@ -634,4 +634,9 @@ export default class LSTM extends Layer { this.output.transferFromGLTexture() } } + + resetStates() { + this.previousCandidate = null + this.currentHiddenState = null + } } diff --git a/src/layers/recurrent/SimpleRNN.js b/src/layers/recurrent/SimpleRNN.js index 372c2833..6544cbcd 100644 --- a/src/layers/recurrent/SimpleRNN.js +++ b/src/layers/recurrent/SimpleRNN.js @@ -295,4 +295,8 @@ export default class SimpleRNN extends Layer { this.output.transferFromGLTexture() } } + + resetStates() { + this.currentHiddenState = null + } }