From 93a16757e95a2e2963df859e0094d957c9424bf3 Mon Sep 17 00:00:00 2001 From: mgoin Date: Tue, 3 Dec 2024 17:29:42 +0000 Subject: [PATCH 1/3] Support Lark grammars for XGrammar Signed-off-by: mgoin --- .../guided_decoding/__init__.py | 19 +-- .../guided_decoding/xgrammar_decoding.py | 10 +- .../guided_decoding/xgrammar_utils.py | 120 ++++++++++++++++++ 3 files changed, 134 insertions(+), 15 deletions(-) create mode 100644 vllm/model_executor/guided_decoding/xgrammar_utils.py diff --git a/vllm/model_executor/guided_decoding/__init__.py b/vllm/model_executor/guided_decoding/__init__.py index 23c31fcfd7f05..e31c6262ee44b 100644 --- a/vllm/model_executor/guided_decoding/__init__.py +++ b/vllm/model_executor/guided_decoding/__init__.py @@ -24,21 +24,12 @@ def maybe_backend_fallback( "Falling back to use xgrammar instead.") guided_params.backend = "xgrammar" - if guided_params.backend == "xgrammar": + if (guided_params.backend == "xgrammar" and + (guided_params.regex is not None or guided_params.choice is not None)): # xgrammar doesn't support regex or choice, fallback to outlines - if guided_params.regex is not None or guided_params.choice is not None: - logger.warning( - "xgrammar only supports json or grammar guided decoding. " - "Falling back to use outlines instead.") - guided_params.backend = "outlines" - - # xgrammar only supports EBNF grammars and uses the GBNF format - # https://github.com/ggerganov/llama.cpp/blob/master/grammars/README.md - elif (guided_params.grammar is not None - and "::=" not in guided_params.grammar): - logger.warning("xgrammar only supports EBNF grammars. " - "Falling back to use outlines instead.") - guided_params.backend = "outlines" + logger.warning("xgrammar doesn't support regex guided decoding. " + "Falling back to use outlines instead.") + guided_params.backend = "outlines" return guided_params diff --git a/vllm/model_executor/guided_decoding/xgrammar_decoding.py b/vllm/model_executor/guided_decoding/xgrammar_decoding.py index 8287cd6cf3aa0..c9b3f0a947710 100644 --- a/vllm/model_executor/guided_decoding/xgrammar_decoding.py +++ b/vllm/model_executor/guided_decoding/xgrammar_decoding.py @@ -14,6 +14,9 @@ except ImportError: pass +from vllm.model_executor.guided_decoding.xgrammar_utils import ( + convert_lark_to_gbnf, grammar_is_likely_lark) + if TYPE_CHECKING: from transformers import PreTrainedTokenizer @@ -152,7 +155,12 @@ def from_guided_params(cls, tokenizer_hash=tokenizer_hash, max_threads=max_threads) elif guided_params.grammar: - return cls(grammar_str=guided_params.grammar, + # XGrammar only supports GBNF grammars, so we must convert Lark + if grammar_is_likely_lark(guided_params.grammar): + grammar_str = convert_lark_to_gbnf(guided_params.grammar) + else: + grammar_str = guided_params.grammar + return cls(grammar_str=grammar_str, vocab_size=model_config.hf_config.vocab_size, encoded_vocab=encoded_vocab, stop_token_ids=stop_token_ids, diff --git a/vllm/model_executor/guided_decoding/xgrammar_utils.py b/vllm/model_executor/guided_decoding/xgrammar_utils.py new file mode 100644 index 0000000000000..4ca819f04c65f --- /dev/null +++ b/vllm/model_executor/guided_decoding/xgrammar_utils.py @@ -0,0 +1,120 @@ +import re + + +def grammar_is_likely_lark(grammar_str: str) -> bool: + """ + Check if grammar appears to use Lark syntax. + + Args: + grammar_str: Input grammar string + + Returns: + bool: True if grammar appears to be in Lark format, False otherwise + + Examples: + >>> grammar_is_likely_lark("rule: 'abc'") + True + >>> grammar_is_likely_lark("rule ::= 'abc'") + False + """ + if not grammar_str or not isinstance(grammar_str, str): + return False + + for line in grammar_str.split('\n'): + # Remove both comment styles + line = re.sub(r'(#|//).*$', '', line).strip() + if not line: + continue + + # Look for Lark-style rule definitions + if ':' in line and '::=' not in line: + return True + + # Look for Lark-specific features + if any(pattern in line for pattern in ['?start:', '|', '~']): + return True + + return False + + +def convert_lark_to_gbnf(grammar_str: str) -> str: + """ + Convert a Lark grammar string to GBNF format. + + Supports: + - Lark rule definitions to GBNF productions + - String literals with proper escaping + - Multi-line rules with alternatives (|) + - Basic terminal definitions + - Comments (both # and // style) + + Args: + grammar_str: Input grammar in Lark format + + Returns: + str: Converted grammar in GBNF format + + Examples: + >>> print(convert_lark_to_gbnf("rule: 'hello'")) + root ::= rule + rule ::= "hello" + """ + # First identify what rule should be used as root + first_rule = None + for line in grammar_str.split('\n'): + # Remove both comment styles + line = re.sub(r'(#|//).*$', '', line).strip() + if not line: + continue + + if ':' in line and not line.startswith('|'): + name = line.split(':', 1)[0].strip().strip('?') + if first_rule is None: + first_rule = name + if name == 'start': # If we find 'start', use it + first_rule = 'start' + break + + if first_rule is None: + raise ValueError("No rules found in grammar") + + # Use provided root_name if specified + root_rule = first_rule + output_lines = [f"root ::= {root_rule}"] + + current_rule = None + current_definition = [] + + for line in grammar_str.split('\n'): + # Remove both comment styles + line = re.sub(r'(#|//).*$', '', line).strip() + if not line: + continue + + # Handle rule definition + if ':' in line and not line.startswith('|'): + # If we were building a rule, add it + if current_rule: + output_lines.append( + f"{current_rule} ::= {' | '.join(current_definition)}") + + # Start new rule + name, definition = line.split(':', 1) + current_rule = name.strip().strip('?') + # Convert string literals from single to double quotes if needed + definition = re.sub(r"'([^']*)'", r'"\1"', definition) + current_definition = [definition.strip()] + + # Handle continuation with | + elif line.startswith('|'): + if current_rule: + # Convert string literals in alternatives too + line = re.sub(r"'([^']*)'", r'"\1"', line[1:].strip()) + current_definition.append(line) + + # Add the last rule if exists + if current_rule: + output_lines.append( + f"{current_rule} ::= {' | '.join(current_definition)}") + + return '\n'.join(output_lines) From a484a63d345ddf0fb6fb11c04e2ffb5a38a88c23 Mon Sep 17 00:00:00 2001 From: mgoin Date: Thu, 5 Dec 2024 21:01:31 +0000 Subject: [PATCH 2/3] Add informative error and much more careful exceptions Signed-off-by: mgoin --- .../guided_decoding/xgrammar_decoding.py | 10 +- .../guided_decoding/xgrammar_utils.py | 116 ++++++++++++++---- 2 files changed, 101 insertions(+), 25 deletions(-) diff --git a/vllm/model_executor/guided_decoding/xgrammar_decoding.py b/vllm/model_executor/guided_decoding/xgrammar_decoding.py index c9b3f0a947710..5b9a873834c4c 100644 --- a/vllm/model_executor/guided_decoding/xgrammar_decoding.py +++ b/vllm/model_executor/guided_decoding/xgrammar_decoding.py @@ -157,7 +157,15 @@ def from_guided_params(cls, elif guided_params.grammar: # XGrammar only supports GBNF grammars, so we must convert Lark if grammar_is_likely_lark(guided_params.grammar): - grammar_str = convert_lark_to_gbnf(guided_params.grammar) + try: + grammar_str = convert_lark_to_gbnf(guided_params.grammar) + except ValueError as e: + raise ValueError( + "Failed to convert the grammar from Lark to GBNF. " + "Please either use GBNF grammar directly or specify" + " --guided-decoding-backend=outlines.\n" + f"Conversion error: {str(e)}" + ) from e else: grammar_str = guided_params.grammar return cls(grammar_str=grammar_str, diff --git a/vllm/model_executor/guided_decoding/xgrammar_utils.py b/vllm/model_executor/guided_decoding/xgrammar_utils.py index 4ca819f04c65f..1f3a8b5bc0e89 100644 --- a/vllm/model_executor/guided_decoding/xgrammar_utils.py +++ b/vllm/model_executor/guided_decoding/xgrammar_utils.py @@ -40,6 +40,11 @@ def grammar_is_likely_lark(grammar_str: str) -> bool: def convert_lark_to_gbnf(grammar_str: str) -> str: """ Convert a Lark grammar string to GBNF format. + + GBNF reference: + https://github.com/ggerganov/llama.cpp/blob/master/grammars/README.md + Lark grammar reference: + https://lark-parser.readthedocs.io/en/latest/grammar.html Supports: - Lark rule definitions to GBNF productions @@ -59,62 +64,125 @@ def convert_lark_to_gbnf(grammar_str: str) -> str: root ::= rule rule ::= "hello" """ + if not isinstance(grammar_str, str): + raise ValueError(f"Grammar must be a string, got {type(grammar_str)}") + + if not grammar_str.strip(): + raise ValueError("Grammar string cannot be empty") + + # Track rules for validation while being lenient about rule names + defined_rules = set() + referenced_rules = set() + # First identify what rule should be used as root first_rule = None - for line in grammar_str.split('\n'): + lines = grammar_str.split('\n') + + for line_num, line in enumerate(lines, 1): # Remove both comment styles line = re.sub(r'(#|//).*$', '', line).strip() if not line: continue if ':' in line and not line.startswith('|'): - name = line.split(':', 1)[0].strip().strip('?') - if first_rule is None: - first_rule = name - if name == 'start': # If we find 'start', use it - first_rule = 'start' - break + try: + name = line.split(':', 1)[0].strip().strip('?') + defined_rules.add(name) + + if first_rule is None: + first_rule = name + if name == 'start': # If we find 'start', use it + first_rule = 'start' - if first_rule is None: - raise ValueError("No rules found in grammar") + except IndexError as e: + raise ValueError(f"Invalid rule format on line {line_num}. " + "Expected 'rule_name: definition'") from e + + if not defined_rules: + raise ValueError("No valid rules found in grammar") - # Use provided root_name if specified root_rule = first_rule output_lines = [f"root ::= {root_rule}"] current_rule = None current_definition = [] - for line in grammar_str.split('\n'): - # Remove both comment styles + for line_num, line in enumerate(lines, 1): line = re.sub(r'(#|//).*$', '', line).strip() if not line: continue - # Handle rule definition if ':' in line and not line.startswith('|'): - # If we were building a rule, add it if current_rule: output_lines.append( f"{current_rule} ::= {' | '.join(current_definition)}") - # Start new rule - name, definition = line.split(':', 1) - current_rule = name.strip().strip('?') - # Convert string literals from single to double quotes if needed - definition = re.sub(r"'([^']*)'", r'"\1"', definition) - current_definition = [definition.strip()] + try: + name, definition = line.split(':', 1) + current_rule = name.strip().strip('?') + + # Basic quote validation to catch obvious errors + if definition.count("'") % 2 != 0 or definition.count( + '"') % 2 != 0: + raise ValueError("Mismatched quotes in rule " + f"'{current_rule}' on line {line_num}") + + # Convert string literals from single to double quotes + definition = re.sub(r"'([^']*)'", r'"\1"', definition) + + # Extract referenced rules (excluding quoted strings and + # special characters) + # Remove quoted strings + temp = re.sub(r'"[^"]*"', '', definition) + # Remove special chars + temp = re.sub(r'[+*?()|\[\]{}]', ' ', temp) + tokens = re.findall(r'\b[a-zA-Z_][a-zA-Z0-9_]*\b', temp) + referenced_rules.update(tokens) + + current_definition = [definition.strip()] + + except ValueError as e: + raise ValueError("Error parsing rule definition on " + f"line {line_num}: {str(e)}") from e + except Exception as e: + raise ValueError("Unexpected error parsing rule on " + f"line {line_num}: {str(e)}") from e - # Handle continuation with | elif line.startswith('|'): - if current_rule: - # Convert string literals in alternatives too + if not current_rule: + raise ValueError(f"Alternative '|' on line {line_num} " + "without a preceding rule definition") + + try: + # Convert string literals from single to double quotes line = re.sub(r"'([^']*)'", r'"\1"', line[1:].strip()) + + # Basic quote validation + if line.count("'") % 2 != 0 or line.count('"') % 2 != 0: + raise ValueError( + "Mismatched quotes in alternative for " + f"rule '{current_rule}' on line {line_num}") + + # Extract referenced rules (same as above) + temp = re.sub(r'"[^"]*"', '', line) + temp = re.sub(r'[+*?()|\[\]{}]', ' ', temp) + tokens = re.findall(r'\b[a-zA-Z_][a-zA-Z0-9_]*\b', temp) + referenced_rules.update(tokens) + current_definition.append(line) - # Add the last rule if exists + except ValueError as e: + raise ValueError("Error parsing alternative on line " + f"{line_num}: {str(e)}") from e + if current_rule: output_lines.append( f"{current_rule} ::= {' | '.join(current_definition)}") + # Check for undefined rules, excluding common terminals and special cases + undefined_rules = referenced_rules - defined_rules - {'root'} + if undefined_rules: + raise ValueError("Referenced rules are not defined: " + f"{', '.join(sorted(undefined_rules))}") + return '\n'.join(output_lines) From 346685b43256c73c0b3f98944481c0d55d104a51 Mon Sep 17 00:00:00 2001 From: mgoin Date: Thu, 5 Dec 2024 21:11:54 +0000 Subject: [PATCH 3/3] Improve readability Signed-off-by: mgoin --- .../guided_decoding/xgrammar_decoding.py | 3 +- .../guided_decoding/xgrammar_utils.py | 126 +++++++----------- 2 files changed, 51 insertions(+), 78 deletions(-) diff --git a/vllm/model_executor/guided_decoding/xgrammar_decoding.py b/vllm/model_executor/guided_decoding/xgrammar_decoding.py index 5b9a873834c4c..b59a2269d2cd5 100644 --- a/vllm/model_executor/guided_decoding/xgrammar_decoding.py +++ b/vllm/model_executor/guided_decoding/xgrammar_decoding.py @@ -164,8 +164,7 @@ def from_guided_params(cls, "Failed to convert the grammar from Lark to GBNF. " "Please either use GBNF grammar directly or specify" " --guided-decoding-backend=outlines.\n" - f"Conversion error: {str(e)}" - ) from e + f"Conversion error: {str(e)}") from e else: grammar_str = guided_params.grammar return cls(grammar_str=grammar_str, diff --git a/vllm/model_executor/guided_decoding/xgrammar_utils.py b/vllm/model_executor/guided_decoding/xgrammar_utils.py index 1f3a8b5bc0e89..12b42245f4e3d 100644 --- a/vllm/model_executor/guided_decoding/xgrammar_utils.py +++ b/vllm/model_executor/guided_decoding/xgrammar_utils.py @@ -46,13 +46,6 @@ def convert_lark_to_gbnf(grammar_str: str) -> str: Lark grammar reference: https://lark-parser.readthedocs.io/en/latest/grammar.html - Supports: - - Lark rule definitions to GBNF productions - - String literals with proper escaping - - Multi-line rules with alternatives (|) - - Basic terminal definitions - - Comments (both # and // style) - Args: grammar_str: Input grammar in Lark format @@ -66,34 +59,46 @@ def convert_lark_to_gbnf(grammar_str: str) -> str: """ if not isinstance(grammar_str, str): raise ValueError(f"Grammar must be a string, got {type(grammar_str)}") - if not grammar_str.strip(): raise ValueError("Grammar string cannot be empty") - # Track rules for validation while being lenient about rule names defined_rules = set() referenced_rules = set() - - # First identify what rule should be used as root + output_lines = [] + + def clean_line(line: str) -> str: + """Remove comments and whitespace from line.""" + return re.sub(r'(#|//).*$', '', line).strip() + + def check_quotes(text: str, rule_name: str, line_num: int) -> None: + """Validate quote matching in text.""" + if text.count("'") % 2 != 0 or text.count('"') % 2 != 0: + raise ValueError( + f"Mismatched quotes in {rule_name} on line {line_num}") + + def extract_references(text: str) -> set: + """Extract rule references from text.""" + # Remove quoted strings and special characters + text = re.sub(r'"[^"]*"', '', text) + text = re.sub(r'[+*?()|\[\]{}]', ' ', text) + return set(re.findall(r'\b[a-zA-Z_][a-zA-Z0-9_]*\b', text)) + + # First pass: Find root rule and validate rule definitions + lines = [clean_line(line) for line in grammar_str.split('\n')] first_rule = None - lines = grammar_str.split('\n') for line_num, line in enumerate(lines, 1): - # Remove both comment styles - line = re.sub(r'(#|//).*$', '', line).strip() - if not line: + if not line or line.startswith('|'): continue - if ':' in line and not line.startswith('|'): + if ':' in line: try: name = line.split(':', 1)[0].strip().strip('?') defined_rules.add(name) - if first_rule is None: first_rule = name - if name == 'start': # If we find 'start', use it + if name == 'start': first_rule = 'start' - except IndexError as e: raise ValueError(f"Invalid rule format on line {line_num}. " "Expected 'rule_name: definition'") from e @@ -101,85 +106,54 @@ def convert_lark_to_gbnf(grammar_str: str) -> str: if not defined_rules: raise ValueError("No valid rules found in grammar") - root_rule = first_rule - output_lines = [f"root ::= {root_rule}"] + # Add root rule + output_lines.append(f"root ::= {first_rule}") + # Second pass: Process rule definitions and alternatives current_rule = None current_definition = [] for line_num, line in enumerate(lines, 1): - line = re.sub(r'(#|//).*$', '', line).strip() if not line: continue - if ':' in line and not line.startswith('|'): - if current_rule: - output_lines.append( - f"{current_rule} ::= {' | '.join(current_definition)}") + try: + if ':' in line and not line.startswith('|'): + # Save previous rule if exists + if current_rule: + output_lines.append( + f"{current_rule} ::= {' | '.join(current_definition)}") - try: + # Process new rule name, definition = line.split(':', 1) current_rule = name.strip().strip('?') - # Basic quote validation to catch obvious errors - if definition.count("'") % 2 != 0 or definition.count( - '"') % 2 != 0: - raise ValueError("Mismatched quotes in rule " - f"'{current_rule}' on line {line_num}") - - # Convert string literals from single to double quotes + check_quotes(definition, f"rule '{current_rule}'", line_num) definition = re.sub(r"'([^']*)'", r'"\1"', definition) - - # Extract referenced rules (excluding quoted strings and - # special characters) - # Remove quoted strings - temp = re.sub(r'"[^"]*"', '', definition) - # Remove special chars - temp = re.sub(r'[+*?()|\[\]{}]', ' ', temp) - tokens = re.findall(r'\b[a-zA-Z_][a-zA-Z0-9_]*\b', temp) - referenced_rules.update(tokens) - + referenced_rules.update(extract_references(definition)) current_definition = [definition.strip()] - except ValueError as e: - raise ValueError("Error parsing rule definition on " - f"line {line_num}: {str(e)}") from e - except Exception as e: - raise ValueError("Unexpected error parsing rule on " - f"line {line_num}: {str(e)}") from e - - elif line.startswith('|'): - if not current_rule: - raise ValueError(f"Alternative '|' on line {line_num} " - "without a preceding rule definition") - - try: - # Convert string literals from single to double quotes - line = re.sub(r"'([^']*)'", r'"\1"', line[1:].strip()) - - # Basic quote validation - if line.count("'") % 2 != 0 or line.count('"') % 2 != 0: - raise ValueError( - "Mismatched quotes in alternative for " - f"rule '{current_rule}' on line {line_num}") - - # Extract referenced rules (same as above) - temp = re.sub(r'"[^"]*"', '', line) - temp = re.sub(r'[+*?()|\[\]{}]', ' ', temp) - tokens = re.findall(r'\b[a-zA-Z_][a-zA-Z0-9_]*\b', temp) - referenced_rules.update(tokens) + elif line.startswith('|'): + if not current_rule: + raise ValueError(f"Alternative '|' on line {line_num} " + "without a preceding rule definition") - current_definition.append(line) + alt_def = line[1:].strip() + check_quotes(alt_def, f"alternative for rule '{current_rule}'", + line_num) + alt_def = re.sub(r"'([^']*)'", r'"\1"', alt_def) + referenced_rules.update(extract_references(alt_def)) + current_definition.append(alt_def) - except ValueError as e: - raise ValueError("Error parsing alternative on line " - f"{line_num}: {str(e)}") from e + except ValueError as e: + raise ValueError(f"Error on line {line_num}: {str(e)}") from e + # Add final rule if exists if current_rule: output_lines.append( f"{current_rule} ::= {' | '.join(current_definition)}") - # Check for undefined rules, excluding common terminals and special cases + # Validate all rules are defined undefined_rules = referenced_rules - defined_rules - {'root'} if undefined_rules: raise ValueError("Referenced rules are not defined: "