diff --git a/megatron/core/model_parallel_config.py b/megatron/core/model_parallel_config.py index f2751673e4..936ac1edf7 100644 --- a/megatron/core/model_parallel_config.py +++ b/megatron/core/model_parallel_config.py @@ -272,6 +272,27 @@ class ModelParallelConfig: encoder and decoder (e.g., T5). Ignored if None. """ + overlap_p2p_comm_warmup_flush: bool = False + """If true, overlap communication and computation in warm up and flush phase. + Only valid when overlap_p2p_comm is True and batch_p2p_comm is False. + Defaults to False. + """ + + microbatch_group_size_per_vp_stage: Optional[int] = None + """This value specifies the number of micro-batches that are executed + at a time for a given virtual stage (both forward and backward). + Default (in __post_init__() method below) to pipeline_parallel_size + which specifies a depth-first schedule. + Example: for PP=2 VP=2, when microbatch_group_size_per_vp_stage=2, + num_microbatches = 4, we have + rank 0 | 0 1 0 1 2 3 2 3 + rank 1 | 0 1 0 1 2 3 2 3 + When microbatch_group_size_per_vp_stage=3, num_microbatches = 5, + we have + rank 0 | 0 1 2 0 1 2 3 4 3 4 + rank 1 | 0 1 2 0 1 2 3 4 3 4 + """ + ################### # CPU Offloading ################### @@ -339,6 +360,16 @@ def __post_init__(self): if self.expert_model_parallel_size > 1 and self.tensor_model_parallel_size > 1: if self.sequence_parallel is False: raise ValueError( - "When using expert parallelism and tensor parallelism, sequence parallelism " - "must be used" + "When using expert parallelism and tensor parallelism, " + "sequence parallelism must be used" + ) + + if self.microbatch_group_size_per_vp_stage is None: + self.microbatch_group_size_per_vp_stage = self.pipeline_model_parallel_size + + if self.overlap_p2p_comm_warmup_flush: + if not self.overlap_p2p_comm or self.batch_p2p_comm: + raise ValueError( + "Pipeline parallel communication overlapping in warmup and flush is only " + "compatible with overlap_p2p_comm but not batch_p2p_comm" ) diff --git a/megatron/core/pipeline_parallel/p2p_communication.py b/megatron/core/pipeline_parallel/p2p_communication.py index 3e33e7c2f8..88aee8987a 100644 --- a/megatron/core/pipeline_parallel/p2p_communication.py +++ b/megatron/core/pipeline_parallel/p2p_communication.py @@ -1,8 +1,6 @@ # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. -import operator -from functools import reduce -from typing import Callable, List, Optional, Tuple, Union +from typing import List, Optional, Tuple, Union import torch @@ -166,8 +164,7 @@ def _p2p_ops( prev_pipeline_rank: int, next_pipeline_rank: int, ): - reqs = [] - rank = get_pipeline_model_parallel_rank() + reqs = {} even_send_odd_recv_group = group if get_pipeline_model_parallel_world_size() == 2: # Use the global process group for one of the two p2p communications @@ -183,50 +180,50 @@ def _p2p_ops( send_next_req = torch.distributed.isend( tensor=tensor_send_next, dst=next_pipeline_rank, group=even_send_odd_recv_group ) - reqs.append(send_next_req) + reqs["send_next"] = send_next_req if tensor_recv_prev is not None: recv_prev_req = torch.distributed.irecv( tensor=tensor_recv_prev, src=prev_pipeline_rank, group=even_recv_odd_send_group ) - reqs.append(recv_prev_req) + reqs["recv_prev"] = recv_prev_req if tensor_send_prev is not None: send_prev_req = torch.distributed.isend( tensor=tensor_send_prev, dst=prev_pipeline_rank, group=even_send_odd_recv_group ) - reqs.append(send_prev_req) + reqs["send_prev"] = send_prev_req if tensor_recv_next is not None: recv_next_req = torch.distributed.irecv( tensor=tensor_recv_next, src=next_pipeline_rank, group=even_recv_odd_send_group ) - reqs.append(recv_next_req) + reqs["recv_next"] = recv_next_req else: if tensor_recv_prev is not None: recv_prev_req = torch.distributed.irecv( tensor=tensor_recv_prev, src=prev_pipeline_rank, group=even_send_odd_recv_group ) - reqs.append(recv_prev_req) + reqs["recv_prev"] = recv_prev_req if tensor_send_next is not None: send_next_req = torch.distributed.isend( tensor=tensor_send_next, dst=next_pipeline_rank, group=even_recv_odd_send_group ) - reqs.append(send_next_req) + reqs["send_next"] = send_next_req if tensor_recv_next is not None: recv_next_req = torch.distributed.irecv( tensor=tensor_recv_next, src=next_pipeline_rank, group=even_send_odd_recv_group ) - reqs.append(recv_next_req) + reqs["recv_next"] = recv_next_req if tensor_send_prev is not None: send_prev_req = torch.distributed.isend( tensor=tensor_send_prev, dst=prev_pipeline_rank, group=even_recv_odd_send_group ) - reqs.append(send_prev_req) + reqs["send_prev"] = send_prev_req return reqs @@ -349,7 +346,10 @@ def _ring_exchange_wrapper(**kwargs): assert not isinstance(prev_rank, list) prev_rank = [prev_rank] - reqs = [] + if config.use_ring_exchange_p2p or config.batch_p2p_comm: + reqs = [] + else: + reqs = {} tensor_recv_prev_list = [] tensor_recv_next_list = [] @@ -366,20 +366,22 @@ def _ring_exchange_wrapper(**kwargs): else: tensor_recv_next = None - reqs.extend( - p2p_func( - tensor_send_prev=tensor_send_prev, - tensor_recv_prev=tensor_recv_prev, - tensor_send_next=tensor_send_next, - tensor_recv_next=tensor_recv_next, - group=group, - prev_pipeline_rank=pr, - next_pipeline_rank=nr, - ) + p2p_reqs = p2p_func( + tensor_send_prev=tensor_send_prev, + tensor_recv_prev=tensor_recv_prev, + tensor_send_next=tensor_send_next, + tensor_recv_next=tensor_recv_next, + group=group, + prev_pipeline_rank=pr, + next_pipeline_rank=nr, ) + if isinstance(p2p_reqs, list): + reqs.extend(p2p_reqs) + else: + reqs.update(p2p_reqs) if wait_on_reqs and len(reqs) > 0: - for req in reqs: + for req in reqs if isinstance(reqs, list) else reqs.values(): req.wait() reqs = None diff --git a/megatron/core/pipeline_parallel/schedules.py b/megatron/core/pipeline_parallel/schedules.py index 90c4a87947..fcfb407451 100644 --- a/megatron/core/pipeline_parallel/schedules.py +++ b/megatron/core/pipeline_parallel/schedules.py @@ -557,6 +557,16 @@ def forward_backward_pipelining_with_interleaving( communication between pipeline stages as needed. Returns dictionary with losses if the last stage, empty dict otherwise.""" + + # Convention used in this function: + # num_microbatches for number of microbatches per pipeline stage; + # num_model_chunks for virtual pipeline size; + # then total_num_microbatches = num_microbatches * num_model_chunks. + # Their corresponding index variables are + # microbatch_id in [0, num_microbatches) + # model_chunk_id in [0, num_model_chunks) + # virtual_microbatch_id in [0, total_num_microbatches) + assert isinstance(model, list), "interleaved pipeline parallelism expected model chunking" assert all(isinstance(chunk, torch.nn.Module) for chunk in model), "invalid model chunking" assert isinstance( @@ -632,10 +642,26 @@ def enable_grad_sync(): pipeline_parallel_size = parallel_state.get_pipeline_model_parallel_world_size() pipeline_parallel_rank = parallel_state.get_pipeline_model_parallel_rank() - if num_microbatches % pipeline_parallel_size != 0: - msg = f'number of microbatches ({num_microbatches}) is not divisible by ' - msg += f'pipeline-model-parallel-size ({pipeline_parallel_size}) ' - msg += 'when using interleaved schedule' + if ( + config.microbatch_group_size_per_vp_stage > num_microbatches + or config.microbatch_group_size_per_vp_stage < pipeline_parallel_size + ): + msg = ( + 'The number of contiguous micro-batches in a virtual pipeline stage' + f'should range in [PP={pipeline_parallel_size} , M={num_microbatches}]' + ) + raise ValueError(msg) + + # If the final micro-batch group has fewer micro-batches than pipeline-parallel size, + # the pipeline will have dependency bubbles. + final_microbatch_group_size = num_microbatches % config.microbatch_group_size_per_vp_stage + if 0 < final_microbatch_group_size < pipeline_parallel_size: + msg = 'The remainder of M (the total micro-batches) divided by N (number of ' + msg += 'contiguous micro-batches in a virtual pipeline stage) should be 0, ' + msg += 'or larger than or equal to the pipeline-parallel size, but it is ' + msg += f'{final_microbatch_group_size}. ' + msg += 'Otherwise, it introduces dependency bubbles in the pipeline ' + msg += 'and reduces throughput.' raise RuntimeError(msg) model_type = get_model_type(model[0]) @@ -659,19 +685,17 @@ def enable_grad_sync(): if forward_only: num_warmup_microbatches = total_num_microbatches else: - # Run all forward passes and then all backward passes if number of - # microbatches is just the number of pipeline stages. - # Otherwise, perform (num_model_chunks-1)*pipeline_parallel_size on + # Run (num_model_chunks-1)*config.microbatch_group_size_per_vp_stage on # all workers, followed by more microbatches after depending on # stage ID (more forward passes for earlier stages, later stages can # immediately start with 1F1B). - if num_microbatches == pipeline_parallel_size: + num_warmup_microbatches = (pipeline_parallel_size - pipeline_parallel_rank - 1) * 2 + num_warmup_microbatches += ( + num_model_chunks - 1 + ) * config.microbatch_group_size_per_vp_stage + if num_warmup_microbatches >= total_num_microbatches: num_warmup_microbatches = total_num_microbatches all_warmup_microbatches = True - else: - num_warmup_microbatches = (pipeline_parallel_size - pipeline_parallel_rank - 1) * 2 - num_warmup_microbatches += (num_model_chunks - 1) * pipeline_parallel_size - num_warmup_microbatches = min(num_warmup_microbatches, total_num_microbatches) num_microbatches_remaining = total_num_microbatches - num_warmup_microbatches # Checkpoint the activations of partial Transformer layers in a number of micro-batches @@ -691,10 +715,55 @@ def enable_grad_sync(): config.param_sync_func[0](model[0].parameters()) config.param_sync_func[1](model[1].parameters()) - def get_model_chunk_id(microbatch_id, forward): + # Create a tunable schedule lookup table. + # The schedule lookup table uses the virtual_microbatch_id to find the corresponding + # microbatch_id and model_chunk_id. For example, the tunable schedule table for + # PP2 N3M5 with VP2 is constructed as below: + # virtual_microbatch_id | 0 1 2 3 4 5 6 7 8 9 + # microbatch_id | 0 1 2 0 1 2 3 4 3 4 + # model_chunk_id | 0 0 0 1 1 1 0 0 1 1 + schedule_table = [] + for min_microbatch_id_in_group in range( + 0, num_microbatches, config.microbatch_group_size_per_vp_stage + ): + if ( + min_microbatch_id_in_group + config.microbatch_group_size_per_vp_stage + >= num_microbatches + ): + # Construct schedule for the last microbatch group + schedule_table.extend( + [ + (microbatch_id, model_chunk_id) + for model_chunk_id in range(len(model)) + for microbatch_id in range(min_microbatch_id_in_group, num_microbatches) + ] + ) + else: + # Construct schedule for other microbatch groups + schedule_table.extend( + [ + (microbatch_id, model_chunk_id) + for model_chunk_id in range(len(model)) + for microbatch_id in range( + min_microbatch_id_in_group, + min_microbatch_id_in_group + config.microbatch_group_size_per_vp_stage, + ) + ] + ) + + # Decouple individual lookup table for microbatch_id and model_chunk_id. + # For example, the micro-batch table for PP2 N3M5 with VP2 is + # virtual_microbatch_id | 0 1 2 3 4 5 6 7 8 9 + # microbatch_id | 0 1 2 0 1 2 3 4 3 4 + # Similarly, the model chunk table is + # virtual_microbatch_id | 0 1 2 3 4 5 6 7 8 9 + # model_chunk_id | 0 0 0 1 1 1 0 0 1 1 + # Both tables are indexed with virtual_microbatch_id. + microbatch_id_table, model_chunk_id_table = zip(*schedule_table) + + def get_model_chunk_id(virtual_microbatch_id, forward): """Helper method to get the model chunk ID given the iteration number.""" - microbatch_id_in_group = microbatch_id % (pipeline_parallel_size * num_model_chunks) - model_chunk_id = microbatch_id_in_group // pipeline_parallel_size + model_chunk_id = model_chunk_id_table[virtual_microbatch_id % total_num_microbatches] if not forward: model_chunk_id = num_model_chunks - model_chunk_id - 1 return model_chunk_id @@ -702,38 +771,93 @@ def get_model_chunk_id(microbatch_id, forward): def get_microbatch_id_in_model_chunk(iteration_id, forward): """Helper method to get the microbatch_id within model chunk given the iteration number.""" assert forward - iteration_group_id = iteration_id // (pipeline_parallel_size * num_model_chunks) - microbatch_id_in_model_chunk = (iteration_group_id * pipeline_parallel_size) + ( - iteration_id % pipeline_parallel_size - ) + microbatch_id_in_model_chunk = microbatch_id_table[iteration_id] return microbatch_id_in_model_chunk - def is_first_microbatch_for_model_chunk(microbatch_id: int) -> bool: + def num_released_microbatches(virtual_microbatch_id, model_chunk_id): + """Helper method to count number of released (i.e. popped from input_tensors) + microbatches for a model chunk.""" + if forward_only: # Micro-batch is released after forward prop. + return model_chunk_id_table[:virtual_microbatch_id].count(model_chunk_id) + else: # Micro-batch is released after backward prop. + # Zero backward prop in warmup. + if virtual_microbatch_id < num_warmup_microbatches: + return 0 + else: + backward_microbatch_id = virtual_microbatch_id - num_warmup_microbatches + model_chunk_id = num_model_chunks - model_chunk_id - 1 + return model_chunk_id_table[:backward_microbatch_id].count(model_chunk_id) + + def is_first_microbatch_for_model_chunk(virtual_microbatch_id: int) -> bool: """Check if an iteration is the first for a model chunk.""" - microbatch_group_size = pipeline_parallel_size * num_model_chunks - microbatch_group_id = microbatch_id // microbatch_group_size - microbatch_id_in_group = microbatch_id % microbatch_group_size - if microbatch_group_id == 0: - return microbatch_id_in_group % pipeline_parallel_size == 0 + if virtual_microbatch_id < total_num_microbatches: + return microbatch_id_table[virtual_microbatch_id] == 0 else: return False - def is_last_microbatch_for_model_chunk(microbatch_id: int) -> bool: + def is_last_microbatch_for_model_chunk(virtual_microbatch_id: int) -> bool: """Check if an iteration is the last for a model chunk.""" - microbatch_group_size = pipeline_parallel_size * num_model_chunks - num_microbatch_groups = total_num_microbatches // microbatch_group_size - microbatch_group_id = microbatch_id // microbatch_group_size - microbatch_id_in_group = microbatch_id % microbatch_group_size - if microbatch_group_id == num_microbatch_groups - 1: - return microbatch_id_in_group % pipeline_parallel_size == pipeline_parallel_size - 1 + if virtual_microbatch_id < total_num_microbatches: + return microbatch_id_table[virtual_microbatch_id] == num_microbatches - 1 else: return False - def forward_step_helper(microbatch_id, current_microbatch, checkpoint_activations_microbatch): + def recv_tensor_from_previous_stage(virtual_microbatch_id, forward): + """Determine if peers are sending, and where in data structure + to put received tensors. + Return a boolean if the pipeline stage expects to recv from peers, and the + corresponding model_chunk_id for the received tensor. + """ + recv = True + # The leading pipeline stage is the first rank in fwd and the last rank in bwd. + is_leading_pipeline_stage = ( + parallel_state.is_pipeline_first_stage(ignore_virtual=True) + if forward + else parallel_state.is_pipeline_last_stage(ignore_virtual=True) + ) + + last_model_chunk = (num_model_chunks - 1) if forward else 0 + + if is_leading_pipeline_stage: + # The leading pipeline stage is ahead of the ending pipeline stage + # (i.e. last rank in fwd and first rank in bwd) by (pipeline_parallel_size - 1). + # Let's consider bwd as an example with PP 4: + # 0 1 2 3 ... + # 0 1 2 3 ... + # 0 1 2 3 ... + # 0 1 2 3 ... + if virtual_microbatch_id < (pipeline_parallel_size - 1): + # The ending stage has not produced any tensors, so no recv will be initiated. + recv = False + next_model_chunk_id = get_model_chunk_id(virtual_microbatch_id + 1, forward) + else: + # Find the model chunk of the aligned microbatches in the ending stage. + # For example, microbatch 0 in the ending stage is aligned with microbatch 3 + # in the leading stage. + next_model_chunk_id = get_model_chunk_id( + virtual_microbatch_id - (pipeline_parallel_size - 1), forward + ) + # Last model chunk in the final stage does not produce tensors. + if next_model_chunk_id == last_model_chunk: + recv = False + if forward: + # Model chunk id increases in forward. + next_model_chunk_id += 1 + else: + # Model chunk id decreases in backward. + next_model_chunk_id -= 1 + else: + next_model_chunk_id = get_model_chunk_id(virtual_microbatch_id + 1, forward) + + return recv, next_model_chunk_id + + def forward_step_helper( + virtual_microbatch_id, microbatch_id, checkpoint_activations_microbatch + ): """Helper method to run forward step with model split into chunks (run set_virtual_pipeline_model_parallel_rank() before calling forward_step()).""" - model_chunk_id = get_model_chunk_id(microbatch_id, forward=True) + model_chunk_id = get_model_chunk_id(virtual_microbatch_id, forward=True) parallel_state.set_virtual_pipeline_model_parallel_rank(model_chunk_id) # launch param synchronization for next model chunk @@ -742,12 +866,14 @@ def forward_step_helper(microbatch_id, current_microbatch, checkpoint_activation # asynchronous communication at the same time across the # pipeline-parallel group. if config.param_sync_func is not None: - param_sync_microbatch_id = microbatch_id + pipeline_parallel_rank + param_sync_virtual_microbatch_id = virtual_microbatch_id + pipeline_parallel_rank if ( - param_sync_microbatch_id < total_num_microbatches - and is_first_microbatch_for_model_chunk(param_sync_microbatch_id) + param_sync_virtual_microbatch_id < total_num_microbatches + and is_first_microbatch_for_model_chunk(param_sync_virtual_microbatch_id) ): - param_sync_chunk_id = get_model_chunk_id(param_sync_microbatch_id, forward=True) + 1 + param_sync_chunk_id = ( + get_model_chunk_id(param_sync_virtual_microbatch_id, forward=True) + 1 + ) if 1 < param_sync_chunk_id < num_model_chunks: config.param_sync_func[param_sync_chunk_id]( model[param_sync_chunk_id].parameters() @@ -757,7 +883,14 @@ def forward_step_helper(microbatch_id, current_microbatch, checkpoint_activation if parallel_state.is_pipeline_first_stage(): if len(input_tensors[model_chunk_id]) == len(output_tensors[model_chunk_id]): input_tensors[model_chunk_id].append(None) - input_tensor = input_tensors[model_chunk_id][-1] + + # For non-depth-first pipeline schedules, the first rank would buffer multiple received + # activation tensors for a model chunk until accessed during warmup. + # This input buffering is needed to overlap the computation with the receipt of + # the next inputs. To index the proper buffered inputs for forword_step, we use + # microbatch_id offset with number of released microbatches that have completed backprop. + offset = num_released_microbatches(virtual_microbatch_id, model_chunk_id) + input_tensor = input_tensors[model_chunk_id][microbatch_id - offset] output_tensor, num_tokens = forward_step( forward_step_func, @@ -770,31 +903,37 @@ def forward_step_helper(microbatch_id, current_microbatch, checkpoint_activation collect_non_loss_data, checkpoint_activations_microbatch, check_first_val_step( - first_val_step, forward_only, is_first_microbatch_for_model_chunk(microbatch_id) + first_val_step, + forward_only, + is_first_microbatch_for_model_chunk(virtual_microbatch_id), ), - current_microbatch=current_microbatch, + current_microbatch=microbatch_id, ) + output_tensors[model_chunk_id].append(output_tensor) nonlocal total_num_tokens total_num_tokens += num_tokens.item() - # if forward-only, no need to save tensors for a backward pass + # If forward-only, no need to save tensors for a backward pass. if forward_only: - input_tensors[model_chunk_id].pop() + # Release the tensor that have completed forward step. + input_tensors[model_chunk_id].pop(0) output_tensors[model_chunk_id].pop() return output_tensor - def backward_step_helper(microbatch_id): + def backward_step_helper(virtual_microbatch_id): """Helper method to run backward step with model split into chunks (run set_virtual_pipeline_model_parallel_rank() before calling backward_step()).""" - model_chunk_id = get_model_chunk_id(microbatch_id, forward=False) + model_chunk_id = get_model_chunk_id(virtual_microbatch_id, forward=False) parallel_state.set_virtual_pipeline_model_parallel_rank(model_chunk_id) # launch grad synchronization (default) - if config.grad_sync_func is None and is_last_microbatch_for_model_chunk(microbatch_id): + if config.grad_sync_func is None and is_last_microbatch_for_model_chunk( + virtual_microbatch_id + ): enable_grad_sync() synchronized_model_chunks.add(model_chunk_id) @@ -804,6 +943,7 @@ def backward_step_helper(microbatch_id): input_tensor = input_tensors[model_chunk_id].pop(0) output_tensor = output_tensors[model_chunk_id].pop(0) output_tensor_grad = output_tensor_grads[model_chunk_id].pop(0) + input_tensor_grad = backward_step( input_tensor, output_tensor, output_tensor_grad, model_type, config ) @@ -814,11 +954,13 @@ def backward_step_helper(microbatch_id): # asynchronous communication at the same time across the # pipeline-parallel group. if config.grad_sync_func is not None: - grad_sync_microbatch_id = microbatch_id - pipeline_parallel_rank - if grad_sync_microbatch_id >= 0 and is_last_microbatch_for_model_chunk( - grad_sync_microbatch_id + grad_sync_virtual_microbatch_id = virtual_microbatch_id - pipeline_parallel_rank + if grad_sync_virtual_microbatch_id >= 0 and is_last_microbatch_for_model_chunk( + grad_sync_virtual_microbatch_id ): - grad_sync_chunk_id = get_model_chunk_id(grad_sync_microbatch_id, forward=False) + grad_sync_chunk_id = get_model_chunk_id( + grad_sync_virtual_microbatch_id, forward=False + ) enable_grad_sync() config.grad_sync_func[grad_sync_chunk_id](model[grad_sync_chunk_id].parameters()) synchronized_model_chunks.add(grad_sync_chunk_id) @@ -831,15 +973,66 @@ def backward_step_helper(microbatch_id): input_tensors[0].append(p2p_communication.recv_forward(tensor_shape, config)) fwd_wait_handles = None + fwd_wait_recv_handles = None bwd_wait_handles = None + bwd_wait_recv_handles = None + if parallel_state.is_pipeline_first_stage(ignore_virtual=True): + fwd_recv_buffer_size = ( + config.microbatch_group_size_per_vp_stage - pipeline_parallel_size + 1 + ) + else: + fwd_recv_buffer_size = 1 + if parallel_state.is_pipeline_last_stage(ignore_virtual=True): + bwd_recv_buffer_size = ( + config.microbatch_group_size_per_vp_stage - pipeline_parallel_size + 1 + ) + else: + bwd_recv_buffer_size = 1 + fwd_recv_buffer = [None] * fwd_recv_buffer_size + bwd_recv_buffer = [None] * bwd_recv_buffer_size + recv_prev_wait_handles = [] + send_next_wait_handle = None + send_prev_wait_handle = None + recv_next_wait_handles = [] for k in range(num_warmup_microbatches): + cur_model_chunk_id = get_model_chunk_id(k, forward=True) + parallel_state.set_virtual_pipeline_model_parallel_rank(cur_model_chunk_id) + + if config.overlap_p2p_comm_warmup_flush: + if not parallel_state.is_pipeline_first_stage() and k != 0: + assert recv_prev_wait_handles, ( + f'pp rank {pipeline_parallel_rank}, iteration {k},' + 'should have registered recv handle' + ) + recv_prev_wait_handle = recv_prev_wait_handles.pop(0) + recv_prev_wait_handle.wait() - if fwd_wait_handles is not None: - for req in fwd_wait_handles: - req.wait() + # Determine if tensor should be received from previous stage. + recv_prev, next_forward_model_chunk_id = recv_tensor_from_previous_stage(k, forward=True) - # Decide to checkpoint all layers' activations of the current micro-batch + # No receive in last iteration when recv iteration k+1. + if k == (total_num_microbatches - 1): + recv_prev = False + + # Prefetch recv for iteration k+1 for non-first ranks. + if config.overlap_p2p_comm_warmup_flush and not parallel_state.is_pipeline_first_stage( + ignore_virtual=True + ): + fwd_recv_buffer[k % fwd_recv_buffer_size], fwd_wait_recv_handles = ( + p2p_communication.send_forward_recv_forward( + output_tensor=None, # No output_tensor to send. + recv_prev=recv_prev, + tensor_shape=tensor_shape, + config=config, + overlap_p2p_comm=True, + ) + ) + + if fwd_wait_recv_handles: + recv_prev_wait_handles.append(fwd_wait_recv_handles.pop("recv_prev")) + + # Decide to checkpoint all layers' activations of the current micro-batch. if max_outstanding_backprops is not None: checkpoint_activations_microbatch = ( k % max_outstanding_backprops @@ -848,19 +1041,8 @@ def backward_step_helper(microbatch_id): else: checkpoint_activations_microbatch = None - current_microbatch = get_microbatch_id_in_model_chunk(k, forward=True) - output_tensor = forward_step_helper( - k, current_microbatch, checkpoint_activations_microbatch - ) - - # Determine if tensor should be received from previous stage. - next_forward_model_chunk_id = get_model_chunk_id(k + 1, forward=True) - recv_prev = True - if parallel_state.is_pipeline_first_stage(ignore_virtual=True): - if next_forward_model_chunk_id == 0: - recv_prev = False - if k == (total_num_microbatches - 1): - recv_prev = False + microbatch_id = get_microbatch_id_in_model_chunk(k, forward=True) + output_tensor = forward_step_helper(k, microbatch_id, checkpoint_activations_microbatch) # Don't send tensor downstream if on last stage. if parallel_state.is_pipeline_last_stage(): @@ -868,9 +1050,10 @@ def backward_step_helper(microbatch_id): # Send and receive tensors as appropriate (send tensors computed # in this iteration; receive tensors for next iteration). - if not config.overlap_p2p_comm: + if not config.overlap_p2p_comm_warmup_flush: if ( k == (num_warmup_microbatches - 1) + and not config.overlap_p2p_comm and not forward_only and not all_warmup_microbatches ): @@ -893,16 +1076,46 @@ def backward_step_helper(microbatch_id): input_tensor = p2p_communication.send_forward_recv_forward( output_tensor, recv_prev=recv_prev, tensor_shape=tensor_shape, config=config ) - input_tensors[next_forward_model_chunk_id].append(input_tensor) + if recv_prev: + input_tensors[next_forward_model_chunk_id].append(input_tensor) + deallocate_output_tensor(output_tensor, config.deallocate_pipeline_outputs) else: - input_tensor, fwd_wait_handles = p2p_communication.send_forward_recv_forward( - output_tensor, - recv_prev=recv_prev, - tensor_shape=tensor_shape, - config=config, - overlap_p2p_comm=True, - ) + if not parallel_state.is_pipeline_first_stage(ignore_virtual=True): + # Send only since recv prefetched. + _, fwd_wait_handles = p2p_communication.send_forward_recv_forward( + output_tensor, + recv_prev=False, + tensor_shape=tensor_shape, + config=config, + overlap_p2p_comm=True, + ) + else: # No prefetch for first rank, so both send and recv initiated. + fwd_recv_buffer[k % fwd_recv_buffer_size], fwd_wait_handles = ( + p2p_communication.send_forward_recv_forward( + output_tensor, + recv_prev=recv_prev, + tensor_shape=tensor_shape, + config=config, + overlap_p2p_comm=True, + ) + ) + if send_next_wait_handle is not None: + send_next_wait_handle.wait() + if fwd_wait_handles is not None: + send_next_wait_handle = ( + fwd_wait_handles.pop("send_next") if "send_next" in fwd_wait_handles else None + ) + if "recv_prev" in fwd_wait_handles: + recv_prev_wait_handles.append(fwd_wait_handles.pop("recv_prev")) + deallocate_output_tensor(output_tensor, config.deallocate_pipeline_outputs) + if recv_prev: + input_tensors[next_forward_model_chunk_id].append( + fwd_recv_buffer[k % fwd_recv_buffer_size] + ) + fwd_recv_buffer[(k + 1) % fwd_recv_buffer_size] = None + + if config.overlap_p2p_comm: if ( k == (num_warmup_microbatches - 1) and not forward_only @@ -913,7 +1126,7 @@ def backward_step_helper(microbatch_id): if parallel_state.is_pipeline_last_stage(ignore_virtual=True): recv_next = False - (output_tensor_grad, bwd_wait_handles) = ( + (bwd_recv_buffer[-1], bwd_wait_handles) = ( p2p_communication.send_backward_recv_backward( input_tensor_grad, recv_next=recv_next, @@ -922,18 +1135,26 @@ def backward_step_helper(microbatch_id): overlap_p2p_comm=True, ) ) + if send_prev_wait_handle is not None: + send_prev_wait_handle.wait() + if bwd_wait_handles is not None: + send_prev_wait_handle = ( + bwd_wait_handles.pop("send_prev") + if "send_prev" in bwd_wait_handles + else None + ) + if "recv_next" in bwd_wait_handles: + recv_next_wait_handles.append(bwd_wait_handles.pop("recv_next")) - output_tensor_grads[num_model_chunks - 1].append(output_tensor_grad) - input_tensors[next_forward_model_chunk_id].append(input_tensor) - - deallocate_output_tensor(output_tensor, config.deallocate_pipeline_outputs) + if recv_next: + output_tensor_grads[num_model_chunks - 1].append(bwd_recv_buffer[-1]) # Run 1F1B in steady state. for k in range(num_microbatches_remaining): # Forward pass. forward_k = k + num_warmup_microbatches - # Decide to checkpoint all layers' activations of the current micro-batch + # Decide to checkpoint all layers' activations of the current micro-batch. if max_outstanding_backprops is not None: checkpoint_activations_microbatch = ( forward_k % max_outstanding_backprops @@ -942,16 +1163,27 @@ def backward_step_helper(microbatch_id): else: checkpoint_activations_microbatch = None - current_microbatch = get_microbatch_id_in_model_chunk(forward_k, forward=True) + cur_model_chunk_id = get_model_chunk_id(forward_k, forward=True) + parallel_state.set_virtual_pipeline_model_parallel_rank(cur_model_chunk_id) + microbatch_id = get_microbatch_id_in_model_chunk(forward_k, forward=True) if config.overlap_p2p_comm: - if fwd_wait_handles is not None: - for req in fwd_wait_handles: - req.wait() + if not parallel_state.is_pipeline_first_stage(): + if config.overlap_p2p_comm_warmup_flush: + assert recv_prev_wait_handles, ( + f'pp rank {pipeline_parallel_rank}, fwd iteration {forward_k}, ' + 'should have registered recv handle' + ) + recv_prev_wait_handle = recv_prev_wait_handles.pop(0) + recv_prev_wait_handle.wait() + else: + if recv_prev_wait_handles is not None and recv_prev_wait_handles: + recv_prev_wait_handle = recv_prev_wait_handles.pop(0) + recv_prev_wait_handle.wait() deallocate_output_tensor(output_tensor, config.deallocate_pipeline_outputs) output_tensor = forward_step_helper( - forward_k, current_microbatch, checkpoint_activations_microbatch + forward_k, microbatch_id, checkpoint_activations_microbatch ) # Determine if current stage has anything to send in either direction, @@ -959,23 +1191,13 @@ def backward_step_helper(microbatch_id): forward_model_chunk_id = get_model_chunk_id(forward_k, forward=True) parallel_state.set_virtual_pipeline_model_parallel_rank(forward_model_chunk_id) - # Last virtual stage no activation tensor to send + # Last virtual stage no activation tensor to send. if parallel_state.is_pipeline_last_stage(): output_tensor = None - # Determine if peers are sending, and where in data structure to put - # received tensors. - recv_prev = True - if parallel_state.is_pipeline_first_stage(ignore_virtual=True): - # First stage is ahead of last stage by (pipeline_parallel_size - 1). - next_forward_model_chunk_id = get_model_chunk_id( - forward_k - (pipeline_parallel_size - 1), forward=True - ) - if next_forward_model_chunk_id == (num_model_chunks - 1): - recv_prev = False - next_forward_model_chunk_id += 1 - else: - next_forward_model_chunk_id = get_model_chunk_id(forward_k + 1, forward=True) + recv_prev, next_forward_model_chunk_id = recv_tensor_from_previous_stage( + forward_k, forward=True + ) # If last iteration, don't receive; we already received one extra # before the start of the for loop. @@ -984,54 +1206,85 @@ def backward_step_helper(microbatch_id): # Send activation tensor to the next stage and receive activation tensor from the # previous stage - input_tensor, fwd_wait_handles = p2p_communication.send_forward_recv_forward( - output_tensor, - recv_prev=recv_prev, - tensor_shape=tensor_shape, - config=config, - overlap_p2p_comm=True, + fwd_recv_buffer[forward_k % fwd_recv_buffer_size], fwd_wait_handles = ( + p2p_communication.send_forward_recv_forward( + output_tensor, + recv_prev=recv_prev, + tensor_shape=tensor_shape, + config=config, + overlap_p2p_comm=True, + ) ) + if send_next_wait_handle is not None: + send_next_wait_handle.wait() + if fwd_wait_handles is not None: + send_next_wait_handle = ( + fwd_wait_handles.pop("send_next") if "send_next" in fwd_wait_handles else None + ) + if "recv_prev" in fwd_wait_handles: + recv_prev_wait_handles.append(fwd_wait_handles.pop("recv_prev")) # assert fwd_wait_handles is not None - if bwd_wait_handles is not None: - for req in bwd_wait_handles: - req.wait() - # Backward pass. backward_k = k - input_tensor_grad = backward_step_helper(backward_k) - backward_model_chunk_id = get_model_chunk_id(backward_k, forward=False) parallel_state.set_virtual_pipeline_model_parallel_rank(backward_model_chunk_id) + if not parallel_state.is_pipeline_last_stage(): + if config.overlap_p2p_comm_warmup_flush: + assert recv_next_wait_handles, ( + f'pp rank {pipeline_parallel_rank}, bwd iteration {backward_k}, ' + 'should have registered recv next handle' + ) + recv_next_wait_handle = recv_next_wait_handles.pop(0) + recv_next_wait_handle.wait() + else: + if recv_next_wait_handles is not None and recv_next_wait_handles: + recv_next_wait_handle = recv_next_wait_handles.pop(0) + recv_next_wait_handle.wait() + + input_tensor_grad = backward_step_helper(backward_k) - # First virtual stage no activation gradient tensor to send + # First virtual stage no activation gradient tensor to send. if parallel_state.is_pipeline_first_stage(): input_tensor_grad = None - # Determine if the current virtual stage has an activation gradient tensor to receive - recv_next = True - if parallel_state.is_pipeline_last_stage(ignore_virtual=True): - # Last stage is ahead of first stage by (pipeline_parallel_size - 1). - next_backward_model_chunk_id = get_model_chunk_id( - backward_k - (pipeline_parallel_size - 1), forward=False - ) - if next_backward_model_chunk_id == 0: - recv_next = False - next_backward_model_chunk_id -= 1 - else: - next_backward_model_chunk_id = get_model_chunk_id(backward_k + 1, forward=False) - - output_tensor_grad, bwd_wait_handles = p2p_communication.send_backward_recv_backward( - input_tensor_grad, - recv_next=recv_next, - tensor_shape=tensor_shape, - config=config, - overlap_p2p_comm=True, + recv_next, next_backward_model_chunk_id = recv_tensor_from_previous_stage( + backward_k, forward=False ) - else: # no p2p overlap + (bwd_recv_buffer[backward_k % bwd_recv_buffer_size], bwd_wait_handles) = ( + p2p_communication.send_backward_recv_backward( + input_tensor_grad, + recv_next=recv_next, + tensor_shape=tensor_shape, + config=config, + overlap_p2p_comm=True, + ) + ) + if send_prev_wait_handle is not None: + send_prev_wait_handle.wait() + if bwd_wait_handles is not None: + send_prev_wait_handle = ( + bwd_wait_handles.pop("send_prev") if "send_prev" in bwd_wait_handles else None + ) + if "recv_next" in bwd_wait_handles: + recv_next_wait_handles.append(bwd_wait_handles.pop("recv_next")) + + # Put input_tensor and output_tensor_grad in data structures in the + # right location. + if recv_prev: + input_tensors[next_forward_model_chunk_id].append( + fwd_recv_buffer[forward_k % fwd_recv_buffer_size] + ) + fwd_recv_buffer[(forward_k + 1) % fwd_recv_buffer_size] = None + if recv_next: + output_tensor_grads[next_backward_model_chunk_id].append( + bwd_recv_buffer[backward_k % bwd_recv_buffer_size] + ) + bwd_recv_buffer[(backward_k + 1) % bwd_recv_buffer_size] = None + else: # No p2p overlap. output_tensor = forward_step_helper( - forward_k, current_microbatch, checkpoint_activations_microbatch + forward_k, microbatch_id, checkpoint_activations_microbatch ) # Backward pass. @@ -1053,31 +1306,13 @@ def backward_step_helper(microbatch_id): if parallel_state.is_pipeline_first_stage(): input_tensor_grad = None - # Determine if peers are sending, and where in data structure to put - # received tensors. - recv_prev = True - if parallel_state.is_pipeline_first_stage(ignore_virtual=True): - # First stage is ahead of last stage by (pipeline_parallel_size - 1). - next_forward_model_chunk_id = get_model_chunk_id( - forward_k - (pipeline_parallel_size - 1), forward=True - ) - if next_forward_model_chunk_id == (num_model_chunks - 1): - recv_prev = False - next_forward_model_chunk_id += 1 - else: - next_forward_model_chunk_id = get_model_chunk_id(forward_k + 1, forward=True) + recv_prev, next_forward_model_chunk_id = recv_tensor_from_previous_stage( + forward_k, forward=True + ) - recv_next = True - if parallel_state.is_pipeline_last_stage(ignore_virtual=True): - # Last stage is ahead of first stage by (pipeline_parallel_size - 1). - next_backward_model_chunk_id = get_model_chunk_id( - backward_k - (pipeline_parallel_size - 1), forward=False - ) - if next_backward_model_chunk_id == 0: - recv_next = False - next_backward_model_chunk_id -= 1 - else: - next_backward_model_chunk_id = get_model_chunk_id(backward_k + 1, forward=False) + recv_next, next_backward_model_chunk_id = recv_tensor_from_previous_stage( + backward_k, forward=False + ) # If last iteration, don't receive; we already received one extra # before the start of the for loop. @@ -1097,39 +1332,117 @@ def backward_step_helper(microbatch_id): ) deallocate_output_tensor(output_tensor, config.deallocate_pipeline_outputs) - # Put input_tensor and output_tensor_grad in data structures in the - # right location. - if recv_prev: - input_tensors[next_forward_model_chunk_id].append(input_tensor) - if recv_next: - output_tensor_grads[next_backward_model_chunk_id].append(output_tensor_grad) + # Put input_tensor and output_tensor_grad in data structures in the + # right location. + if recv_prev: + input_tensors[next_forward_model_chunk_id].append(input_tensor) + if recv_next: + output_tensor_grads[next_backward_model_chunk_id].append(output_tensor_grad) deallocate_output_tensor(output_tensor, config.deallocate_pipeline_outputs) # Run cooldown backward passes (flush out pipeline). if not forward_only: - if config.overlap_p2p_comm and bwd_wait_handles is not None: - for wait_handle in bwd_wait_handles: - wait_handle.wait() + if bwd_wait_handles is not None: + for bwd_wait_handle in bwd_wait_handles.values(): + bwd_wait_handle.wait() if all_warmup_microbatches: output_tensor_grads[num_model_chunks - 1].append( p2p_communication.recv_backward(tensor_shape, config=config) ) for k in range(num_microbatches_remaining, total_num_microbatches): - input_tensor_grad = backward_step_helper(k) - next_backward_model_chunk_id = get_model_chunk_id(k + 1, forward=False) - recv_next = True - if parallel_state.is_pipeline_last_stage(ignore_virtual=True): - if next_backward_model_chunk_id == (num_model_chunks - 1): - recv_next = False + cur_model_chunk_id = get_model_chunk_id(k, forward=False) + parallel_state.set_virtual_pipeline_model_parallel_rank(cur_model_chunk_id) + if not parallel_state.is_pipeline_last_stage() and k != 0: + if config.overlap_p2p_comm_warmup_flush: + assert recv_next_wait_handles, ( + f'pp rank {pipeline_parallel_rank}, backward iteration {k}, ' + 'should have registered recv next handle' + ) + recv_next_wait_handle = recv_next_wait_handles.pop(0) + recv_next_wait_handle.wait() + else: + if recv_next_wait_handles is not None and recv_next_wait_handles: + recv_next_wait_handle = recv_next_wait_handles.pop(0) + recv_next_wait_handle.wait() + + recv_next, next_backward_model_chunk_id = recv_tensor_from_previous_stage( + k, forward=False + ) + if k == (total_num_microbatches - 1): recv_next = False - output_tensor_grads[next_backward_model_chunk_id].append( - p2p_communication.send_backward_recv_backward( + + # Prefetch recv for backward iteration k+1 for non last ranks. + if config.overlap_p2p_comm_warmup_flush and not parallel_state.is_pipeline_last_stage( + ignore_virtual=True + ): + bwd_recv_buffer[k % bwd_recv_buffer_size], bwd_wait_recv_handles = ( + p2p_communication.send_backward_recv_backward( + input_tensor_grad=None, # No input_tensor_grad to send. + recv_next=recv_next, + tensor_shape=tensor_shape, + config=config, + overlap_p2p_comm=True, + ) + ) + + if bwd_wait_recv_handles: + recv_next_wait_handles.append(bwd_wait_recv_handles.pop("recv_next")) + + input_tensor_grad = backward_step_helper(k) + + # First virtual stage no activation gradient tensor to send. + if parallel_state.is_pipeline_first_stage(): + input_tensor_grad = None + + if config.overlap_p2p_comm_warmup_flush: + if not parallel_state.is_pipeline_last_stage(ignore_virtual=True): + _, bwd_wait_handles = p2p_communication.send_backward_recv_backward( + input_tensor_grad, + recv_next=False, + tensor_shape=tensor_shape, + config=config, + overlap_p2p_comm=True, + ) + else: + bwd_recv_buffer[k % bwd_recv_buffer_size], bwd_wait_handles = ( + p2p_communication.send_backward_recv_backward( + input_tensor_grad, + recv_next=recv_next, + tensor_shape=tensor_shape, + config=config, + overlap_p2p_comm=True, + ) + ) + + if send_prev_wait_handle is not None: + send_prev_wait_handle.wait() + if bwd_wait_handles is not None: + send_prev_wait_handle = ( + bwd_wait_handles.pop("send_prev") + if "send_prev" in bwd_wait_handles + else None + ) + if "recv_next" in bwd_wait_handles: + recv_next_wait_handles.append(bwd_wait_handles.pop("recv_next")) + if recv_next: + output_tensor_grads[next_backward_model_chunk_id].append( + bwd_recv_buffer[k % bwd_recv_buffer_size] + ) + bwd_recv_buffer[(k + 1) % bwd_recv_buffer_size] = None + + else: + output_tensor_grad = p2p_communication.send_backward_recv_backward( input_tensor_grad, recv_next=recv_next, tensor_shape=tensor_shape, config=config ) - ) + + if recv_next: + output_tensor_grads[next_backward_model_chunk_id].append(output_tensor_grad) + + if send_prev_wait_handle is not None: + send_prev_wait_handle.wait() # Launch any remaining grad reductions. enable_grad_sync() @@ -1139,6 +1452,13 @@ def backward_step_helper(microbatch_id): config.grad_sync_func[model_chunk_id](model[model_chunk_id].parameters()) synchronized_model_chunks.add(model_chunk_id) + assert ( + not recv_prev_wait_handles + ), 'recv_prev_wait_handles should be cleared at the end of a step' + assert ( + not recv_next_wait_handles + ), 'recv_next_wait_handles should be cleared at the end of a step' + if config.finalize_model_grads_func is not None and not forward_only: # If defer_embedding_wgrad_compute is enabled we need to do the @@ -1208,7 +1528,7 @@ def get_tensor_shapes( def recv_forward(tensor_shapes, config): - """recv forward.""" + """Wrapper for p2p_communication.recv_forward used with non-interleaving schedule.""" input_tensors = [] for tensor_shape in tensor_shapes: if tensor_shape is None: @@ -1219,7 +1539,7 @@ def recv_forward(tensor_shapes, config): def recv_backward(tensor_shapes, config): - """recv backward.""" + """Wrapper for p2p_communication.recv_backward used with non-interleaving schedule.""" output_tensor_grads = [] for tensor_shape in tensor_shapes: if tensor_shape is None: @@ -1230,7 +1550,7 @@ def recv_backward(tensor_shapes, config): def send_forward(output_tensors, tensor_shapes, config): - """send forward.""" + """Wrapper for p2p_communication.send_forward used with non-interleaving schedule.""" if not isinstance(output_tensors, list): output_tensors = [output_tensors] for output_tensor, tensor_shape in zip(output_tensors, tensor_shapes): @@ -1240,7 +1560,7 @@ def send_forward(output_tensors, tensor_shapes, config): def send_backward(input_tensor_grads, tensor_shapes, config): - """send backward.""" + """Wrapper for p2p_communication.send_backward used with non-interleaving schedule.""" if not isinstance(input_tensor_grads, list): input_tensor_grads = [input_tensor_grads] for input_tensor_grad, tensor_shape in zip(input_tensor_grads, tensor_shapes): @@ -1250,7 +1570,8 @@ def send_backward(input_tensor_grads, tensor_shapes, config): def send_forward_recv_backward(output_tensors, tensor_shapes, config): - """send forward and recv backward.""" + """Wrapper for p2p_communication.send_forward_recv_backward used + with non-interleaving schedule.""" if not isinstance(output_tensors, list): output_tensors = [output_tensors] output_tensor_grads = [] @@ -1266,7 +1587,8 @@ def send_forward_recv_backward(output_tensors, tensor_shapes, config): def send_backward_recv_forward(input_tensor_grads, tensor_shapes, config): - """send backward and recv forward.""" + """Wrapper for p2p_communication.send_backward_recv_forward used + with non-interleaving schedule.""" if not isinstance(input_tensor_grads, list): input_tensor_grads = [input_tensor_grads] input_tensors = [] diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py index 9fad373c9a..a48d95129a 100644 --- a/megatron/training/arguments.py +++ b/megatron/training/arguments.py @@ -1543,9 +1543,15 @@ def _add_distributed_args(parser): '--tensor-model-parallel-size instead.') group.add_argument('--num-layers-per-virtual-pipeline-stage', type=int, default=None, help='Number of layers per virtual pipeline stage') + group.add_argument('--microbatch-group-size-per-virtual-pipeline-stage', type=int, default=None, + help='Number of contiguous microbatches per virtual pipeline stage', + dest='microbatch_group_size_per_vp_stage') group.add_argument('--no-overlap-p2p-communication', action='store_false', - help='overlap pipeline parallel communication with forward and backward chunks', + help='overlap pipeline parallel communication with forward and backward chunks in 1F1B', dest='overlap_p2p_comm') + group.add_argument('--overlap-p2p-communication-warmup-flush', action='store_true', + default=False, help='if set, overlap pipeline parallel communication in warmup and flush', + dest='overlap_p2p_comm_warmup_flush') group.add_argument('--distributed-backend', default='nccl', choices=['nccl', 'gloo'], help='Which backend to use for distributed training.') diff --git a/tests/functional_tests/jet_recipes/gpt.yaml b/tests/functional_tests/jet_recipes/gpt.yaml index 0615032d35..957db69326 100644 --- a/tests/functional_tests/jet_recipes/gpt.yaml +++ b/tests/functional_tests/jet_recipes/gpt.yaml @@ -57,6 +57,7 @@ products: - gpt3_mr_mcore_te_tp1_pp4_vp1_decoupled_lr_dgx_a100_1N8G - gpt3_mr_mcore_te_tp1_pp4_vp1_uneven_pipeline_dgx_a100_1N8G - gpt3_mr_mcore_te_tp1_pp4_vp1_dgx_a100_1N8G + - gpt3_mr_mcore_te_tp1_pp4_vp1_tunable_overlap_dgx_a100_1N8G - gpt3_mr_mcore_te_tp1_pp4_vp1_dist_optimizer_overlap_grad_reduce_param_gather_overlap_optimizer_dgx_a100_1N8G - gpt3_mr_mcore_te_tp1_pp4_vp1_dist_optimizer_overlap_grad_reduce_untied_dgx_a100_1N8G - gpt3_mr_mcore_te_tp1_pp4_vp1_resume_torch_decoupled_lr_dgx_a100_1N8G @@ -65,6 +66,7 @@ products: - gpt3_mr_mcore_te_tp1_pp4_vp1_resume_torch_dist_dist_optimizer_overlap_grad_reduce_dgx_a100_1N8G - gpt3_mr_mcore_te_tp1_pp4_vp1_resume_torch_dist_dist_optimizer_overlap_grad_reduce_param_gather_dgx_a100_1N8G - gpt3_mr_mcore_te_tp1_pp4_vp1_resume_torch_dist_dist_optimizer_overlap_grad_reduce_untied_dgx_a100_1N8G + - gpt3_mr_mcore_te_tp1_pp4_vp1_resume_torch_dist_tunable_overlap_dgx_a100_1N8G - gpt3_mr_mcore_te_tp2_pp1_resume_torch_dist_te_8experts2parallel_dist_optimizer_dgx_a100_1N8G - gpt3_mr_mcore_te_tp2_pp1_resume_torch_dist_te_8experts2parallel_groupedGEMM_dgx_a100_1N8G - gpt3_mr_mcore_te_tp2_pp1_resume_torch_dist_te_8experts2parallel_top2router_dgx_a100_1N8G diff --git a/tests/functional_tests/test_cases/gpt/gpt3_mr_mcore_te_tp1_pp4_vp1_resume_torch_dist_tunable_overlap_dgx_a100_1N8G/model_config.yaml b/tests/functional_tests/test_cases/gpt/gpt3_mr_mcore_te_tp1_pp4_vp1_resume_torch_dist_tunable_overlap_dgx_a100_1N8G/model_config.yaml new file mode 100644 index 0000000000..49bd5f94c5 --- /dev/null +++ b/tests/functional_tests/test_cases/gpt/gpt3_mr_mcore_te_tp1_pp4_vp1_resume_torch_dist_tunable_overlap_dgx_a100_1N8G/model_config.yaml @@ -0,0 +1,54 @@ +ENV_VARS: + CUDA_DEVICE_MAX_CONNECTIONS: 1 + NVTE_ALLOW_NONDETERMINISTIC_ALGO: 0 + NCCL_ALGO: Tree + CUBLAS_WORKSPACE_CONFIG: :4096:8 + N_REPEATS: 5 +MODEL_ARGS: + --num-layers: 12 + --hidden-size: 512 + --num-attention-heads: 8 + --log-params-norm: true + --log-num-zeros-in-grad: true + --log-validation-ppl-to-tensorboard: true + --log-timers-to-tensorboard: true + --tensorboard-dir: ${TENSORBOARD_PATH} + --micro-batch-size: 4 + --global-batch-size: 40 + --seq-length: 1024 + --max-position-embeddings: 1024 + --train-iters: 100 + --timing-log-level: 2 + --lr-decay-iters: 320000 + --save: ${CHECKPOINT_PATH} + --load: ${CHECKPOINT_PATH} + --data-path: ${DATA_PATH}/my-gpt3_00_text_document + --vocab-file: ${DATA_PATH}/bpe/vocab.json + --merge-file: ${DATA_PATH}/bpe/merges.txt + --split: 949,50,1 + --distributed-backend: nccl + --lr: 0.00015 + --lr-decay-style: cosine + --min-lr: 1.0e-5 + --weight-decay: 1e-2 + --clip-grad: 1.0 + --lr-warmup-fraction: .01 + --log-interval: 1 + --save-interval: 50 + --eval-interval: 1000 + --eval-iters: 10 + --transformer-impl: transformer_engine + --tensor-model-parallel-size: 1 + --pipeline-model-parallel-size: 4 + --num-layers-per-virtual-pipeline-stage: 1 + --overlap-p2p-communication-warmup-flush: true + --microbatch-group-size-per-virtual-pipeline-stage: 5 + --deterministic-mode: true + --no-gradient-accumulation-fusion: true + --attention-softmax-in-fp32: true + --use-checkpoint-opt_param-scheduler: true + --use-mcore-models: true + --ckpt-format: torch_dist + --data-cache-path: ${DATA_CACHE_PATH} + --bf16: true +TEST_TYPE: ckpt-resume diff --git a/tests/functional_tests/test_cases/gpt/gpt3_mr_mcore_te_tp1_pp4_vp1_tunable_overlap_dgx_a100_1N8G/golden_values_dev.json b/tests/functional_tests/test_cases/gpt/gpt3_mr_mcore_te_tp1_pp4_vp1_tunable_overlap_dgx_a100_1N8G/golden_values_dev.json new file mode 100644 index 0000000000..a03d56c822 --- /dev/null +++ b/tests/functional_tests/test_cases/gpt/gpt3_mr_mcore_te_tp1_pp4_vp1_tunable_overlap_dgx_a100_1N8G/golden_values_dev.json @@ -0,0 +1 @@ +{"lm loss": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [10.81184, 10.84052, 10.8763, 10.79906, 10.68214, 10.59702, 10.49258, 10.11236, 10.12393, 9.98165]}, "num-zeros": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [1118.0, 1331.0, 1230.0, 1085.0, 1180.0, 1245.0, 1454.0, 1330.0, 1752.0, 1851.0]}, "iteration-time": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [17.24286, 0.35341, 0.35187, 0.35028, 0.34941, 0.35093, 0.3488, 0.35179, 0.34905, 0.34684]}} \ No newline at end of file diff --git a/tests/functional_tests/test_cases/gpt/gpt3_mr_mcore_te_tp1_pp4_vp1_tunable_overlap_dgx_a100_1N8G/golden_values_lts.json b/tests/functional_tests/test_cases/gpt/gpt3_mr_mcore_te_tp1_pp4_vp1_tunable_overlap_dgx_a100_1N8G/golden_values_lts.json new file mode 100644 index 0000000000..91c3ae6977 --- /dev/null +++ b/tests/functional_tests/test_cases/gpt/gpt3_mr_mcore_te_tp1_pp4_vp1_tunable_overlap_dgx_a100_1N8G/golden_values_lts.json @@ -0,0 +1 @@ +{"lm loss": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [10.81184, 10.84052, 10.87624, 10.79904, 10.68212, 10.59698, 10.49257, 10.11232, 10.12396, 9.98163]}, "num-zeros": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [1125.0, 1304.0, 1252.0, 1102.0, 1201.0, 1200.0, 1489.0, 1395.0, 1677.0, 1867.0]}, "num-zeros vs samples": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [1125.0, 1304.0, 1252.0, 1102.0, 1201.0, 1200.0, 1489.0, 1395.0, 1677.0, 1867.0]}, "iteration-time": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [22.22011, 0.36082, 0.35927, 0.35627, 0.35901, 0.35008, 0.34828, 0.34774, 0.35145, 0.35141]}} \ No newline at end of file diff --git a/tests/functional_tests/test_cases/gpt/gpt3_mr_mcore_te_tp1_pp4_vp1_tunable_overlap_dgx_a100_1N8G/model_config.yaml b/tests/functional_tests/test_cases/gpt/gpt3_mr_mcore_te_tp1_pp4_vp1_tunable_overlap_dgx_a100_1N8G/model_config.yaml new file mode 100644 index 0000000000..ee9b7ec957 --- /dev/null +++ b/tests/functional_tests/test_cases/gpt/gpt3_mr_mcore_te_tp1_pp4_vp1_tunable_overlap_dgx_a100_1N8G/model_config.yaml @@ -0,0 +1,53 @@ +ENV_VARS: + CUDA_DEVICE_MAX_CONNECTIONS: 1 + NVTE_ALLOW_NONDETERMINISTIC_ALGO: 0 + NCCL_ALGO: Tree + CUBLAS_WORKSPACE_CONFIG: :4096:8 + N_REPEATS: 5 +MODEL_ARGS: + --num-layers: 12 + --hidden-size: 512 + --num-attention-heads: 8 + --log-params-norm: true + --log-num-zeros-in-grad: true + --log-validation-ppl-to-tensorboard: true + --log-timers-to-tensorboard: true + --tensorboard-dir: ${TENSORBOARD_PATH} + --micro-batch-size: 4 + --global-batch-size: 40 + --seq-length: 1024 + --max-position-embeddings: 1024 + --train-iters: 50 + --timing-log-level: 2 + --lr-decay-iters: 320000 + --save: ${CHECKPOINT_PATH} + --load: ${CHECKPOINT_PATH} + --data-path: ${DATA_PATH}/my-gpt3_00_text_document + --vocab-file: ${DATA_PATH}/bpe/vocab.json + --merge-file: ${DATA_PATH}/bpe/merges.txt + --split: 949,50,1 + --distributed-backend: nccl + --lr: 0.00015 + --lr-decay-style: cosine + --min-lr: 1.0e-5 + --weight-decay: 1e-2 + --clip-grad: 1.0 + --lr-warmup-fraction: .01 + --log-interval: 1 + --save-interval: 10000 + --eval-interval: 1000 + --eval-iters: 10 + --transformer-impl: transformer_engine + --tensor-model-parallel-size: 1 + --pipeline-model-parallel-size: 4 + --num-layers-per-virtual-pipeline-stage: 1 + --overlap-p2p-communication-warmup-flush: true + --microbatch-group-size-per-virtual-pipeline-stage: 5 + --deterministic-mode: true + --no-gradient-accumulation-fusion: true + --attention-softmax-in-fp32: true + --use-mcore-models: true + --ckpt-format: torch_dist + --data-cache-path: ${DATA_CACHE_PATH} + --bf16: true +TEST_TYPE: regular diff --git a/tests/unit_tests/pipeline_parallel/test_helpers.py b/tests/unit_tests/pipeline_parallel/test_helpers.py new file mode 100644 index 0000000000..a20c3a5401 --- /dev/null +++ b/tests/unit_tests/pipeline_parallel/test_helpers.py @@ -0,0 +1,124 @@ +def compare_helpers(pipeline_parallel_size, num_microbatches, num_model_chunks): + total_num_microbatches = num_microbatches * num_model_chunks + + # Baseline helpers + def baseline_get_model_chunk_id(microbatch_id, forward): + """Helper method to get the model chunk ID given the iteration number.""" + microbatch_id_in_group = microbatch_id % (pipeline_parallel_size * num_model_chunks) + model_chunk_id = microbatch_id_in_group // pipeline_parallel_size + if not forward: + model_chunk_id = num_model_chunks - model_chunk_id - 1 + return model_chunk_id + + def baseline_get_microbatch_id_in_model_chunk(iteration_id, forward): + """Helper method to get the microbatch_id within model chunk given the iteration number.""" + assert forward + iteration_group_id = iteration_id // (pipeline_parallel_size * num_model_chunks) + microbatch_id_in_model_chunk = (iteration_group_id * pipeline_parallel_size) + ( + iteration_id % pipeline_parallel_size + ) + return microbatch_id_in_model_chunk + + def baseline_is_first_microbatch_for_model_chunk(microbatch_id: int) -> bool: + """Check if an iteration is the first for a model chunk.""" + microbatch_group_size = pipeline_parallel_size * num_model_chunks + microbatch_group_id = microbatch_id // microbatch_group_size + microbatch_id_in_group = microbatch_id % microbatch_group_size + if microbatch_group_id == 0: + return microbatch_id_in_group % pipeline_parallel_size == 0 + else: + return False + + def baseline_is_last_microbatch_for_model_chunk(microbatch_id: int) -> bool: + """Check if an iteration is the last for a model chunk.""" + microbatch_group_size = pipeline_parallel_size * num_model_chunks + num_microbatch_groups = total_num_microbatches // microbatch_group_size + microbatch_group_id = microbatch_id // microbatch_group_size + microbatch_id_in_group = microbatch_id % microbatch_group_size + if microbatch_group_id == num_microbatch_groups - 1: + return microbatch_id_in_group % pipeline_parallel_size == pipeline_parallel_size - 1 + else: + return False + + # Create schedule table prior to new helper methods + schedule_table = [] + for min_microbatch_id_in_group in range(0, num_microbatches, pipeline_parallel_size): + if min_microbatch_id_in_group + pipeline_parallel_size >= num_microbatches: + # Construct schedule for the last microbatch group + schedule_table.extend( + [ + (microbatch_id, model_chunk_id) + for model_chunk_id in range(num_model_chunks) + for microbatch_id in range(min_microbatch_id_in_group, num_microbatches) + ] + ) + else: + # Construct schedule for other microbatch groups + schedule_table.extend( + [ + (microbatch_id, model_chunk_id) + for model_chunk_id in range(num_model_chunks) + for microbatch_id in range( + min_microbatch_id_in_group, + min_microbatch_id_in_group + pipeline_parallel_size, + ) + ] + ) + + microbatch_id_table, model_chunk_id_table = zip(*schedule_table) + + # New helper methods that indexes schedule table + def new_get_model_chunk_id(virtual_microbatch_id, forward): + """Helper method to get the model chunk ID given the iteration number.""" + model_chunk_id = model_chunk_id_table[virtual_microbatch_id % total_num_microbatches] + if not forward: + model_chunk_id = num_model_chunks - model_chunk_id - 1 + return model_chunk_id + + def new_get_microbatch_id_in_model_chunk(iteration_id, forward): + """Helper method to get the microbatch_id within model chunk given the iteration number.""" + assert forward + microbatch_id_in_model_chunk = microbatch_id_table[iteration_id] + return microbatch_id_in_model_chunk + + def new_is_first_microbatch_for_model_chunk(virtual_microbatch_id: int) -> bool: + """Check if an iteration is the first for a model chunk.""" + if virtual_microbatch_id < total_num_microbatches: + return microbatch_id_table[virtual_microbatch_id] == 0 + else: + return False + + def new_is_last_microbatch_for_model_chunk(virtual_microbatch_id: int) -> bool: + """Check if an iteration is the last for a model chunk.""" + if virtual_microbatch_id < total_num_microbatches: + return microbatch_id_table[virtual_microbatch_id] == num_microbatches - 1 + else: + return False + + for i in range(total_num_microbatches): + # Test both forward and backward + assert baseline_get_model_chunk_id(i, forward=False) == new_get_model_chunk_id( + i, forward=False + ) + assert baseline_get_model_chunk_id(i, forward=True) == new_get_model_chunk_id( + i, forward=True + ) + + # Only used in forward + assert baseline_get_microbatch_id_in_model_chunk( + i, forward=True + ) == new_get_microbatch_id_in_model_chunk(i, forward=True) + + assert baseline_is_first_microbatch_for_model_chunk( + i + ) == new_is_first_microbatch_for_model_chunk(i) + assert baseline_is_last_microbatch_for_model_chunk( + i + ) == new_is_last_microbatch_for_model_chunk(i) + + +def test_helpers(): + for pp in [2, 4, 8]: + for m in [pp, 2 * pp, 4 * pp, 8 * pp]: + for vp in range(2, 13): + compare_helpers(pipeline_parallel_size=pp, num_microbatches=m, num_model_chunks=vp)