Skip to content

Commit ac7b8a7

Browse files
Merge pull request #24 from neuralmagic/varun/9826-stop-strings
Stop String Plumbing
2 parents 3ad8684 + d517017 commit ac7b8a7

File tree

5 files changed

+113
-32
lines changed

5 files changed

+113
-32
lines changed

vllm/v1/engine/__init__.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import enum
12
from dataclasses import dataclass
23
from typing import List, Optional, Union
34

@@ -56,3 +57,12 @@ class EngineCoreOutputs(msgspec.Struct, array_like=True):
5657

5758
# [num_reqs]
5859
outputs: List[EngineCoreOutput]
60+
61+
62+
class EngineCoreRequestType(enum.Enum):
63+
"""
64+
Request types defined as hex byte strings, so it can be sent over sockets
65+
without separate encoding step.
66+
"""
67+
AddRequest = b'\x00'
68+
AbortRequest = b'\x01'

vllm/v1/engine/async_llm.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import asyncio
2-
from typing import AsyncGenerator, Dict, Mapping, Optional, Type, Union
2+
from typing import AsyncGenerator, Dict, List, Mapping, Optional, Type, Union
33

44
from vllm.config import EngineConfig, ModelConfig
55
from vllm.engine.arg_utils import AsyncEngineArgs
@@ -110,6 +110,9 @@ def from_engine_args(
110110
def _get_executor_cls(cls, engine_config: EngineConfig):
111111
return GPUExecutor
112112

113+
async def abort_request(self, request_ids: List[str]) -> None:
114+
await self.engine_core.abort_requests_async(request_ids)
115+
113116
async def add_request(
114117
self,
115118
request_id: str,
@@ -200,9 +203,10 @@ async def _run_output_handler(self):
200203
# NOTE: we could simplify the Detokenizer code by returning full
201204
# List[RequestOutput] rather than pushing to the Queue at the
202205
# expense of doing another loop through List[RequestOutput].
203-
_to_abort = self.detokenizer.step_streaming(outputs)
206+
requests_to_abort = self.detokenizer.step_streaming(outputs)
204207

205-
# TODO: send aborts (in one message)
208+
if requests_to_abort:
209+
await self.abort_request(requests_to_abort)
206210
except BaseException as e:
207211
logger.error(e)
208212

vllm/v1/engine/core.py

Lines changed: 49 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import queue
33
from multiprocessing.process import BaseProcess
44
from threading import Thread
5-
from typing import List, Tuple, Type
5+
from typing import List, Tuple, Type, Any
66

77
import msgspec
88
import zmq
@@ -13,9 +13,10 @@
1313
from vllm.usage.usage_lib import UsageContext
1414
from vllm.v1.core.scheduler import Scheduler
1515
from vllm.v1.engine import (POLLING_TIMEOUT_MS, EngineCoreOutput,
16-
EngineCoreOutputs, EngineCoreRequest)
16+
EngineCoreOutputs, EngineCoreRequest,
17+
EngineCoreRequestType)
1718
from vllm.v1.executor.gpu_executor import GPUExecutor
18-
from vllm.v1.request import Request
19+
from vllm.v1.request import Request, RequestStatus
1920
from vllm.version import __version__ as VLLM_VERSION
2021

2122
logger = init_logger(__name__)
@@ -123,6 +124,10 @@ def add_request(self, request: EngineCoreRequest):
123124
req = Request.from_engine_core_request(request)
124125
self.scheduler.add_request(req)
125126

127+
def abort_requests(self, request_ids: List[str]):
128+
self.scheduler.finish_requests(request_ids,
129+
RequestStatus.FINISHED_ABORTED)
130+
126131
def step(self) -> List[EngineCoreOutput]:
127132
"""Schedule, execute, and make output."""
128133

@@ -153,7 +158,10 @@ def __init__(
153158
super().__init__(vllm_config, executor_class, usage_context)
154159

155160
self.msgpack_encoder = msgspec.msgpack.Encoder()
156-
self.msgpack_decoder = msgspec.msgpack.Decoder(EngineCoreRequest)
161+
self.msgpack_add_request_decoder = \
162+
msgspec.msgpack.Decoder(EngineCoreRequest)
163+
self.msgpack_abort_requests_decoder = \
164+
msgspec.msgpack.Decoder(list[str])
157165

158166
self.ctx = zmq.Context() # type: ignore[attr-defined]
159167

@@ -182,10 +190,27 @@ def __init__(
182190
ready_socket.close(linger=0)
183191

184192
def process_input_socket(self):
193+
194+
def get_decoder_from_request_type(req_type: bytes) \
195+
-> msgspec.msgpack.Decoder:
196+
# Identify msgpack decoder based on request_type.
197+
if req_type == EngineCoreRequestType.AddRequest.value:
198+
return self.msgpack_add_request_decoder
199+
elif req_type == EngineCoreRequestType.AbortRequest.value:
200+
return self.msgpack_abort_requests_decoder
201+
else:
202+
raise ValueError(f"Unhandled request type {request_type}")
203+
185204
while True:
186-
frames = self.input_socket.recv_multipart(copy=False)
187-
request = self.msgpack_decoder.decode(frames[0].buffer)
188-
self.input_queue.put_nowait(request)
205+
request_type, request_data = \
206+
self.input_socket.recv_multipart(copy=False)
207+
208+
# Decode request_data
209+
msgpack_decoder: msgspec.msgpack.Decoder = \
210+
get_decoder_from_request_type(request_type.buffer)
211+
request_data: Any = msgpack_decoder.decode(request_data.buffer)
212+
213+
self.input_queue.put_nowait((request_type.buffer, request_data))
189214

190215
def process_output_socket(self):
191216
while True:
@@ -267,8 +292,7 @@ def run_busy_loop(self):
267292
while True:
268293
# Poll the input socket until there is work to do.
269294
if not self.scheduler.has_unfinished_requests():
270-
request = self.input_queue.get()
271-
self._handle_request(request)
295+
self._handle_request(self.input_queue.get())
272296

273297
# Handle new input from the socket.
274298
self._handle_new_input()
@@ -282,17 +306,27 @@ def run_busy_loop(self):
282306
def _handle_new_input(self):
283307
"""Handle new input from the AsyncLLMEngine for async mode."""
284308
while not self.input_queue.empty():
285-
request = self.input_queue.get_nowait()
286-
self._handle_request(request)
309+
self._handle_request(self.input_queue.get_nowait())
310+
311+
def _handle_request(self, request: Tuple[bytes, Any]):
287312

288-
def _handle_request(self, request: EngineCoreRequest):
289313
try:
290-
self.add_request(request)
314+
request_type, request_data = request
315+
# Process request_data based on request_type
316+
if request_type == EngineCoreRequestType.AddRequest.value:
317+
assert isinstance(request_data, EngineCoreRequest), \
318+
f'Unexpected datatype {type(request_data)}'
319+
self.add_request(request_data)
320+
elif request_type == EngineCoreRequestType.AbortRequest.value:
321+
assert isinstance(request_data, list), \
322+
f'Unexpected datatype {type(request_data)}'
323+
self.scheduler.finish_requests(request_data,
324+
RequestStatus.FINISHED_ABORTED)
325+
else:
326+
raise ValueError(f"Unhandled request type {request_type}")
291327

292-
# TODO: handle abort via another socket
293328
# TODO: handle logits processors via cloudpickle
294329
# TODO: handle profiling
295-
296330
except Exception as e:
297331
# TODO: handle gracefully
298332
raise e

vllm/v1/engine/core_client.py

Lines changed: 40 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import List
1+
from typing import List, Union
22

33
import msgspec
44
import zmq
@@ -7,7 +7,8 @@
77
from vllm.logger import init_logger
88
from vllm.utils import get_open_zmq_ipc_path
99
from vllm.v1.engine import (POLLING_TIMEOUT_MS, EngineCoreOutput,
10-
EngineCoreOutputs, EngineCoreRequest)
10+
EngineCoreOutputs, EngineCoreRequest,
11+
EngineCoreRequestType)
1112
from vllm.v1.engine.core import EngineCore, EngineCoreProc
1213

1314
logger = init_logger(__name__)
@@ -52,12 +53,18 @@ def get_output(self) -> List[EngineCoreOutput]:
5253
def add_request(self, request: EngineCoreRequest) -> None:
5354
raise NotImplementedError
5455

56+
def abort_requests(self, request_ids: List[str]) -> None:
57+
raise NotImplementedError
58+
5559
async def get_output_async(self) -> List[EngineCoreOutput]:
5660
raise NotImplementedError
5761

5862
async def add_request_async(self, request: EngineCoreRequest) -> None:
5963
raise NotImplementedError
6064

65+
async def abort_requests_async(self, request_ids: List[str]) -> None:
66+
raise NotImplementedError
67+
6168

6269
class InprocClient(EngineCoreClient):
6370
"""
@@ -80,6 +87,9 @@ def get_output(self) -> List[EngineCoreOutput]:
8087
def add_request(self, request: EngineCoreRequest) -> None:
8188
self.engine_core.add_request(request)
8289

90+
def abort_requests(self, request_ids: List[str]) -> None:
91+
self.engine_core.abort_requests(request_ids)
92+
8393

8494
class MPClient(EngineCoreClient):
8595
"""
@@ -153,12 +163,23 @@ def get_output(self) -> List[EngineCoreOutput]:
153163

154164
return engine_core_outputs
155165

156-
def add_request(self, request: EngineCoreRequest) -> None:
157-
158-
self.input_socket.send_multipart((self.encoder.encode(request), ),
166+
def _send_input(self,
167+
request_type: EngineCoreRequestType,
168+
request: Union[EngineCoreRequest, List[str]]) \
169+
-> None:
170+
self.input_socket.send_multipart((
171+
request_type.value,
172+
self.encoder.encode(request),
173+
),
159174
copy=False,
160175
flags=zmq.NOBLOCK)
161176

177+
def add_request(self, request: EngineCoreRequest) -> None:
178+
self._send_input(EngineCoreRequestType.AddRequest, request)
179+
180+
def abort_requests(self, request_ids: List[str]) -> None:
181+
self._send_input(EngineCoreRequestType.AbortRequest, request_ids)
182+
162183

163184
class AsyncMPClient(MPClient):
164185
"""Asyncio-compatible client for multi-proc EngineCore."""
@@ -176,7 +197,19 @@ async def get_output_async(self) -> List[EngineCoreOutput]:
176197

177198
return engine_core_outputs
178199

200+
async def _send_input(self,
201+
request_type: EngineCoreRequestType,
202+
request: Union[EngineCoreRequest, List[str]])\
203+
-> None:
204+
await self.input_socket.send_multipart((
205+
request_type.value,
206+
self.encoder.encode(request),
207+
),
208+
copy=False,
209+
flags=zmq.NOBLOCK)
210+
179211
async def add_request_async(self, request: EngineCoreRequest) -> None:
212+
await self._send_input(EngineCoreRequestType.AddRequest, request)
180213

181-
await self.input_socket.send_multipart(
182-
(self.encoder.encode(request), ), copy=False, flags=zmq.NOBLOCK)
214+
async def abort_requests_async(self, request_ids: List[str]) -> None:
215+
await self._send_input(EngineCoreRequestType.AbortRequest, request_ids)

vllm/v1/engine/llm_engine.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Dict, Iterable, List, Mapping, Optional, Type, Union
1+
from typing import Dict, List, Mapping, Optional, Type, Union
22

33
from vllm.config import EngineConfig
44
from vllm.engine.arg_utils import EngineArgs
@@ -107,10 +107,8 @@ def has_unfinished_requests(self) -> bool:
107107
def validate_outputs(cls, outputs, output_type):
108108
return outputs
109109

110-
def abort_request(self, request_id: Union[str, Iterable[str]]) -> None:
111-
# TODO: send to EngineCore
112-
# TODO: send to Detokenizer
113-
pass
110+
def abort_request(self, request_ids: List[str]) -> None:
111+
self.engine_core.abort_requests(request_ids)
114112

115113
def add_request(
116114
self,
@@ -141,9 +139,11 @@ def step(self) -> List[RequestOutput]:
141139
engine_core_outputs = self.engine_core.get_output()
142140

143141
# 2) Detokenizer the EngineCoreOutput.
144-
request_outputs, to_abort = self.detokenizer.step(engine_core_outputs)
142+
request_outputs, requests_to_abort = self.detokenizer.step(
143+
engine_core_outputs)
145144

146145
# 3) Abort requests that finished due to stopping criteria.
147-
self.abort_request(to_abort)
146+
if requests_to_abort:
147+
self.abort_request(requests_to_abort)
148148

149149
return request_outputs

0 commit comments

Comments
 (0)