Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Stop strings #23

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 22 additions & 15 deletions vllm/engine/output_processor/stop_checker.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Callable, Optional
from typing import Callable, List, Optional, Tuple

from vllm.lora.request import LoRARequest
from vllm.sampling_params import SamplingParams
Expand Down Expand Up @@ -67,9 +67,13 @@ def maybe_stop_sequence(
return

# Check if any stop strings are matched.
stop_str = self._check_stop_strings(seq, new_char_count,
sampling_params)
if stop_str is not None:
stop = self.check_stop_strings(
seq.output_text, new_char_count, sampling_params.stop,
sampling_params.include_stop_str_in_output)
if stop is not None:
stop_str, truncate_to = stop
if truncate_to != -1:
seq.output_text = seq.output_text[:truncate_to]
seq.status = SequenceStatus.FINISHED_STOPPED
seq.stop_reason = stop_str
return
Expand All @@ -85,33 +89,36 @@ def maybe_stop_sequence(
return

@staticmethod
def _check_stop_strings(seq: Sequence, new_char_count: int,
sampling_params: SamplingParams) -> Optional[str]:
def check_stop_strings(
output_text: str,
new_char_count: int,
stop: List[str],
include_in_output: bool,
) -> Optional[Tuple[str, int]]:
"""Check if any stop strings are matched and truncate sequence
output text accordingly.

Returns the stop string if matched or else None.
"""
if not new_char_count or not sampling_params.stop:
if not new_char_count or not stop:
return None

for stop_str in sampling_params.stop:
for stop_str in stop:
stop_string_len = len(stop_str)
# Avoid searching already-searched text.
stop_index = seq.output_text.find(
stop_str, -new_char_count - stop_string_len)
stop_index = output_text.find(stop_str,
-new_char_count - stop_string_len)
if stop_index == -1:
continue

if sampling_params.include_stop_str_in_output:
if include_in_output:
# Truncate to end of stop string.
stop_index += stop_string_len
if stop_index >= len(seq.output_text):
if stop_index >= len(output_text):
# No truncation required.
return stop_str
return stop_str, -1

# Truncate the output text to either the beginning
# or end of the stop string.
seq.output_text = seq.output_text[:stop_index]
return stop_str
return stop_str, stop_index
return None
24 changes: 12 additions & 12 deletions vllm/outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,20 +114,24 @@ def __init__(
self.encoder_prompt_token_ids = encoder_prompt_token_ids

@classmethod
def create_empty(cls, request_id: str, prompt: Optional[str],
prompt_token_ids: Optional[List[int]]) -> "RequestOutput":
def new(
cls,
request_id: str,
prompt: Optional[str],
prompt_token_ids: Optional[List[int]],
text: str,
token_ids: List[int],
finished: bool = False,
) -> "RequestOutput":
"""Initialize a new "empty" RequestOutput object."""

# TODO: Support `n` > 1.
completion_output = CompletionOutput(
index=0,
text="",
token_ids=[],
text=text,
token_ids=token_ids,
cumulative_logprob=None,
logprobs=None, # TODO
finish_reason=None,
stop_reason=None,
lora_request=None,
)

return RequestOutput(
Expand All @@ -136,11 +140,7 @@ def create_empty(cls, request_id: str, prompt: Optional[str],
prompt_token_ids=prompt_token_ids,
prompt_logprobs=None, # TODO
outputs=[completion_output],
finished=False,
metrics=None,
lora_request=None,
encoder_prompt=None,
encoder_prompt_token_ids=None,
finished=finished,
)

@classmethod
Expand Down
6 changes: 2 additions & 4 deletions vllm/v1/engine/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
import asyncio
from dataclasses import dataclass
from typing import List, Optional, Union

import msgspec

from vllm.lora.request import LoRARequest
from vllm.outputs import RequestOutput
from vllm.sampling_params import RequestOutputKind, SamplingParams

POLLING_TIMEOUT_MS = 5000
Expand All @@ -21,8 +19,8 @@ class DetokenizerRequest:
spaces_between_special_tokens: bool
output_kind: RequestOutputKind

# Queue for streaming outputs to clients.
output_queue: Optional[asyncio.Queue[RequestOutput]] = None
stop: List[str]
include_stop_str_in_output: bool


class EngineCoreRequest(msgspec.Struct):
Expand Down
12 changes: 11 additions & 1 deletion vllm/v1/engine/async_llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,10 @@ async def add_request(
raise KeyError(f"Request {request_id} already exists.")

# TODO: handle abort.
# IDEA(Nick): we could batch up aborts rather than sending
# them individually, so that we send at most one batch of
# aborts per step (added to any that we're doing due to
# stop string matches for that step)
def _abort():
pass

Expand All @@ -152,6 +156,11 @@ def _abort():

return stream.generator()

# TODO: we should support multiple prompts in one call, as you
# can do with LLM.generate. So that for multi-prompt completion
# requests we don't need to send multiple messages to core proc,
# and so we don't need multiple streams which then get
# re-multiplexed in the API server anyhow.
async def generate(
self,
prompt: PromptType,
Expand Down Expand Up @@ -188,7 +197,8 @@ async def run_output_handler(self):
# NOTE: we could simplify the Detokenizer code by returning full
# List[RequestOutput] rather than pushing to the Queue at the
# expense of doing another loop through List[RequestOutput].
self.detokenizer.step_streaming(outputs)
_to_abort = self.detokenizer.step_streaming(outputs)

# TODO: send aborts (in one message)
except BaseException as e:
logger.error(e)
Loading