Skip to content

Conversation

nicolo-domyn
Copy link

@nicolo-domyn nicolo-domyn commented Oct 14, 2025

Important

The Update branch button must only be pressed in very rare occassions.
An outdated branch is never blocking the merge of a PR.
Please reach out to the automation team before pressing that button.

What does this PR do ?

Implement sequence packing for chat-like data and increase its efficiency for large-scale datasets.

Collection: llm, nlp, utils

Changelog

  • Support packing in llm.gpt.chat.data.ChatDataModule: methods _create_data and prepare_data
  • Modify llm.gpt.data.core.GPTSFTChatDataset: the collate_fn seemed to assume that the sequences in the batch came prepadded when context parallel is active (see this comment). However, running the data pipeline this didn't seem to be the case, and sequences came unpadded. Consequently, the proposed change treats the input sequences as unpadded, and runs padding during collate_fn.
  • modify utils.sequence_packing_utils.create_hist: for the same reason as above, the sequences coming into the function are not padded. If doing padding after packing, this may exceed the maximum length allowed of the packed sequence (due to the unaccounted padding). Introduce a divisibility_factor argument inside create_hist to take this into account. Padding is needed when doing packing with context parallel, as the sequence lengths need to be divisible by 2*CP.
  • fixed nlp.models.language_modeling.megatron_gpt_model.MegatronGPTModel.get_batch_on_this_context_parallel_rank: this now splits some inputs tensors across CP ranks (tokens, labels, loss_mask, etc.) but leaves other untouched: the cu_seqlens do not need splitting across ranks
  • make the packing pipeline more scalable, as it is needed for large-scale datasets (10M sequences or more). In utils.sequence_packing_utils:
    • the fill_packing_strategy is around 10X faster now (on such large datasets moving from 1m to 6s) while being equivalent in result
    • the newly introduced first_fit_shuffle_with_heap does not have quadratic complexity: in local tests it runs in around 30s as opposed to 14h using first_fit. Its functioning is very similar to first_fit_shuffle, and the packing efficiency is equivalent (around 96% in local tests).

Usage

Here you may need to temporarily override FineTuningDataModule.setup (where it defines self.max_train_samples) and FineTuningDataModule._create_dataloader where it defines (self.init_global_step)

from pathlib import Path

from tqdm import tqdm

from nemo.collections.llm.gpt.data.chat import ChatDataModule
from nemo.collections.llm.gpt.data.packed_sequence import PackedSequenceSpecs
from nemo.collections.nlp.modules.common.tokenizer_utils import get_nmt_tokenizer

if __name__ == "__main__":
    packed_sequence_specs = PackedSequenceSpecs(
        packed_sequence_size=32768, 
        tokenizer_model_name="MY_TOKENIZER", 
        pad_cu_seqlens=True,  
    )

    tokenizer = get_nmt_tokenizer(
        library="sentencepiece",
        tokenizer_model=TOKENIZER_PATH,  # path to tokenizer model
    )

    dm = ChatDataModule(
        dataset_root=DATASET_ROOT,  # contains training.jsonl and validation.jsonl
        seq_length=32768,
        tokenizer=tokenizer,
        micro_batch_size=1,
        global_batch_size=4, 
        seed=1234,
        num_workers=1,
        pin_memory=True,
        persistent_workers=False,
        packed_sequence_specs=packed_sequence_specs,
        dataset_kwargs = {
            "return_cu_seqlen": True,
            "answer_only_loss": True,
            "pad_to_max_length": True,
        },  
    )
    dm.prepare_data()
    dm.setup("fit")
    train_loader = dm.train_dataloader()
    for batch in tqdm(train_loader, desc="Training Batches"):
        print(batch)

GitHub Actions CI

The Jenkins CI system has been replaced by GitHub Actions self-hosted runners.

The GitHub Actions CI will run automatically when the "Run CICD" label is added to the PR.
To re-run CI remove and add the label again.
To run CI on an untrusted fork, a NeMo user with write access must first click "Approve and run".

