Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[platform] Move executor class init to platform #11085

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 6 additions & 52 deletions vllm/engine/async_llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,14 @@
get_guided_decoding_logits_processor)
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.outputs import PoolingRequestOutput, RequestOutput
from vllm.platforms import current_platform
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
from vllm.sequence import ExecuteModelRequest
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.usage.usage_lib import UsageContext
from vllm.utils import deprecate_kwargs, weak_bind
from vllm.utils import deprecate_kwargs, resolve_obj_by_qualname, weak_bind

logger = init_logger(__name__)
ENGINE_ITERATION_TIMEOUT_S = envs.VLLM_ENGINE_ITERATION_TIMEOUT_S
Expand Down Expand Up @@ -630,58 +631,11 @@ def _get_executor_cls(
"distributed_executor_backend must be a subclass of "
f"ExecutorAsyncBase. Got {distributed_executor_backend}.")
executor_class = distributed_executor_backend
elif engine_config.device_config.device_type == "neuron":
from vllm.executor.neuron_executor import NeuronExecutorAsync
executor_class = NeuronExecutorAsync
elif engine_config.device_config.device_type == "tpu":
if distributed_executor_backend == "ray":
from vllm.executor.ray_tpu_executor import RayTPUExecutorAsync
executor_class = RayTPUExecutorAsync
else:
assert distributed_executor_backend is None
from vllm.executor.tpu_executor import TPUExecutorAsync
executor_class = TPUExecutorAsync
elif engine_config.device_config.device_type == "cpu":
from vllm.executor.cpu_executor import CPUExecutorAsync
executor_class = CPUExecutorAsync
elif engine_config.device_config.device_type == "hpu":
if distributed_executor_backend == "ray":
initialize_ray_cluster(engine_config.parallel_config)
from vllm.executor.ray_hpu_executor import RayHPUExecutorAsync
executor_class = RayHPUExecutorAsync
else:
from vllm.executor.hpu_executor import HPUExecutorAsync
executor_class = HPUExecutorAsync
elif engine_config.device_config.device_type == "openvino":
assert distributed_executor_backend is None, (
"Distributed execution is not supported with "
"the OpenVINO backend.")
from vllm.executor.openvino_executor import OpenVINOExecutorAsync
executor_class = OpenVINOExecutorAsync
elif engine_config.device_config.device_type == "xpu":
if distributed_executor_backend is None:
from vllm.executor.xpu_executor import XPUExecutorAsync
executor_class = XPUExecutorAsync
elif distributed_executor_backend == "ray":
from vllm.executor.ray_xpu_executor import RayXPUExecutorAsync
executor_class = RayXPUExecutorAsync
elif distributed_executor_backend == "mp":
from vllm.executor.multiproc_xpu_executor import (
MultiprocessingXPUExecutorAsync)
executor_class = MultiprocessingXPUExecutorAsync
else:
raise RuntimeError(
"Not supported distributed execution model on XPU device.")
elif distributed_executor_backend == "ray":
from vllm.executor.ray_gpu_executor import RayGPUExecutorAsync
executor_class = RayGPUExecutorAsync
elif distributed_executor_backend == "mp":
from vllm.executor.multiproc_gpu_executor import (
MultiprocessingGPUExecutorAsync)
executor_class = MultiprocessingGPUExecutorAsync
else:
from vllm.executor.gpu_executor import GPUExecutorAsync
executor_class = GPUExecutorAsync
executor_cls = current_platform.get_executor_cls(
distributed_executor_backend=distributed_executor_backend,
is_async=True)
executor_class = resolve_obj_by_qualname(executor_cls)
return executor_class

@classmethod
Expand Down
62 changes: 7 additions & 55 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
import torch
from typing_extensions import TypeVar, deprecated

