Skip to content

Commit

Permalink
Merge pull request #24 from neuralmagic/varun/9826-stop-strings
Browse files Browse the repository at this point in the history
Stop String Plumbing
  • Loading branch information
robertgshaw2-neuralmagic authored Nov 5, 2024
2 parents 3ad8684 + d517017 commit ac7b8a7
Show file tree
Hide file tree
Showing 5 changed files with 113 additions and 32 deletions.
10 changes: 10 additions & 0 deletions vllm/v1/engine/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import enum
from dataclasses import dataclass
from typing import List, Optional, Union

Expand Down Expand Up @@ -56,3 +57,12 @@ class EngineCoreOutputs(msgspec.Struct, array_like=True):

# [num_reqs]
outputs: List[EngineCoreOutput]


class EngineCoreRequestType(enum.Enum):
"""
Request types defined as hex byte strings, so it can be sent over sockets
without separate encoding step.
"""
AddRequest = b'\x00'
AbortRequest = b'\x01'
10 changes: 7 additions & 3 deletions vllm/v1/engine/async_llm.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import asyncio
from typing import AsyncGenerator, Dict, Mapping, Optional, Type, Union
from typing import AsyncGenerator, Dict, List, Mapping, Optional, Type, Union

from vllm.config import EngineConfig, ModelConfig
from vllm.engine.arg_utils import AsyncEngineArgs
Expand Down Expand Up @@ -110,6 +110,9 @@ def from_engine_args(
def _get_executor_cls(cls, engine_config: EngineConfig):
return GPUExecutor

async def abort_request(self, request_ids: List[str]) -> None:
await self.engine_core.abort_requests_async(request_ids)

async def add_request(
self,
request_id: str,
Expand Down Expand Up @@ -200,9 +203,10 @@ async def _run_output_handler(self):
# NOTE: we could simplify the Detokenizer code by returning full
# List[RequestOutput] rather than pushing to the Queue at the
# expense of doing another loop through List[RequestOutput].
_to_abort = self.detokenizer.step_streaming(outputs)
requests_to_abort = self.detokenizer.step_streaming(outputs)

# TODO: send aborts (in one message)
if requests_to_abort:
await self.abort_request(requests_to_abort)
except BaseException as e:
logger.error(e)

Expand Down
64 changes: 49 additions & 15 deletions vllm/v1/engine/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import queue
from multiprocessing.process import BaseProcess
from threading import Thread
from typing import List, Tuple, Type
from typing import List, Tuple, Type, Any

import msgspec
import zmq
Expand All @@ -13,9 +13,10 @@
from vllm.usage.usage_lib import UsageContext
from vllm.v1.core.scheduler import Scheduler
from vllm.v1.engine import (POLLING_TIMEOUT_MS, EngineCoreOutput,
EngineCoreOutputs, EngineCoreRequest)
EngineCoreOutputs, EngineCoreRequest,
EngineCoreRequestType)
from vllm.v1.executor.gpu_executor import GPUExecutor
from vllm.v1.request import Request
from vllm.v1.request import Request, RequestStatus
from vllm.version import __version__ as VLLM_VERSION

logger = init_logger(__name__)
Expand Down Expand Up @@ -123,6 +124,10 @@ def add_request(self, request: EngineCoreRequest):
req = Request.from_engine_core_request(request)
self.scheduler.add_request(req)

def abort_requests(self, request_ids: List[str]):
self.scheduler.finish_requests(request_ids,
RequestStatus.FINISHED_ABORTED)

def step(self) -> List[EngineCoreOutput]:
"""Schedule, execute, and make output."""

Expand Down Expand Up @@ -153,7 +158,10 @@ def __init__(
super().__init__(vllm_config, executor_class, usage_context)

self.msgpack_encoder = msgspec.msgpack.Encoder()
self.msgpack_decoder = msgspec.msgpack.Decoder(EngineCoreRequest)
self.msgpack_add_request_decoder = \
msgspec.msgpack.Decoder(EngineCoreRequest)
self.msgpack_abort_requests_decoder = \
msgspec.msgpack.Decoder(list[str])

self.ctx = zmq.Context() # type: ignore[attr-defined]

Expand Down Expand Up @@ -182,10 +190,27 @@ def __init__(
ready_socket.close(linger=0)

def process_input_socket(self):

def get_decoder_from_request_type(req_type: bytes) \
-> msgspec.msgpack.Decoder:
# Identify msgpack decoder based on request_type.
if req_type == EngineCoreRequestType.AddRequest.value:
return self.msgpack_add_request_decoder
elif req_type == EngineCoreRequestType.AbortRequest.value:
return self.msgpack_abort_requests_decoder
else:
raise ValueError(f"Unhandled request type {request_type}")

while True:
frames = self.input_socket.recv_multipart(copy=False)
request = self.msgpack_decoder.decode(frames[0].buffer)
self.input_queue.put_nowait(request)
request_type, request_data = \
self.input_socket.recv_multipart(copy=False)

# Decode request_data
msgpack_decoder: msgspec.msgpack.Decoder = \
get_decoder_from_request_type(request_type.buffer)
request_data: Any = msgpack_decoder.decode(request_data.buffer)

self.input_queue.put_nowait((request_type.buffer, request_data))

def process_output_socket(self):
while True:
Expand Down Expand Up @@ -267,8 +292,7 @@ def run_busy_loop(self):
while True:
# Poll the input socket until there is work to do.
if not self.scheduler.has_unfinished_requests():
request = self.input_queue.get()
self._handle_request(request)
self._handle_request(self.input_queue.get())

# Handle new input from the socket.
self._handle_new_input()
Expand All @@ -282,17 +306,27 @@ def run_busy_loop(self):
def _handle_new_input(self):
"""Handle new input from the AsyncLLMEngine for async mode."""
while not self.input_queue.empty():
request = self.input_queue.get_nowait()
self._handle_request(request)
self._handle_request(self.input_queue.get_nowait())

def _handle_request(self, request: Tuple[bytes, Any]):

def _handle_request(self, request: EngineCoreRequest):
try:
self.add_request(request)
request_type, request_data = request
# Process request_data based on request_type
if request_type == EngineCoreRequestType.AddRequest.value:
assert isinstance(request_data, EngineCoreRequest), \
f'Unexpected datatype {type(request_data)}'
self.add_request(request_data)
elif request_type == EngineCoreRequestType.AbortRequest.value:
assert isinstance(request_data, list), \
f'Unexpected datatype {type(request_data)}'
self.scheduler.finish_requests(request_data,
RequestStatus.FINISHED_ABORTED)
else:
raise ValueError(f"Unhandled request type {request_type}")

# TODO: handle abort via another socket
# TODO: handle logits processors via cloudpickle
# TODO: handle profiling

except Exception as e:
# TODO: handle gracefully
raise e
Expand Down
47 changes: 40 additions & 7 deletions vllm/v1/engine/core_client.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List
from typing import List, Union

import msgspec
import zmq
Expand All @@ -7,7 +7,8 @@
from vllm.logger import init_logger
from vllm.utils import get_open_zmq_ipc_path
from vllm.v1.engine import (POLLING_TIMEOUT_MS, EngineCoreOutput,
EngineCoreOutputs, EngineCoreRequest)
EngineCoreOutputs, EngineCoreRequest,
EngineCoreRequestType)
from vllm.v1.engine.core import EngineCore, EngineCoreProc

logger = init_logger(__name__)
Expand Down Expand Up @@ -52,12 +53,18 @@ def get_output(self) -> List[EngineCoreOutput]:
def add_request(self, request: EngineCoreRequest) -> None:
raise NotImplementedError

def abort_requests(self, request_ids: List[str]) -> None:
raise NotImplementedError

async def get_output_async(self) -> List[EngineCoreOutput]:
raise NotImplementedError

async def add_request_async(self, request: EngineCoreRequest) -> None:
raise NotImplementedError

async def abort_requests_async(self, request_ids: List[str]) -> None:
raise NotImplementedError


class InprocClient(EngineCoreClient):
"""
Expand All @@ -80,6 +87,9 @@ def get_output(self) -> List[EngineCoreOutput]:
def add_request(self, request: EngineCoreRequest) -> None:
self.engine_core.add_request(request)

def abort_requests(self, request_ids: List[str]) -> None:
self.engine_core.abort_requests(request_ids)


class MPClient(EngineCoreClient):
"""
Expand Down Expand Up @@ -153,12 +163,23 @@ def get_output(self) -> List[EngineCoreOutput]:

return engine_core_outputs

def add_request(self, request: EngineCoreRequest) -> None:

self.input_socket.send_multipart((self.encoder.encode(request), ),
def _send_input(self,
request_type: EngineCoreRequestType,
request: Union[EngineCoreRequest, List[str]]) \
-> None:
self.input_socket.send_multipart((
request_type.value,
self.encoder.encode(request),
),
copy=False,
flags=zmq.NOBLOCK)

def add_request(self, request: EngineCoreRequest) -> None:
self._send_input(EngineCoreRequestType.AddRequest, request)

def abort_requests(self, request_ids: List[str]) -> None:
self._send_input(EngineCoreRequestType.AbortRequest, request_ids)


class AsyncMPClient(MPClient):
"""Asyncio-compatible client for multi-proc EngineCore."""
Expand All @@ -176,7 +197,19 @@ async def get_output_async(self) -> List[EngineCoreOutput]:

return engine_core_outputs

async def _send_input(self,
request_type: EngineCoreRequestType,
request: Union[EngineCoreRequest, List[str]])\
-> None:
await self.input_socket.send_multipart((
request_type.value,
self.encoder.encode(request),
),
copy=False,
flags=zmq.NOBLOCK)

async def add_request_async(self, request: EngineCoreRequest) -> None:
await self._send_input(EngineCoreRequestType.AddRequest, request)

await self.input_socket.send_multipart(
(self.encoder.encode(request), ), copy=False, flags=zmq.NOBLOCK)
async def abort_requests_async(self, request_ids: List[str]) -> None:
await self._send_input(EngineCoreRequestType.AbortRequest, request_ids)
14 changes: 7 additions & 7 deletions vllm/v1/engine/llm_engine.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict, Iterable, List, Mapping, Optional, Type, Union
from typing import Dict, List, Mapping, Optional, Type, Union

from vllm.config import EngineConfig
from vllm.engine.arg_utils import EngineArgs
Expand Down Expand Up @@ -107,10 +107,8 @@ def has_unfinished_requests(self) -> bool:
def validate_outputs(cls, outputs, output_type):
return outputs

def abort_request(self, request_id: Union[str, Iterable[str]]) -> None:
# TODO: send to EngineCore
# TODO: send to Detokenizer
pass
def abort_request(self, request_ids: List[str]) -> None:
self.engine_core.abort_requests(request_ids)

def add_request(
self,
Expand Down Expand Up @@ -141,9 +139,11 @@ def step(self) -> List[RequestOutput]:
engine_core_outputs = self.engine_core.get_output()

# 2) Detokenizer the EngineCoreOutput.
request_outputs, to_abort = self.detokenizer.step(engine_core_outputs)
request_outputs, requests_to_abort = self.detokenizer.step(
engine_core_outputs)

# 3) Abort requests that finished due to stopping criteria.
self.abort_request(to_abort)
if requests_to_abort:
self.abort_request(requests_to_abort)

return request_outputs

0 comments on commit ac7b8a7

Please sign in to comment.