fairscale is a PyTorch extension library for high performance and large scale training.
fairscale supports:
- pipeline parallelism (fairscale.nn.Pipe)
- optimizer state sharding (fairscale.optim.oss)
Run a 4-layer model on 2 GPUs. The first two layers run on cuda:0 and the next two layers run on cuda:1.
import torch
import fairscale
model = torch.nn.Sequential(a, b, c, d)
model = fairscale.nn.Pipe(model, balance=[2, 2], devices=[0, 1], chunks=8)
- PyTorch >= 1.4
Normal installation:
pip install .
Development mode:
pip install -e .
See the CONTRIBUTING file for how to help out.
fairscale is licensed under the BSD-3-Clause License.