import vllm.envs as envs
from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig,
ObservabilityConfig, ParallelConfig, SchedulerConfig,
VllmConfig)
Expand Down Expand Up @@ -43,6 +42,7 @@
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.outputs import (PoolingRequestOutput, RequestOutput,
RequestOutputFactory)
from vllm.platforms import current_platform
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import RequestOutputKind, SamplingParams
Expand All @@ -60,7 +60,8 @@
BaseTokenizerGroup, init_tokenizer_from_configs)
from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
usage_message)
from vllm.utils import Counter, Device, deprecate_kwargs, weak_bind
from vllm.utils import (Counter, Device, deprecate_kwargs,
resolve_obj_by_qualname, weak_bind)
from vllm.version import __version__ as VLLM_VERSION

logger = init_logger(__name__)
Expand Down Expand Up @@ -461,61 +462,12 @@ def _get_executor_cls(cls,
if distributed_executor_backend.uses_ray: # type: ignore
initialize_ray_cluster(engine_config.parallel_config)
executor_class = distributed_executor_backend
elif engine_config.device_config.device_type == "neuron":
from vllm.executor.neuron_executor import NeuronExecutor
executor_class = NeuronExecutor
elif engine_config.device_config.device_type == "tpu":
if distributed_executor_backend == "ray":
initialize_ray_cluster(engine_config.parallel_config)
from vllm.executor.ray_tpu_executor import RayTPUExecutor
executor_class = RayTPUExecutor
else:
assert distributed_executor_backend is None
from vllm.executor.tpu_executor import TPUExecutor
executor_class = TPUExecutor
elif engine_config.device_config.device_type == "cpu":
from vllm.executor.cpu_executor import CPUExecutor
executor_class = CPUExecutor
elif engine_config.device_config.device_type == "hpu":
if distributed_executor_backend == "ray":
initialize_ray_cluster(engine_config.parallel_config)
from vllm.executor.ray_hpu_executor import RayHPUExecutor
executor_class = RayHPUExecutor
else:
from vllm.executor.hpu_executor import HPUExecutor
executor_class = HPUExecutor
elif engine_config.device_config.device_type == "openvino":
from vllm.executor.openvino_executor import OpenVINOExecutor
executor_class = OpenVINOExecutor
elif engine_config.device_config.device_type == "xpu":
else:
executor_cls = current_platform.get_executor_cls(
distributed_executor_backend)
executor_class = resolve_obj_by_qualname(executor_cls)
if distributed_executor_backend == "ray":
initialize_ray_cluster(engine_config.parallel_config)
from vllm.executor.ray_xpu_executor import RayXPUExecutor
executor_class = RayXPUExecutor
elif distributed_executor_backend == "mp":
# FIXME(kunshang):
# spawn needs calling `if __name__ == '__main__':``
# fork is not supported for xpu start new process.
logger.error(
"Both start methods (spawn and fork) have issue "
"on XPU if you use mp backend, Please try ray instead.")
else:
from vllm.executor.xpu_executor import XPUExecutor
executor_class = XPUExecutor
elif distributed_executor_backend == "ray":
initialize_ray_cluster(engine_config.parallel_config)
from vllm.executor.ray_gpu_executor import RayGPUExecutor
executor_class = RayGPUExecutor
elif distributed_executor_backend == "mp":
from vllm.executor.multiproc_gpu_executor import (
MultiprocessingGPUExecutor)
assert not envs.VLLM_USE_RAY_SPMD_WORKER, (
"multiprocessing distributed executor backend does not "
"support VLLM_USE_RAY_SPMD_WORKER=1")
executor_class = MultiprocessingGPUExecutor
else:
from vllm.executor.gpu_executor import GPUExecutor
executor_class = GPUExecutor
return executor_class

@classmethod
Expand Down
8 changes: 8 additions & 0 deletions vllm/platforms/cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,3 +98,11 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
"vllm.worker.cpu_worker.CPUWorker"
else:
parallel_config.worker_cls = "vllm.worker.cpu_worker.CPUWorker"

@classmethod
def get_executor_cls(cls,
distributed_executor_backend: Optional[str] = None,
is_async: Optional[bool] = None) -> str:
if is_async:
return "vllm.executor.cpu_executor.CPUExecutorAsync"
return "vllm.executor.cpu_executor.CPUExecutor"
24 changes: 24 additions & 0 deletions vllm/platforms/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,30 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
else:
parallel_config.worker_cls = "vllm.worker.worker.Worker"

@classmethod
def get_executor_cls(cls,
distributed_executor_backend: Optional[str] = None,
is_async: Optional[bool] = None) -> str:
if distributed_executor_backend == "ray":
if is_async:
return "vllm.executor.ray_gpu_executor.RayGPUExecutorAsync"
else:
return "vllm.executor.ray_gpu_executor.RayGPUExecutor"
if distributed_executor_backend == "mp":
if is_async:
return "vllm.executor.multiproc_gpu_executor." \
"MultiprocessingGPUExecutorAsync"
else:
assert not envs.VLLM_USE_RAY_SPMD_WORKER, (
"multiprocessing distributed executor backend does not "
"support VLLM_USE_RAY_SPMD_WORKER=1")
return "vllm.executor.multiproc_gpu_executor." \
"MultiprocessingGPUExecutor"
if is_async:
return "vllm.executor.gpu_executor.GPUExecutorAsync"
else:
return "vllm.executor.gpu_executor.GPUExecutor"


# NVML utils
# Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`,
Expand Down
12 changes: 12 additions & 0 deletions vllm/platforms/hpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,3 +43,15 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
parallel_config = vllm_config.parallel_config
if parallel_config.worker_cls == "auto":
parallel_config.worker_cls = "vllm.worker.hpu_worker.HPUWorker"

@classmethod
def get_executor_cls(cls,
distributed_executor_backend: Optional[str] = None,
is_async: Optional[bool] = None):
if distributed_executor_backend == "ray":
if is_async:
return "vllm.executor.ray_hpu_executor.RayHPUExecutorAsync"
return "vllm.executor.ray_hpu_executor.RayHPUExecutor"
if is_async:
return "vllm.executor.hpu_executor.HPUExecutorAsync"
return "vllm.executor.hpu_executor.HPUExecutor"
11 changes: 11 additions & 0 deletions vllm/platforms/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,10 @@

if TYPE_CHECKING:
from vllm.config import VllmConfig
from vllm.executor.executor_base import ExecutorBase
else:
VllmConfig = None
ExecutorBase = None

logger = init_logger(__name__)

Expand Down Expand Up @@ -221,6 +223,15 @@ def get_cpu_architecture(cls) -> CpuArchEnum:

return CpuArchEnum.OTHER if machine else CpuArchEnum.UNKNOWN

@classmethod
def get_executor_cls(cls,
distributed_executor_backend: Optional[str] = None,
is_async: Optional[bool] = None) -> str:
"""
Get the executor class for the current platform.
"""
raise NotImplementedError


class UnspecifiedPlatform(Platform):
_enum = PlatformEnum.UNSPECIFIED
Expand Down
8 changes: 8 additions & 0 deletions vllm/platforms/neuron.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,11 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
if parallel_config.worker_cls == "auto":
parallel_config.worker_cls = \
"vllm.worker.neuron_worker.NeuronWorker"

@classmethod
def get_executor_cls(cls,
distributed_executor_backend: Optional[str] = None,
is_async: Optional[bool] = None):
Copy link
Contributor

@MengqingCao MengqingCao Dec 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe add a assertion here for distributed_executor_backend like openvino does?

if is_async:
return "vllm.executor.neuron_executor.NeuronExecutorAsync"
return "vllm.executor.neuron_executor.NeuronExecutor"
11 changes: 11 additions & 0 deletions vllm/platforms/openvino.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,3 +135,14 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
raise RuntimeError(
"Invalid environment variable VLLM_OPENVINO_KVCACHE_SPACE"
f" {kv_cache_space}, expect a positive integer value.")

@classmethod
def get_executor_cls(cls,
distributed_executor_backend: Optional[str] = None,
is_async: Optional[bool] = None):
assert distributed_executor_backend is None, (
"Distributed execution is not supported with "
"the OpenVINO backend.")
if is_async:
return "vllm.executor.openvino_executor.OpenVINOExecutorAsync"
return "vllm.executor.openvino_executor.OpenVINOExecutor"
24 changes: 24 additions & 0 deletions vllm/platforms/rocm.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,3 +106,27 @@ def verify_quantization(cls, quant: str) -> None:
"Using AWQ quantization with ROCm, but VLLM_USE_TRITON_AWQ"
" is not set, enabling VLLM_USE_TRITON_AWQ.")
envs.VLLM_USE_TRITON_AWQ = True

