diff --git a/requirements-common.txt b/requirements-common.txt index cfa02025629f..aa322ccc2ff1 100644 --- a/requirements-common.txt +++ b/requirements-common.txt @@ -21,7 +21,7 @@ tiktoken >= 0.6.0 # Required for DBRX tokenizer lm-format-enforcer >= 0.10.9, < 0.11 outlines == 0.1.11 lark == 1.2.2 -xgrammar >= 0.1.6; platform_machine == "x86_64" +xgrammar >= 0.1.11; platform_machine == "x86_64" typing_extensions >= 4.10 filelock >= 3.16.1 # need to contain https://github.com/tox-dev/filelock/pull/317 partial-json-parser # used for parsing partial JSON outputs diff --git a/vllm/model_executor/guided_decoding/__init__.py b/vllm/model_executor/guided_decoding/__init__.py index 3eb7d186eb00..77212a1d8cf1 100644 --- a/vllm/model_executor/guided_decoding/__init__.py +++ b/vllm/model_executor/guided_decoding/__init__.py @@ -49,11 +49,10 @@ def maybe_backend_fallback( "Falling back to use outlines instead.") guided_params.backend = "outlines" - # xgrammar doesn't support regex or choice, fallback to outlines - if guided_params.regex is not None or guided_params.choice is not None: - logger.warning( - "xgrammar only supports json or grammar guided decoding. " - "Falling back to use outlines instead.") + # xgrammar doesn't support regex, fallback to outlines + if guided_params.regex is not None: + logger.warning("xgrammar does not support regex guided decoding. " + "Falling back to use outlines instead.") guided_params.backend = "outlines" # xgrammar doesn't support some JSON schema features diff --git a/vllm/model_executor/guided_decoding/xgrammar_decoding.py b/vllm/model_executor/guided_decoding/xgrammar_decoding.py index fc3a4cd4bebc..329b03a573da 100644 --- a/vllm/model_executor/guided_decoding/xgrammar_decoding.py +++ b/vllm/model_executor/guided_decoding/xgrammar_decoding.py @@ -5,8 +5,9 @@ import copy import json +import re from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, List import torch from transformers import PreTrainedTokenizerFast @@ -228,11 +229,39 @@ def from_guided_params(cls, max_threads=max_threads, tokenizer_data=tokenizer_data, ) + elif guided_params.choice: + choice_str = GrammarConfig.choice_as_grammar(guided_params.choice) + try: + xgr.Grammar.from_ebnf(choice_str) + except RuntimeError as err: + raise ValueError(str(err)) from err + + return cls( + grammar_str=choice_str, + vocab_size=model_config.hf_text_config.vocab_size, + tokenizer_hash=tokenizer_hash, + max_threads=max_threads, + tokenizer_data=tokenizer_data, + ) else: raise ValueError( "Currently only support JSON and EBNF grammar mode for xgrammar" ) + @staticmethod + def escape_ebnf_string(s: str) -> str: + """Escape special characters in a EBNF string.""" + # Escape double quotes and backslashes + return re.sub(r'(["\\])', r'\\\1', s) + + @staticmethod + def choice_as_grammar(choice: List[str] | None) -> str: + if choice is None: + raise ValueError("Choice is not set") + escaped_choices = (GrammarConfig.escape_ebnf_string(c) for c in choice) + grammar = ('root ::= ' + ' | '.join(f'"{c}"' for c in escaped_choices)) + return grammar + @dataclass class XGrammarLogitsProcessor: