Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sequence parallelism #2412

Open
wants to merge 30 commits into
base: main
Choose a base branch
from
Open

Sequence parallelism #2412

wants to merge 30 commits into from

Conversation

djsaunde
Copy link
Contributor

@djsaunde djsaunde commented Mar 13, 2025

Description

This PR implements sequence parallelism via ring-flash-attn. Specifically, their hf_adapter.py module is used to patch transformers flash attention with llama3_flash_attn_varlen_func, the SP implementation from the llama3 tech report. This technically isn't ring attention, but is the most performant SP variant in most cases.

I think since the batch API (non-sample packing case) is a special case of the varlen API (sample packing case), these changes should be sufficient to cover both cases, but this should be validated with tests.

Motivation and Context

SP is necessary for long context post-training where the VRAM on a single card results in OOM for a single sequence. If a user has >1 GPUs, they can run longer context post-training by enabling this option.

The attention is distributed across the GPUs according to the set sequence_parallel_degree (i.e., if sequence_parallel_degree = 4, then sequences are split into 4 equal-length chunks). Attention is computed on each of the sub-sequences, and then comm is done inter-GPU in order to complete the attention computation.

How has this been tested?

pytest coverage (not super comprehensive) and functional tests.

Screenshots (if appropriate)

Types of changes

  • ring-flash-attn hf_adapter.py integration
  • Data collation changes (sequence splitting, position ID adjustment)
  • AxolotlTrainer sampler, dataloader changes
    • Refactor multipack sampler logic to helper method
    • DistributedSampler for SP case
      • Setting rank = SP group ID allows us to sample data according to SP group
    • Data loader (in the SP case) is not prepared for distributed training by the accelerator object
      • Distribution already handled by the DistributedSampler
  • Bonus: added random_init flag to load model without pretrained weights
  • Bonus: a bit of cleanup

@djsaunde djsaunde self-assigned this Mar 13, 2025
@@ -548,6 +553,14 @@ def apply_patches(self) -> None:

patch_self_attn_lora(self.cfg)

if self.cfg.sequence_parallel_degree > 1:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if self.cfg.sequence_parallel_degree > 1:
if self.cfg.sequence_parallel_degree and self.cfg.sequence_parallel_degree > 1:

should fix the NoneType comparison exception in the e2e tests

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants