Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 50 additions & 34 deletions tests/v1/entrypoints/llm/test_struct_output_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,9 @@ def test_structured_output(
max_tokens=1000,
guided_decoding=GuidedDecodingParams(json=sample_json_schema))
outputs = llm.generate(prompts=[
f"Give an example JSON for an employee profile "
f"that fits this schema: {sample_json_schema}"
(f"Give an example JSON for an employee profile that fits this "
f"schema. Make the response as short as possible. Schema: "
f"{sample_json_schema}")
] * 2,
sampling_params=sampling_params,
use_tqdm=True)
Expand Down Expand Up @@ -136,7 +137,8 @@ def test_structured_output(

outputs = llm.generate(
prompts=("Generate a JSON object with curly braces for a person with "
"name and age fields for John Smith who is 31 years old."),
"name and age fields for John Smith who is 31 years old. "
"Make the response as short as possible."),
sampling_params=sampling_params,
use_tqdm=True)

Expand Down Expand Up @@ -165,19 +167,20 @@ def test_structured_output(
with pytest.raises(ValueError,
match="The provided JSON schema contains features "
"not supported by xgrammar."):
llm.generate(prompts=[
f"Give an example JSON for an employee profile "
f"that fits this schema: {unsupported_json_schema}"
] * 2,
sampling_params=sampling_params,
use_tqdm=True)
llm.generate(
prompts=[(f"Give an example JSON for an employee profile that "
f"fits this schema: {unsupported_json_schema}. "
f"Make the response as short as possible.")] * 2,
sampling_params=sampling_params,
use_tqdm=True)
else:
outputs = llm.generate(
prompts=("Give an example JSON object for a grade "
"that fits this schema: "
f"{unsupported_json_schema}"),
sampling_params=sampling_params,
use_tqdm=True)
outputs = llm.generate(prompts=(
"Give an example JSON object for a grade "
"that fits this schema: "
f"{unsupported_json_schema}. Make the response as short as "
"possible."),
sampling_params=sampling_params,
use_tqdm=True)
assert outputs is not None
for output in outputs:
assert output is not None
Expand All @@ -199,8 +202,10 @@ def test_structured_output(
max_tokens=1000,
guided_decoding=GuidedDecodingParams(grammar=sample_sql_ebnf))
outputs = llm.generate(
prompts=("Generate a sql statement that selects col_1 from "
"table_1 where it is equal to 1"),
prompts=(
"Generate a sql statement that selects col_1 from "
"table_1 where it is equal to 1. Make the response as short as "
"possible."),
sampling_params=sampling_params,
use_tqdm=True,
)
Expand Down Expand Up @@ -231,8 +236,10 @@ def test_structured_output(
max_tokens=1000,
guided_decoding=GuidedDecodingParams(grammar=sample_sql_lark))
outputs = llm.generate(
prompts=("Generate a sql statement that selects col_1 from "
"table_1 where it is equal to 1"),
prompts=(
"Generate a sql statement that selects col_1 from "
"table_1 where it is equal to 1. Make the response as short as "
"possible."),
sampling_params=sampling_params,
use_tqdm=True,
)
Expand Down Expand Up @@ -269,8 +276,10 @@ def test_structured_output(
guided_decoding=GuidedDecodingParams(grammar="not a grammar"))
with pytest.raises(ValueError, match="Failed to convert the grammar "):
llm.generate(
prompts=("Generate a sql statement that selects col_1 from "
"table_1 where it is equal to 1"),
prompts=(
"Generate a sql statement that selects col_1 from "
"table_1 where it is equal to 1. Make the response as short "
"as possible."),
sampling_params=sampling_params,
use_tqdm=True,
)
Expand All @@ -284,7 +293,8 @@ def test_structured_output(
guided_decoding=GuidedDecodingParams(regex=sample_regex))
outputs = llm.generate(
prompts=[
f"Give an example IPv4 address with this regex: {sample_regex}"
(f"Give an example IPv4 address with this regex: {sample_regex}. "
f"Make the response as short as possible.")
] * 2,
sampling_params=sampling_params,
use_tqdm=True,
Expand All @@ -309,7 +319,8 @@ def test_structured_output(
top_p=0.95,
guided_decoding=GuidedDecodingParams(choice=sample_guided_choice))
outputs = llm.generate(
prompts="The best language for type-safe systems programming is ",
prompts=("The best language for type-safe systems programming is "
"(Make the response as short as possible.) "),
sampling_params=sampling_params,
use_tqdm=True)
assert outputs is not None
Expand All @@ -331,11 +342,12 @@ def test_structured_output(
temperature=1.0,
max_tokens=1000,
guided_decoding=GuidedDecodingParams(json=json_schema))
outputs = llm.generate(
prompts="Generate a JSON with the brand, model and car_type of"
"the most iconic car from the 90's",
sampling_params=sampling_params,
use_tqdm=True)
outputs = llm.generate(prompts=(
"Generate a JSON with the brand, model and car_type of the most "
"iconic car from the 90's. Make the response as short as "
"possible."),
sampling_params=sampling_params,
use_tqdm=True)

assert outputs is not None

Expand Down Expand Up @@ -373,7 +385,8 @@ def test_structured_output(
guided_decoding=GuidedDecodingParams(json=json_schema))

outputs = llm.generate(
prompts="Generate a description of a frog using 50 characters.",
prompts=("Generate a description of a frog using 50 characters. "
"Make the response as short as possible."),
sampling_params=sampling_params,
use_tqdm=True)

Expand Down Expand Up @@ -452,7 +465,8 @@ def test_structured_output(

You are a helpful assistant.

Given the previous instructions, what is the weather in New York City?
Given the previous instructions, what is the weather in New York City? \
Make the response as short as possible.
"""

# Change this once other backends support structural_tag
Expand Down Expand Up @@ -509,9 +523,10 @@ def test_structured_output_auto_mode(
max_tokens=1000,
guided_decoding=GuidedDecodingParams(json=unsupported_json_schema))

prompts = ("Give an example JSON object for a grade "
"that fits this schema: "
f"{unsupported_json_schema}")
prompts = (
"Give an example JSON object for a grade "
"that fits this schema: "
f"{unsupported_json_schema}. Make the response as short as possible.")
# This would fail with the default of "xgrammar", but in "auto"
# we will handle fallback automatically.
outputs = llm.generate(prompts=prompts,
Expand Down Expand Up @@ -566,7 +581,8 @@ def test_guidance_no_additional_properties(monkeypatch: pytest.MonkeyPatch):
prompt = (
"<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a "
"helpful assistant.<|im_end|>\n<|im_start|>user\nPlease generate a "
"large JSON object with key-value pairs a1=b1, a2=b2, ..., a20=b20"
"large JSON object with key-value pairs a1=b1, a2=b2, ..., a20=b20. "
"Make the response as short as possible."
"<|im_end|>\n<|im_start|>assistant\n")

def generate_with_backend(backend):
Expand Down
Loading