Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 11 additions & 63 deletions tests/entrypoints/openai/test_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,6 @@
# any model with a chat template should work here
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"

GUIDED_DECODING_BACKENDS = ["outlines", "lm-format-enforcer", "xgrammar"]


@pytest.fixture(scope="module")
def monkeypatch_module():
Expand Down Expand Up @@ -487,20 +485,9 @@ async def test_chat_completion_stream_options(client: openai.AsyncOpenAI,
assert last_completion_tokens == 10


# NOTE: Not sure why, but when I place this after `test_guided_regex_chat`
# (i.e. using the same ordering as in the Completions API tests), the test
# will fail on the second `guided_decoding_backend` even when I swap their order
# (ref: https://github.com/vllm-project/vllm/pull/5526#issuecomment-2173772256)
@pytest.mark.asyncio
@pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS)
async def test_guided_choice_chat(client: openai.AsyncOpenAI,
is_v1_server: bool,
guided_decoding_backend: str,
sample_guided_choice):

if is_v1_server and guided_decoding_backend != 'xgrammar':
pytest.skip("Only xgrammar backend is supported with V1")

messages = [{
"role": "system",
"content": "you are a helpful assistant"
Expand All @@ -515,8 +502,7 @@ async def test_guided_choice_chat(client: openai.AsyncOpenAI,
messages=messages,
max_completion_tokens=10,
temperature=0.7,
extra_body=dict(guided_choice=sample_guided_choice,
guided_decoding_backend=guided_decoding_backend))
extra_body=dict(guided_choice=sample_guided_choice))
choice1 = chat_completion.choices[0].message.content
assert choice1 in sample_guided_choice

Expand All @@ -530,22 +516,16 @@ async def test_guided_choice_chat(client: openai.AsyncOpenAI,
messages=messages,
max_completion_tokens=10,
temperature=0.7,
extra_body=dict(guided_choice=sample_guided_choice,
guided_decoding_backend=guided_decoding_backend))
extra_body=dict(guided_choice=sample_guided_choice))
choice2 = chat_completion.choices[0].message.content
assert choice2 in sample_guided_choice
assert choice1 != choice2


@pytest.mark.asyncio
@pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS)
async def test_guided_json_chat(client: openai.AsyncOpenAI, is_v1_server: bool,
guided_decoding_backend: str,
async def test_guided_json_chat(client: openai.AsyncOpenAI,
sample_json_schema):

if is_v1_server:
pytest.skip("sample_json_schema has features unsupported in V1")

messages = [{
"role": "system",
"content": "you are a helpful assistant"
Expand All @@ -560,8 +540,7 @@ async def test_guided_json_chat(client: openai.AsyncOpenAI, is_v1_server: bool,
model=MODEL_NAME,
messages=messages,
max_completion_tokens=1000,
extra_body=dict(guided_json=sample_json_schema,
guided_decoding_backend=guided_decoding_backend))
extra_body=dict(guided_json=sample_json_schema))
message = chat_completion.choices[0].message
assert message.content is not None
json1 = json.loads(message.content)
Expand All @@ -578,8 +557,7 @@ async def test_guided_json_chat(client: openai.AsyncOpenAI, is_v1_server: bool,
model=MODEL_NAME,
messages=messages,
max_completion_tokens=1000,
extra_body=dict(guided_json=sample_json_schema,
guided_decoding_backend=guided_decoding_backend))
extra_body=dict(guided_json=sample_json_schema))
message = chat_completion.choices[0].message
assert message.content is not None
json2 = json.loads(message.content)
Expand All @@ -589,13 +567,7 @@ async def test_guided_json_chat(client: openai.AsyncOpenAI, is_v1_server: bool,


@pytest.mark.asyncio
@pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS)
async def test_guided_regex_chat(client: openai.AsyncOpenAI,
is_v1_server: bool,
guided_decoding_backend: str, sample_regex):

if is_v1_server and guided_decoding_backend != 'xgrammar':
pytest.skip("Only xgrammar backend is supported with V1")
async def test_guided_regex_chat(client: openai.AsyncOpenAI, sample_regex):

messages = [{
"role": "system",
Expand All @@ -610,8 +582,7 @@ async def test_guided_regex_chat(client: openai.AsyncOpenAI,
model=MODEL_NAME,
messages=messages,
max_completion_tokens=20,
extra_body=dict(guided_regex=sample_regex,
guided_decoding_backend=guided_decoding_backend))
extra_body=dict(guided_regex=sample_regex))
ip1 = chat_completion.choices[0].message.content
assert ip1 is not None
assert re.fullmatch(sample_regex, ip1) is not None
Expand All @@ -622,8 +593,7 @@ async def test_guided_regex_chat(client: openai.AsyncOpenAI,
model=MODEL_NAME,
messages=messages,
max_completion_tokens=20,
extra_body=dict(guided_regex=sample_regex,
guided_decoding_backend=guided_decoding_backend))
extra_body=dict(guided_regex=sample_regex))
ip2 = chat_completion.choices[0].message.content
assert ip2 is not None
assert re.fullmatch(sample_regex, ip2) is not None
Expand Down Expand Up @@ -652,15 +622,9 @@ async def test_guided_decoding_type_error(client: openai.AsyncOpenAI):


@pytest.mark.asyncio
@pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS)
async def test_guided_choice_chat_logprobs(client: openai.AsyncOpenAI,
is_v1_server: bool,
guided_decoding_backend: str,
sample_guided_choice):

if is_v1_server and guided_decoding_backend != 'xgrammar':
pytest.skip("Only xgrammar backend is supported with V1")

messages = [{
"role": "system",
"content": "you are a helpful assistant"
Expand All @@ -676,8 +640,7 @@ async def test_guided_choice_chat_logprobs(client: openai.AsyncOpenAI,
max_completion_tokens=10,
logprobs=True,
top_logprobs=5,
extra_body=dict(guided_choice=sample_guided_choice,
guided_decoding_backend=guided_decoding_backend))
extra_body=dict(guided_choice=sample_guided_choice))

assert chat_completion.choices[0].logprobs is not None
assert chat_completion.choices[0].logprobs.content is not None
Expand All @@ -689,14 +652,7 @@ async def test_guided_choice_chat_logprobs(client: openai.AsyncOpenAI,


@pytest.mark.asyncio
@pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS)
async def test_named_tool_use(client: openai.AsyncOpenAI, is_v1_server: bool,
guided_decoding_backend: str,
sample_json_schema):

if is_v1_server:
pytest.skip("sample_json_schema has features unsupported on V1")

async def test_named_tool_use(client: openai.AsyncOpenAI, sample_json_schema):
messages = [{
"role": "system",
"content": "you are a helpful assistant"
Expand Down Expand Up @@ -728,7 +684,7 @@ async def test_named_tool_use(client: openai.AsyncOpenAI, is_v1_server: bool,
"name": "dummy_function_name"
}
},
extra_body=dict(guided_decoding_backend=guided_decoding_backend))
)
message = chat_completion.choices[0].message
assert len(message.content) == 0
json_string = message.tool_calls[0].function.arguments
Expand Down Expand Up @@ -763,7 +719,6 @@ async def test_named_tool_use(client: openai.AsyncOpenAI, is_v1_server: bool,
"name": "dummy_function_name"
}
},
extra_body=dict(guided_decoding_backend=guided_decoding_backend),
stream=True)

output = []
Expand Down Expand Up @@ -888,7 +843,6 @@ async def test_required_tool_use(client: openai.AsyncOpenAI,
model=model_name,
tools=tools,
tool_choice="required",
extra_body=dict(guided_decoding_backend="outlines"),
)

assert chat_completion.choices[0].message.tool_calls is not None
Expand All @@ -900,7 +854,6 @@ async def test_required_tool_use(client: openai.AsyncOpenAI,
model=model_name,
tools=tools,
tool_choice="required",
extra_body=dict(guided_decoding_backend="outlines"),
stream=True,
)

Expand All @@ -914,12 +867,7 @@ async def test_required_tool_use(client: openai.AsyncOpenAI,

@pytest.mark.asyncio
async def test_inconsistent_tool_choice_and_tools(client: openai.AsyncOpenAI,
is_v1_server: bool,
sample_json_schema):

if is_v1_server:
pytest.skip("sample_json_schema has features unsupported on V1")

messages = [{
"role": "system",
"content": "you are a helpful assistant"
Expand Down
4 changes: 2 additions & 2 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2888,7 +2888,7 @@ class DecodingConfig:

# Which guided decoding algo to use.
# 'outlines' / 'lm-format-enforcer' / 'xgrammar'
guided_decoding_backend: str = 'xgrammar'
guided_decoding_backend: str = "auto" if envs.VLLM_USE_V1 else "xgrammar"

reasoning_backend: Optional[str] = None

Expand All @@ -2913,7 +2913,7 @@ def compute_hash(self) -> str:

def __post_init__(self):
v0_valid_guided_backends = [
'outlines', 'lm-format-enforcer', 'xgrammar'
'outlines', 'lm-format-enforcer', 'xgrammar', 'auto'
]
v1_valid_guided_backends = ['xgrammar', 'guidance', 'auto']

Expand Down
6 changes: 3 additions & 3 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ class EngineArgs:
enable_chunked_prefill: Optional[bool] = None
disable_chunked_mm_input: bool = False

guided_decoding_backend: str = 'xgrammar'
guided_decoding_backend: str = DecodingConfig.guided_decoding_backend
logits_processor_pattern: Optional[str] = None

speculative_config: Optional[Dict[str, Any]] = None
Expand Down Expand Up @@ -381,13 +381,13 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
parser.add_argument(
'--guided-decoding-backend',
type=str,
default='xgrammar',
default=DecodingConfig.guided_decoding_backend,
help='Which engine will be used for guided decoding'
' (JSON schema / regex etc) by default. Currently support '
'https://github.com/mlc-ai/xgrammar and '
'https://github.com/guidance-ai/llguidance.'
'Valid backend values are "xgrammar", "guidance", and "auto". '
'With "auto", we will make opinionated choices based on request'
'With "auto", we will make opinionated choices based on request '
'contents and what the backend libraries currently support, so '
'the behavior is subject to change in each release.')
parser.add_argument(
Expand Down
6 changes: 6 additions & 0 deletions vllm/model_executor/guided_decoding/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,12 @@ def fallback_or_error(guided_params: GuidedDecodingParams, message: str,
logger.warning("%s Falling back to use %s instead.", message, fallback)
guided_params.backend = fallback

# `auto` was added for V1 to explicitly declare a mode that has fallbacks
# in place. If that is specified with V0, treat it as `xgrammar`, as we have
# fallbacks enabled for that and it is the V0 default.
if guided_params.backend == "auto":
guided_params.backend = "xgrammar"

# lm-format-enforce doesn't support grammar, fallback to xgrammar
if guided_params.backend_name == "lm-format-enforcer":
if guided_params.grammar is not None:
Expand Down