From 36cbb4834af10ad8293a666be093b31a3df47104 Mon Sep 17 00:00:00 2001 From: Junhwa Song Date: Tue, 18 Apr 2023 11:29:57 +0900 Subject: [PATCH] [Enhance] Support ViT for TensorParallel (#155) ## Description I added support for ViT in TensorParallel by appending config to `_TensorParallelMapping`. `PatchEmbed` layer in ViT does not have the `weight` parameter unlike `Embedding` layer, so I replaced the `weight` parameter with a dummy value to prevent an `AttributeError`. Any feedback is welcome. ### Memory usage mode | world_size=1 | world_size=2 | world_size=4 | world_size=8 -|-|-|-|- 1D | 1760MiB | 1126MiB | 789MiB | 2D | | | 589MiB | 2.5D (d=1) | | | 589MiB | 2.5D (d=2) | | | | 586MiB 3D | | | | ### TODO - [ ] Benchmark with `world_size=8` - [ ] Refactor slicing patch embedding - [ ] Fix slicing logic to return the same value as `TensorParallel1D`
code for testing

```python import os import torch.multiprocessing as mp import torch from torch import nn from torch import optim import torch.distributed as dist from transformers import ViTModel, ViTForImageClassification, ViTConfig import oslo from oslo.torch.distributed.parallel_context import ParallelContext from oslo.torch.distributed.parallel_mode import ParallelMode from oslo.torch.nn.parallel import TensorParallel def setup(rank, world_size): os.environ["MASTER_ADDR"] = "localhost" os.environ["MASTER_PORT"] = "12340" os.environ["RANK"] = str(rank) os.environ["LOCAL_RANK"] = str(rank) os.environ["WORLD_SIZE"] = str(world_size) os.environ["LOCAL_WORLD_SIZE"] = str(world_size) def cleanup(): dist.destroy_process_group() def train(rank, world_size): print(f"Running oslo TP example on rank {rank}.") setup(rank, world_size) parallel_context = ParallelContext.from_torch( tensor_parallel_size=world_size, tensor_parallel_mode=ParallelMode.TENSOR_1D, ) # TENSOR2D or TENSOR_2P5D model = ViTForImageClassification(ViTConfig(num_labels=1000)).to(rank) model = TensorParallel(model, parallel_context) optimizer = optim.SGD(model.parameters(), lr=1e-4) loss_fn = nn.MSELoss() oslo.ready(model, parallel_context) for _ in range(100): model.zero_grad() logits = model(pixel_values=torch.ones(8, 3, 224, 224).to(rank)).logits labels = torch.ones(8, 1000).to(rank) * 100 loss = loss_fn(logits, labels) loss.backward() optimizer.step() print(logits) print(torch.cuda.max_memory_allocated() / 1024**2) # MB cleanup() def main(world_size): mp.spawn(train, args=(world_size,), nprocs=world_size, join=True) if __name__ == "__main__": main(4) ```

## Linked Issues Related to #152 --- oslo/torch/nn/modules/embedding.py | 62 +++++++++++++++++++ .../parallel/tensor_parallel/_2d/_wrapper.py | 60 ++++++++++++++++++ .../tensor_parallel/tensor_parallel.py | 4 ++ oslo/transformers/constants.py | 1 + oslo/transformers/mapping_utils.py | 11 +++- 5 files changed, 136 insertions(+), 2 deletions(-) diff --git a/oslo/torch/nn/modules/embedding.py b/oslo/torch/nn/modules/embedding.py index b403204d..52436582 100644 --- a/oslo/torch/nn/modules/embedding.py +++ b/oslo/torch/nn/modules/embedding.py @@ -4,6 +4,7 @@ import torch.nn as nn import torch.nn.functional as F from torch import Tensor +from transformers.models.vit.modeling_vit import ViTConfig, ViTEmbeddings from oslo.torch.distributed import ParallelContext, ParallelMode @@ -256,6 +257,67 @@ def forward(self, input: Tensor) -> Tensor: return output +class ViTEmbedding2D(ViTEmbeddings): + def __init__( + self, config: ViTConfig, use_mask_token: bool = False, parallel_context=None + ) -> None: + assert parallel_context is not None, "parallel_context must be provided" + self.parallel_context = parallel_context + self.summa_dim = self.parallel_context.get_world_size( + ParallelMode.TENSOR_2D_COL + ) + + super().__init__(config, use_mask_token=use_mask_token) + + def forward( + self, + pixel_values: torch.Tensor, + bool_masked_pos: Optional[torch.BoolTensor] = None, + interpolate_pos_encoding: bool = False, + ) -> torch.Tensor: + from oslo.torch.distributed.nn.functional import all_gather + + batch_size, num_channels, height, width = pixel_values.shape + embeddings = self.patch_embeddings( + pixel_values, interpolate_pos_encoding=interpolate_pos_encoding + ) + + if bool_masked_pos is not None: + seq_length = embeddings.shape[1] + mask_tokens = self.mask_token.expand(batch_size, seq_length, -1) + # replace the masked visual tokens by mask_tokens + mask = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens) + embeddings = embeddings * (1.0 - mask) + mask_tokens * mask + + # add the [CLS] token to the embedded patch tokens + cls_token = all_gather( + self.cls_token, + dim=-1, + parallel_context=self.parallel_context, + parallel_mode=ParallelMode.TENSOR_2D_COL, + ) + cls_tokens = cls_token.expand(batch_size, -1, -1) + embeddings = torch.cat((cls_tokens, embeddings), dim=1) + + # add positional encoding to each token + if interpolate_pos_encoding: + embeddings = embeddings + self.interpolate_pos_encoding( + embeddings, height, width + ) + else: + position_embeddings = all_gather( + self.position_embeddings, + dim=-1, + parallel_context=self.parallel_context, + parallel_mode=ParallelMode.TENSOR_2D_COL, + ) + embeddings = embeddings + position_embeddings + + embeddings = self.dropout(embeddings) + + return embeddings + + class Embedding2p5D(nn.Embedding): def __init__( self, diff --git a/oslo/torch/nn/parallel/tensor_parallel/_2d/_wrapper.py b/oslo/torch/nn/parallel/tensor_parallel/_2d/_wrapper.py index 7eaff5a6..6a9abfb2 100644 --- a/oslo/torch/nn/parallel/tensor_parallel/_2d/_wrapper.py +++ b/oslo/torch/nn/parallel/tensor_parallel/_2d/_wrapper.py @@ -11,6 +11,7 @@ VocabParallelEmbedding2D, Embedding2D, VocabUtility, + ViTEmbedding2D, ) from oslo.torch.nn.modules.layer_norm import ( LayerNorm2D, @@ -101,11 +102,15 @@ def _update_mp_arguments(self): setattr(module, elem.name, reduced_arg) def _parallelize_embedding(self): + from transformers.models.vit.modeling_vit import ViTEmbeddings + for module_name, module in self.module.named_modules(): if isinstance(module, nn.Embedding): self._slice_embedding( module=module, ) + elif isinstance(module, ViTEmbeddings): + self._slice_patch_embedding(module=module) def _parallelize_linear(self): for module_name, module in self.module.named_modules(): @@ -216,6 +221,61 @@ def _slice_embedding(self, module): ParallelMode.TENSOR_2D_COL: col_rank, } + def _slice_patch_embedding(self, module): + summa_dim = self.parallel_context.get_world_size(ParallelMode.TENSOR_2D_COL) + row_rank = self.parallel_context.get_local_rank(ParallelMode.TENSOR_2D_ROW) + col_rank = self.parallel_context.get_local_rank(ParallelMode.TENSOR_2D_COL) + + patch_embed = module.patch_embeddings.projection + weight_list = patch_embed.weight.data.chunk(summa_dim, dim=0) + weight_list = [weight.chunk(summa_dim, dim=0) for weight in weight_list] + patch_embed.weight.data = weight_list[row_rank][col_rank].contiguous() + if patch_embed.bias is not None: + bias_list = patch_embed.bias.data.chunk(summa_dim, dim=0) + bias_list = [bias.chunk(summa_dim, dim=0) for bias in bias_list] + patch_embed.bias.data = bias_list[row_rank][col_rank].contiguous() + + def module_forward(input): + from oslo.torch.distributed.nn.functional import all_gather + + weight = all_gather( + patch_embed.weight, + dim=0, + parallel_context=self.parallel_context, + parallel_mode=ParallelMode.TENSOR_2D_COL, + ) + + bias = None + if patch_embed.bias is not None: + bias = all_gather( + patch_embed.bias, + dim=0, + parallel_context=self.parallel_context, + parallel_mode=ParallelMode.TENSOR_2D_COL, + ) + + return patch_embed._conv_forward(input, weight, bias) + + patch_embed.forward = module_forward + + pos_embed = module.position_embeddings + param_list = pos_embed.data.chunk(summa_dim, dim=-1) + param_list = [param.chunk(summa_dim, dim=-1) for param in param_list] + pos_embed.data = param_list[row_rank][col_rank].contiguous() + + cls_token = module.cls_token + param_list = cls_token.data.chunk(summa_dim, dim=-1) + param_list = [param.chunk(summa_dim, dim=-1) for param in param_list] + cls_token.data = param_list[row_rank][col_rank].contiguous() + + _update_module_arguments( + module=module, + parallel_context=self.parallel_context, + summa_dim=summa_dim, + orig_module=copy.deepcopy(module.__class__), + ) + module.__class__ = ViTEmbedding2D + def _slice_linear(self, module, reversed, fusion_degree, slice_bias): summa_dim = self.parallel_context.get_world_size(ParallelMode.TENSOR_2D_COL) row_rank = self.parallel_context.get_local_rank(ParallelMode.TENSOR_2D_ROW) diff --git a/oslo/torch/nn/parallel/tensor_parallel/tensor_parallel.py b/oslo/torch/nn/parallel/tensor_parallel/tensor_parallel.py index 9b45a548..4aaf948c 100644 --- a/oslo/torch/nn/parallel/tensor_parallel/tensor_parallel.py +++ b/oslo/torch/nn/parallel/tensor_parallel/tensor_parallel.py @@ -101,6 +101,10 @@ def _resize_vocab_size(model, parallel_context): ), "model object must have `get_input_embeddings` method." module = model.get_input_embeddings() + if not hasattr(module, "weight"): + module.weight = None + return model + vocab_size, embedding_dim = module.weight.size() new_vocab_size = vocab_size diff --git a/oslo/transformers/constants.py b/oslo/transformers/constants.py index 9ec13b78..a15240c1 100644 --- a/oslo/transformers/constants.py +++ b/oslo/transformers/constants.py @@ -9,6 +9,7 @@ "decoder_token_type_ids": 0, "decoder_position_ids": 0, "decoder_inputs_embeds": 0, + "pixel_values": 0, } BATCH_DIMENSIONS_PP = { diff --git a/oslo/transformers/mapping_utils.py b/oslo/transformers/mapping_utils.py index de34217a..ae3f5839 100644 --- a/oslo/transformers/mapping_utils.py +++ b/oslo/transformers/mapping_utils.py @@ -35,7 +35,7 @@ def _load_hf_class_by_name(model_name): transformers = importlib.import_module("transformers") cls = getattr(transformers, f"{model_name}PreTrainedModel", None) if cls is None: - cls = getattr(transformers, f"{model_name}PretrainedModel", None) + cls = getattr(transformers, model_name, None) return cls except ImportError: return None @@ -55,7 +55,7 @@ def _load_oslo_class_by_name(model_name): transformers = importlib.import_module("oslo.transformers") cls = getattr(transformers, f"{model_name}PreTrainedModel", None) if cls is None: - cls = getattr(transformers, f"{model_name}PretrainedModel", None) + cls = getattr(transformers, model_name, None) return cls except ImportError: return None @@ -233,6 +233,13 @@ class _TensorParallelMapping(_ParallelMapping): gather_output=True, ), ], + "ViTForImageClassification": [ + Column("query", "key", "value", "intermediate.dense"), + Row("output.dense"), + Other("position_embeddings", "cls_token", gather_output=True), + Update("num_attention_heads", "all_head_size"), + Head("classifier", gather_output=True), + ], }