Parallel networks using flax #2862
-
Is it possible to create multiple instances of a network and optimize all of them in parallel on the same data? I'm trying to run a small experiment a lot of times and was wondering if flax + optax + Thanks! |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 3 replies
-
Hey @mohamad-amin, what you describe is certainly possible with
The |
Beta Was this translation helpful? Give feedback.
-
Hi @cgarciae , I'm also curious about the parallel training scheme written in the Ensembling on multiple devices; specifically, I don't understand why here replicate is used to propagate batched data, shouldn't different mini-batch be trained on each device (instead of identical replica)? Thanks, |
Beta Was this translation helpful? Give feedback.
Hey @mohamad-amin, what you describe is certainly possible with
pmap
+ Flax. Seems you want to create an ensemble and train it in parallel, check out the Ensembling on multiple devices guide for some pointers.The
train_step
is usually a vanilla JAX function (jit/pmap) so there should be no issues if you follow conventions like usingflax.training.train_state.TrainState
to pass theapply
function and stuff like that, for more info check out our (recently updated) Quick Start guide.