Skip to content

Commit

Permalink
make incremental detokenization OO, handle stop strings
Browse files Browse the repository at this point in the history
  • Loading branch information
njhill committed Nov 1, 2024
1 parent 3c14bdf commit ae26a38
Show file tree
Hide file tree
Showing 6 changed files with 203 additions and 124 deletions.
37 changes: 22 additions & 15 deletions vllm/engine/output_processor/stop_checker.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Callable, Optional
from typing import Callable, List, Optional, Tuple

from vllm.lora.request import LoRARequest
from vllm.sampling_params import SamplingParams
Expand Down Expand Up @@ -67,9 +67,13 @@ def maybe_stop_sequence(
return

# Check if any stop strings are matched.
stop_str = self._check_stop_strings(seq, new_char_count,
sampling_params)
if stop_str is not None:
stop = self.check_stop_strings(
seq.output_text, new_char_count, sampling_params.stop,
sampling_params.include_stop_str_in_output)
if stop is not None:
stop_str, truncate_to = stop
if truncate_to != -1:
seq.output_text = seq.output_text[:truncate_to]
seq.status = SequenceStatus.FINISHED_STOPPED
seq.stop_reason = stop_str
return
Expand All @@ -85,33 +89,36 @@ def maybe_stop_sequence(
return

@staticmethod
def _check_stop_strings(seq: Sequence, new_char_count: int,
sampling_params: SamplingParams) -> Optional[str]:
def check_stop_strings(
output_text: str,
new_char_count: int,
stop: List[str],
include_in_output: bool,
) -> Optional[Tuple[str, int]]:
"""Check if any stop strings are matched and truncate sequence
output text accordingly.
Returns the stop string if matched or else None.
"""
if not new_char_count or not sampling_params.stop:
if not new_char_count or not stop:
return None

for stop_str in sampling_params.stop:
for stop_str in stop:
stop_string_len = len(stop_str)
# Avoid searching already-searched text.
stop_index = seq.output_text.find(
stop_str, -new_char_count - stop_string_len)
stop_index = output_text.find(stop_str,
-new_char_count - stop_string_len)
if stop_index == -1:
continue

if sampling_params.include_stop_str_in_output:
if include_in_output:
# Truncate to end of stop string.
stop_index += stop_string_len
if stop_index >= len(seq.output_text):
if stop_index >= len(output_text):
# No truncation required.
return stop_str
return stop_str, -1

# Truncate the output text to either the beginning
# or end of the stop string.
seq.output_text = seq.output_text[:stop_index]
return stop_str
return stop_str, stop_index
return None
6 changes: 2 additions & 4 deletions vllm/v1/engine/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
import asyncio
from dataclasses import dataclass
from typing import List, Optional, Union

import msgspec

from vllm.lora.request import LoRARequest
from vllm.outputs import RequestOutput
from vllm.sampling_params import RequestOutputKind, SamplingParams

POLLING_TIMEOUT_MS = 5000
Expand All @@ -21,8 +19,8 @@ class DetokenizerRequest:
spaces_between_special_tokens: bool
output_kind: RequestOutputKind

# Queue for streaming outputs to clients.
output_queue: Optional[asyncio.Queue[RequestOutput]] = None
stop: List[str]
include_stop_str_in_output: bool


class EngineCoreRequest(msgspec.Struct):
Expand Down
12 changes: 11 additions & 1 deletion vllm/v1/engine/async_llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,10 @@ async def add_request(
raise KeyError(f"Request {request_id} already exists.")

# TODO: handle abort.
# IDEA(Nick): we could batch up aborts rather than sending
# them individually, so that we send at most one batch of
# aborts per step (added to any that we're doing due to
# stop string matches for that step)
def _abort():
pass

Expand All @@ -152,6 +156,11 @@ def _abort():

return stream.generator()

# TODO: we should support multiple prompts in one call, as you
# can do with LLM.generate. So that for multi-prompt completion
# requests we don't need to send multiple messages to core proc,
# and so we don't need multiple streams which then get
# re-multiplexed in the API server anyhow.
async def generate(
self,
prompt: PromptType,
Expand Down Expand Up @@ -188,7 +197,8 @@ async def run_output_handler(self):
# NOTE: we could simplify the Detokenizer code by returning full
# List[RequestOutput] rather than pushing to the Queue at the
# expense of doing another loop through List[RequestOutput].
self.detokenizer.step_streaming(outputs)
_to_abort = self.detokenizer.step_streaming(outputs)

# TODO: send aborts (in one message)
except BaseException as e:
logger.error(e)
Loading

0 comments on commit ae26a38

Please sign in to comment.