Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Enhance] Support ViT for TensorParallel #155

Merged
merged 6 commits into from
Apr 18, 2023

Conversation

KKIEEK
Copy link
Contributor

@KKIEEK KKIEEK commented Mar 18, 2023

Description

I added support for ViT in TensorParallel by appending config to _TensorParallelMapping.
PatchEmbed layer in ViT does not have the weight parameter unlike Embedding layer, so I replaced the weight parameter with a dummy value to prevent an AttributeError.

Any feedback is welcome.

Memory usage

mode world_size=1 world_size=2 world_size=4 world_size=8
1D 1760MiB 1126MiB 789MiB
2D 589MiB
2.5D (d=1) 589MiB
2.5D (d=2) 586MiB
3D

TODO

  • Benchmark with world_size=8
  • Refactor slicing patch embedding
  • Fix slicing logic to return the same value as TensorParallel1D
code for testing

import os
import torch.multiprocessing as mp

import torch
from torch import nn
from torch import optim
import torch.distributed as dist
from transformers import ViTModel, ViTForImageClassification, ViTConfig

import oslo
from oslo.torch.distributed.parallel_context import ParallelContext
from oslo.torch.distributed.parallel_mode import ParallelMode
from oslo.torch.nn.parallel import TensorParallel


def setup(rank, world_size):
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = "12340"
    os.environ["RANK"] = str(rank)
    os.environ["LOCAL_RANK"] = str(rank)
    os.environ["WORLD_SIZE"] = str(world_size)
    os.environ["LOCAL_WORLD_SIZE"] = str(world_size)


def cleanup():
    dist.destroy_process_group()


def train(rank, world_size):
    print(f"Running oslo TP example on rank {rank}.")
    setup(rank, world_size)
    parallel_context = ParallelContext.from_torch(
        tensor_parallel_size=world_size,
        tensor_parallel_mode=ParallelMode.TENSOR_1D,
    )  # TENSOR2D or TENSOR_2P5D

    model = ViTForImageClassification(ViTConfig(num_labels=1000)).to(rank)
    model = TensorParallel(model, parallel_context)
    optimizer = optim.SGD(model.parameters(), lr=1e-4)
    loss_fn = nn.MSELoss()

    oslo.ready(model, parallel_context)

    for _ in range(100):
        model.zero_grad()
        logits = model(pixel_values=torch.ones(8, 3, 224, 224).to(rank)).logits
        labels = torch.ones(8, 1000).to(rank) * 100
        loss = loss_fn(logits, labels)
        loss.backward()
        optimizer.step()
        print(logits)
        print(torch.cuda.max_memory_allocated() / 1024**2)  # MB

    cleanup()


def main(world_size):
    mp.spawn(train, args=(world_size,), nprocs=world_size, join=True)


if __name__ == "__main__":
    main(4)

Linked Issues

Related to #152

@KKIEEK KKIEEK requested a review from hyunwoongko as a code owner March 18, 2023 13:05
@KKIEEK
Copy link
Contributor Author

KKIEEK commented Mar 18, 2023

Currently, TensorParallelism for ViT does not support 2D and 2.5.
Could anyone let me know why LayerNorm is not working with 2D and 2.5?

