Skip to content

Commit

Permalink
[platform] Move excutor class init to platform
Browse files Browse the repository at this point in the history
Signed-off-by: wangxiyuan <[email protected]>
  • Loading branch information
wangxiyuan committed Dec 11, 2024
1 parent 9a93973 commit e39e97f
Show file tree
Hide file tree
Showing 11 changed files with 136 additions and 106 deletions.
58 changes: 6 additions & 52 deletions vllm/engine/async_llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,14 @@
get_guided_decoding_logits_processor)
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.outputs import PoolingRequestOutput, RequestOutput
from vllm.platforms import current_platform
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
from vllm.sequence import ExecuteModelRequest
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.usage.usage_lib import UsageContext
from vllm.utils import deprecate_kwargs, weak_bind
from vllm.utils import deprecate_kwargs, resolve_obj_by_qualname, weak_bind

logger = init_logger(__name__)
ENGINE_ITERATION_TIMEOUT_S = envs.VLLM_ENGINE_ITERATION_TIMEOUT_S
Expand Down Expand Up @@ -630,58 +631,11 @@ def _get_executor_cls(
"distributed_executor_backend must be a subclass of "
f"ExecutorAsyncBase. Got {distributed_executor_backend}.")
executor_class = distributed_executor_backend
elif engine_config.device_config.device_type == "neuron":
from vllm.executor.neuron_executor import NeuronExecutorAsync
executor_class = NeuronExecutorAsync
elif engine_config.device_config.device_type == "tpu":
if distributed_executor_backend == "ray":
from vllm.executor.ray_tpu_executor import RayTPUExecutorAsync
executor_class = RayTPUExecutorAsync
else:
assert distributed_executor_backend is None
from vllm.executor.tpu_executor import TPUExecutorAsync
executor_class = TPUExecutorAsync
elif engine_config.device_config.device_type == "cpu":
from vllm.executor.cpu_executor import CPUExecutorAsync
executor_class = CPUExecutorAsync
elif engine_config.device_config.device_type == "hpu":
if distributed_executor_backend == "ray":
initialize_ray_cluster(engine_config.parallel_config)
from vllm.executor.ray_hpu_executor import RayHPUExecutorAsync
executor_class = RayHPUExecutorAsync
else:
from vllm.executor.hpu_executor import HPUExecutorAsync
executor_class = HPUExecutorAsync
elif engine_config.device_config.device_type == "openvino":
assert distributed_executor_backend is None, (
"Distributed execution is not supported with "
"the OpenVINO backend.")
from vllm.executor.openvino_executor import OpenVINOExecutorAsync
executor_class = OpenVINOExecutorAsync
elif engine_config.device_config.device_type == "xpu":
if distributed_executor_backend is None:
from vllm.executor.xpu_executor import XPUExecutorAsync
executor_class = XPUExecutorAsync
elif distributed_executor_backend == "ray":
from vllm.executor.ray_xpu_executor import RayXPUExecutorAsync
executor_class = RayXPUExecutorAsync
elif distributed_executor_backend == "mp":
from vllm.executor.multiproc_xpu_executor import (
MultiprocessingXPUExecutorAsync)
executor_class = MultiprocessingXPUExecutorAsync
else:
raise RuntimeError(
"Not supported distributed execution model on XPU device.")
elif distributed_executor_backend == "ray":
from vllm.executor.ray_gpu_executor import RayGPUExecutorAsync
executor_class = RayGPUExecutorAsync
elif distributed_executor_backend == "mp":
from vllm.executor.multiproc_gpu_executor import (
MultiprocessingGPUExecutorAsync)
executor_class = MultiprocessingGPUExecutorAsync
else:
from vllm.executor.gpu_executor import GPUExecutorAsync
executor_class = GPUExecutorAsync
executor_cls = current_platform.get_executor_cls(
distributed_executor_backend=distributed_executor_backend,
is_async=True)
executor_class = resolve_obj_by_qualname(executor_cls)
return executor_class

@classmethod
Expand Down
61 changes: 7 additions & 54 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.outputs import (PoolingRequestOutput, RequestOutput,
RequestOutputFactory)
from vllm.platforms import current_platform
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import RequestOutputKind, SamplingParams
Expand All @@ -60,7 +61,8 @@
BaseTokenizerGroup, init_tokenizer_from_configs)
from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
usage_message)
from vllm.utils import Counter, Device, deprecate_kwargs, weak_bind
from vllm.utils import (Counter, Device, deprecate_kwargs,
resolve_obj_by_qualname, weak_bind)
from vllm.version import __version__ as VLLM_VERSION

