Skip to content

Commit

Permalink
Plumb aborts
Browse files Browse the repository at this point in the history
  • Loading branch information
Varun Sundar Rabindranath committed Nov 5, 2024
1 parent bb1a75b commit 78af701
Show file tree
Hide file tree
Showing 6 changed files with 99 additions and 25 deletions.
10 changes: 10 additions & 0 deletions vllm/v1/engine/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import enum
from dataclasses import dataclass
from typing import List, Optional, Union

Expand Down Expand Up @@ -57,3 +58,12 @@ class EngineCoreOutputs(msgspec.Struct):

# [num_reqs]
outputs: List[EngineCoreOutput]


class EngineCoreRequestType(enum.Enum):
"""
Request types defined as hex byte strings, so it can be sent over sockets
without separate encoding step.
"""
AddRequest = b'\x00'
AbortRequest = b'\x01'
10 changes: 7 additions & 3 deletions vllm/v1/engine/async_llm.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import asyncio
from typing import AsyncGenerator, Dict, Mapping, Optional, Type, Union
from typing import AsyncGenerator, Dict, List, Mapping, Optional, Type, Union

from vllm.config import EngineConfig, ModelConfig
from vllm.engine.arg_utils import AsyncEngineArgs
Expand Down Expand Up @@ -110,6 +110,9 @@ def from_engine_args(
def _get_executor_cls(cls, engine_config: EngineConfig):
return GPUExecutor

async def abort_request(self, request_ids: List[str]) -> None:
await self.engine_core.abort_requests_async(request_ids)

async def add_request(
self,
request_id: str,
Expand Down Expand Up @@ -200,9 +203,10 @@ async def _run_output_handler(self):
# NOTE: we could simplify the Detokenizer code by returning full
# List[RequestOutput] rather than pushing to the Queue at the
# expense of doing another loop through List[RequestOutput].
_to_abort = self.detokenizer.step_streaming(outputs)
requests_to_abort = self.detokenizer.step_streaming(outputs)

# TODO: send aborts (in one message)
if requests_to_abort:
await self.abort_request(requests_to_abort)
except BaseException as e:
logger.error(e)

Expand Down
1 change: 1 addition & 0 deletions vllm/v1/engine/async_stream.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio
from typing import Any, AsyncGenerator, Callable, Optional, Type, Union

from vllm.outputs import EmbeddingRequestOutput, RequestOutput


Expand Down
42 changes: 34 additions & 8 deletions vllm/v1/engine/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,10 @@
from vllm.usage.usage_lib import UsageContext
from vllm.v1.core.scheduler import Scheduler
from vllm.v1.engine import (POLLING_TIMEOUT_MS, EngineCoreOutput,
EngineCoreOutputs, EngineCoreRequest)
EngineCoreOutputs, EngineCoreRequest,
EngineCoreRequestType)
from vllm.v1.executor.gpu_executor import GPUExecutor
from vllm.v1.request import Request
from vllm.v1.request import Request, RequestStatus
from vllm.version import __version__ as VLLM_VERSION

logger = init_logger(__name__)
Expand Down Expand Up @@ -121,6 +122,10 @@ def add_request(self, request: EngineCoreRequest):
req = Request.from_engine_core_request(request)
self.scheduler.add_request(req)

def abort_requests(self, request_ids: List[str]):
self.scheduler.finish_requests(request_ids,
RequestStatus.FINISHED_ABORTED)

def step(self) -> List[EngineCoreOutput]:
"""Schedule, execute, and make output."""

Expand Down Expand Up @@ -151,7 +156,10 @@ def __init__(
super().__init__(vllm_config, executor_class, usage_context)

self.msgpack_encoder = msgspec.msgpack.Encoder()
self.msgpack_decoder = msgspec.msgpack.Decoder(EngineCoreRequest)
self.msgpack_add_request_decoder = \
msgspec.msgpack.Decoder(EngineCoreRequest)
self.msgpack_abort_requests_decoder = \
msgspec.msgpack.Decoder(list[str])

self.ctx = zmq.Context() # type: ignore[attr-defined]

Expand Down Expand Up @@ -259,14 +267,32 @@ def run_busy_loop(self):
def _handle_new_input(self):
"""Handle new input from the AsyncLLMEngine for async mode."""

def process_add_request(request_data: bytes) -> None:
engine_core_request: EngineCoreRequest = \
self.msgpack_add_request_decoder.decode(request_data)
self.add_request(engine_core_request)

def process_abort_requests(request_data: bytes) -> None:
request_ids: List[str] = \
self.msgpack_abort_requests_decoder.decode(request_data)
self.scheduler.finish_requests(request_ids,
RequestStatus.FINISHED_ABORTED)

try:
if self.input_socket.poll(timeout=0) != 0:
frames = self.input_socket.recv_multipart(copy=False)
engine_core_request = self.msgpack_decoder.decode(
frames[0].buffer)
self.add_request(engine_core_request)
request_type, request_data = \
self.input_socket.recv_multipart(copy=False)

# Process request_data based on request_type
if request_type.buffer == \
EngineCoreRequestType.AddRequest.value:
process_add_request(request_data.buffer)
elif request_type.buffer == \
EngineCoreRequestType.AbortRequest.value:
process_abort_requests(request_data.buffer)
else:
raise ValueError(f"Unhandled request type {request_type}")

# TODO: handle abort via another socket
# TODO: handle logits processors via cloudpickle
# TODO: handle profiling

Expand Down
47 changes: 40 additions & 7 deletions vllm/v1/engine/core_client.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List
from typing import List, Union

import msgspec
import zmq
Expand All @@ -7,7 +7,8 @@
from vllm.logger import init_logger
from vllm.utils import get_open_zmq_ipc_path
from vllm.v1.engine import (POLLING_TIMEOUT_MS, EngineCoreOutput,
EngineCoreOutputs, EngineCoreRequest)
EngineCoreOutputs, EngineCoreRequest,
EngineCoreRequestType)
from vllm.v1.engine.core import EngineCore, EngineCoreProc

logger = init_logger(__name__)
Expand Down Expand Up @@ -52,12 +53,18 @@ def get_output(self) -> List[EngineCoreOutput]:
def add_request(self, request: EngineCoreRequest) -> None:
raise NotImplementedError

def abort_requests(self, request_ids: List[str]) -> None:
raise NotImplementedError

async def get_output_async(self) -> List[EngineCoreOutput]:
raise NotImplementedError

async def add_request_async(self, request: EngineCoreRequest) -> None:
raise NotImplementedError

async def abort_requests_async(self, request_ids: List[str]) -> None:
raise NotImplementedError


class InprocClient(EngineCoreClient):
"""
Expand All @@ -80,6 +87,9 @@ def get_output(self) -> List[EngineCoreOutput]:
def add_request(self, request: EngineCoreRequest) -> None:
self.engine_core.add_request(request)

def abort_requests(self, request_ids: List[str]) -> None:
self.engine_core.abort_requests(request_ids)


class MPClient(EngineCoreClient):
"""
Expand Down Expand Up @@ -153,12 +163,23 @@ def get_output(self) -> List[EngineCoreOutput]:

return engine_core_outputs

def add_request(self, request: EngineCoreRequest) -> None:

self.input_socket.send_multipart((self.encoder.encode(request), ),
def _send_input(self,
request_type: EngineCoreRequestType,
request: Union[EngineCoreRequest, List[str]]) \
-> None:
self.input_socket.send_multipart((
request_type.value,
self.encoder.encode(request),
),
copy=False,
flags=zmq.NOBLOCK)

def add_request(self, request: EngineCoreRequest) -> None:
self._send_input(EngineCoreRequestType.AddRequest, request)

def abort_requests(self, request_ids: List[str]) -> None:
self._send_input(EngineCoreRequestType.AbortRequest, request_ids)


class AsyncMPClient(MPClient):
"""Asyncio-compatible client for multi-proc EngineCore."""
Expand All @@ -176,7 +197,19 @@ async def get_output_async(self) -> List[EngineCoreOutput]:

return engine_core_outputs

async def _send_input(self,
request_type: EngineCoreRequestType,
request: Union[EngineCoreRequest, List[str]])\
-> None:
await self.input_socket.send_multipart((
request_type.value,
self.encoder.encode(request),
),
copy=False,
flags=zmq.NOBLOCK)

async def add_request_async(self, request: EngineCoreRequest) -> None:
await self._send_input(EngineCoreRequestType.AddRequest, request)

await self.input_socket.send_multipart(
(self.encoder.encode(request), ), copy=False, flags=zmq.NOBLOCK)
async def abort_requests_async(self, request_ids: List[str]) -> None:
await self._send_input(EngineCoreRequestType.AbortRequest, request_ids)
14 changes: 7 additions & 7 deletions vllm/v1/engine/llm_engine.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict, Iterable, List, Mapping, Optional, Type, Union
from typing import Dict, List, Mapping, Optional, Type, Union

from vllm.config import EngineConfig
from vllm.engine.arg_utils import EngineArgs
Expand Down Expand Up @@ -107,10 +107,8 @@ def has_unfinished_requests(self) -> bool:
def validate_outputs(cls, outputs, output_type):
return outputs

def abort_request(self, request_id: Union[str, Iterable[str]]) -> None:
# TODO: send to EngineCore
# TODO: send to Detokenizer
pass
def abort_request(self, request_ids: List[str]) -> None:
self.engine_core.abort_requests(request_ids)

def add_request(
self,
Expand Down Expand Up @@ -141,9 +139,11 @@ def step(self) -> List[RequestOutput]:
engine_core_outputs = self.engine_core.get_output()

# 2) Detokenizer the EngineCoreOutput.
request_outputs, to_abort = self.detokenizer.step(engine_core_outputs)
request_outputs, requests_to_abort = self.detokenizer.step(
engine_core_outputs)

# 3) Abort requests that finished due to stopping criteria.
self.abort_request(to_abort)
if requests_to_abort:
self.abort_request(requests_to_abort)

return request_outputs

0 comments on commit 78af701

Please sign in to comment.