Skip to content

Commit

Permalink
[rollout] feat: support vLLM v0.6.3 and fix hf rollout import issue (#33
Browse files Browse the repository at this point in the history
)

* [feat] support vllm spmd version in v0.6.3

* [misc] fix hf_weight loader

* [misc] rollout: update vllm version and fix hf import

* lint

* [misc] fix init

* [doc] feat: modify doc to support vllm v6
  • Loading branch information
PeterSH6 authored Dec 6, 2024
1 parent b4a3d6b commit c592a8b
Show file tree
Hide file tree
Showing 21 changed files with 3,004 additions and 10 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ The following dependencies are required for all backends.
pip install torch==2.4.0 torchvision==0.19.0 torchaudio==2.4.0 --index-url https://download.pytorch.org/whl/cu121

# install vllm
pip3 install vllm==0.5.4
pip3 install vllm==0.6.3 # or you can install 0.5.4, 0.4.2 and 0.3.1
pip3 install ray==2.10 # other version may have bug

# flash attention 2
Expand Down
2 changes: 1 addition & 1 deletion docs/preparation/install.rst
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ found in :doc:`FSDP Workers<../workers/fsdp_workers>`.
pip install torch==2.4.0 torchvision==0.19.0 torchaudio==2.4.0 --index-url https://download.pytorch.org/whl/cu121
# install vllm
pip3 install vllm==0.5.4
pip3 install vllm==0.6.3 # or you can install 0.5.4, 0.4.2 and 0.3.1
pip3 install ray==2.10 # other version may have bug
# flash attention 2
Expand Down
6 changes: 2 additions & 4 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
transformers
hydra-core
tensordict < 0.3.1
tensordict==0.5.0
numpy
pytest
deepspeed
pybind11
codetiming
yapf
wandb
git+https://github.com/NVIDIA/TransformerEngine.git@stable
# vllm==0.5.4 # vllm is installed in image building to avoid ray conflicts
git+https://github.com/NVIDIA/TransformerEngine.git@stable
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
]

install_optional = [
'vllm==0.5.4',
'vllm==0.6.3',
]

extras_require = {
Expand Down
8 changes: 7 additions & 1 deletion verl/third_party/vllm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,12 @@ def get_version(pkg):
from .vllm_v_0_5_4.llm import LLM
from .vllm_v_0_5_4.llm import LLMEngine
from .vllm_v_0_5_4 import parallel_state
elif package_version == '0.6.3':
vllm_version = '0.6.3'
from .vllm_v_0_6_3.llm import LLM
from .vllm_v_0_6_3.llm import LLMEngine
from .vllm_v_0_6_3 import parallel_state
else:
raise ValueError(
f'vllm version {package_version} not supported. Currently supported versions are 0.3.1, 0.4.2, and 0.5.4.')
f'vllm version {package_version} not supported. Currently supported versions are 0.3.1, 0.4.2, 0.5.4 and 0.6.3.'
)
13 changes: 13 additions & 0 deletions verl/third_party/vllm/vllm_v_0_6_3/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
78 changes: 78 additions & 0 deletions verl/third_party/vllm/vllm_v_0_6_3/arg_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
# Copyright 2023 The vLLM team.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/engine/arg_utils.py

import os
from dataclasses import dataclass

from transformers import PretrainedConfig
from vllm.config import EngineConfig
from vllm.engine.arg_utils import EngineArgs

from .config import LoadConfig, ModelConfig


@dataclass
class EngineArgs(EngineArgs):
model_hf_config: PretrainedConfig = None # for verl

def __post_init__(self):
pass

def create_model_config(self) -> ModelConfig:
return ModelConfig(
hf_config=self.model_hf_config,
tokenizer_mode=self.tokenizer_mode,
trust_remote_code=self.trust_remote_code,
dtype=self.dtype,
seed=self.seed,
revision=self.revision,
code_revision=self.code_revision,
rope_scaling=self.rope_scaling,
rope_theta=self.rope_theta,
tokenizer_revision=self.tokenizer_revision,
max_model_len=self.max_model_len,
quantization=self.quantization,
quantization_param_path=self.quantization_param_path,
enforce_eager=self.enforce_eager,
max_context_len_to_capture=self.max_context_len_to_capture,
max_seq_len_to_capture=self.max_seq_len_to_capture,
max_logprobs=self.max_logprobs,
disable_sliding_window=self.disable_sliding_window,
skip_tokenizer_init=self.skip_tokenizer_init,
served_model_name=self.served_model_name,
limit_mm_per_prompt=self.limit_mm_per_prompt,
use_async_output_proc=not self.disable_async_output_proc,
override_neuron_config=self.override_neuron_config,
config_format=self.config_format,
mm_processor_kwargs=self.mm_processor_kwargs,
)

def create_load_config(self) -> LoadConfig:
return LoadConfig(
load_format=self.load_format,
download_dir=self.download_dir,
model_loader_extra_config=self.model_loader_extra_config,
ignore_patterns=self.ignore_patterns,
)

def create_engine_config(self) -> EngineConfig:
engine_config = super().create_engine_config()

# NOTE[VERL]: Use the world_size set by torchrun
world_size = int(os.getenv("WORLD_SIZE", "-1"))
assert world_size != -1, "The world_size is set to -1, not initialized by TORCHRUN"
engine_config.parallel_config.world_size = world_size

return engine_config
105 changes: 105 additions & 0 deletions verl/third_party/vllm/vllm_v_0_6_3/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
# Copyright 2023 The vLLM team.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/config.py

import enum
import json
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, List, Optional, Union

from transformers import PretrainedConfig

# Add for verl
from vllm.config import ModelConfig
from vllm.logger import init_logger
from vllm.utils import is_hip

if TYPE_CHECKING:
from vllm.model_executor.model_loader.loader import BaseModelLoader

logger = init_logger(__name__)


class LoadFormat(str, enum.Enum):
AUTO = "auto"
MEGATRON = "megatron"
HF = "hf"
DTENSOR = "dtensor"
DUMMY_HF = "dummy_hf"
DUMMY_MEGATRON = "dummy_megatron"
DUMMY_DTENSOR = "dummy_dtensor"


class ModelConfig(ModelConfig):

def __init__(self, hf_config: PretrainedConfig, *args, **kwargs) -> None:
super().__init__(model=hf_config._name_or_path, tokenizer=hf_config._name_or_path, *args, **kwargs)
self.hf_config = hf_config


@dataclass
class LoadConfig:
"""
download_dir: Directory to download and load the weights, default to the
default cache directory of huggingface.
load_format: The format of the model weights to load:
"auto" will try to load the weights in the safetensors format and
fall back to the pytorch bin format if safetensors format is
not available.
"pt" will load the weights in the pytorch bin format.
"safetensors" will load the weights in the safetensors format.
"npcache" will load the weights in pytorch format and store
a numpy cache to speed up the loading.
"dummy" will initialize the weights with random values, which is
mainly for profiling.
"tensorizer" will use CoreWeave's tensorizer library for
fast weight loading.
"bitsandbytes" will load nf4 type weights.
ignore_patterns: The list of patterns to ignore when loading the model.
Default to "original/**/*" to avoid repeated loading of llama's
checkpoints.
"""

load_format: Union[str, LoadFormat, "BaseModelLoader"] = LoadFormat.AUTO
download_dir: Optional[str] = None
model_loader_extra_config: Optional[Union[str, dict]] = field(default_factory=dict)
ignore_patterns: Optional[Union[List[str], str]] = None

def __post_init__(self):
model_loader_extra_config = self.model_loader_extra_config or {}
if isinstance(model_loader_extra_config, str):
self.model_loader_extra_config = json.loads(model_loader_extra_config)
self._verify_load_format()

if self.ignore_patterns is not None and len(self.ignore_patterns) > 0:
logger.info("Ignoring the following patterns when downloading weights: %s", self.ignore_patterns)
else:
self.ignore_patterns = ["original/**/*"]

def _verify_load_format(self) -> None:
if not isinstance(self.load_format, str):
return

load_format = self.load_format.lower()
self.load_format = LoadFormat(load_format)

rocm_not_supported_load_format: List[str] = []
if is_hip() and load_format in rocm_not_supported_load_format:
rocm_supported_load_format = [
f for f in LoadFormat.__members__ if (f not in rocm_not_supported_load_format)
]
raise ValueError(f"load format '{load_format}' is not supported in ROCm. "
f"Supported load formats are "
f"{rocm_supported_load_format}")
Loading

0 comments on commit c592a8b

Please sign in to comment.