Skip to content

Commit

Permalink
Merge pull request #23 from njhill/stop-strings
Browse files Browse the repository at this point in the history
Stop strings
  • Loading branch information
robertgshaw2-neuralmagic authored Nov 2, 2024
2 parents 2ff6fb4 + 70c8344 commit ed8ef9d
Show file tree
Hide file tree
Showing 7 changed files with 228 additions and 144 deletions.
37 changes: 22 additions & 15 deletions vllm/engine/output_processor/stop_checker.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Callable, Optional
from typing import Callable, List, Optional, Tuple

from vllm.lora.request import LoRARequest
from vllm.sampling_params import SamplingParams
Expand Down Expand Up @@ -67,9 +67,13 @@ def maybe_stop_sequence(
return

# Check if any stop strings are matched.
stop_str = self._check_stop_strings(seq, new_char_count,
sampling_params)
if stop_str is not None:
stop = self.check_stop_strings(
seq.output_text, new_char_count, sampling_params.stop,
sampling_params.include_stop_str_in_output)
if stop is not None:
stop_str, truncate_to = stop
if truncate_to != -1:
seq.output_text = seq.output_text[:truncate_to]
seq.status = SequenceStatus.FINISHED_STOPPED
seq.stop_reason = stop_str
return
Expand All @@ -85,33 +89,36 @@ def maybe_stop_sequence(
return

@staticmethod
def _check_stop_strings(seq: Sequence, new_char_count: int,
sampling_params: SamplingParams) -> Optional[str]:
def check_stop_strings(
output_text: str,
new_char_count: int,
stop: List[str],
include_in_output: bool,
) -> Optional[Tuple[str, int]]:
"""Check if any stop strings are matched and truncate sequence
output text accordingly.
Returns the stop string if matched or else None.
"""
if not new_char_count or not sampling_params.stop:
if not new_char_count or not stop:
return None

for stop_str in sampling_params.stop:
for stop_str in stop:
stop_string_len = len(stop_str)
# Avoid searching already-searched text.
stop_index = seq.output_text.find(
stop_str, -new_char_count - stop_string_len)
stop_index = output_text.find(stop_str,
-new_char_count - stop_string_len)
if stop_index == -1:
continue

if sampling_params.include_stop_str_in_output:
if include_in_output:
# Truncate to end of stop string.
stop_index += stop_string_len
if stop_index >= len(seq.output_text):
if stop_index >= len(output_text):
# No truncation required.
return stop_str
return stop_str, -1

# Truncate the output text to either the beginning
# or end of the stop string.
seq.output_text = seq.output_text[:stop_index]
return stop_str
return stop_str, stop_index
return None
24 changes: 12 additions & 12 deletions vllm/outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,20 +114,24 @@ def __init__(
self.encoder_prompt_token_ids = encoder_prompt_token_ids

@classmethod
def create_empty(cls, request_id: str, prompt: Optional[str],
prompt_token_ids: Optional[List[int]]) -> "RequestOutput":
def new(
cls,
request_id: str,
prompt: Optional[str],
prompt_token_ids: Optional[List[int]],
text: str,
token_ids: List[int],
finished: bool = False,
) -> "RequestOutput":
"""Initialize a new "empty" RequestOutput object."""

# TODO: Support `n` > 1.
completion_output = CompletionOutput(
index=0,
text="",
token_ids=[],
text=text,
token_ids=token_ids,
cumulative_logprob=None,
logprobs=None, # TODO
finish_reason=None,
stop_reason=None,
lora_request=None,
)

return RequestOutput(
Expand All @@ -136,11 +140,7 @@ def create_empty(cls, request_id: str, prompt: Optional[str],
prompt_token_ids=prompt_token_ids,
prompt_logprobs=None, # TODO
outputs=[completion_output],
finished=False,
metrics=None,
lora_request=None,
encoder_prompt=None,
encoder_prompt_token_ids=None,
finished=finished,
)

@classmethod
Expand Down
6 changes: 2 additions & 4 deletions vllm/v1/engine/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
import asyncio
from dataclasses import dataclass
from typing import List, Optional, Union

import msgspec

from vllm.lora.request import LoRARequest
from vllm.outputs import RequestOutput
from vllm.sampling_params import RequestOutputKind, SamplingParams

POLLING_TIMEOUT_MS = 5000
Expand All @@ -21,8 +19,8 @@ class DetokenizerRequest:
spaces_between_special_tokens: bool
output_kind: RequestOutputKind

# Queue for streaming outputs to clients.
output_queue: Optional[asyncio.Queue[RequestOutput]] = None
stop: List[str]
include_stop_str_in_output: bool


class EngineCoreRequest(msgspec.Struct):
Expand Down
12 changes: 11 additions & 1 deletion vllm/v1/engine/async_llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,10 @@ async def add_request(
raise KeyError(f"Request {request_id} already exists.")

# TODO: handle abort.
# IDEA(Nick): we could batch up aborts rather than sending
# them individually, so that we send at most one batch of
# aborts per step (added to any that we're doing due to
# stop string matches for that step)
def _abort():
pass

Expand All @@ -152,6 +156,11 @@ def _abort():

return stream.generator()

# TODO: we should support multiple prompts in one call, as you
# can do with LLM.generate. So that for multi-prompt completion
# requests we don't need to send multiple messages to core proc,
# and so we don't need multiple streams which then get
# re-multiplexed in the API server anyhow.
async def generate(
self,
prompt: PromptType,
Expand Down Expand Up @@ -188,7 +197,8 @@ async def run_output_handler(self):
# NOTE: we could simplify the Detokenizer code by returning full
# List[RequestOutput] rather than pushing to the Queue at the
# expense of doing another loop through List[RequestOutput].
self.detokenizer.step_streaming(outputs)
_to_abort = self.detokenizer.step_streaming(outputs)

# TODO: send aborts (in one message)
except BaseException as e:
logger.error(e)
Loading

0 comments on commit ed8ef9d

Please sign in to comment.