Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core] Support Lark grammars for XGrammar #10870

Merged
merged 5 commits into from
Dec 6, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 0 additions & 8 deletions vllm/model_executor/guided_decoding/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,6 @@ def maybe_backend_fallback(
"Falling back to use outlines instead.")
guided_params.backend = "outlines"

# xgrammar only supports EBNF grammars and uses the GBNF format
# https://github.com/ggerganov/llama.cpp/blob/master/grammars/README.md
elif (guided_params.grammar is not None
and "::=" not in guided_params.grammar):
logger.warning("xgrammar only supports EBNF grammars. "
"Falling back to use outlines instead.")
guided_params.backend = "outlines"

return guided_params


Expand Down
10 changes: 9 additions & 1 deletion vllm/model_executor/guided_decoding/xgrammar_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
except ImportError:
pass

from vllm.model_executor.guided_decoding.xgrammar_utils import (
convert_lark_to_gbnf, grammar_is_likely_lark)

if TYPE_CHECKING:
from transformers import PreTrainedTokenizer

Expand Down Expand Up @@ -152,7 +155,12 @@ def from_guided_params(cls,
tokenizer_hash=tokenizer_hash,
max_threads=max_threads)
elif guided_params.grammar:
return cls(grammar_str=guided_params.grammar,
# XGrammar only supports GBNF grammars, so we must convert Lark
if grammar_is_likely_lark(guided_params.grammar):
grammar_str = convert_lark_to_gbnf(guided_params.grammar)
mgoin marked this conversation as resolved.
Show resolved Hide resolved
else:
grammar_str = guided_params.grammar
return cls(grammar_str=grammar_str,
vocab_size=model_config.hf_config.vocab_size,
encoded_vocab=encoded_vocab,
stop_token_ids=stop_token_ids,
Expand Down
120 changes: 120 additions & 0 deletions vllm/model_executor/guided_decoding/xgrammar_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
import re


def grammar_is_likely_lark(grammar_str: str) -> bool:
"""
Check if grammar appears to use Lark syntax.

Args:
grammar_str: Input grammar string

Returns:
bool: True if grammar appears to be in Lark format, False otherwise

Examples:
>>> grammar_is_likely_lark("rule: 'abc'")
True
>>> grammar_is_likely_lark("rule ::= 'abc'")
False
"""
if not grammar_str or not isinstance(grammar_str, str):
return False

for line in grammar_str.split('\n'):
# Remove both comment styles
line = re.sub(r'(#|//).*$', '', line).strip()
if not line:
continue

# Look for Lark-style rule definitions
if ':' in line and '::=' not in line:
return True

# Look for Lark-specific features
if any(pattern in line for pattern in ['?start:', '|', '~']):
return True

return False


def convert_lark_to_gbnf(grammar_str: str) -> str:
"""
mgoin marked this conversation as resolved.
Show resolved Hide resolved
Convert a Lark grammar string to GBNF format.

Supports:
- Lark rule definitions to GBNF productions
- String literals with proper escaping
- Multi-line rules with alternatives (|)
- Basic terminal definitions
- Comments (both # and // style)

Args:
grammar_str: Input grammar in Lark format

Returns:
str: Converted grammar in GBNF format

Examples:
>>> print(convert_lark_to_gbnf("rule: 'hello'"))
root ::= rule
rule ::= "hello"
"""
# First identify what rule should be used as root
first_rule = None
for line in grammar_str.split('\n'):
# Remove both comment styles
line = re.sub(r'(#|//).*$', '', line).strip()
if not line:
continue

if ':' in line and not line.startswith('|'):
name = line.split(':', 1)[0].strip().strip('?')
if first_rule is None:
first_rule = name
if name == 'start': # If we find 'start', use it
first_rule = 'start'
break

if first_rule is None:
raise ValueError("No rules found in grammar")

# Use provided root_name if specified
root_rule = first_rule
output_lines = [f"root ::= {root_rule}"]

current_rule = None
current_definition = []

for line in grammar_str.split('\n'):
# Remove both comment styles
line = re.sub(r'(#|//).*$', '', line).strip()
if not line:
continue

# Handle rule definition
if ':' in line and not line.startswith('|'):
# If we were building a rule, add it
if current_rule:
output_lines.append(
f"{current_rule} ::= {' | '.join(current_definition)}")

# Start new rule
name, definition = line.split(':', 1)
current_rule = name.strip().strip('?')
# Convert string literals from single to double quotes if needed
definition = re.sub(r"'([^']*)'", r'"\1"', definition)
current_definition = [definition.strip()]

# Handle continuation with |
elif line.startswith('|'):
if current_rule:
# Convert string literals in alternatives too
line = re.sub(r"'([^']*)'", r'"\1"', line[1:].strip())
current_definition.append(line)

# Add the last rule if exists
if current_rule:
output_lines.append(
f"{current_rule} ::= {' | '.join(current_definition)}")

return '\n'.join(output_lines)
Loading