-
Notifications
You must be signed in to change notification settings - Fork 54
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[rollout] feat: support vLLM v0.6.3 and fix hf rollout import issue (#33
) * [feat] support vllm spmd version in v0.6.3 * [misc] fix hf_weight loader * [misc] rollout: update vllm version and fix hf import * lint * [misc] fix init * [doc] feat: modify doc to support vllm v6
- Loading branch information
Showing
21 changed files
with
3,004 additions
and
10 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,12 +1,10 @@ | ||
transformers | ||
hydra-core | ||
tensordict < 0.3.1 | ||
tensordict==0.5.0 | ||
numpy | ||
pytest | ||
deepspeed | ||
pybind11 | ||
codetiming | ||
yapf | ||
wandb | ||
git+https://github.com/NVIDIA/TransformerEngine.git@stable | ||
# vllm==0.5.4 # vllm is installed in image building to avoid ray conflicts | ||
git+https://github.com/NVIDIA/TransformerEngine.git@stable |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -35,7 +35,7 @@ | |
] | ||
|
||
install_optional = [ | ||
'vllm==0.5.4', | ||
'vllm==0.6.3', | ||
] | ||
|
||
extras_require = { | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
# Copyright 2024 Bytedance Ltd. and/or its affiliates | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
# Copyright 2024 Bytedance Ltd. and/or its affiliates | ||
# Copyright 2023 The vLLM team. | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/engine/arg_utils.py | ||
|
||
import os | ||
from dataclasses import dataclass | ||
|
||
from transformers import PretrainedConfig | ||
from vllm.config import EngineConfig | ||
from vllm.engine.arg_utils import EngineArgs | ||
|
||
from .config import LoadConfig, ModelConfig | ||
|
||
|
||
@dataclass | ||
class EngineArgs(EngineArgs): | ||
model_hf_config: PretrainedConfig = None # for verl | ||
|
||
def __post_init__(self): | ||
pass | ||
|
||
def create_model_config(self) -> ModelConfig: | ||
return ModelConfig( | ||
hf_config=self.model_hf_config, | ||
tokenizer_mode=self.tokenizer_mode, | ||
trust_remote_code=self.trust_remote_code, | ||
dtype=self.dtype, | ||
seed=self.seed, | ||
revision=self.revision, | ||
code_revision=self.code_revision, | ||
rope_scaling=self.rope_scaling, | ||
rope_theta=self.rope_theta, | ||
tokenizer_revision=self.tokenizer_revision, | ||
max_model_len=self.max_model_len, | ||
quantization=self.quantization, | ||
quantization_param_path=self.quantization_param_path, | ||
enforce_eager=self.enforce_eager, | ||
max_context_len_to_capture=self.max_context_len_to_capture, | ||
max_seq_len_to_capture=self.max_seq_len_to_capture, | ||
max_logprobs=self.max_logprobs, | ||
disable_sliding_window=self.disable_sliding_window, | ||
skip_tokenizer_init=self.skip_tokenizer_init, | ||
served_model_name=self.served_model_name, | ||
limit_mm_per_prompt=self.limit_mm_per_prompt, | ||
use_async_output_proc=not self.disable_async_output_proc, | ||
override_neuron_config=self.override_neuron_config, | ||
config_format=self.config_format, | ||
mm_processor_kwargs=self.mm_processor_kwargs, | ||
) | ||
|
||
def create_load_config(self) -> LoadConfig: | ||
return LoadConfig( | ||
load_format=self.load_format, | ||
download_dir=self.download_dir, | ||
model_loader_extra_config=self.model_loader_extra_config, | ||
ignore_patterns=self.ignore_patterns, | ||
) | ||
|
||
def create_engine_config(self) -> EngineConfig: | ||
engine_config = super().create_engine_config() | ||
|
||
# NOTE[VERL]: Use the world_size set by torchrun | ||
world_size = int(os.getenv("WORLD_SIZE", "-1")) | ||
assert world_size != -1, "The world_size is set to -1, not initialized by TORCHRUN" | ||
engine_config.parallel_config.world_size = world_size | ||
|
||
return engine_config |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
# Copyright 2024 Bytedance Ltd. and/or its affiliates | ||
# Copyright 2023 The vLLM team. | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/config.py | ||
|
||
import enum | ||
import json | ||
from dataclasses import dataclass, field | ||
from typing import TYPE_CHECKING, List, Optional, Union | ||
|
||
from transformers import PretrainedConfig | ||
|
||
# Add for verl | ||
from vllm.config import ModelConfig | ||
from vllm.logger import init_logger | ||
from vllm.utils import is_hip | ||
|
||
if TYPE_CHECKING: | ||
from vllm.model_executor.model_loader.loader import BaseModelLoader | ||
|
||
logger = init_logger(__name__) | ||
|
||
|
||
class LoadFormat(str, enum.Enum): | ||
AUTO = "auto" | ||
MEGATRON = "megatron" | ||
HF = "hf" | ||
DTENSOR = "dtensor" | ||
DUMMY_HF = "dummy_hf" | ||
DUMMY_MEGATRON = "dummy_megatron" | ||
DUMMY_DTENSOR = "dummy_dtensor" | ||
|
||
|
||
class ModelConfig(ModelConfig): | ||
|
||
def __init__(self, hf_config: PretrainedConfig, *args, **kwargs) -> None: | ||
super().__init__(model=hf_config._name_or_path, tokenizer=hf_config._name_or_path, *args, **kwargs) | ||
self.hf_config = hf_config | ||
|
||
|
||
@dataclass | ||
class LoadConfig: | ||
""" | ||
download_dir: Directory to download and load the weights, default to the | ||
default cache directory of huggingface. | ||
load_format: The format of the model weights to load: | ||
"auto" will try to load the weights in the safetensors format and | ||
fall back to the pytorch bin format if safetensors format is | ||
not available. | ||
"pt" will load the weights in the pytorch bin format. | ||
"safetensors" will load the weights in the safetensors format. | ||
"npcache" will load the weights in pytorch format and store | ||
a numpy cache to speed up the loading. | ||
"dummy" will initialize the weights with random values, which is | ||
mainly for profiling. | ||
"tensorizer" will use CoreWeave's tensorizer library for | ||
fast weight loading. | ||
"bitsandbytes" will load nf4 type weights. | ||
ignore_patterns: The list of patterns to ignore when loading the model. | ||
Default to "original/**/*" to avoid repeated loading of llama's | ||
checkpoints. | ||
""" | ||
|
||
load_format: Union[str, LoadFormat, "BaseModelLoader"] = LoadFormat.AUTO | ||
download_dir: Optional[str] = None | ||
model_loader_extra_config: Optional[Union[str, dict]] = field(default_factory=dict) | ||
ignore_patterns: Optional[Union[List[str], str]] = None | ||
|
||
def __post_init__(self): | ||
model_loader_extra_config = self.model_loader_extra_config or {} | ||
if isinstance(model_loader_extra_config, str): | ||
self.model_loader_extra_config = json.loads(model_loader_extra_config) | ||
self._verify_load_format() | ||
|
||
if self.ignore_patterns is not None and len(self.ignore_patterns) > 0: | ||
logger.info("Ignoring the following patterns when downloading weights: %s", self.ignore_patterns) | ||
else: | ||
self.ignore_patterns = ["original/**/*"] | ||
|
||
def _verify_load_format(self) -> None: | ||
if not isinstance(self.load_format, str): | ||
return | ||
|
||
load_format = self.load_format.lower() | ||
self.load_format = LoadFormat(load_format) | ||
|
||
rocm_not_supported_load_format: List[str] = [] | ||
if is_hip() and load_format in rocm_not_supported_load_format: | ||
rocm_supported_load_format = [ | ||
f for f in LoadFormat.__members__ if (f not in rocm_not_supported_load_format) | ||
] | ||
raise ValueError(f"load format '{load_format}' is not supported in ROCm. " | ||
f"Supported load formats are " | ||
f"{rocm_supported_load_format}") |
Oops, something went wrong.