Skip to content

Commit

Permalink
[Enhance] Support ViT for TensorParallel (#155)
Browse files Browse the repository at this point in the history
## Description

I added support for ViT in TensorParallel by appending config to
`_TensorParallelMapping`.
`PatchEmbed` layer in ViT does not have the `weight` parameter unlike
`Embedding` layer, so I replaced the `weight` parameter with a dummy
value to prevent an `AttributeError`.

Any feedback is welcome.

### Memory usage
mode | world_size=1 | world_size=2 | world_size=4 | world_size=8
-|-|-|-|-
1D | 1760MiB | 1126MiB | 789MiB |
2D | | | 589MiB |
2.5D (d=1) | | | 589MiB |
2.5D (d=2) | | | | 586MiB
3D | | | |

### TODO
- [ ] Benchmark with `world_size=8`
- [ ] Refactor slicing patch embedding
- [ ] Fix slicing logic to return the same value as `TensorParallel1D`

<details><summary>code for testing</summary>
<p>

```python
import os
import torch.multiprocessing as mp

import torch
from torch import nn
from torch import optim
import torch.distributed as dist
from transformers import ViTModel, ViTForImageClassification, ViTConfig

import oslo
from oslo.torch.distributed.parallel_context import ParallelContext
from oslo.torch.distributed.parallel_mode import ParallelMode
from oslo.torch.nn.parallel import TensorParallel


def setup(rank, world_size):
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = "12340"
    os.environ["RANK"] = str(rank)
    os.environ["LOCAL_RANK"] = str(rank)
    os.environ["WORLD_SIZE"] = str(world_size)
    os.environ["LOCAL_WORLD_SIZE"] = str(world_size)


def cleanup():
    dist.destroy_process_group()


def train(rank, world_size):
    print(f"Running oslo TP example on rank {rank}.")
    setup(rank, world_size)
    parallel_context = ParallelContext.from_torch(
        tensor_parallel_size=world_size,
        tensor_parallel_mode=ParallelMode.TENSOR_1D,
    )  # TENSOR2D or TENSOR_2P5D

    model = ViTForImageClassification(ViTConfig(num_labels=1000)).to(rank)
    model = TensorParallel(model, parallel_context)
    optimizer = optim.SGD(model.parameters(), lr=1e-4)
    loss_fn = nn.MSELoss()

    oslo.ready(model, parallel_context)

    for _ in range(100):
        model.zero_grad()
        logits = model(pixel_values=torch.ones(8, 3, 224, 224).to(rank)).logits
        labels = torch.ones(8, 1000).to(rank) * 100
        loss = loss_fn(logits, labels)
        loss.backward()
        optimizer.step()
        print(logits)
        print(torch.cuda.max_memory_allocated() / 1024**2)  # MB

    cleanup()


def main(world_size):
    mp.spawn(train, args=(world_size,), nprocs=world_size, join=True)


if __name__ == "__main__":
    main(4)
```

</p>
</details> 

## Linked Issues

Related to #152
  • Loading branch information
KKIEEK authored and dyanos committed Jun 8, 2023
1 parent 72f7018 commit 36cbb48
Show file tree
Hide file tree
Showing 5 changed files with 136 additions and 2 deletions.
62 changes: 62 additions & 0 deletions oslo/torch/nn/modules/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from transformers.models.vit.modeling_vit import ViTConfig, ViTEmbeddings

from oslo.torch.distributed import ParallelContext, ParallelMode

Expand Down Expand Up @@ -256,6 +257,67 @@ def forward(self, input: Tensor) -> Tensor:
return output


class ViTEmbedding2D(ViTEmbeddings):
def __init__(
self, config: ViTConfig, use_mask_token: bool = False, parallel_context=None
) -> None:
assert parallel_context is not None, "parallel_context must be provided"
self.parallel_context = parallel_context
self.summa_dim = self.parallel_context.get_world_size(
ParallelMode.TENSOR_2D_COL
)

super().__init__(config, use_mask_token=use_mask_token)

def forward(
self,
pixel_values: torch.Tensor,
bool_masked_pos: Optional[torch.BoolTensor] = None,
interpolate_pos_encoding: bool = False,
) -> torch.Tensor:
from oslo.torch.distributed.nn.functional import all_gather

batch_size, num_channels, height, width = pixel_values.shape
embeddings = self.patch_embeddings(
pixel_values, interpolate_pos_encoding=interpolate_pos_encoding
)

if bool_masked_pos is not None:
seq_length = embeddings.shape[1]
mask_tokens = self.mask_token.expand(batch_size, seq_length, -1)
# replace the masked visual tokens by mask_tokens
mask = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens)
embeddings = embeddings * (1.0 - mask) + mask_tokens * mask

# add the [CLS] token to the embedded patch tokens
cls_token = all_gather(
self.cls_token,
dim=-1,
parallel_context=self.parallel_context,
parallel_mode=ParallelMode.TENSOR_2D_COL,
)
cls_tokens = cls_token.expand(batch_size, -1, -1)
embeddings = torch.cat((cls_tokens, embeddings), dim=1)

# add positional encoding to each token
if interpolate_pos_encoding:
embeddings = embeddings + self.interpolate_pos_encoding(
embeddings, height, width
)
else:
position_embeddings = all_gather(
self.position_embeddings,
dim=-1,
parallel_context=self.parallel_context,
parallel_mode=ParallelMode.TENSOR_2D_COL,
)
embeddings = embeddings + position_embeddings

embeddings = self.dropout(embeddings)

return embeddings


