From 938e5c8a0c96fe5037aa54c269ce536e03e9a70b Mon Sep 17 00:00:00 2001 From: Tyler Poon Date: Sat, 23 Nov 2024 08:53:43 -0800 Subject: [PATCH] ADLR/megatron-lm!2289 - pp > 1 online evaluation Co-authored-by: Tyler Poon --- examples/multimodal/run_text_generation.py | 58 +++++++++++++++---- .../core/models/multimodal/llava_model.py | 3 + megatron/core/parallel_state.py | 13 +++++ .../text_generation/communication.py | 45 +++++++++----- .../inference/text_generation/forward_step.py | 34 +++++++---- tests/unit_tests/models/test_llava_model.py | 2 + 6 files changed, 118 insertions(+), 37 deletions(-) diff --git a/examples/multimodal/run_text_generation.py b/examples/multimodal/run_text_generation.py index 1da2e71646..fd35966e27 100644 --- a/examples/multimodal/run_text_generation.py +++ b/examples/multimodal/run_text_generation.py @@ -22,7 +22,8 @@ from megatron.core.models.vision.clip_vit_model import get_num_image_embeddings from megatron.inference.text_generation.api import generate_and_post_process from megatron.inference.text_generation.forward_step import ForwardStep -from megatron.training import get_args, get_model +from megatron.inference.text_generation.communication import broadcast_int_list +from megatron.training import get_args, get_model, get_tokenizer, print_rank_0 from megatron.training.checkpointing import load_checkpoint from megatron.training.initialize import initialize_megatron @@ -156,7 +157,7 @@ def generate_samples(model, config: EvaluationConfig, print_output): conv = get_conversation(config.task, question) - forward_step = partial(VLMForwardStep, num_img_embeddings_per_tile, imgs, num_tiles) + forward_step = partial(VLMForwardStep, num_img_embeddings_per_tile, imgs, num_tiles, args.decoder_seq_length) if is_first_rank(): resp_sentences, _, _, _ = generate_and_post_process( @@ -316,6 +317,7 @@ def __init__( num_img_embeddings_per_tile, images, num_tiles, + decoder_seq_length, model, max_batch_size, max_sequence_length, @@ -327,6 +329,18 @@ def __init__( super().__init__(model, max_batch_size, max_sequence_length + num_img_embeddings) self._images = images self._num_tiles = num_tiles + self._num_img_embeddings = num_img_embeddings + self.decoder_seq_length = decoder_seq_length + + self._recv_only_vision_embeds = False + pp_rank = parallel_state.get_pipeline_model_parallel_rank() + # Checks if the previous stage only has a vision encoder, and that the current stage has part of the LM decoder. + # In this case, the current stage should only receive vision embeddings. + if pp_rank > 0: + self._recv_only_vision_embeds = parallel_state.is_inside_encoder(pp_rank - 1) and (not parallel_state.is_inside_decoder(pp_rank - 1)) and parallel_state.is_inside_decoder() + + # Checks if the current stage only has a vision encoder + self._encoder_only = parallel_state.is_inside_encoder() and not parallel_state.is_inside_decoder() def _forward(self, tokens, position_ids, attention_mask): return self.model( @@ -340,20 +354,44 @@ def _forward(self, tokens, position_ids, attention_mask): ) def __call__(self, tokens, position_ids, attention_mask): - output = super().__call__(tokens, position_ids, attention_mask) + num_image_tokens = (tokens == self.model.image_token_index).sum().item() + num_tokens = tokens.size(1) + recv_buffer_seq_length = None + if num_image_tokens > 0: + # When there are image tokens and this stage only receives vision embeddings, adjust the recv buffer seq length to match the image embeddings sequence length. + # If there are image tokens and this stage receives full embeddings, make sure we compensate for expansion of image tokens. + # Note that this will set a recv_buffer_seq_length for the encoder stage, this length is irrelevant since that recv buffer is never allocated. + if self._recv_only_vision_embeds: + recv_buffer_seq_length = self._num_img_embeddings + else: + recv_buffer_seq_length = min(self._num_img_embeddings + num_tokens - num_image_tokens, self.decoder_seq_length) + elif self._recv_only_vision_embeds: + # If this stage only receives vision embeddings and there are no image tokens we won't run the encoder and therefore shouldn't try to recv. + recv_buffer_seq_length = 0 + + # If the pipeline stage only has a vision encoder, then it only needs to run when there are image tokens + if not (self._encoder_only and num_image_tokens == 0): + output = super().__call__(tokens, position_ids, attention_mask, recv_buffer_seq_length=recv_buffer_seq_length) + else: + output = None if isinstance(output, tuple): - logits = output[0] + logits, _ = output else: logits = output # On the first inference iteration, we compute image tokens. - # Update the sequence length offset by the number of image tokens. - num_image_tokens = (tokens == self.model.module.image_token_index).sum().item() - num_tokens = tokens.size(1) + # On every PP stage(although inference params should only matter for decoder), + # update the sequence length offset by the number of image tokens. if num_tokens > 1 and num_image_tokens > 0: - self.inference_params.sequence_len_offset += ( - self.inference_params.key_value_memory_dict["image_tokens_count"] - num_image_tokens - ) + if "image_tokens_count" not in self.inference_params.key_value_memory_dict: + self.inference_params.key_value_memory_dict["image_tokens_count"] = self._num_img_embeddings + + if self._num_img_embeddings + num_tokens - num_image_tokens > self.decoder_seq_length: + self.inference_params.sequence_len_offset += self.decoder_seq_length - num_tokens + else: + self.inference_params.sequence_len_offset += ( + self.inference_params.key_value_memory_dict["image_tokens_count"] - num_image_tokens + ) return logits diff --git a/megatron/core/models/multimodal/llava_model.py b/megatron/core/models/multimodal/llava_model.py index 1f6da2f4f6..3b46487f87 100644 --- a/megatron/core/models/multimodal/llava_model.py +++ b/megatron/core/models/multimodal/llava_model.py @@ -272,6 +272,7 @@ def _preprocess_data( loss_mask, labels, use_inference_kv_cache, + inference_params, image_token_index, num_image_tiles, attention_mask, @@ -351,6 +352,7 @@ def _preprocess_data( if ( self._language_is_pipeline_parallel and max_seq_len < self._language_max_sequence_length + and inference_params is None ): max_seq_len = self._language_max_sequence_length @@ -696,6 +698,7 @@ def forward( loss_mask, labels, use_inference_kv_cache, + inference_params, image_token_index if image_token_index is not None else self.image_token_index, num_image_tiles, attention_mask, diff --git a/megatron/core/parallel_state.py b/megatron/core/parallel_state.py index 500c06e17a..f6bd0e3109 100644 --- a/megatron/core/parallel_state.py +++ b/megatron/core/parallel_state.py @@ -74,6 +74,10 @@ # the first local rank in the tensor model parallel group _TENSOR_MODEL_PARALLEL_GLOBAL_RANKS = None +# A list of global ranks for each model parallel group to ease calculation of +# the first local rank in the model parallel group +_MODEL_PARALLEL_GLOBAL_RANKS = None + # Context parallel group that the current rank belongs to _CONTEXT_PARALLEL_GROUP = None # A list of global ranks for each context parallel group to ease calculation of the @@ -739,6 +743,7 @@ def generator_wrapper(group_type, **kwargs): # Build the model-parallel groups. global _MODEL_PARALLEL_GROUP + global _MODEL_PARALLEL_GLOBAL_RANKS assert _MODEL_PARALLEL_GROUP is None, 'model parallel group is already initialized' for ranks in generator_wrapper('tp-pp'): group = torch.distributed.new_group( @@ -746,6 +751,7 @@ def generator_wrapper(group_type, **kwargs): ) if rank in ranks: _MODEL_PARALLEL_GROUP = group + _MODEL_PARALLEL_GLOBAL_RANKS = ranks # Build the model-parallel groups with expert parallel global _MODEL_AND_EXPERT_PARALLEL_GROUP @@ -1386,6 +1392,13 @@ def get_tensor_model_parallel_src_rank(): return _TENSOR_MODEL_PARALLEL_GLOBAL_RANKS[0] +def get_model_parallel_src_rank(): + """Calculate the global rank corresponding to the first local rank + in the model parallel group.""" + assert _MODEL_PARALLEL_GLOBAL_RANKS is not None, "Model parallel group is not initialized" + return _MODEL_PARALLEL_GLOBAL_RANKS[0] + + def get_data_parallel_src_rank(with_context_parallel=False): """Calculate the global rank corresponding to the first local rank in the data parallel group.""" diff --git a/megatron/inference/text_generation/communication.py b/megatron/inference/text_generation/communication.py index a67e0a5e42..c3d5dfefbe 100644 --- a/megatron/inference/text_generation/communication.py +++ b/megatron/inference/text_generation/communication.py @@ -9,7 +9,6 @@ from megatron.core import mpu - # TODO: use functions from megatron/p2p def recv_from_prev_pipeline_rank_(recv_buffer=None): """Receive from previous pipeline stage and update the @@ -25,8 +24,6 @@ def recv_from_prev_pipeline_rank_(recv_buffer=None): # To protect against race condition when using batch_isend_irecv(). torch.cuda.synchronize() - - # TODO: use functions from megatron/p2p def send_to_next_pipeline_rank(tensor=None): """Send output to the next pipeline stage.""" @@ -80,6 +77,29 @@ def broadcast_from_last_pipeline_stage(size, dtype, tensor=None): return tensor +def _send_and_recv_from_last_to_first_pipeline_stage(tensor=None): + is_last_stage = mpu.is_pipeline_last_stage() + is_first_stage = mpu.is_pipeline_first_stage() + + if is_last_stage or is_first_stage: + if is_first_stage: + recv_prev_op = torch.distributed.P2POp( + torch.distributed.irecv, tensor, + mpu.get_pipeline_model_parallel_last_rank()) + reqs = torch.distributed.batch_isend_irecv([recv_prev_op]) + elif is_last_stage: + send_next_op = torch.distributed.P2POp( + torch.distributed.isend, tensor, + mpu.get_pipeline_model_parallel_first_rank()) + reqs = torch.distributed.batch_isend_irecv([send_next_op]) + + for req in reqs: + req.wait() + # To protect against race condition when using batch_isend_irecv(). + torch.cuda.synchronize() + + return tensor + def broadcast_from_last_to_first_pipeline_stage(size, dtype, tensor=None): """Broadcast tensor values from last stage into the first stage.""" @@ -98,10 +118,7 @@ def broadcast_from_last_to_first_pipeline_stage(size, dtype, tensor=None): tensor = torch.empty(size, dtype=dtype, device=torch.cuda.current_device()) - src = mpu.get_pipeline_model_parallel_last_rank() - group = mpu.get_embedding_group() - # Broadcast from last stage into the first stage. - torch.distributed.broadcast(tensor, src, group) + tensor = _send_and_recv_from_last_to_first_pipeline_stage(tensor) else: tensor = None @@ -123,8 +140,6 @@ def copy_from_last_to_first_pipeline_stage(size, dtype, tensor=None): if is_last_stage or is_first_stage: _is_cuda(tensor) is_contiguous = tensor.is_contiguous() - src = mpu.get_pipeline_model_parallel_last_rank() - group = mpu.get_embedding_group() if is_contiguous: tensor_ = tensor else: @@ -134,8 +149,7 @@ def copy_from_last_to_first_pipeline_stage(size, dtype, tensor=None): tensor_ = torch.empty(size, dtype=dtype, device=torch.cuda.current_device()) - # Broadcast from last stage into the first stage. - torch.distributed.broadcast(tensor_, src, group) + tensor_ = _send_and_recv_from_last_to_first_pipeline_stage(tensor_) # Update the first stage tensor if is_first_stage and not is_contiguous: tensor[...] = tensor_ @@ -150,7 +164,7 @@ def broadcast_tensor(size, dtype, tensor=None, rank=0, data_parallel=False): data_parallel (bool): Broadcast across a single data parallel model replica. """ if data_parallel: - rank = parallel_state.get_tensor_model_parallel_src_rank() + rank = parallel_state.get_model_parallel_src_rank() if torch.distributed.get_rank() == rank: _is_cuda_contiguous(tensor) @@ -161,7 +175,7 @@ def broadcast_tensor(size, dtype, tensor=None, rank=0, data_parallel=False): group = None if data_parallel: - group = parallel_state.get_tensor_model_parallel_group() + group = parallel_state.get_model_parallel_group() torch.distributed.broadcast(tensor, rank, group=group) @@ -179,12 +193,11 @@ def broadcast_list(size, dtype, list_values=None, rank=0, data_parallel=False): tensor = None if data_parallel: - src_rank = parallel_state.get_data_parallel_src_rank() - if src_rank == 0: + if parallel_state.get_model_parallel_src_rank() == torch.distributed.get_rank(): tensor = torch.tensor(list_values, dtype=dtype, device=torch.cuda.current_device()) - rank = parallel_state.get_tensor_model_parallel_src_rank() + rank = parallel_state.get_model_parallel_src_rank() else: if torch.distributed.get_rank() == rank: tensor = torch.tensor(list_values, dtype=dtype, diff --git a/megatron/inference/text_generation/forward_step.py b/megatron/inference/text_generation/forward_step.py index 5340e44da9..0a89936ed2 100644 --- a/megatron/inference/text_generation/forward_step.py +++ b/megatron/inference/text_generation/forward_step.py @@ -39,7 +39,7 @@ def __init__(self, model, max_batch_size, max_sequence_length): def _forward(self, tokens, position_ids, attention_mask): return self.model(tokens, position_ids, attention_mask, inference_params=self.inference_params) - def __call__(self, tokens, position_ids, attention_mask): + def __call__(self, tokens, position_ids, attention_mask, recv_buffer_seq_length=None): """Invocation of the forward methods. Note that self.inference_params is being modified by the forward step.""" # Pipelining case. @@ -47,18 +47,25 @@ def __call__(self, tokens, position_ids, attention_mask): # and requires setting args.pipeline_model_parallel > 1. The batch will be split into # smaller microbatches to be pipelined through the stages. if self.pipeline_size_larger_than_one: - current_batch_x_seqlen = tokens.size(0) * tokens.size(1) + seq_len = tokens.size(1) if recv_buffer_seq_length is None else recv_buffer_seq_length + current_batch_x_seqlen = tokens.size(0) * seq_len if current_batch_x_seqlen >= self.pipelining_batch_x_seqlen: micro_batch_size = \ - max(1, self.pipelining_batch_x_seqlen // tokens.size(1)) + max(1, self.pipelining_batch_x_seqlen // seq_len) return self._with_pipelining_forward_step(tokens, position_ids, attention_mask, - micro_batch_size) - # Do not pipeline the batch; the entire batch will be passed through all at once. + micro_batch_size, + recv_buffer_seq_length=recv_buffer_seq_length) + + recv_buffer = None + if recv_buffer_seq_length is not None: + recv_buffer = _allocate_recv_buffer(tokens.size(0), recv_buffer_seq_length) + return self._no_pipelining_forward_step(tokens, position_ids, - attention_mask) + attention_mask, + recv_buffer=recv_buffer) def _forward_step_helper(self, tokens, position_ids, attention_mask, recv_buffer=None): @@ -66,15 +73,20 @@ def _forward_step_helper(self, tokens, position_ids, attention_mask, recv_buffer only the first time the memory is allocated.""" batch_size = tokens.size(0) sequence_length = tokens.size(1) + if recv_buffer is None: recv_buffer = _allocate_recv_buffer(batch_size, sequence_length) # Receive from previous stage. - recv_from_prev_pipeline_rank_(recv_buffer) + if recv_buffer is not None and torch.numel(recv_buffer) > 0: + recv_from_prev_pipeline_rank_(recv_buffer) # Forward pass through the model. - self.model.set_input_tensor(recv_buffer) + if not mpu.is_pipeline_first_stage(): + self.model.set_input_tensor(recv_buffer) output_tensor = self._forward(tokens, position_ids, attention_mask) + if isinstance(output_tensor, tuple): + output_tensor = output_tensor[0] # Send output to the next stage. send_to_next_pipeline_rank(output_tensor) @@ -99,10 +111,10 @@ def _no_pipelining_forward_step(self, tokens, position_ids, attention_mask, return logits - def _with_pipelining_forward_step(self, tokens, position_ids, attention_mask, micro_batch_size): + def _with_pipelining_forward_step(self, tokens, position_ids, attention_mask, micro_batch_size, recv_buffer_seq_length=None): """No interleaving is supported.""" - sequence_length = tokens.size(1) batch_size = tokens.size(0) + sequence_length = tokens.size(1) if recv_buffer_seq_length is None else recv_buffer_seq_length # Divide the batch dimension into micro batches. num_micro_batches, last_chunk = divmod(batch_size, @@ -143,7 +155,7 @@ def _with_pipelining_forward_step(self, tokens, position_ids, attention_mask, mi # Once we are done with all the micro-batches, we can # adjust the sequence length offset. - self.inference_params.sequence_len_offset += sequence_length + self.inference_params.sequence_len_offset += tokens.size(1) # and reset the batch size offset self.inference_params.batch_size_offset = 0 diff --git a/tests/unit_tests/models/test_llava_model.py b/tests/unit_tests/models/test_llava_model.py index 6101835db6..2b31bf18a0 100644 --- a/tests/unit_tests/models/test_llava_model.py +++ b/tests/unit_tests/models/test_llava_model.py @@ -126,6 +126,7 @@ def test_preprocess_data(self): use_inference_kv_cache = False attention_mask = None + inference_params = None embeddings, labels, loss_mask, attention_mask = self.model._preprocess_data( image_embeddings, @@ -134,6 +135,7 @@ def test_preprocess_data(self): loss_mask, labels, use_inference_kv_cache, + inference_params, image_token_index, num_image_tiles, attention_mask,