-
I am trying to replicate the implementation below from the book "Deep Learning with Python" by François Chollet regarding timeseries data: import keras
from keras import layers
sequence_length = 120
num_features = 14
inputs = keras.Input(shape=(sequence_length, num_features))
x = layers.LSTM(32, recurrent_dropout=0.25)(inputs)
x = layers.Dropout(0.5)(x)
outputs = layers.Dense(1)(x)
model = keras.Model(inputs, outputs) My attempt below: class RecurrentDropout(nnx.Module):
"""Recurrent Dropout Layer.
This module implements recurrent dropout, a regularization technique
specifically designed for recurrent neural networks (RNNs). It generates a
fixed dropout mask during the first call and applies it consistently to the
hidden state across all time steps. This ensures the same units are dropped
throughout the sequence, preventing co-adaptation of hidden units and
improving generalization performance.
Attributes:
rate: The dropout probability (not the keep rate).
rngs: An instance of nnx.Rngs.
Example Usage:
>>> from flax import nnx
>>> import jax.numpy as jnp
>>> rngs = nnx.Rngs(0)
>>> dropout = RecurrentDropout(rngs=rngs, rate=0.2)
>>> x = (jnp.ones((1, 10)), jnp.ones((1, 10))) # Example hidden state and cell state
>>> output = dropout(x)
>>> print(output[0].shape) # Output shape remains the same
(1, 10)
"""
def __init__(self, *, rngs: nnx.Rngs, rate: float):
self.rate = rate
self.rngs = rngs
def __call__(self, x):
"""Applies a fixed dropout mask to the input.
Args:
x: the input to which the dropout mask will be applied.
Returns:
The input with the dropout mask applied.
"""
if not hasattr(self, 'mask'):
key = self.rngs.dropout()
state = nnx.state(self).flat_state()
state['mask'] = nnx.VariableState(
type=nnx.Param, value=jax.random.bernoulli(key, 1 - self.rate, x[0].shape))
h, c = x
return h * state['mask'].value, c
class LSTMWithRecurrentDropout(nnx.OptimizedLSTMCell):
"""LSTM cell with recurrent dropout.
This class extends the nnx.OptimizedLSTMCell by adding recurrent dropout,
a regularization technique that applies dropout to the hidden state of the
LSTM at each time step. This helps to prevent overfitting and improve
generalization performance, particularly in tasks involving sequential data.
Attributes:
rngs: An instance of nnx.Rngs.
in_features: The number of input features.
hidden_features: The number of features in the hidden state of the LSTM.
dropout_rate: The dropout probability for the recurrent connections.
Example usage:
>>> from flax import nnx
>>> import jax.numpy as jnp
>>> rngs = nnx.Rngs(0)
>>> lstm = LSTMWithRecurrentDropout(rngs=rngs, in_features=10, hidden_features=20, dropout_rate=0.2)
>>> h = jnp.ones((1, 20)) # (batch_size, hidden_features)
>>> c = jnp.ones((1, 20)) # (batch_size, hidden_features)
>>> x = jnp.ones((1, 10)) # (batch_size, in_features)
>>> new_h, new_c = lstm((h, c), x, rngs=rngs)
>>> print(new_h.shape, new_c.shape)
(1, 20) (1, 20)
"""
def __init__(self, *, rngs: nnx.Rngs, in_features: int, hidden_features: int, dropout_rate: float, **kwargs):
super().__init__(in_features=in_features, hidden_features=hidden_features, rngs=rngs, **kwargs)
self.recurrent_dropout = RecurrentDropout(rate=dropout_rate, rngs=rngs)
def __call__(self, carry, x):
h, c = carry
new_h, new_c = super().__call__((h, c), x)
new_h = self.recurrent_dropout(new_h)
return new_h, new_c
class RNNWithRecurrentDropout(nnx.Module):
"""Recurrent Neural Network (RNN) with recurrent dropout.
This module implements an RNN with an LSTM cell and recurrent dropout for
processing sequential data. It applies dropout to the hidden state of the
LSTM at each time step, helping to prevent overfitting and improve
generalization.
Attributes:
rngs: An instance of nnx.Rngs.
in_features: The number of input features.
hidden_features: The number of features in the hidden state of the LSTM.
dropout_rate: The dropout probability for the output layer.
recurrent_dropout_rate: The dropout probability for the recurrent
connections in the LSTM.
Example usage:
>>> from flax import nnx
>>> import jax.numpy as jnp
>>> rngs = nnx.Rngs(0)
>>> model = RNNWithRecurrentDropout(rngs=rngs, in_features=10, hidden_features=20)
>>> x = jnp.ones((1, 3, 10)) # (batch_size, sequence_length, in_features)
>>> output = model(x)
>>> print(output.shape)
(1, 1)
"""
def __init__(self, *, rngs: nnx.Rngs, in_features: int,
hidden_features: int = 32, dropout_rate: float = 0.5, recurrent_dropout_rate: float = 0.25):
cell = LSTMWithRecurrentDropout(in_features=in_features, hidden_features=hidden_features,
rngs=rngs, dropout_rate=recurrent_dropout_rate)
self.lstm = nnx.RNN(cell)
self.dropout = nnx.Dropout(dropout_rate, rngs=rngs)
self.dense = nnx.Linear(in_features=hidden_features, out_features=1, rngs=rngs)
def __call__(self, x):
x = self.lstm(x)
x = self.dropout(x)
x = x[:, -1, :] # Use only the final hidden state
return self.dense(x) I am not sure it works. I get a validation MAE of approx. 2.62 degrees while the Keras implementation achieved a validation MAE as low as 2.36 degrees. I guess that my implementation of the recurrent dropout is not correct. |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 5 replies
-
Hey @sgkouzias, you can implement recurrent dropout by selecting a class LSTMWithRecurrentDropout(nnx.OptimizedLSTMCell):
def __init__(
self,
*,
rngs: nnx.Rngs,
in_features: int,
hidden_features: int,
dropout_rate: float,
**kwargs,
):
super().__init__(
in_features=in_features,
hidden_features=hidden_features,
rngs=rngs,
**kwargs,
)
self.recurrent_dropout = nnx.Dropout(
rate=dropout_rate, rng_collection='recurrent_dropout', rngs=rngs
)
def __call__(self, carry, x):
h, c = carry
new_h, new_c = super().__call__((h, c), x)
new_h = jax.tree.map(self.recurrent_dropout, new_h)
return new_h, new_c
class RNNWithRecurrentDropout(nnx.Module):
def __init__(
self,
*,
rngs: nnx.Rngs,
in_features: int,
hidden_features: int = 32,
dropout_rate: float = 0.5,
recurrent_dropout_rate: float = 0.25,
):
cell = LSTMWithRecurrentDropout(
in_features=in_features,
hidden_features=hidden_features,
rngs=rngs,
dropout_rate=recurrent_dropout_rate,
)
self.lstm = nnx.RNN(cell, broadcast_rngs='recurrent_dropout')
self.dropout = nnx.Dropout(dropout_rate, rngs=rngs)
self.dense = nnx.Linear(
in_features=hidden_features, out_features=1, rngs=rngs
)
def __call__(self, x):
x = self.lstm(x)
x = self.dropout(x)
x = x[:, -1, :] # Use only the final hidden state
return self.dense(x)
model = RNNWithRecurrentDropout(
in_features=32,
hidden_features=64,
dropout_rate=0.2,
recurrent_dropout_rate=0.1,
rngs=nnx.Rngs(0, recurrent_dropout=1),
)
x = jnp.ones((8, 10, 32))
y = model(x)
print(y.shape) To test it locally you can install pip install git+https://github.com/google/flax@rnn-broadcast-rngs |
Beta Was this translation helpful? Give feedback.
Hey @sgkouzias, you can implement recurrent dropout by selecting a
rng_collection
different to'dropout'
inDropout
constructor e.g.'recurrent_dropout'
, and then broadcasting state for that RNG stream duringscan
(this is done internally by RNN). However, current RNN API is missing some options so I've created #4407 to address this. Here's a demo of the working solution: