Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Support new parameter - EBNF in xgrammar #2526

Merged
merged 6 commits into from
Dec 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions python/sglang/lang/backend/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,11 @@ def select(
def openai_completion(
client, token_usage, is_chat=None, retries=3, prompt=None, **kwargs
):
# if "ebnf" is in kwargs, warn and remove
if "ebnf" in kwargs:
warnings.warn("EBNF is not officially supported by OpenAI endpoints. Ignoring.")
del kwargs["ebnf"]

for attempt in range(retries):
try:
if is_chat:
Expand Down Expand Up @@ -398,6 +403,11 @@ def openai_completion(
def openai_completion_stream(
client, token_usage, is_chat=None, retries=3, prompt=None, **kwargs
):
# if "ebnf" is in kwargs, warn and remove
if "ebnf" in kwargs:
warnings.warn("EBNF is not officially supported by OpenAI endpoints. Ignoring.")
del kwargs["ebnf"]

for attempt in range(retries):
try:
if is_chat:
Expand Down
6 changes: 6 additions & 0 deletions python/sglang/srt/constrained/xgrammar_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,12 @@ def init_value_impl(self, key: Tuple[str, str]) -> XGrammarGrammar:
f"Skip invalid json_schema: json_schema={key_string}, {e=}"
)
return None
elif key_type == "ebnf":
try:
ctx = self.grammar_compiler.compile_grammar(key_string)
except RuntimeError as e:
logging.warning(f"Skip invalid ebnf: ebnf={key_string}, {e=}")
return None
elif key_type == "regex":
logger.warning(
"regex hasn't been supported by xgrammar yet. This is skipped."
Expand Down
3 changes: 3 additions & 0 deletions python/sglang/srt/managers/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -589,12 +589,15 @@ def handle_generate_request(
if (
req.sampling_params.json_schema is not None
or req.sampling_params.regex is not None
or req.sampling_params.ebnf is not None
):
assert self.grammar_backend is not None
if req.sampling_params.json_schema is not None:
key = ("json", req.sampling_params.json_schema)
elif req.sampling_params.regex is not None:
key = ("regex", req.sampling_params.regex)
elif req.sampling_params.ebnf is not None:
key = ("ebnf", req.sampling_params.ebnf)

adarshxs marked this conversation as resolved.
Show resolved Hide resolved
req.grammar = self.grammar_backend.get_cached_value(key)
if not req.grammar:
Expand Down
19 changes: 19 additions & 0 deletions python/sglang/srt/openai_api/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,6 +517,7 @@ def v1_generate_request(
"repetition_penalty": request.repetition_penalty,
"regex": request.regex,
"json_schema": request.json_schema,
"ebnf": request.ebnf,
"n": request.n,
"no_stop_trim": request.no_stop_trim,
"ignore_eos": request.ignore_eos,
Expand Down Expand Up @@ -692,6 +693,14 @@ def v1_generate_response(request, ret, tokenizer_manager, to_file=False):

async def v1_completions(tokenizer_manager, raw_request: Request):
request_json = await raw_request.json()
if "extra_body" in request_json:
extra = request_json["extra_body"]
if "ebnf" in extra:
request_json["ebnf"] = extra["ebnf"]
if "regex" in extra:
request_json["regex"] = extra["regex"]
# remove extra_body to avoid pydantic conflict
del request_json["extra_body"]
all_requests = [CompletionRequest(**request_json)]
adapted_request, request = v1_generate_request(all_requests)

Expand Down Expand Up @@ -936,6 +945,7 @@ def v1_chat_generate_request(
"frequency_penalty": request.frequency_penalty,
"repetition_penalty": request.repetition_penalty,
"regex": request.regex,
"ebnf": request.ebnf,
"n": request.n,
"no_stop_trim": request.no_stop_trim,
"ignore_eos": request.ignore_eos,
Expand Down Expand Up @@ -1108,6 +1118,15 @@ def v1_chat_generate_response(request, ret, to_file=False, cache_report=False):

async def v1_chat_completions(tokenizer_manager, raw_request: Request):
request_json = await raw_request.json()
if "extra_body" in request_json:
extra = request_json["extra_body"]
# For example, if 'ebnf' is given:
if "ebnf" in extra:
request_json["ebnf"] = extra["ebnf"]
if "regex" in extra:
request_json["regex"] = extra["regex"]
# remove extra_body to avoid pydantic conflict
del request_json["extra_body"]
all_requests = [ChatCompletionRequest(**request_json)]
adapted_request, request = v1_chat_generate_request(all_requests, tokenizer_manager)

Expand Down
2 changes: 2 additions & 0 deletions python/sglang/srt/openai_api/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,7 @@ class CompletionRequest(BaseModel):
ignore_eos: bool = False
skip_special_tokens: bool = True
lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
ebnf: Optional[str] = None


class CompletionResponseChoice(BaseModel):
Expand Down Expand Up @@ -288,6 +289,7 @@ class ChatCompletionRequest(BaseModel):
ignore_eos: bool = False
skip_special_tokens: bool = True
lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
ebnf: Optional[str] = None


class ChatMessage(BaseModel):
Expand Down
11 changes: 9 additions & 2 deletions python/sglang/srt/sampling/sampling_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def __init__(
regex: Optional[str] = None,
n: int = 1,
json_schema: Optional[str] = None,
ebnf: Optional[str] = None,
no_stop_trim: bool = False,
ignore_eos: bool = False,
skip_special_tokens: bool = True,
Expand All @@ -60,6 +61,7 @@ def __init__(
self.regex = regex
self.n = n
self.json_schema = json_schema
self.ebnf = ebnf
self.no_stop_trim = no_stop_trim

# Process some special cases
Expand Down Expand Up @@ -111,8 +113,13 @@ def verify(self):
f"min_new_tokens must be in (0, max_new_tokens({self.max_new_tokens})], got "
f"{self.min_new_tokens}."
)
if self.regex is not None and self.json_schema is not None:
raise ValueError("regex and json_schema cannot be both set.")
grammars = [
self.json_schema,
self.regex,
self.ebnf,
] # since mutually exclusive, only one can be set
if sum(x is not None for x in grammars) > 1:
raise ValueError("Only one of regex, json_schema, or ebnf can be set.")

def normalize(self, tokenizer):
# Process stop strings
Expand Down
247 changes: 247 additions & 0 deletions test/srt/test_ebnf_constrained.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,247 @@
"""
python3 -m unittest test_ebnf_constrained.TestEBNFConstrained.test_ebnf_generate_email
python3 -m unittest test_ebnf_constrained.TestEBNFConstrained.test_ebnf_generate_greeting
"""

import json
import unittest

import requests

from sglang.srt.utils import kill_process_tree
from sglang.test.test_utils import (
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
popen_launch_server,
)


def setup_class(cls, disable_overlap: bool):
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
cls.base_url = DEFAULT_URL_FOR_TEST
cls.ebnf_grammar = 'root ::= "test"' # Default grammar

other_args = [
"--max-running-requests",
"10",
"--grammar-backend",
"xgrammar",
]

if disable_overlap:
other_args += ["--disable-overlap-schedule"]

cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=other_args,
)


class TestEBNFConstrained(unittest.TestCase):
@classmethod
def setUpClass(cls):
setup_class(cls, disable_overlap=False)
cls.check_jump_forward = False

@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)

def run_decode(
self,
ebnf,
expected_patterns,
prompt,
return_logprob=False,
top_logprobs_num=0,
n=1,
):
response = requests.post(
self.base_url + "/generate",
json={
"text": prompt,
"sampling_params": {
"temperature": 0 if n == 1 else 0.5,
"max_new_tokens": 128,
"n": n,
"ebnf": ebnf,
},
"stream": False,
"return_logprob": return_logprob,
"top_logprobs_num": top_logprobs_num,
"logprob_start_len": 0,
},
)

ret = response.json()
print(json.dumps(ret, indent=2))
print("=" * 100)

if not isinstance(ret, list):
self.fail(f"Expected response to be a list, but got {type(ret)}")

for item in ret:
text = item.get("text", "").strip()
if not text:
self.fail("Generated text is empty.")

match = False
for pattern in expected_patterns:
if self.regex_match(text, pattern):
match = True
break
if not match:
self.fail(f"Text '{text}' does not match any of the allowed patterns.")

def regex_match(self, text, pattern):
import re

return re.match(pattern, text) is not None

def test_ebnf_generate_email(self):
self.__class__.ebnf_grammar = 'root ::= "[email protected]"'
allowed_patterns = [r"^user@example\.com$"]
prompt = "Generate an email address:"

self.run_decode(
ebnf=self.__class__.ebnf_grammar,
expected_patterns=allowed_patterns,
prompt=prompt,
n=3,
)

def test_ebnf_generate_greeting(self):
self.__class__.ebnf_grammar = 'root ::= "Hello" | "Hi" | "Hey"'
allowed_patterns = [r"^(Hello|Hi|Hey)$"]
prompt = "Generate a greeting:"

self.run_decode(
ebnf=self.__class__.ebnf_grammar,
expected_patterns=allowed_patterns,
prompt=prompt,
n=3,
)

def test_ebnf_generate_number(self):
self.__class__.ebnf_grammar = """
root ::= digit digit digit
digit ::= [0-9]
"""
allowed_patterns = [r"^\d{3}$"]
prompt = "Generate a three-digit number:"

self.run_decode(
ebnf=self.__class__.ebnf_grammar,
expected_patterns=allowed_patterns,
prompt=prompt,
n=3,
)

def test_ebnf_generate_phone(self):
self.__class__.ebnf_grammar = """
root ::= "(" area ")" " " prefix "-" line
area ::= [0-9] [0-9] [0-9]
prefix ::= [0-9] [0-9] [0-9]
line ::= [0-9] [0-9] [0-9] [0-9]
"""
allowed_patterns = [r"^\(\d{3}\) \d{3}-\d{4}$"]
prompt = "Generate a phone number:"

self.run_decode(
ebnf=self.__class__.ebnf_grammar,
expected_patterns=allowed_patterns,
prompt=prompt,
n=3,
)

def test_ebnf_generate_date(self):
self.__class__.ebnf_grammar = """
root ::= year "-" month "-" day
year ::= "2024"
month ::= "01" | "02" | "03" | "04" | "05" | "06" | "07" | "08" | "09" | "10" | "11" | "12"
day ::= "01" | "02" | "03" | "04" | "05" | "06" | "07" | "08" | "09" | "10" |
"11" | "12" | "13" | "14" | "15" | "16" | "17" | "18" | "19" | "20" |
"21" | "22" | "23" | "24" | "25" | "26" | "27" | "28" | "29" | "30" | "31"
"""
allowed_patterns = [r"^2024-(0[1-9]|1[0-2])-(0[1-9]|[12]\d|3[01])$"]
prompt = "Generate a date in YYYY-MM-DD format:"

self.run_decode(
ebnf=self.__class__.ebnf_grammar,
expected_patterns=allowed_patterns,
prompt=prompt,
n=3,
)

def test_ebnf_generate_hex_color(self):
self.__class__.ebnf_grammar = """
root ::= "#" hex hex hex hex hex hex
hex ::= [0-9] | [A-F]
"""
allowed_patterns = [r"^#[0-9A-F]{6}$"]
prompt = "Generate a hex color code:"

self.run_decode(
ebnf=self.__class__.ebnf_grammar,
expected_patterns=allowed_patterns,
prompt=prompt,
n=3,
)

def test_ebnf_generate_complex_json(self):
self.__class__.ebnf_grammar = """
root ::= object
object ::= "{" ws pair (ws "," ws pair)* ws "}"
pair ::= "\\"name\\"" ws ":" ws value |
"\\"age\\"" ws ":" ws number |
"\\"city\\"" ws ":" ws string
value ::= string | number
string ::= "\\"" [a-zA-Z0-9 ]+ "\\""
number ::= [1-9] [0-9]*
ws ::= [ ]*
"""
allowed_patterns = [
r'^{\s*"name"\s*:\s*"[a-zA-Z0-9 ]+"\s*,\s*"age"\s*:\s*[1-9][0-9]*\s*,\s*"city"\s*:\s*"[a-zA-Z0-9 ]+"\s*}$',
]
prompt = "Generate a simple JSON with name, age, and city:"

self.run_decode(
ebnf=self.__class__.ebnf_grammar,
expected_patterns=allowed_patterns,
prompt=prompt,
n=3,
)

def test_ebnf_generate_custom_log_format(self):
self.__class__.ebnf_grammar = """
root ::= logentry
logentry ::= "[" datetime "] " level ": System.process - " message
datetime ::= "2024-01-01T12:00:00Z"
level ::= "INFO"
message ::= "Operation " [a-z]+ " successfully"
"""
allowed_patterns = [
r"^\[2024-01-01T12:00:00Z\] INFO: System\.process - Operation [a-z]+ successfully$"
]
prompt = "Generate a log entry:"

self.run_decode(
ebnf=self.__class__.ebnf_grammar,
expected_patterns=allowed_patterns,
prompt=prompt,
n=3,
)


class TestJumpForward(TestEBNFConstrained):
@classmethod
def setUpClass(cls):
setup_class(cls, disable_overlap=True)
cls.check_jump_forward = True


if __name__ == "__main__":
unittest.main()
Loading
Loading