Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add nnx.shard_map #4261

Open
cgarciae opened this issue Oct 7, 2024 · 6 comments
Open

Add nnx.shard_map #4261

cgarciae opened this issue Oct 7, 2024 · 6 comments
Assignees
Labels
Priority: P0 (urgent) Response within 1 business day. Resolution as soon as possible. (Assignee required)

Comments

@cgarciae
Copy link
Collaborator

cgarciae commented Oct 7, 2024

No description provided.

@cgarciae cgarciae added the Priority: P0 (urgent) Response within 1 business day. Resolution as soon as possible. (Assignee required) label Oct 7, 2024
@cgarciae cgarciae self-assigned this Oct 7, 2024
@8bitmp3
Copy link
Collaborator

8bitmp3 commented Oct 8, 2024

Thanks @cgarciae 🚀

TODO Then we can update the Transforms guide

@carlesoctav
Copy link

any updates?

@marcelroed
Copy link

@cgarciae This is P0 to me and I imagine to many others too. Let me know if I can do anything to help this along.

@cgarciae
Copy link
Collaborator Author

@marcelroed thanks for the ping. Will try to get it in soon, all the pieces are ready.

@marcelroed
Copy link

In the meantime, are there any issues with doing

from jax.experimental.shard_map import shard_map

@partial(shard_map, mesh=mesh, in_specs=..., out_specs=...)
def forward(graph_def, state, x):
    model = nnx.merge(graph_def, state)
    return model(x)

other than mutations not reaching outside of JIT?

@cgarciae
Copy link
Collaborator Author

@marcelroed using split / merge is perfectly valid. If you want to propagate state updates you can also return the new state and use nnx.update outside.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Priority: P0 (urgent) Response within 1 business day. Resolution as soon as possible. (Assignee required)
Projects
None yet
Development

No branches or pull requests

4 participants