Skip to content

Commit

Permalink
[Bugfix] Fix request cancellation without polling (#11190)
Browse files Browse the repository at this point in the history
  • Loading branch information
joerunde authored Dec 17, 2024
1 parent f9ecbb1 commit 2d1b9ba
Show file tree
Hide file tree
Showing 12 changed files with 164 additions and 103 deletions.
51 changes: 51 additions & 0 deletions tests/entrypoints/openai/test_basic.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import asyncio
from http import HTTPStatus
from typing import List

import openai
import pytest
import pytest_asyncio
import requests
Expand Down Expand Up @@ -103,3 +105,52 @@ async def test_check_health(server: RemoteOpenAIServer):
response = requests.get(server.url_for("health"))

assert response.status_code == HTTPStatus.OK


@pytest.mark.parametrize(
"server_args",
[
pytest.param(["--max-model-len", "10100"],
id="default-frontend-multiprocessing"),
pytest.param(
["--disable-frontend-multiprocessing", "--max-model-len", "10100"],
id="disable-frontend-multiprocessing")
],
indirect=True,
)
@pytest.mark.asyncio
async def test_request_cancellation(server: RemoteOpenAIServer):
# clunky test: send an ungodly amount of load in with short timeouts
# then ensure that it still responds quickly afterwards

chat_input = [{"role": "user", "content": "Write a long story"}]
client = server.get_async_client(timeout=0.5)
tasks = []
# Request about 2 million tokens
for _ in range(200):
task = asyncio.create_task(
client.chat.completions.create(messages=chat_input,
model=MODEL_NAME,
max_tokens=10000,
extra_body={"min_tokens": 10000}))
tasks.append(task)

done, pending = await asyncio.wait(tasks,
return_when=asyncio.ALL_COMPLETED)

# Make sure all requests were sent to the server and timed out
# (We don't want to hide other errors like 400s that would invalidate this
# test)
assert len(pending) == 0
for d in done:
with pytest.raises(openai.APITimeoutError):
d.result()

# If the server had not cancelled all the other requests, then it would not
# be able to respond to this one within the timeout
client = server.get_async_client(timeout=5)
response = await client.chat.completions.create(messages=chat_input,
model=MODEL_NAME,
max_tokens=10)

assert len(response.choices) == 1
6 changes: 1 addition & 5 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import asyncio
import os
import socket
from functools import partial
from typing import AsyncIterator, Tuple

import pytest
Expand All @@ -26,10 +25,7 @@ async def mock_async_iterator(idx: int):
print(f"iterator {idx} cancelled")

iterators = [mock_async_iterator(i) for i in range(3)]
merged_iterator = merge_async_iterators(*iterators,
is_cancelled=partial(asyncio.sleep,
0,
result=False))
merged_iterator = merge_async_iterators(*iterators)

async def stream_output(generator: AsyncIterator[Tuple[int, str]]):
async for idx, output in generator:
Expand Down
11 changes: 5 additions & 6 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,12 +163,11 @@ def get_client(self):
api_key=self.DUMMY_API_KEY,
)

def get_async_client(self):
return openai.AsyncOpenAI(
base_url=self.url_for("v1"),
api_key=self.DUMMY_API_KEY,
max_retries=0,
)
def get_async_client(self, **kwargs):
return openai.AsyncOpenAI(base_url=self.url_for("v1"),
api_key=self.DUMMY_API_KEY,
max_retries=0,
**kwargs)


