diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 60dccd7a0812c..a05408c153c31 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -29,13 +29,14 @@ get_guided_decoding_logits_processor) from vllm.model_executor.layers.sampler import SamplerOutput from vllm.outputs import PoolingRequestOutput, RequestOutput +from vllm.platforms import current_platform from vllm.pooling_params import PoolingParams from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import SamplingParams from vllm.sequence import ExecuteModelRequest from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.usage.usage_lib import UsageContext -from vllm.utils import deprecate_kwargs, weak_bind +from vllm.utils import deprecate_kwargs, resolve_obj_by_qualname, weak_bind logger = init_logger(__name__) ENGINE_ITERATION_TIMEOUT_S = envs.VLLM_ENGINE_ITERATION_TIMEOUT_S @@ -630,58 +631,11 @@ def _get_executor_cls( "distributed_executor_backend must be a subclass of " f"ExecutorAsyncBase. Got {distributed_executor_backend}.") executor_class = distributed_executor_backend - elif engine_config.device_config.device_type == "neuron": - from vllm.executor.neuron_executor import NeuronExecutorAsync - executor_class = NeuronExecutorAsync - elif engine_config.device_config.device_type == "tpu": - if distributed_executor_backend == "ray": - from vllm.executor.ray_tpu_executor import RayTPUExecutorAsync - executor_class = RayTPUExecutorAsync - else: - assert distributed_executor_backend is None - from vllm.executor.tpu_executor import TPUExecutorAsync - executor_class = TPUExecutorAsync - elif engine_config.device_config.device_type == "cpu": - from vllm.executor.cpu_executor import CPUExecutorAsync - executor_class = CPUExecutorAsync - elif engine_config.device_config.device_type == "hpu": - if distributed_executor_backend == "ray": - initialize_ray_cluster(engine_config.parallel_config) - from vllm.executor.ray_hpu_executor import RayHPUExecutorAsync - executor_class = RayHPUExecutorAsync - else: - from vllm.executor.hpu_executor import HPUExecutorAsync - executor_class = HPUExecutorAsync - elif engine_config.device_config.device_type == "openvino": - assert distributed_executor_backend is None, ( - "Distributed execution is not supported with " - "the OpenVINO backend.") - from vllm.executor.openvino_executor import OpenVINOExecutorAsync - executor_class = OpenVINOExecutorAsync - elif engine_config.device_config.device_type == "xpu": - if distributed_executor_backend is None: - from vllm.executor.xpu_executor import XPUExecutorAsync - executor_class = XPUExecutorAsync - elif distributed_executor_backend == "ray": - from vllm.executor.ray_xpu_executor import RayXPUExecutorAsync - executor_class = RayXPUExecutorAsync - elif distributed_executor_backend == "mp": - from vllm.executor.multiproc_xpu_executor import ( - MultiprocessingXPUExecutorAsync) - executor_class = MultiprocessingXPUExecutorAsync - else: - raise RuntimeError( - "Not supported distributed execution model on XPU device.") - elif distributed_executor_backend == "ray": - from vllm.executor.ray_gpu_executor import RayGPUExecutorAsync - executor_class = RayGPUExecutorAsync - elif distributed_executor_backend == "mp": - from vllm.executor.multiproc_gpu_executor import ( - MultiprocessingGPUExecutorAsync) - executor_class = MultiprocessingGPUExecutorAsync else: - from vllm.executor.gpu_executor import GPUExecutorAsync - executor_class = GPUExecutorAsync + executor_cls = current_platform.get_executor_cls( + distributed_executor_backend=distributed_executor_backend, + is_async=True) + executor_class = resolve_obj_by_qualname(executor_cls) return executor_class @classmethod diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 6eca304b45f07..c9f6ce580a0a2 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -43,6 +43,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry from vllm.outputs import (PoolingRequestOutput, RequestOutput, RequestOutputFactory) +from vllm.platforms import current_platform from vllm.pooling_params import PoolingParams from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import RequestOutputKind, SamplingParams @@ -60,7 +61,8 @@ BaseTokenizerGroup, init_tokenizer_from_configs) from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled, usage_message) -from vllm.utils import Counter, Device, deprecate_kwargs, weak_bind +from vllm.utils import (Counter, Device, deprecate_kwargs, + resolve_obj_by_qualname, weak_bind) from vllm.version import __version__ as VLLM_VERSION logger = init_logger(__name__) @@ -461,61 +463,12 @@ def _get_executor_cls(cls, if distributed_executor_backend.uses_ray: # type: ignore initialize_ray_cluster(engine_config.parallel_config) executor_class = distributed_executor_backend - elif engine_config.device_config.device_type == "neuron": - from vllm.executor.neuron_executor import NeuronExecutor - executor_class = NeuronExecutor - elif engine_config.device_config.device_type == "tpu": - if distributed_executor_backend == "ray": - initialize_ray_cluster(engine_config.parallel_config) - from vllm.executor.ray_tpu_executor import RayTPUExecutor - executor_class = RayTPUExecutor - else: - assert distributed_executor_backend is None - from vllm.executor.tpu_executor import TPUExecutor - executor_class = TPUExecutor - elif engine_config.device_config.device_type == "cpu": - from vllm.executor.cpu_executor import CPUExecutor - executor_class = CPUExecutor - elif engine_config.device_config.device_type == "hpu": - if distributed_executor_backend == "ray": - initialize_ray_cluster(engine_config.parallel_config) - from vllm.executor.ray_hpu_executor import RayHPUExecutor - executor_class = RayHPUExecutor - else: - from vllm.executor.hpu_executor import HPUExecutor - executor_class = HPUExecutor - elif engine_config.device_config.device_type == "openvino": - from vllm.executor.openvino_executor import OpenVINOExecutor - executor_class = OpenVINOExecutor - elif engine_config.device_config.device_type == "xpu": + else: + executor_cls = current_platform.get_executor_cls( + distributed_executor_backend) + executor_class = resolve_obj_by_qualname(executor_cls) if distributed_executor_backend == "ray": initialize_ray_cluster(engine_config.parallel_config) - from vllm.executor.ray_xpu_executor import RayXPUExecutor - executor_class = RayXPUExecutor - elif distributed_executor_backend == "mp": - # FIXME(kunshang): - # spawn needs calling `if __name__ == '__main__':`` - # fork is not supported for xpu start new process. - logger.error( - "Both start methods (spawn and fork) have issue " - "on XPU if you use mp backend, Please try ray instead.") - else: - from vllm.executor.xpu_executor import XPUExecutor - executor_class = XPUExecutor - elif distributed_executor_backend == "ray": - initialize_ray_cluster(engine_config.parallel_config) - from vllm.executor.ray_gpu_executor import RayGPUExecutor - executor_class = RayGPUExecutor - elif distributed_executor_backend == "mp": - from vllm.executor.multiproc_gpu_executor import ( - MultiprocessingGPUExecutor) - assert not envs.VLLM_USE_RAY_SPMD_WORKER, ( - "multiprocessing distributed executor backend does not " - "support VLLM_USE_RAY_SPMD_WORKER=1") - executor_class = MultiprocessingGPUExecutor - else: - from vllm.executor.gpu_executor import GPUExecutor - executor_class = GPUExecutor return executor_class @classmethod diff --git a/vllm/platforms/cpu.py b/vllm/platforms/cpu.py index e5142b985d1f2..b5f254797649d 100644 --- a/vllm/platforms/cpu.py +++ b/vllm/platforms/cpu.py @@ -98,3 +98,10 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: "vllm.worker.cpu_worker.CPUWorker" else: parallel_config.worker_cls = "vllm.worker.cpu_worker.CPUWorker" + + @classmethod + def get_executor_cls(cls, distributed_executor_backend: str | None = None, + is_async: bool | None = None) -> str: + if is_async: + return "vllm.executor.cpu_executor.CPUExecutorAsync" + return "vllm.executor.cpu_executor.CPUExecutor" diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index ae1fd6d5ce068..7fe39203db8fa 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -137,6 +137,29 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: else: parallel_config.worker_cls = "vllm.worker.worker.Worker" + @classmethod + def get_executor_cls(cls, distributed_executor_backend: str | None = None, + is_async: bool | None = None) -> str: + if distributed_executor_backend == "ray": + if is_async: + return "vllm.executor.ray_gpu_executor.RayGPUExecutorAsync" + else: + return "vllm.executor.ray_gpu_executor.RayGPUExecutor" + if distributed_executor_backend == "mp": + if is_async: + return "vllm.executor.multiproc_gpu_executor." \ + "MultiprocessingGPUExecutorAsync" + else: + assert not envs.VLLM_USE_RAY_SPMD_WORKER, ( + "multiprocessing distributed executor backend does not " + "support VLLM_USE_RAY_SPMD_WORKER=1") + return "vllm.executor.multiproc_gpu_executor." \ + "MultiprocessingGPUExecutor" + if is_async: + return "vllm.executor.gpu_executor.GPUExecutorAsync" + else: + return "vllm.executor.gpu_executor.GPUExecutor" + # NVML utils # Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`, diff --git a/vllm/platforms/hpu.py b/vllm/platforms/hpu.py index 7f22bee3eaa74..fa714a2db20f2 100644 --- a/vllm/platforms/hpu.py +++ b/vllm/platforms/hpu.py @@ -43,3 +43,14 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: parallel_config = vllm_config.parallel_config if parallel_config.worker_cls == "auto": parallel_config.worker_cls = "vllm.worker.hpu_worker.HPUWorker" + + @classmethod + def get_executor_cls(cls, distributed_executor_backend: str | None = None, + is_async: bool | None = None): + if distributed_executor_backend == "ray": + if is_async: + return "vllm.executor.ray_hpu_executor.RayHPUExecutorAsync" + return "vllm.executor.ray_hpu_executor.RayHPUExecutor" + if is_async: + return "vllm.executor.hpu_executor.HPUExecutorAsync" + return "vllm.executor.hpu_executor.HPUExecutor" diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index db06d2c18e681..ba9ba2cf49c66 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -10,8 +10,10 @@ if TYPE_CHECKING: from vllm.config import VllmConfig + from vllm.executor.executor_base import ExecutorBase else: VllmConfig = None + ExecutorBase = None logger = init_logger(__name__) @@ -221,6 +223,14 @@ def get_cpu_architecture(cls) -> CpuArchEnum: return CpuArchEnum.OTHER if machine else CpuArchEnum.UNKNOWN + @classmethod + def get_executor_cls(cls, distributed_executor_backend: str | None = None, + is_async: bool | None = None) -> str: + """ + Get the executor class for the current platform. + """ + raise NotImplementedError + class UnspecifiedPlatform(Platform): _enum = PlatformEnum.UNSPECIFIED diff --git a/vllm/platforms/neuron.py b/vllm/platforms/neuron.py index 1e5c4bddfa24f..1218c86fe2ff4 100644 --- a/vllm/platforms/neuron.py +++ b/vllm/platforms/neuron.py @@ -28,3 +28,10 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: if parallel_config.worker_cls == "auto": parallel_config.worker_cls = \ "vllm.worker.neuron_worker.NeuronWorker" + + @classmethod + def get_executor_cls(cls, distributed_executor_backend: str | None = None, + is_async: bool | None = None): + if is_async: + return "vllm.executor.neuron_executor.NeuronExecutorAsync" + return "vllm.executor.neuron_executor.NeuronExecutor" diff --git a/vllm/platforms/openvino.py b/vllm/platforms/openvino.py index e0f8e8b4b49fe..a6afbd8fca83f 100644 --- a/vllm/platforms/openvino.py +++ b/vllm/platforms/openvino.py @@ -135,3 +135,13 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: raise RuntimeError( "Invalid environment variable VLLM_OPENVINO_KVCACHE_SPACE" f" {kv_cache_space}, expect a positive integer value.") + + @classmethod + def get_executor_cls(cls, distributed_executor_backend: str | None = None, + is_async: bool | None = None): + assert distributed_executor_backend is None, ( + "Distributed execution is not supported with " + "the OpenVINO backend.") + if is_async: + return "vllm.executor.openvino_executor.OpenVINOExecutorAsync" + return "vllm.executor.openvino_executor.OpenVINOExecutor" diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index 0133f26a0b1bc..d56b18e645fab 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -106,3 +106,26 @@ def verify_quantization(cls, quant: str) -> None: "Using AWQ quantization with ROCm, but VLLM_USE_TRITON_AWQ" " is not set, enabling VLLM_USE_TRITON_AWQ.") envs.VLLM_USE_TRITON_AWQ = True + + @classmethod + def get_executor_cls(cls, distributed_executor_backend: str | None = None, + is_async: bool | None = None): + if distributed_executor_backend == "ray": + if is_async: + return "vllm.executor.ray_gpu_executor.RayGPUExecutorAsync" + else: + return "vllm.executor.ray_gpu_executor.RayGPUExecutor" + if distributed_executor_backend == "mp": + if is_async: + return "vllm.executor.multiproc_gpu_executor." \ + "MultiprocessingGPUExecutorAsync" + else: + assert not envs.VLLM_USE_RAY_SPMD_WORKER, ( + "multiprocessing distributed executor backend does not " + "support VLLM_USE_RAY_SPMD_WORKER=1") + return "vllm.executor.multiproc_gpu_executor." \ + "MultiprocessingGPUExecutor" + if is_async: + return "vllm.executor.gpu_executor.GPUExecutorAsync" + else: + return "vllm.executor.gpu_executor.GPUExecutor" diff --git a/vllm/platforms/tpu.py b/vllm/platforms/tpu.py index 10d874349f36b..f5d53259f100a 100644 --- a/vllm/platforms/tpu.py +++ b/vllm/platforms/tpu.py @@ -67,3 +67,14 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: "vllm.worker.multi_step_tpu_worker.MultiStepTPUWorker" else: parallel_config.worker_cls = "vllm.worker.tpu_worker.TPUWorker" + + @classmethod + def get_executor_cls(cls, distributed_executor_backend: str | None = None, + is_async: bool | None = None): + if distributed_executor_backend == "ray": + if is_async: + return "vllm.executor.ray_tpu_executor.RayTPUExecutorAsync" + return "vllm.executor.ray_tpu_executor.RayTPUExecutor" + if is_async: + return "vllm.executor.tpu_executor.TPUExecutorAsync" + return "vllm.executor.tpu_executor.TPUExecutor" diff --git a/vllm/platforms/xpu.py b/vllm/platforms/xpu.py index 11dbd04d55671..eceac43039ad3 100644 --- a/vllm/platforms/xpu.py +++ b/vllm/platforms/xpu.py @@ -78,3 +78,24 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: parallel_config.distributed_executor_backend = "ray" if parallel_config.worker_cls == "auto": parallel_config.worker_cls = "vllm.worker.xpu_worker.XPUWorker" + + @classmethod + def get_executor_cls(cls, distributed_executor_backend: str | None = None, + is_async: bool | None = None): + if distributed_executor_backend == "ray": + if is_async: + return "vllm.executor.ray_xpu_executor.RayXPUExecutorAsync" + return "vllm.executor.ray_xpu_executor.RayXPUExecutor" + if distributed_executor_backend == "mp": + if is_async: + return "vllm.executor.multiproc_xpu_executor." \ + "MultiprocessingXPUExecutorAsync" + # FIXME(kunshang): + # spawn needs calling `if __name__ == '__main__':`` + # fork is not supported for xpu start new process. + logger.error( + "Both start methods (spawn and fork) have issue " + "on XPU if you use mp backend, Please try ray instead.") + if is_async: + return "vllm.executor.xpu_executor.XPUExecutorAsync" + return "vllm.executor.xpu_executor.XPUExecutor"