Before your PR is "Ready for review"

Pre checks:

  • Make sure you read and followed Contributor guidelines
  • Did you write any new necessary tests?
  • Did you add or update any necessary documentation?
  • Reviewer: Does the PR have correct import guards for all optional libraries?

PR Type:

  • New Feature
  • Bugfix
  • Documentation

If you haven't finished some of the above items you can still open "Draft" PR.

Who can review?

NLP collection reviewers: @MaximumEntropy @ericharper @ekmb, @yzhang123, @VahidooX, @vladgets, @okuchaiev

Additional Information

  • All of the modifications above do not make the pipeline completely scalable for large datasets. The main bottleneck is still the online tokenization in the main tokenization function. I estimate it would take around 2 days for tokenisation only on large datasets. While parallelising the tokenisation would require a separate script, at least supporting pre-tokenized data loading would help. I have a local implementation that I can add here if necessary.
  • Similarly, another big bottleneck is that for large datasets the packing function goes OOM. Moving to a .bin .idx format dataset (as opposed to currently holding all the data in memory and then saving a single .npy) solves the problem. Once again, I have a local implementation that I can add to the PR if needed.
  • Finally, there is another issue that has been signalled in megatron [BUG] Context parallel nan loss NVIDIA/Megatron-LM#1764. Without it, training with packing yields NaNs. The fix proposed there allows training with packing and context parallel correctly.

All in all with these modifications (and a parallelised tokenisation script) I can run packing on 10M samples, across CP ranks, on chat data.

(cherry picked from commit 90a72dc)
Signed-off-by: Nicolò Ruggeri <[email protected]>
(cherry picked from commit 46ee4fb)
Signed-off-by: Nicolò Ruggeri <[email protected]>
This allows to consider sequences as already padded (although they are not) , so that when we pad in downstream tasks (which should be done inside `GPTSFTPackedDataset.collate_fn`) the packed sequence doesn't exceed the maximum length due to unaccounted padding.

(cherry picked from commit 656186c)
Signed-off-by: Nicolò Ruggeri <[email protected]>
(cherry picked from commit 788ff65)
Signed-off-by: Nicolò Ruggeri <[email protected]>
…subsequences in a packed sequence to be a length divisible by `self.pad_seq_length_to_mult`.

The functions `pad_thd_sequences_for_cp` and `generate_positional_ids_for_cp` are copied from transformer_engine, since they are not available in the official NeMo container venvs.

(cherry picked from commit 430aa82)
Signed-off-by: Nicolò Ruggeri <[email protected]>
(cherry picked from commit bd95a66)
Signed-off-by: Nicolò Ruggeri <[email protected]>
… packing and `return_cu_seqlen`.

(cherry picked from commit 7adf3e4)
Signed-off-by: Nicolò Ruggeri <[email protected]>
(cherry picked from commit c5ff698)
Signed-off-by: Nicolò Ruggeri <[email protected]>
…ult.

(cherry picked from commit 52a3bd3)
Signed-off-by: Nicolò Ruggeri <[email protected]>
(cherry picked from commit 3bcea86)
Signed-off-by: Nicolò Ruggeri <[email protected]>
(cherry picked from commit 15e1db7)
Signed-off-by: Nicolò Ruggeri <[email protected]>
(cherry picked from commit f339ea8)
Signed-off-by: Nicolò Ruggeri <[email protected]>
(cherry picked from commit 0dc3622)
Signed-off-by: Nicolò Ruggeri <[email protected]>
(cherry picked from commit f51f406)
Signed-off-by: Nicolò Ruggeri <[email protected]>
@nicolo-domyn nicolo-domyn force-pushed the feat/chat-sequence-packing branch from f51f406 to 45bfc54 Compare October 15, 2025 09:25
nicolo-domyn and others added 2 commits October 15, 2025 09:26
@github-actions github-actions bot removed the NLP label Oct 15, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant