Skip to content

Commit

Permalink
Merge branch 'rerun_step' into 'main'
Browse files Browse the repository at this point in the history
Add functionality to re-run iterations

See merge request ADLR/megatron-lm!2282
  • Loading branch information
ko3n1g committed Dec 8, 2024
2 parents d677ca3 + cf84356 commit 43fa44c
Show file tree
Hide file tree
Showing 10 changed files with 1,319 additions and 56 deletions.
22 changes: 12 additions & 10 deletions megatron/core/distributed/param_and_grad_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,15 @@

import logging
import math
import os
from contextlib import nullcontext
from enum import Enum
from typing import Dict, List, Optional

import torch
from torch.distributed import _coalescing_manager

from megatron.core.rerun_state_machine import get_rerun_state_machine

from ..utils import is_float8tensor, is_torch_min_version, log_on_each_pipeline_stage
from .distributed_data_parallel_config import DistributedDataParallelConfig

Expand Down Expand Up @@ -153,15 +154,16 @@ def check_for_nan_in_grad(self):
Make sure norm of grads in bucket are not NaN prior to data-parallel
all-reduce / reduce-scatter.
"""
global_rank = torch.distributed.get_rank()
norm_is_nan = self.buckets[0].grad_data.norm(p=2).isnan()
for i in range(1, len(self.buckets)):
norm_is_nan.logical_or_(self.buckets[i].grad_data.norm(p=2).isnan())
assert not norm_is_nan, (
f'Rank {global_rank}: found NaN in local grad norm in '
f'backward pass before data-parallel communication collective. '
f'Device: {torch.cuda.current_device()}, node: {os.uname()[1]}'
)
rerun_state_machine = get_rerun_state_machine()
for i in range(len(self.buckets)):
rerun_state_machine.validate_result(
result=self.buckets[i].grad_data.norm(p=2),
rejection_func=torch.isnan,
message=f"found NaN in local grad norm for bucket #{i} "
f"in backward pass before data-parallel communication collective",
tolerance=0.001, # 0.1% tolerance to account for non-deterministic FA backward
fatal=True,
)

def start_param_sync(self, force_sync: bool = False):
"""
Expand Down
Loading

0 comments on commit 43fa44c

Please sign in to comment.