Skip to content

Parallel networks using flax #2862

Answered by cgarciae
mohamad-amin asked this question in Q&A
Discussion options

You must be logged in to vote

Hey @mohamad-amin, what you describe is certainly possible with pmap + Flax. Seems you want to create an ensemble and train it in parallel, check out the Ensembling on multiple devices guide for some pointers.

but I'm not sure if there's any technical burdens (like flax models not being hashable, hypothetically, or ...).

The train_step is usually a vanilla JAX function (jit/pmap) so there should be no issues if you follow conventions like using flax.training.train_state.TrainState to pass the apply function and stuff like that, for more info check out our (recently updated) Quick Start guide.

Replies: 2 comments 3 replies

Comment options

You must be logged in to vote
0 replies
Answer selected by mohamad-amin
Comment options

You must be logged in to vote
3 replies
@cgarciae
Comment options

@wqlevi
Comment options

@cgarciae
Comment options

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
3 participants