Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[V1] Fix cancellation propagation, log req completion consistently #30

Merged
merged 1 commit into from
Nov 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 9 additions & 8 deletions vllm/v1/engine/async_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,8 +139,7 @@ async def add_request(
raise KeyError(f"Request {request_id} already exists.")

# 1) Create a new AsyncStream for the request.
stream = self._add_request_to_streams(request_id,
verbose=self.log_requests)
stream = self._add_request_to_streams(request_id)

# 2) Convert input --> DetokenizerRequest / EngineCoreRequest.
detokenizer_req, engine_core_req = self.processor.process_inputs(
Expand Down Expand Up @@ -212,7 +211,6 @@ def _finish_stream(self, request_id: str):
def _add_request_to_streams(
self,
request_id: str,
verbose: bool = False,
) -> AsyncStream:

if request_id in self.request_streams:
Expand All @@ -223,26 +221,26 @@ def _add_request_to_streams(
stream = AsyncStream(request_id, aborted_reqs.append)
self.request_streams[request_id] = stream

if verbose:
if self.log_requests:
logger.info("Added request %s.", request_id)

return stream

async def _process_cancellations(self) -> None:
"""
Process requests cancelled from user user disconnecting.
Process requests cancelled from user disconnecting.

When a client disconnects, AsyncStream._cancel() is called.
We passed a callback to AsyncStream(), which appends to
self.client_aborted_requests.

As a result, if any requests are cancels from the user side
As a result, if any requests are canceled from the user side
the request_id will show up in self.client_aborted_requests.
"""

# Avoid streams having circular ref to parent AsyncLLM object.
if not self.client_aborted_requests:
return []
return
reqs_to_abort = self.client_aborted_requests.copy()
self.client_aborted_requests.clear()

Expand All @@ -251,10 +249,11 @@ async def _process_cancellations(self) -> None:

# Remove from RequestStreams.
for request_id in reqs_to_abort:
if self.log_requests:
logger.info("User-cancelled request %s.", request_id)
self._finish_stream(request_id)

# Remove from EngineCore.
print(f"{reqs_to_abort=}")
await self.engine_core.abort_requests_async(reqs_to_abort)

def _process_request_outputs(self, request_outputs: List[RequestOutput]):
Expand All @@ -271,6 +270,8 @@ def _process_request_outputs(self, request_outputs: List[RequestOutput]):

# If finished, remove from the tracker.
if request_output.finished:
if self.log_requests:
logger.info("Finished request %s.", request_id)
self._finish_stream(request_id)

async def _run_output_handler(self):
Expand Down
16 changes: 7 additions & 9 deletions vllm/v1/engine/async_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,7 @@ class AsyncStream:

STOP_ITERATION = Exception() # Sentinel

def __init__(self, request_id: str,
cancel: Callable[[str], Awaitable[None]]) -> None:
def __init__(self, request_id: str, cancel: Callable[[str], None]) -> None:
self.request_id = request_id
self._cancel = cancel
self._queue: asyncio.Queue = asyncio.Queue()
Expand All @@ -32,24 +31,23 @@ def finish(
self._queue.put_nowait(exception if self._is_raisable(exception)
else AsyncStream.STOP_ITERATION)

@property
def finished(self) -> bool:
return self._finished

async def generator(
self
) -> AsyncGenerator[Union[RequestOutput, EmbeddingRequestOutput], None]:
finished = False
try:
while True:
result = await self._queue.get()
if self._is_raisable(result):
finished = True
if result == AsyncStream.STOP_ITERATION:
return
raise result
yield result
except GeneratorExit:
self._cancel(self.request_id)
raise asyncio.CancelledError from None
finally:
self._finished = True
if not finished:
self._cancel(self.request_id)

@staticmethod
def _is_raisable(value: Any):
Expand Down