diff --git a/examples/inference/README.md b/examples/inference/README.md index bd8e738e55..b4b07cbc6a 100644 --- a/examples/inference/README.md +++ b/examples/inference/README.md @@ -1,5 +1,5 @@ ### Megatron Core Inference Documentation -This guide will walk you through how you can use megatron core for inference on your models. +This guide provides an example for Megatron Core for running model inference. ### Contents - [Megatron Core Inference Documentation](#megatron-core-inference-documentation) @@ -18,21 +18,21 @@ This guide will walk you through how you can use megatron core for inference on
#### 1. Quick Start -This will walk you through the flow of running batch inference on a GPT model trained using megatron core. The file can be found at [simple_gpt_batch_inference.py](./gpt/simple_gpt_batch_inference.py) +This example runs batch inference on a GPT model trained using Megatron Core. The entrypoint is [simple_gpt_batch_inference.py](./gpt/gpt_batch_inference.py)
-##### 1.1 Understanding The Code -***STEP 1 - We initialize model parallel and other default arguments*** -We can default micro batch size to be 1, since for TP models it is not used, and for PP models it is calculated during runtime. +##### 1.1 Code Walkthrough +***STEP 1 - Initialize model parallel and other default arguments*** +The micro batch size is set as 1 as it is not used in tensor-parallelism only, and for pipeline-parallel models it is calculated at runtime. ```python initialize_megatron( args_defaults={'no_load_rng': True, 'no_load_optim': True, 'micro_batch_size': 1} ) ``` -***STEP 2 - We load the model using the model_provider_function*** -NOTE: The model provider function in the script supports MCore and Legacy models. +***STEP 2 - Load the model using the model_provider_function*** +NOTE: The model provider function supports both MCore and Legacy models. ```python model = get_model(model_provider, wrap_with_ddp=False) @@ -41,10 +41,10 @@ NOTE: The model provider function in the script supports MCore and Legacy models ``` ***STEP 3 - Choose an engine*** -One of the important elements of the generate function is an inference engine. In this example we will be choosing the [megatron core engine](../../megatron/core/inference/engine/mcore_engine.py) with a [simple text generation controller](../../megatron/core/inference/text_generation_controllers/simple_text_generation_controller.py), the default engine. Other engines that will be supported in the future are TRTLLMEngine. +Text generation requires an inference engine, which includes a scheduler. The default engine is the [Megatron Core engine](../../megatron/core/inference/engine/mcore_engine.py) with a simple [text generation controller](../../megatron/core/inference/text_generation_controllers/text_generation_controller.py). TRTLLMEngine will be supported in the future. ```python inference_wrapped_model = GPTInferenceWrapper(model, args) - text_generation_controller = SimpleTextGenerationController( + text_generation_controller = TextGenerationController( inference_wrapped_model=inference_wrapped_model, tokenizer=tokenizer ) @@ -53,12 +53,12 @@ One of the important elements of the generate function is an inference engine. I ) ``` -***STEP 4 - Run the generate function and display results*** -We use default values for the [common inference params](../../megatron/core/inference/common_inference_params.py). Customize this if you want to change top_p, top_k, number of tokens to generate etc. -*Note that the result is returned as a list of [InferenceRequests](../../megatron/core/inference/inference_request.py)* +***STEP 4 - Run text generation*** +The [SamplingParams](../../megatron/core/inference/sampling_params.py) contains suggested defaults. Customize this to change top_p, top_k, number of tokens to generate etc. +*Note: The result is returned as a list of [InferenceRequests](../../megatron/core/inference/inference_request.py)* ```python results: List[InferenceRequest] = inference_engine.generate( - prompts=args.prompts, common_inference_params=common_inference_params + prompts=args.prompts, sampling_params=sampling_params ) if torch.distributed.get_rank() == 0: @@ -76,12 +76,12 @@ We use default values for the [common inference params](../../megatron/core/infe
##### 1.2 Running The Code -An example run script is shown below. Change the tokenizer paths, inference params, and other settings for your model. +An example run script is shown below. Set the tokenizer paths, inference params, and other settings appropriately. -For a quick recap on inference params refer to [this blog](https://ivibudh.medium.com/a-guide-to-controlling-llm-model-output-exploring-top-k-top-p-and-temperature-parameters-ed6a31313910) +For a quick recap on sampling parameters, refer to [this blog](https://ivibudh.medium.com/a-guide-to-controlling-llm-model-output-exploring-top-k-top-p-and-temperature-parameters-ed6a31313910). ``` -#In a slurm cluster (You could also use docker) +# In a slurm cluster (You could also use docker) ACCOUNT= MLM_PATH=/path/to/megatron-lm GPT_CKPT=/path/to/gpt/ckpt @@ -133,8 +133,8 @@ NOTE: Other parameters which can be customized for inference are :- --top_p (top_p sampling) --num-tokens-to-generate (Number of tokens to generate for each prompt) --inference-batch-times-seqlen-threshold (During inference, if batch-size times sequence-length is smaller than this threshold then we will not use pipelining, otherwise we will.') ---use-dist-ckpt (If you are using dist checkpoint format for the model) ---use-legacy-models (If you are using legacy gpt model instead of mcore gpt model) +--use-dist-ckpt (If using dist checkpoint format for the model) +--use-legacy-models (If using legacy gpt model instead of mcore gpt model) ``` @@ -142,16 +142,17 @@ NOTE: Other parameters which can be customized for inference are :-
-#### 2. Flow of Control In MCore Backend -The following is what happens in the [simple_gpt_batch_inference.py](./gpt/simple_gpt_batch_inference.py). -* We call [mcore_engine](../../megatron/core/inference/engines/mcore_engine.py) **generate()** function with all our input prompts. -* The scheduler in the engine will add these prompts to the [active requests] pool (../../megatron/core/inference/inference_request.py) until we hit the max batch size, and then it will put the rest in the waiting requests pool. -* The engine will then run until all requests (waiting + active) are completed +#### 2. Control Flow in the MCore Backend +An example of inference with static batching is provided in [gpt_batch_inference.py](./gpt/gpt_batch_inference.py). +* [mcore_engine](../../megatron/core/inference/engines/mcore_engine.py) **generate()** function is called with the input prompts. +* The `Scheduler` in the engine will add these prompts to the [active requests] pool (../../megatron/core/inference/inference_request.py) until max batch size is hit. Remaining requests will be added to the waiting requests pool. +* The engine will run until all requests (waiting + active) are completed. * The active requests are passed into **generate_all_output_tokens_static_batch()** of the text generation controller . - * This function uses the [model_inference_wrappers](../../megatron/core/inference/model_inference_wrappers/abstract_model_inference_wrapper.py) **prep_model_for_inference()** , and then runs an auto regressive loop - * In the auto regressive loop, the **get_batch_for_context_window()** method of the inference wrapper is called to get the required input, passes it into the **run_one_forward_step()** method, which calls the appropriate (PP, TP) model `.forward()` methods to get the output logits - * The output logits are synchronized across all pipeline parallel ranks - * The text generation controller obtains the log probabilities and samples tokens based on the strategy defined in the common inference parameters. + * This function uses the **prep_model_for_inference()** method of the [model_inference_wrappers](../../megatron/core/inference/model_inference_wrappers/abstract_model_inference_wrapper.py) and runs an autoregressive sampling loop + * In the autoregressive loop, the **get_batch_for_context_window()** method of the inference wrapper is called to slice out the input tokens and masks + * Input tokens and masks are passed it into the **run_one_forward_step()** method, which calls the model `.forward()` method to get the output logits + * Output logits are synchronized across all pipeline parallel ranks + * The text generation controller obtains the log probabilities and samples tokens based on the strategy defined in the sampling parameters. * The sampled tokens are then appended to the input prompt tokens for the next iteration * The **update_generation_status()** method of the text generation controller checks which prompts have finished generating or hit a stop condition * After the inference loop, the result is detokenized and stored as an attribute of the InferenceRequest. These requests are marked as completed. @@ -160,16 +161,18 @@ The following is what happens in the [simple_gpt_batch_inference.py](./gpt/simpl
#### 3. Customizing The Inference Pipeline -The following guide will walk you through how you can customize different parts of the inference pipeline. There are three levels at which you can customize the pipeline. -* **Inference engine** - Highest level of customization. Currently we support the MCore Engine. Change this to add a new engine. -* **Text generation controller** - Extend this to customize tokenization, detokenization, or implement a new sampling strategy. + +The inference pipeline supports three levels of customization: + +* **Inference engine** - The MCore Engine is currently supported. Change this to add a new backend. +* **Text generation controller** - The main sampling loop. This can be customized to support alternative tokenization, detokenization, or to implement a new sampling strategy. * **Inference Wrapped Model** - Change this to support a new model. * **Modify Inference Parameters** - Change this to update top_p, top_k, number of tokens to be generated, temperature, or other sampling parameters.
##### 3.1. Create Your Own Inference Backend -This is the highest level of customization. The [abstract_engine.py](./../../megatron/core/inference/engine/abstract_engine.py) file has a generate method that can be extended to support a new backend. +The [abstract_engine.py](./../../megatron/core/inference/engine/abstract_engine.py) file contains a `generate` method that can be extended to support a new backend. ```python class AbstractEngine(ABC): @@ -177,15 +180,17 @@ class AbstractEngine(ABC): def generate(self) -> dict: """The abstract backend's generate function. - To define your own backend, make sure you implement this and return the outputs as a dictionary . - + To define a new backend, implement this method and return the outputs as a dictionary. +```
-##### 3.2. Create Your Own Text Generation Controller -In case you want to use the megatron core backend, but would like to overwrite the tokenization, text generation or detokenization extend the [simple_text_generation_controller.py](../../megatron/core/inference/text_generation_controllers/simple_text_generation_controller.py). The class has the following methods +##### 3.2. Implement a new Sampling Loop + +The [TextGenerationController](../../megatron/core/inference/text_generation_controllers/text_generation_controller.py) contains the main sampling loop and can be modified to support new tokenization, detokenization, or sampling strategies. + ``` python -class SimpleTextGenerationController: +class TextGenerationController: def tokenize_prompt(self, prompt: str) -> Tuple[torch.Tensor, torch.Tensor]: """Utility to tokenize the input prompts""" @@ -193,12 +198,12 @@ class SimpleTextGenerationController: def sample_from_logits( self, last_token_logits: torch.Tensor, - common_inference_params: CommonInferenceParams, + sampling_params: SamplingParams, vocab_size: int, ) -> torch.Tensor: """Samples the logits to generate outputs - Given the logits of the last token, this function samples it according to the parameters defined in common_inference_params and returns the samples + Given the logits of the last token, this function samples according to the parameters defined in sampling_params and returns the sampled tokens. """ def update_generation_status( @@ -229,12 +234,12 @@ class SimpleTextGenerationController:
##### 3.3. Support Other Models -In order to support other models please extend the [abstract_model_inference_wrapper.py](./../../megatron/core/inference/model_inference_wrappers/abstract_model_inference_wrapper.py) file. The abstract wrapper already supports the following : -* Forward method which automatically calls the appropriate forward method (PP or TP etc) depending on model parallel settings -* Initalizes the model and puts it in eval mode -* Obtains the input parameters (batch size, max seq length) and has an instance of the input +Extend [abstract_model_inference_wrapper.py](./../../megatron/core/inference/model_inference_wrappers/abstract_model_inference_wrapper.py) to support other models. The abstract model wrapper implements: +* Forward method which calls the model `forward` method depending on model parallel settings +* Initializes the model and puts it in `.eval()` mode +* Setup for the input parameters (max batch size, max seq length) -The main methods to change for your model might be the following: +The following methods should be implemented: ```python class AbstractModelInferenceWrapper: def prep_model_for_inference(self, prompts_tokens: torch.Tensor): @@ -247,28 +252,28 @@ class AbstractModelInferenceWrapper: def get_batch_for_context_window(self) -> List: """Returns the input data for inference - This function gets called iteratively in the inference loop . It can be used to extract relevant input from the prompt tokens, attention mask etc. required for each step in inference. + This function gets called iteratively in the inference loop. It can be used to extract relevant input from the prompt tokens, attention mask etc. required for each step in inference. ``` -Refer to [gpt_inference_wrapper.py](../../megatron/core/inference/model_inference_wrappers/gpt/gpt_inference_wrapper.py) for an example of extending this for GPTModel. +Refer to [gpt_inference_wrapper.py](../../megatron/core/inference/model_inference_wrappers/gpt/gpt_inference_wrapper.py) for an example of implementing this for GPTModel.
##### 3.3. Modify Inference Parameters -We use [common inference params](../../megatron/core/inference/common_inference_params.py) for text generation. Customize this if you want to change top_p, top_k, number of tokens to generate etc. If you want to add other attributes that you would use in the inference loop, you can do that as shown below +We use [common inference params](../../megatron/core/inference/sampling_params.py) for text generation. Customize this if you want to change top_p, top_k, number of tokens to generate etc. If you want to add other attributes that you would use in the inference loop, you can do that as shown below ``` -from megatron.core.inference.common_inference_params import CommonInferenceParams +from megatron.core.inference.sampling_params import SamplingParams -c = CommonInferenceParams(temperature=0.5) +c = SamplingParams(temperature=0.5) c.add_attributes({'min_length':4, 'eod_id':153}) ```
#### 4. Future work -The following are planned for the future releases . +The following features are planned for the future releases. * Dynamic batching * Paged Attention * TRTLLM Engine support -* Support for Multimodal model inference \ No newline at end of file +* Support for multimodal inference \ No newline at end of file diff --git a/examples/inference/gpt/simple_gpt_batch_inference.py b/examples/inference/gpt/gpt_batch_inference.py similarity index 91% rename from examples/inference/gpt/simple_gpt_batch_inference.py rename to examples/inference/gpt/gpt_batch_inference.py index 5c7ae5bd77..050b230cef 100644 --- a/examples/inference/gpt/simple_gpt_batch_inference.py +++ b/examples/inference/gpt/gpt_batch_inference.py @@ -6,10 +6,10 @@ from argparse import Namespace from megatron.core.inference.engines.abstract_engine import AbstractEngine from megatron.core.inference.engines.mcore_engine import MCoreEngine -from megatron.core.inference.common_inference_params import CommonInferenceParams +from megatron.core.inference.sampling_params import SamplingParams from megatron.core.inference.model_inference_wrappers.gpt.gpt_inference_wrapper import GPTInferenceWrapper from megatron.core.inference.inference_request import InferenceRequest -from megatron.core.inference.text_generation_controllers.simple_text_generation_controller import SimpleTextGenerationController +from megatron.core.inference.text_generation_controllers.text_generation_controller import TextGenerationController from megatron.core.transformer.module import MegatronModule sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir, os.path.pardir))) @@ -66,7 +66,7 @@ def get_inference_engine(args: Namespace, model: MegatronModule) -> AbstractEngi ) inference_wrapped_model = GPTInferenceWrapper(model, inference_wrapper_config) - text_generation_controller = SimpleTextGenerationController(inference_wrapped_model=inference_wrapped_model, tokenizer=tokenizer) + text_generation_controller = TextGenerationController(inference_wrapped_model=inference_wrapped_model, tokenizer=tokenizer) return MCoreEngine(text_generation_controller=text_generation_controller, max_batch_size=args.max_batch_size) def main(): @@ -89,7 +89,7 @@ def main(): inference_engine = get_inference_engine(args, model) - common_inference_params = CommonInferenceParams( + sampling_params = SamplingParams( temperature=args.temperature, top_k=args.top_k, top_p=args.top_p, @@ -97,7 +97,7 @@ def main(): num_tokens_to_generate=args.num_tokens_to_generate) results: List[InferenceRequest] = inference_engine.generate( - prompts=args.prompts, common_inference_params=common_inference_params + prompts=args.prompts, sampling_params=sampling_params ) if torch.distributed.get_rank() == 0: diff --git a/examples/inference/t5/simple_t5_batch_inference.py b/examples/inference/t5/simple_t5_batch_inference.py index 3f4557d3c2..b4226d7de0 100644 --- a/examples/inference/t5/simple_t5_batch_inference.py +++ b/examples/inference/t5/simple_t5_batch_inference.py @@ -5,7 +5,7 @@ import torch import pretrain_t5 -from megatron.core.inference.common_inference_params import CommonInferenceParams +from megatron.core.inference.sampling_params import SamplingParams from megatron.core.inference.engines.abstract_engine import AbstractEngine from megatron.core.inference.engines.mcore_engine import MCoreEngine from megatron.core.inference.inference_request import InferenceRequest @@ -120,7 +120,7 @@ def main(): inference_engine = get_inference_engine(args, model) - common_inference_params = CommonInferenceParams( + sampling_params = SamplingParams( temperature=args.temperature, top_k=args.top_k, top_p=args.top_p, @@ -138,7 +138,7 @@ def main(): prompts=args.prompts, add_BOS=True, encoder_prompts=args.encoder_prompts, - common_inference_params=common_inference_params, + sampling_params=sampling_params, ) if torch.distributed.get_rank() == 0: diff --git a/megatron/core/inference/common_inference_params.py b/megatron/core/inference/common_inference_params.py index 22353088f8..7955bb6fc1 100644 --- a/megatron/core/inference/common_inference_params.py +++ b/megatron/core/inference/common_inference_params.py @@ -1,29 +1,4 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. -from dataclasses import dataclass - - -@dataclass -class CommonInferenceParams: - """Inference parameters sent along with the prompts - - For an explanation of these parameters refer to this blog https://ivibudh.medium.com/a-guide-to-controlling-llm-model-output-exploring-top-k-top-p-and-temperature-parameters-ed6a31313910 - """ - - temperature: float = 1.0 - top_k: int = 0 - top_p: float = 0.0 - return_log_probs: bool = False - num_tokens_to_generate: int = 30 - - def add_attributes(self, attribute_value_pair: dict): - """Utility to add more attributes to inference params - - Use this method to pass in a custom dictonary to add more inference parameter attributes to the instance you created. Use as follows - c = CommonInferenceParams - c.add_attributes({'min_length':4, 'eod_id':153}) - - Args: - attribute_value_pair (dict): A dictionary containing attributes as the key names and their values as the values. - """ - for key, value in attribute_value_pair.items(): - setattr(self, key, value) +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +from megatron.core.inference.sampling_params import ( # noqa: F401 # pylint: disable=unused-import + SamplingParams as CommonInferenceParams, +) diff --git a/megatron/core/inference/engines/mcore_engine.py b/megatron/core/inference/engines/mcore_engine.py index fe8160228b..28ef46bf92 100644 --- a/megatron/core/inference/engines/mcore_engine.py +++ b/megatron/core/inference/engines/mcore_engine.py @@ -3,12 +3,12 @@ import torch -from megatron.core.inference.common_inference_params import CommonInferenceParams from megatron.core.inference.engines.abstract_engine import AbstractEngine from megatron.core.inference.inference_request import InferenceRequest +from megatron.core.inference.sampling_params import SamplingParams from megatron.core.inference.scheduler import Scheduler -from megatron.core.inference.text_generation_controllers.simple_text_generation_controller import ( - SimpleTextGenerationController, +from megatron.core.inference.text_generation_controllers.text_generation_controller import ( + TextGenerationController, ) @@ -19,7 +19,7 @@ class MCoreEngine(AbstractEngine): Supports any model that is callable (Accepts the inputs and outputs the tensor) Args: - text_generation_controller (SimpleTextGenerationController): A text generation + text_generation_controller (TextGenerationController): A text generation controller that will be used to define how to preprocess prompts, generate outputs and detokenizer the output tokens. max_batch_size : The maxinum number of requests to process at once @@ -29,7 +29,7 @@ class MCoreEngine(AbstractEngine): def __init__( self, - text_generation_controller: SimpleTextGenerationController, + text_generation_controller: TextGenerationController, max_batch_size, random_seed: int = None, ): @@ -42,7 +42,8 @@ def generate( prompts: List[str], add_BOS: bool = False, encoder_prompts: List[str] = None, - common_inference_params: CommonInferenceParams = None, + common_inference_params: SamplingParams = None, + sampling_params: SamplingParams = None, ) -> dict: """The megatron core inference backend generate function @@ -54,13 +55,19 @@ def generate( prompts (List[str]): All the prompts as a list of strings add_BOS (bool): Whether to add BOS token to beginning of prompts encoder_prompts (List[dict]): All the encoder prompts as a list of strings - common_inference_params (CommonInferenceParams): The inference parameters + common_inference_params: Deprecated. Only used for backward compatibility with + MCore <= 0.9.0. Use `sampling_params` going forward. + sampling_params (SamplingParams): The request-level sampling parameters Returns: List[InferenceRequest]: The output is list of inference requests containing the generated tokens, texts and log probs if required """ # TODO :M core- get rng state tracker + + if common_inference_params: + sampling_params = common_inference_params + if self.random_seed: torch.random.manual_seed(self.random_seed) @@ -73,7 +80,7 @@ def generate( prompt=prompt, prompt_tokens=prompt_tokens, encoder_prompt=encoder_prompt, - inference_parameters=common_inference_params, + inference_parameters=sampling_params, ) self.run_engine() diff --git a/megatron/core/inference/inference_request.py b/megatron/core/inference/inference_request.py index 4825dfd366..ea0d67bfea 100644 --- a/megatron/core/inference/inference_request.py +++ b/megatron/core/inference/inference_request.py @@ -5,7 +5,7 @@ import torch -from megatron.core.inference.common_inference_params import CommonInferenceParams +from megatron.core.inference.sampling_params import SamplingParams # class syntax @@ -28,7 +28,7 @@ class InferenceRequest: request_id: str prompt: str - inference_parameters: CommonInferenceParams + inference_parameters: SamplingParams prompt_tokens: List[int] arrival_time: float status: Status diff --git a/megatron/core/inference/sampling_params.py b/megatron/core/inference/sampling_params.py new file mode 100644 index 0000000000..8ffcb6321d --- /dev/null +++ b/megatron/core/inference/sampling_params.py @@ -0,0 +1,35 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +from dataclasses import dataclass + + +@dataclass +class SamplingParams: + """Inference parameters sent along with the prompts. + This class contains request-level attributes that control the sampling techniques used when + generating text. This is distinct from megatron.core.InferenceParams, which is sets model-level + inference attributes such as the maximum sequence length, and contains the KV cache. + + For an explanation of these parameters refer to this blog + https://ivibudh.medium.com/a-guide-to-controlling-llm-model-output-exploring-top-k-top-p-and- + temperature-parameters-ed6a31313910 + """ + + temperature: float = 1.0 + top_k: int = 0 + top_p: float = 0.0 + return_log_probs: bool = False + num_tokens_to_generate: int = 30 + + def add_attributes(self, attribute_value_pair: dict): + """Utility to add more attributes to sampling params + + Use this method to pass in a custom dictionary to add more sampling parameter attributes. + c = SamplingParams + c.add_attributes({'min_length':4, 'eod_id':153}) + + Args: + attribute_value_pair (dict): A dictionary containing attributes as the key names and + their values as the values. + """ + for key, value in attribute_value_pair.items(): + setattr(self, key, value) diff --git a/megatron/core/inference/scheduler.py b/megatron/core/inference/scheduler.py index 00ab81b4ab..ef177232b4 100644 --- a/megatron/core/inference/scheduler.py +++ b/megatron/core/inference/scheduler.py @@ -6,8 +6,8 @@ import torch -from megatron.core.inference.common_inference_params import CommonInferenceParams from megatron.core.inference.inference_request import InferenceRequest, Status +from megatron.core.inference.sampling_params import SamplingParams from megatron.core.inference.utils import Counter @@ -33,7 +33,7 @@ def add_request( prompt: str, prompt_tokens: torch.Tensor, encoder_prompt: str = None, - inference_parameters: CommonInferenceParams = None, + inference_parameters: SamplingParams = None, arrival_time: float = None, ): """Add an incoming request @@ -45,7 +45,7 @@ def add_request( prompt (str): Input prompt string prompt_tokens (torch.Tensor): A torch tensor having the input prompts tokenized encoder_prompt (str): Encoder input string - inference_parameters (CommonInferenceParams): The inference parameters + inference_parameters (SamplingParams): The inference parameters arrival_time (float, optional): The incoming request time. Defaults to None. """ request_id = str(next(self.request_counter)) diff --git a/megatron/core/inference/text_generation_controllers/encoder_decoder_text_generation_controller.py b/megatron/core/inference/text_generation_controllers/encoder_decoder_text_generation_controller.py index 61beff0211..0c2a41be44 100644 --- a/megatron/core/inference/text_generation_controllers/encoder_decoder_text_generation_controller.py +++ b/megatron/core/inference/text_generation_controllers/encoder_decoder_text_generation_controller.py @@ -4,15 +4,15 @@ import torch from megatron.core.inference.inference_request import InferenceRequest -from megatron.core.inference.text_generation_controllers.simple_text_generation_controller import ( - SimpleTextGenerationController, +from megatron.core.inference.text_generation_controllers.text_generation_controller import ( + TextGenerationController, ) -class EncoderDecoderTextGenerationController(SimpleTextGenerationController): +class EncoderDecoderTextGenerationController(TextGenerationController): """The text generation controller for encoder-decoder architecture - This class ingherits from SimpleTextGenerationController, adding features + This class inherits from TextGenerationController, adding features relating to encoder input encoder_prompt """ diff --git a/megatron/core/inference/text_generation_controllers/simple_text_generation_controller.py b/megatron/core/inference/text_generation_controllers/simple_text_generation_controller.py index ceea4064d2..f97df13249 100644 --- a/megatron/core/inference/text_generation_controllers/simple_text_generation_controller.py +++ b/megatron/core/inference/text_generation_controllers/simple_text_generation_controller.py @@ -1,400 +1,5 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. -from typing import List, OrderedDict, Tuple +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. -import torch -import torch.nn.functional as F - -from megatron.core import parallel_state -from megatron.core.inference.common_inference_params import CommonInferenceParams -from megatron.core.inference.communication_utils import broadcast_from_last_pipeline_stage -from megatron.core.inference.inference_request import InferenceRequest, Status -from megatron.core.inference.model_inference_wrappers.abstract_model_inference_wrapper import ( - AbstractModelInferenceWrapper, +from megatron.core.inference.text_generation_controllers.text_generation_controller import ( # noqa: F401 # pylint: disable=unused-import + TextGenerationController as SimpleTextGenerationController, ) - - -class SimpleTextGenerationController: - """The basic text generation controller - - This class is responsible for tokenizing the input , running the inference, sampling - and also detokenizing the output - - Args: - inference_wrapped_model (AbstractModelInferenceWrapper): A model that - is wrapped using the specs given in the abstract_model_inference_wrapper.py - tokenizer (_type_): Tokenizer used for tokenizing and detokenizing the prompts - """ - - def __init__(self, inference_wrapped_model: AbstractModelInferenceWrapper, tokenizer): - self.inference_wrapped_model = inference_wrapped_model - self.tokenizer = tokenizer - - # For models without pipeline parallelism, is_first_stage and is_last_stage returns True - self.model_is_pipeline_parallel = not ( - parallel_state.is_pipeline_first_stage() and parallel_state.is_pipeline_last_stage() - ) - - def tokenize_prompt( - self, prompt: str, add_BOS: bool = False - ) -> Tuple[torch.Tensor, torch.Tensor]: - """Utility to tokenize the input prompts - - Args: - prompt (str): The input prompt - - Returns: - torch.Tensor: Returns the tokenized prompt - """ - prompt_tokens = self.tokenizer.tokenize(prompt) - - if add_BOS: - prompt_tokens = [self.tokenizer.bos] + prompt_tokens - - return prompt_tokens - - def detokenize_generations(self, prompt_tokens_with_generated_tokens: torch.Tensor) -> str: - """Detokenize the output generations - - Args: - prompt_tokens_with_generated_tokens (torch.Tensor): The input prompt - tokens plus the generated tokens - - Returns: - str: The detokenized output - """ - tokens = prompt_tokens_with_generated_tokens.cpu().numpy().tolist() - return self.tokenizer.detokenize(tokens) - - def sample_from_logits( - self, - last_token_logits: torch.Tensor, - common_inference_params: CommonInferenceParams, - vocab_size: int = None, - ) -> torch.Tensor: - """Samples the logits to generate outputs - - Given the logits of the last token, this function samples it - according to the parameters defined in common_inference_params - and returns the samples - - Args: - last_token_logits (torch.Tensor): The last token logits. A tensor of - size [batch_size, vocab_size] - common_inference_params (CommonInferenceParams): The paramters to use - for inference - vocab_size (int): Obtained from the tokenizer. Defaults to None - - Returns: - torch.Tensor: 1D tensor of the sampled logits with [batch_size] elements - """ - - top_p = common_inference_params.top_p - top_k = common_inference_params.top_k - temperature = common_inference_params.temperature - - assert not (top_k > 0 and top_p > 0), 'Cannot have top-p and top-k both greater than zero' - assert top_p <= 1.0, 'top-p should be in (0,1]' - - def modify_logits_for_top_k_filtering(logits, top_k): - """Set the logits for none top-k values to -inf.""" - filter_ = logits < torch.topk(logits, top_k)[0][..., -1, None] - logits.masked_fill_(filter_, float('-Inf')) - - def modify_logits_for_top_p_filtering(logits, top_p): - """Set the logits for none top-p values to -inf.""" - # First sort and calculate cumulative sum of probabilities. - sorted_logits, sorted_indices = torch.sort(logits, descending=True) - cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1) - - # Filteration based on the cumulative sum. - filter_ = cumulative_probs > top_p - # This shift by 1 is weird and I cannot justify it. This existed - # in the original implementation: - # https://github.com/ari-holtzman/degen/blob/master/gen.py - # and I guess it is needed so keeping it for now. - filter_[:, 1:] = filter_[:, :-1].clone() - # Make sure we at least have one token to select from. - filter_[..., 0] = 0 - - # Fill in the filtered part - filter_ = filter_.scatter(1, sorted_indices, filter_) - logits.masked_fill_(filter_, float('-Inf')) - - # Greedy sampling - if top_k == 1: - sampled_logits = torch.argmax(last_token_logits, dim=-1) - else: - last_token_logits = last_token_logits.clone() - if temperature != 1.0: - last_token_logits.div_(temperature) - - if top_k > 1: - assert top_k <= last_token_logits.size(1), 'top-k is larger than logit size.' - if vocab_size: - assert top_k < vocab_size, 'top-k is larger than vocab size.' - modify_logits_for_top_k_filtering(last_token_logits, top_k) - - elif top_p > 0.0: - modify_logits_for_top_p_filtering(last_token_logits, top_p) - - # After filtering, we need to recalculate the distribution. - probabilities = last_token_logits.softmax(dim=-1) - sampled_logits = torch.multinomial(probabilities, num_samples=1).view(-1) - - # If vocab size is provided, make sure the samples are in in the range [0, vocab-size). - if vocab_size: - sampled_logits = torch.clamp(sampled_logits, min=0, max=(vocab_size - 1)) - return sampled_logits - - def update_generation_status( - self, - updated_prompts_tokens: torch.Tensor, - generation_started: torch.Tensor, - current_context_end_position: int, - is_generation_done_tensor: torch.Tensor, - generated_sequence_lengths: torch.Tensor, - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """Checks which prompts have reached an end condition - - We check which prompts have reached an end condition and set the corresponding - flags of the is_generation_done_tensor to True. The generated sequence lengths - increase as we keep generating, until that prompts hits an end condition. The - generation_started tensor determines which prompts have started generating. - - Args: - updated_prompts_tokens (torch.Tensor): The prompts tokens updated with the latest - generated tokens. A tensor of shape [batch_size, max_seq_len] - (i.e max_seq_len = max_prompt_len + tokens_to_generate) - generation_started (torch.Tensor): A boolean tensor of shape [batch_size]. True - indicates the prompt at that index has started generating tokens. - current_context_end_position (int): An integer indicating which position to - extract from the prompts tokens to get the latest generated tokens. - is_generation_done_tensor (torch.Tensor): A boolean tensor of shape [batch_size]. - True indicates the prompt at that index has reached end condition. - generated_sequence_lengths (torch.Tensor): A int tensor of shape [batch_size]. - Each value represents the generated sequence lengths for that prompt. - - Returns: - Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Returns the boolean - is_generation_done_tensor and the generated_sequence_lengths after updating it - """ - latest_samples = updated_prompts_tokens[:, current_context_end_position] - # Make sure we are checking eod criterion only for prompts that have started generating - # (i.e) We only look at the generated tokenns and not the input tokens. - reached_eod = (latest_samples == self.tokenizer.eod) & generation_started - is_generation_done_tensor = is_generation_done_tensor | reached_eod - # We increment generated sequence lengths when that prompt has not hit the - # EOD and generation has started - generated_sequence_lengths += ~is_generation_done_tensor & generation_started - - return is_generation_done_tensor, generated_sequence_lengths - - def pad_input_prompt_tokens( - self, - batch_prompt_tokens_list: List[List[int]], - max_prompt_length_in_batch: int, - num_tokens_to_generate: int, - ) -> torch.Tensor: - """Method to pad input prompts - - Given a list of prompts, pad them all to uniform length - - Args: - batch_prompt_tokens_list (List[List[int]]): A list containing the prompt tokens - max_prompt_length_in_batch (int): Maximum of the length of the input prompt tokens - num_tokens_togenerate (int): The number of tokens to generate for each prompt - - Returns: - torch.Tensor: A torch tensor of shape [bs, max_seq_len] (i.e) - max_seq_len = max_prompt_length_in_batch + num_tokens_to_generate, - with extra indices for each tensor padded with mask id. - """ - max_seq_len = max_prompt_length_in_batch + num_tokens_to_generate - - for prompt_tokens in batch_prompt_tokens_list: - padding_size = max_seq_len - len(prompt_tokens) - prompt_tokens.extend([self.tokenizer.eod] * padding_size) - - return torch.tensor(batch_prompt_tokens_list).cuda() - - def generate_output_tokens_dynamic_batch( - self, active_requests: OrderedDict[int, InferenceRequest] - ) -> OrderedDict[int, InferenceRequest]: - """Utility to generate the output tokens and probabilities for the prompts - - This utility generates the output tokens for a dynamic batch. It will run one forward step - at a time, and pass control back to the engine, which will update the request pool and call - this method again. - - Args: - active_requests (OrderedDict[int, InferenceRequest]): The input active requests. - - Returns: - OrderedDict[int, InferenceRequest]: The result for each of the incoming requests - after running one forward step. - """ - raise Exception("Not implemented yet") - - def generate_all_output_tokens_static_batch( - self, active_requests: OrderedDict[int, InferenceRequest] - ) -> OrderedDict[int, InferenceRequest]: - """Utility to generate the all the output tokens and probabilities for the prompts . - - This utility generates the output tokens for a static batch. It runs the forward steps till - all prompts complete generation, updates the status of these requests to completed, adds - the generated result and returns these requests - - Args: - active_requests (OrderedDict[int, InferenceRequest]): The input active requests. - - Returns: - OrderedDict[int, InferenceRequest]: The result for each of the incoming requests - """ - batch_prompt_tokens_list = list( - map(lambda request: request.prompt_tokens, active_requests.values()) - ) - prompt_lengths_in_batch = torch.tensor( - [len(prompt_tokens) for prompt_tokens in batch_prompt_tokens_list] - ).cuda() - max_prompt_length_in_batch = max(prompt_lengths_in_batch) - min_prompt_length_in_batch = min(prompt_lengths_in_batch) - - # For batch inference the inference params are the same for all request - common_inference_params: CommonInferenceParams = list(active_requests.values())[ - 0 - ].inference_parameters - - # max_seq_len = max_prompt_length_in_batch + num_tokens_to_generate - batch_prompt_tokens = self.pad_input_prompt_tokens( - batch_prompt_tokens_list, - max_prompt_length_in_batch=max_prompt_length_in_batch, - num_tokens_to_generate=common_inference_params.num_tokens_to_generate, - ) - batch_size, max_sequence_length = batch_prompt_tokens.shape - - # Pre allocate log probs tensor - output_log_probs = None - if common_inference_params.return_log_probs: - output_log_probs = torch.empty( - (batch_size, max_sequence_length - 1), dtype=torch.float32 - ).cuda() - - # An array to check which of the prompts have reached end of generation condition - is_generation_done_tensor = torch.zeros(batch_size, dtype=torch.bool).cuda() - - # An array to act as a counter to keep track of generated sequence lengths - generated_sequence_lengths = torch.zeros(batch_size).cuda() - - with torch.no_grad(): - - self.prep_model_for_inference( - prompts_tokens=batch_prompt_tokens, active_requests=active_requests - ) - - context_start_position = 0 - # Pick the context window that we need to pass through the network. - for context_end_position in range(min_prompt_length_in_batch, max_sequence_length): - - inference_input = self.inference_wrapped_model.get_batch_for_context_window( - context_start_position, context_end_position - ) - - # Returns the final logits of shape [batch_size, context_length, vocab_size] - # Note: This is returned in all TP ranks or last PP stage in PP models - logits = self.inference_wrapped_model.run_one_forward_step(inference_input) - if self.model_is_pipeline_parallel: - context_length = context_end_position - context_start_position - logits = broadcast_from_last_pipeline_stage( - [batch_size, context_length, self.inference_wrapped_model.model.vocab_size], - dtype=self.inference_wrapped_model.inference_wrapper_config.params_dtype, - tensor=logits, - ) - - # Indicates which of the input prompts have started generating tokens. - # A 1D boolean tensor with [batch_size] elements (i.e) The shortest - # prompts will start generating first and so on - generation_started = prompt_lengths_in_batch <= context_end_position - last_token_logits = logits[:, -1, :] - sampled_logits = self.sample_from_logits( - last_token_logits, common_inference_params, self.inference_wrapped_model.model.vocab_size - ) - - # Substitute the sampled logits only for only the prompts that - # have started generating tokens - batch_prompt_tokens[generation_started, context_end_position] = sampled_logits[ - generation_started - ] - - if common_inference_params.return_log_probs: - log_probs = F.log_softmax(logits, dim=2) - indices = torch.unsqueeze( - batch_prompt_tokens[ - :, (context_start_position + 1) : (context_end_position + 1) - ], - 2, - ) - # Get the log probabilities for only the prompt tokens - output_log_probs[:, context_start_position:context_end_position] = torch.gather( - log_probs, 2, indices - ).squeeze(2) - - context_start_position = context_end_position - - # Check end of generation status for each tensor - # and update generated sequence lengths - (is_generation_done_tensor, generated_sequence_lengths) = ( - self.update_generation_status( - updated_prompts_tokens=batch_prompt_tokens, - generation_started=generation_started, - current_context_end_position=context_end_position, - is_generation_done_tensor=is_generation_done_tensor, - generated_sequence_lengths=generated_sequence_lengths, - ) - ) - # Boolean flag indicating if all prompts are finished - all_prompts_done = torch.all(is_generation_done_tensor) - if all_prompts_done: - break - - # Include all the generated tokens - batch_prompt_tokens_with_generations = batch_prompt_tokens[:, : (context_end_position + 1)] - if common_inference_params.return_log_probs: - output_log_probs = output_log_probs[:, :context_end_position] - - generated_sequence_lengths[ - generated_sequence_lengths > common_inference_params.num_tokens_to_generate - ] = common_inference_params.num_tokens_to_generate - - for idx, request in enumerate(active_requests.values()): - input_prompt_length = int(prompt_lengths_in_batch[idx]) - # Shorter prompts might have generated more than required tokens. So we trim them down - required_sequence_length = int( - min(generated_sequence_lengths[idx], common_inference_params.num_tokens_to_generate) - ) - # Extract only the generated tokens - required_result_tokens = batch_prompt_tokens_with_generations[ - idx, input_prompt_length : (input_prompt_length + required_sequence_length) - ] - - request.generated_length = required_sequence_length - request.generated_tokens = required_result_tokens - request.generated_log_probs = ( - None - if output_log_probs is None - else output_log_probs[idx, input_prompt_length:required_sequence_length] - ) - request.status = Status.COMPLETED - request.generated_text = self.detokenize_generations(required_result_tokens) - - return active_requests - - def prep_model_for_inference( - self, prompts_tokens: torch.Tensor, active_requests: OrderedDict[int, InferenceRequest] - ): - """Preparing batch for inference, using respective wrapper's prep_model_for_inference method - - Args: - prompts_tokens (torch.Tensor): A tensor of shape [batch_size, max_sequence_length] - active_requests (OrderedDict[int, InferenceRequest]): The input active requests - """ - self.inference_wrapped_model.prep_model_for_inference(prompts_tokens=prompts_tokens) diff --git a/megatron/core/inference/text_generation_controllers/text_generation_controller.py b/megatron/core/inference/text_generation_controllers/text_generation_controller.py new file mode 100644 index 0000000000..f15c819c43 --- /dev/null +++ b/megatron/core/inference/text_generation_controllers/text_generation_controller.py @@ -0,0 +1,400 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +from typing import List, OrderedDict, Tuple + +import torch +import torch.nn.functional as F + +from megatron.core import parallel_state +from megatron.core.inference.communication_utils import broadcast_from_last_pipeline_stage +from megatron.core.inference.inference_request import InferenceRequest, Status +from megatron.core.inference.model_inference_wrappers.abstract_model_inference_wrapper import ( + AbstractModelInferenceWrapper, +) +from megatron.core.inference.sampling_params import SamplingParams + + +class TextGenerationController: + """The text generation controller (the main sampling loop) + + This class tokenizes the input, runs inference, samples from logits, and detokenizes the output. + + Args: + inference_wrapped_model (AbstractModelInferenceWrapper): A model that + is wrapped using the specs given in the abstract_model_inference_wrapper.py + tokenizer (_type_): Tokenizer used for tokenizing and detokenizing the prompts + """ + + def __init__(self, inference_wrapped_model: AbstractModelInferenceWrapper, tokenizer): + self.inference_wrapped_model = inference_wrapped_model + self.tokenizer = tokenizer + + # For models without pipeline parallelism, is_first_stage and is_last_stage returns True + self.model_is_pipeline_parallel = not ( + parallel_state.is_pipeline_first_stage() and parallel_state.is_pipeline_last_stage() + ) + + def tokenize_prompt( + self, prompt: str, add_BOS: bool = False + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Utility to tokenize the input prompts + + Args: + prompt (str): The input prompt + + Returns: + torch.Tensor: Returns the tokenized prompt + """ + prompt_tokens = self.tokenizer.tokenize(prompt) + + if add_BOS: + prompt_tokens = [self.tokenizer.bos] + prompt_tokens + + return prompt_tokens + + def detokenize_generations(self, prompt_tokens_with_generated_tokens: torch.Tensor) -> str: + """Detokenize the output generations + + Args: + prompt_tokens_with_generated_tokens (torch.Tensor): The input prompt + tokens plus the generated tokens + + Returns: + str: The detokenized output + """ + tokens = prompt_tokens_with_generated_tokens.cpu().numpy().tolist() + return self.tokenizer.detokenize(tokens) + + def sample_from_logits( + self, + last_token_logits: torch.Tensor, + sampling_params: SamplingParams = None, + vocab_size: int = None, + **kwargs + ) -> torch.Tensor: + """Samples the logits to generate outputs + + Given the logits of the last token, this function samples it + according to the parameters defined in sampling_params + and returns the samples + + Args: + last_token_logits (torch.Tensor): The last token logits. A tensor of + size [batch_size, vocab_size] + sampling_params (SamplingParams): The parameters to use for inference. + vocab_size (int): Obtained from the tokenizer. Defaults to None + + Returns: + torch.Tensor: 1D tensor of the sampled logits with [batch_size] elements + """ + + if kwargs.get('common_inference_params'): + sampling_params = kwargs['common_inference_params'] + + top_p = sampling_params.top_p + top_k = sampling_params.top_k + temperature = sampling_params.temperature + + assert not (top_k > 0 and top_p > 0), 'Cannot have top-p and top-k both greater than zero' + assert top_p <= 1.0, 'top-p should be in (0,1]' + + def modify_logits_for_top_k_filtering(logits, top_k): + """Set the logits for none top-k values to -inf.""" + filter_ = logits < torch.topk(logits, top_k)[0][..., -1, None] + logits.masked_fill_(filter_, float('-Inf')) + + def modify_logits_for_top_p_filtering(logits, top_p): + """Set the logits for none top-p values to -inf.""" + # First sort and calculate cumulative sum of probabilities. + sorted_logits, sorted_indices = torch.sort(logits, descending=True) + cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1) + + # Filteration based on the cumulative sum. + filter_ = cumulative_probs > top_p + # This shift by 1 is weird and I cannot justify it. This existed + # in the original implementation: + # https://github.com/ari-holtzman/degen/blob/master/gen.py + # and I guess it is needed so keeping it for now. + filter_[:, 1:] = filter_[:, :-1].clone() + # Make sure we at least have one token to select from. + filter_[..., 0] = 0 + + # Fill in the filtered part + filter_ = filter_.scatter(1, sorted_indices, filter_) + logits.masked_fill_(filter_, float('-Inf')) + + # Greedy sampling + if top_k == 1: + sampled_logits = torch.argmax(last_token_logits, dim=-1) + else: + last_token_logits = last_token_logits.clone() + if temperature != 1.0: + last_token_logits.div_(temperature) + + if top_k > 1: + assert top_k <= last_token_logits.size(1), 'top-k is larger than logit size.' + if vocab_size: + assert top_k < vocab_size, 'top-k is larger than vocab size.' + modify_logits_for_top_k_filtering(last_token_logits, top_k) + + elif top_p > 0.0: + modify_logits_for_top_p_filtering(last_token_logits, top_p) + + # After filtering, we need to recalculate the distribution. + probabilities = last_token_logits.softmax(dim=-1) + sampled_logits = torch.multinomial(probabilities, num_samples=1).view(-1) + + # If vocab size is provided, make sure the samples are in in the range [0, vocab-size). + if vocab_size: + sampled_logits = torch.clamp(sampled_logits, min=0, max=(vocab_size - 1)) + return sampled_logits + + def update_generation_status( + self, + updated_prompts_tokens: torch.Tensor, + generation_started: torch.Tensor, + current_context_end_position: int, + is_generation_done_tensor: torch.Tensor, + generated_sequence_lengths: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Checks which prompts have reached an end condition + + We check which prompts have reached an end condition and set the corresponding + flags of the is_generation_done_tensor to True. The generated sequence lengths + increase as we keep generating, until that prompts hits an end condition. The + generation_started tensor determines which prompts have started generating. + + Args: + updated_prompts_tokens (torch.Tensor): The prompts tokens updated with the latest + generated tokens. A tensor of shape [batch_size, max_seq_len] + (i.e max_seq_len = max_prompt_len + tokens_to_generate) + generation_started (torch.Tensor): A boolean tensor of shape [batch_size]. True + indicates the prompt at that index has started generating tokens. + current_context_end_position (int): An integer indicating which position to + extract from the prompts tokens to get the latest generated tokens. + is_generation_done_tensor (torch.Tensor): A boolean tensor of shape [batch_size]. + True indicates the prompt at that index has reached end condition. + generated_sequence_lengths (torch.Tensor): A int tensor of shape [batch_size]. + Each value represents the generated sequence lengths for that prompt. + + Returns: + Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Returns the boolean + is_generation_done_tensor and the generated_sequence_lengths after updating it + """ + latest_samples = updated_prompts_tokens[:, current_context_end_position] + # Make sure we are checking eod criterion only for prompts that have started generating + # (i.e) We only look at the generated tokenns and not the input tokens. + reached_eod = (latest_samples == self.tokenizer.eod) & generation_started + is_generation_done_tensor = is_generation_done_tensor | reached_eod + # We increment generated sequence lengths when that prompt has not hit the + # EOD and generation has started + generated_sequence_lengths += ~is_generation_done_tensor & generation_started + + return is_generation_done_tensor, generated_sequence_lengths + + def pad_input_prompt_tokens( + self, + batch_prompt_tokens_list: List[List[int]], + max_prompt_length_in_batch: int, + num_tokens_to_generate: int, + ) -> torch.Tensor: + """Method to pad input prompts + + Given a list of prompts, pad them all to uniform length + + Args: + batch_prompt_tokens_list (List[List[int]]): A list containing the prompt tokens + max_prompt_length_in_batch (int): Maximum of the length of the input prompt tokens + num_tokens_togenerate (int): The number of tokens to generate for each prompt + + Returns: + torch.Tensor: A torch tensor of shape [bs, max_seq_len] (i.e) + max_seq_len = max_prompt_length_in_batch + num_tokens_to_generate, + with extra indices for each tensor padded with mask id. + """ + max_seq_len = max_prompt_length_in_batch + num_tokens_to_generate + + for prompt_tokens in batch_prompt_tokens_list: + padding_size = max_seq_len - len(prompt_tokens) + prompt_tokens.extend([self.tokenizer.eod] * padding_size) + + return torch.tensor(batch_prompt_tokens_list).cuda() + + def generate_output_tokens_dynamic_batch( + self, active_requests: OrderedDict[int, InferenceRequest] + ) -> OrderedDict[int, InferenceRequest]: + """Utility to generate the output tokens and probabilities for the prompts + + This utility generates the output tokens for a dynamic batch. It will run one forward step + at a time, and pass control back to the engine, which will update the request pool and call + this method again. + + Args: + active_requests (OrderedDict[int, InferenceRequest]): The input active requests. + + Returns: + OrderedDict[int, InferenceRequest]: The result for each of the incoming requests + after running one forward step. + """ + raise Exception("Not implemented yet") + + def generate_all_output_tokens_static_batch( + self, active_requests: OrderedDict[int, InferenceRequest] + ) -> OrderedDict[int, InferenceRequest]: + """Utility to generate the all the output tokens and probabilities for the prompts . + + This utility generates the output tokens for a static batch. It runs the forward steps till + all prompts complete generation, updates the status of these requests to completed, adds + the generated result and returns these requests + + Args: + active_requests (OrderedDict[int, InferenceRequest]): The input active requests. + + Returns: + OrderedDict[int, InferenceRequest]: The result for each of the incoming requests + """ + batch_prompt_tokens_list = list( + map(lambda request: request.prompt_tokens, active_requests.values()) + ) + prompt_lengths_in_batch = torch.tensor( + [len(prompt_tokens) for prompt_tokens in batch_prompt_tokens_list] + ).cuda() + max_prompt_length_in_batch = max(prompt_lengths_in_batch) + min_prompt_length_in_batch = min(prompt_lengths_in_batch) + + # For batch inference the inference params are the same for all request + sampling_params: SamplingParams = list(active_requests.values())[0].inference_parameters + + # max_seq_len = max_prompt_length_in_batch + num_tokens_to_generate + batch_prompt_tokens = self.pad_input_prompt_tokens( + batch_prompt_tokens_list, + max_prompt_length_in_batch=max_prompt_length_in_batch, + num_tokens_to_generate=sampling_params.num_tokens_to_generate, + ) + batch_size, max_sequence_length = batch_prompt_tokens.shape + + # Pre allocate log probs tensor + output_log_probs = None + if sampling_params.return_log_probs: + output_log_probs = torch.empty( + (batch_size, max_sequence_length - 1), dtype=torch.float32 + ).cuda() + + # An array to check which of the prompts have reached end of generation condition + is_generation_done_tensor = torch.zeros(batch_size, dtype=torch.bool).cuda() + + # An array to act as a counter to keep track of generated sequence lengths + generated_sequence_lengths = torch.zeros(batch_size).cuda() + + with torch.no_grad(): + + self.prep_model_for_inference( + prompts_tokens=batch_prompt_tokens, active_requests=active_requests + ) + + context_start_position = 0 + # Pick the context window that we need to pass through the network. + for context_end_position in range(min_prompt_length_in_batch, max_sequence_length): + + inference_input = self.inference_wrapped_model.get_batch_for_context_window( + context_start_position, context_end_position + ) + + # Returns the final logits of shape [batch_size, context_length, vocab_size] + # Note: This is returned in all TP ranks or last PP stage in PP models + logits = self.inference_wrapped_model.run_one_forward_step(inference_input) + if self.model_is_pipeline_parallel: + context_length = context_end_position - context_start_position + logits = broadcast_from_last_pipeline_stage( + [batch_size, context_length, self.tokenizer.vocab_size], + dtype=self.inference_wrapped_model.inference_wrapper_config.params_dtype, + tensor=logits, + ) + + # Indicates which of the input prompts have started generating tokens. + # A 1D boolean tensor with [batch_size] elements (i.e) The shortest + # prompts will start generating first and so on + generation_started = prompt_lengths_in_batch <= context_end_position + last_token_logits = logits[:, -1, :] + sampled_logits = self.sample_from_logits( + last_token_logits, sampling_params, self.tokenizer.vocab_size + ) + + # Substitute the sampled logits only for only the prompts that + # have started generating tokens + batch_prompt_tokens[generation_started, context_end_position] = sampled_logits[ + generation_started + ] + + if sampling_params.return_log_probs: + log_probs = F.log_softmax(logits, dim=2) + indices = torch.unsqueeze( + batch_prompt_tokens[ + :, (context_start_position + 1) : (context_end_position + 1) + ], + 2, + ) + # Get the log probabilities for only the prompt tokens + output_log_probs[:, context_start_position:context_end_position] = torch.gather( + log_probs, 2, indices + ).squeeze(2) + + context_start_position = context_end_position + + # Check end of generation status for each tensor + # and update generated sequence lengths + (is_generation_done_tensor, generated_sequence_lengths) = ( + self.update_generation_status( + updated_prompts_tokens=batch_prompt_tokens, + generation_started=generation_started, + current_context_end_position=context_end_position, + is_generation_done_tensor=is_generation_done_tensor, + generated_sequence_lengths=generated_sequence_lengths, + ) + ) + # Boolean flag indicating if all prompts are finished + all_prompts_done = torch.all(is_generation_done_tensor) + if all_prompts_done: + break + + # Include all the generated tokens + batch_prompt_tokens_with_generations = batch_prompt_tokens[:, : (context_end_position + 1)] + if sampling_params.return_log_probs: + output_log_probs = output_log_probs[:, :context_end_position] + + generated_sequence_lengths[ + generated_sequence_lengths > sampling_params.num_tokens_to_generate + ] = sampling_params.num_tokens_to_generate + + for idx, request in enumerate(active_requests.values()): + input_prompt_length = int(prompt_lengths_in_batch[idx]) + # Shorter prompts might have generated more than required tokens. So we trim them down + required_sequence_length = int( + min(generated_sequence_lengths[idx], sampling_params.num_tokens_to_generate) + ) + # Extract only the generated tokens + required_result_tokens = batch_prompt_tokens_with_generations[ + idx, input_prompt_length : (input_prompt_length + required_sequence_length) + ] + + request.generated_length = required_sequence_length + request.generated_tokens = required_result_tokens + request.generated_log_probs = ( + None + if output_log_probs is None + else output_log_probs[idx, input_prompt_length:required_sequence_length] + ) + request.status = Status.COMPLETED + request.generated_text = self.detokenize_generations(required_result_tokens) + + return active_requests + + def prep_model_for_inference( + self, prompts_tokens: torch.Tensor, active_requests: OrderedDict[int, InferenceRequest] + ): + """Preparing batch for inference, using respective wrapper's prep_model_for_inference method + + Args: + prompts_tokens (torch.Tensor): A tensor of shape [batch_size, max_sequence_length] + active_requests (OrderedDict[int, InferenceRequest]): The input active requests + """ + self.inference_wrapped_model.prep_model_for_inference(prompts_tokens=prompts_tokens) diff --git a/tests/unit_tests/inference/engines/test_mcore_engine.py b/tests/unit_tests/inference/engines/test_mcore_engine.py index 8295744d36..1b342db4e6 100644 --- a/tests/unit_tests/inference/engines/test_mcore_engine.py +++ b/tests/unit_tests/inference/engines/test_mcore_engine.py @@ -5,7 +5,6 @@ import torch -from megatron.core.inference.common_inference_params import CommonInferenceParams from megatron.core.inference.engines.mcore_engine import MCoreEngine from megatron.core.inference.inference_request import InferenceRequest, Status from megatron.core.inference.model_inference_wrappers.gpt.gpt_inference_wrapper import ( @@ -14,8 +13,9 @@ from megatron.core.inference.model_inference_wrappers.inference_wrapper_config import ( InferenceWrapperConfig, ) -from megatron.core.inference.text_generation_controllers.simple_text_generation_controller import ( - SimpleTextGenerationController, +from megatron.core.inference.sampling_params import SamplingParams +from megatron.core.inference.text_generation_controllers.text_generation_controller import ( + TextGenerationController, ) from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_local_spec from megatron.core.models.gpt.gpt_model import GPTModel @@ -60,7 +60,7 @@ def setup_method(self, method): inference_wrapped_model = GPTInferenceWrapper(gpt_model, inference_wrapper_config) self.mock_tokenizer = mock.Mock() - text_generation_controller = SimpleTextGenerationController( + text_generation_controller = TextGenerationController( inference_wrapped_model=inference_wrapped_model, tokenizer=self.mock_tokenizer ) @@ -85,7 +85,7 @@ def test_generate(self): prompts = ["sample" * (i + 1) for i in range(self.batch_size)] results: List[InferenceRequest] = self.mcore_engine.generate( - prompts, common_inference_params=CommonInferenceParams(num_tokens_to_generate=10) + prompts, sampling_params=SamplingParams(num_tokens_to_generate=10) ) for result in results: @@ -110,9 +110,7 @@ def test_generate_empty_prompt(self): prompts = ["" for i in range(self.batch_size)] results: List[InferenceRequest] = self.mcore_engine.generate( - prompts, - add_BOS=True, - common_inference_params=CommonInferenceParams(num_tokens_to_generate=10), + prompts, add_BOS=True, sampling_params=SamplingParams(num_tokens_to_generate=10) ) for result in results: diff --git a/tests/unit_tests/inference/test_common_inference_params.py b/tests/unit_tests/inference/test_common_inference_params.py index af51e433df..c7ef4c9ed8 100644 --- a/tests/unit_tests/inference/test_common_inference_params.py +++ b/tests/unit_tests/inference/test_common_inference_params.py @@ -1,10 +1,10 @@ -from megatron.core.inference.common_inference_params import CommonInferenceParams +from megatron.core.inference.sampling_params import SamplingParams -class TestCommonInferenceParams: +class TestSamplingParams: def test_inference_params(self): - inference_parameters = CommonInferenceParams() + inference_parameters = SamplingParams() inference_parameters.add_attributes({"min_tokens": 45}) assert ( inference_parameters.min_tokens == 45 diff --git a/tests/unit_tests/inference/test_scheduler.py b/tests/unit_tests/inference/test_scheduler.py index b1f0ea184e..90caa70a7b 100644 --- a/tests/unit_tests/inference/test_scheduler.py +++ b/tests/unit_tests/inference/test_scheduler.py @@ -2,8 +2,8 @@ import torch -from megatron.core.inference.common_inference_params import CommonInferenceParams from megatron.core.inference.inference_request import InferenceRequest, Status +from megatron.core.inference.sampling_params import SamplingParams from megatron.core.inference.scheduler import Scheduler @@ -25,7 +25,7 @@ def setup_method(self, method): def test_scheduler(self): prompt = "sample prompt" prompt_tokens = torch.randn(5) - inference_parameters = CommonInferenceParams() + inference_parameters = SamplingParams() for i in range(self.max_batch_size): self.scheduler.add_request(prompt, prompt_tokens, inference_parameters) diff --git a/tests/unit_tests/inference/text_generation_controllers/test_encoder_decoder_text_generation_controller.py b/tests/unit_tests/inference/text_generation_controllers/test_encoder_decoder_text_generation_controller.py index c28d0c3432..12903a919f 100644 --- a/tests/unit_tests/inference/text_generation_controllers/test_encoder_decoder_text_generation_controller.py +++ b/tests/unit_tests/inference/text_generation_controllers/test_encoder_decoder_text_generation_controller.py @@ -10,7 +10,6 @@ import pytest import torch -from megatron.core.inference.common_inference_params import CommonInferenceParams from megatron.core.inference.inference_request import InferenceRequest, Status from megatron.core.inference.model_inference_wrappers.inference_wrapper_config import ( InferenceWrapperConfig, @@ -18,6 +17,7 @@ from megatron.core.inference.model_inference_wrappers.t5.t5_inference_wrapper import ( T5InferenceWrapper, ) +from megatron.core.inference.sampling_params import SamplingParams from megatron.core.inference.text_generation_controllers.encoder_decoder_text_generation_controller import ( EncoderDecoderTextGenerationController, ) @@ -126,7 +126,7 @@ def test_generate_all_output_tokens_static_batch(self): request_id=i, prompt=prompt, encoder_prompt=encoder_prompt, - inference_parameters=CommonInferenceParams(num_tokens_to_generate=10), + inference_parameters=SamplingParams(num_tokens_to_generate=10), arrival_time=time.time(), prompt_tokens=prompt_tokens, status=Status.ACTIVE_BUT_NOT_GENERATING_TOKENS, diff --git a/tests/unit_tests/inference/text_generation_controllers/test_simple_text_generation_controller.py b/tests/unit_tests/inference/text_generation_controllers/test_simple_text_generation_controller.py index 1e09cf05fb..1db360f232 100644 --- a/tests/unit_tests/inference/text_generation_controllers/test_simple_text_generation_controller.py +++ b/tests/unit_tests/inference/text_generation_controllers/test_simple_text_generation_controller.py @@ -9,7 +9,6 @@ import pytest import torch -from megatron.core.inference.common_inference_params import CommonInferenceParams from megatron.core.inference.inference_request import InferenceRequest, Status from megatron.core.inference.model_inference_wrappers.gpt.gpt_inference_wrapper import ( GPTInferenceWrapper, @@ -17,8 +16,9 @@ from megatron.core.inference.model_inference_wrappers.inference_wrapper_config import ( InferenceWrapperConfig, ) -from megatron.core.inference.text_generation_controllers.simple_text_generation_controller import ( - SimpleTextGenerationController, +from megatron.core.inference.sampling_params import SamplingParams +from megatron.core.inference.text_generation_controllers.text_generation_controller import ( + TextGenerationController, ) from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_local_spec from megatron.core.models.gpt.gpt_model import GPTModel @@ -28,7 +28,7 @@ from tests.unit_tests.test_utilities import Utils -class TestSimpleTextGenerationController: +class TestTextGenerationController: def setup_method(self, method): Utils.initialize_model_parallel( @@ -67,7 +67,7 @@ def setup_method(self, method): self.mock_tokenizer = mock.Mock() - self.text_generation_controller = SimpleTextGenerationController( + self.text_generation_controller = TextGenerationController( inference_wrapped_model=inference_wrapped_model, tokenizer=self.mock_tokenizer ) @@ -78,7 +78,7 @@ def test_sample_from_logits(self): with pytest.raises(AssertionError) as aerror: self.text_generation_controller.sample_from_logits( last_token_logits=None, - common_inference_params=CommonInferenceParams(top_k=2, top_p=0.4), + sampling_params=SamplingParams(top_k=2, top_p=0.4), vocab_size=self.vocab_size, ) assert str(aerror.value) == 'Cannot have top-p and top-k both greater than zero' @@ -86,7 +86,7 @@ def test_sample_from_logits(self): with pytest.raises(AssertionError) as aerror: self.text_generation_controller.sample_from_logits( last_token_logits=None, - common_inference_params=CommonInferenceParams(top_p=1.4, top_k=0), + sampling_params=SamplingParams(top_p=1.4, top_k=0), vocab_size=self.vocab_size, ) assert str(aerror.value) == 'top-p should be in (0,1]' @@ -94,7 +94,7 @@ def test_sample_from_logits(self): with pytest.raises(AssertionError) as aerror: self.text_generation_controller.sample_from_logits( last_token_logits=torch.randn(self.batch_size, 1), - common_inference_params=CommonInferenceParams(top_k=self.vocab_size + 10), + sampling_params=SamplingParams(top_k=self.vocab_size + 10), vocab_size=self.vocab_size, ) assert str(aerror.value) == 'top-k is larger than logit size.' @@ -103,14 +103,14 @@ def test_sample_from_logits(self): torch.arange(0, self.vocab_size).repeat(self.batch_size, 1).float().cuda() ) sampled_logits = self.text_generation_controller.sample_from_logits( - last_token_logits, CommonInferenceParams(top_k=1), self.vocab_size + last_token_logits, SamplingParams(top_k=1), self.vocab_size ) assert torch.all( sampled_logits.cpu() == torch.ones(self.batch_size) * self.vocab_size - 1 ), f"The sampled logits should all be {self.vocab_size} but its {sampled_logits}" sampled_logits = self.text_generation_controller.sample_from_logits( - last_token_logits, CommonInferenceParams(top_k=2), self.vocab_size + last_token_logits, SamplingParams(top_k=2), self.vocab_size ) assert torch.all( sampled_logits >= self.vocab_size - 2 @@ -120,7 +120,7 @@ def test_sample_from_logits(self): top_p = 0.3 expected_min_value = l[l.softmax(dim=-1).cumsum(dim=-1) > top_p][0].item() sampled_logits = self.text_generation_controller.sample_from_logits( - last_token_logits, CommonInferenceParams(top_p=top_p, top_k=0), self.vocab_size + last_token_logits, SamplingParams(top_p=top_p, top_k=0), self.vocab_size ) assert torch.all( sampled_logits >= expected_min_value @@ -131,7 +131,7 @@ def test_sample_from_logits(self): expected_min_value = l[l.div_(temperature).softmax(dim=-1).cumsum(dim=-1) > top_p][0].item() sampled_logits = self.text_generation_controller.sample_from_logits( last_token_logits, - CommonInferenceParams(top_p=top_p, temperature=temperature, top_k=0), + SamplingParams(top_p=top_p, temperature=temperature, top_k=0), self.vocab_size, ) assert torch.all( @@ -154,7 +154,7 @@ def test_generate_all_output_tokens_static_batch(self): inference_request = InferenceRequest( request_id=i, prompt=prompt, - inference_parameters=CommonInferenceParams(num_tokens_to_generate=10), + inference_parameters=SamplingParams(num_tokens_to_generate=10), arrival_time=time.time(), prompt_tokens=torch.randint( low=0, high=self.vocab_size - 1, size=(len(prompt),)