diff --git a/vllm/v1/engine/__init__.py b/vllm/v1/engine/__init__.py index d929fa7884dd6..ade94dd6efc89 100644 --- a/vllm/v1/engine/__init__.py +++ b/vllm/v1/engine/__init__.py @@ -1,3 +1,4 @@ +import enum from dataclasses import dataclass from typing import List, Optional, Union @@ -57,3 +58,12 @@ class EngineCoreOutputs(msgspec.Struct): # [num_reqs] outputs: List[EngineCoreOutput] + + +class EngineCoreRequestType(enum.Enum): + """ + Request types defined as hex byte strings, so it can be sent over sockets + without separate encoding step. + """ + AddRequest = b'\x00' + AbortRequest = b'\x01' diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index ba2d5345d674a..8a9e19103e0d2 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -1,5 +1,5 @@ import asyncio -from typing import AsyncGenerator, Dict, Mapping, Optional, Type, Union +from typing import AsyncGenerator, Dict, List, Mapping, Optional, Type, Union from vllm.config import EngineConfig, ModelConfig from vllm.engine.arg_utils import AsyncEngineArgs @@ -110,6 +110,9 @@ def from_engine_args( def _get_executor_cls(cls, engine_config: EngineConfig): return GPUExecutor + async def abort_request(self, request_ids: List[str]) -> None: + await self.engine_core.abort_requests_async(request_ids) + async def add_request( self, request_id: str, @@ -200,9 +203,10 @@ async def _run_output_handler(self): # NOTE: we could simplify the Detokenizer code by returning full # List[RequestOutput] rather than pushing to the Queue at the # expense of doing another loop through List[RequestOutput]. - _to_abort = self.detokenizer.step_streaming(outputs) + requests_to_abort = self.detokenizer.step_streaming(outputs) - # TODO: send aborts (in one message) + if requests_to_abort: + await self.abort_request(requests_to_abort) except BaseException as e: logger.error(e) diff --git a/vllm/v1/engine/async_stream.py b/vllm/v1/engine/async_stream.py index e79f1562a0e67..4bd57b7a9ef89 100644 --- a/vllm/v1/engine/async_stream.py +++ b/vllm/v1/engine/async_stream.py @@ -1,5 +1,6 @@ import asyncio from typing import Any, AsyncGenerator, Callable, Optional, Type, Union + from vllm.outputs import EmbeddingRequestOutput, RequestOutput diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index f5378b7a25a11..ad5d4fa247db9 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -11,9 +11,10 @@ from vllm.usage.usage_lib import UsageContext from vllm.v1.core.scheduler import Scheduler from vllm.v1.engine import (POLLING_TIMEOUT_MS, EngineCoreOutput, - EngineCoreOutputs, EngineCoreRequest) + EngineCoreOutputs, EngineCoreRequest, + EngineCoreRequestType) from vllm.v1.executor.gpu_executor import GPUExecutor -from vllm.v1.request import Request +from vllm.v1.request import Request, RequestStatus from vllm.version import __version__ as VLLM_VERSION logger = init_logger(__name__) @@ -121,6 +122,10 @@ def add_request(self, request: EngineCoreRequest): req = Request.from_engine_core_request(request) self.scheduler.add_request(req) + def abort_requests(self, request_ids: List[str]): + self.scheduler.finish_requests(request_ids, + RequestStatus.FINISHED_ABORTED) + def step(self) -> List[EngineCoreOutput]: """Schedule, execute, and make output.""" @@ -151,7 +156,10 @@ def __init__( super().__init__(vllm_config, executor_class, usage_context) self.msgpack_encoder = msgspec.msgpack.Encoder() - self.msgpack_decoder = msgspec.msgpack.Decoder(EngineCoreRequest) + self.msgpack_add_request_decoder = \ + msgspec.msgpack.Decoder(EngineCoreRequest) + self.msgpack_abort_requests_decoder = \ + msgspec.msgpack.Decoder(list[str]) self.ctx = zmq.Context() # type: ignore[attr-defined] @@ -259,14 +267,32 @@ def run_busy_loop(self): def _handle_new_input(self): """Handle new input from the AsyncLLMEngine for async mode.""" + def process_add_request(request_data: bytes) -> None: + engine_core_request: EngineCoreRequest = \ + self.msgpack_add_request_decoder.decode(request_data) + self.add_request(engine_core_request) + + def process_abort_requests(request_data: bytes) -> None: + request_ids: List[str] = \ + self.msgpack_abort_requests_decoder.decode(request_data) + self.scheduler.finish_requests(request_ids, + RequestStatus.FINISHED_ABORTED) + try: if self.input_socket.poll(timeout=0) != 0: - frames = self.input_socket.recv_multipart(copy=False) - engine_core_request = self.msgpack_decoder.decode( - frames[0].buffer) - self.add_request(engine_core_request) + request_type, request_data = \ + self.input_socket.recv_multipart(copy=False) + + # Process request_data based on request_type + if request_type.buffer == \ + EngineCoreRequestType.AddRequest.value: + process_add_request(request_data.buffer) + elif request_type.buffer == \ + EngineCoreRequestType.AbortRequest.value: + process_abort_requests(request_data.buffer) + else: + raise ValueError(f"Unhandled request type {request_type}") - # TODO: handle abort via another socket # TODO: handle logits processors via cloudpickle # TODO: handle profiling diff --git a/vllm/v1/engine/core_client.py b/vllm/v1/engine/core_client.py index d1f5e94ff437f..7fa14cc3dbc3b 100644 --- a/vllm/v1/engine/core_client.py +++ b/vllm/v1/engine/core_client.py @@ -1,4 +1,4 @@ -from typing import List +from typing import List, Union import msgspec import zmq @@ -7,7 +7,8 @@ from vllm.logger import init_logger from vllm.utils import get_open_zmq_ipc_path from vllm.v1.engine import (POLLING_TIMEOUT_MS, EngineCoreOutput, - EngineCoreOutputs, EngineCoreRequest) + EngineCoreOutputs, EngineCoreRequest, + EngineCoreRequestType) from vllm.v1.engine.core import EngineCore, EngineCoreProc logger = init_logger(__name__) @@ -52,12 +53,18 @@ def get_output(self) -> List[EngineCoreOutput]: def add_request(self, request: EngineCoreRequest) -> None: raise NotImplementedError + def abort_requests(self, request_ids: List[str]) -> None: + raise NotImplementedError + async def get_output_async(self) -> List[EngineCoreOutput]: raise NotImplementedError async def add_request_async(self, request: EngineCoreRequest) -> None: raise NotImplementedError + async def abort_requests_async(self, request_ids: List[str]) -> None: + raise NotImplementedError + class InprocClient(EngineCoreClient): """ @@ -80,6 +87,9 @@ def get_output(self) -> List[EngineCoreOutput]: def add_request(self, request: EngineCoreRequest) -> None: self.engine_core.add_request(request) + def abort_requests(self, request_ids: List[str]) -> None: + self.engine_core.abort_requests(request_ids) + class MPClient(EngineCoreClient): """ @@ -153,12 +163,23 @@ def get_output(self) -> List[EngineCoreOutput]: return engine_core_outputs - def add_request(self, request: EngineCoreRequest) -> None: - - self.input_socket.send_multipart((self.encoder.encode(request), ), + def _send_input(self, + request_type: EngineCoreRequestType, + request: Union[EngineCoreRequest, List[str]]) \ + -> None: + self.input_socket.send_multipart(( + request_type.value, + self.encoder.encode(request), + ), copy=False, flags=zmq.NOBLOCK) + def add_request(self, request: EngineCoreRequest) -> None: + self._send_input(EngineCoreRequestType.AddRequest, request) + + def abort_requests(self, request_ids: List[str]) -> None: + self._send_input(EngineCoreRequestType.AbortRequest, request_ids) + class AsyncMPClient(MPClient): """Asyncio-compatible client for multi-proc EngineCore.""" @@ -176,7 +197,19 @@ async def get_output_async(self) -> List[EngineCoreOutput]: return engine_core_outputs + async def _send_input(self, + request_type: EngineCoreRequestType, + request: Union[EngineCoreRequest, List[str]])\ + -> None: + await self.input_socket.send_multipart(( + request_type.value, + self.encoder.encode(request), + ), + copy=False, + flags=zmq.NOBLOCK) + async def add_request_async(self, request: EngineCoreRequest) -> None: + await self._send_input(EngineCoreRequestType.AddRequest, request) - await self.input_socket.send_multipart( - (self.encoder.encode(request), ), copy=False, flags=zmq.NOBLOCK) + async def abort_requests_async(self, request_ids: List[str]) -> None: + await self._send_input(EngineCoreRequestType.AbortRequest, request_ids) diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py index 24a3274c21484..f2d5506f45380 100644 --- a/vllm/v1/engine/llm_engine.py +++ b/vllm/v1/engine/llm_engine.py @@ -1,4 +1,4 @@ -from typing import Dict, Iterable, List, Mapping, Optional, Type, Union +from typing import Dict, List, Mapping, Optional, Type, Union from vllm.config import EngineConfig from vllm.engine.arg_utils import EngineArgs @@ -107,10 +107,8 @@ def has_unfinished_requests(self) -> bool: def validate_outputs(cls, outputs, output_type): return outputs - def abort_request(self, request_id: Union[str, Iterable[str]]) -> None: - # TODO: send to EngineCore - # TODO: send to Detokenizer - pass + def abort_request(self, request_ids: List[str]) -> None: + self.engine_core.abort_requests(request_ids) def add_request( self, @@ -141,9 +139,11 @@ def step(self) -> List[RequestOutput]: engine_core_outputs = self.engine_core.get_output() # 2) Detokenizer the EngineCoreOutput. - request_outputs, to_abort = self.detokenizer.step(engine_core_outputs) + request_outputs, requests_to_abort = self.detokenizer.step( + engine_core_outputs) # 3) Abort requests that finished due to stopping criteria. - self.abort_request(to_abort) + if requests_to_abort: + self.abort_request(requests_to_abort) return request_outputs