Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to unroll a model end to end? #4507

Open
JoaoAparicio opened this issue Jan 27, 2025 · 3 comments
Open

How to unroll a model end to end? #4507

JoaoAparicio opened this issue Jan 27, 2025 · 3 comments

Comments

@JoaoAparicio
Copy link

Hello

What's the correct way to unroll a model that contains an LSTM?

e.g. Suppose my model has 3 blocks from top to bottom:

spatial_block (not recurrent)
lstm
mlp (not recurrent)

I know how to unroll the LSTM N times, there's a module for that:

lstm_unrolled = nnx.RRN(lstm, unroll=N, ...)

But then, how do I do the same for the rest of the non-recurrent parts of my model?

It ocurred to me that nnx.RNN is perhaps generic enough to work with the full model end to end? The fact that it's annotated as taking a nnx.RNNCellBase suggests this is probably not the case.

Instead I tried using nnx.RNN to unroll the LSTM but then do the rest manually.
For example I managed to make this work (schematically):

spatialblock_w_timesteps = nnx.vmap(SpatialBlock.__call__, in_axes=(None, 0), out_axes=0)
lstm_unrolled = nnx.RNN(model.lstmcell, unroll=rnn_unroll, rngs=rngs)
mlp_w_timesteps = nnx.vmap(MLP.__call__, in_axes=(None, 0), out_axes=0)

x = spatialblock_w_timesteps(model.spatial_block, x)
x = lstm_unrolled(x)
x = mlp_w_timesteps(model.mlp, x)

This works. But because it has to be done by hand on a model by model bases, it scales badly for more complex models, and it's error prone.
Is there a way of unrolling the full model as easily as the LSTM component?

@cgarciae
Copy link
Collaborator

Hi @JoaoAparicio, layers like nnx.Linear, which I imagine is the base for your MLP, broadcast to all leading dimensions, in other words they are applied all time steps. E.g.

class Model(nnx.Module):
  def __init__(self, rngs):
    self.rnn = nnx.RNN(...)
    self.linear = nnx.Linear(..., rngs=rngs)
  
  def __call__(self, x):
    # x: [batch, time, features]
    x = self.rnn(x) # applies recurrently to all timesteps
    x = self.linear(x) # applies in parallel to all timesteps
    return x

@JoaoAparicio
Copy link
Author

JoaoAparicio commented Jan 29, 2025

Hey, thank you for taking the time :-)

Couple of follow up questions!

In the code that you presented you have the comment that self.rnn(x) applies to all timesteps. But just above it you have a comment about the shape of x, x: [batch, time, features]. So does self.rnn also apply in parallel to all batch dimensions? As you said, self.linear does, but what about self.rnn?

And following from the above, is the intended design when writing modules, that modules should be written in a way to understand the minimum number of features dimensions they require, and assume that any additional outter dimensions are to be parallelized over?

@JoaoAparicio
Copy link
Author

Oh and quick question: in the timesteps dimension, which direction does time go?

This isn't in the RNN docs, but from this code
https://github.com/google/flax/blob/main/flax/nnx/nn/recurrent.py#L749-L781

      >>> inputs = jnp.array([[1, 0, 0], [2, 3, 0], [4, 5, 6]])
      Array([[1, 0, 0],
             [3, 2, 0],
             [6, 5, 4]], dtype=int32)

should I infer that for input with dimensions (time, features), inputs[-1, :] contains features at time t, inputs[-2, :] contains features at time t-1, etc?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants