Skip to content

Commit

Permalink
Plumb aborts
Browse files Browse the repository at this point in the history
  • Loading branch information
Varun Sundar Rabindranath committed Nov 5, 2024
1 parent 3ad8684 commit 35580c6
Show file tree
Hide file tree
Showing 5 changed files with 114 additions and 32 deletions.
10 changes: 10 additions & 0 deletions vllm/v1/engine/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import enum
from dataclasses import dataclass
from typing import List, Optional, Union

Expand Down Expand Up @@ -56,3 +57,12 @@ class EngineCoreOutputs(msgspec.Struct, array_like=True):

# [num_reqs]
outputs: List[EngineCoreOutput]


class EngineCoreRequestType(enum.Enum):
"""
Request types defined as hex byte strings, so it can be sent over sockets
without separate encoding step.
"""
AddRequest = b'\x00'
AbortRequest = b'\x01'
10 changes: 7 additions & 3 deletions vllm/v1/engine/async_llm.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import asyncio
from typing import AsyncGenerator, Dict, Mapping, Optional, Type, Union
from typing import AsyncGenerator, Dict, List, Mapping, Optional, Type, Union

from vllm.config import EngineConfig, ModelConfig
from vllm.engine.arg_utils import AsyncEngineArgs
Expand Down Expand Up @@ -110,6 +110,9 @@ def from_engine_args(
def _get_executor_cls(cls, engine_config: EngineConfig):
return GPUExecutor

async def abort_request(self, request_ids: List[str]) -> None:
await self.engine_core.abort_requests_async(request_ids)

async def add_request(
self,
request_id: str,
Expand Down Expand Up @@ -200,9 +203,10 @@ async def _run_output_handler(self):
# NOTE: we could simplify the Detokenizer code by returning full
# List[RequestOutput] rather than pushing to the Queue at the
# expense of doing another loop through List[RequestOutput].
_to_abort = self.detokenizer.step_streaming(outputs)
requests_to_abort = self.detokenizer.step_streaming(outputs)

# TODO: send aborts (in one message)
if requests_to_abort:
await self.abort_request(requests_to_abort)
except BaseException as e:
logger.error(e)

Expand Down
65 changes: 50 additions & 15 deletions vllm/v1/engine/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
import queue
from multiprocessing.process import BaseProcess
from threading import Thread
from typing import List, Tuple, Type
from typing import List, Tuple, Type, Any
from collections import namedtuple

import msgspec
import zmq
Expand All @@ -13,9 +14,10 @@
from vllm.usage.usage_lib import UsageContext
from vllm.v1.core.scheduler import Scheduler
from vllm.v1.engine import (POLLING_TIMEOUT_MS, EngineCoreOutput,
EngineCoreOutputs, EngineCoreRequest)
EngineCoreOutputs, EngineCoreRequest,
EngineCoreRequestType)
from vllm.v1.executor.gpu_executor import GPUExecutor
from vllm.v1.request import Request
from vllm.v1.request import Request, RequestStatus
from vllm.version import __version__ as VLLM_VERSION

logger = init_logger(__name__)
Expand Down Expand Up @@ -123,6 +125,10 @@ def add_request(self, request: EngineCoreRequest):
req = Request.from_engine_core_request(request)
self.scheduler.add_request(req)

def abort_requests(self, request_ids: List[str]):
self.scheduler.finish_requests(request_ids,
RequestStatus.FINISHED_ABORTED)

def step(self) -> List[EngineCoreOutput]:
"""Schedule, execute, and make output."""

Expand Down Expand Up @@ -153,7 +159,10 @@ def __init__(
super().__init__(vllm_config, executor_class, usage_context)

self.msgpack_encoder = msgspec.msgpack.Encoder()
self.msgpack_decoder = msgspec.msgpack.Decoder(EngineCoreRequest)
self.msgpack_add_request_decoder = \
msgspec.msgpack.Decoder(EngineCoreRequest)
self.msgpack_abort_requests_decoder = \
msgspec.msgpack.Decoder(list[str])

self.ctx = zmq.Context() # type: ignore[attr-defined]

Expand Down Expand Up @@ -182,10 +191,27 @@ def __init__(
ready_socket.close(linger=0)

def process_input_socket(self):

def get_decoder_from_request_type(req_type: bytes) \
-> msgspec.msgpack.Decoder:
# Identify msgpack decoder based on request_type.
if req_type == EngineCoreRequestType.AddRequest.value:
return self.msgpack_add_request_decoder
elif req_type == EngineCoreRequestType.AbortRequest.value:
return self.msgpack_abort_requests_decoder
else:
raise ValueError(f"Unhandled request type {request_type}")

while True:
frames = self.input_socket.recv_multipart(copy=False)
request = self.msgpack_decoder.decode(frames[0].buffer)
self.input_queue.put_nowait(request)
request_type, request_data = \
self.input_socket.recv_multipart(copy=False)

# Decode request_data
msgpack_decoder: msgspec.msgpack.Decoder = \
get_decoder_from_request_type(request_type.buffer)
request_data: Any = msgpack_decoder.decode(request_data.buffer)

self.input_queue.put_nowait((request_type.buffer, request_data))

def process_output_socket(self):
while True:
Expand Down Expand Up @@ -267,8 +293,7 @@ def run_busy_loop(self):
while True:
# Poll the input socket until there is work to do.
if not self.scheduler.has_unfinished_requests():
request = self.input_queue.get()
self._handle_request(request)
self._handle_request(self.input_queue.get())

# Handle new input from the socket.
self._handle_new_input()
Expand All @@ -282,17 +307,27 @@ def run_busy_loop(self):
def _handle_new_input(self):
"""Handle new input from the AsyncLLMEngine for async mode."""
while not self.input_queue.empty():
request = self.input_queue.get_nowait()
self._handle_request(request)
self._handle_request(self.input_queue.get_nowait())

def _handle_request(self, request: Tuple[bytes, Any]):

def _handle_request(self, request: EngineCoreRequest):
try:
self.add_request(request)
request_type, request_data = request
# Process request_data based on request_type
if request_type == EngineCoreRequestType.AddRequest.value:
assert isinstance(request_data, EngineCoreRequest), \
f'Unexpected datatype {type(request_data)}'
self.add_request(request_data)
elif request_type == EngineCoreRequestType.AbortRequest.value:
assert isinstance(request_data, list), \
f'Unexpected datatype {type(request_data)}'
self.scheduler.finish_requests(request_data,
RequestStatus.FINISHED_ABORTED)
else:
raise ValueError(f"Unhandled request type {request_type}")

# TODO: handle abort via another socket
# TODO: handle logits processors via cloudpickle
# TODO: handle profiling

except Exception as e:
# TODO: handle gracefully
raise e
Expand Down
47 changes: 40 additions & 7 deletions vllm/v1/engine/core_client.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List
from typing import List, Union

import msgspec
import zmq
Expand All @@ -7,7 +7,8 @@
from vllm.logger import init_logger
from vllm.utils import get_open_zmq_ipc_path
from vllm.v1.engine import (POLLING_TIMEOUT_MS, EngineCoreOutput,
EngineCoreOutputs, EngineCoreRequest)
EngineCoreOutputs, EngineCoreRequest,
EngineCoreRequestType)
from vllm.v1.engine.core import EngineCore, EngineCoreProc

logger = init_logger(__name__)
Expand Down Expand Up @@ -52,12 +53,18 @@ def get_output(self) -> List[EngineCoreOutput]:
def add_request(self, request: EngineCoreRequest) -> None:
raise NotImplementedError

def abort_requests(self, request_ids: List[str]) -> None:
raise NotImplementedError

async def get_output_async(self) -> List[EngineCoreOutput]:
raise NotImplementedError

async def add_request_async(self, request: EngineCoreRequest) -> None:
raise NotImplementedError

async def abort_requests_async(self, request_ids: List[str]) -> None:
raise NotImplementedError


class InprocClient(EngineCoreClient):
"""
Expand All @@ -80,6 +87,9 @@ def get_output(self) -> List[EngineCoreOutput]:
def add_request(self, request: EngineCoreRequest) -> None:
self.engine_core.add_request(request)

def abort_requests(self, request_ids: List[str]) -> None:
self.engine_core.abort_requests(request_ids)


class MPClient(EngineCoreClient):
"""
Expand Down Expand Up @@ -153,12 +163,23 @@ def get_output(self) -> List[EngineCoreOutput]:

return engine_core_outputs

def add_request(self, request: EngineCoreRequest) -> None:

self.input_socket.send_multipart((self.encoder.encode(request), ),
def _send_input(self,
request_type: EngineCoreRequestType,
request: Union[EngineCoreRequest, List[str]]) \
-> None:
self.input_socket.send_multipart((
request_type.value,
self.encoder.encode(request),
),
copy=False,
flags=zmq.NOBLOCK)

def add_request(self, request: EngineCoreRequest) -> None:
self._send_input(EngineCoreRequestType.AddRequest, request)

def abort_requests(self, request_ids: List[str]) -> None:
self._send_input(EngineCoreRequestType.AbortRequest, request_ids)


class AsyncMPClient(MPClient):
"""Asyncio-compatible client for multi-proc EngineCore."""
Expand All @@ -176,7 +197,19 @@ async def get_output_async(self) -> List[EngineCoreOutput]:

return engine_core_outputs

async def _send_input(self,
request_type: EngineCoreRequestType,
request: Union[EngineCoreRequest, List[str]])\
-> None:
await self.input_socket.send_multipart((
request_type.value,
self.encoder.encode(request),
),
copy=False,
flags=zmq.NOBLOCK)

async def add_request_async(self, request: EngineCoreRequest) -> None:
await self._send_input(EngineCoreRequestType.AddRequest, request)

await self.input_socket.send_multipart(
(self.encoder.encode(request), ), copy=False, flags=zmq.NOBLOCK)
async def abort_requests_async(self, request_ids: List[str]) -> None:
await self._send_input(EngineCoreRequestType.AbortRequest, request_ids)
14 changes: 7 additions & 7 deletions vllm/v1/engine/llm_engine.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict, Iterable, List, Mapping, Optional, Type, Union
from typing import Dict, List, Mapping, Optional, Type, Union

from vllm.config import EngineConfig
from vllm.engine.arg_utils import EngineArgs
Expand Down Expand Up @@ -107,10 +107,8 @@ def has_unfinished_requests(self) -> bool:
def validate_outputs(cls, outputs, output_type):
return outputs

def abort_request(self, request_id: Union[str, Iterable[str]]) -> None:
# TODO: send to EngineCore
# TODO: send to Detokenizer
pass
def abort_request(self, request_ids: List[str]) -> None:
self.engine_core.abort_requests(request_ids)

def add_request(
self,
Expand Down Expand Up @@ -141,9 +139,11 @@ def step(self) -> List[RequestOutput]:
engine_core_outputs = self.engine_core.get_output()

# 2) Detokenizer the EngineCoreOutput.
request_outputs, to_abort = self.detokenizer.step(engine_core_outputs)
request_outputs, requests_to_abort = self.detokenizer.step(
engine_core_outputs)

# 3) Abort requests that finished due to stopping criteria.
self.abort_request(to_abort)
if requests_to_abort:
self.abort_request(requests_to_abort)

return request_outputs

0 comments on commit 35580c6

Please sign in to comment.