logger = init_logger(__name__)
Expand Down Expand Up @@ -461,61 +463,12 @@ def _get_executor_cls(cls,
if distributed_executor_backend.uses_ray: # type: ignore
initialize_ray_cluster(engine_config.parallel_config)
executor_class = distributed_executor_backend
elif engine_config.device_config.device_type == "neuron":
from vllm.executor.neuron_executor import NeuronExecutor
executor_class = NeuronExecutor
elif engine_config.device_config.device_type == "tpu":
if distributed_executor_backend == "ray":
initialize_ray_cluster(engine_config.parallel_config)
from vllm.executor.ray_tpu_executor import RayTPUExecutor
executor_class = RayTPUExecutor
else:
assert distributed_executor_backend is None
from vllm.executor.tpu_executor import TPUExecutor
executor_class = TPUExecutor
elif engine_config.device_config.device_type == "cpu":
from vllm.executor.cpu_executor import CPUExecutor
executor_class = CPUExecutor
elif engine_config.device_config.device_type == "hpu":
if distributed_executor_backend == "ray":
initialize_ray_cluster(engine_config.parallel_config)
from vllm.executor.ray_hpu_executor import RayHPUExecutor
executor_class = RayHPUExecutor
else:
from vllm.executor.hpu_executor import HPUExecutor
executor_class = HPUExecutor
elif engine_config.device_config.device_type == "openvino":
from vllm.executor.openvino_executor import OpenVINOExecutor
executor_class = OpenVINOExecutor
elif engine_config.device_config.device_type == "xpu":
else:
executor_cls = current_platform.get_executor_cls(
distributed_executor_backend)
executor_class = resolve_obj_by_qualname(executor_cls)
if distributed_executor_backend == "ray":
initialize_ray_cluster(engine_config.parallel_config)
from vllm.executor.ray_xpu_executor import RayXPUExecutor
executor_class = RayXPUExecutor
elif distributed_executor_backend == "mp":
# FIXME(kunshang):
# spawn needs calling `if __name__ == '__main__':``
# fork is not supported for xpu start new process.
logger.error(
"Both start methods (spawn and fork) have issue "
"on XPU if you use mp backend, Please try ray instead.")
else:
from vllm.executor.xpu_executor import XPUExecutor
executor_class = XPUExecutor
elif distributed_executor_backend == "ray":
initialize_ray_cluster(engine_config.parallel_config)
from vllm.executor.ray_gpu_executor import RayGPUExecutor
executor_class = RayGPUExecutor
elif distributed_executor_backend == "mp":
from vllm.executor.multiproc_gpu_executor import (
MultiprocessingGPUExecutor)
assert not envs.VLLM_USE_RAY_SPMD_WORKER, (
"multiprocessing distributed executor backend does not "
"support VLLM_USE_RAY_SPMD_WORKER=1")
executor_class = MultiprocessingGPUExecutor
else:
from vllm.executor.gpu_executor import GPUExecutor
executor_class = GPUExecutor
return executor_class

@classmethod
Expand Down
7 changes: 7 additions & 0 deletions vllm/platforms/cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,3 +98,10 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
"vllm.worker.cpu_worker.CPUWorker"
else:
parallel_config.worker_cls = "vllm.worker.cpu_worker.CPUWorker"

@classmethod
def get_executor_cls(cls, distributed_executor_backend: str | None = None,
is_async: bool | None = None) -> str:
if is_async:
return "vllm.executor.cpu_executor.CPUExecutorAsync"
return "vllm.executor.cpu_executor.CPUExecutor"
23 changes: 23 additions & 0 deletions vllm/platforms/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,29 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
else:
parallel_config.worker_cls = "vllm.worker.worker.Worker"

@classmethod
def get_executor_cls(cls, distributed_executor_backend: str | None = None,
is_async: bool | None = None) -> str:
if distributed_executor_backend == "ray":
if is_async:
return "vllm.executor.ray_gpu_executor.RayGPUExecutorAsync"
else:
return "vllm.executor.ray_gpu_executor.RayGPUExecutor"
if distributed_executor_backend == "mp":
if is_async:
return "vllm.executor.multiproc_gpu_executor." \
"MultiprocessingGPUExecutorAsync"
else:
assert not envs.VLLM_USE_RAY_SPMD_WORKER, (
"multiprocessing distributed executor backend does not "
"support VLLM_USE_RAY_SPMD_WORKER=1")
return "vllm.executor.multiproc_gpu_executor." \
"MultiprocessingGPUExecutor"
if is_async:
return "vllm.executor.gpu_executor.GPUExecutorAsync"
else:
return "vllm.executor.gpu_executor.GPUExecutor"


# NVML utils
# Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`,
Expand Down
11 changes: 11 additions & 0 deletions vllm/platforms/hpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,3 +43,14 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
parallel_config = vllm_config.parallel_config
if parallel_config.worker_cls == "auto":
parallel_config.worker_cls = "vllm.worker.hpu_worker.HPUWorker"

@classmethod
def get_executor_cls(cls, distributed_executor_backend: str | None = None,
is_async: bool | None = None):
if distributed_executor_backend == "ray":
if is_async:
return "vllm.executor.ray_hpu_executor.RayHPUExecutorAsync"
return "vllm.executor.ray_hpu_executor.RayHPUExecutor"
if is_async:
return "vllm.executor.hpu_executor.HPUExecutorAsync"
return "vllm.executor.hpu_executor.HPUExecutor"
10 changes: 10 additions & 0 deletions vllm/platforms/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,10 @@

if TYPE_CHECKING:
from vllm.config import VllmConfig
from vllm.executor.executor_base import ExecutorBase
else:
VllmConfig = None
ExecutorBase = None

logger = init_logger(__name__)

Expand Down Expand Up @@ -221,6 +223,14 @@ def get_cpu_architecture(cls) -> CpuArchEnum:

return CpuArchEnum.OTHER if machine else CpuArchEnum.UNKNOWN

@classmethod
def get_executor_cls(cls, distributed_executor_backend: str | None = None,

Check failure on line 227 in vllm/platforms/interface.py

View workflow job for this annotation

GitHub Actions / mypy (3.9)

X | Y syntax for unions requires Python 3.10 [syntax]
is_async: bool | None = None) -> str:

Check failure on line 228 in vllm/platforms/interface.py

View workflow job for this annotation

GitHub Actions / mypy (3.9)

X | Y syntax for unions requires Python 3.10 [syntax]
"""
Get the executor class for the current platform.
"""
raise NotImplementedError


class UnspecifiedPlatform(Platform):
_enum = PlatformEnum.UNSPECIFIED
Expand Down
7 changes: 7 additions & 0 deletions vllm/platforms/neuron.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,10 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
if parallel_config.worker_cls == "auto":
parallel_config.worker_cls = \
"vllm.worker.neuron_worker.NeuronWorker"

@classmethod
def get_executor_cls(cls, distributed_executor_backend: str | None = None,
is_async: bool | None = None):
if is_async:
return "vllm.executor.neuron_executor.NeuronExecutorAsync"
return "vllm.executor.neuron_executor.NeuronExecutor"
10 changes: 10 additions & 0 deletions vllm/platforms/openvino.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,3 +135,13 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
raise RuntimeError(
"Invalid environment variable VLLM_OPENVINO_KVCACHE_SPACE"
f" {kv_cache_space}, expect a positive integer value.")

@classmethod
def get_executor_cls(cls, distributed_executor_backend: str | None = None,

Check failure on line 140 in vllm/platforms/openvino.py

View workflow job for this annotation

GitHub Actions / mypy (3.9)

X | Y syntax for unions requires Python 3.10 [syntax]
is_async: bool | None = None):

Check failure on line 141 in vllm/platforms/openvino.py

View workflow job for this annotation

GitHub Actions / mypy (3.9)

X | Y syntax for unions requires Python 3.10 [syntax]
assert distributed_executor_backend is None, (
"Distributed execution is not supported with "
"the OpenVINO backend.")
if is_async:
return "vllm.executor.openvino_executor.OpenVINOExecutorAsync"
return "vllm.executor.openvino_executor.OpenVINOExecutor"
23 changes: 23 additions & 0 deletions vllm/platforms/rocm.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,3 +106,26 @@ def verify_quantization(cls, quant: str) -> None:
"Using AWQ quantization with ROCm, but VLLM_USE_TRITON_AWQ"
" is not set, enabling VLLM_USE_TRITON_AWQ.")
envs.VLLM_USE_TRITON_AWQ = True

