Skip to content

Commit

Permalink
Merge branch 'kunlunl/native_fp8_2' into 'main'
Browse files Browse the repository at this point in the history
Add native-fp8

See merge request ADLR/megatron-lm!1669
  • Loading branch information
ko3n1g committed Sep 5, 2024
2 parents 3396356 + 033d8b0 commit 01945b9
Show file tree
Hide file tree
Showing 31 changed files with 816 additions and 228 deletions.
6 changes: 3 additions & 3 deletions .gitlab/stages/01.tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ include:
- template: Security/Secret-Detection.gitlab-ci.yml

build_image:
tags:
tags:
- ${TAG}
image: docker:26.1.4-dind
timeout: 45m
Expand Down Expand Up @@ -90,7 +90,7 @@ unit_tests:
parallel:
matrix:
- TAG: latest
- TAG: f2d356582247e1df5a4c0f7c426d33096a394dc1
- TAG: f6ee2ebaf2c8a3bfa091a8327452078ecd89fc3a
tags: [8xL40S]
variables:
GIT_STRATEGY: clone
Expand Down Expand Up @@ -171,4 +171,4 @@ secret_detection:
echo "Atleast one vulnerability has been found"
cat gl-secret-detection-report.json | jq '.'
exit 1
fi
fi
2 changes: 1 addition & 1 deletion megatron/core/distributed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@
from .distributed_data_parallel import DistributedDataParallel
from .distributed_data_parallel_config import DistributedDataParallelConfig
from .finalize_model_grads import finalize_model_grads
from .param_and_grad_buffer import ParamAndGradBuffer, shard_buffer
from .param_and_grad_buffer import ParamAndGradBuffer, partition_buckets, shard_buffer
93 changes: 71 additions & 22 deletions megatron/core/distributed/distributed_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@
from ..config_logger import has_config_logger_enabled, log_config_to_disk
from ..transformer.module import MegatronModule
from ..transformer.transformer_config import TransformerConfig
from ..utils import log_single_rank
from ..utils import is_float8tensor, log_single_rank
from .distributed_data_parallel_config import DistributedDataParallelConfig
from .param_and_grad_buffer import ParamAndGradBuffer
from .param_and_grad_buffer import BucketGroup, ParamAndGradBuffer, partition_buckets

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -78,7 +78,7 @@ def __init__(
self.bucket_size = None

self.module = module
self.param_to_buffer = {}
self.param_to_bucket_group = {}

# Group parameters by their gradient type.
param_to_name = {}
Expand All @@ -100,19 +100,50 @@ def allocate_buffers_for_parameters(
input_params, data_parallel_group, gradient_scaling_factor
):
param_and_grad_dtype_to_params = {}
param_and_grad_dtype_to_offsets = {}
param_and_grad_dtype_to_indices = {}

# Group parameters by their gradient type.
for param in input_params:
if not param.requires_grad:
continue

param_dtype = param.dtype
if is_float8tensor(param):
# Currently TE's Float8Tensor is a wrapper of torch.Tensor. It has a "fake"
# dtype (usually a higher precision dtype such as bfloat16), but its actual
# data is stored in the form of a torch uint8 tensor within the Float8Tensor's
# ".data" attribute. Therefore, when creating the param buffer for fp8 params,
# it is necessary to use torch.uint8, not the "fake" dtype got from
# "param.dtype".
param_dtype = torch.uint8
grad_dtype = torch.float if self.ddp_config.grad_reduce_in_fp32 else param.dtype

params = param_and_grad_dtype_to_params.get((param_dtype, grad_dtype), [])
params.append(param)
param_and_grad_dtype_to_params[(param_dtype, grad_dtype)] = params

# Get the index of each param among the params with same dtype, if a param is fp8,
# use its "fake" high precision dtype to find which params have same dtype with it.
# For example:
# Case 1:
# params = [p1(bf16), p2(bf16), p3(bf16), p4(bf16)]
# param_and_grad_dtype_to_indices = {
# (torch.bfloat16, torch.float32): [0, 1, 2, 3],
# }
# Case 2:
# params = [p1(bf16), p2(fp8), p3(fp8), p4(bf16)]
# param_and_grad_dtype_to_indices = {
# (torch.bfloat16, torch.float32): [0, 3],
# (torch.uint8, torch.float32): [1, 2],
# }
# We need these indices to load a non-native-fp8 checkpoint in native-fp8 mode.
offset = param_and_grad_dtype_to_offsets.get((param.dtype, grad_dtype), 0)
param_and_grad_dtype_to_offsets[(param.dtype, grad_dtype)] = offset + 1
indices = param_and_grad_dtype_to_indices.get((param_dtype, grad_dtype), [])
indices.append(offset)
param_and_grad_dtype_to_indices[(param_dtype, grad_dtype)] = indices

if not config.calculate_per_token_loss:
target_gradient_scaling_factor = 1.0 / parallel_state.get_data_parallel_world_size(
with_context_parallel=True
Expand Down Expand Up @@ -140,12 +171,26 @@ def allocate_buffers_for_parameters(
self.bucket_size,
param_to_name,
gradient_scaling_factor,
param_and_grad_dtype_to_indices[(param_dtype, grad_dtype)],
)
)
for param in params:
self.param_to_buffer[param] = buffers[-1]

return buffers
# In some scenarios, we want to put buckets from different buffers into a group so that
# their communication can be aggregated. For example, when there are both fp8 buffers
# and bf16 buffers in the model and vpp is enabled, each model chunk will have an fp8
# bucket and a bf16 bucket, which doubles the number of communication kernels, and
# because of the use of CUDA_DEVICE_MAX_CONNECTIONS=1, having multiple back-to-back
# communications will prevent the overlap of the communication kernels with computation
# kernels.
bucket_groups = partition_buckets(buffers)

# Create map from param to BucketGroup, used in pre_hook.
for bucket_group in bucket_groups:
for bucket in bucket_group.buckets:
for param in bucket.params_list:
self.param_to_bucket_group[param] = bucket_group

return buffers, bucket_groups

if config.calculate_per_token_loss:
gradient_scaling_factor = 1.0
Expand All @@ -164,17 +209,19 @@ def allocate_buffers_for_parameters(
expert_gradient_scaling_factor = 1.0 / data_parallel_world_size

# Allocate the param+grad buffers for dense params' grads.
self.buffers = allocate_buffers_for_parameters(
self.buffers, self.bucket_groups = allocate_buffers_for_parameters(
dense_params,
parallel_state.get_data_parallel_group(with_context_parallel=True),
gradient_scaling_factor=gradient_scaling_factor,
)

# Allocate separate param+grad buffers for expert parallel params' grads.
self.expert_parallel_buffers = allocate_buffers_for_parameters(
expert_parallel_params,
parallel_state.get_data_modulo_expert_parallel_group(with_context_parallel=True),
gradient_scaling_factor=expert_gradient_scaling_factor,
self.expert_parallel_buffers, self.expert_parallel_bucket_groups = (
allocate_buffers_for_parameters(
expert_parallel_params,
parallel_state.get_data_modulo_expert_parallel_group(with_context_parallel=True),
gradient_scaling_factor=expert_gradient_scaling_factor,
)
)

# Delete references to weight_tensor if they exist since we don't want two parameter copies
Expand All @@ -200,7 +247,7 @@ def unmap_weight_tensor(m):
param_tmp = param.expand_as(param)
# Get the gradient accumulator function.
grad_acc = param_tmp.grad_fn.next_functions[0][0]
grad_acc.register_hook(self._make_param_hook(param, self.param_to_buffer))
grad_acc.register_hook(self._make_param_hook(param, self.param_to_bucket_group))
self.grad_accs.append(grad_acc)

def forward(self, *inputs, **kwargs):
Expand All @@ -212,7 +259,7 @@ def forward(self, *inputs, **kwargs):
def _make_param_hook(
self,
param: torch.nn.Parameter,
param_to_buffer: Dict[torch.nn.Parameter, ParamAndGradBuffer],
param_to_bucket_group: Dict[torch.nn.Parameter, BucketGroup],
):
"""
Creates the all-reduce / reduce-scatter hook for backprop.
Expand All @@ -231,7 +278,7 @@ def param_hook(*unused):
param.grad = None

if self.ddp_config.overlap_grad_reduce:
param_to_buffer[param].register_grad_ready(param)
param_to_bucket_group[param].register_grad_ready(param)

return param_hook

Expand All @@ -240,13 +287,13 @@ def no_sync(self):
"""
Context manager that turns off gradient synchronization.
"""
for buffer in self.buffers + self.expert_parallel_buffers:
buffer.is_last_microbatch = False
for bucket_group in self.bucket_groups + self.expert_parallel_bucket_groups:
bucket_group.is_last_microbatch = False
try:
yield
finally:
for buffer in self.buffers + self.expert_parallel_buffers:
buffer.is_last_microbatch = True
for bucket_group in self.bucket_groups + self.expert_parallel_bucket_groups:
bucket_group.is_last_microbatch = True

def start_grad_sync(self, *unused):
"""
Expand All @@ -257,8 +304,8 @@ def start_grad_sync(self, *unused):
calls. When overlap_grad_reduce is set to False, calls synchronous
communication ops.
"""
for buffer in self.buffers + self.expert_parallel_buffers:
buffer.start_grad_sync()
for bucket_group in self.bucket_groups + self.expert_parallel_bucket_groups:
bucket_group.start_grad_sync()

def scale_gradients(self, scaling_factor: float) -> None:
"""Scale all gradients inside the buffers by `scaling_factor`."""
Expand All @@ -274,8 +321,8 @@ def finish_grad_sync(self):
calls to complete. When overlap_grad_reduce is set to False, calls synchronous
communication ops.
"""
for buffer in self.buffers + self.expert_parallel_buffers:
buffer.finish_grad_sync()
for bucket_group in self.bucket_groups + self.expert_parallel_bucket_groups:
bucket_group.finish_grad_sync()

def zero_grad_buffer(self):
"""
Expand All @@ -287,6 +334,8 @@ def zero_grad_buffer(self):
param.grad_added_to_main_grad = False
for buffer in self.buffers + self.expert_parallel_buffers:
buffer.reset()
for bucket_group in self.bucket_groups + self.expert_parallel_bucket_groups:
bucket_group.reset()

def broadcast_params(self):
"""
Expand Down
4 changes: 4 additions & 0 deletions megatron/core/distributed/distributed_data_parallel_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,7 @@ class DistributedDataParallelConfig:
average_in_collective: bool = False
"""If true, compute average in collective directly, as opposed to dividing by the
dp_size first and then computing sum in the collective."""

fp8_param_gather: bool = False
"""If true, keep the compute param in fp8 (do not use any other intermediate dtype) and
perform the param all-gather in fp8."""
Loading

0 comments on commit 01945b9

Please sign in to comment.