From 35580c6162962eabb139da1673b0460c8f913bd6 Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Sun, 3 Nov 2024 04:03:04 +0000 Subject: [PATCH] Plumb aborts --- vllm/v1/engine/__init__.py | 10 ++++++ vllm/v1/engine/async_llm.py | 10 ++++-- vllm/v1/engine/core.py | 65 +++++++++++++++++++++++++++-------- vllm/v1/engine/core_client.py | 47 +++++++++++++++++++++---- vllm/v1/engine/llm_engine.py | 14 ++++---- 5 files changed, 114 insertions(+), 32 deletions(-) diff --git a/vllm/v1/engine/__init__.py b/vllm/v1/engine/__init__.py index 2d1d2f7a5b7b2..96edc8aab6822 100644 --- a/vllm/v1/engine/__init__.py +++ b/vllm/v1/engine/__init__.py @@ -1,3 +1,4 @@ +import enum from dataclasses import dataclass from typing import List, Optional, Union @@ -56,3 +57,12 @@ class EngineCoreOutputs(msgspec.Struct, array_like=True): # [num_reqs] outputs: List[EngineCoreOutput] + + +class EngineCoreRequestType(enum.Enum): + """ + Request types defined as hex byte strings, so it can be sent over sockets + without separate encoding step. + """ + AddRequest = b'\x00' + AbortRequest = b'\x01' diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index 15563cb774907..716bfab094752 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -1,5 +1,5 @@ import asyncio -from typing import AsyncGenerator, Dict, Mapping, Optional, Type, Union +from typing import AsyncGenerator, Dict, List, Mapping, Optional, Type, Union from vllm.config import EngineConfig, ModelConfig from vllm.engine.arg_utils import AsyncEngineArgs @@ -110,6 +110,9 @@ def from_engine_args( def _get_executor_cls(cls, engine_config: EngineConfig): return GPUExecutor + async def abort_request(self, request_ids: List[str]) -> None: + await self.engine_core.abort_requests_async(request_ids) + async def add_request( self, request_id: str, @@ -200,9 +203,10 @@ async def _run_output_handler(self): # NOTE: we could simplify the Detokenizer code by returning full # List[RequestOutput] rather than pushing to the Queue at the # expense of doing another loop through List[RequestOutput]. - _to_abort = self.detokenizer.step_streaming(outputs) + requests_to_abort = self.detokenizer.step_streaming(outputs) - # TODO: send aborts (in one message) + if requests_to_abort: + await self.abort_request(requests_to_abort) except BaseException as e: logger.error(e) diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index fa983fc39ce76..9ac5e58706a2b 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -2,7 +2,8 @@ import queue from multiprocessing.process import BaseProcess from threading import Thread -from typing import List, Tuple, Type +from typing import List, Tuple, Type, Any +from collections import namedtuple import msgspec import zmq @@ -13,9 +14,10 @@ from vllm.usage.usage_lib import UsageContext from vllm.v1.core.scheduler import Scheduler from vllm.v1.engine import (POLLING_TIMEOUT_MS, EngineCoreOutput, - EngineCoreOutputs, EngineCoreRequest) + EngineCoreOutputs, EngineCoreRequest, + EngineCoreRequestType) from vllm.v1.executor.gpu_executor import GPUExecutor -from vllm.v1.request import Request +from vllm.v1.request import Request, RequestStatus from vllm.version import __version__ as VLLM_VERSION logger = init_logger(__name__) @@ -123,6 +125,10 @@ def add_request(self, request: EngineCoreRequest): req = Request.from_engine_core_request(request) self.scheduler.add_request(req) + def abort_requests(self, request_ids: List[str]): + self.scheduler.finish_requests(request_ids, + RequestStatus.FINISHED_ABORTED) + def step(self) -> List[EngineCoreOutput]: """Schedule, execute, and make output.""" @@ -153,7 +159,10 @@ def __init__( super().__init__(vllm_config, executor_class, usage_context) self.msgpack_encoder = msgspec.msgpack.Encoder() - self.msgpack_decoder = msgspec.msgpack.Decoder(EngineCoreRequest) + self.msgpack_add_request_decoder = \ + msgspec.msgpack.Decoder(EngineCoreRequest) + self.msgpack_abort_requests_decoder = \ + msgspec.msgpack.Decoder(list[str]) self.ctx = zmq.Context() # type: ignore[attr-defined] @@ -182,10 +191,27 @@ def __init__( ready_socket.close(linger=0) def process_input_socket(self): + + def get_decoder_from_request_type(req_type: bytes) \ + -> msgspec.msgpack.Decoder: + # Identify msgpack decoder based on request_type. + if req_type == EngineCoreRequestType.AddRequest.value: + return self.msgpack_add_request_decoder + elif req_type == EngineCoreRequestType.AbortRequest.value: + return self.msgpack_abort_requests_decoder + else: + raise ValueError(f"Unhandled request type {request_type}") + while True: - frames = self.input_socket.recv_multipart(copy=False) - request = self.msgpack_decoder.decode(frames[0].buffer) - self.input_queue.put_nowait(request) + request_type, request_data = \ + self.input_socket.recv_multipart(copy=False) + + # Decode request_data + msgpack_decoder: msgspec.msgpack.Decoder = \ + get_decoder_from_request_type(request_type.buffer) + request_data: Any = msgpack_decoder.decode(request_data.buffer) + + self.input_queue.put_nowait((request_type.buffer, request_data)) def process_output_socket(self): while True: @@ -267,8 +293,7 @@ def run_busy_loop(self): while True: # Poll the input socket until there is work to do. if not self.scheduler.has_unfinished_requests(): - request = self.input_queue.get() - self._handle_request(request) + self._handle_request(self.input_queue.get()) # Handle new input from the socket. self._handle_new_input() @@ -282,17 +307,27 @@ def run_busy_loop(self): def _handle_new_input(self): """Handle new input from the AsyncLLMEngine for async mode.""" while not self.input_queue.empty(): - request = self.input_queue.get_nowait() - self._handle_request(request) + self._handle_request(self.input_queue.get_nowait()) + + def _handle_request(self, request: Tuple[bytes, Any]): - def _handle_request(self, request: EngineCoreRequest): try: - self.add_request(request) + request_type, request_data = request + # Process request_data based on request_type + if request_type == EngineCoreRequestType.AddRequest.value: + assert isinstance(request_data, EngineCoreRequest), \ + f'Unexpected datatype {type(request_data)}' + self.add_request(request_data) + elif request_type == EngineCoreRequestType.AbortRequest.value: + assert isinstance(request_data, list), \ + f'Unexpected datatype {type(request_data)}' + self.scheduler.finish_requests(request_data, + RequestStatus.FINISHED_ABORTED) + else: + raise ValueError(f"Unhandled request type {request_type}") - # TODO: handle abort via another socket # TODO: handle logits processors via cloudpickle # TODO: handle profiling - except Exception as e: # TODO: handle gracefully raise e diff --git a/vllm/v1/engine/core_client.py b/vllm/v1/engine/core_client.py index d1f5e94ff437f..7fa14cc3dbc3b 100644 --- a/vllm/v1/engine/core_client.py +++ b/vllm/v1/engine/core_client.py @@ -1,4 +1,4 @@ -from typing import List +from typing import List, Union import msgspec import zmq @@ -7,7 +7,8 @@ from vllm.logger import init_logger from vllm.utils import get_open_zmq_ipc_path from vllm.v1.engine import (POLLING_TIMEOUT_MS, EngineCoreOutput, - EngineCoreOutputs, EngineCoreRequest) + EngineCoreOutputs, EngineCoreRequest, + EngineCoreRequestType) from vllm.v1.engine.core import EngineCore, EngineCoreProc logger = init_logger(__name__) @@ -52,12 +53,18 @@ def get_output(self) -> List[EngineCoreOutput]: def add_request(self, request: EngineCoreRequest) -> None: raise NotImplementedError + def abort_requests(self, request_ids: List[str]) -> None: + raise NotImplementedError + async def get_output_async(self) -> List[EngineCoreOutput]: raise NotImplementedError async def add_request_async(self, request: EngineCoreRequest) -> None: raise NotImplementedError + async def abort_requests_async(self, request_ids: List[str]) -> None: + raise NotImplementedError + class InprocClient(EngineCoreClient): """ @@ -80,6 +87,9 @@ def get_output(self) -> List[EngineCoreOutput]: def add_request(self, request: EngineCoreRequest) -> None: self.engine_core.add_request(request) + def abort_requests(self, request_ids: List[str]) -> None: + self.engine_core.abort_requests(request_ids) + class MPClient(EngineCoreClient): """ @@ -153,12 +163,23 @@ def get_output(self) -> List[EngineCoreOutput]: return engine_core_outputs - def add_request(self, request: EngineCoreRequest) -> None: - - self.input_socket.send_multipart((self.encoder.encode(request), ), + def _send_input(self, + request_type: EngineCoreRequestType, + request: Union[EngineCoreRequest, List[str]]) \ + -> None: + self.input_socket.send_multipart(( + request_type.value, + self.encoder.encode(request), + ), copy=False, flags=zmq.NOBLOCK) + def add_request(self, request: EngineCoreRequest) -> None: + self._send_input(EngineCoreRequestType.AddRequest, request) + + def abort_requests(self, request_ids: List[str]) -> None: + self._send_input(EngineCoreRequestType.AbortRequest, request_ids) + class AsyncMPClient(MPClient): """Asyncio-compatible client for multi-proc EngineCore.""" @@ -176,7 +197,19 @@ async def get_output_async(self) -> List[EngineCoreOutput]: return engine_core_outputs + async def _send_input(self, + request_type: EngineCoreRequestType, + request: Union[EngineCoreRequest, List[str]])\ + -> None: + await self.input_socket.send_multipart(( + request_type.value, + self.encoder.encode(request), + ), + copy=False, + flags=zmq.NOBLOCK) + async def add_request_async(self, request: EngineCoreRequest) -> None: + await self._send_input(EngineCoreRequestType.AddRequest, request) - await self.input_socket.send_multipart( - (self.encoder.encode(request), ), copy=False, flags=zmq.NOBLOCK) + async def abort_requests_async(self, request_ids: List[str]) -> None: + await self._send_input(EngineCoreRequestType.AbortRequest, request_ids) diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py index 24a3274c21484..f2d5506f45380 100644 --- a/vllm/v1/engine/llm_engine.py +++ b/vllm/v1/engine/llm_engine.py @@ -1,4 +1,4 @@ -from typing import Dict, Iterable, List, Mapping, Optional, Type, Union +from typing import Dict, List, Mapping, Optional, Type, Union from vllm.config import EngineConfig from vllm.engine.arg_utils import EngineArgs @@ -107,10 +107,8 @@ def has_unfinished_requests(self) -> bool: def validate_outputs(cls, outputs, output_type): return outputs - def abort_request(self, request_id: Union[str, Iterable[str]]) -> None: - # TODO: send to EngineCore - # TODO: send to Detokenizer - pass + def abort_request(self, request_ids: List[str]) -> None: + self.engine_core.abort_requests(request_ids) def add_request( self, @@ -141,9 +139,11 @@ def step(self) -> List[RequestOutput]: engine_core_outputs = self.engine_core.get_output() # 2) Detokenizer the EngineCoreOutput. - request_outputs, to_abort = self.detokenizer.step(engine_core_outputs) + request_outputs, requests_to_abort = self.detokenizer.step( + engine_core_outputs) # 3) Abort requests that finished due to stopping criteria. - self.abort_request(to_abort) + if requests_to_abort: + self.abort_request(requests_to_abort) return request_outputs