Skip to content

Commit

Permalink
Merge branch 'helenn-refactor-textgen' into 'main'
Browse files Browse the repository at this point in the history
MCore Inference misc changes

See merge request ADLR/megatron-lm!2439
  • Loading branch information
jaredcasper committed Dec 18, 2024
2 parents 1b7553e + 281cbe6 commit 8d2bc43
Show file tree
Hide file tree
Showing 16 changed files with 555 additions and 530 deletions.
105 changes: 55 additions & 50 deletions examples/inference/README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
### Megatron Core Inference Documentation
This guide will walk you through how you can use megatron core for inference on your models.
This guide provides an example for Megatron Core for running model inference.

### Contents
- [Megatron Core Inference Documentation](#megatron-core-inference-documentation)
Expand All @@ -18,21 +18,21 @@ This guide will walk you through how you can use megatron core for inference on
<br>

#### 1. Quick Start
This will walk you through the flow of running batch inference on a GPT model trained using megatron core. The file can be found at [simple_gpt_batch_inference.py](./gpt/simple_gpt_batch_inference.py)
This example runs batch inference on a GPT model trained using Megatron Core. The entrypoint is [simple_gpt_batch_inference.py](./gpt/gpt_batch_inference.py)

<br>

##### 1.1 Understanding The Code
***STEP 1 - We initialize model parallel and other default arguments***
We can default micro batch size to be 1, since for TP models it is not used, and for PP models it is calculated during runtime.
##### 1.1 Code Walkthrough
***STEP 1 - Initialize model parallel and other default arguments***
The micro batch size is set as 1 as it is not used in tensor-parallelism only, and for pipeline-parallel models it is calculated at runtime.
```python
initialize_megatron(
args_defaults={'no_load_rng': True, 'no_load_optim': True, 'micro_batch_size': 1}
)
```

***STEP 2 - We load the model using the model_provider_function***
NOTE: The model provider function in the script supports MCore and Legacy models.
***STEP 2 - Load the model using the model_provider_function***
NOTE: The model provider function supports both MCore and Legacy models.

```python
model = get_model(model_provider, wrap_with_ddp=False)
Expand All @@ -41,10 +41,10 @@ NOTE: The model provider function in the script supports MCore and Legacy models
```

***STEP 3 - Choose an engine***
One of the important elements of the generate function is an inference engine. In this example we will be choosing the [megatron core engine](../../megatron/core/inference/engine/mcore_engine.py) with a [simple text generation controller](../../megatron/core/inference/text_generation_controllers/simple_text_generation_controller.py), the default engine. Other engines that will be supported in the future are TRTLLMEngine.
Text generation requires an inference engine, which includes a scheduler. The default engine is the [Megatron Core engine](../../megatron/core/inference/engine/mcore_engine.py) with a simple [text generation controller](../../megatron/core/inference/text_generation_controllers/text_generation_controller.py). TRTLLMEngine will be supported in the future.
```python
inference_wrapped_model = GPTInferenceWrapper(model, args)
text_generation_controller = SimpleTextGenerationController(
text_generation_controller = TextGenerationController(
inference_wrapped_model=inference_wrapped_model,
tokenizer=tokenizer
)
Expand All @@ -53,12 +53,12 @@ One of the important elements of the generate function is an inference engine. I
)
```

***STEP 4 - Run the generate function and display results***
We use default values for the [common inference params](../../megatron/core/inference/common_inference_params.py). Customize this if you want to change top_p, top_k, number of tokens to generate etc.
*Note that the result is returned as a list of [InferenceRequests](../../megatron/core/inference/inference_request.py)*
***STEP 4 - Run text generation***
The [SamplingParams](../../megatron/core/inference/sampling_params.py) contains suggested defaults. Customize this to change top_p, top_k, number of tokens to generate etc.
*Note: The result is returned as a list of [InferenceRequests](../../megatron/core/inference/inference_request.py)*
```python
results: List[InferenceRequest] = inference_engine.generate(
prompts=args.prompts, common_inference_params=common_inference_params
prompts=args.prompts, sampling_params=sampling_params
)

if torch.distributed.get_rank() == 0:
Expand All @@ -76,12 +76,12 @@ We use default values for the [common inference params](../../megatron/core/infe
<br>

##### 1.2 Running The Code
An example run script is shown below. Change the tokenizer paths, inference params, and other settings for your model.
An example run script is shown below. Set the tokenizer paths, inference params, and other settings appropriately.

For a quick recap on inference params refer to [this blog](https://ivibudh.medium.com/a-guide-to-controlling-llm-model-output-exploring-top-k-top-p-and-temperature-parameters-ed6a31313910)
For a quick recap on sampling parameters, refer to [this blog](https://ivibudh.medium.com/a-guide-to-controlling-llm-model-output-exploring-top-k-top-p-and-temperature-parameters-ed6a31313910).

```
#In a slurm cluster (You could also use docker)
# In a slurm cluster (You could also use docker)
ACCOUNT=<account>
MLM_PATH=/path/to/megatron-lm
GPT_CKPT=/path/to/gpt/ckpt
Expand Down Expand Up @@ -133,25 +133,26 @@ NOTE: Other parameters which can be customized for inference are :-
--top_p (top_p sampling)
--num-tokens-to-generate (Number of tokens to generate for each prompt)
--inference-batch-times-seqlen-threshold (During inference, if batch-size times sequence-length is smaller than this threshold then we will not use pipelining, otherwise we will.')
--use-dist-ckpt (If you are using dist checkpoint format for the model)
--use-legacy-models (If you are using legacy gpt model instead of mcore gpt model)
--use-dist-ckpt (If using dist checkpoint format for the model)
--use-legacy-models (If using legacy gpt model instead of mcore gpt model)
```


<br>


#### 2. Flow of Control In MCore Backend
The following is what happens in the [simple_gpt_batch_inference.py](./gpt/simple_gpt_batch_inference.py).
* We call [mcore_engine](../../megatron/core/inference/engines/mcore_engine.py) **generate()** function with all our input prompts.
* The scheduler in the engine will add these prompts to the [active requests] pool (../../megatron/core/inference/inference_request.py) until we hit the max batch size, and then it will put the rest in the waiting requests pool.
* The engine will then run until all requests (waiting + active) are completed
#### 2. Control Flow in the MCore Backend
An example of inference with static batching is provided in [gpt_batch_inference.py](./gpt/gpt_batch_inference.py).
* [mcore_engine](../../megatron/core/inference/engines/mcore_engine.py) **generate()** function is called with the input prompts.
* The `Scheduler` in the engine will add these prompts to the [active requests] pool (../../megatron/core/inference/inference_request.py) until max batch size is hit. Remaining requests will be added to the waiting requests pool.
* The engine will run until all requests (waiting + active) are completed.
* The active requests are passed into **generate_all_output_tokens_static_batch()** of the text generation controller .
* This function uses the [model_inference_wrappers](../../megatron/core/inference/model_inference_wrappers/abstract_model_inference_wrapper.py) **prep_model_for_inference()** , and then runs an auto regressive loop
* In the auto regressive loop, the **get_batch_for_context_window()** method of the inference wrapper is called to get the required input, passes it into the **run_one_forward_step()** method, which calls the appropriate (PP, TP) model `.forward()` methods to get the output logits
* The output logits are synchronized across all pipeline parallel ranks
* The text generation controller obtains the log probabilities and samples tokens based on the strategy defined in the common inference parameters.
* This function uses the **prep_model_for_inference()** method of the [model_inference_wrappers](../../megatron/core/inference/model_inference_wrappers/abstract_model_inference_wrapper.py) and runs an autoregressive sampling loop
* In the autoregressive loop, the **get_batch_for_context_window()** method of the inference wrapper is called to slice out the input tokens and masks
* Input tokens and masks are passed it into the **run_one_forward_step()** method, which calls the model `.forward()` method to get the output logits
* Output logits are synchronized across all pipeline parallel ranks
* The text generation controller obtains the log probabilities and samples tokens based on the strategy defined in the sampling parameters.
* The sampled tokens are then appended to the input prompt tokens for the next iteration
* The **update_generation_status()** method of the text generation controller checks which prompts have finished generating or hit a stop condition
* After the inference loop, the result is detokenized and stored as an attribute of the InferenceRequest. These requests are marked as completed.
Expand All @@ -160,45 +161,49 @@ The following is what happens in the [simple_gpt_batch_inference.py](./gpt/simpl
<br>

#### 3. Customizing The Inference Pipeline
The following guide will walk you through how you can customize different parts of the inference pipeline. There are three levels at which you can customize the pipeline.
* **Inference engine** - Highest level of customization. Currently we support the MCore Engine. Change this to add a new engine.
* **Text generation controller** - Extend this to customize tokenization, detokenization, or implement a new sampling strategy.

The inference pipeline supports three levels of customization:

* **Inference engine** - The MCore Engine is currently supported. Change this to add a new backend.
* **Text generation controller** - The main sampling loop. This can be customized to support alternative tokenization, detokenization, or to implement a new sampling strategy.
* **Inference Wrapped Model** - Change this to support a new model.
* **Modify Inference Parameters** - Change this to update top_p, top_k, number of tokens to be generated, temperature, or other sampling parameters.

<br>

##### 3.1. Create Your Own Inference Backend
This is the highest level of customization. The [abstract_engine.py](./../../megatron/core/inference/engine/abstract_engine.py) file has a generate method that can be extended to support a new backend.
The [abstract_engine.py](./../../megatron/core/inference/engine/abstract_engine.py) file contains a `generate` method that can be extended to support a new backend.

```python
class AbstractEngine(ABC):
@staticmethod
def generate(self) -> dict:
"""The abstract backend's generate function.
To define your own backend, make sure you implement this and return the outputs as a dictionary .
To define a new backend, implement this method and return the outputs as a dictionary.
```
<br>
##### 3.2. Create Your Own Text Generation Controller
In case you want to use the megatron core backend, but would like to overwrite the tokenization, text generation or detokenization extend the [simple_text_generation_controller.py](../../megatron/core/inference/text_generation_controllers/simple_text_generation_controller.py). The class has the following methods
##### 3.2. Implement a new Sampling Loop
The [TextGenerationController](../../megatron/core/inference/text_generation_controllers/text_generation_controller.py) contains the main sampling loop and can be modified to support new tokenization, detokenization, or sampling strategies.
``` python
class SimpleTextGenerationController:
class TextGenerationController:
def tokenize_prompt(self, prompt: str) -> Tuple[torch.Tensor, torch.Tensor]:
"""Utility to tokenize the input prompts"""
def sample_from_logits(
self,
last_token_logits: torch.Tensor,
common_inference_params: CommonInferenceParams,
sampling_params: SamplingParams,
vocab_size: int,
) -> torch.Tensor:
"""Samples the logits to generate outputs

Given the logits of the last token, this function samples it according to the parameters defined in common_inference_params and returns the samples
Given the logits of the last token, this function samples according to the parameters defined in sampling_params and returns the sampled tokens.
"""
def update_generation_status(
Expand Down Expand Up @@ -229,12 +234,12 @@ class SimpleTextGenerationController:
<br>
##### 3.3. Support Other Models
In order to support other models please extend the [abstract_model_inference_wrapper.py](./../../megatron/core/inference/model_inference_wrappers/abstract_model_inference_wrapper.py) file. The abstract wrapper already supports the following :
* Forward method which automatically calls the appropriate forward method (PP or TP etc) depending on model parallel settings
* Initalizes the model and puts it in eval mode
* Obtains the input parameters (batch size, max seq length) and has an instance of the input
Extend [abstract_model_inference_wrapper.py](./../../megatron/core/inference/model_inference_wrappers/abstract_model_inference_wrapper.py) to support other models. The abstract model wrapper implements:
* Forward method which calls the model `forward` method depending on model parallel settings
* Initializes the model and puts it in `.eval()` mode
* Setup for the input parameters (max batch size, max seq length)
The main methods to change for your model might be the following:
The following methods should be implemented:
```python
class AbstractModelInferenceWrapper:
def prep_model_for_inference(self, prompts_tokens: torch.Tensor):
Expand All @@ -247,28 +252,28 @@ class AbstractModelInferenceWrapper:
def get_batch_for_context_window(self) -> List:
"""Returns the input data for inference

This function gets called iteratively in the inference loop . It can be used to extract relevant input from the prompt tokens, attention mask etc. required for each step in inference.
This function gets called iteratively in the inference loop. It can be used to extract relevant input from the prompt tokens, attention mask etc. required for each step in inference.
```

Refer to [gpt_inference_wrapper.py](../../megatron/core/inference/model_inference_wrappers/gpt/gpt_inference_wrapper.py) for an example of extending this for GPTModel.
Refer to [gpt_inference_wrapper.py](../../megatron/core/inference/model_inference_wrappers/gpt/gpt_inference_wrapper.py) for an example of implementing this for GPTModel.

<br>

##### 3.3. Modify Inference Parameters
We use [common inference params](../../megatron/core/inference/common_inference_params.py) for text generation. Customize this if you want to change top_p, top_k, number of tokens to generate etc. If you want to add other attributes that you would use in the inference loop, you can do that as shown below
We use [common inference params](../../megatron/core/inference/sampling_params.py) for text generation. Customize this if you want to change top_p, top_k, number of tokens to generate etc. If you want to add other attributes that you would use in the inference loop, you can do that as shown below

```
from megatron.core.inference.common_inference_params import CommonInferenceParams
from megatron.core.inference.sampling_params import SamplingParams

c = CommonInferenceParams(temperature=0.5)
c = SamplingParams(temperature=0.5)
c.add_attributes({'min_length':4, 'eod_id':153})
```

<br>

#### 4. Future work
The following are planned for the future releases .
The following features are planned for the future releases.
* Dynamic batching
* Paged Attention
* TRTLLM Engine support
* Support for Multimodal model inference
* Support for multimodal inference
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@
from argparse import Namespace
from megatron.core.inference.engines.abstract_engine import AbstractEngine
from megatron.core.inference.engines.mcore_engine import MCoreEngine
from megatron.core.inference.common_inference_params import CommonInferenceParams
from megatron.core.inference.sampling_params import SamplingParams
from megatron.core.inference.model_inference_wrappers.gpt.gpt_inference_wrapper import GPTInferenceWrapper
from megatron.core.inference.inference_request import InferenceRequest
from megatron.core.inference.text_generation_controllers.simple_text_generation_controller import SimpleTextGenerationController
from megatron.core.inference.text_generation_controllers.text_generation_controller import TextGenerationController
from megatron.core.transformer.module import MegatronModule
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__),
os.path.pardir, os.path.pardir)))
Expand Down Expand Up @@ -66,7 +66,7 @@ def get_inference_engine(args: Namespace, model: MegatronModule) -> AbstractEngi
)

inference_wrapped_model = GPTInferenceWrapper(model, inference_wrapper_config)
text_generation_controller = SimpleTextGenerationController(inference_wrapped_model=inference_wrapped_model, tokenizer=tokenizer)
text_generation_controller = TextGenerationController(inference_wrapped_model=inference_wrapped_model, tokenizer=tokenizer)
return MCoreEngine(text_generation_controller=text_generation_controller, max_batch_size=args.max_batch_size)

def main():
Expand All @@ -89,15 +89,15 @@ def main():

inference_engine = get_inference_engine(args, model)

common_inference_params = CommonInferenceParams(
sampling_params = SamplingParams(
temperature=args.temperature,
top_k=args.top_k,
top_p=args.top_p,
return_log_probs=args.return_log_probs,
num_tokens_to_generate=args.num_tokens_to_generate)

results: List[InferenceRequest] = inference_engine.generate(
prompts=args.prompts, common_inference_params=common_inference_params
prompts=args.prompts, sampling_params=sampling_params
)

if torch.distributed.get_rank() == 0:
Expand Down
6 changes: 3 additions & 3 deletions examples/inference/t5/simple_t5_batch_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch

import pretrain_t5
from megatron.core.inference.common_inference_params import CommonInferenceParams
from megatron.core.inference.sampling_params import SamplingParams
from megatron.core.inference.engines.abstract_engine import AbstractEngine
from megatron.core.inference.engines.mcore_engine import MCoreEngine
from megatron.core.inference.inference_request import InferenceRequest
Expand Down Expand Up @@ -120,7 +120,7 @@ def main():

inference_engine = get_inference_engine(args, model)

common_inference_params = CommonInferenceParams(
sampling_params = SamplingParams(
temperature=args.temperature,
top_k=args.top_k,
top_p=args.top_p,
Expand All @@ -138,7 +138,7 @@ def main():
prompts=args.prompts,
add_BOS=True,
encoder_prompts=args.encoder_prompts,
common_inference_params=common_inference_params,
sampling_params=sampling_params,
)

if torch.distributed.get_rank() == 0:
Expand Down
Loading

0 comments on commit 8d2bc43

Please sign in to comment.