def _test_completion(
Expand Down
46 changes: 27 additions & 19 deletions vllm/engine/async_llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1065,16 +1065,20 @@ async def generate(
>>> # Process and return the final output
>>> ...
"""
async for output in await self.add_request(
request_id,
prompt,
sampling_params,
lora_request=lora_request,
trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request,
priority=priority,
):
yield LLMEngine.validate_output(output, RequestOutput)
try:
async for output in await self.add_request(
request_id,
prompt,
sampling_params,
lora_request=lora_request,
trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request,
priority=priority,
):
yield LLMEngine.validate_output(output, RequestOutput)
except asyncio.CancelledError:
await self.abort(request_id)
raise

async def encode(
self,
Expand Down Expand Up @@ -1147,15 +1151,19 @@ async def encode(
>>> # Process and return the final output
>>> ...
"""
async for output in await self.add_request(
request_id,
prompt,
pooling_params,
lora_request=lora_request,
trace_headers=trace_headers,
priority=priority,
):
yield LLMEngine.validate_output(output, PoolingRequestOutput)
try:
async for output in await self.add_request(
request_id,
prompt,
pooling_params,
lora_request=lora_request,
trace_headers=trace_headers,
priority=priority,
):
yield LLMEngine.validate_output(output, PoolingRequestOutput)
except asyncio.CancelledError:
await self.abort(request_id)
raise

async def abort(self, request_id: str) -> None:
"""Abort a request.
Expand Down
11 changes: 7 additions & 4 deletions vllm/entrypoints/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.entrypoints.launcher import serve_http
from vllm.entrypoints.utils import with_cancellation
from vllm.logger import init_logger
from vllm.sampling_params import SamplingParams
from vllm.usage.usage_lib import UsageContext
from vllm.utils import (FlexibleArgumentParser, iterate_with_cancellation,
random_uuid)
from vllm.utils import FlexibleArgumentParser, random_uuid
from vllm.version import __version__ as VLLM_VERSION

logger = init_logger("vllm.entrypoints.api_server")
Expand All @@ -47,15 +47,18 @@ async def generate(request: Request) -> Response:
- other fields: the sampling parameters (See `SamplingParams` for details).
"""
request_dict = await request.json()
return await _generate(request_dict, raw_request=request)


@with_cancellation
async def _generate(request_dict: dict, raw_request: Request) -> Response:
prompt = request_dict.pop("prompt")
stream = request_dict.pop("stream", False)
sampling_params = SamplingParams(**request_dict)
request_id = random_uuid()

assert engine is not None
results_generator = engine.generate(prompt, sampling_params, request_id)
results_generator = iterate_with_cancellation(
results_generator, is_cancelled=request.is_disconnected)

# Streaming case
async def stream_results() -> AsyncGenerator[bytes, None]:
Expand Down
8 changes: 8 additions & 0 deletions vllm/entrypoints/openai/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
from vllm.entrypoints.openai.serving_tokenization import (
OpenAIServingTokenization)
from vllm.entrypoints.openai.tool_parsers import ToolParserManager
from vllm.entrypoints.utils import with_cancellation
from vllm.logger import init_logger
from vllm.usage.usage_lib import UsageContext
from vllm.utils import (FlexibleArgumentParser, get_open_zmq_ipc_path,
Expand Down Expand Up @@ -311,6 +312,7 @@ async def health(raw_request: Request) -> Response:


@router.post("/tokenize")
@with_cancellation
async def tokenize(request: TokenizeRequest, raw_request: Request):
handler = tokenization(raw_request)

Expand All @@ -325,6 +327,7 @@ async def tokenize(request: TokenizeRequest, raw_request: Request):


@router.post("/detokenize")
@with_cancellation
async def detokenize(request: DetokenizeRequest, raw_request: Request):
handler = tokenization(raw_request)

Expand Down Expand Up @@ -353,6 +356,7 @@ async def show_version():


@router.post("/v1/chat/completions")
@with_cancellation
async def create_chat_completion(request: ChatCompletionRequest,
raw_request: Request):
handler = chat(raw_request)
Expand All @@ -373,6 +377,7 @@ async def create_chat_completion(request: ChatCompletionRequest,


@router.post("/v1/completions")
@with_cancellation
async def create_completion(request: CompletionRequest, raw_request: Request):
handler = completion(raw_request)
if handler is None:
Expand All @@ -390,6 +395,7 @@ async def create_completion(request: CompletionRequest, raw_request: Request):


@router.post("/v1/embeddings")
@with_cancellation
async def create_embedding(request: EmbeddingRequest, raw_request: Request):
handler = embedding(raw_request)
if handler is None:
Expand All @@ -407,6 +413,7 @@ async def create_embedding(request: EmbeddingRequest, raw_request: Request):


@router.post("/score")
@with_cancellation
async def create_score(request: ScoreRequest, raw_request: Request):
handler = score(raw_request)
if handler is None:
Expand All @@ -424,6 +431,7 @@ async def create_score(request: ScoreRequest, raw_request: Request):


@router.post("/v1/score")
@with_cancellation
async def create_score_v1(request: ScoreRequest, raw_request: Request):
logger.warning(
"To indicate that Score API is not part of standard OpenAI API, we "
Expand Down
5 changes: 0 additions & 5 deletions vllm/entrypoints/openai/serving_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
from vllm.sequence import Logprob
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
from vllm.transformers_utils.tokenizers import maybe_serialize_tool_calls
from vllm.utils import iterate_with_cancellation

logger = init_logger(__name__)

Expand Down Expand Up @@ -234,10 +233,6 @@ async def create_chat_completion(
assert len(generators) == 1
result_generator, = generators

if raw_request:
result_generator = iterate_with_cancellation(
result_generator, raw_request.is_disconnected)

# Streaming response
if request.stream:
return self.chat_completion_stream_generator(
Expand Down
3 changes: 1 addition & 2 deletions vllm/entrypoints/openai/serving_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,8 +159,7 @@ async def create_completion(
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))

result_generator = merge_async_iterators(
*generators, is_cancelled=raw_request.is_disconnected)
result_generator = merge_async_iterators(*generators)

model_name = self._get_model_name(lora_request)
num_prompts = len(engine_prompts)
Expand Down
5 changes: 1 addition & 4 deletions vllm/entrypoints/openai/serving_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,10 +202,7 @@ async def create_embedding(
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))

result_generator = merge_async_iterators(
*generators,
is_cancelled=raw_request.is_disconnected if raw_request else None,
)
result_generator = merge_async_iterators(*generators)

num_prompts = len(engine_prompts)

Expand Down
5 changes: 1 addition & 4 deletions vllm/entrypoints/openai/serving_score.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,10 +186,7 @@ async def create_score(
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))

result_generator = merge_async_iterators(
*generators,
is_cancelled=raw_request.is_disconnected if raw_request else None,
)
result_generator = merge_async_iterators(*generators)

num_prompts = len(engine_prompts)

Expand Down
57 changes: 57 additions & 0 deletions vllm/entrypoints/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import asyncio
import functools

from fastapi import Request


async def listen_for_disconnect(request: Request) -> None:
"""Returns if a disconnect message is received"""
while True:
message = await request.receive()
if message["type"] == "http.disconnect":
break


def with_cancellation(handler_func):
"""Decorator that allows a route handler to be cancelled by client
disconnections.
This does _not_ use request.is_disconnected, which does not work with
middleware. Instead this follows the pattern from
starlette.StreamingResponse, which simultaneously awaits on two tasks- one
to wait for an http disconnect message, and the other to do the work that we
want done. When the first task finishes, the other is cancelled.
A core assumption of this method is that the body of the request has already
been read. This is a safe assumption to make for fastapi handlers that have
already parsed the body of the request into a pydantic model for us.
This decorator is unsafe to use elsewhere, as it will consume and throw away
all incoming messages for the request while it looks for a disconnect
message.
In the case where a `StreamingResponse` is returned by the handler, this
wrapper will stop listening for disconnects and instead the response object
will start listening for disconnects.
"""

# Functools.wraps is required for this wrapper to appear to fastapi as a
# normal route handler, with the correct request type hinting.
@functools.wraps(handler_func)
async def wrapper(*args, **kwargs):

# The request is either the second positional arg or `raw_request`
request = args[1] if len(args) > 1 else kwargs["raw_request"]

handler_task = asyncio.create_task(handler_func(*args, **kwargs))
cancellation_task = asyncio.create_task(listen_for_disconnect(request))

done, pending = await asyncio.wait([handler_task, cancellation_task],
return_when=asyncio.FIRST_COMPLETED)
for task in pending:
task.cancel()

if handler_task in done:
return handler_task.result()
return None

return wrapper
Loading

0 comments on commit 2d1b9ba

Please sign in to comment.