diff --git a/examples/inference/README.md b/examples/inference/README.md
index bd8e738e55..b4b07cbc6a 100644
--- a/examples/inference/README.md
+++ b/examples/inference/README.md
@@ -1,5 +1,5 @@
### Megatron Core Inference Documentation
-This guide will walk you through how you can use megatron core for inference on your models.
+This guide provides an example for Megatron Core for running model inference.
### Contents
- [Megatron Core Inference Documentation](#megatron-core-inference-documentation)
@@ -18,21 +18,21 @@ This guide will walk you through how you can use megatron core for inference on
#### 1. Quick Start
-This will walk you through the flow of running batch inference on a GPT model trained using megatron core. The file can be found at [simple_gpt_batch_inference.py](./gpt/simple_gpt_batch_inference.py)
+This example runs batch inference on a GPT model trained using Megatron Core. The entrypoint is [simple_gpt_batch_inference.py](./gpt/gpt_batch_inference.py)
-##### 1.1 Understanding The Code
-***STEP 1 - We initialize model parallel and other default arguments***
-We can default micro batch size to be 1, since for TP models it is not used, and for PP models it is calculated during runtime.
+##### 1.1 Code Walkthrough
+***STEP 1 - Initialize model parallel and other default arguments***
+The micro batch size is set as 1 as it is not used in tensor-parallelism only, and for pipeline-parallel models it is calculated at runtime.
```python
initialize_megatron(
args_defaults={'no_load_rng': True, 'no_load_optim': True, 'micro_batch_size': 1}
)
```
-***STEP 2 - We load the model using the model_provider_function***
-NOTE: The model provider function in the script supports MCore and Legacy models.
+***STEP 2 - Load the model using the model_provider_function***
+NOTE: The model provider function supports both MCore and Legacy models.
```python
model = get_model(model_provider, wrap_with_ddp=False)
@@ -41,10 +41,10 @@ NOTE: The model provider function in the script supports MCore and Legacy models
```
***STEP 3 - Choose an engine***
-One of the important elements of the generate function is an inference engine. In this example we will be choosing the [megatron core engine](../../megatron/core/inference/engine/mcore_engine.py) with a [simple text generation controller](../../megatron/core/inference/text_generation_controllers/simple_text_generation_controller.py), the default engine. Other engines that will be supported in the future are TRTLLMEngine.
+Text generation requires an inference engine, which includes a scheduler. The default engine is the [Megatron Core engine](../../megatron/core/inference/engine/mcore_engine.py) with a simple [text generation controller](../../megatron/core/inference/text_generation_controllers/text_generation_controller.py). TRTLLMEngine will be supported in the future.
```python
inference_wrapped_model = GPTInferenceWrapper(model, args)
- text_generation_controller = SimpleTextGenerationController(
+ text_generation_controller = TextGenerationController(
inference_wrapped_model=inference_wrapped_model,
tokenizer=tokenizer
)
@@ -53,12 +53,12 @@ One of the important elements of the generate function is an inference engine. I
)
```
-***STEP 4 - Run the generate function and display results***
-We use default values for the [common inference params](../../megatron/core/inference/common_inference_params.py). Customize this if you want to change top_p, top_k, number of tokens to generate etc.
-*Note that the result is returned as a list of [InferenceRequests](../../megatron/core/inference/inference_request.py)*
+***STEP 4 - Run text generation***
+The [SamplingParams](../../megatron/core/inference/sampling_params.py) contains suggested defaults. Customize this to change top_p, top_k, number of tokens to generate etc.
+*Note: The result is returned as a list of [InferenceRequests](../../megatron/core/inference/inference_request.py)*
```python
results: List[InferenceRequest] = inference_engine.generate(
- prompts=args.prompts, common_inference_params=common_inference_params
+ prompts=args.prompts, sampling_params=sampling_params
)
if torch.distributed.get_rank() == 0:
@@ -76,12 +76,12 @@ We use default values for the [common inference params](../../megatron/core/infe
##### 1.2 Running The Code
-An example run script is shown below. Change the tokenizer paths, inference params, and other settings for your model.
+An example run script is shown below. Set the tokenizer paths, inference params, and other settings appropriately.
-For a quick recap on inference params refer to [this blog](https://ivibudh.medium.com/a-guide-to-controlling-llm-model-output-exploring-top-k-top-p-and-temperature-parameters-ed6a31313910)
+For a quick recap on sampling parameters, refer to [this blog](https://ivibudh.medium.com/a-guide-to-controlling-llm-model-output-exploring-top-k-top-p-and-temperature-parameters-ed6a31313910).
```
-#In a slurm cluster (You could also use docker)
+# In a slurm cluster (You could also use docker)
ACCOUNT=
MLM_PATH=/path/to/megatron-lm
GPT_CKPT=/path/to/gpt/ckpt
@@ -133,8 +133,8 @@ NOTE: Other parameters which can be customized for inference are :-
--top_p (top_p sampling)
--num-tokens-to-generate (Number of tokens to generate for each prompt)
--inference-batch-times-seqlen-threshold (During inference, if batch-size times sequence-length is smaller than this threshold then we will not use pipelining, otherwise we will.')
---use-dist-ckpt (If you are using dist checkpoint format for the model)
---use-legacy-models (If you are using legacy gpt model instead of mcore gpt model)
+--use-dist-ckpt (If using dist checkpoint format for the model)
+--use-legacy-models (If using legacy gpt model instead of mcore gpt model)
```
@@ -142,16 +142,17 @@ NOTE: Other parameters which can be customized for inference are :-
-#### 2. Flow of Control In MCore Backend
-The following is what happens in the [simple_gpt_batch_inference.py](./gpt/simple_gpt_batch_inference.py).
-* We call [mcore_engine](../../megatron/core/inference/engines/mcore_engine.py) **generate()** function with all our input prompts.
-* The scheduler in the engine will add these prompts to the [active requests] pool (../../megatron/core/inference/inference_request.py) until we hit the max batch size, and then it will put the rest in the waiting requests pool.
-* The engine will then run until all requests (waiting + active) are completed
+#### 2. Control Flow in the MCore Backend
+An example of inference with static batching is provided in [gpt_batch_inference.py](./gpt/gpt_batch_inference.py).
+* [mcore_engine](../../megatron/core/inference/engines/mcore_engine.py) **generate()** function is called with the input prompts.
+* The `Scheduler` in the engine will add these prompts to the [active requests] pool (../../megatron/core/inference/inference_request.py) until max batch size is hit. Remaining requests will be added to the waiting requests pool.
+* The engine will run until all requests (waiting + active) are completed.
* The active requests are passed into **generate_all_output_tokens_static_batch()** of the text generation controller .
- * This function uses the [model_inference_wrappers](../../megatron/core/inference/model_inference_wrappers/abstract_model_inference_wrapper.py) **prep_model_for_inference()** , and then runs an auto regressive loop
- * In the auto regressive loop, the **get_batch_for_context_window()** method of the inference wrapper is called to get the required input, passes it into the **run_one_forward_step()** method, which calls the appropriate (PP, TP) model `.forward()` methods to get the output logits
- * The output logits are synchronized across all pipeline parallel ranks
- * The text generation controller obtains the log probabilities and samples tokens based on the strategy defined in the common inference parameters.
+ * This function uses the **prep_model_for_inference()** method of the [model_inference_wrappers](../../megatron/core/inference/model_inference_wrappers/abstract_model_inference_wrapper.py) and runs an autoregressive sampling loop
+ * In the autoregressive loop, the **get_batch_for_context_window()** method of the inference wrapper is called to slice out the input tokens and masks
+ * Input tokens and masks are passed it into the **run_one_forward_step()** method, which calls the model `.forward()` method to get the output logits
+ * Output logits are synchronized across all pipeline parallel ranks
+ * The text generation controller obtains the log probabilities and samples tokens based on the strategy defined in the sampling parameters.
* The sampled tokens are then appended to the input prompt tokens for the next iteration
* The **update_generation_status()** method of the text generation controller checks which prompts have finished generating or hit a stop condition
* After the inference loop, the result is detokenized and stored as an attribute of the InferenceRequest. These requests are marked as completed.
@@ -160,16 +161,18 @@ The following is what happens in the [simple_gpt_batch_inference.py](./gpt/simpl
#### 3. Customizing The Inference Pipeline
-The following guide will walk you through how you can customize different parts of the inference pipeline. There are three levels at which you can customize the pipeline.
-* **Inference engine** - Highest level of customization. Currently we support the MCore Engine. Change this to add a new engine.
-* **Text generation controller** - Extend this to customize tokenization, detokenization, or implement a new sampling strategy.
+
+The inference pipeline supports three levels of customization:
+
+* **Inference engine** - The MCore Engine is currently supported. Change this to add a new backend.
+* **Text generation controller** - The main sampling loop. This can be customized to support alternative tokenization, detokenization, or to implement a new sampling strategy.
* **Inference Wrapped Model** - Change this to support a new model.
* **Modify Inference Parameters** - Change this to update top_p, top_k, number of tokens to be generated, temperature, or other sampling parameters.
##### 3.1. Create Your Own Inference Backend
-This is the highest level of customization. The [abstract_engine.py](./../../megatron/core/inference/engine/abstract_engine.py) file has a generate method that can be extended to support a new backend.
+The [abstract_engine.py](./../../megatron/core/inference/engine/abstract_engine.py) file contains a `generate` method that can be extended to support a new backend.
```python
class AbstractEngine(ABC):
@@ -177,15 +180,17 @@ class AbstractEngine(ABC):
def generate(self) -> dict:
"""The abstract backend's generate function.
- To define your own backend, make sure you implement this and return the outputs as a dictionary .
-
+ To define a new backend, implement this method and return the outputs as a dictionary.
+```
-##### 3.2. Create Your Own Text Generation Controller
-In case you want to use the megatron core backend, but would like to overwrite the tokenization, text generation or detokenization extend the [simple_text_generation_controller.py](../../megatron/core/inference/text_generation_controllers/simple_text_generation_controller.py). The class has the following methods
+##### 3.2. Implement a new Sampling Loop
+
+The [TextGenerationController](../../megatron/core/inference/text_generation_controllers/text_generation_controller.py) contains the main sampling loop and can be modified to support new tokenization, detokenization, or sampling strategies.
+
``` python
-class SimpleTextGenerationController:
+class TextGenerationController:
def tokenize_prompt(self, prompt: str) -> Tuple[torch.Tensor, torch.Tensor]:
"""Utility to tokenize the input prompts"""
@@ -193,12 +198,12 @@ class SimpleTextGenerationController:
def sample_from_logits(
self,
last_token_logits: torch.Tensor,
- common_inference_params: CommonInferenceParams,
+ sampling_params: SamplingParams,
vocab_size: int,
) -> torch.Tensor:
"""Samples the logits to generate outputs
- Given the logits of the last token, this function samples it according to the parameters defined in common_inference_params and returns the samples
+ Given the logits of the last token, this function samples according to the parameters defined in sampling_params and returns the sampled tokens.
"""
def update_generation_status(
@@ -229,12 +234,12 @@ class SimpleTextGenerationController:
##### 3.3. Support Other Models
-In order to support other models please extend the [abstract_model_inference_wrapper.py](./../../megatron/core/inference/model_inference_wrappers/abstract_model_inference_wrapper.py) file. The abstract wrapper already supports the following :
-* Forward method which automatically calls the appropriate forward method (PP or TP etc) depending on model parallel settings
-* Initalizes the model and puts it in eval mode
-* Obtains the input parameters (batch size, max seq length) and has an instance of the input
+Extend [abstract_model_inference_wrapper.py](./../../megatron/core/inference/model_inference_wrappers/abstract_model_inference_wrapper.py) to support other models. The abstract model wrapper implements:
+* Forward method which calls the model `forward` method depending on model parallel settings
+* Initializes the model and puts it in `.eval()` mode
+* Setup for the input parameters (max batch size, max seq length)
-The main methods to change for your model might be the following:
+The following methods should be implemented:
```python
class AbstractModelInferenceWrapper:
def prep_model_for_inference(self, prompts_tokens: torch.Tensor):
@@ -247,28 +252,28 @@ class AbstractModelInferenceWrapper:
def get_batch_for_context_window(self) -> List:
"""Returns the input data for inference
- This function gets called iteratively in the inference loop . It can be used to extract relevant input from the prompt tokens, attention mask etc. required for each step in inference.
+ This function gets called iteratively in the inference loop. It can be used to extract relevant input from the prompt tokens, attention mask etc. required for each step in inference.
```
-Refer to [gpt_inference_wrapper.py](../../megatron/core/inference/model_inference_wrappers/gpt/gpt_inference_wrapper.py) for an example of extending this for GPTModel.
+Refer to [gpt_inference_wrapper.py](../../megatron/core/inference/model_inference_wrappers/gpt/gpt_inference_wrapper.py) for an example of implementing this for GPTModel.
##### 3.3. Modify Inference Parameters
-We use [common inference params](../../megatron/core/inference/common_inference_params.py) for text generation. Customize this if you want to change top_p, top_k, number of tokens to generate etc. If you want to add other attributes that you would use in the inference loop, you can do that as shown below
+We use [common inference params](../../megatron/core/inference/sampling_params.py) for text generation. Customize this if you want to change top_p, top_k, number of tokens to generate etc. If you want to add other attributes that you would use in the inference loop, you can do that as shown below
```
-from megatron.core.inference.common_inference_params import CommonInferenceParams
+from megatron.core.inference.sampling_params import SamplingParams
-c = CommonInferenceParams(temperature=0.5)
+c = SamplingParams(temperature=0.5)
c.add_attributes({'min_length':4, 'eod_id':153})
```
#### 4. Future work
-The following are planned for the future releases .
+The following features are planned for the future releases.
* Dynamic batching
* Paged Attention
* TRTLLM Engine support
-* Support for Multimodal model inference
\ No newline at end of file
+* Support for multimodal inference
\ No newline at end of file
diff --git a/examples/inference/gpt/simple_gpt_batch_inference.py b/examples/inference/gpt/gpt_batch_inference.py
similarity index 91%
rename from examples/inference/gpt/simple_gpt_batch_inference.py
rename to examples/inference/gpt/gpt_batch_inference.py
index 5c7ae5bd77..050b230cef 100644
--- a/examples/inference/gpt/simple_gpt_batch_inference.py
+++ b/examples/inference/gpt/gpt_batch_inference.py
@@ -6,10 +6,10 @@
from argparse import Namespace
from megatron.core.inference.engines.abstract_engine import AbstractEngine
from megatron.core.inference.engines.mcore_engine import MCoreEngine
-from megatron.core.inference.common_inference_params import CommonInferenceParams
+from megatron.core.inference.sampling_params import SamplingParams
from megatron.core.inference.model_inference_wrappers.gpt.gpt_inference_wrapper import GPTInferenceWrapper
from megatron.core.inference.inference_request import InferenceRequest
-from megatron.core.inference.text_generation_controllers.simple_text_generation_controller import SimpleTextGenerationController
+from megatron.core.inference.text_generation_controllers.text_generation_controller import TextGenerationController
from megatron.core.transformer.module import MegatronModule
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__),
os.path.pardir, os.path.pardir)))
@@ -66,7 +66,7 @@ def get_inference_engine(args: Namespace, model: MegatronModule) -> AbstractEngi
)
inference_wrapped_model = GPTInferenceWrapper(model, inference_wrapper_config)
- text_generation_controller = SimpleTextGenerationController(inference_wrapped_model=inference_wrapped_model, tokenizer=tokenizer)
+ text_generation_controller = TextGenerationController(inference_wrapped_model=inference_wrapped_model, tokenizer=tokenizer)
return MCoreEngine(text_generation_controller=text_generation_controller, max_batch_size=args.max_batch_size)
def main():
@@ -89,7 +89,7 @@ def main():
inference_engine = get_inference_engine(args, model)
- common_inference_params = CommonInferenceParams(
+ sampling_params = SamplingParams(
temperature=args.temperature,
top_k=args.top_k,
top_p=args.top_p,
@@ -97,7 +97,7 @@ def main():
num_tokens_to_generate=args.num_tokens_to_generate)
results: List[InferenceRequest] = inference_engine.generate(
- prompts=args.prompts, common_inference_params=common_inference_params
+ prompts=args.prompts, sampling_params=sampling_params
)
if torch.distributed.get_rank() == 0:
diff --git a/examples/inference/t5/simple_t5_batch_inference.py b/examples/inference/t5/simple_t5_batch_inference.py
index 3f4557d3c2..b4226d7de0 100644
--- a/examples/inference/t5/simple_t5_batch_inference.py
+++ b/examples/inference/t5/simple_t5_batch_inference.py
@@ -5,7 +5,7 @@
import torch
import pretrain_t5
-from megatron.core.inference.common_inference_params import CommonInferenceParams
+from megatron.core.inference.sampling_params import SamplingParams
from megatron.core.inference.engines.abstract_engine import AbstractEngine
from megatron.core.inference.engines.mcore_engine import MCoreEngine
from megatron.core.inference.inference_request import InferenceRequest
@@ -120,7 +120,7 @@ def main():
inference_engine = get_inference_engine(args, model)
- common_inference_params = CommonInferenceParams(
+ sampling_params = SamplingParams(
temperature=args.temperature,
top_k=args.top_k,
top_p=args.top_p,
@@ -138,7 +138,7 @@ def main():
prompts=args.prompts,
add_BOS=True,
encoder_prompts=args.encoder_prompts,
- common_inference_params=common_inference_params,
+ sampling_params=sampling_params,
)
if torch.distributed.get_rank() == 0:
diff --git a/megatron/core/inference/common_inference_params.py b/megatron/core/inference/common_inference_params.py
index 22353088f8..7955bb6fc1 100644
--- a/megatron/core/inference/common_inference_params.py
+++ b/megatron/core/inference/common_inference_params.py
@@ -1,29 +1,4 @@
-# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
-from dataclasses import dataclass
-
-
-@dataclass
-class CommonInferenceParams:
- """Inference parameters sent along with the prompts
-
- For an explanation of these parameters refer to this blog https://ivibudh.medium.com/a-guide-to-controlling-llm-model-output-exploring-top-k-top-p-and-temperature-parameters-ed6a31313910
- """
-
- temperature: float = 1.0
- top_k: int = 0
- top_p: float = 0.0
- return_log_probs: bool = False
- num_tokens_to_generate: int = 30
-
- def add_attributes(self, attribute_value_pair: dict):
- """Utility to add more attributes to inference params
-
- Use this method to pass in a custom dictonary to add more inference parameter attributes to the instance you created. Use as follows
- c = CommonInferenceParams
- c.add_attributes({'min_length':4, 'eod_id':153})
-
- Args:
- attribute_value_pair (dict): A dictionary containing attributes as the key names and their values as the values.
- """
- for key, value in attribute_value_pair.items():
- setattr(self, key, value)
+# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
+from megatron.core.inference.sampling_params import ( # noqa: F401 # pylint: disable=unused-import
+ SamplingParams as CommonInferenceParams,
+)
diff --git a/megatron/core/inference/engines/mcore_engine.py b/megatron/core/inference/engines/mcore_engine.py
index fe8160228b..28ef46bf92 100644
--- a/megatron/core/inference/engines/mcore_engine.py
+++ b/megatron/core/inference/engines/mcore_engine.py
@@ -3,12 +3,12 @@
import torch
-from megatron.core.inference.common_inference_params import CommonInferenceParams
from megatron.core.inference.engines.abstract_engine import AbstractEngine
from megatron.core.inference.inference_request import InferenceRequest
+from megatron.core.inference.sampling_params import SamplingParams
from megatron.core.inference.scheduler import Scheduler
-from megatron.core.inference.text_generation_controllers.simple_text_generation_controller import (
- SimpleTextGenerationController,
+from megatron.core.inference.text_generation_controllers.text_generation_controller import (
+ TextGenerationController,
)
@@ -19,7 +19,7 @@ class MCoreEngine(AbstractEngine):
Supports any model that is callable (Accepts the inputs and outputs the tensor)
Args:
- text_generation_controller (SimpleTextGenerationController): A text generation
+ text_generation_controller (TextGenerationController): A text generation
controller that will be used to define how to preprocess prompts, generate
outputs and detokenizer the output tokens.
max_batch_size : The maxinum number of requests to process at once
@@ -29,7 +29,7 @@ class MCoreEngine(AbstractEngine):
def __init__(
self,
- text_generation_controller: SimpleTextGenerationController,
+ text_generation_controller: TextGenerationController,
max_batch_size,
random_seed: int = None,
):
@@ -42,7 +42,8 @@ def generate(
prompts: List[str],
add_BOS: bool = False,
encoder_prompts: List[str] = None,
- common_inference_params: CommonInferenceParams = None,
+ common_inference_params: SamplingParams = None,
+ sampling_params: SamplingParams = None,
) -> dict:
"""The megatron core inference backend generate function
@@ -54,13 +55,19 @@ def generate(
prompts (List[str]): All the prompts as a list of strings
add_BOS (bool): Whether to add BOS token to beginning of prompts
encoder_prompts (List[dict]): All the encoder prompts as a list of strings
- common_inference_params (CommonInferenceParams): The inference parameters
+ common_inference_params: Deprecated. Only used for backward compatibility with
+ MCore <= 0.9.0. Use `sampling_params` going forward.
+ sampling_params (SamplingParams): The request-level sampling parameters
Returns:
List[InferenceRequest]: The output is list of inference requests containing the
generated tokens, texts and log probs if required
"""
# TODO :M core- get rng state tracker
+
+ if common_inference_params:
+ sampling_params = common_inference_params
+
if self.random_seed:
torch.random.manual_seed(self.random_seed)
@@ -73,7 +80,7 @@ def generate(
prompt=prompt,
prompt_tokens=prompt_tokens,
encoder_prompt=encoder_prompt,
- inference_parameters=common_inference_params,
+ inference_parameters=sampling_params,
)
self.run_engine()
diff --git a/megatron/core/inference/inference_request.py b/megatron/core/inference/inference_request.py
index 4825dfd366..ea0d67bfea 100644
--- a/megatron/core/inference/inference_request.py
+++ b/megatron/core/inference/inference_request.py
@@ -5,7 +5,7 @@
import torch
-from megatron.core.inference.common_inference_params import CommonInferenceParams
+from megatron.core.inference.sampling_params import SamplingParams
# class syntax
@@ -28,7 +28,7 @@ class InferenceRequest:
request_id: str
prompt: str
- inference_parameters: CommonInferenceParams
+ inference_parameters: SamplingParams
prompt_tokens: List[int]
arrival_time: float
status: Status
diff --git a/megatron/core/inference/sampling_params.py b/megatron/core/inference/sampling_params.py
new file mode 100644
index 0000000000..8ffcb6321d
--- /dev/null
+++ b/megatron/core/inference/sampling_params.py
@@ -0,0 +1,35 @@
+# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
+from dataclasses import dataclass
+
+
+@dataclass
+class SamplingParams:
+ """Inference parameters sent along with the prompts.
+ This class contains request-level attributes that control the sampling techniques used when
+ generating text. This is distinct from megatron.core.InferenceParams, which is sets model-level
+ inference attributes such as the maximum sequence length, and contains the KV cache.
+
+ For an explanation of these parameters refer to this blog
+ https://ivibudh.medium.com/a-guide-to-controlling-llm-model-output-exploring-top-k-top-p-and-
+ temperature-parameters-ed6a31313910
+ """
+
+ temperature: float = 1.0
+ top_k: int = 0
+ top_p: float = 0.0
+ return_log_probs: bool = False
+ num_tokens_to_generate: int = 30
+
+ def add_attributes(self, attribute_value_pair: dict):
+ """Utility to add more attributes to sampling params
+
+ Use this method to pass in a custom dictionary to add more sampling parameter attributes.
+ c = SamplingParams
+ c.add_attributes({'min_length':4, 'eod_id':153})
+
+ Args:
+ attribute_value_pair (dict): A dictionary containing attributes as the key names and
+ their values as the values.
+ """
+ for key, value in attribute_value_pair.items():
+ setattr(self, key, value)
diff --git a/megatron/core/inference/scheduler.py b/megatron/core/inference/scheduler.py
index 00ab81b4ab..ef177232b4 100644
--- a/megatron/core/inference/scheduler.py
+++ b/megatron/core/inference/scheduler.py
@@ -6,8 +6,8 @@
import torch
-from megatron.core.inference.common_inference_params import CommonInferenceParams
from megatron.core.inference.inference_request import InferenceRequest, Status
+from megatron.core.inference.sampling_params import SamplingParams
from megatron.core.inference.utils import Counter
@@ -33,7 +33,7 @@ def add_request(
prompt: str,
prompt_tokens: torch.Tensor,
encoder_prompt: str = None,
- inference_parameters: CommonInferenceParams = None,
+ inference_parameters: SamplingParams = None,
arrival_time: float = None,
):
"""Add an incoming request
@@ -45,7 +45,7 @@ def add_request(
prompt (str): Input prompt string
prompt_tokens (torch.Tensor): A torch tensor having the input prompts tokenized
encoder_prompt (str): Encoder input string
- inference_parameters (CommonInferenceParams): The inference parameters
+ inference_parameters (SamplingParams): The inference parameters
arrival_time (float, optional): The incoming request time. Defaults to None.
"""
request_id = str(next(self.request_counter))
diff --git a/megatron/core/inference/text_generation_controllers/encoder_decoder_text_generation_controller.py b/megatron/core/inference/text_generation_controllers/encoder_decoder_text_generation_controller.py
index 61beff0211..0c2a41be44 100644
--- a/megatron/core/inference/text_generation_controllers/encoder_decoder_text_generation_controller.py
+++ b/megatron/core/inference/text_generation_controllers/encoder_decoder_text_generation_controller.py
@@ -4,15 +4,15 @@
import torch
from megatron.core.inference.inference_request import InferenceRequest
-from megatron.core.inference.text_generation_controllers.simple_text_generation_controller import (
- SimpleTextGenerationController,
+from megatron.core.inference.text_generation_controllers.text_generation_controller import (
+ TextGenerationController,
)
-class EncoderDecoderTextGenerationController(SimpleTextGenerationController):
+class EncoderDecoderTextGenerationController(TextGenerationController):
"""The text generation controller for encoder-decoder architecture
- This class ingherits from SimpleTextGenerationController, adding features
+ This class inherits from TextGenerationController, adding features
relating to encoder input encoder_prompt
"""
diff --git a/megatron/core/inference/text_generation_controllers/simple_text_generation_controller.py b/megatron/core/inference/text_generation_controllers/simple_text_generation_controller.py
index ceea4064d2..f97df13249 100644
--- a/megatron/core/inference/text_generation_controllers/simple_text_generation_controller.py
+++ b/megatron/core/inference/text_generation_controllers/simple_text_generation_controller.py
@@ -1,400 +1,5 @@
-# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
-from typing import List, OrderedDict, Tuple
+# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
-import torch
-import torch.nn.functional as F
-
-from megatron.core import parallel_state
-from megatron.core.inference.common_inference_params import CommonInferenceParams
-from megatron.core.inference.communication_utils import broadcast_from_last_pipeline_stage
-from megatron.core.inference.inference_request import InferenceRequest, Status
-from megatron.core.inference.model_inference_wrappers.abstract_model_inference_wrapper import (
- AbstractModelInferenceWrapper,
+from megatron.core.inference.text_generation_controllers.text_generation_controller import ( # noqa: F401 # pylint: disable=unused-import
+ TextGenerationController as SimpleTextGenerationController,
)
-
-
-class SimpleTextGenerationController:
- """The basic text generation controller
-
- This class is responsible for tokenizing the input , running the inference, sampling
- and also detokenizing the output
-
- Args:
- inference_wrapped_model (AbstractModelInferenceWrapper): A model that
- is wrapped using the specs given in the abstract_model_inference_wrapper.py
- tokenizer (_type_): Tokenizer used for tokenizing and detokenizing the prompts
- """
-
- def __init__(self, inference_wrapped_model: AbstractModelInferenceWrapper, tokenizer):
- self.inference_wrapped_model = inference_wrapped_model
- self.tokenizer = tokenizer
-
- # For models without pipeline parallelism, is_first_stage and is_last_stage returns True
- self.model_is_pipeline_parallel = not (
- parallel_state.is_pipeline_first_stage() and parallel_state.is_pipeline_last_stage()
- )
-
- def tokenize_prompt(
- self, prompt: str, add_BOS: bool = False
- ) -> Tuple[torch.Tensor, torch.Tensor]:
- """Utility to tokenize the input prompts
-
- Args:
- prompt (str): The input prompt
-
- Returns:
- torch.Tensor: Returns the tokenized prompt
- """
- prompt_tokens = self.tokenizer.tokenize(prompt)
-
- if add_BOS:
- prompt_tokens = [self.tokenizer.bos] + prompt_tokens
-
- return prompt_tokens
-
- def detokenize_generations(self, prompt_tokens_with_generated_tokens: torch.Tensor) -> str:
- """Detokenize the output generations
-
- Args:
- prompt_tokens_with_generated_tokens (torch.Tensor): The input prompt
- tokens plus the generated tokens
-
- Returns:
- str: The detokenized output
- """
- tokens = prompt_tokens_with_generated_tokens.cpu().numpy().tolist()
- return self.tokenizer.detokenize(tokens)
-
- def sample_from_logits(
- self,
- last_token_logits: torch.Tensor,
- common_inference_params: CommonInferenceParams,
- vocab_size: int = None,
- ) -> torch.Tensor:
- """Samples the logits to generate outputs
-
- Given the logits of the last token, this function samples it
- according to the parameters defined in common_inference_params
- and returns the samples
-
- Args:
- last_token_logits (torch.Tensor): The last token logits. A tensor of
- size [batch_size, vocab_size]
- common_inference_params (CommonInferenceParams): The paramters to use
- for inference
- vocab_size (int): Obtained from the tokenizer. Defaults to None
-
- Returns:
- torch.Tensor: 1D tensor of the sampled logits with [batch_size] elements
- """
-
- top_p = common_inference_params.top_p
- top_k = common_inference_params.top_k
- temperature = common_inference_params.temperature
-
- assert not (top_k > 0 and top_p > 0), 'Cannot have top-p and top-k both greater than zero'
- assert top_p <= 1.0, 'top-p should be in (0,1]'
-
- def modify_logits_for_top_k_filtering(logits, top_k):
- """Set the logits for none top-k values to -inf."""
- filter_ = logits < torch.topk(logits, top_k)[0][..., -1, None]
- logits.masked_fill_(filter_, float('-Inf'))
-
- def modify_logits_for_top_p_filtering(logits, top_p):
- """Set the logits for none top-p values to -inf."""
- # First sort and calculate cumulative sum of probabilities.
- sorted_logits, sorted_indices = torch.sort(logits, descending=True)
- cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
-
- # Filteration based on the cumulative sum.
- filter_ = cumulative_probs > top_p
- # This shift by 1 is weird and I cannot justify it. This existed
- # in the original implementation:
- # https://github.com/ari-holtzman/degen/blob/master/gen.py
- # and I guess it is needed so keeping it for now.
- filter_[:, 1:] = filter_[:, :-1].clone()
- # Make sure we at least have one token to select from.
- filter_[..., 0] = 0
-
- # Fill in the filtered part
- filter_ = filter_.scatter(1, sorted_indices, filter_)
- logits.masked_fill_(filter_, float('-Inf'))
-
- # Greedy sampling
- if top_k == 1:
- sampled_logits = torch.argmax(last_token_logits, dim=-1)
- else:
- last_token_logits = last_token_logits.clone()
- if temperature != 1.0:
- last_token_logits.div_(temperature)
-
- if top_k > 1:
- assert top_k <= last_token_logits.size(1), 'top-k is larger than logit size.'
- if vocab_size:
- assert top_k < vocab_size, 'top-k is larger than vocab size.'
- modify_logits_for_top_k_filtering(last_token_logits, top_k)
-
- elif top_p > 0.0:
- modify_logits_for_top_p_filtering(last_token_logits, top_p)
-
- # After filtering, we need to recalculate the distribution.
- probabilities = last_token_logits.softmax(dim=-1)
- sampled_logits = torch.multinomial(probabilities, num_samples=1).view(-1)
-
- # If vocab size is provided, make sure the samples are in in the range [0, vocab-size).
- if vocab_size:
- sampled_logits = torch.clamp(sampled_logits, min=0, max=(vocab_size - 1))
- return sampled_logits
-
- def update_generation_status(
- self,
- updated_prompts_tokens: torch.Tensor,
- generation_started: torch.Tensor,
- current_context_end_position: int,
- is_generation_done_tensor: torch.Tensor,
- generated_sequence_lengths: torch.Tensor,
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
- """Checks which prompts have reached an end condition
-
- We check which prompts have reached an end condition and set the corresponding
- flags of the is_generation_done_tensor to True. The generated sequence lengths
- increase as we keep generating, until that prompts hits an end condition. The
- generation_started tensor determines which prompts have started generating.
-
- Args:
- updated_prompts_tokens (torch.Tensor): The prompts tokens updated with the latest
- generated tokens. A tensor of shape [batch_size, max_seq_len]
- (i.e max_seq_len = max_prompt_len + tokens_to_generate)
- generation_started (torch.Tensor): A boolean tensor of shape [batch_size]. True
- indicates the prompt at that index has started generating tokens.
- current_context_end_position (int): An integer indicating which position to
- extract from the prompts tokens to get the latest generated tokens.
- is_generation_done_tensor (torch.Tensor): A boolean tensor of shape [batch_size].
- True indicates the prompt at that index has reached end condition.
- generated_sequence_lengths (torch.Tensor): A int tensor of shape [batch_size].
- Each value represents the generated sequence lengths for that prompt.
-
- Returns:
- Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Returns the boolean
- is_generation_done_tensor and the generated_sequence_lengths after updating it
- """
- latest_samples = updated_prompts_tokens[:, current_context_end_position]
- # Make sure we are checking eod criterion only for prompts that have started generating
- # (i.e) We only look at the generated tokenns and not the input tokens.
- reached_eod = (latest_samples == self.tokenizer.eod) & generation_started
- is_generation_done_tensor = is_generation_done_tensor | reached_eod
- # We increment generated sequence lengths when that prompt has not hit the
- # EOD and generation has started
- generated_sequence_lengths += ~is_generation_done_tensor & generation_started
-
- return is_generation_done_tensor, generated_sequence_lengths
-
- def pad_input_prompt_tokens(
- self,
- batch_prompt_tokens_list: List[List[int]],
- max_prompt_length_in_batch: int,
- num_tokens_to_generate: int,
- ) -> torch.Tensor:
- """Method to pad input prompts
-
- Given a list of prompts, pad them all to uniform length
-
- Args:
- batch_prompt_tokens_list (List[List[int]]): A list containing the prompt tokens
- max_prompt_length_in_batch (int): Maximum of the length of the input prompt tokens
- num_tokens_togenerate (int): The number of tokens to generate for each prompt
-
- Returns:
- torch.Tensor: A torch tensor of shape [bs, max_seq_len] (i.e)
- max_seq_len = max_prompt_length_in_batch + num_tokens_to_generate,
- with extra indices for each tensor padded with mask id.
- """
- max_seq_len = max_prompt_length_in_batch + num_tokens_to_generate
-
- for prompt_tokens in batch_prompt_tokens_list:
- padding_size = max_seq_len - len(prompt_tokens)
- prompt_tokens.extend([self.tokenizer.eod] * padding_size)
-
- return torch.tensor(batch_prompt_tokens_list).cuda()
-
- def generate_output_tokens_dynamic_batch(
- self, active_requests: OrderedDict[int, InferenceRequest]
- ) -> OrderedDict[int, InferenceRequest]:
- """Utility to generate the output tokens and probabilities for the prompts
-
- This utility generates the output tokens for a dynamic batch. It will run one forward step
- at a time, and pass control back to the engine, which will update the request pool and call
- this method again.
-
- Args:
- active_requests (OrderedDict[int, InferenceRequest]): The input active requests.
-
- Returns:
- OrderedDict[int, InferenceRequest]: The result for each of the incoming requests
- after running one forward step.
- """
- raise Exception("Not implemented yet")
-
- def generate_all_output_tokens_static_batch(
- self, active_requests: OrderedDict[int, InferenceRequest]
- ) -> OrderedDict[int, InferenceRequest]:
- """Utility to generate the all the output tokens and probabilities for the prompts .
-
- This utility generates the output tokens for a static batch. It runs the forward steps till
- all prompts complete generation, updates the status of these requests to completed, adds
- the generated result and returns these requests
-
- Args:
- active_requests (OrderedDict[int, InferenceRequest]): The input active requests.
-
- Returns:
- OrderedDict[int, InferenceRequest]: The result for each of the incoming requests
- """
- batch_prompt_tokens_list = list(
- map(lambda request: request.prompt_tokens, active_requests.values())
- )
- prompt_lengths_in_batch = torch.tensor(
- [len(prompt_tokens) for prompt_tokens in batch_prompt_tokens_list]
- ).cuda()
- max_prompt_length_in_batch = max(prompt_lengths_in_batch)
- min_prompt_length_in_batch = min(prompt_lengths_in_batch)
-
- # For batch inference the inference params are the same for all request
- common_inference_params: CommonInferenceParams = list(active_requests.values())[
- 0
- ].inference_parameters
-
- # max_seq_len = max_prompt_length_in_batch + num_tokens_to_generate
- batch_prompt_tokens = self.pad_input_prompt_tokens(
- batch_prompt_tokens_list,
- max_prompt_length_in_batch=max_prompt_length_in_batch,
- num_tokens_to_generate=common_inference_params.num_tokens_to_generate,
- )
- batch_size, max_sequence_length = batch_prompt_tokens.shape
-
- # Pre allocate log probs tensor
- output_log_probs = None
- if common_inference_params.return_log_probs:
- output_log_probs = torch.empty(
- (batch_size, max_sequence_length - 1), dtype=torch.float32
- ).cuda()
-
- # An array to check which of the prompts have reached end of generation condition
- is_generation_done_tensor = torch.zeros(batch_size, dtype=torch.bool).cuda()
-
- # An array to act as a counter to keep track of generated sequence lengths
- generated_sequence_lengths = torch.zeros(batch_size).cuda()
-
- with torch.no_grad():
-
- self.prep_model_for_inference(
- prompts_tokens=batch_prompt_tokens, active_requests=active_requests
- )
-
- context_start_position = 0
- # Pick the context window that we need to pass through the network.
- for context_end_position in range(min_prompt_length_in_batch, max_sequence_length):
-
- inference_input = self.inference_wrapped_model.get_batch_for_context_window(
- context_start_position, context_end_position
- )
-
- # Returns the final logits of shape [batch_size, context_length, vocab_size]
- # Note: This is returned in all TP ranks or last PP stage in PP models
- logits = self.inference_wrapped_model.run_one_forward_step(inference_input)
- if self.model_is_pipeline_parallel:
- context_length = context_end_position - context_start_position
- logits = broadcast_from_last_pipeline_stage(
- [batch_size, context_length, self.inference_wrapped_model.model.vocab_size],
- dtype=self.inference_wrapped_model.inference_wrapper_config.params_dtype,
- tensor=logits,
- )
-
- # Indicates which of the input prompts have started generating tokens.
- # A 1D boolean tensor with [batch_size] elements (i.e) The shortest
- # prompts will start generating first and so on
- generation_started = prompt_lengths_in_batch <= context_end_position
- last_token_logits = logits[:, -1, :]
- sampled_logits = self.sample_from_logits(
- last_token_logits, common_inference_params, self.inference_wrapped_model.model.vocab_size
- )
-
- # Substitute the sampled logits only for only the prompts that
- # have started generating tokens
- batch_prompt_tokens[generation_started, context_end_position] = sampled_logits[
- generation_started
- ]
-
- if common_inference_params.return_log_probs:
- log_probs = F.log_softmax(logits, dim=2)
- indices = torch.unsqueeze(
- batch_prompt_tokens[
- :, (context_start_position + 1) : (context_end_position + 1)
- ],
- 2,
- )
- # Get the log probabilities for only the prompt tokens
- output_log_probs[:, context_start_position:context_end_position] = torch.gather(
- log_probs, 2, indices
- ).squeeze(2)
-
- context_start_position = context_end_position
-
- # Check end of generation status for each tensor
- # and update generated sequence lengths
- (is_generation_done_tensor, generated_sequence_lengths) = (
- self.update_generation_status(
- updated_prompts_tokens=batch_prompt_tokens,
- generation_started=generation_started,
- current_context_end_position=context_end_position,
- is_generation_done_tensor=is_generation_done_tensor,
- generated_sequence_lengths=generated_sequence_lengths,
- )
- )
- # Boolean flag indicating if all prompts are finished
- all_prompts_done = torch.all(is_generation_done_tensor)
- if all_prompts_done:
- break
-
- # Include all the generated tokens
- batch_prompt_tokens_with_generations = batch_prompt_tokens[:, : (context_end_position + 1)]
- if common_inference_params.return_log_probs:
- output_log_probs = output_log_probs[:, :context_end_position]
-
- generated_sequence_lengths[
- generated_sequence_lengths > common_inference_params.num_tokens_to_generate
- ] = common_inference_params.num_tokens_to_generate
-
- for idx, request in enumerate(active_requests.values()):
- input_prompt_length = int(prompt_lengths_in_batch[idx])
- # Shorter prompts might have generated more than required tokens. So we trim them down
- required_sequence_length = int(
- min(generated_sequence_lengths[idx], common_inference_params.num_tokens_to_generate)
- )
- # Extract only the generated tokens
- required_result_tokens = batch_prompt_tokens_with_generations[
- idx, input_prompt_length : (input_prompt_length + required_sequence_length)
- ]
-
- request.generated_length = required_sequence_length
- request.generated_tokens = required_result_tokens
- request.generated_log_probs = (
- None
- if output_log_probs is None
- else output_log_probs[idx, input_prompt_length:required_sequence_length]
- )
- request.status = Status.COMPLETED
- request.generated_text = self.detokenize_generations(required_result_tokens)
-
- return active_requests
-
- def prep_model_for_inference(
- self, prompts_tokens: torch.Tensor, active_requests: OrderedDict[int, InferenceRequest]
- ):
- """Preparing batch for inference, using respective wrapper's prep_model_for_inference method
-
- Args:
- prompts_tokens (torch.Tensor): A tensor of shape [batch_size, max_sequence_length]
- active_requests (OrderedDict[int, InferenceRequest]): The input active requests
- """
- self.inference_wrapped_model.prep_model_for_inference(prompts_tokens=prompts_tokens)
diff --git a/megatron/core/inference/text_generation_controllers/text_generation_controller.py b/megatron/core/inference/text_generation_controllers/text_generation_controller.py
new file mode 100644
index 0000000000..f15c819c43
--- /dev/null
+++ b/megatron/core/inference/text_generation_controllers/text_generation_controller.py
@@ -0,0 +1,400 @@
+# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
+from typing import List, OrderedDict, Tuple
+
+import torch
+import torch.nn.functional as F
+
+from megatron.core import parallel_state
+from megatron.core.inference.communication_utils import broadcast_from_last_pipeline_stage
+from megatron.core.inference.inference_request import InferenceRequest, Status
+from megatron.core.inference.model_inference_wrappers.abstract_model_inference_wrapper import (
+ AbstractModelInferenceWrapper,
+)
+from megatron.core.inference.sampling_params import SamplingParams
+
+
+class TextGenerationController:
+ """The text generation controller (the main sampling loop)
+
+ This class tokenizes the input, runs inference, samples from logits, and detokenizes the output.
+
+ Args:
+ inference_wrapped_model (AbstractModelInferenceWrapper): A model that
+ is wrapped using the specs given in the abstract_model_inference_wrapper.py
+ tokenizer (_type_): Tokenizer used for tokenizing and detokenizing the prompts
+ """
+
+ def __init__(self, inference_wrapped_model: AbstractModelInferenceWrapper, tokenizer):
+ self.inference_wrapped_model = inference_wrapped_model
+ self.tokenizer = tokenizer
+
+ # For models without pipeline parallelism, is_first_stage and is_last_stage returns True
+ self.model_is_pipeline_parallel = not (
+ parallel_state.is_pipeline_first_stage() and parallel_state.is_pipeline_last_stage()
+ )
+
+ def tokenize_prompt(
+ self, prompt: str, add_BOS: bool = False
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Utility to tokenize the input prompts
+
+ Args:
+ prompt (str): The input prompt
+
+ Returns:
+ torch.Tensor: Returns the tokenized prompt
+ """
+ prompt_tokens = self.tokenizer.tokenize(prompt)
+
+ if add_BOS:
+ prompt_tokens = [self.tokenizer.bos] + prompt_tokens
+
+ return prompt_tokens
+
+ def detokenize_generations(self, prompt_tokens_with_generated_tokens: torch.Tensor) -> str:
+ """Detokenize the output generations
+
+ Args:
+ prompt_tokens_with_generated_tokens (torch.Tensor): The input prompt
+ tokens plus the generated tokens
+
+ Returns:
+ str: The detokenized output
+ """
+ tokens = prompt_tokens_with_generated_tokens.cpu().numpy().tolist()
+ return self.tokenizer.detokenize(tokens)
+
+ def sample_from_logits(
+ self,
+ last_token_logits: torch.Tensor,
+ sampling_params: SamplingParams = None,
+ vocab_size: int = None,
+ **kwargs
+ ) -> torch.Tensor:
+ """Samples the logits to generate outputs
+
+ Given the logits of the last token, this function samples it
+ according to the parameters defined in sampling_params
+ and returns the samples
+
+ Args:
+ last_token_logits (torch.Tensor): The last token logits. A tensor of
+ size [batch_size, vocab_size]
+ sampling_params (SamplingParams): The parameters to use for inference.
+ vocab_size (int): Obtained from the tokenizer. Defaults to None
+
+ Returns:
+ torch.Tensor: 1D tensor of the sampled logits with [batch_size] elements
+ """
+
+ if kwargs.get('common_inference_params'):
+ sampling_params = kwargs['common_inference_params']
+
+ top_p = sampling_params.top_p
+ top_k = sampling_params.top_k
+ temperature = sampling_params.temperature
+
+ assert not (top_k > 0 and top_p > 0), 'Cannot have top-p and top-k both greater than zero'
+ assert top_p <= 1.0, 'top-p should be in (0,1]'
+
+ def modify_logits_for_top_k_filtering(logits, top_k):
+ """Set the logits for none top-k values to -inf."""
+ filter_ = logits < torch.topk(logits, top_k)[0][..., -1, None]
+ logits.masked_fill_(filter_, float('-Inf'))
+
+ def modify_logits_for_top_p_filtering(logits, top_p):
+ """Set the logits for none top-p values to -inf."""
+ # First sort and calculate cumulative sum of probabilities.
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
+ cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
+
+ # Filteration based on the cumulative sum.
+ filter_ = cumulative_probs > top_p
+ # This shift by 1 is weird and I cannot justify it. This existed
+ # in the original implementation:
+ # https://github.com/ari-holtzman/degen/blob/master/gen.py
+ # and I guess it is needed so keeping it for now.
+ filter_[:, 1:] = filter_[:, :-1].clone()
+ # Make sure we at least have one token to select from.
+ filter_[..., 0] = 0
+
+ # Fill in the filtered part
+ filter_ = filter_.scatter(1, sorted_indices, filter_)
+ logits.masked_fill_(filter_, float('-Inf'))
+
+ # Greedy sampling
+ if top_k == 1:
+ sampled_logits = torch.argmax(last_token_logits, dim=-1)
+ else:
+ last_token_logits = last_token_logits.clone()
+ if temperature != 1.0:
+ last_token_logits.div_(temperature)
+
+ if top_k > 1:
+ assert top_k <= last_token_logits.size(1), 'top-k is larger than logit size.'
+ if vocab_size:
+ assert top_k < vocab_size, 'top-k is larger than vocab size.'
+ modify_logits_for_top_k_filtering(last_token_logits, top_k)
+
+ elif top_p > 0.0:
+ modify_logits_for_top_p_filtering(last_token_logits, top_p)
+
+ # After filtering, we need to recalculate the distribution.
+ probabilities = last_token_logits.softmax(dim=-1)
+ sampled_logits = torch.multinomial(probabilities, num_samples=1).view(-1)
+
+ # If vocab size is provided, make sure the samples are in in the range [0, vocab-size).
+ if vocab_size:
+ sampled_logits = torch.clamp(sampled_logits, min=0, max=(vocab_size - 1))
+ return sampled_logits
+
+ def update_generation_status(
+ self,
+ updated_prompts_tokens: torch.Tensor,
+ generation_started: torch.Tensor,
+ current_context_end_position: int,
+ is_generation_done_tensor: torch.Tensor,
+ generated_sequence_lengths: torch.Tensor,
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """Checks which prompts have reached an end condition
+
+ We check which prompts have reached an end condition and set the corresponding
+ flags of the is_generation_done_tensor to True. The generated sequence lengths
+ increase as we keep generating, until that prompts hits an end condition. The
+ generation_started tensor determines which prompts have started generating.
+
+ Args:
+ updated_prompts_tokens (torch.Tensor): The prompts tokens updated with the latest
+ generated tokens. A tensor of shape [batch_size, max_seq_len]
+ (i.e max_seq_len = max_prompt_len + tokens_to_generate)
+ generation_started (torch.Tensor): A boolean tensor of shape [batch_size]. True
+ indicates the prompt at that index has started generating tokens.
+ current_context_end_position (int): An integer indicating which position to
+ extract from the prompts tokens to get the latest generated tokens.
+ is_generation_done_tensor (torch.Tensor): A boolean tensor of shape [batch_size].
+ True indicates the prompt at that index has reached end condition.
+ generated_sequence_lengths (torch.Tensor): A int tensor of shape [batch_size].
+ Each value represents the generated sequence lengths for that prompt.
+
+ Returns:
+ Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Returns the boolean
+ is_generation_done_tensor and the generated_sequence_lengths after updating it
+ """
+ latest_samples = updated_prompts_tokens[:, current_context_end_position]
+ # Make sure we are checking eod criterion only for prompts that have started generating
+ # (i.e) We only look at the generated tokenns and not the input tokens.
+ reached_eod = (latest_samples == self.tokenizer.eod) & generation_started
+ is_generation_done_tensor = is_generation_done_tensor | reached_eod
+ # We increment generated sequence lengths when that prompt has not hit the
+ # EOD and generation has started
+ generated_sequence_lengths += ~is_generation_done_tensor & generation_started
+
+ return is_generation_done_tensor, generated_sequence_lengths
+
+ def pad_input_prompt_tokens(
+ self,
+ batch_prompt_tokens_list: List[List[int]],
+ max_prompt_length_in_batch: int,
+ num_tokens_to_generate: int,
+ ) -> torch.Tensor:
+ """Method to pad input prompts
+
+ Given a list of prompts, pad them all to uniform length
+
+ Args:
+ batch_prompt_tokens_list (List[List[int]]): A list containing the prompt tokens
+ max_prompt_length_in_batch (int): Maximum of the length of the input prompt tokens
+ num_tokens_togenerate (int): The number of tokens to generate for each prompt
+
+ Returns:
+ torch.Tensor: A torch tensor of shape [bs, max_seq_len] (i.e)
+ max_seq_len = max_prompt_length_in_batch + num_tokens_to_generate,
+ with extra indices for each tensor padded with mask id.
+ """
+ max_seq_len = max_prompt_length_in_batch + num_tokens_to_generate
+
+ for prompt_tokens in batch_prompt_tokens_list:
+ padding_size = max_seq_len - len(prompt_tokens)
+ prompt_tokens.extend([self.tokenizer.eod] * padding_size)
+
+ return torch.tensor(batch_prompt_tokens_list).cuda()
+
+ def generate_output_tokens_dynamic_batch(
+ self, active_requests: OrderedDict[int, InferenceRequest]
+ ) -> OrderedDict[int, InferenceRequest]:
+ """Utility to generate the output tokens and probabilities for the prompts
+
+ This utility generates the output tokens for a dynamic batch. It will run one forward step
+ at a time, and pass control back to the engine, which will update the request pool and call
+ this method again.
+
+ Args:
+ active_requests (OrderedDict[int, InferenceRequest]): The input active requests.
+
+ Returns:
+ OrderedDict[int, InferenceRequest]: The result for each of the incoming requests
+ after running one forward step.
+ """
+ raise Exception("Not implemented yet")
+
+ def generate_all_output_tokens_static_batch(
+ self, active_requests: OrderedDict[int, InferenceRequest]
+ ) -> OrderedDict[int, InferenceRequest]:
+ """Utility to generate the all the output tokens and probabilities for the prompts .
+
+ This utility generates the output tokens for a static batch. It runs the forward steps till
+ all prompts complete generation, updates the status of these requests to completed, adds
+ the generated result and returns these requests
+
+ Args:
+ active_requests (OrderedDict[int, InferenceRequest]): The input active requests.
+
+ Returns:
+ OrderedDict[int, InferenceRequest]: The result for each of the incoming requests
+ """
+ batch_prompt_tokens_list = list(
+ map(lambda request: request.prompt_tokens, active_requests.values())
+ )
+ prompt_lengths_in_batch = torch.tensor(
+ [len(prompt_tokens) for prompt_tokens in batch_prompt_tokens_list]
+ ).cuda()
+ max_prompt_length_in_batch = max(prompt_lengths_in_batch)
+ min_prompt_length_in_batch = min(prompt_lengths_in_batch)
+
+ # For batch inference the inference params are the same for all request
+ sampling_params: SamplingParams = list(active_requests.values())[0].inference_parameters
+
+ # max_seq_len = max_prompt_length_in_batch + num_tokens_to_generate
+ batch_prompt_tokens = self.pad_input_prompt_tokens(
+ batch_prompt_tokens_list,
+ max_prompt_length_in_batch=max_prompt_length_in_batch,
+ num_tokens_to_generate=sampling_params.num_tokens_to_generate,
+ )
+ batch_size, max_sequence_length = batch_prompt_tokens.shape
+
+ # Pre allocate log probs tensor
+ output_log_probs = None
+ if sampling_params.return_log_probs:
+ output_log_probs = torch.empty(
+ (batch_size, max_sequence_length - 1), dtype=torch.float32
+ ).cuda()
+
+ # An array to check which of the prompts have reached end of generation condition
+ is_generation_done_tensor = torch.zeros(batch_size, dtype=torch.bool).cuda()
+
+ # An array to act as a counter to keep track of generated sequence lengths
+ generated_sequence_lengths = torch.zeros(batch_size).cuda()
+
+ with torch.no_grad():
+
+ self.prep_model_for_inference(
+ prompts_tokens=batch_prompt_tokens, active_requests=active_requests
+ )
+
+ context_start_position = 0
+ # Pick the context window that we need to pass through the network.
+ for context_end_position in range(min_prompt_length_in_batch, max_sequence_length):
+
+ inference_input = self.inference_wrapped_model.get_batch_for_context_window(
+ context_start_position, context_end_position
+ )
+
+ # Returns the final logits of shape [batch_size, context_length, vocab_size]
+ # Note: This is returned in all TP ranks or last PP stage in PP models
+ logits = self.inference_wrapped_model.run_one_forward_step(inference_input)
+ if self.model_is_pipeline_parallel:
+ context_length = context_end_position - context_start_position
+ logits = broadcast_from_last_pipeline_stage(
+ [batch_size, context_length, self.tokenizer.vocab_size],
+ dtype=self.inference_wrapped_model.inference_wrapper_config.params_dtype,
+ tensor=logits,
+ )
+
+ # Indicates which of the input prompts have started generating tokens.
+ # A 1D boolean tensor with [batch_size] elements (i.e) The shortest
+ # prompts will start generating first and so on
+ generation_started = prompt_lengths_in_batch <= context_end_position
+ last_token_logits = logits[:, -1, :]
+ sampled_logits = self.sample_from_logits(
+ last_token_logits, sampling_params, self.tokenizer.vocab_size
+ )
+
+ # Substitute the sampled logits only for only the prompts that
+ # have started generating tokens
+ batch_prompt_tokens[generation_started, context_end_position] = sampled_logits[
+ generation_started
+ ]
+
+ if sampling_params.return_log_probs:
+ log_probs = F.log_softmax(logits, dim=2)
+ indices = torch.unsqueeze(
+ batch_prompt_tokens[
+ :, (context_start_position + 1) : (context_end_position + 1)
+ ],
+ 2,
+ )
+ # Get the log probabilities for only the prompt tokens
+ output_log_probs[:, context_start_position:context_end_position] = torch.gather(
+ log_probs, 2, indices
+ ).squeeze(2)
+
+ context_start_position = context_end_position
+
+ # Check end of generation status for each tensor
+ # and update generated sequence lengths
+ (is_generation_done_tensor, generated_sequence_lengths) = (
+ self.update_generation_status(
+ updated_prompts_tokens=batch_prompt_tokens,
+ generation_started=generation_started,
+ current_context_end_position=context_end_position,
+ is_generation_done_tensor=is_generation_done_tensor,
+ generated_sequence_lengths=generated_sequence_lengths,
+ )
+ )
+ # Boolean flag indicating if all prompts are finished
+ all_prompts_done = torch.all(is_generation_done_tensor)
+ if all_prompts_done:
+ break
+
+ # Include all the generated tokens
+ batch_prompt_tokens_with_generations = batch_prompt_tokens[:, : (context_end_position + 1)]
+ if sampling_params.return_log_probs:
+ output_log_probs = output_log_probs[:, :context_end_position]
+
+ generated_sequence_lengths[
+ generated_sequence_lengths > sampling_params.num_tokens_to_generate
+ ] = sampling_params.num_tokens_to_generate
+
+ for idx, request in enumerate(active_requests.values()):
+ input_prompt_length = int(prompt_lengths_in_batch[idx])
+ # Shorter prompts might have generated more than required tokens. So we trim them down
+ required_sequence_length = int(
+ min(generated_sequence_lengths[idx], sampling_params.num_tokens_to_generate)
+ )
+ # Extract only the generated tokens
+ required_result_tokens = batch_prompt_tokens_with_generations[
+ idx, input_prompt_length : (input_prompt_length + required_sequence_length)
+ ]
+
+ request.generated_length = required_sequence_length
+ request.generated_tokens = required_result_tokens
+ request.generated_log_probs = (
+ None
+ if output_log_probs is None
+ else output_log_probs[idx, input_prompt_length:required_sequence_length]
+ )
+ request.status = Status.COMPLETED
+ request.generated_text = self.detokenize_generations(required_result_tokens)
+
+ return active_requests
+
+ def prep_model_for_inference(
+ self, prompts_tokens: torch.Tensor, active_requests: OrderedDict[int, InferenceRequest]
+ ):
+ """Preparing batch for inference, using respective wrapper's prep_model_for_inference method
+
+ Args:
+ prompts_tokens (torch.Tensor): A tensor of shape [batch_size, max_sequence_length]
+ active_requests (OrderedDict[int, InferenceRequest]): The input active requests
+ """
+ self.inference_wrapped_model.prep_model_for_inference(prompts_tokens=prompts_tokens)
diff --git a/tests/unit_tests/inference/engines/test_mcore_engine.py b/tests/unit_tests/inference/engines/test_mcore_engine.py
index 8295744d36..1b342db4e6 100644
--- a/tests/unit_tests/inference/engines/test_mcore_engine.py
+++ b/tests/unit_tests/inference/engines/test_mcore_engine.py
@@ -5,7 +5,6 @@
import torch
-from megatron.core.inference.common_inference_params import CommonInferenceParams
from megatron.core.inference.engines.mcore_engine import MCoreEngine
from megatron.core.inference.inference_request import InferenceRequest, Status
from megatron.core.inference.model_inference_wrappers.gpt.gpt_inference_wrapper import (
@@ -14,8 +13,9 @@
from megatron.core.inference.model_inference_wrappers.inference_wrapper_config import (
InferenceWrapperConfig,
)
-from megatron.core.inference.text_generation_controllers.simple_text_generation_controller import (
- SimpleTextGenerationController,
+from megatron.core.inference.sampling_params import SamplingParams
+from megatron.core.inference.text_generation_controllers.text_generation_controller import (
+ TextGenerationController,
)
from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_local_spec
from megatron.core.models.gpt.gpt_model import GPTModel
@@ -60,7 +60,7 @@ def setup_method(self, method):
inference_wrapped_model = GPTInferenceWrapper(gpt_model, inference_wrapper_config)
self.mock_tokenizer = mock.Mock()
- text_generation_controller = SimpleTextGenerationController(
+ text_generation_controller = TextGenerationController(
inference_wrapped_model=inference_wrapped_model, tokenizer=self.mock_tokenizer
)
@@ -85,7 +85,7 @@ def test_generate(self):
prompts = ["sample" * (i + 1) for i in range(self.batch_size)]
results: List[InferenceRequest] = self.mcore_engine.generate(
- prompts, common_inference_params=CommonInferenceParams(num_tokens_to_generate=10)
+ prompts, sampling_params=SamplingParams(num_tokens_to_generate=10)
)
for result in results:
@@ -110,9 +110,7 @@ def test_generate_empty_prompt(self):
prompts = ["" for i in range(self.batch_size)]
results: List[InferenceRequest] = self.mcore_engine.generate(
- prompts,
- add_BOS=True,
- common_inference_params=CommonInferenceParams(num_tokens_to_generate=10),
+ prompts, add_BOS=True, sampling_params=SamplingParams(num_tokens_to_generate=10)
)
for result in results:
diff --git a/tests/unit_tests/inference/test_common_inference_params.py b/tests/unit_tests/inference/test_common_inference_params.py
index af51e433df..c7ef4c9ed8 100644
--- a/tests/unit_tests/inference/test_common_inference_params.py
+++ b/tests/unit_tests/inference/test_common_inference_params.py
@@ -1,10 +1,10 @@
-from megatron.core.inference.common_inference_params import CommonInferenceParams
+from megatron.core.inference.sampling_params import SamplingParams
-class TestCommonInferenceParams:
+class TestSamplingParams:
def test_inference_params(self):
- inference_parameters = CommonInferenceParams()
+ inference_parameters = SamplingParams()
inference_parameters.add_attributes({"min_tokens": 45})
assert (
inference_parameters.min_tokens == 45
diff --git a/tests/unit_tests/inference/test_scheduler.py b/tests/unit_tests/inference/test_scheduler.py
index b1f0ea184e..90caa70a7b 100644
--- a/tests/unit_tests/inference/test_scheduler.py
+++ b/tests/unit_tests/inference/test_scheduler.py
@@ -2,8 +2,8 @@
import torch
-from megatron.core.inference.common_inference_params import CommonInferenceParams
from megatron.core.inference.inference_request import InferenceRequest, Status
+from megatron.core.inference.sampling_params import SamplingParams
from megatron.core.inference.scheduler import Scheduler
@@ -25,7 +25,7 @@ def setup_method(self, method):
def test_scheduler(self):
prompt = "sample prompt"
prompt_tokens = torch.randn(5)
- inference_parameters = CommonInferenceParams()
+ inference_parameters = SamplingParams()
for i in range(self.max_batch_size):
self.scheduler.add_request(prompt, prompt_tokens, inference_parameters)
diff --git a/tests/unit_tests/inference/text_generation_controllers/test_encoder_decoder_text_generation_controller.py b/tests/unit_tests/inference/text_generation_controllers/test_encoder_decoder_text_generation_controller.py
index c28d0c3432..12903a919f 100644
--- a/tests/unit_tests/inference/text_generation_controllers/test_encoder_decoder_text_generation_controller.py
+++ b/tests/unit_tests/inference/text_generation_controllers/test_encoder_decoder_text_generation_controller.py
@@ -10,7 +10,6 @@
import pytest
import torch
-from megatron.core.inference.common_inference_params import CommonInferenceParams
from megatron.core.inference.inference_request import InferenceRequest, Status
from megatron.core.inference.model_inference_wrappers.inference_wrapper_config import (
InferenceWrapperConfig,
@@ -18,6 +17,7 @@
from megatron.core.inference.model_inference_wrappers.t5.t5_inference_wrapper import (
T5InferenceWrapper,
)
+from megatron.core.inference.sampling_params import SamplingParams
from megatron.core.inference.text_generation_controllers.encoder_decoder_text_generation_controller import (
EncoderDecoderTextGenerationController,
)
@@ -126,7 +126,7 @@ def test_generate_all_output_tokens_static_batch(self):
request_id=i,
prompt=prompt,
encoder_prompt=encoder_prompt,
- inference_parameters=CommonInferenceParams(num_tokens_to_generate=10),
+ inference_parameters=SamplingParams(num_tokens_to_generate=10),
arrival_time=time.time(),
prompt_tokens=prompt_tokens,
status=Status.ACTIVE_BUT_NOT_GENERATING_TOKENS,
diff --git a/tests/unit_tests/inference/text_generation_controllers/test_simple_text_generation_controller.py b/tests/unit_tests/inference/text_generation_controllers/test_simple_text_generation_controller.py
index 1e09cf05fb..1db360f232 100644
--- a/tests/unit_tests/inference/text_generation_controllers/test_simple_text_generation_controller.py
+++ b/tests/unit_tests/inference/text_generation_controllers/test_simple_text_generation_controller.py
@@ -9,7 +9,6 @@
import pytest
import torch
-from megatron.core.inference.common_inference_params import CommonInferenceParams
from megatron.core.inference.inference_request import InferenceRequest, Status
from megatron.core.inference.model_inference_wrappers.gpt.gpt_inference_wrapper import (
GPTInferenceWrapper,
@@ -17,8 +16,9 @@
from megatron.core.inference.model_inference_wrappers.inference_wrapper_config import (
InferenceWrapperConfig,
)
-from megatron.core.inference.text_generation_controllers.simple_text_generation_controller import (
- SimpleTextGenerationController,
+from megatron.core.inference.sampling_params import SamplingParams
+from megatron.core.inference.text_generation_controllers.text_generation_controller import (
+ TextGenerationController,
)
from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_local_spec
from megatron.core.models.gpt.gpt_model import GPTModel
@@ -28,7 +28,7 @@
from tests.unit_tests.test_utilities import Utils
-class TestSimpleTextGenerationController:
+class TestTextGenerationController:
def setup_method(self, method):
Utils.initialize_model_parallel(
@@ -67,7 +67,7 @@ def setup_method(self, method):
self.mock_tokenizer = mock.Mock()
- self.text_generation_controller = SimpleTextGenerationController(
+ self.text_generation_controller = TextGenerationController(
inference_wrapped_model=inference_wrapped_model, tokenizer=self.mock_tokenizer
)
@@ -78,7 +78,7 @@ def test_sample_from_logits(self):
with pytest.raises(AssertionError) as aerror:
self.text_generation_controller.sample_from_logits(
last_token_logits=None,
- common_inference_params=CommonInferenceParams(top_k=2, top_p=0.4),
+ sampling_params=SamplingParams(top_k=2, top_p=0.4),
vocab_size=self.vocab_size,
)
assert str(aerror.value) == 'Cannot have top-p and top-k both greater than zero'
@@ -86,7 +86,7 @@ def test_sample_from_logits(self):
with pytest.raises(AssertionError) as aerror:
self.text_generation_controller.sample_from_logits(
last_token_logits=None,
- common_inference_params=CommonInferenceParams(top_p=1.4, top_k=0),
+ sampling_params=SamplingParams(top_p=1.4, top_k=0),
vocab_size=self.vocab_size,
)
assert str(aerror.value) == 'top-p should be in (0,1]'
@@ -94,7 +94,7 @@ def test_sample_from_logits(self):
with pytest.raises(AssertionError) as aerror:
self.text_generation_controller.sample_from_logits(
last_token_logits=torch.randn(self.batch_size, 1),
- common_inference_params=CommonInferenceParams(top_k=self.vocab_size + 10),
+ sampling_params=SamplingParams(top_k=self.vocab_size + 10),
vocab_size=self.vocab_size,
)
assert str(aerror.value) == 'top-k is larger than logit size.'
@@ -103,14 +103,14 @@ def test_sample_from_logits(self):
torch.arange(0, self.vocab_size).repeat(self.batch_size, 1).float().cuda()
)
sampled_logits = self.text_generation_controller.sample_from_logits(
- last_token_logits, CommonInferenceParams(top_k=1), self.vocab_size
+ last_token_logits, SamplingParams(top_k=1), self.vocab_size
)
assert torch.all(
sampled_logits.cpu() == torch.ones(self.batch_size) * self.vocab_size - 1
), f"The sampled logits should all be {self.vocab_size} but its {sampled_logits}"
sampled_logits = self.text_generation_controller.sample_from_logits(
- last_token_logits, CommonInferenceParams(top_k=2), self.vocab_size
+ last_token_logits, SamplingParams(top_k=2), self.vocab_size
)
assert torch.all(
sampled_logits >= self.vocab_size - 2
@@ -120,7 +120,7 @@ def test_sample_from_logits(self):
top_p = 0.3
expected_min_value = l[l.softmax(dim=-1).cumsum(dim=-1) > top_p][0].item()
sampled_logits = self.text_generation_controller.sample_from_logits(
- last_token_logits, CommonInferenceParams(top_p=top_p, top_k=0), self.vocab_size
+ last_token_logits, SamplingParams(top_p=top_p, top_k=0), self.vocab_size
)
assert torch.all(
sampled_logits >= expected_min_value
@@ -131,7 +131,7 @@ def test_sample_from_logits(self):
expected_min_value = l[l.div_(temperature).softmax(dim=-1).cumsum(dim=-1) > top_p][0].item()
sampled_logits = self.text_generation_controller.sample_from_logits(
last_token_logits,
- CommonInferenceParams(top_p=top_p, temperature=temperature, top_k=0),
+ SamplingParams(top_p=top_p, temperature=temperature, top_k=0),
self.vocab_size,
)
assert torch.all(
@@ -154,7 +154,7 @@ def test_generate_all_output_tokens_static_batch(self):
inference_request = InferenceRequest(
request_id=i,
prompt=prompt,
- inference_parameters=CommonInferenceParams(num_tokens_to_generate=10),
+ inference_parameters=SamplingParams(num_tokens_to_generate=10),
arrival_time=time.time(),
prompt_tokens=torch.randint(
low=0, high=self.vocab_size - 1, size=(len(prompt),)