@classmethod
def get_executor_cls(cls,
distributed_executor_backend: Optional[str] = None,
is_async: Optional[bool] = None):
if distributed_executor_backend == "ray":
if is_async:
return "vllm.executor.ray_gpu_executor.RayGPUExecutorAsync"
else:
return "vllm.executor.ray_gpu_executor.RayGPUExecutor"
if distributed_executor_backend == "mp":
if is_async:
return "vllm.executor.multiproc_gpu_executor." \
"MultiprocessingGPUExecutorAsync"
else:
assert not envs.VLLM_USE_RAY_SPMD_WORKER, (
"multiprocessing distributed executor backend does not "
"support VLLM_USE_RAY_SPMD_WORKER=1")
return "vllm.executor.multiproc_gpu_executor." \
"MultiprocessingGPUExecutor"
if is_async:
return "vllm.executor.gpu_executor.GPUExecutorAsync"
else:
return "vllm.executor.gpu_executor.GPUExecutor"
12 changes: 12 additions & 0 deletions vllm/platforms/tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,3 +67,15 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
"vllm.worker.multi_step_tpu_worker.MultiStepTPUWorker"
else:
parallel_config.worker_cls = "vllm.worker.tpu_worker.TPUWorker"

@classmethod
def get_executor_cls(cls,
distributed_executor_backend: Optional[str] = None,
is_async: Optional[bool] = None):
if distributed_executor_backend == "ray":
if is_async:
return "vllm.executor.ray_tpu_executor.RayTPUExecutorAsync"
return "vllm.executor.ray_tpu_executor.RayTPUExecutor"
if is_async:
return "vllm.executor.tpu_executor.TPUExecutorAsync"
return "vllm.executor.tpu_executor.TPUExecutor"
22 changes: 22 additions & 0 deletions vllm/platforms/xpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,3 +78,25 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
parallel_config.distributed_executor_backend = "ray"
if parallel_config.worker_cls == "auto":
parallel_config.worker_cls = "vllm.worker.xpu_worker.XPUWorker"

@classmethod
def get_executor_cls(cls,
distributed_executor_backend: Optional[str] = None,
is_async: Optional[bool] = None):
if distributed_executor_backend == "ray":
if is_async:
return "vllm.executor.ray_xpu_executor.RayXPUExecutorAsync"
return "vllm.executor.ray_xpu_executor.RayXPUExecutor"
if distributed_executor_backend == "mp":
if is_async:
return "vllm.executor.multiproc_xpu_executor." \
"MultiprocessingXPUExecutorAsync"
# FIXME(kunshang):
# spawn needs calling `if __name__ == '__main__':``
# fork is not supported for xpu start new process.
logger.error(
"Both start methods (spawn and fork) have issue "
"on XPU if you use mp backend, Please try ray instead.")
if is_async:
return "vllm.executor.xpu_executor.XPUExecutorAsync"
return "vllm.executor.xpu_executor.XPUExecutor"
Loading