Traceback (most recent call last):
  File "/admin/home/.local/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 69, in _wrap
    fn(i, *args)
  File "/admin/home/vit/test_tp.py", line 46, in train
    logits = model(pixel_values=torch.ones(8, 3, 224, 224).to(rank)).logits
  File "/admin/home/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "/admin/home/vit/oslo/torch/nn/parallel/tensor_parallel/tensor_parallel.py", line 95, in forward
    return self.module_forward(*args, **kwargs)
  File "/admin/home/vit/oslo/torch/nn/parallel/tensor_parallel/_2p5d/_wrapper.py", line 71, in forward
    return self.module_forward(*args, **kwargs)
  File "/admin/home/.local/lib/python3.8/site-packages/transformers/models/vit/modeling_vit.py", line 789, in forward
    outputs = self.vit(
  File "/admin/home/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "/admin/home/.local/lib/python3.8/site-packages/transformers/models/vit/modeling_vit.py", line 579, in forward
    encoder_outputs = self.encoder(
  File "/admin/home/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "/admin/home/.local/lib/python3.8/site-packages/transformers/models/vit/modeling_vit.py", line 409, in forward
    layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions)
  File "/admin/home/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "/admin/home/.local/lib/python3.8/site-packages/transformers/models/vit/modeling_vit.py", line 349, in forward
    self.layernorm_before(hidden_states),  # in ViT, layernorm is applied before self-attention
  File "/admin/home/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "/admin/home/vit/oslo/torch/nn/modules/layer_norm.py", line 317, in forward
    outputs = torch.addcmul(bias, scale, outputs)
RuntimeError: The size of tensor a (384) must match the size of tensor b (768) at non-singleton dimension 2

@KKIEEK KKIEEK changed the title Support ViT for TensorParallel [Enhance] Support ViT for TensorParallel Mar 18, 2023
@hyunwoongko
Copy link
Member

I think there's something wrong config on ViT mapping.
@bzantium @jason9693 could you check it on thursday together?

@hyunwoongko
Copy link
Member

@KKIEEK I suggest you to join model parallel team meeting on thursday. let's discuss together.

@KKIEEK
Copy link
Contributor Author

KKIEEK commented Mar 20, 2023

@hyunwoongko Okay, I see.

@bzantium
Copy link
Contributor

I think there's something wrong config on ViT mapping.
@bzantium @jason9693 could you check it on thursday together?

Sure! I will check

@hyunwoongko
Copy link
Member

@KKIEEK please let me know when you are done :)

@KKIEEK
Copy link
Contributor Author

KKIEEK commented Apr 5, 2023

@hyunwoongko
I roughly have fixed the ViTEmbeddings to work well with 2D TensorParallelism, but I don't think this implementation is elegant in current TensorParallel scheme. (it is too overly complicated and dependent on original vit modeling code)
Do you have any suggestions for improving it?

@KKIEEK
Copy link
Contributor Author

KKIEEK commented Apr 17, 2023

@hyunwoongko Can you review this PR for me? If you don't get a chance to do, I will close this PR.