@classmethod
def get_executor_cls(cls, distributed_executor_backend: str | None = None,

Check failure on line 111 in vllm/platforms/rocm.py

View workflow job for this annotation

GitHub Actions / mypy (3.9)

X | Y syntax for unions requires Python 3.10 [syntax]
is_async: bool | None = None):

Check failure on line 112 in vllm/platforms/rocm.py

View workflow job for this annotation

GitHub Actions / mypy (3.9)

X | Y syntax for unions requires Python 3.10 [syntax]
if distributed_executor_backend == "ray":
if is_async:
return "vllm.executor.ray_gpu_executor.RayGPUExecutorAsync"
else:
return "vllm.executor.ray_gpu_executor.RayGPUExecutor"
if distributed_executor_backend == "mp":
if is_async:
return "vllm.executor.multiproc_gpu_executor." \
"MultiprocessingGPUExecutorAsync"
else:
assert not envs.VLLM_USE_RAY_SPMD_WORKER, (
"multiprocessing distributed executor backend does not "
"support VLLM_USE_RAY_SPMD_WORKER=1")
return "vllm.executor.multiproc_gpu_executor." \
"MultiprocessingGPUExecutor"
if is_async:
return "vllm.executor.gpu_executor.GPUExecutorAsync"
else:
return "vllm.executor.gpu_executor.GPUExecutor"
11 changes: 11 additions & 0 deletions vllm/platforms/tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,3 +67,14 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
"vllm.worker.multi_step_tpu_worker.MultiStepTPUWorker"
else:
parallel_config.worker_cls = "vllm.worker.tpu_worker.TPUWorker"