class Embedding2p5D(nn.Embedding):
def __init__(
self,
Expand Down
60 changes: 60 additions & 0 deletions oslo/torch/nn/parallel/tensor_parallel/_2d/_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
VocabParallelEmbedding2D,
Embedding2D,
VocabUtility,
ViTEmbedding2D,
)
from oslo.torch.nn.modules.layer_norm import (
LayerNorm2D,
Expand Down Expand Up @@ -101,11 +102,15 @@ def _update_mp_arguments(self):
setattr(module, elem.name, reduced_arg)

def _parallelize_embedding(self):
from transformers.models.vit.modeling_vit import ViTEmbeddings

for module_name, module in self.module.named_modules():
if isinstance(module, nn.Embedding):
self._slice_embedding(
module=module,
)
elif isinstance(module, ViTEmbeddings):
self._slice_patch_embedding(module=module)

def _parallelize_linear(self):
for module_name, module in self.module.named_modules():
Expand Down Expand Up @@ -216,6 +221,61 @@ def _slice_embedding(self, module):
ParallelMode.TENSOR_2D_COL: col_rank,
}

def _slice_patch_embedding(self, module):
summa_dim = self.parallel_context.get_world_size(ParallelMode.TENSOR_2D_COL)
row_rank = self.parallel_context.get_local_rank(ParallelMode.TENSOR_2D_ROW)
col_rank = self.parallel_context.get_local_rank(ParallelMode.TENSOR_2D_COL)

patch_embed = module.patch_embeddings.projection
weight_list = patch_embed.weight.data.chunk(summa_dim, dim=0)
weight_list = [weight.chunk(summa_dim, dim=0) for weight in weight_list]
patch_embed.weight.data = weight_list[row_rank][col_rank].contiguous()
if patch_embed.bias is not None:
bias_list = patch_embed.bias.data.chunk(summa_dim, dim=0)
bias_list = [bias.chunk(summa_dim, dim=0) for bias in bias_list]
patch_embed.bias.data = bias_list[row_rank][col_rank].contiguous()

def module_forward(input):
from oslo.torch.distributed.nn.functional import all_gather

weight = all_gather(
patch_embed.weight,
dim=0,
parallel_context=self.parallel_context,
parallel_mode=ParallelMode.TENSOR_2D_COL,
)

bias = None
if patch_embed.bias is not None:
bias = all_gather(
patch_embed.bias,
dim=0,
parallel_context=self.parallel_context,
parallel_mode=ParallelMode.TENSOR_2D_COL,
)

return patch_embed._conv_forward(input, weight, bias)

patch_embed.forward = module_forward

pos_embed = module.position_embeddings
param_list = pos_embed.data.chunk(summa_dim, dim=-1)
param_list = [param.chunk(summa_dim, dim=-1) for param in param_list]
pos_embed.data = param_list[row_rank][col_rank].contiguous()

cls_token = module.cls_token
param_list = cls_token.data.chunk(summa_dim, dim=-1)
param_list = [param.chunk(summa_dim, dim=-1) for param in param_list]
cls_token.data = param_list[row_rank][col_rank].contiguous()

_update_module_arguments(
module=module,
parallel_context=self.parallel_context,
summa_dim=summa_dim,
orig_module=copy.deepcopy(module.__class__),
)
module.__class__ = ViTEmbedding2D

def _slice_linear(self, module, reversed, fusion_degree, slice_bias):
summa_dim = self.parallel_context.get_world_size(ParallelMode.TENSOR_2D_COL)
row_rank = self.parallel_context.get_local_rank(ParallelMode.TENSOR_2D_ROW)
Expand Down
4 changes: 4 additions & 0 deletions oslo/torch/nn/parallel/tensor_parallel/tensor_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,10 @@ def _resize_vocab_size(model, parallel_context):
), "model object must have `get_input_embeddings` method."

module = model.get_input_embeddings()
if not hasattr(module, "weight"):
module.weight = None
return model

vocab_size, embedding_dim = module.weight.size()
new_vocab_size = vocab_size

Expand Down
1 change: 1 addition & 0 deletions oslo/transformers/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
"decoder_token_type_ids": 0,
"decoder_position_ids": 0,
"decoder_inputs_embeds": 0,
"pixel_values": 0,
}

BATCH_DIMENSIONS_PP = {
Expand Down
11 changes: 9 additions & 2 deletions oslo/transformers/mapping_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def _load_hf_class_by_name(model_name):
transformers = importlib.import_module("transformers")
cls = getattr(transformers, f"{model_name}PreTrainedModel", None)
if cls is None:
cls = getattr(transformers, f"{model_name}PretrainedModel", None)
cls = getattr(transformers, model_name, None)
return cls
except ImportError:
return None
Expand All @@ -55,7 +55,7 @@ def _load_oslo_class_by_name(model_name):
transformers = importlib.import_module("oslo.transformers")
cls = getattr(transformers, f"{model_name}PreTrainedModel", None)
if cls is None:
cls = getattr(transformers, f"{model_name}PretrainedModel", None)
cls = getattr(transformers, model_name, None)
return cls
except ImportError:
return None
Expand Down Expand Up @@ -233,6 +233,13 @@ class _TensorParallelMapping(_ParallelMapping):
gather_output=True,
),
],
"ViTForImageClassification": [
Column("query", "key", "value", "intermediate.dense"),
Row("output.dense"),
Other("position_embeddings", "cls_token", gather_output=True),
Update("num_attention_heads", "all_head_size"),
Head("classifier", gather_output=True),
],
}


Expand Down

0 comments on commit 36cbb48

Please sign in to comment.