Training a diffusion transformer with flow matching
This is a minimal example of diffusion model(DiT) training with flow matching, implemented in JAX/Flax.
800 steps, single batch training, Imagenet, 350M param DiT
- Flow Matching Guide and Code => research paper.
- facebookresearch/flow_matching => research codebase for flow matching (pytorch)
- Rami's dit-vs-unet => diffusion training implementation with JAX.