@classmethod
def get_executor_cls(cls, distributed_executor_backend: str | None = None,

Check failure on line 72 in vllm/platforms/tpu.py

View workflow job for this annotation

GitHub Actions / mypy (3.9)

X | Y syntax for unions requires Python 3.10 [syntax]
is_async: bool | None = None):

Check failure on line 73 in vllm/platforms/tpu.py

View workflow job for this annotation

GitHub Actions / mypy (3.9)

X | Y syntax for unions requires Python 3.10 [syntax]
if distributed_executor_backend == "ray":
if is_async:
return "vllm.executor.ray_tpu_executor.RayTPUExecutorAsync"
return "vllm.executor.ray_tpu_executor.RayTPUExecutor"
if is_async:
return "vllm.executor.tpu_executor.TPUExecutorAsync"
return "vllm.executor.tpu_executor.TPUExecutor"
21 changes: 21 additions & 0 deletions vllm/platforms/xpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,3 +78,24 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
parallel_config.distributed_executor_backend = "ray"
if parallel_config.worker_cls == "auto":
parallel_config.worker_cls = "vllm.worker.xpu_worker.XPUWorker"

@classmethod
def get_executor_cls(cls, distributed_executor_backend: str | None = None,

Check failure on line 83 in vllm/platforms/xpu.py

View workflow job for this annotation

GitHub Actions / mypy (3.9)

X | Y syntax for unions requires Python 3.10 [syntax]
is_async: bool | None = None):

Check failure on line 84 in vllm/platforms/xpu.py

View workflow job for this annotation

GitHub Actions / mypy (3.9)

X | Y syntax for unions requires Python 3.10 [syntax]
if distributed_executor_backend == "ray":
if is_async:
return "vllm.executor.ray_xpu_executor.RayXPUExecutorAsync"
return "vllm.executor.ray_xpu_executor.RayXPUExecutor"
if distributed_executor_backend == "mp":
if is_async:
return "vllm.executor.multiproc_xpu_executor." \
"MultiprocessingXPUExecutorAsync"
# FIXME(kunshang):
# spawn needs calling `if __name__ == '__main__':``
# fork is not supported for xpu start new process.
logger.error(
"Both start methods (spawn and fork) have issue "
"on XPU if you use mp backend, Please try ray instead.")
if is_async:
return "vllm.executor.xpu_executor.XPUExecutorAsync"
return "vllm.executor.xpu_executor.XPUExecutor"

0 comments on commit e39e97f

Please sign in to comment.