Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pipeline Parallelism (Supported? How to?) #827

Open
casper-hansen opened this issue Nov 14, 2024 · 4 comments
Open

Pipeline Parallelism (Supported? How to?) #827

casper-hansen opened this issue Nov 14, 2024 · 4 comments
Labels
enhancement New feature or request

Comments

@casper-hansen
Copy link

🚀 Feature Request

Supporting TP and SP seems quite easy to do with the `replication parameter:

replication = tp * sp

I have tried various ways to enable PP without success (unexpected high loss). I tried adding pp into the equation when computing replication and num_canonical_nodes, but I cannot get it to function normally because I get an unexpected high loss.

Motivation

I want to use the mosaicml streaming library with 4D parallel. Specifically, I rely on TorchTitan as my training tool and have simply swapped in the mosaicml streaming library by modifying the StreamingTextDataset implementation from LLM Foundry.

@ethantang-db
Copy link
Contributor

we can look into this more in detail, meanwhile, have you tried using mosaicml/composer though for training? Are there specific features you are relying on in Torchtitan?

@casper-hansen
Copy link
Author

casper-hansen commented Nov 15, 2024

I would really appreciate if you could look into it! TorchTitan uses torch.distributed.pipelining, most of which is only available from 2.5.0 or in nightly builds.

There are many key features like FSDP2, 4D parallelism, FP8, and torch.compile that makes LLaMa models scale well in pretraining. You also get full control over the training loop which is desirable if you want to experiment.

@snarayan21
Copy link
Collaborator

@casper-hansen So StreamingDataset's replication argument assumes that the ranks that have replicated samples are in contiguous blocks of global rank indices. Concretely, suppose on 16 GPUs, I have a replication factor of 2. Then StreamingDataset will replicate the same samples on GPU ranks 0 and 1, 2 and 3, 4 and 5, and so on. In the 4D parallelism case, you likely have ranks that are not contiguous, but still want to replicate samples over these ranks (as in, using the previous example, you may want GPU ranks 0, 1, 8, and 9 to see the same samples).

We currently enable replication through the World object's replicate function (called here) which is used to set the correct global node and rank indices to construct the sample partition over and retrieve samples. If you want to try enabling 4D parallelism yourself, I would take a look at the replicate function here and allow it to create a new World object with the right information according to your sharding & parallelism strategy.

@cassanof
Copy link

Would be great to integrate the new DeviceMesh abstraction from pytorch.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

4 participants