@hyunwoongko hyunwoongko merged commit ca97853 into EleutherAI:main Apr 18, 2023
yhna940 added a commit to yhna940/oslo that referenced this pull request Apr 18, 2023
* import ParallelMode (EleutherAI#166)

## fix typo on tensor parallel tutorial

- `from oslo import ParallelContext, ParallelMode`

* [Fix] zero param check (EleutherAI#164)

## Title

- [Fix] zero param check

## Description

- ZeRO checks the redundancy of parameters to calculate the norm. There
is a minor bug in checking the TP and needs to be fixed.

## Linked Issues

- N/A

* [Fix] zero optimizer w/ tensor parallel test (EleutherAI#167)

## Title

- [Fix] zero optimizer w/ tensor parallel test

## Description

- ZeRO was not running in tensor parallel mode, so I fixed this by
switching to a model from `transformers`.

## Linked Issues

- N/A

* Add restarting model from saved model and fix bug (EleutherAI#171)

## Description

- load a model
- start training again from a saved point
- fix bug that training_arg not saved with nccl error. It was because of
parallel_context, and it was removed before saving training_arg and
re-attached again
- test load and restart with oslo TP

* Make decoder-only models to be able to generate with `inputs_embeds` (EleutherAI#172)

## Title
Make decoder-only models to be able to generate with `inputs_embeds`

## Description
Synchronize GPT2 code with Hugging Face transformers—GPT2 can generate
with `input_embeds`.

>Accepting `.generate()` calls with `inputs_embeds` on decoder-only
models is a long-standing request
(huggingface/transformers#6535) -- see
huggingface/transformers#6535 (comment)
particular and its reacts.
>
>It has to be added on a per-model basis, and this PR adds the necessary
changes for GPT2. Other models will throw an informative exception if
the user passes `inputs_embeds`, asking them to check this PR and
implement the same pattern on the model they want to use it with 🤗
>
>Please note that it is still expected that the user passes `input_ids`,
i.e.

```python
outputs = model.generate(input_ids, inputs_embeds=inputs_embeds)
```

>This is because decoder-only models expect the prompt to be present in
the output, and this is the only way to preserve it! input_ids can also
be omitted and, in that case, the output won't contain the prompt.

For more details, please check out [this
PR](huggingface/transformers#21405).

* Wrong import in zero (EleutherAI#169)

## Title

Prevent from using torch 2.0

## Description

- Some of feature have changed in torch 2.0. and oslo has dependency on
torch._six which no longer support by torch 2.0.

olso Dependency
-
https://github.com/EleutherAI/oslo/blob/910c789e7f46d2876b964c221d31984b7924974f/oslo/torch/nn/parallel/data_parallel/zero/sharded_optim/_utils.py#L19

other issues
- microsoft/DeepSpeed#2845

## Linked Issues

- resolved #00

* [Fix] Support gradient accumulation for DDP (EleutherAI#173)

## Description

In order to support gradient accumulation, I removed `free_storage`
function that can cause `CUDA error: an illegal memory access was
encountered` in many case. (but this change may lead to an increase in
memory consumption)
What do you guys think about this PR? @nijkah @jinwonkim93

* [Fix] minor bug for single output in _DistributedDataParallel (EleutherAI#177)

## Title

- Fix minor bug for single output in _DistributedDataParallel

## Description

- This PR addresses a minor bug in the `_DistributedDataParallel` class
when handling single output tensors. The changes include:

1. Update the `forward` method in `_DistributedDataParallel` to
correctly handle single output tensors.
2. Add new test cases in
`tests_deprecated/torch/nn/parallel/data_parallel/data_parallel.py` to
ensure the correct behavior for models with various output types (single
tensor, multiple tensors, and dictionary of tensors).

These updates will ensure that the `_DistributedDataParallel` class
works correctly with various output types, providing a more robust
solution for users.

## Linked Issues

- N/A

* [Enhance] Support ViT for TensorParallel (EleutherAI#155)

## Description

I added support for ViT in TensorParallel by appending config to
`_TensorParallelMapping`.
`PatchEmbed` layer in ViT does not have the `weight` parameter unlike
`Embedding` layer, so I replaced the `weight` parameter with a dummy
value to prevent an `AttributeError`.

Any feedback is welcome.

### Memory usage
mode | world_size=1 | world_size=2 | world_size=4 | world_size=8
-|-|-|-|-
1D | 1760MiB | 1126MiB | 789MiB |
2D | | | 589MiB |
2.5D (d=1) | | | 589MiB |
2.5D (d=2) | | | | 586MiB
3D | | | |

### TODO
- [ ] Benchmark with `world_size=8`
- [ ] Refactor slicing patch embedding
- [ ] Fix slicing logic to return the same value as `TensorParallel1D`

<details><summary>code for testing</summary>
<p>

```python
import os
import torch.multiprocessing as mp

import torch
from torch import nn
from torch import optim
import torch.distributed as dist
from transformers import ViTModel, ViTForImageClassification, ViTConfig

import oslo
from oslo.torch.distributed.parallel_context import ParallelContext
from oslo.torch.distributed.parallel_mode import ParallelMode
from oslo.torch.nn.parallel import TensorParallel


def setup(rank, world_size):
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = "12340"
    os.environ["RANK"] = str(rank)
    os.environ["LOCAL_RANK"] = str(rank)
    os.environ["WORLD_SIZE"] = str(world_size)
    os.environ["LOCAL_WORLD_SIZE"] = str(world_size)


def cleanup():
    dist.destroy_process_group()


def train(rank, world_size):
    print(f"Running oslo TP example on rank {rank}.")
    setup(rank, world_size)
    parallel_context = ParallelContext.from_torch(
        tensor_parallel_size=world_size,
        tensor_parallel_mode=ParallelMode.TENSOR_1D,
    )  # TENSOR2D or TENSOR_2P5D

    model = ViTForImageClassification(ViTConfig(num_labels=1000)).to(rank)
    model = TensorParallel(model, parallel_context)
    optimizer = optim.SGD(model.parameters(), lr=1e-4)
    loss_fn = nn.MSELoss()

    oslo.ready(model, parallel_context)

    for _ in range(100):
        model.zero_grad()
        logits = model(pixel_values=torch.ones(8, 3, 224, 224).to(rank)).logits
        labels = torch.ones(8, 1000).to(rank) * 100
        loss = loss_fn(logits, labels)
        loss.backward()
        optimizer.step()
        print(logits)
        print(torch.cuda.max_memory_allocated() / 1024**2)  # MB

    cleanup()


def main(world_size):
    mp.spawn(train, args=(world_size,), nprocs=world_size, join=True)


if __name__ == "__main__":
    main(4)
```

</p>
</details> 

## Linked Issues

Related to EleutherAI#152

---------

Co-authored-by: Minho Ryu <[email protected]>
Co-authored-by: Hansol Park <[email protected]>
Co-authored-by: Ingyu Seong <[email protected]>
Co-authored-by: whooray <[email protected]>
Co-authored-by: Junhwa Song <[email protected]>
dyanos pushed a commit that referenced this pull request Jun 8, 2023
## Description

I added support for ViT in TensorParallel by appending config to
`_TensorParallelMapping`.
`PatchEmbed` layer in ViT does not have the `weight` parameter unlike
`Embedding` layer, so I replaced the `weight` parameter with a dummy
value to prevent an `AttributeError`.

Any feedback is welcome.

### Memory usage
mode | world_size=1 | world_size=2 | world_size=4 | world_size=8
-|-|-|-|-
1D | 1760MiB | 1126MiB | 789MiB |
2D | | | 589MiB |
2.5D (d=1) | | | 589MiB |
2.5D (d=2) | | | | 586MiB
3D | | | |

### TODO
- [ ] Benchmark with `world_size=8`
- [ ] Refactor slicing patch embedding
- [ ] Fix slicing logic to return the same value as `TensorParallel1D`

<details><summary>code for testing</summary>
<p>

```python
import os
import torch.multiprocessing as mp

import torch
from torch import nn
from torch import optim
import torch.distributed as dist
from transformers import ViTModel, ViTForImageClassification, ViTConfig

import oslo
from oslo.torch.distributed.parallel_context import ParallelContext
from oslo.torch.distributed.parallel_mode import ParallelMode
from oslo.torch.nn.parallel import TensorParallel


def setup(rank, world_size):
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = "12340"
    os.environ["RANK"] = str(rank)
    os.environ["LOCAL_RANK"] = str(rank)
    os.environ["WORLD_SIZE"] = str(world_size)
    os.environ["LOCAL_WORLD_SIZE"] = str(world_size)


def cleanup():
    dist.destroy_process_group()


def train(rank, world_size):
    print(f"Running oslo TP example on rank {rank}.")
    setup(rank, world_size)
    parallel_context = ParallelContext.from_torch(
        tensor_parallel_size=world_size,
        tensor_parallel_mode=ParallelMode.TENSOR_1D,
    )  # TENSOR2D or TENSOR_2P5D

    model = ViTForImageClassification(ViTConfig(num_labels=1000)).to(rank)
    model = TensorParallel(model, parallel_context)
    optimizer = optim.SGD(model.parameters(), lr=1e-4)
    loss_fn = nn.MSELoss()

    oslo.ready(model, parallel_context)

    for _ in range(100):
        model.zero_grad()
        logits = model(pixel_values=torch.ones(8, 3, 224, 224).to(rank)).logits
        labels = torch.ones(8, 1000).to(rank) * 100
        loss = loss_fn(logits, labels)
        loss.backward()
        optimizer.step()
        print(logits)
        print(torch.cuda.max_memory_allocated() / 1024**2)  # MB

    cleanup()


def main(world_size):
    mp.spawn(train, args=(world_size,), nprocs=world_size, join=True)


if __name__ == "__main__":
    main(4)
```

</p>
</details> 

## Linked Issues

Related to #152
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants