Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

regex stopping condition #2035

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/references/sampling_params.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ The `sampling_params` follows this format
max_new_tokens: int = 128,
# Stop when hitting any of the strings in this list.
stop: Optional[Union[str, List[str]]] = None,
# Stop when hitting any of the regex patterns in this list.
stop_regex: Optional[Union[str, List[str]]] = None,
# Stop when hitting any of the token_ids in this list. Could be useful when mixed with
# `min_new_tokens`.
stop_token_ids: Optional[List[int]] = [],
Expand Down
6 changes: 6 additions & 0 deletions python/sglang/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ def gen(
max_tokens: Optional[int] = None,
min_tokens: Optional[int] = None,
stop: Optional[Union[str, List[str]]] = None,
stop_regex: Optional[Union[str, List[str]]] = None,
stop_token_ids: Optional[List[int]] = None,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
Expand Down Expand Up @@ -121,6 +122,7 @@ def gen(
max_tokens,
min_tokens,
stop,
stop_regex,
stop_token_ids,
temperature,
top_p,
Expand All @@ -143,6 +145,7 @@ def gen_int(
name: Optional[str] = None,
max_tokens: Optional[int] = None,
stop: Optional[Union[str, List[str]]] = None,
stop_regex: Optional[Union[str, List[str]]] = None,
stop_token_ids: Optional[List[int]] = None,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
Expand All @@ -161,6 +164,7 @@ def gen_int(
max_tokens,
None,
stop,
stop_regex,
stop_token_ids,
temperature,
top_p,
Expand All @@ -182,6 +186,7 @@ def gen_string(
name: Optional[str] = None,
max_tokens: Optional[int] = None,
stop: Optional[Union[str, List[str]]] = None,
stop_regex: Optional[Union[str, List[str]]] = None,
stop_token_ids: Optional[List[int]] = None,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
Expand All @@ -200,6 +205,7 @@ def gen_string(
max_tokens,
None,
stop,
stop_regex,
stop_token_ids,
temperature,
top_p,
Expand Down
1 change: 1 addition & 0 deletions python/sglang/lang/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -678,6 +678,7 @@ def _resolve_sampling_params(self, sampling_params):
"max_new_tokens",
"min_new_tokens",
"stop",
"stop_regex",
"stop_token_ids",
"temperature",
"top_p",
Expand Down
13 changes: 13 additions & 0 deletions python/sglang/lang/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ class SglSamplingParams:
max_new_tokens: int = 128
min_new_tokens: int = 0
stop: Union[str, List[str]] = ()
stop_regex: Union[str, List[str]] = ()
stop_token_ids: Optional[List[int]] = ()
temperature: float = 1.0
top_p: float = 1.0
Expand All @@ -42,6 +43,7 @@ def clone(self):
self.max_new_tokens,
self.min_new_tokens,
self.stop,
self.stop_regex,
self.stop_token_ids,
self.temperature,
self.top_p,
Expand Down Expand Up @@ -117,6 +119,7 @@ def to_srt_kwargs(self):
"max_new_tokens": self.max_new_tokens,
"min_new_tokens": self.min_new_tokens,
"stop": self.stop,
"stop_regex": self.stop_regex,
"stop_token_ids": self.stop_token_ids,
"temperature": self.temperature,
"top_p": self.top_p,
Expand Down Expand Up @@ -154,6 +157,7 @@ def run(
*args,
max_new_tokens: int = 128,
stop: Optional[Union[str, List[str]]] = None,
stop_regex: Optional[Union[str, List[str]]] = None,
stop_token_ids: Optional[List[int]] = None,
temperature: float = 1.0,
top_p: float = 1.0,
Expand All @@ -178,10 +182,13 @@ def run(
stop = []
if stop_token_ids is None:
stop_token_ids = []
if stop_regex is None:
stop_regex = []

default_sampling_para = SglSamplingParams(
max_new_tokens=max_new_tokens,
stop=stop,
stop_regex=stop_regex,
stop_token_ids=stop_token_ids,
temperature=temperature,
top_p=top_p,
Expand Down Expand Up @@ -212,6 +219,7 @@ def run_batch(
*,
max_new_tokens: int = 128,
stop: Optional[Union[str, List[str]]] = None,
stop_regex: Optional[Union[str, List[str]]] = None,
stop_token_ids: Optional[List[int]] = None,
temperature: float = 1.0,
top_p: float = 1.0,
Expand All @@ -234,6 +242,8 @@ def run_batch(
stop = []
if stop_token_ids is None:
stop_token_ids = []
if stop_regex is None:
stop_regex = []

assert isinstance(batch_kwargs, (list, tuple))
if len(batch_kwargs) == 0:
Expand All @@ -256,6 +266,7 @@ def run_batch(
default_sampling_para = SglSamplingParams(
max_new_tokens=max_new_tokens,
stop=stop,
stop_regex=stop_regex,
stop_token_ids=stop_token_ids,
temperature=temperature,
top_p=top_p,
Expand Down Expand Up @@ -438,6 +449,7 @@ def __init__(
max_new_tokens: Optional[int] = None,
min_new_tokens: Optional[int] = None,
stop: Optional[Union[str, List[str]]] = None,
stop_regex: Optional[Union[str, List[str]]] = None,
stop_token_ids: Optional[List[int]] = None,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
Expand All @@ -461,6 +473,7 @@ def __init__(
max_new_tokens=max_new_tokens,
min_new_tokens=min_new_tokens,
stop=stop,
stop_regex=stop_regex,
stop_token_ids=stop_token_ids,
temperature=temperature,
top_p=top_p,
Expand Down
26 changes: 17 additions & 9 deletions python/sglang/srt/managers/schedule_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@

import dataclasses
import logging
import re
from typing import List, Optional, Tuple, Union

import torch
Expand Down Expand Up @@ -210,6 +211,7 @@ def __init__(
# 3: last token
self.vid = 0 # version id to sync decode status with in detokenizer_manager
self.decoded_text = ""
self.stop_check_text = ""
self.surr_offset = None # Surrounding offset to defeat the cleanup algorithm
self.read_offset = None

Expand Down Expand Up @@ -350,16 +352,22 @@ def check_finished(self):
self.finished_reason = FINISH_MATCHED_TOKEN(matched=last_token_id)
return

# Check stop strings
if len(self.sampling_params.stop_strs) > 0:
tail_str = self.tokenizer.decode(
self.output_ids[-(self.sampling_params.stop_str_max_len + 1) :]
)
if (
len(self.sampling_params.stop_strs) > 0
or len(self.sampling_params.stop_regex_strs) > 0
):
self.stop_check_text += self.tokenizer.decode(last_token_id)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You cannot decode text token by token and concatenate the string. This will lead to wrong outputs.

Copy link
Author

@jancervenka jancervenka Nov 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh right, I didn't know there are tokenizers where this doesn't work. Is it then ok to decode the entire output each time? Or decode a fixed window and accept that it's not 100% reliable?


for stop_str in self.sampling_params.stop_strs:
if stop_str in tail_str or stop_str in self.decoded_text:
self.finished_reason = FINISH_MATCHED_STR(matched=stop_str)
return
# Check stop strings
for stop_str in self.sampling_params.stop_strs:
if stop_str in self.stop_check_text or stop_str in self.decoded_text:
self.finished_reason = FINISH_MATCHED_STR(matched=stop_str)
return

for stop_regex_str in self.sampling_params.stop_regex_strs:
if re.search(stop_regex_str, self.stop_check_text):
self.finished_reason = FINISH_MATCHED_STR(matched=stop_regex_str)
return

def jump_forward_and_retokenize(self, jump_forward_str, next_state):
if self.origin_input_text is None:
Expand Down
5 changes: 5 additions & 0 deletions python/sglang/srt/openai_api/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,6 +508,7 @@ def v1_generate_request(
"max_new_tokens": request.max_tokens,
"min_new_tokens": request.min_tokens,
"stop": request.stop,
"stop_regex": request.stop_regex,
"stop_token_ids": request.stop_token_ids,
"top_p": request.top_p,
"presence_penalty": request.presence_penalty,
Expand Down Expand Up @@ -892,6 +893,7 @@ def v1_chat_generate_request(
if assistant_prefix:
prompt_ids += tokenizer_manager.tokenizer.encode(assistant_prefix)
stop = request.stop
stop_regex = request.stop_regex
image_data = None
modalities = []
else:
Expand All @@ -900,6 +902,7 @@ def v1_chat_generate_request(
image_data = conv.image_data
modalities = conv.modalities
stop = conv.stop_str or []
stop_regex = []
if request.stop:
if isinstance(request.stop, str):
stop.append(request.stop)
Expand All @@ -910,6 +913,7 @@ def v1_chat_generate_request(
# Use the raw prompt and stop strings if the messages is already a string.
prompt_ids = request.messages
stop = request.stop
stop_regex = request.stop_regex
image_data = None
modalities = []
input_ids.append(prompt_ids)
Expand All @@ -922,6 +926,7 @@ def v1_chat_generate_request(
"max_new_tokens": request.max_tokens,
"min_new_tokens": request.min_tokens,
"stop": stop,
"stop_regex": stop_regex,
"stop_token_ids": request.stop_token_ids,
"top_p": request.top_p,
"presence_penalty": request.presence_penalty,
Expand Down
2 changes: 2 additions & 0 deletions python/sglang/srt/openai_api/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ class CompletionRequest(BaseModel):
regex: Optional[str] = None
min_tokens: int = 0
repetition_penalty: float = 1.0
stop_regex: Optional[Union[str, List[str]]] = Field(default_factory=list)
stop_token_ids: Optional[List[int]] = None
no_stop_trim: bool = False
ignore_eos: bool = False
Expand Down Expand Up @@ -280,6 +281,7 @@ class ChatCompletionRequest(BaseModel):
regex: Optional[str] = None
min_tokens: int = 0
repetition_penalty: float = 1.0
stop_regex: Optional[Union[str, List[str]]] = Field(default_factory=list)
stop_token_ids: Optional[List[int]] = None
no_stop_trim: bool = False
ignore_eos: bool = False
Expand Down
8 changes: 8 additions & 0 deletions python/sglang/srt/sampling/sampling_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def __init__(
max_new_tokens: int = 128,
min_new_tokens: int = 0,
stop: Optional[Union[str, List[str]]] = None,
stop_regex: Optional[Union[str, List[str]]] = None,
stop_token_ids: Optional[List[int]] = None,
temperature: float = 1.0,
top_p: float = 1.0,
Expand All @@ -50,6 +51,7 @@ def __init__(
self.presence_penalty = presence_penalty
self.repetition_penalty = repetition_penalty
self.stop_strs = stop
self.stop_regex_strs = stop_regex
if stop_token_ids:
self.stop_token_ids = set(stop_token_ids)
else:
Expand Down Expand Up @@ -133,3 +135,9 @@ def normalize(self, tokenizer):
else:
stop_str_max_len = max(stop_str_max_len, len(stop_str))
self.stop_str_max_len = stop_str_max_len

if self.stop_regex_strs is None:
self.stop_regex_strs = []
else:
if isinstance(self.stop_regex_strs, str):
self.stop_regex_strs = [self.stop_regex_strs]
23 changes: 23 additions & 0 deletions test/srt/test_matched_stop.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def run_completions_generation(
prompt=MANY_NEW_TOKENS_PROMPT,
max_tokens=1,
stop=None,
stop_regex=None,
finish_reason=None,
matched_stop=None,
):
Expand All @@ -53,6 +54,9 @@ def run_completions_generation(
if stop is not None:
payload["stop"] = stop

if stop_regex is not None:
payload["stop_regex"] = stop_regex

response_completions = requests.post(
self.base_url + "/v1/completions",
json=payload,
Expand All @@ -70,6 +74,7 @@ def run_chat_completions_generation(
prompt=MANY_NEW_TOKENS_PROMPT,
max_tokens=1,
stop=None,
stop_regex=None,
finish_reason=None,
matched_stop=None,
):
Expand All @@ -87,6 +92,9 @@ def run_chat_completions_generation(
if stop is not None:
chat_payload["stop"] = stop

if stop_regex is not None:
chat_payload["stop_regex"] = stop_regex

response_chat = requests.post(
self.base_url + "/v1/chat/completions",
json=chat_payload,
Expand All @@ -105,6 +113,21 @@ def test_finish_stop_str(self):
max_tokens=1000, stop="\n", finish_reason="stop", matched_stop="\n"
)

def test_finish_stop_regex_str(self):
stop_regex = r"and |or "
self.run_completions_generation(
max_tokens=1000,
stop_regex=stop_regex,
finish_reason="stop",
matched_stop=stop_regex,
)
self.run_chat_completions_generation(
max_tokens=1000,
stop_regex=stop_regex,
finish_reason="stop",
matched_stop=stop_regex,
)

def test_finish_stop_eos(self):
llama_format_prompt = """
<|begin_of_text|><|start_header_id|>system<|end_header_id|>
Expand Down
Loading