Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci

Signed-off-by: NeMo-Aligner CI <[email protected]>
  • Loading branch information
pre-commit-ci[bot] committed Nov 7, 2024
1 parent d8a639a commit 96b4f01
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 14 deletions.
2 changes: 1 addition & 1 deletion examples/nlp/gpt/train_gpt_dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from nemo.utils import logging
from nemo.utils.exp_manager import exp_manager
from nemo_aligner.algorithms.dpo import DPOTrainer, dpo_custom_collate
from nemo_aligner.data.nlp.builders import build_dataloader, identity_collate, build_train_valid_test_dpo_datasets
from nemo_aligner.data.nlp.builders import build_dataloader, build_train_valid_test_dpo_datasets, identity_collate
from nemo_aligner.models.nlp.gpt.megatron_gpt_dpo_model import MegatronGPTDPOModel
from nemo_aligner.utils.distributed import Timer
from nemo_aligner.utils.train_script_utils import (
Expand Down
14 changes: 8 additions & 6 deletions nemo_aligner/algorithms/dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,23 +15,23 @@
import math
from collections import defaultdict
from statistics import mean
from typing import Any, Protocol

import torch
import torch.distributed
from omegaconf.dictconfig import DictConfig
from tqdm import tqdm
from typing import Protocol, Any

from nemo.collections.nlp.data.language_modeling.megatron.megatron_batch_samplers import (
MegatronPretrainingRandomBatchSampler,
)
from nemo.collections.nlp.modules.common.megatron.utils import get_ltor_masks_and_position_ids
from nemo.utils import logging
from nemo_aligner.utils import parallel_state
from nemo_aligner.utils.distributed import SyncTimer
from nemo_aligner.utils.train_utils import clip_gradients
from nemo_aligner.utils.trainer_utils import check_progress, compute_limit_batches, compute_num_steps_per_epoch
from nemo_aligner.utils.utils import clear_memory
from nemo_aligner.utils import parallel_state


class DistributedCollateFunction(Protocol):
Expand All @@ -47,12 +47,12 @@ def dpo_custom_collate(
eod_mask_loss: bool = False,
pad_length_to_multiple_of: int | None = None,
) -> dict[str, torch.Tensor]:
'''
"""
Transposes minibatch from list[dict] -> dict[Tensor] and also pads
This collate happens outside of the torch data loader and is not compatible with the multiprocessing
logic due to requiring communication collectives.
'''
"""
chosen_tokens = [item["chosen"] for item in batch]
rejected_tokens = [item["rejected"] for item in batch]
chosen_lengths = torch.LongTensor([item["chosen_length"] for item in batch])
Expand All @@ -72,8 +72,10 @@ def dpo_custom_collate(
if pad_length_to_multiple_of:
# Assumes both chosen and rejected match
max_seq_len = torch.tensor(chosen_tokens.shape[1], device=torch.cuda.current_device())
torch.distributed.all_reduce(max_seq_len, op=torch.distributed.ReduceOp.MAX, group=parallel_state.get_data_parallel_group())

torch.distributed.all_reduce(
max_seq_len, op=torch.distributed.ReduceOp.MAX, group=parallel_state.get_data_parallel_group()
)

padded_max_len = math.ceil(max_seq_len / pad_length_to_multiple_of) * pad_length_to_multiple_of
chosen_tokens = torch.nn.functional.pad(
chosen_tokens, (0, padded_max_len - chosen_tokens.shape[1]), mode="constant", value=eos_id
Expand Down
20 changes: 13 additions & 7 deletions tests/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,18 @@
def llama3_tokenizer():
return AutoTokenizer("meta-llama/Meta-Llama-3-8b")


@pytest.fixture
def str_to_list_tokenizer():
class StringToListTokenizer:
eos_id: int = -1

def text_to_ids(self, text: str) -> list[int]:
return [int(x) for x in text.split()]

return StringToListTokenizer()


@pytest.fixture
def make_tmp_jsonl():
with TemporaryDirectory() as tmp_dir:
Expand Down Expand Up @@ -266,11 +268,13 @@ def test_dpo_loader_pad_to_multiple(init_model_parallel, make_tmp_jsonl, str_to_

num_mini_batches = 0
for mbatch in train_dataloader:
chosen_lengths = [len(x['chosen']) for x in mbatch]
rejected_lengths = [len(x['rejected']) for x in mbatch]
chosen_lengths = [len(x["chosen"]) for x in mbatch]
rejected_lengths = [len(x["rejected"]) for x in mbatch]
assert chosen_lengths == rejected_lengths

assert len(set(chosen_lengths)) == len(chosen_lengths), f"Lengths should be unique in this test: {chosen_lengths=}"

assert len(set(chosen_lengths)) == len(
chosen_lengths
), f"Lengths should be unique in this test: {chosen_lengths=}"

mbatch = distributed_collate_fn(mbatch)
assert mbatch["chosen"].shape[1] % expected_seq_len_multiple == 0
Expand All @@ -282,8 +286,10 @@ def test_dpo_loader_pad_to_multiple(init_model_parallel, make_tmp_jsonl, str_to_
assert mbatch["position_ids"].shape[1] % expected_seq_len_multiple == 0

# Check that all ranks have the same length
max_chosen_seq_length = torch.tensor(mbatch["chosen"].shape[1], device='cuda')
torch.distributed.all_reduce(max_chosen_seq_length, op=torch.distributed.ReduceOp.MAX, group=parallel_state.get_data_parallel_group())
max_chosen_seq_length = torch.tensor(mbatch["chosen"].shape[1], device="cuda")
torch.distributed.all_reduce(
max_chosen_seq_length, op=torch.distributed.ReduceOp.MAX, group=parallel_state.get_data_parallel_group()
)
assert mbatch["chosen"].shape[1] == max_chosen_seq_length.item()

num_mini_batches += 1
Expand Down

0 comments on commit 96b4f01

Please sign in to comment.