-
Notifications
You must be signed in to change notification settings - Fork 26
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Enhance] Support ViT for TensorParallel (#155)
## Description I added support for ViT in TensorParallel by appending config to `_TensorParallelMapping`. `PatchEmbed` layer in ViT does not have the `weight` parameter unlike `Embedding` layer, so I replaced the `weight` parameter with a dummy value to prevent an `AttributeError`. Any feedback is welcome. ### Memory usage mode | world_size=1 | world_size=2 | world_size=4 | world_size=8 -|-|-|-|- 1D | 1760MiB | 1126MiB | 789MiB | 2D | | | 589MiB | 2.5D (d=1) | | | 589MiB | 2.5D (d=2) | | | | 586MiB 3D | | | | ### TODO - [ ] Benchmark with `world_size=8` - [ ] Refactor slicing patch embedding - [ ] Fix slicing logic to return the same value as `TensorParallel1D` <details><summary>code for testing</summary> <p> ```python import os import torch.multiprocessing as mp import torch from torch import nn from torch import optim import torch.distributed as dist from transformers import ViTModel, ViTForImageClassification, ViTConfig import oslo from oslo.torch.distributed.parallel_context import ParallelContext from oslo.torch.distributed.parallel_mode import ParallelMode from oslo.torch.nn.parallel import TensorParallel def setup(rank, world_size): os.environ["MASTER_ADDR"] = "localhost" os.environ["MASTER_PORT"] = "12340" os.environ["RANK"] = str(rank) os.environ["LOCAL_RANK"] = str(rank) os.environ["WORLD_SIZE"] = str(world_size) os.environ["LOCAL_WORLD_SIZE"] = str(world_size) def cleanup(): dist.destroy_process_group() def train(rank, world_size): print(f"Running oslo TP example on rank {rank}.") setup(rank, world_size) parallel_context = ParallelContext.from_torch( tensor_parallel_size=world_size, tensor_parallel_mode=ParallelMode.TENSOR_1D, ) # TENSOR2D or TENSOR_2P5D model = ViTForImageClassification(ViTConfig(num_labels=1000)).to(rank) model = TensorParallel(model, parallel_context) optimizer = optim.SGD(model.parameters(), lr=1e-4) loss_fn = nn.MSELoss() oslo.ready(model, parallel_context) for _ in range(100): model.zero_grad() logits = model(pixel_values=torch.ones(8, 3, 224, 224).to(rank)).logits labels = torch.ones(8, 1000).to(rank) * 100 loss = loss_fn(logits, labels) loss.backward() optimizer.step() print(logits) print(torch.cuda.max_memory_allocated() / 1024**2) # MB cleanup() def main(world_size): mp.spawn(train, args=(world_size,), nprocs=world_size, join=True) if __name__ == "__main__": main(4) ``` </p> </details> ## Linked Issues Related to #152
- Loading branch information
Showing
5 changed files
with
136 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters