diff --git a/README.md b/README.md index 8ba30df..1919dfd 100644 --- a/README.md +++ b/README.md @@ -69,7 +69,7 @@ The following dependencies are required for all backends. pip install torch==2.4.0 torchvision==0.19.0 torchaudio==2.4.0 --index-url https://download.pytorch.org/whl/cu121 # install vllm -pip3 install vllm==0.5.4 +pip3 install vllm==0.6.3 # or you can install 0.5.4, 0.4.2 and 0.3.1 pip3 install ray==2.10 # other version may have bug # flash attention 2 diff --git a/docs/preparation/install.rst b/docs/preparation/install.rst index 1c623c1..9a932e9 100644 --- a/docs/preparation/install.rst +++ b/docs/preparation/install.rst @@ -45,7 +45,7 @@ found in :doc:`FSDP Workers<../workers/fsdp_workers>`. pip install torch==2.4.0 torchvision==0.19.0 torchaudio==2.4.0 --index-url https://download.pytorch.org/whl/cu121 # install vllm - pip3 install vllm==0.5.4 + pip3 install vllm==0.6.3 # or you can install 0.5.4, 0.4.2 and 0.3.1 pip3 install ray==2.10 # other version may have bug # flash attention 2 diff --git a/requirements.txt b/requirements.txt index 823cfda..ca102e9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,12 +1,10 @@ transformers hydra-core -tensordict < 0.3.1 +tensordict==0.5.0 numpy pytest -deepspeed pybind11 codetiming yapf wandb -git+https://github.com/NVIDIA/TransformerEngine.git@stable -# vllm==0.5.4 # vllm is installed in image building to avoid ray conflicts \ No newline at end of file +git+https://github.com/NVIDIA/TransformerEngine.git@stable \ No newline at end of file diff --git a/setup.py b/setup.py index 8289d3c..8b2907f 100644 --- a/setup.py +++ b/setup.py @@ -36,7 +36,7 @@ ] install_optional = [ - 'vllm==0.5.4', + 'vllm==0.6.3', ] extras_require = { diff --git a/verl/third_party/vllm/__init__.py b/verl/third_party/vllm/__init__.py index 9eee28f..290c837 100644 --- a/verl/third_party/vllm/__init__.py +++ b/verl/third_party/vllm/__init__.py @@ -40,6 +40,12 @@ def get_version(pkg): from .vllm_v_0_5_4.llm import LLM from .vllm_v_0_5_4.llm import LLMEngine from .vllm_v_0_5_4 import parallel_state +elif package_version == '0.6.3': + vllm_version = '0.6.3' + from .vllm_v_0_6_3.llm import LLM + from .vllm_v_0_6_3.llm import LLMEngine + from .vllm_v_0_6_3 import parallel_state else: raise ValueError( - f'vllm version {package_version} not supported. Currently supported versions are 0.3.1, 0.4.2, and 0.5.4.') + f'vllm version {package_version} not supported. Currently supported versions are 0.3.1, 0.4.2, 0.5.4 and 0.6.3.' + ) diff --git a/verl/third_party/vllm/vllm_v_0_6_3/__init__.py b/verl/third_party/vllm/vllm_v_0_6_3/__init__.py new file mode 100644 index 0000000..1ce90c5 --- /dev/null +++ b/verl/third_party/vllm/vllm_v_0_6_3/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/verl/third_party/vllm/vllm_v_0_6_3/arg_utils.py b/verl/third_party/vllm/vllm_v_0_6_3/arg_utils.py new file mode 100644 index 0000000..bc4685c --- /dev/null +++ b/verl/third_party/vllm/vllm_v_0_6_3/arg_utils.py @@ -0,0 +1,78 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2023 The vLLM team. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/engine/arg_utils.py + +import os +from dataclasses import dataclass + +from transformers import PretrainedConfig +from vllm.config import EngineConfig +from vllm.engine.arg_utils import EngineArgs + +from .config import LoadConfig, ModelConfig + + +@dataclass +class EngineArgs(EngineArgs): + model_hf_config: PretrainedConfig = None # for verl + + def __post_init__(self): + pass + + def create_model_config(self) -> ModelConfig: + return ModelConfig( + hf_config=self.model_hf_config, + tokenizer_mode=self.tokenizer_mode, + trust_remote_code=self.trust_remote_code, + dtype=self.dtype, + seed=self.seed, + revision=self.revision, + code_revision=self.code_revision, + rope_scaling=self.rope_scaling, + rope_theta=self.rope_theta, + tokenizer_revision=self.tokenizer_revision, + max_model_len=self.max_model_len, + quantization=self.quantization, + quantization_param_path=self.quantization_param_path, + enforce_eager=self.enforce_eager, + max_context_len_to_capture=self.max_context_len_to_capture, + max_seq_len_to_capture=self.max_seq_len_to_capture, + max_logprobs=self.max_logprobs, + disable_sliding_window=self.disable_sliding_window, + skip_tokenizer_init=self.skip_tokenizer_init, + served_model_name=self.served_model_name, + limit_mm_per_prompt=self.limit_mm_per_prompt, + use_async_output_proc=not self.disable_async_output_proc, + override_neuron_config=self.override_neuron_config, + config_format=self.config_format, + mm_processor_kwargs=self.mm_processor_kwargs, + ) + + def create_load_config(self) -> LoadConfig: + return LoadConfig( + load_format=self.load_format, + download_dir=self.download_dir, + model_loader_extra_config=self.model_loader_extra_config, + ignore_patterns=self.ignore_patterns, + ) + + def create_engine_config(self) -> EngineConfig: + engine_config = super().create_engine_config() + + # NOTE[VERL]: Use the world_size set by torchrun + world_size = int(os.getenv("WORLD_SIZE", "-1")) + assert world_size != -1, "The world_size is set to -1, not initialized by TORCHRUN" + engine_config.parallel_config.world_size = world_size + + return engine_config diff --git a/verl/third_party/vllm/vllm_v_0_6_3/config.py b/verl/third_party/vllm/vllm_v_0_6_3/config.py new file mode 100644 index 0000000..d7cee45 --- /dev/null +++ b/verl/third_party/vllm/vllm_v_0_6_3/config.py @@ -0,0 +1,105 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2023 The vLLM team. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/config.py + +import enum +import json +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, List, Optional, Union + +from transformers import PretrainedConfig + +# Add for verl +from vllm.config import ModelConfig +from vllm.logger import init_logger +from vllm.utils import is_hip + +if TYPE_CHECKING: + from vllm.model_executor.model_loader.loader import BaseModelLoader + +logger = init_logger(__name__) + + +class LoadFormat(str, enum.Enum): + AUTO = "auto" + MEGATRON = "megatron" + HF = "hf" + DTENSOR = "dtensor" + DUMMY_HF = "dummy_hf" + DUMMY_MEGATRON = "dummy_megatron" + DUMMY_DTENSOR = "dummy_dtensor" + + +class ModelConfig(ModelConfig): + + def __init__(self, hf_config: PretrainedConfig, *args, **kwargs) -> None: + super().__init__(model=hf_config._name_or_path, tokenizer=hf_config._name_or_path, *args, **kwargs) + self.hf_config = hf_config + + +@dataclass +class LoadConfig: + """ + download_dir: Directory to download and load the weights, default to the + default cache directory of huggingface. + load_format: The format of the model weights to load: + "auto" will try to load the weights in the safetensors format and + fall back to the pytorch bin format if safetensors format is + not available. + "pt" will load the weights in the pytorch bin format. + "safetensors" will load the weights in the safetensors format. + "npcache" will load the weights in pytorch format and store + a numpy cache to speed up the loading. + "dummy" will initialize the weights with random values, which is + mainly for profiling. + "tensorizer" will use CoreWeave's tensorizer library for + fast weight loading. + "bitsandbytes" will load nf4 type weights. + ignore_patterns: The list of patterns to ignore when loading the model. + Default to "original/**/*" to avoid repeated loading of llama's + checkpoints. + + """ + + load_format: Union[str, LoadFormat, "BaseModelLoader"] = LoadFormat.AUTO + download_dir: Optional[str] = None + model_loader_extra_config: Optional[Union[str, dict]] = field(default_factory=dict) + ignore_patterns: Optional[Union[List[str], str]] = None + + def __post_init__(self): + model_loader_extra_config = self.model_loader_extra_config or {} + if isinstance(model_loader_extra_config, str): + self.model_loader_extra_config = json.loads(model_loader_extra_config) + self._verify_load_format() + + if self.ignore_patterns is not None and len(self.ignore_patterns) > 0: + logger.info("Ignoring the following patterns when downloading weights: %s", self.ignore_patterns) + else: + self.ignore_patterns = ["original/**/*"] + + def _verify_load_format(self) -> None: + if not isinstance(self.load_format, str): + return + + load_format = self.load_format.lower() + self.load_format = LoadFormat(load_format) + + rocm_not_supported_load_format: List[str] = [] + if is_hip() and load_format in rocm_not_supported_load_format: + rocm_supported_load_format = [ + f for f in LoadFormat.__members__ if (f not in rocm_not_supported_load_format) + ] + raise ValueError(f"load format '{load_format}' is not supported in ROCm. " + f"Supported load formats are " + f"{rocm_supported_load_format}") diff --git a/verl/third_party/vllm/vllm_v_0_6_3/dtensor_weight_loaders.py b/verl/third_party/vllm/vllm_v_0_6_3/dtensor_weight_loaders.py new file mode 100644 index 0000000..a3042ca --- /dev/null +++ b/verl/third_party/vllm/vllm_v_0_6_3/dtensor_weight_loaders.py @@ -0,0 +1,380 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2023 The vLLM team. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Adapted from https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/model_loader + +from typing import Dict + +import torch.nn as nn +from torch.distributed._tensor import DTensor +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models.utils import is_pp_missing_parameter + + +def gemma_dtensor_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + params_dict = dict(vllm_model.named_parameters()) + for name, loaded_weight in actor_weights.items(): + for param_name, shard_name, shard_id in stacked_params_mapping: + if shard_name not in name: + continue + stacked_name = name.replace(shard_name, param_name) + # Skip loading extra bias for GPTQ models. + if stacked_name.endswith(".bias") and stacked_name not in params_dict: + continue + local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight) + param = params_dict[stacked_name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, local_loaded_weight.to(dtype=param.dtype), shard_id) + break + else: + # lm_head is not used in vllm as it is tied with embed_token. + # To prevent errors, skip loading lm_head.weight. + if "lm_head.weight" in name: + continue + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight) + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, local_loaded_weight.to(dtype=param.dtype)) + + +def gptbigcode_dtensor_load_weights(actor_weights: Dict, vllm_model: nn.Module): + params_dict = dict(vllm_model.named_parameters(remove_duplicate=False)) + for name, loaded_weight in actor_weights.items(): + if "lm_head.weight" in name: + continue + if ".attn.bias" in name: + # Skip attention mask. + # NOTE: "c_attn.bias" should not be skipped. + continue + local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight) + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, local_loaded_weight.to(dtype=param.dtype)) + + +def starcoder2_dtensor_load_weights(actor_weights: Dict, vllm_model: nn.Module): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ] + + params_dict = dict(vllm_model.named_parameters(remove_duplicate=False)) + for name, loaded_weight in actor_weights.items(): + if "rotary_emb.inv_freq" in name: + continue + + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight) + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, local_loaded_weight.to(dtype=param.dtype), shard_id) + break + else: + if vllm_model.config.tie_word_embeddings and "lm_head.weight" in name: + continue + param = params_dict[name] + local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight) + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, local_loaded_weight.to(dtype=param.dtype)) + + +def llama_dtensor_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + (".qkv_proj", ".q_proj", "q"), + (".qkv_proj", ".k_proj", "k"), + (".qkv_proj", ".v_proj", "v"), + (".gate_up_proj", ".gate_proj", 0), + (".gate_up_proj", ".up_proj", 1), + ] + params_dict = dict(vllm_model.named_parameters()) + for name, loaded_weight in actor_weights.items(): + if "rotary_emb.inv_freq" in name: + continue + if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name: + # Models trained using ColossalAI may include these tensors in + # the checkpoint. Skip them. + continue + # With tie_word_embeddings, we can skip lm_head.weight + # The weight might appear unnecessarily in the files if the model is + # processed with quantization, LoRA, fine-tuning, etc. + if vllm_model.config.tie_word_embeddings and "lm_head.weight" in name: + continue + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight) + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, local_loaded_weight.to(dtype=param.dtype), shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight) + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, local_loaded_weight) + + +def qwen2_dtensor_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + params_dict = dict(vllm_model.named_parameters(remove_duplicate=False)) + for name, loaded_weight in actor_weights.items(): + if "rotary_emb.inv_freq" in name: + continue + if vllm_model.config.tie_word_embeddings and "lm_head.weight" in name: + continue + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight) + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, local_loaded_weight.to(dtype=param.dtype), shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight) + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, local_loaded_weight.to(dtype=param.dtype)) + + +def qwen2vl_dtensor_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + params_dict = dict(vllm_model.named_parameters(remove_duplicate=False)) + for name, loaded_weight in actor_weights.items(): + if "rotary_emb.inv_freq" in name: + continue + if vllm_model.config.tie_word_embeddings and "lm_head.weight" in name: + continue + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight) + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, local_loaded_weight.to(dtype=param.dtype), shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight) + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, local_loaded_weight.to(dtype=param.dtype)) + + +from vllm.model_executor.layers.fused_moe import FusedMoE + + +def deepseekv2_dtensor_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + + # Params for weights, fp8 weight scales, fp8 activation scales + # (param_name, weight_name, expert_id, shard_id) + expert_params_mapping = FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + num_experts=vllm_model.config.n_routed_experts, + ) + + params_dict = dict(vllm_model.named_parameters(remove_duplicate=False)) + for name, loaded_weight in actor_weights.items(): + if "rotary_emb.inv_freq" in name: + continue + for param_name, weight_name, shard_id in stacked_params_mapping: + # Skip non-stacked layers and experts (experts handled below). + if weight_name not in name: + continue + # We have mlp.experts[0].gate_proj in the checkpoint. + # Since we handle the experts below in expert_params_mapping, + # we need to skip here BEFORE we update the name, otherwise + # name will be updated to mlp.experts[0].gate_up_proj, which + # will then be updated below in expert_params_mapping + # for mlp.experts[0].gate_gate_up_proj, which breaks load. + if ("mlp.experts." in name) and name not in params_dict: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + if is_pp_missing_parameter(name, vllm_model): + continue + + param = params_dict[name] + local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight) + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, local_loaded_weight.to(dtype=param.dtype), shard_id) + break + else: + for mapping in expert_params_mapping: + param_name, weight_name, expert_id, shard_id = mapping + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + + if is_pp_missing_parameter(name, vllm_model): + continue + + param = params_dict[name] + local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight) + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader( + param, + local_loaded_weight.to(dtype=param.dtype), + weight_name, + shard_id=shard_id, + expert_id=expert_id, + ) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + if is_pp_missing_parameter(name, vllm_model): + continue + + param = params_dict[name] + local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight) + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, local_loaded_weight.to(dtype=param.dtype)) + + +def gpt2_dtensor_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module: + pass + + +def redistribute_dtensor(param_name: str, loaded_weights: DTensor, parallelize_plan: Dict = None): + param_name = _process_parameter_names(name=param_name) + if parallelize_plan is not None: + assert ( + param_name + in parallelize_plan.keys()), f"param name: {param_name} not in parallelize_plan :{parallelize_plan.keys()}" + placement = parallelize_plan[param_name] + local_loaded_weights = loaded_weights.redistribute(device_mesh=loaded_weights.device_mesh, + placements=placement).to_local() + else: + local_loaded_weights = loaded_weights.full_tensor() + return local_loaded_weights + + +def _process_parameter_names(name): + # Remove '.weight' if it exists at the end of the string + if name.endswith(".weight"): + name = name[:-7] + + # Remove 'model.layers.x.' or 'model.' prefix + if "model.layers" in name: + parts = name.split(".") + # Reconstruct the string without 'model.layers.x.' + name = ".".join(parts[3:]) # parts[0] is 'model', parts[1] is 'layers', parts[2] is 'x' + elif name.startswith("model."): + name = name[6:] # Remove 'model.' + + return name + + +__MODEL_DTENSOR_WEIGHT_LOADER_REGISTRY__ = { + "GPT2LMHeadModel": gpt2_dtensor_weight_loader, + "LlamaForCausalLM": llama_dtensor_weight_loader, + "LLaMAForCausalLM": llama_dtensor_weight_loader, + "MistralForCausalLM": llama_dtensor_weight_loader, # mistral is the same as llama in vLLM + "InternLMForCausalLM": llama_dtensor_weight_loader, + "AquilaModel": llama_dtensor_weight_loader, + "AquilaForCausalLM": llama_dtensor_weight_loader, + "Phi3ForCausalLM": llama_dtensor_weight_loader, + "GemmaForCausalLM": gemma_dtensor_weight_loader, + "Gemma2ForCausalLM": gemma_dtensor_weight_loader, + "GPTBigCodeForCausalLM": gptbigcode_dtensor_load_weights, + "Starcoder2ForCausalLM": starcoder2_dtensor_load_weights, + "Qwen2ForCausalLM": qwen2_dtensor_weight_loader, + "DeepseekV2ForCausalLM": deepseekv2_dtensor_weight_loader, + "Qwen2VLForConditionalGeneration": qwen2vl_dtensor_weight_loader, +} + + +# the actor model is .state_dict() +# Load dtensor weights +def load_dtensor_weights(actor_weights: Dict, vllm_model: nn.Module): + weight_loader = _get_model_weight_loader(vllm_model.__class__.__name__) + weight_loader(actor_weights, vllm_model) + # NOTE(sgm) to reduce peak memory usage, we offload vllm model to cpu + # after init, and we need this after sync model weights for in first iter. + vllm_model = vllm_model.cuda() + + +def _get_model_weight_loader(arch: str): + if arch in __MODEL_DTENSOR_WEIGHT_LOADER_REGISTRY__: + return __MODEL_DTENSOR_WEIGHT_LOADER_REGISTRY__[arch] + raise ValueError(f"Model architectures {arch} are not supported for now. " + f"Supported architectures: {__MODEL_DTENSOR_WEIGHT_LOADER_REGISTRY__.keys()}") + + +# NOTE(sgm): we use per-parameter weight loader in each vllm sub +def update_dtensor_weight_loader(): + pass diff --git a/verl/third_party/vllm/vllm_v_0_6_3/hf_weight_loader.py b/verl/third_party/vllm/vllm_v_0_6_3/hf_weight_loader.py new file mode 100644 index 0000000..a3e5b22 --- /dev/null +++ b/verl/third_party/vllm/vllm_v_0_6_3/hf_weight_loader.py @@ -0,0 +1,41 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2023 The vLLM team. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Adapted from https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/model_loader + +from typing import Dict + +import torch.nn as nn +from vllm.model_executor.model_loader.utils import set_default_torch_dtype + + +def update_hf_weight_loader(): + print("no hf weight loader need to be updated") + return + + +def load_hf_weights(actor_weights: Dict, vllm_model: nn.Module): + assert isinstance(actor_weights, Dict) + with set_default_torch_dtype(next(vllm_model.parameters()).dtype): # TODO + if vllm_model.config.tie_word_embeddings and "lm_head.weight" in actor_weights.keys(): + del actor_weights["lm_head.weight"] + vllm_model.load_weights(actor_weights.items()) + for _, module in vllm_model.named_modules(): + quant_method = getattr(module, "quant_method", None) + if quant_method is not None: + quant_method.process_weights_after_loading(module) + # FIXME: Remove this after Mixtral is updated + # to use quant_method. + if hasattr(module, "process_weights_after_loading"): + module.process_weights_after_loading() + vllm_model = vllm_model.cuda() diff --git a/verl/third_party/vllm/vllm_v_0_6_3/llm.py b/verl/third_party/vllm/vllm_v_0_6_3/llm.py new file mode 100644 index 0000000..9351457 --- /dev/null +++ b/verl/third_party/vllm/vllm_v_0_6_3/llm.py @@ -0,0 +1,200 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2023 The vLLM team. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/entrypoints/llm.py + +from typing import Dict, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +from torch.nn.utils.rnn import pad_sequence +from transformers import PretrainedConfig, PreTrainedTokenizer, PreTrainedTokenizerFast +from verl.trainer.ppo.rollout.tokenizer import HybridEngineBaseTokenizer +from vllm import LLM +from vllm.outputs import EmbeddingRequestOutput, RequestOutput +from vllm.utils import Counter + +from .arg_utils import EngineArgs +from .llm_engine_sp import LLMEngine + + +class LLM(LLM): + """An LLM for generating texts from given prompts and sampling parameters. + + This class includes a tokenizer, a language model (possibly distributed + across multiple GPUs), and GPU memory space allocated for intermediate + states (aka KV cache). Given a batch of prompts and sampling parameters, + this class generates texts from the model, using an intelligent batching + mechanism and efficient memory management. + + NOTE: This class is intended to be used for offline inference. For online + serving, use the `AsyncLLMEngine` class instead. + NOTE: For the comprehensive list of arguments, see `EngineArgs`. + + Args: + model: A HuggingFace Transformers model instance. + tokenizer: A HuggingFace Transformers tokenizer instance. + tokenizer_mode: The tokenizer mode. "auto" will use the fast tokenizer + if available, and "slow" will always use the slow tokenizer. + trust_remote_code: Trust remote code (e.g., from HuggingFace) when + downloading the model and tokenizer. + tensor_parallel_size: The number of GPUs to use for distributed + execution with tensor parallelism. + dtype: The data type for the model weights and activations. Currently, + we support `float32`, `float16`, and `bfloat16`. If `auto`, we use + the `torch_dtype` attribute specified in the model config file. + However, if the `torch_dtype` in the config is `float32`, we will + use `float16` instead. + quantization: The method used to quantize the model weights. Currently, + we support "awq". If None, we assume the model weights are not + quantized and use `dtype` to determine the data type of the weights. + revision: The specific model version to use. It can be a branch name, + a tag name, or a commit id. + tokenizer_revision: The specific tokenizer version to use. It can be a + branch name, a tag name, or a commit id. + seed: The seed to initialize the random number generator for sampling. + gpu_memory_utilization: The ratio (between 0 and 1) of GPU memory to + reserve for the model weights, activations, and KV cache. Higher + values will increase the KV cache size and thus improve the model's + throughput. However, if the value is too high, it may cause out-of- + memory (OOM) errors. + swap_space: The size (GiB) of CPU memory per GPU to use as swap space. + This can be used for temporarily storing the states of the requests + when their `best_of` sampling parameters are larger than 1. If all + requests will have `best_of=1`, you can safely set this to 0. + Otherwise, too small values may cause out-of-memory (OOM) errors. + enforce_eager: Whether to enforce eager execution. If True, we will + disable CUDA graph and always execute the model in eager mode. + If False, we will use CUDA graph and eager execution in hybrid. + max_context_len_to_capture: Maximum context len covered by CUDA graphs. + When a sequence has context length larger than this, we fall back + to eager mode. + disable_custom_all_reduce: See ParallelConfig + """ + + def __init__( + self, + model: Union[nn.Module, Dict], # model itself or its parameter dict + tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast, HybridEngineBaseTokenizer], + model_hf_config: PretrainedConfig, + tokenizer_mode: str = "auto", + trust_remote_code: bool = False, + skip_tokenizer_init: bool = False, + tensor_parallel_size: int = 1, + dtype: str = "auto", + quantization: Optional[str] = None, + revision: Optional[str] = None, + tokenizer_revision: Optional[str] = None, + seed: int = 0, + gpu_memory_utilization: float = 0.9, + swap_space: int = 4, + cpu_offload_gb: float = 0, + enforce_eager: bool = False, + max_context_len_to_capture: Optional[int] = None, + max_seq_len_to_capture: int = 8192, + disable_custom_all_reduce: bool = False, + load_format="auto", + **kwargs, + ) -> None: + if "disable_log_stats" not in kwargs: + kwargs["disable_log_stats"] = True + removed_vision_keys = ("image_token_id", "image_feature_size", "image_input_shape", "image_input_type") + if any(k in kwargs for k in removed_vision_keys): + raise TypeError("There is no need to pass vision-related arguments anymore.") + engine_args = EngineArgs( + model_hf_config=model_hf_config, + # tokenizer=tokenizer, + tokenizer_mode=tokenizer_mode, + skip_tokenizer_init=skip_tokenizer_init, + trust_remote_code=trust_remote_code, + tensor_parallel_size=tensor_parallel_size, + dtype=dtype, + quantization=quantization, + revision=revision, + tokenizer_revision=tokenizer_revision, + seed=seed, + gpu_memory_utilization=gpu_memory_utilization, + swap_space=swap_space, + cpu_offload_gb=cpu_offload_gb, + enforce_eager=enforce_eager, + max_context_len_to_capture=max_context_len_to_capture, + max_seq_len_to_capture=max_seq_len_to_capture, + disable_custom_all_reduce=disable_custom_all_reduce, + load_format=load_format, + **kwargs, + ) + tokenizer_cls = (PreTrainedTokenizer, PreTrainedTokenizerFast, HybridEngineBaseTokenizer) + if not isinstance(tokenizer, tokenizer_cls): + raise ValueError( + f"Unexpected tokenizer type: {type(tokenizer)}. Must be" + "one of the following: PreTrainedTokenizer, PreTrainedTokenizerFast, verl.trainer.ppo.rollout.HybridEngineBaseTokenizer" + ) + self.llm_engine = LLMEngine.from_engine_args(model, tokenizer, engine_args) # TODO: check usagecontext + self.request_counter = Counter() + + def init_cache_engine(self): + self.llm_engine.init_cache_engine() + + def free_cache_engine(self): + self.llm_engine.free_cache_engine() + + def get_tokenizer(self) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]: + return self.llm_engine.tokenizer + + def set_tokenizer( + self, + tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], + ) -> None: + self.llm_engine.tokenizer = tokenizer + + def _run_engine(self, *, use_tqdm: bool) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: + outputs = super()._run_engine(use_tqdm=use_tqdm) + return self._post_process_outputs(outputs) + + # # NOTE(shengguangming): add for verl + # # TODO(sgm): we can optimize it by making the dataloader yield List[int] without padding. + # def _pre_process_inputs(self, prompt_token_ids: torch.Tensor) -> List[int]: + # # remove the left padding in the prompt token_id + # pad_token_id = self.llm_engine.tokenizer.pad_token_id if self.llm_engine.tokenizer.pad_token_id is not None else self.llm_engine.tokenizer.eos_token_id + # non_pad_index = torch.nonzero(prompt_token_ids != pad_token_id, as_tuple=False)[0][0] + # token_ids = prompt_token_ids[non_pad_index:].tolist() + # return token_ids + + # NOTE(shengguangming): add for verl + def _post_process_outputs(self, request_outputs: List[RequestOutput]) -> Tuple[torch.Tensor, torch.Tensor]: + output_token_ids = [] + logprobs = [] + for request_output in request_outputs: # List[RequestOutput] + outputs = request_output.outputs + for output in outputs: # List[CompletionOutput], usually len == 1 + output_token_ids.append(torch.tensor(output.token_ids)) + # TODO(shengguangming): can be optimzied by rewrite the Sampler._get_logprobs() logits + logprobs_dicts = output.logprobs + if logprobs_dicts is not None: + logprob = [] + for logprobs_dict, id in zip(logprobs_dicts, output.token_ids): + logprob.append(logprobs_dict[id].logprob) + logprobs.append(torch.tensor(logprob)) + + pad_token_id = (self.llm_engine.tokenizer.pad_token_id if self.llm_engine.tokenizer.pad_token_id is not None + else self.llm_engine.tokenizer.eos_token_id) + output_token_ids = pad_sequence(output_token_ids, batch_first=True, padding_value=pad_token_id) + if len(logprobs) > 0: + logprobs = pad_sequence(logprobs, batch_first=True, padding_value=pad_token_id) + return output_token_ids, logprobs + + def sync_model_weights(self, actor_weights: Dict[str, torch.Tensor], load_format: str) -> None: + self.llm_engine.sync_model_weights(actor_weights=actor_weights, load_format=load_format) + + def offload_model_weights(self) -> None: + self.llm_engine.offload_model_weights() diff --git a/verl/third_party/vllm/vllm_v_0_6_3/llm_engine_sp.py b/verl/third_party/vllm/vllm_v_0_6_3/llm_engine_sp.py new file mode 100644 index 0000000..10b112b --- /dev/null +++ b/verl/third_party/vllm/vllm_v_0_6_3/llm_engine_sp.py @@ -0,0 +1,408 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2023 The vLLM team. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/engine/llm_engine.py + +from functools import partial +from typing import Callable, Dict, Optional, Type, Union + +import torch +import torch.nn as nn +from vllm.config import ( + CacheConfig, + DecodingConfig, + DeviceConfig, + EngineConfig, + LoadConfig, + LoRAConfig, + ModelConfig, + ObservabilityConfig, + ParallelConfig, + PromptAdapterConfig, + SchedulerConfig, + SpeculativeConfig, +) +from vllm.core.scheduler import Scheduler +from vllm.engine.arg_utils import EngineArgs +from vllm.engine.llm_engine import LLMEngine, SchedulerContext, SchedulerOutputState, _load_generation_config_dict +from vllm.engine.metrics_types import StatLoggerBase +from vllm.engine.output_processor.interfaces import SequenceGroupOutputProcessor +from vllm.engine.output_processor.stop_checker import StopChecker +from vllm.executor.executor_base import ExecutorBase +from vllm.inputs import INPUT_REGISTRY, InputRegistry +from vllm.inputs.preprocess import InputPreprocessor +from vllm.logger import init_logger +from vllm.sequence import Sequence +from vllm.tracing import init_tracer +from vllm.transformers_utils.detokenizer import Detokenizer +from vllm.transformers_utils.tokenizer import AnyTokenizer +from vllm.usage.usage_lib import UsageContext, is_usage_stats_enabled, usage_message +from vllm.utils import Counter, weak_bind +from vllm.version import __version__ as VLLM_VERSION + +from .arg_utils import EngineArgs +from .config import LoadConfig, ModelConfig +from .tokenizer import TokenizerGroup + +logger = init_logger(__name__) +_LOCAL_LOGGING_INTERVAL_SEC = 5 + + +class LLMEngine(LLMEngine): + """An LLM engine that receives requests and generates texts. + + This is the main class for the vLLM engine. It receives requests + from clients and generates texts from the LLM. It includes a tokenizer, a + language model (possibly distributed across multiple GPUs), and GPU memory + space allocated for intermediate states (aka KV cache). This class utilizes + iteration-level scheduling and efficient memory management to maximize the + serving throughput. + + The :class:`~vllm.LLM` class wraps this class for offline batched inference + and the :class:`AsyncLLMEngine` class wraps this class for online serving. + + The config arguments are derived from :class:`~vllm.EngineArgs`. (See + :ref:`engine_args`) + + Args: + model_config: The configuration related to the LLM model. + cache_config: The configuration related to the KV cache memory + management. + parallel_config: The configuration related to distributed execution. + scheduler_config: The configuration related to the request scheduler. + device_config: The configuration related to the device. + lora_config (Optional): The configuration related to serving multi-LoRA. + speculative_config (Optional): The configuration related to speculative + decoding. + executor_class: The model executor class for managing distributed + execution. + prompt_adapter_config (Optional): The configuration related to serving + prompt adapters. + log_stats: Whether to log statistics. + usage_context: Specified entry point, used for usage info collection. + """ + + def __init__( + self, + # NOTE(sgm): first two arguments are added for verl + model: Union[nn.Module, Dict], # model itself or its parameter dict + tokenizer: nn.Module, + # NOTE(sgm): vllm original arguments + model_config: ModelConfig, + cache_config: CacheConfig, + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + device_config: DeviceConfig, + load_config: LoadConfig, + lora_config: Optional[LoRAConfig], + speculative_config: Optional[SpeculativeConfig], + decoding_config: Optional[DecodingConfig], + observability_config: Optional[ObservabilityConfig], + prompt_adapter_config: Optional[PromptAdapterConfig], + executor_class: Type[ExecutorBase], + log_stats: bool, + usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, + stat_loggers: Optional[Dict[str, StatLoggerBase]] = None, + input_registry: InputRegistry = INPUT_REGISTRY, + use_cached_outputs: bool = False, + ) -> None: + logger.info( + "Initializing an LLM engine (v%s) with config: " + "model=%r, speculative_config=%r, tokenizer=%r, " + "skip_tokenizer_init=%s, tokenizer_mode=%s, revision=%s, " + "override_neuron_config=%s, " + "rope_scaling=%r, rope_theta=%r, tokenizer_revision=%s, " + "trust_remote_code=%s, dtype=%s, max_seq_len=%d, " + "download_dir=%r, load_format=%s, tensor_parallel_size=%d, " + "pipeline_parallel_size=%d, " + "disable_custom_all_reduce=%s, quantization=%s, " + "enforce_eager=%s, kv_cache_dtype=%s, " + "quantization_param_path=%s, device_config=%s, " + "decoding_config=%r, observability_config=%r, " + "seed=%d, served_model_name=%s, use_v2_block_manager=%s, " + "num_scheduler_steps=%d, chunked_prefill_enabled=%s " + "multi_step_stream_outputs=%s, enable_prefix_caching=%s, " + "use_async_output_proc=%s, use_cached_outputs=%s, " + "mm_processor_kwargs=%s)", + VLLM_VERSION, + model_config.model, + speculative_config, + model_config.tokenizer, + model_config.skip_tokenizer_init, + model_config.tokenizer_mode, + model_config.revision, + model_config.override_neuron_config, + model_config.rope_scaling, + model_config.rope_theta, + model_config.tokenizer_revision, + model_config.trust_remote_code, + model_config.dtype, + model_config.max_model_len, + load_config.download_dir, + load_config.load_format, + parallel_config.tensor_parallel_size, + parallel_config.pipeline_parallel_size, + parallel_config.disable_custom_all_reduce, + model_config.quantization, + model_config.enforce_eager, + cache_config.cache_dtype, + model_config.quantization_param_path, + device_config.device, + decoding_config, + observability_config, + model_config.seed, + model_config.served_model_name, + scheduler_config.use_v2_block_manager, + scheduler_config.num_scheduler_steps, + scheduler_config.chunked_prefill_enabled, + scheduler_config.multi_step_stream_outputs, + cache_config.enable_prefix_caching, + model_config.use_async_output_proc, + use_cached_outputs, + model_config.mm_processor_kwargs, + ) + # TODO(woosuk): Print more configs in debug mode. + self.model_config = model_config + self.cache_config = cache_config + self.lora_config = lora_config + self.parallel_config = parallel_config + self.scheduler_config = scheduler_config + self.device_config = device_config + self.speculative_config = speculative_config + self.load_config = load_config + self.decoding_config = decoding_config or DecodingConfig() + self.prompt_adapter_config = prompt_adapter_config + self.observability_config = observability_config or ObservabilityConfig() + self.log_stats = log_stats + self.use_cached_outputs = use_cached_outputs + + if not self.model_config.skip_tokenizer_init: + self.tokenizer = self._init_tokenizer(tokenizer) + self.detokenizer = Detokenizer(self.tokenizer) + tokenizer_group = self.get_tokenizer_group() + else: + self.tokenizer = None + self.detokenizer = None + tokenizer_group = None + + # Ensure that the function doesn't contain a reference to self, + # to avoid engine GC issues + def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer: + assert tokenizer_group, "tokenizer_group cannot be None, " "make sure skip_tokenizer_init is False" + return tokenizer_group.get_lora_tokenizer(sequence.lora_request) + + self.seq_counter = Counter() + self.generation_config_fields = _load_generation_config_dict(model_config) + + self.input_preprocessor = InputPreprocessor(model_config, self.tokenizer) + + self.input_registry = input_registry + self.input_processor = input_registry.create_input_processor(model_config) + + self.model_executor = executor_class( + model=model, # add for spmd_gpu_executor + model_config=model_config, + cache_config=cache_config, + parallel_config=parallel_config, + scheduler_config=scheduler_config, + device_config=device_config, + lora_config=lora_config, + speculative_config=speculative_config, + load_config=load_config, + prompt_adapter_config=prompt_adapter_config, + observability_config=self.observability_config, + ) + + if not self.model_config.embedding_mode: + self._initialize_kv_caches() + + # If usage stat is enabled, collect relevant info. + if is_usage_stats_enabled(): + from vllm.model_executor.model_loader import get_architecture_class_name + + usage_message.report_usage( + get_architecture_class_name(model_config), + usage_context, + extra_kvs={ + # Common configuration + "dtype": str(model_config.dtype), + "tensor_parallel_size": parallel_config.tensor_parallel_size, + "block_size": cache_config.block_size, + "gpu_memory_utilization": cache_config.gpu_memory_utilization, + # Quantization + "quantization": model_config.quantization, + "kv_cache_dtype": str(cache_config.cache_dtype), + # Feature flags + "enable_lora": bool(lora_config), + "enable_prompt_adapter": bool(prompt_adapter_config), + "enable_prefix_caching": cache_config.enable_prefix_caching, + "enforce_eager": model_config.enforce_eager, + "disable_custom_all_reduce": parallel_config.disable_custom_all_reduce, + }, + ) + + if self.tokenizer: + # Ping the tokenizer to ensure liveness if it runs in a + # different process. + self.tokenizer.ping() + + self.cached_scheduler_outputs = [ + SchedulerOutputState() for _ in range(self.parallel_config.pipeline_parallel_size) + ] + + self.scheduler_contexts = [ + SchedulerContext(multi_step_stream_outputs=self.scheduler_config.multi_step_stream_outputs) + for _ in range(self.parallel_config.pipeline_parallel_size) + ] + + if model_config.use_async_output_proc: + process_model_outputs = weak_bind(self._process_model_outputs) + + self.async_callbacks = [ + partial(process_model_outputs, ctx=self.scheduler_contexts[v_id]) + for v_id in range(self.parallel_config.pipeline_parallel_size) + ] + else: + self.async_callbacks = [] + + # Currently used by AsyncLLMEngine to ensure quick append + # of request outputs to asyncio queues + self.process_request_outputs_callback: Optional[Callable] = None + + # Create the scheduler. + # NOTE: the cache_config here have been updated with the numbers of + # GPU and CPU blocks, which are profiled in the distributed executor. + self.scheduler = [ + Scheduler( + scheduler_config, + cache_config, + lora_config, + parallel_config.pipeline_parallel_size, + self.async_callbacks[v_id] if model_config.use_async_output_proc else None, + ) for v_id in range(parallel_config.pipeline_parallel_size) + ] + + # Metric Logging. + if self.log_stats: + if stat_loggers is not None: + self.stat_loggers = stat_loggers + else: + # Lazy import for prometheus multiprocessing. + # We need to set PROMETHEUS_MULTIPROC_DIR environment variable + # before prometheus_client is imported. + # See https://prometheus.github.io/client_python/multiprocess/ + from vllm.engine.metrics import LoggingStatLogger, PrometheusStatLogger + + self.stat_loggers = { + "logging": + LoggingStatLogger(local_interval=_LOCAL_LOGGING_INTERVAL_SEC), + "prometheus": + PrometheusStatLogger( + local_interval=_LOCAL_LOGGING_INTERVAL_SEC, + labels=dict(model_name=model_config.served_model_name), + max_model_len=self.model_config.max_model_len, + ), + } + self.stat_loggers["prometheus"].info("cache_config", self.cache_config) + + self.tracer = None + if self.observability_config.otlp_traces_endpoint: + self.tracer = init_tracer("vllm.llm_engine", self.observability_config.otlp_traces_endpoint) + + # Create sequence output processor, e.g. for beam search or + # speculative decoding. + self.output_processor = SequenceGroupOutputProcessor.create_output_processor( + self.scheduler_config, + self.detokenizer, + self.scheduler, + self.seq_counter, + get_tokenizer_for_seq, + stop_checker=StopChecker( + self.scheduler_config.max_model_len, + get_tokenizer_for_seq, + ), + ) + + # TODO(sgm): add for verl but we may not tokenizer in Rollout + def _init_tokenizer(self, tokenizer, **tokenizer_init_kwargs): + init_kwargs = dict(enable_lora=bool(self.lora_config), + max_num_seqs=self.scheduler_config.max_num_seqs, + max_input_length=None) + init_kwargs.update(tokenizer_init_kwargs) + return TokenizerGroup(tokenizer, **init_kwargs) + + def init_cache_engine(self): + # TODO: check whether we should rebuild the CUDAGraph every iter when offload/load KVCache + # Re-capture CUDAGraph would be time-consuming + self.model_executor.init_cache_engine() + + def free_cache_engine(self): + self.model_executor.free_cache_engine() + + # NOTE(sgm): currently, we only support GPU executor + # The GPUExecutor remove the Ray dependency + @classmethod + def _get_executor_cls(cls, engine_config: EngineConfig) -> Type[ExecutorBase]: + distributed_executor_backend = engine_config.parallel_config.distributed_executor_backend + # Initialize the cluster and specify the executor class.] + assert (engine_config.device_config.device_type == "cuda" + ), "Currently, the vllm in verl only support running on GPU" + + # print('Waiting for debugger'); import os,debugpy; debugpy.listen(('localhost', 5678 + int(os.getenv('RANK', '0')))); debugpy.wait_for_client() + if engine_config.parallel_config.world_size == 1: + engine_config.load_config.load_format = "dummy_hf" + + from .spmd_gpu_executor import SPMDGPUExecutor + + executor_class = SPMDGPUExecutor + + return executor_class + + @classmethod + def from_engine_args( + cls, + model, + tokenizer, + engine_args: EngineArgs, + usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, + stat_loggers: Optional[Dict[str, StatLoggerBase]] = None, + ) -> "LLMEngine": + """Creates an LLM engine from the engine arguments.""" + # Create the engine configs. + engine_config = engine_args.create_engine_config() + executor_class = cls._get_executor_cls(engine_config) + # Initialize the cluster and specify the executor class. + assert (engine_config.device_config.device_type == "cuda" + ), "Currently, the vllm in verl only support running on GPU" + + from .spmd_gpu_executor import SPMDGPUExecutor + + executor_class = SPMDGPUExecutor + + # Create the LLM engine. + engine = cls( + model, + tokenizer, + **engine_config.to_dict(), + executor_class=executor_class, + log_stats=not engine_args.disable_log_stats, + usage_context=usage_context, + stat_loggers=stat_loggers, + ) + return engine + + def sync_model_weights(self, actor_weights: Dict[str, torch.Tensor], load_format: str) -> None: + self.model_executor.sync_model_weights(actor_weights=actor_weights, load_format=load_format) + + def offload_model_weights(self) -> None: + self.model_executor.offload_model_weights() diff --git a/verl/third_party/vllm/vllm_v_0_6_3/megatron_weight_loaders.py b/verl/third_party/vllm/vllm_v_0_6_3/megatron_weight_loaders.py new file mode 100644 index 0000000..7fd6c0e --- /dev/null +++ b/verl/third_party/vllm/vllm_v_0_6_3/megatron_weight_loaders.py @@ -0,0 +1,308 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2023 The vLLM team. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Adapted from https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/model_loader + +from typing import Dict + +import torch +import torch.nn as nn +from vllm.model_executor.layers.linear import * +from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead, VocabParallelEmbedding +from vllm.model_executor.models import ModelRegistry + + +# NOTE(shengguangming): replace the origin weight loader function in the class +def parallel_weight_loader(self, param: torch.Tensor, loaded_weight: torch.Tensor) -> None: + """Parallel Linear weight loader.""" + assert (param.size() == loaded_weight.size( + )), "the parameter size is not align with the loaded weight size, param size: {}, loaded_weight size: {}".format( + param.size(), loaded_weight.size()) + assert (param.data.dtype == loaded_weight.data.dtype + ), "if we want to shared weights, the data type should also be the same" + + param.data = loaded_weight.data + + +def default_weight_loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None: + """Default weight loader.""" + assert param.size() == loaded_weight.size() + assert (param.data.dtype == loaded_weight.data.dtype + ), "if we want to shared weights, the data type should also be the same" + + param.data = loaded_weight.data + + +def gpt2_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module: + params_dict = dict(vllm_model.named_parameters(remove_duplicate=False)) + for name, loaded_weight in actor_weights.items(): + if "lm_head.weight" in name: + # GPT-2 ties the weights of the embedding layer and the final + # linear layer. + continue + if ".attn.bias" in name or ".attn.masked_bias" in name: + # Skip attention mask. + # NOTE: "c_attn.bias" should not be skipped. + continue + if not name.startswith("transformer."): + name = "transformer." + name + param = params_dict[name] + # The HF's GPT-2 implementation uses Conv1D instead of Linear. + # Because of this, we need to transpose the weights. + # Note(zhuohan): the logic below might break quantized models. + for conv1d_weight_name in ["c_attn", "c_proj", "c_fc"]: + if conv1d_weight_name not in name: + continue + if not name.endswith(".weight"): + continue + # TODO: check megatron + loaded_weight = loaded_weight.t() + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + + +def llama_megatron_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module: + # NOTE(shengguangming): the megatron llama may have this prefix + params_dict = dict(vllm_model.named_parameters()) + for name, loaded_weight in actor_weights.items(): + if "rotary_emb.inv_freq" in name: + continue + else: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + + +def llama_megatron_core_te_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module: + params_mapping = [ + # (megatron core gpt model name, vllm model name) + ("embedding.word_embeddings", "model.embed_tokens"), + ("self_attention.linear_qkv.layer_norm_weight", "input_layernorm.weight"), + ("self_attention.linear_qkv.layer_norm_bias", "input_layernorm.bias"), + ("self_attention.linear_qkv", "self_attn.qkv_proj"), + ("self_attention.linear_qkv", "self_attn.qkv_proj"), + ("self_attention.linear_proj", "self_attn.o_proj"), + ("pre_mlp_layernorm", "post_attention_layernorm"), + ("mlp.linear_fc1.layer_norm_weight", "post_attention_layernorm.weight"), + ("mlp.linear_fc1.layer_norm_bias", "post_attention_layernorm.bias"), + ("mlp.linear_fc1", "mlp.gate_up_proj"), + ("mlp.linear_fc2", "mlp.down_proj"), + ("decoder.final_layernorm", "model.norm"), + ("output_layer", "lm_head"), + ] + # NOTE(shengguangming): the megatron llama may have this prefix + params_dict = dict(vllm_model.named_parameters()) + for name, loaded_weight in actor_weights.items(): + name = _replace_name(name, params_mapping) + if name.endswith(".bias") and name not in params_dict: + continue + if "rotary_emb.inv_freq" in name: + continue + else: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + + +def llama_megatron_core_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module: + params_mapping = [ + # (megatron core gpt model name, vllm model name) + ("embedding.word_embeddings", "model.embed_tokens"), + ("self_attention.linear_qkv", "self_attn.qkv_proj"), + ("self_attention.linear_proj", "self_attn.o_proj"), + ( + "input_layernorm", + "input_layernorm", + ), + ("pre_mlp_layernorm", "post_attention_layernorm"), + ("mlp.linear_fc1", "mlp.gate_up_proj"), + ("mlp.linear_fc2", "mlp.down_proj"), + ("decoder.final_layernorm", "model.norm"), + ("output_layer", "lm_head"), + ] + # NOTE(shengguangming): the megatron llama may have this prefix + params_dict = dict(vllm_model.named_parameters()) + for name, loaded_weight in actor_weights.items(): + name = _replace_name(name, params_mapping) + if name.endswith(".bias") and name not in params_dict: + continue + if "rotary_emb.inv_freq" in name: + continue + else: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + + +def _replace_name(megatron_name, name_mapping): + for m_name, v_name in name_mapping: + if m_name not in megatron_name: + continue + if "layers" in megatron_name: # deal with decoder layers + megatron_name = megatron_name.replace("decoder", "model") + megatron_name_list = megatron_name.split(".") + if "layer_norm_weight" in megatron_name_list or "layer_norm_bias" in megatron_name_list: + param_name_list = megatron_name_list[:3] + param_name_list.append(v_name) + param_name = ".".join(param_name_list) + else: + param_name_list = megatron_name_list[:3] + weight_or_bias = megatron_name_list[-1] + param_name_list.append(v_name) + param_name_list.append(weight_or_bias) + param_name = ".".join(param_name_list) + return param_name + else: + param_name = megatron_name.replace(m_name, v_name) + return param_name + + +def llama_megatron_core_te_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module: + params_mapping = [ + # (megatron core gpt model name, vllm model name) + ("embedding.word_embeddings", "model.embed_tokens"), + ("self_attention.linear_qkv.layer_norm_weight", "input_layernorm.weight"), + ("self_attention.linear_qkv.layer_norm_bias", "input_layernorm.bias"), + ("self_attention.linear_qkv", "self_attn.qkv_proj"), + ("self_attention.linear_qkv", "self_attn.qkv_proj"), + ("self_attention.linear_proj", "self_attn.o_proj"), + ("pre_mlp_layernorm", "post_attention_layernorm"), + ("mlp.linear_fc1.layer_norm_weight", "post_attention_layernorm.weight"), + ("mlp.linear_fc1.layer_norm_bias", "post_attention_layernorm.bias"), + ("mlp.linear_fc1", "mlp.gate_up_proj"), + ("mlp.linear_fc2", "mlp.down_proj"), + ("decoder.final_layernorm", "model.norm"), + ("output_layer", "lm_head"), + ] + # NOTE(shengguangming): the megatron llama may have this prefix + params_dict = dict(vllm_model.named_parameters()) + for name, loaded_weight in actor_weights.items(): + name = _replace_name(name, params_mapping) + if name.endswith(".bias") and name not in params_dict: + continue + if "rotary_emb.inv_freq" in name: + continue + else: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + + +def llama_megatron_core_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module: + params_mapping = [ + # (megatron core gpt model name, vllm model name) + ("embedding.word_embeddings", "model.embed_tokens"), + ("self_attention.linear_qkv", "self_attn.qkv_proj"), + ("self_attention.linear_proj", "self_attn.o_proj"), + ( + "input_layernorm", + "input_layernorm", + ), + ("pre_mlp_layernorm", "post_attention_layernorm"), + ("mlp.linear_fc1", "mlp.gate_up_proj"), + ("mlp.linear_fc2", "mlp.down_proj"), + ("decoder.final_layernorm", "model.norm"), + ("output_layer", "lm_head"), + ] + # NOTE(shengguangming): the megatron llama may have this prefix + params_dict = dict(vllm_model.named_parameters()) + for name, loaded_weight in actor_weights.items(): + name = _replace_name(name, params_mapping) + if name.endswith(".bias") and name not in params_dict: + continue + if "rotary_emb.inv_freq" in name: + continue + else: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + + +def _replace_name(megatron_name, name_mapping): + for m_name, v_name in name_mapping: + if m_name not in megatron_name: + continue + if "layers" in megatron_name: # deal with decoder layers + megatron_name = megatron_name.replace("decoder", "model") + megatron_name_list = megatron_name.split(".") + if "layer_norm_weight" in megatron_name_list or "layer_norm_bias" in megatron_name_list: + param_name_list = megatron_name_list[:3] + param_name_list.append(v_name) + param_name = ".".join(param_name_list) + else: + param_name_list = megatron_name_list[:3] + weight_or_bias = megatron_name_list[-1] + param_name_list.append(v_name) + param_name_list.append(weight_or_bias) + param_name = ".".join(param_name_list) + return param_name + else: + param_name = megatron_name.replace(m_name, v_name) + return param_name + + +def mistral_megatron_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module: + # TODO: need to implement a general way to deal with prefix + params_dict = dict(vllm_model.named_parameters()) + for name, loaded_weight in actor_weights.items(): + if "rotary_emb.inv_freq" in name: + continue + else: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + + +__LAYER_WEIGHT_MEGATRON_LOADER_REGISTRY__ = { + ColumnParallelLinear: parallel_weight_loader, + MergedColumnParallelLinear: parallel_weight_loader, + QKVParallelLinear: parallel_weight_loader, + RowParallelLinear: parallel_weight_loader, + VocabParallelEmbedding: parallel_weight_loader, + ParallelLMHead: parallel_weight_loader, + # "ScaledActivation.weight_loader": ScaledActivation, # TODO(shengguangming): latest commit in vllm fix awq for this function and add load_weights + # "default_weight_loader": default_weight_loader +} + +# for layer_class, weight_loader in __LAYER_WEIGHT_MEGATRON_LOADER_REGISTRY__.items(): +# # setattr(layer_class, 'megatron_weight_loader', weight_loader) +# layer_class.weight_loader = weight_loader + +__MODEL_MEGATRON_WEIGHT_LOADER_REGISTRY__ = { + "GPT2LMHeadModel": gpt2_weight_loader, + "LlamaForCausalLM": llama_megatron_weight_loader, # use te backend for open-source megatron + "LLaMAForCausalLM": llama_megatron_weight_loader, + "MistralForCausalLM": mistral_megatron_weight_loader, +} + + +# the actor model is .state_dict() +# Load megatron weights +def load_megatron_weights(actor_weights: Dict, vllm_model: nn.Module): + weight_loader = _get_model_weight_loader(vllm_model.__class__.__name__) + weight_loader(actor_weights, vllm_model) + # NOTE(sgm) to reduce peak memory usage, we offload vllm model to cpu + # after init, and we need this after sync model weights for in first iter. + vllm_model = vllm_model.cuda() + + +def _get_model_weight_loader(arch: str): + if arch in __MODEL_MEGATRON_WEIGHT_LOADER_REGISTRY__: + return __MODEL_MEGATRON_WEIGHT_LOADER_REGISTRY__[arch] + raise ValueError(f"Model architectures {arch} are not supported for now. " + f"Supported architectures: {ModelRegistry.get_supported_archs()}") + + +def update_megatron_weight_loader(): + for layer_class, weight_loader in __LAYER_WEIGHT_MEGATRON_LOADER_REGISTRY__.items(): + layer_class.weight_loader = weight_loader diff --git a/verl/third_party/vllm/vllm_v_0_6_3/model_loader.py b/verl/third_party/vllm/vllm_v_0_6_3/model_loader.py new file mode 100644 index 0000000..2f32a91 --- /dev/null +++ b/verl/third_party/vllm/vllm_v_0_6_3/model_loader.py @@ -0,0 +1,332 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2023 The vLLM team. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Adapted from https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/models +"""Utilities for selecting and loading models.""" +from typing import Dict, Optional, Union + +import torch +import torch.nn as nn +from transformers import PreTrainedModel +from vllm.config import CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, ParallelConfig, SchedulerConfig +from vllm.distributed.communication_op import tensor_model_parallel_all_gather +from vllm.model_executor.model_loader import BaseModelLoader +from vllm.model_executor.model_loader.loader import _initialize_model +from vllm.model_executor.model_loader.utils import set_default_torch_dtype + +from .config import LoadConfig, LoadFormat, ModelConfig +from .dtensor_weight_loaders import load_dtensor_weights, update_dtensor_weight_loader +from .hf_weight_loader import update_hf_weight_loader +from .megatron_weight_loaders import load_megatron_weights, update_megatron_weight_loader + + +def get_model( + actor_model: Union[PreTrainedModel, Dict], + model_config: ModelConfig, + load_config: LoadConfig, + device_config: DeviceConfig, + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + lora_config: Optional[LoRAConfig], + cache_config: CacheConfig = None, +) -> nn.Module: + loader = get_model_loader(load_config) + if load_config.load_format.startswith("dummy"): + return loader.load_model( + model_config=model_config, + device_config=device_config, + lora_config=lora_config, + parallel_config=parallel_config, + scheduler_config=scheduler_config, + cache_config=cache_config, + ) + else: + return loader.load_model( + actor_model=actor_model, + model_config=model_config, + device_config=device_config, + lora_config=lora_config, + parallel_config=parallel_config, + scheduler_config=scheduler_config, + cache_config=cache_config, + ) + + +def get_model_loader(load_config: LoadConfig) -> BaseModelLoader: + """Get a model loader based on the load format.""" + + if isinstance(load_config.load_format, type): + return load_config.load_format(load_config) + + if load_config.load_format == LoadFormat.AUTO: + update_megatron_weight_loader() + return MegatronLoader(load_config) + + # NOTE(sgm): change the weight_loader function in runtime + if load_config.load_format == LoadFormat.MEGATRON: + update_megatron_weight_loader() + return MegatronLoader(load_config) + + if load_config.load_format == LoadFormat.HF: + update_hf_weight_loader() + return HFLoader(load_config) + + if load_config.load_format == LoadFormat.DTENSOR: + update_dtensor_weight_loader() + return DTensorLoader(load_config) + + if load_config.load_format == LoadFormat.DUMMY_HF: + update_hf_weight_loader() + return DummyModelLoader(load_config) + + if load_config.load_format == LoadFormat.DUMMY_MEGATRON: + update_megatron_weight_loader() + return DummyModelLoader(load_config) + + if load_config.load_format == LoadFormat.DUMMY_DTENSOR: + update_dtensor_weight_loader() + return DummyModelLoader(load_config) + + raise ValueError("load format not supported in verl: {}, only support {} and {}".format( + load_config.load_format, LoadFormat.MEGATRON, LoadFormat.HF)) + + +class DummyModelLoader(BaseModelLoader): + """Model loader that will set model weights to random values.""" + + def __init__(self, load_config: LoadConfig): + super().__init__(load_config) + if load_config.model_loader_extra_config: + raise ValueError(f"Model loader extra config is not supported for " + f"load format {load_config.load_format}") + + def download_model(self, model_config: ModelConfig) -> None: + pass + + def load_model( + self, + *, + model_config: ModelConfig, + device_config: DeviceConfig, + lora_config: Optional[LoRAConfig], + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + cache_config: CacheConfig, + ) -> nn.Module: + with set_default_torch_dtype(model_config.dtype): + with torch.device(device_config.device): + model = _initialize_model(model_config, self.load_config, lora_config, cache_config, scheduler_config) + # NOTE(woosuk): For accurate performance evaluation, we assign + # random values to the weights. + # initialize_dummy_weights(model) + return model.eval() + + +class MegatronLoader(BaseModelLoader): + """Model loader that can load the model weights from partitioned megatron model.""" + + def __init__(self, load_config: LoadConfig): + super().__init__(load_config) + if load_config.model_loader_extra_config: + raise ValueError(f"Model loader extra config is not supported for " + f"load format {load_config.load_format}") + + def download_model(self, model_config: ModelConfig) -> None: + pass # Nothing to download + + def _get_weights_iterator(actor_model: Union[PreTrainedModel, Dict]): + # NOTE(shengguangming) Load the weights from the actor model + pass + # if isinstance(actor_model, nn.Module): + # load_weights(actor_weights=dict(actor_model.named_parameters(remove_duplicate=False)), vllm_model=model) + # else: + # load_weights(actor_weights=actor_model, vllm_model=model) + # return actor_model + + def load_model( + self, + actor_model: Union[PreTrainedModel, Dict], + model_config: ModelConfig, + device_config: DeviceConfig, + lora_config: Optional[LoRAConfig], + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + cache_config: CacheConfig, + ) -> nn.Module: + with set_default_torch_dtype(model_config.dtype): + with torch.device(device_config.device): + model = _initialize_model(model_config, self.load_config, lora_config, cache_config, scheduler_config) + + # TODO(sgm): This is a hack, we need to register the load_weight() func for each model in vllm + if isinstance(actor_model, nn.Module): + load_megatron_weights(actor_weights=dict(actor_model.named_parameters(remove_duplicate=False)), + vllm_model=model) + else: + load_megatron_weights(actor_weights=actor_model, vllm_model=model) + + for _, module in model.named_modules(): + quant_method = getattr(module, "quant_method", None) + if quant_method is not None: + quant_method.process_weights_after_loading(module) + # FIXME: Remove this after Mixtral is updated + # to use quant_method. + if hasattr(module, "process_weights_after_loading"): + module.process_weights_after_loading() + # NOTE(sgm) Some weights are point to gpu, but still need this. + model = model.cuda() # NOTE (zhangchi.usc1992) We need this for vllm to profile memory usage + return model.eval() + + +class HFLoader(BaseModelLoader): + """Model loader that can load the model weights from model's full params.""" + + def __init__(self, load_config: LoadConfig): + super().__init__(load_config) + if load_config.model_loader_extra_config: + raise ValueError(f"Model loader extra config is not supported for " + f"load format {load_config.load_format}") + + def _get_weights_iterator(self, actor_model: Union[PreTrainedModel, Dict]): + if isinstance(actor_model, Dict): + return actor_model.items() + elif isinstance(actor_model, nn.Module): + return dict(actor_model.named_parameters()).items() + else: + raise ValueError(f"actor model should be Dict or nn.Module, but get {type(actor_model)}") + + def load_model( + self, + actor_model: Union[PreTrainedModel, Dict], + model_config: ModelConfig, + device_config: DeviceConfig, + lora_config: Optional[LoRAConfig], + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + cache_config: CacheConfig, + ) -> nn.Module: + with set_default_torch_dtype(model_config.dtype): + # with torch.device(device_config.device): + # NOTE(sgm): init the model in cpu + model = _initialize_model(model_config, self.load_config, lora_config, cache_config, scheduler_config) + model.load_weights(self._get_weights_iterator(actor_model)) + for _, module in model.named_modules(): + quant_method = getattr(module, "quant_method", None) + if quant_method is not None: + quant_method.process_weights_after_loading(module) + # FIXME: Remove this after Mixtral is updated + # to use quant_method. + if hasattr(module, "process_weights_after_loading"): + module.process_weights_after_loading() + # NOTE(sgm) Some weights are point to gpu, but still need this. + model = model.cuda() # NOTE (zhangchi.usc1992) We need this for vllm to profile memory usage + return model.eval() + + +class DTensorLoader(BaseModelLoader): + """Model loader that can load the model weights from partitioned megatron model.""" + + def __init__(self, load_config: LoadConfig): + super().__init__(load_config) + if load_config.model_loader_extra_config: + raise ValueError(f"Model loader extra config is not supported for " + f"load format {load_config.load_format}") + + def _get_weights_iterator(actor_model: Union[PreTrainedModel, Dict]): + # NOTE(shengguangming) Load the weights from the actor model + pass + # if isinstance(actor_model, nn.Module): + # load_weights(actor_weights=dict(actor_model.named_parameters(remove_duplicate=False)), vllm_model=model) + # else: + # load_weights(actor_weights=actor_model, vllm_model=model) + # return actor_model + + def load_model( + self, + actor_model: Union[PreTrainedModel, Dict], + model_config: ModelConfig, + device_config: DeviceConfig, + lora_config: Optional[LoRAConfig], + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + cache_config: CacheConfig, + ) -> nn.Module: + with set_default_torch_dtype(model_config.dtype): + with torch.device(device_config.device): + model = _initialize_model(model_config, self.load_config, lora_config, cache_config, scheduler_config) + + # TODO(sgm): This is a hack, we need to register the load_weight() func for each model in vllm + if isinstance(actor_model, nn.Module): + load_dtensor_weights(actor_weights=dict(actor_model.named_parameters(remove_duplicate=False)), + vllm_model=model) + else: + load_dtensor_weights(actor_weights=actor_model, vllm_model=model) + + for _, module in model.named_modules(): + quant_method = getattr(module, "quant_method", None) + if quant_method is not None: + quant_method.process_weights_after_loading(module) + # FIXME: Remove this after Mixtral is updated + # to use quant_method. + if hasattr(module, "process_weights_after_loading"): + module.process_weights_after_loading() + # NOTE(sgm) Some weights are point to gpu, but still need this. + model = model.cuda() # NOTE (zhangchi.usc1992) We need this for vllm to profile memory usage + return model.eval() + + +# FIXME(sgm): hack the _get_logits function in vllm v0.4.2 +# as they use ray, the _get_logits result will only need to return to the driver node, +# therefore gather is enough. However, we use SPMD instead of a central scheduler, +# all_gather is required (aligned with v0.2.6) +def _get_logits(self, hidden_states: torch.Tensor, embedding: torch.Tensor, + embedding_bias: Optional[torch.Tensor]) -> torch.Tensor: + # Get the logits for the next tokens. + logits = torch.matmul(hidden_states, embedding.t()) + if embedding_bias is not None: + logits += embedding_bias + logits = tensor_model_parallel_all_gather(logits) + # Remove paddings in vocab (if any). + if logits is not None: + logits = logits[:, :self.org_vocab_size] + return logits + + +from vllm.model_executor.layers.logits_processor import LogitsProcessor + + +def logitsprocessor_init( + self, + vocab_size: int, + org_vocab_size: Optional[int] = None, + scale: float = 1.0, + logits_as_input: bool = False, + soft_cap: Optional[float] = None, +) -> None: + """ + Args: + scale: A scaling factor to apply to the logits. + """ + super(LogitsProcessor, self).__init__() + self.scale = scale + self.vocab_size = vocab_size + # Whether the input is logits (default is hidden states). + self.logits_as_input = logits_as_input + # original vocabulary size (without LoRA). + self.org_vocab_size = org_vocab_size or vocab_size + # Soft cap the logits. Used in Gemma 2. + self.soft_cap = soft_cap + # Whether to use gather or all-gather to gather the logits. + self.use_gather = False + + +LogitsProcessor.__init__ = logitsprocessor_init # use all_gather diff --git a/verl/third_party/vllm/vllm_v_0_6_3/model_runner.py b/verl/third_party/vllm/vllm_v_0_6_3/model_runner.py new file mode 100644 index 0000000..b0cceff --- /dev/null +++ b/verl/third_party/vllm/vllm_v_0_6_3/model_runner.py @@ -0,0 +1,182 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2023 The vLLM team. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/worker/model_runner.py + +import warnings +from enum import IntEnum +from typing import Dict, Optional, Union + +import torch +import torch.nn as nn +import vllm.envs as envs +from vllm.compilation.levels import CompilationLevel +from vllm.config import ( + CacheConfig, + DeviceConfig, + LoadConfig, + LoRAConfig, + ModelConfig, + ObservabilityConfig, + ParallelConfig, + PromptAdapterConfig, + SchedulerConfig, +) +from vllm.inputs import INPUT_REGISTRY, InputRegistry +from vllm.logger import init_logger +from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager +from vllm.model_executor.models.interfaces import supports_lora +from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry +from vllm.prompt_adapter.worker_manager import LRUCacheWorkerPromptAdapterManager +from vllm.utils import DeviceMemoryProfiler, is_hip, supports_dynamo +from vllm.worker.model_runner import ModelRunner + +from .config import LoadConfig, ModelConfig +from .model_loader import get_model + +logger = init_logger(__name__) + + +# How batches are constructed. +class BatchType(IntEnum): + # Every batch is prefill. + PREFILL = 0 + # Every batch is decode. + DECODE = 1 + # Batch is a mixture of prefill and decode. + MIXED = 2 + + +class ModelRunner(ModelRunner): + + def __init__( + self, + model: Union[nn.Module, Dict], # [verl] model itself or its parameter dict + model_config: ModelConfig, + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + device_config: DeviceConfig, + cache_config: CacheConfig, + load_config: LoadConfig, + lora_config: Optional[LoRAConfig], + kv_cache_dtype: Optional[str] = "auto", + is_driver_worker: bool = False, + prompt_adapter_config: Optional[PromptAdapterConfig] = None, + return_hidden_states: bool = False, + observability_config: Optional[ObservabilityConfig] = None, + input_registry: InputRegistry = INPUT_REGISTRY, + mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, + ): + + super().__init__( + model_config, + parallel_config, + scheduler_config, + device_config, + cache_config, + load_config, + lora_config, + kv_cache_dtype, + is_driver_worker=True, # a hack + prompt_adapter_config=prompt_adapter_config, + return_hidden_states=return_hidden_states, + observability_config=observability_config, + input_registry=input_registry, + mm_registry=mm_registry, + ) + + # NOTE(sgm): add for verl + self.model = model # this will be replaced by get_model() + + def load_model(self) -> None: + logger.info("Starting to load model %s...", self.model_config.model) + with DeviceMemoryProfiler() as m: + self.model = get_model( + self.model, + model_config=self.model_config, + device_config=self.device_config, + load_config=self.load_config, + lora_config=self.lora_config, + parallel_config=self.parallel_config, + scheduler_config=self.scheduler_config, + cache_config=self.cache_config, + ) + + self.model_memory_usage = m.consumed_memory + logger.info("Loading model weights took %.4f GB", self.model_memory_usage / float(2**30)) + + if self.lora_config: + assert supports_lora(self.model), f"{self.model.__class__.__name__} does not support LoRA yet." + + if supports_multimodal(self.model): + logger.warning("Regarding multimodal models, vLLM currently " + "only supports adding LoRA to language model.") + # It's necessary to distinguish between the max_position_embeddings + # of VLMs and LLMs. + if hasattr(self.model.config, "max_position_embeddings"): + max_pos_embeddings = self.model.config.max_position_embeddings + else: + max_pos_embeddings = self.model.config.text_config.max_position_embeddings + + self.lora_manager = LRUCacheWorkerLoRAManager( + self.scheduler_config.max_num_seqs, + self.scheduler_config.max_num_batched_tokens, + self.vocab_size, + self.lora_config, + self.device, + self.model.embedding_modules, + self.model.embedding_padding_modules, + max_position_embeddings=max_pos_embeddings, + ) + self.model = self.lora_manager.create_lora_manager(self.model) + + if self.prompt_adapter_config: + self.prompt_adapter_manager = LRUCacheWorkerPromptAdapterManager( + self.scheduler_config.max_num_seqs, + self.scheduler_config.max_num_batched_tokens, + self.device, + self.prompt_adapter_config, + ) + self.model = self.prompt_adapter_manager.create_prompt_adapter_manager(self.model) + + if self.kv_cache_dtype == "fp8" and is_hip(): + # Currently only ROCm accepts kv-cache scaling factors + # via quantization_param_path and this will be deprecated + # in the future. + if self.model_config.quantization_param_path is not None: + if callable(getattr(self.model, "load_kv_cache_scales", None)): + warnings.warn( + "Loading kv cache scaling factor from JSON is " + "deprecated and will be removed. Please include " + "kv cache scaling factors in the model checkpoint.", + FutureWarning, + stacklevel=2, + ) + self.model.load_kv_cache_scales(self.model_config.quantization_param_path) + logger.info("Loaded KV cache scaling factors from %s", self.model_config.quantization_param_path) + else: + raise RuntimeError( + "Using FP8 KV cache and scaling factors provided but " + "model %s does not support loading scaling factors.", + self.model.__class__, + ) + else: + logger.warning("Using FP8 KV cache but no scaling factors " + "provided. Defaulting to scaling factors of 1.0. " + "This may lead to less accurate results!") + + if envs.VLLM_TORCH_COMPILE_LEVEL == CompilationLevel.DYNAMO_AS_IS and supports_dynamo(): + from vllm.plugins import get_torch_compile_backend + + backend = get_torch_compile_backend() or "eager" + self.model = torch.compile(self.model, fullgraph=envs.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE, backend=backend) diff --git a/verl/third_party/vllm/vllm_v_0_6_3/parallel_state.py b/verl/third_party/vllm/vllm_v_0_6_3/parallel_state.py new file mode 100644 index 0000000..0150c1c --- /dev/null +++ b/verl/third_party/vllm/vllm_v_0_6_3/parallel_state.py @@ -0,0 +1,312 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2023 The vLLM team. +# Adapted from +# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/parallel_state.py +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +"""Model and data parallel groups.""" +import os +from typing import Optional + +import torch +import torch.distributed +import vllm.distributed.parallel_state as ps +from vllm.distributed.parallel_state import ( + get_pp_group, + get_world_group, + init_distributed_environment, + init_model_parallel_group, +) +from vllm.logger import init_logger + +logger = init_logger(__name__) +""" +This version is strongly tied with Megatron to implement HybridEngine and weight sharing between vllm and Megatron. +- We assume the Megatron tp+dp+pp world is already established before calling this function. + +""" + +# Device mesh for using DTensor +_DEVICE_MESH = None + +# Tensor model parallel group that the current rank belongs to. +_TP = None +# Pipeline model parallel group that the current rank belongs to. +_PP = None + + +# This method is for initializing the ParallelGroup when using HybridEngine +def initialize_parallel_state( + distributed_init_method: str = "env://", + backend: str = "nccl", + tensor_model_parallel_size: int = 1, + num_tp_per_train_tp: int = 1, + pipeline_model_parallel_size: int = 1, +): + # torch.distributed.all_reduce does not free the input tensor until + # the synchronization point. This causes the memory usage to grow + # as the number of all_reduce calls increases. This env var disables + # this behavior. + # Related issue: + # https://discuss.pytorch.org/t/cuda-allocation-lifetime-for-inputs-to-distributed-all-reduce/191573 + os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1" + + # NOTE(sgm): Modify for verl, Env vars will be set by TORCHRUN. + rank = int(os.getenv("RANK", "-1")) + local_rank = int(os.getenv("LOCAL_RANK", "0")) + + # Use the world_size set by TORCHRUN + world_size = int(os.getenv("WORLD_SIZE", "-1")) + assert world_size != -1, "The world_size is set to -1, not initialized by TORCHRUN" + init_distributed_environment(world_size, rank, distributed_init_method, local_rank, backend) + if torch.distributed.get_world_size() > 1: + # NOTE: build a sepearate inference group with infer tp & micro dp + initialize_model_parallel_for_vllm( + tensor_model_parallel_size=tensor_model_parallel_size, + num_tensor_model_parallel_groups_per_train_tp=num_tp_per_train_tp, + ) + else: + initialize_model_parallel(tensor_model_parallel_size, pipeline_model_parallel_size, backend) + + +def ensure_model_parallel_initialized( + tensor_model_parallel_size: int, + pipeline_model_parallel_size: int = 1, + backend: Optional[str] = None, +) -> None: + """Helper to initialize model parallel groups if they are not initialized, + or ensure tensor-parallel and pipeline-parallel sizes are equal to expected + values if the model parallel groups are initialized. + """ + # get the backend of _DEVICE_WORLD_GROUP + backend = backend or torch.distributed.get_backend(get_world_group().device_group) + if not model_parallel_is_initialized(): + initialize_model_parallel(tensor_model_parallel_size, pipeline_model_parallel_size, backend) + return + + assert get_tensor_model_parallel_world_size() == tensor_model_parallel_size, ( + "tensor parallel group already initialized, but of unexpected size: " + f"{get_tensor_model_parallel_world_size()=} vs. " + f"{tensor_model_parallel_size=}") + pp_world_size = get_pp_group().world_size + assert pp_world_size == pipeline_model_parallel_size, ( + "pipeline parallel group already initialized, but of unexpected size: " + f"{pp_world_size=} vs. " + f"{pipeline_model_parallel_size=}") + + +# TODO(sgm): deviate from the v0.5.4, not pp now +def model_parallel_is_initialized(): + """Check if tensor and pipeline parallel groups are initialized.""" + return ps._TP is not None + # and _PIPELINE_MODEL_PARALLEL_GROUP is not None) + + +def initialize_model_parallel_for_vllm( + tensor_model_parallel_size: int, + num_tensor_model_parallel_groups_per_train_tp: int = 1, + pipeline_model_parallel_size: int = 1, +) -> None: + pass + + # Get world size and rank. Ensure some consistencies. + assert torch.distributed.is_initialized() + + assert isinstance(tensor_model_parallel_size, int) + + # assert num_tensor_model_parallel_groups_per_train_tp == 1 and not different_tp_group + # assert num_tensor_model_parallel_groups_per_train_tp > 1 and different_tp_group + + # Build the tensor model-parallel groups. + assert ps._TP is None, "tensor model parallel group is already initialized" + + global _TP + + world_size: int = torch.distributed.get_world_size() + + rank = torch.distributed.get_rank() + + backend = torch.distributed.get_backend() + + num_tensor_model_parallel_groups = world_size // tensor_model_parallel_size + + if num_tensor_model_parallel_groups_per_train_tp == 1: + # if tensor_model_parallel_size == train_tensor_parallel_size: + # using the same tp group as Megatron/vllm + assert _TP is None, "tensor model parallel group is already initialized" + group_ranks = [] + for i in range(num_tensor_model_parallel_groups): + ranks = range(i * tensor_model_parallel_size, (i + 1) * tensor_model_parallel_size) + group_ranks.append(ranks) + _TP = init_model_parallel_group( + group_ranks=group_ranks, + local_rank=get_world_group().local_rank, + backend=backend, + use_custom_allreduce=False, # TODO: check why True is not work in Ray trainer + use_message_queue_broadcaster=True, + ) + ps._TP = _TP + # _MICRO_DATA_PARALLEL_GROUP is move to hybrid engine + else: + # initialize a micro_dp group and a tp group + # assume training tp=4, infer tp=2, then, weight is partitioned as + # [1], [2], [3], [4] for training and [1,2], [1,2], [3,4], [3,4] for inference + + # Build the inference tp groups + # train_tp = train_tensor_parallel_size + train_tp = num_tensor_model_parallel_groups_per_train_tp * tensor_model_parallel_size + # num_tensor_model_parallel_groups_per_train_tp = train_tp // tensor_model_parallel_size + assert _TP is None, "tensor model parallel group is already initialized" + group_ranks = [] + for i in range(num_tensor_model_parallel_groups // num_tensor_model_parallel_groups_per_train_tp): + start = train_tp * i + end = train_tp * (i + 1) + for j in range(num_tensor_model_parallel_groups_per_train_tp): + ranks = list(range(start, end, num_tensor_model_parallel_groups_per_train_tp)) + for i in range(len(ranks)): + ranks[i] += j + group_ranks.append(ranks) + _TP = init_model_parallel_group( + group_ranks=group_ranks, + local_rank=get_world_group().local_rank, + backend=backend, + use_custom_allreduce=False, # TODO: check why True is not work in Ray trainer + use_message_queue_broadcaster=True, + ) + ps._TP = _TP + + # Build the pipeline model-parallel groups. + # global _PIPELINE_MODEL_PARALLEL_GROUP + # global _PIPELINE_GLOBAL_RANKS + # assert ps._PIPELINE_MODEL_PARALLEL_GROUP is None, ("pipeline model parallel group is already initialized") + + # ps._PIPELINE_MODEL_PARALLEL_GROUP = mpu.get_pipeline_model_parallel_group() + # ps._PIPELINE_GLOBAL_RANKS = mpu.get_pipeline_model_parallel_ranks() + + # TODO: init using device mesh (not support hybrid engine now) + # Build the pipeline model-parallel groups. + num_pipeline_model_parallel_groups: int = world_size // pipeline_model_parallel_size + global _PP + assert _PP is None, "pipeline model parallel group is already initialized" + group_ranks = [] + for i in range(num_pipeline_model_parallel_groups): + ranks = list(range(i, world_size, num_pipeline_model_parallel_groups)) + group_ranks.append(ranks) + # pipeline parallel does not need custom allreduce + _PP = init_model_parallel_group(group_ranks, get_world_group().local_rank, backend, use_custom_allreduce=False) + ps._PP = _PP # for verl + + +def initialize_model_parallel( + tensor_model_parallel_size: int = 1, + pipeline_model_parallel_size: int = 1, + backend: Optional[str] = None, +) -> None: + """ + NOTE: This method is a hack from the open-sourced version without + asertion of world_size = tp * pp + + Initialize model parallel groups. + + Arguments: + tensor_model_parallel_size: number of GPUs used for tensor model + parallelism. + pipeline_model_parallel_size: number of GPUs used for pipeline model + parallelism. + + Let's say we have a total of 8 GPUs denoted by g0 ... g7 and we + use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize + the model pipeline. The present function will + create 4 tensor model-parallel groups and 2 pipeline model-parallel groups: + 4 tensor model-parallel groups: + [g0, g1], [g2, g3], [g4, g5], [g6, g7] + 2 pipeline model-parallel groups: + [g0, g2, g4, g6], [g1, g3, g5, g7] + Note that for efficiency, the caller should make sure adjacent ranks + are on the same DGX box. For example if we are using 2 DGX-1 boxes + with a total of 16 GPUs, rank 0 to 7 belong to the first box and + ranks 8 to 15 belong to the second box. + """ + # Get world size and rank. Ensure some consistencies. + assert torch.distributed.is_initialized() + world_size: int = torch.distributed.get_world_size() + backend = backend or torch.distributed.get_backend(ps.get_world_group().device_group) + + # NOTE(sgm) we don't assert world_size == tp * pp + # DP is not managed by vllm but by the VeRL WorkerGroup + # if (world_size != + # tensor_model_parallel_size * pipeline_model_parallel_size): + # raise RuntimeError( + # f"world_size ({world_size}) is not equal to " + # f"tensor_model_parallel_size ({tensor_model_parallel_size}) x " + # f"pipeline_model_parallel_size ({pipeline_model_parallel_size})") + + num_tensor_model_parallel_groups: int = world_size // tensor_model_parallel_size + rank = torch.distributed.get_rank() + global _TP + assert _TP is None, "tensor model parallel group is already initialized" + group_ranks = [] + for i in range(num_tensor_model_parallel_groups): + ranks = list(range(i * tensor_model_parallel_size, (i + 1) * tensor_model_parallel_size)) + group_ranks.append(ranks) + + # message queue broadcaster is only used in tensor model parallel group + _TP = init_model_parallel_group( + group_ranks, + get_world_group().local_rank, + backend, + use_custom_allreduce=False, # TODO: check why True is not work in Ray trainer + use_message_queue_broadcaster=True, + ) + ps._TP = _TP + + # TODO: init using device mesh (not support hybrid engine now) + # Build the pipeline model-parallel groups. + num_pipeline_model_parallel_groups: int = world_size // pipeline_model_parallel_size + global _PP + assert _PP is None, "pipeline model parallel group is already initialized" + group_ranks = [] + for i in range(num_pipeline_model_parallel_groups): + ranks = list(range(i, world_size, num_pipeline_model_parallel_groups)) + group_ranks.append(ranks) + # pipeline parallel does not need custom allreduce + _PP = init_model_parallel_group(group_ranks, get_world_group().local_rank, backend, use_custom_allreduce=False) + ps._PP = _PP # for verl + + +""" +Device mesh utilities +""" + + +def get_device_mesh(): + assert _DEVICE_MESH is not None, "device mesh is not initialized" + return _DEVICE_MESH + + +""" +Tensor model parallel utilities +""" + + +def get_tensor_model_parallel_group(): + """Get the tensor model parallel group the caller rank belongs to.""" + assert _TP is not None, "tensor model parallel group is not initialized" + return _TP.device_group + + +def get_tensor_model_parallel_world_size(): + """Return world size for the tensor model parallel group.""" + return torch.distributed.get_world_size(group=get_tensor_model_parallel_group()) + + +def get_tensor_model_parallel_rank(): + """Return my rank for the tensor model parallel group.""" + return torch.distributed.get_rank(group=get_tensor_model_parallel_group()) + + +def get_tensor_model_parallel_src_rank(): + """Calculate the global rank corresponding to the first local rank + in the tensor model parallel group.""" + global_rank = torch.distributed.get_rank() + local_world_size = get_tensor_model_parallel_world_size() + return (global_rank // local_world_size) * local_world_size diff --git a/verl/third_party/vllm/vllm_v_0_6_3/spmd_gpu_executor.py b/verl/third_party/vllm/vllm_v_0_6_3/spmd_gpu_executor.py new file mode 100644 index 0000000..229a424 --- /dev/null +++ b/verl/third_party/vllm/vllm_v_0_6_3/spmd_gpu_executor.py @@ -0,0 +1,256 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2023 The vLLM team. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/executor/gpu_executor.py + +import os +import socket +from typing import Dict, List, Optional, Set, Tuple + +import torch +from vllm.config import ( + CacheConfig, + DeviceConfig, + LoRAConfig, + ObservabilityConfig, + ParallelConfig, + PromptAdapterConfig, + SchedulerConfig, + SpeculativeConfig, +) +from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase +from vllm.logger import init_logger +from vllm.lora.request import LoRARequest +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.sequence import ExecuteModelRequest + +from .config import LoadConfig, ModelConfig + +logger = init_logger(__name__) + + +class SPMDGPUExecutor(ExecutorBase): + """SPMD-based multi-GPU executor implementations.""" + + def __init__( + self, + model, # pytorch model itself or its parameter dict + model_config: ModelConfig, + cache_config: CacheConfig, + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + device_config: DeviceConfig, + load_config: LoadConfig, + lora_config: Optional[LoRAConfig], + speculative_config: Optional[SpeculativeConfig], + prompt_adapter_config: Optional[PromptAdapterConfig], + observability_config: Optional[ObservabilityConfig], + ) -> None: + self.model_config = model_config + self.cache_config = cache_config + self.lora_config = lora_config + self.load_config = load_config + self.parallel_config = parallel_config + self.scheduler_config = scheduler_config + self.device_config = device_config + self.speculative_config = speculative_config + self.prompt_adapter_config = prompt_adapter_config + self.observability_config = observability_config + + distributed_init_method = initialize_cluster(parallel_config) + self._init_executor(model, distributed_init_method) + + # TODO(sgm): verl not support speculative decode now + def _init_executor(self, model, distributed_init_method) -> None: + assert not self.speculative_config, "Speculative decoding not yet supported for multi-GPU backend." + + # Create the parallel worker for each GPU. + self._init_workers_sp(model, distributed_init_method) + + def _init_workers_sp(self, model, distributed_init_method: str): + # Lazy import the Worker to avoid importing torch.cuda/xformers + # before CUDA_VISIBLE_DEVICES is set in the Worker + from .worker import Worker # pylint: disable=import-outside-toplevel + + rank = int(os.getenv("RANK")) + local_rank = int(os.getenv("LOCAL_RANK")) + print(f"local rank {local_rank}") + + # see https://github.com/NVIDIA/nccl/issues/1234 + os.environ["NCCL_CUMEM_ENABLE"] = "0" + + self.worker = Worker( + model, + self.model_config, + self.parallel_config, + self.scheduler_config, + self.device_config, + self.cache_config, + self.load_config, + local_rank, + rank, + distributed_init_method, + lora_config=self.lora_config, + speculative_config=None, + prompt_adapter_config=self.speculative_config, + is_driver_worker=True, + model_runner_cls=None, # use the default one + ) + + # NOTE(shengguangming): torch.distributed.init_process_group will be called inside the init_model() + self.worker.init_device() + self.worker.load_model() + + def determine_num_available_blocks(self) -> Tuple[int, int]: + """Determine the number of available KV blocks. + + This invokes `determine_num_available_blocks` on each worker and takes + the min of the results, guaranteeing that the selected cache sizes are + compatible with all workers. + + Returns: + - tuple[num_gpu_blocks, num_cpu_blocks] + """ + # Get the maximum number of blocks that can be allocated on GPU and CPU. + num_blocks = self.worker.determine_num_available_blocks() + + # NOTE(shengguangming): Now we don't use a shared centralized controler but each process will + # have its own scheduler + num_gpu_blocks = num_blocks[0] + num_cpu_blocks = num_blocks[1] + + return num_gpu_blocks, num_cpu_blocks + + def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks: int) -> None: + """Initialize the KV cache in all workers.""" + + # NOTE: We log here to avoid multiple logs when number of workers is + # greater than one. We could log in the engine, but not all executors + # have GPUs. + logger.info("# GPU blocks: %d, # CPU blocks: %d", num_gpu_blocks, num_cpu_blocks) + + self.cache_config.num_gpu_blocks = num_gpu_blocks + self.cache_config.num_cpu_blocks = num_cpu_blocks + + if torch.distributed.get_rank() == 0: + print( + f"before init cache memory allocated: {torch.cuda.memory_allocated() / 1e9}GB, reserved: {torch.cuda.memory_reserved() / 1e9}GB" + ) + self.worker.initialize_cache(num_gpu_blocks=num_gpu_blocks, num_cpu_blocks=num_cpu_blocks) + if torch.distributed.get_rank() == 0: + print( + f"after init cache memory allocated: {torch.cuda.memory_allocated() / 1e9}GB, reserved: {torch.cuda.memory_reserved() / 1e9}GB" + ) + + # NOTE(sgm): This will not profile & capture the model(CUDAGraph) when rebuilding KVCache + def init_cache_engine(self) -> None: + self.worker._init_cache_engine() + + def free_cache_engine(self) -> None: + self.worker.free_cache_engine() + + def execute_model(self, execute_model_req) -> List[SamplerOutput]: + all_outputs = self.worker.execute_model(execute_model_req=execute_model_req) + + # NOTE(sgm): + # Each GPU in vllm under verl has its own spmd_gpu_executor, therefore all GPUs should return the outputs + # In vllm with ray, only the driver worker returns the sampling results. + return all_outputs + + def add_lora(self, lora_request: LoRARequest) -> bool: + assert lora_request.lora_int_id > 0, "lora_id must be greater than 0." + return self.worker.add_lora(lora_request=lora_request) + + def remove_lora(self, lora_id: int) -> bool: + assert lora_id > 0, "lora_id must be greater than 0." + return self.worker.remove_lora(lora_id=lora_id) + + def list_loras(self) -> Set[int]: + return self.worker.list_loras() + + def check_health(self) -> None: + # SPMDExecutor will always be healthy as long as + # it's running. + return + + # NOTE(sgm) add for verl to pass the abstract class test, not used + from vllm.prompt_adapter.request import PromptAdapterRequest + + def add_prompt_adapter(self, prompt_adapter_request: PromptAdapterRequest) -> bool: + assert prompt_adapter_request.prompt_adapter_id > 0, "prompt_adapter_id must be greater than 0." + return self.worker.add_prompt_adapter(prompt_adapter_request) + + def list_prompt_adapters(self) -> Set[int]: + return self.worker.list_prompt_adapters() + + def pin_lora(self, lora_id: int) -> bool: + assert lora_id > 0, "lora_id must be greater than 0." + return self.worker.pin_lora(lora_id) + + def pin_prompt_adapter(self, prompt_adapter_id: int) -> bool: + assert prompt_adapter_id > 0, "prompt_adapter_id must be greater than 0." + return self.worker.pin_prompt_adapter(prompt_adapter_id) + + def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool: + assert prompt_adapter_id > 0, "prompt_adapter_id must be greater than 0." + return self.worker.remove_prompt_adapter(prompt_adapter_id) + + # NOTE(sgm): add for verl + def offload_model_weights(self) -> None: + self.worker.offload_model_weights() + + def sync_model_weights(self, actor_weights: Dict[str, torch.Tensor], load_format: str) -> None: + self.worker.sync_model_weights(actor_weights=actor_weights, load_format=load_format) + + +def initialize_cluster( + parallel_config: ParallelConfig, + engine_use_ray: bool = False, + ray_address: Optional[str] = None, +) -> Tuple[str, Optional[None]]: + """Initialize the distributed cluster probably with Ray. + + Args: + parallel_config: The configurations for parallel execution. + + Returns: + The `distributed_init_method` is the address for initializing the + distributed backend. + """ + + # Initialize cluster locally. + port = get_open_port() + # We need to setup the distributed init method to make sure + # the distributed megatron code (e.g., get world size) works correctly. + # distributed_init_method = f"tcp://localhost:{port}" + distributed_init_method = "env://" + return distributed_init_method + + +def get_open_port(): + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("", 0)) + return s.getsockname()[1] + + +# TODO(sgm): not implemented async executor yet +class SPMDGPUExecutorAsync(SPMDGPUExecutor, ExecutorAsyncBase): + + async def execute_model_async(self, execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: + """Executes one model step on the given sequences.""" + raise NotImplementedError + + async def check_health_async(self) -> None: + """Checks if the executor is healthy. If not, it should raise an + exception.""" + self.check_health() diff --git a/verl/third_party/vllm/vllm_v_0_6_3/tokenizer.py b/verl/third_party/vllm/vllm_v_0_6_3/tokenizer.py new file mode 100644 index 0000000..b0b4d0e --- /dev/null +++ b/verl/third_party/vllm/vllm_v_0_6_3/tokenizer.py @@ -0,0 +1,40 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2023 The vLLM team. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/transformers_utils/tokenizer_group/tokenizer_group.py + +from typing import Optional + +from transformers import PreTrainedTokenizer +from vllm.transformers_utils.tokenizer_group import TokenizerGroup +from vllm.utils import LRUCache + + +class TokenizerGroup(TokenizerGroup): + """A group of tokenizers that can be used for LoRA adapters.""" + + def __init__(self, tokenizer: PreTrainedTokenizer, enable_lora: bool, max_num_seqs: int, + max_input_length: Optional[int]): + self.enable_lora = enable_lora + self.max_input_length = max_input_length + self.tokenizer = tokenizer + self.lora_tokenizers = LRUCache[PreTrainedTokenizer](capacity=max_num_seqs) if enable_lora else None + + # FIXME(sgm): for simplicity, we assign the special token here + @property + def pad_token_id(self): + return self.tokenizer.pad_token_id + + @property + def eos_token_id(self): + return self.tokenizer.eos_token_id diff --git a/verl/third_party/vllm/vllm_v_0_6_3/worker.py b/verl/third_party/vllm/vllm_v_0_6_3/worker.py new file mode 100644 index 0000000..cb1a7ab --- /dev/null +++ b/verl/third_party/vllm/vllm_v_0_6_3/worker.py @@ -0,0 +1,333 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2023 The vLLM team. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/worker/worker.py +"""A GPU worker class.""" +import gc +import os +from typing import Dict, List, Optional, Tuple, Type, Union + +import torch +import torch.distributed +import torch.nn as nn +from vllm.config import ( + CacheConfig, + DeviceConfig, + LoRAConfig, + ParallelConfig, + PromptAdapterConfig, + SchedulerConfig, + SpeculativeConfig, +) + +# TODO(sgm): check why vllm has similar file in vllm.model_executor.parallel_utils.parallel_state +from vllm.distributed import get_tensor_model_parallel_group, init_distributed_environment, set_custom_all_reduce +from vllm.model_executor import set_random_seed +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.sequence import ExecuteModelRequest, IntermediateTensors +from vllm.worker.cache_engine import CacheEngine +from vllm.worker.embedding_model_runner import EmbeddingModelRunner +from vllm.worker.model_runner import GPUModelRunnerBase +from vllm.worker.model_runner_base import ModelRunnerInputBase +from vllm.worker.worker import Worker, _check_if_gpu_supports_dtype +from vllm.worker.worker_base import WorkerInput + +from .config import LoadConfig, LoadFormat, ModelConfig +from .dtensor_weight_loaders import load_dtensor_weights +from .hf_weight_loader import load_hf_weights +from .megatron_weight_loaders import load_megatron_weights +from .model_runner import ModelRunner +from .parallel_state import ensure_model_parallel_initialized + + +class Worker(Worker): + """A worker class that executes (a partition of) the model on a GPU. + + Each worker is associated with a single GPU. The worker is responsible for + maintaining the KV cache and executing the model on the GPU. In case of + distributed inference, each worker is assigned a partition of the model. + """ + + def __init__( + self, + model: Union[nn.Module, Dict], # model itself or its parameter dict + model_config: ModelConfig, + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + device_config: DeviceConfig, + cache_config: CacheConfig, + load_config: LoadConfig, + local_rank: int, + rank: int, + distributed_init_method: str, + lora_config: Optional[LoRAConfig] = None, + speculative_config: Optional[SpeculativeConfig] = None, + prompt_adapter_config: Optional[PromptAdapterConfig] = None, + is_driver_worker: bool = False, + model_runner_cls: Optional[Type[GPUModelRunnerBase]] = None, + ) -> None: + # self.model = model # will be replaced in the init_model + self.model_config = model_config + self.parallel_config = parallel_config + self.parallel_config.rank = rank + self.scheduler_config = scheduler_config + self.device_config = device_config + self.cache_config = cache_config + self.local_rank = local_rank + self.rank = rank + self.distributed_init_method = distributed_init_method + self.lora_config = lora_config + self.load_config = load_config + self.prompt_adapter_config = prompt_adapter_config + self.is_driver_worker = is_driver_worker # TODO: we don't need driver + # if parallel_config and is_driver_worker: + # assert rank % parallel_config.tensor_parallel_size == 0, \ + # "Driver worker should be rank 0 of tensor parallel group." + if self.model_config.trust_remote_code: + # note: lazy import to avoid importing torch before initializing + from vllm.utils import init_cached_hf_modules + + init_cached_hf_modules() + + # Return hidden states from target model if the draft model is an + # mlp_speculator + speculative_args = ( + {} if speculative_config is None or (speculative_config.draft_model_config.model == model_config.model) or + (speculative_config.draft_model_config.hf_config.model_type not in ["medusa", "mlp_speculator"]) else { + "return_hidden_states": True + }) + + # TODO(sgm): set correct model runner class + ModelRunnerClass: Type[GPUModelRunnerBase] = ModelRunner + if model_runner_cls is not None: + ModelRunnerClass = model_runner_cls + elif self.model_config.embedding_mode: + ModelRunnerClass = EmbeddingModelRunner + self.model_runner: GPUModelRunnerBase = ModelRunnerClass( + model, # [VERL]: add for verl + model_config, + parallel_config, + scheduler_config, + device_config, + cache_config, + load_config=load_config, + lora_config=self.lora_config, + kv_cache_dtype=self.cache_config.cache_dtype, + is_driver_worker=is_driver_worker, + prompt_adapter_config=prompt_adapter_config, + **speculative_args, + ) + + # Uninitialized cache engine. Will be initialized by + # initialize_cache. + self.cache_engine: List[CacheEngine] = None + # Initialize gpu_cache as embedding models don't initialize kv_caches + self.gpu_cache: Optional[List[List[torch.Tensor]]] = None + + # NOTE(sgm): [VERL] For offloading inference engine params + self.cpu_model = None + + def init_device(self) -> None: + if self.device_config.device.type == "cuda": + # torch.distributed.all_reduce does not free the input tensor until + # the synchronization point. This causes the memory usage to grow + # as the number of all_reduce calls increases. This env var disables + # this behavior. + # Related issue: + # https://discuss.pytorch.org/t/cuda-allocation-lifetime-for-inputs-to-distributed-all-reduce/191573 + os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1" + + # NOTE(sgm): Modify for verl, Env vars will be set by TORCHRUN. + self.rank = self.rank if self.rank is not None else int(os.getenv("RANK", "-1")) + local_rank = int(os.getenv("LOCAL_RANK", "0")) + self.device = torch.device(f"cuda:{local_rank}") + if self.rank < 0: + raise ValueError("Invalid or unspecified rank.") + torch.cuda.set_device(self.device) + + # Use the world_size set by TORCHRUN + world_size = int(os.getenv("WORLD_SIZE", "-1")) + assert world_size != -1, "The world_size is set to -1, not initialized by TORCHRUN" + self.parallel_config.world_size = world_size + + _check_if_gpu_supports_dtype(self.model_config.dtype) + torch.cuda.empty_cache() + self.init_gpu_memory = torch.cuda.mem_get_info()[0] + else: + raise RuntimeError(f"Not support device type: {self.device_config.device}") + + # Initialize the distributed environment. + init_worker_distributed_environment(self.parallel_config, self.rank, self.distributed_init_method, + self.local_rank) + # Set random seed. + set_random_seed(self.model_config.seed) + # self.model = get_model(actor_model=self.model, model_config=self.model_config) + + @torch.inference_mode() + def determine_num_available_blocks(self) -> Tuple[int, int]: + """Profiles the peak memory usage of the model to determine how many + KV blocks may be allocated without OOMs. + + The engine will first conduct a profiling of the existing memory usage. + Then, it calculate the maximum possible number of GPU and CPU blocks + that can be allocated with the remaining free memory. + + .. tip:: + You may limit the usage of GPU memory + by adjusting the `gpu_memory_utilization` parameter. + """ + # Profile the memory usage of the model and get the maximum number of + # cache blocks that can be allocated with the remaining free memory. + torch.cuda.empty_cache() + # torch.cuda.reset_peak_memory_stats() + + # Execute a forward pass with dummy inputs to profile the memory usage + # of the model. + self.model_runner.profile_run() + + # Calculate the number of blocks that can be allocated with the + # profiled peak memory. + torch.cuda.synchronize() + free_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info() + peak_memory = total_gpu_memory - free_gpu_memory + + assert peak_memory > 0, ("Error in memory profiling. This happens when the GPU memory was " + "not properly cleaned up before initializing the vLLM instance.") + + cache_block_size = self.get_cache_block_size_bytes() + + # NOTE(sgm) [VERL] use the remaining memory + num_gpu_blocks = int((free_gpu_memory * self.cache_config.gpu_memory_utilization) // cache_block_size) + # num_gpu_blocks = int((total_gpu_memory * self.cache_config.gpu_memory_utilization - peak_memory) // cache_block_size) + + num_cpu_blocks = int(self.cache_config.swap_space_bytes // cache_block_size) + num_gpu_blocks = max(num_gpu_blocks, 0) + num_cpu_blocks = max(num_cpu_blocks, 0) + if self.model_runner.lora_manager: + self.model_runner.remove_all_loras() + + # NOTE(sgm): Add for [VERL], synchronize number of blocks with all the rank + num_gpu_blocks = torch.tensor([num_gpu_blocks], device="cuda") + num_cpu_blocks = torch.tensor([num_cpu_blocks], device="cuda") + + torch.distributed.all_reduce(num_gpu_blocks, + op=torch.distributed.ReduceOp.MIN, + group=get_tensor_model_parallel_group().device_group) + torch.distributed.all_reduce(num_cpu_blocks, + op=torch.distributed.ReduceOp.MIN, + group=get_tensor_model_parallel_group().device_group) + num_gpu_blocks = num_gpu_blocks.item() + num_cpu_blocks = num_cpu_blocks.item() + gc.collect() + torch.cuda.empty_cache() + return num_gpu_blocks, num_cpu_blocks + + def _init_cache_engine(self): + if self.cache_engine is None and self.gpu_cache is None: + super()._init_cache_engine() + + def free_cache_engine(self): + # ensure `enforce_eager=True` + self.cache_engine = None + self.gpu_cache = None + + # NOTE(sgm): [VERL]: adapt from _execute_model_spmd() + def execute_model(self, + execute_model_req: ExecuteModelRequest, + intermediate_tensors: Optional[IntermediateTensors] = None) -> Optional[List[SamplerOutput]]: + """ + Execute model in Single Program Multiple Data (SPMD) fashion. + All workers take the same request, prepare the input and + execute the model. + """ + assert execute_model_req is not None, ("_execute_model_spmd() requires each worker to take in an " + "ExecuteModelRequest") + worker_input: WorkerInput = self.prepare_worker_input(execute_model_req=execute_model_req) + model_input: ModelRunnerInputBase = self.model_runner.prepare_model_input( + execute_model_req.seq_group_metadata_list) + + # verl.worker.workerbase.WorkerBase + # swap cache + super().execute_worker(worker_input) + + # If there is no input, we don't need to execute the model. + if worker_input.num_seq_groups == 0: + return [] + + return self.model_runner.execute_model( + model_input, + self.kv_cache[worker_input.virtual_engine] if self.kv_cache is not None else None, + intermediate_tensors, + ) + + # assume the input is .state_dict() + def sync_model_weights(self, actor_weights: Dict, load_format: str): + if load_format in [LoadFormat.MEGATRON, LoadFormat.AUTO]: + load_megatron_weights(actor_weights, self.model_runner.model) + elif load_format == LoadFormat.HF: + # full model state dict without no sharding + load_hf_weights(actor_weights, self.model_runner.model) + elif load_format == LoadFormat.DTENSOR: + load_dtensor_weights(actor_weights, self.model_runner.model) + + def offload_model_weights(self) -> None: + if self.cpu_model == None: + self.cpu_model = {} + for name, params in self.model_runner.model.named_parameters(): + self.cpu_model[name] = torch.empty_like(params, device="cpu") + params.data = self.cpu_model[name] + else: + for name, params in self.model_runner.model.named_parameters(): + params.data = self.cpu_model[name] + + +def init_worker_distributed_environment( + parallel_config: ParallelConfig, + rank: int, + distributed_init_method: Optional[str] = "env://", + local_rank: int = -1, +) -> None: + """Initialize the distributed environment.""" + set_custom_all_reduce(not parallel_config.disable_custom_all_reduce) + + # NOTE(sgm) use tcp://localhost:xxxx will hang in HF setting without megatron + init_distributed_environment(parallel_config.world_size, rank, distributed_init_method, local_rank) + + ensure_model_parallel_initialized( + tensor_model_parallel_size=parallel_config.tensor_parallel_size, + pipeline_model_parallel_size=parallel_config.pipeline_parallel_size, + ) + + # TODO(sgm): check whether need this + # if pynccl_utils.is_initialized(): + # pynccl_world_size = pynccl_utils.get_world_size() + # if pynccl_world_size != parallel_config.world_size: + # raise RuntimeError( + # "pynccl is already initialized but the pynccl world " + # "size does not match parallel_config.world_size " + # f"({pynccl_world_size} vs. {parallel_config.world_size}).") + # elif parallel_config.world_size > 1: + # # NOTE(woosuk): We don't initialize pynccl process group when world size + # # is 1. + # # NOTE(kaichao): By default, pynccl is initialized for tp group. + # pynccl_utils.init_process_group( + # group=get_tensor_model_parallel_cpu_group()) + + # # Initialize a custom fast all-reduce implementation. + # if not parallel_config.disable_custom_all_reduce: + # init_custom_ar() + + # A small all_reduce for warmup. + torch.distributed.all_reduce(torch.zeros(1).cuda()) + # if pynccl_utils.is_initialized(): + # pynccl_utils.all_reduce(torch.zeros(1).cuda()) diff --git a/verl/trainer/ppo/hybrid_engine/__init__.py b/verl/trainer/ppo/hybrid_engine/__init__.py index aebff5b..3713733 100644 --- a/verl/trainer/ppo/hybrid_engine/__init__.py +++ b/verl/trainer/ppo/hybrid_engine/__init__.py @@ -14,6 +14,8 @@ from verl.utils.import_utils import is_vllm_available, is_megatron_core_available +from .base import BaseShardingManager + AllGatherPPModel = None if is_megatron_core_available() and is_vllm_available(): diff --git a/verl/trainer/ppo/rollout/vllm_rollout/vllm_rollout.py b/verl/trainer/ppo/rollout/vllm_rollout/vllm_rollout.py index a631e81..e66275c 100644 --- a/verl/trainer/ppo/rollout/vllm_rollout/vllm_rollout.py +++ b/verl/trainer/ppo/rollout/vllm_rollout/vllm_rollout.py @@ -82,7 +82,7 @@ def __init__(self, actor_module: nn.Module, config: DictConfig, tokenizer, model os.environ['MEGATRON_IMPORT_TIMERS'] = '0' train_tp = kwargs.get('train_tp', None) num_tp_per_train_tp = train_tp // tensor_parallel_size - if vllm_version == '0.4.2' or vllm_version == '0.5.4': + if vllm_version in ('0.4.2', '0.5.4', '0.6.3'): vllm_ps.initialize_parallel_state(tensor_model_parallel_size=tensor_parallel_size, num_tp_per_train_tp=num_tp_per_train_tp) @@ -109,7 +109,7 @@ def __init__(self, actor_module: nn.Module, config: DictConfig, tokenizer, model ) # we may detokenize the result all together later - if vllm_version == '0.4.2' or vllm_version == '0.5.4': + if vllm_version in ('0.4.2', '0.5.4', '0.6.3'): kwargs['detokenize'] = False # supporting adding any sampling params from the config file