Skip to content

Commit 3d1631a

Browse files
russellbshreyankg
authored andcommitted
[CI] Speed up V1 structured output tests (vllm-project#15718)
Signed-off-by: Russell Bryant <[email protected]>
1 parent e76bc2a commit 3d1631a

File tree

1 file changed

+89
-133
lines changed

1 file changed

+89
-133
lines changed

tests/v1/entrypoints/llm/test_struct_output_generate.py

Lines changed: 89 additions & 133 deletions
Original file line numberDiff line numberDiff line change
@@ -23,20 +23,46 @@
2323
]
2424

2525

26+
class CarType(str, Enum):
27+
sedan = "sedan"
28+
suv = "SUV"
29+
truck = "Truck"
30+
coupe = "Coupe"
31+
32+
33+
class CarDescription(BaseModel):
34+
brand: str
35+
model: str
36+
car_type: CarType
37+
38+
2639
@pytest.mark.skip_global_cleanup
2740
@pytest.mark.parametrize("guided_decoding_backend",
2841
GUIDED_DECODING_BACKENDS_V1)
2942
@pytest.mark.parametrize("model_name", MODELS_TO_TEST)
30-
def test_guided_json_completion(
43+
def test_structured_output(
3144
monkeypatch: pytest.MonkeyPatch,
3245
sample_json_schema: dict[str, Any],
46+
unsupported_json_schema: dict[str, Any],
47+
sample_sql_ebnf: str,
48+
sample_sql_lark: str,
49+
sample_regex: str,
50+
sample_guided_choice: str,
3351
guided_decoding_backend: str,
3452
model_name: str,
3553
):
3654
monkeypatch.setenv("VLLM_USE_V1", "1")
55+
56+
# Use a single LLM instance for several scenarios to
57+
# speed up the test suite.
3758
llm = LLM(model=model_name,
59+
enforce_eager=True,
3860
max_model_len=1024,
3961
guided_decoding_backend=guided_decoding_backend)
62+
63+
#
64+
# Test 1: Generate JSON output based on a provided schema
65+
#
4066
sampling_params = SamplingParams(
4167
temperature=1.0,
4268
max_tokens=1000,
@@ -63,20 +89,9 @@ def test_guided_json_completion(
6389
output_json = json.loads(generated_text)
6490
jsonschema.validate(instance=output_json, schema=sample_json_schema)
6591

66-
67-
@pytest.mark.skip_global_cleanup
68-
@pytest.mark.parametrize("guided_decoding_backend",
69-
GUIDED_DECODING_BACKENDS_V1)
70-
@pytest.mark.parametrize("model_name", MODELS_TO_TEST)
71-
def test_guided_json_object(
72-
monkeypatch: pytest.MonkeyPatch,
73-
guided_decoding_backend: str,
74-
model_name: str,
75-
):
76-
monkeypatch.setenv("VLLM_USE_V1", "1")
77-
llm = LLM(model=model_name,
78-
max_model_len=1024,
79-
guided_decoding_backend=guided_decoding_backend)
92+
#
93+
# Test 2: Generate JSON object without a schema
94+
#
8095
sampling_params = SamplingParams(
8196
temperature=1.0,
8297
max_tokens=100,
@@ -111,21 +126,9 @@ def test_guided_json_object(
111126
allowed_types = (dict, list)
112127
assert isinstance(parsed_json, allowed_types)
113128

114-
115-
@pytest.mark.skip_global_cleanup
116-
@pytest.mark.parametrize("guided_decoding_backend",
117-
GUIDED_DECODING_BACKENDS_V1 + ["auto"])
118-
@pytest.mark.parametrize("model_name", MODELS_TO_TEST)
119-
def test_guided_json_unsupported_schema(
120-
monkeypatch: pytest.MonkeyPatch,
121-
unsupported_json_schema: dict[str, Any],
122-
guided_decoding_backend: str,
123-
model_name: str,
124-
):
125-
monkeypatch.setenv("VLLM_USE_V1", "1")
126-
llm = LLM(model=model_name,
127-
max_model_len=1024,
128-
guided_decoding_backend=guided_decoding_backend)
129+
#
130+
# Test 3: test a jsonschema incompatible with xgrammar
131+
#
129132
sampling_params = SamplingParams(
130133
temperature=1.0,
131134
max_tokens=1000,
@@ -141,8 +144,6 @@ def test_guided_json_unsupported_schema(
141144
sampling_params=sampling_params,
142145
use_tqdm=True)
143146
else:
144-
# This should work for both "guidance" and "auto".
145-
146147
outputs = llm.generate(
147148
prompts=("Give an example JSON object for a grade "
148149
"that fits this schema: "
@@ -161,21 +162,9 @@ def test_guided_json_unsupported_schema(
161162
parsed_json = json.loads(generated_text)
162163
assert isinstance(parsed_json, dict)
163164

164-
165-
@pytest.mark.skip_global_cleanup
166-
@pytest.mark.parametrize("guided_decoding_backend",
167-
GUIDED_DECODING_BACKENDS_V1)
168-
@pytest.mark.parametrize("model_name", MODELS_TO_TEST)
169-
def test_guided_grammar_ebnf(
170-
monkeypatch: pytest.MonkeyPatch,
171-
sample_sql_ebnf: str,
172-
guided_decoding_backend: str,
173-
model_name: str,
174-
):
175-
monkeypatch.setenv("VLLM_USE_V1", "1")
176-
llm = LLM(model=model_name,
177-
max_model_len=1024,
178-
guided_decoding_backend=guided_decoding_backend)
165+
#
166+
# Test 4: Generate SQL statement using EBNF grammar
167+
#
179168
sampling_params = SamplingParams(
180169
temperature=0.8,
181170
top_p=0.95,
@@ -205,21 +194,9 @@ def test_guided_grammar_ebnf(
205194

206195
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
207196

208-
209-
@pytest.mark.skip_global_cleanup
210-
@pytest.mark.parametrize("guided_decoding_backend",
211-
GUIDED_DECODING_BACKENDS_V1)
212-
@pytest.mark.parametrize("model_name", MODELS_TO_TEST)
213-
def test_guided_grammar_lark(
214-
monkeypatch: pytest.MonkeyPatch,
215-
sample_sql_lark: str,
216-
guided_decoding_backend: str,
217-
model_name: str,
218-
):
219-
monkeypatch.setenv("VLLM_USE_V1", "1")
220-
llm = LLM(model=model_name,
221-
max_model_len=1024,
222-
guided_decoding_backend=guided_decoding_backend)
197+
#
198+
# Test 5: Generate SQL statement using Lark grammar
199+
#
223200
sampling_params = SamplingParams(
224201
temperature=0.8,
225202
top_p=0.95,
@@ -254,20 +231,9 @@ def test_guided_grammar_lark(
254231

255232
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
256233

257-
258-
@pytest.mark.skip_global_cleanup
259-
@pytest.mark.parametrize("guided_decoding_backend",
260-
GUIDED_DECODING_BACKENDS_V1)
261-
@pytest.mark.parametrize("model_name", MODELS_TO_TEST)
262-
def test_guided_grammar_ebnf_invalid(
263-
monkeypatch: pytest.MonkeyPatch,
264-
guided_decoding_backend: str,
265-
model_name: str,
266-
):
267-
monkeypatch.setenv("VLLM_USE_V1", "1")
268-
llm = LLM(model=model_name,
269-
max_model_len=1024,
270-
guided_decoding_backend=guided_decoding_backend)
234+
#
235+
# Test 6: Test invalid grammar input
236+
#
271237
sampling_params = SamplingParams(
272238
temperature=0.8,
273239
top_p=0.95,
@@ -281,21 +247,9 @@ def test_guided_grammar_ebnf_invalid(
281247
use_tqdm=True,
282248
)
283249

284-
285-
@pytest.mark.skip_global_cleanup
286-
@pytest.mark.parametrize("guided_decoding_backend",
287-
GUIDED_DECODING_BACKENDS_V1)
288-
@pytest.mark.parametrize("model_name", MODELS_TO_TEST)
289-
def test_guided_regex(
290-
monkeypatch: pytest.MonkeyPatch,
291-
sample_regex: str,
292-
guided_decoding_backend: str,
293-
model_name: str,
294-
):
295-
monkeypatch.setenv("VLLM_USE_V1", "1")
296-
llm = LLM(model=model_name,
297-
max_model_len=1024,
298-
guided_decoding_backend=guided_decoding_backend)
250+
#
251+
# Test 7: Generate text based on a regex pattern
252+
#
299253
sampling_params = SamplingParams(
300254
temperature=0.8,
301255
top_p=0.95,
@@ -319,21 +273,9 @@ def test_guided_regex(
319273
assert re.fullmatch(sample_regex, generated_text) is not None
320274
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
321275

322-
323-
@pytest.mark.skip_global_cleanup
324-
@pytest.mark.parametrize("guided_decoding_backend",
325-
GUIDED_DECODING_BACKENDS_V1)
326-
@pytest.mark.parametrize("model_name", MODELS_TO_TEST)
327-
def test_guided_choice_completion(
328-
monkeypatch: pytest.MonkeyPatch,
329-
sample_guided_choice: str,
330-
guided_decoding_backend: str,
331-
model_name: str,
332-
):
333-
monkeypatch.setenv("VLLM_USE_V1", "1")
334-
llm = LLM(model=model_name,
335-
max_model_len=1024,
336-
guided_decoding_backend=guided_decoding_backend)
276+
#
277+
# Test 8: Generate text based on a choices
278+
#
337279
sampling_params = SamplingParams(
338280
temperature=0.8,
339281
top_p=0.95,
@@ -353,33 +295,9 @@ def test_guided_choice_completion(
353295
assert generated_text in sample_guided_choice
354296
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
355297

356-
357-
class CarType(str, Enum):
358-
sedan = "sedan"
359-
suv = "SUV"
360-
truck = "Truck"
361-
coupe = "Coupe"
362-
363-
364-
class CarDescription(BaseModel):
365-
brand: str
366-
model: str
367-
car_type: CarType
368-
369-
370-
@pytest.mark.skip_global_cleanup
371-
@pytest.mark.parametrize("guided_decoding_backend",
372-
GUIDED_DECODING_BACKENDS_V1)
373-
@pytest.mark.parametrize("model_name", MODELS_TO_TEST)
374-
def test_guided_json_completion_with_enum(
375-
monkeypatch: pytest.MonkeyPatch,
376-
guided_decoding_backend: str,
377-
model_name: str,
378-
):
379-
monkeypatch.setenv("VLLM_USE_V1", "1")
380-
llm = LLM(model=model_name,
381-
max_model_len=1024,
382-
guided_decoding_backend=guided_decoding_backend)
298+
#
299+
# Test 9: Generate structured output using a Pydantic model with an enum
300+
#
383301
json_schema = CarDescription.model_json_schema()
384302
sampling_params = SamplingParams(
385303
temperature=1.0,
@@ -403,3 +321,41 @@ def test_guided_json_completion_with_enum(
403321
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
404322
output_json = json.loads(generated_text)
405323
jsonschema.validate(instance=output_json, schema=json_schema)
324+
325+
326+
@pytest.mark.skip_global_cleanup
327+
@pytest.mark.parametrize("model_name", MODELS_TO_TEST)
328+
def test_structured_output_auto_mode(
329+
monkeypatch: pytest.MonkeyPatch,
330+
unsupported_json_schema: dict[str, Any],
331+
model_name: str,
332+
):
333+
monkeypatch.setenv("VLLM_USE_V1", "1")
334+
335+
llm = LLM(model=model_name,
336+
max_model_len=1024,
337+
guided_decoding_backend="auto")
338+
339+
sampling_params = SamplingParams(
340+
temperature=1.0,
341+
max_tokens=1000,
342+
guided_decoding=GuidedDecodingParams(json=unsupported_json_schema))
343+
344+
# This would fail with the default of "xgrammar", but in "auto"
345+
# we will handle fallback automatically.
346+
outputs = llm.generate(prompts=("Give an example JSON object for a grade "
347+
"that fits this schema: "
348+
f"{unsupported_json_schema}"),
349+
sampling_params=sampling_params,
350+
use_tqdm=True)
351+
assert outputs is not None
352+
for output in outputs:
353+
assert output is not None
354+
assert isinstance(output, RequestOutput)
355+
generated_text = output.outputs[0].text
356+
assert generated_text is not None
357+
print(generated_text)
358+
359+
# Parse to verify it is valid JSON
360+
parsed_json = json.loads(generated_text)
361+
assert isinstance(parsed_json, dict)

0 commit comments

Comments
 (0)