2323]
2424
2525
26+ class CarType (str , Enum ):
27+ sedan = "sedan"
28+ suv = "SUV"
29+ truck = "Truck"
30+ coupe = "Coupe"
31+
32+
33+ class CarDescription (BaseModel ):
34+ brand : str
35+ model : str
36+ car_type : CarType
37+
38+
2639@pytest .mark .skip_global_cleanup
2740@pytest .mark .parametrize ("guided_decoding_backend" ,
2841 GUIDED_DECODING_BACKENDS_V1 )
2942@pytest .mark .parametrize ("model_name" , MODELS_TO_TEST )
30- def test_guided_json_completion (
43+ def test_structured_output (
3144 monkeypatch : pytest .MonkeyPatch ,
3245 sample_json_schema : dict [str , Any ],
46+ unsupported_json_schema : dict [str , Any ],
47+ sample_sql_ebnf : str ,
48+ sample_sql_lark : str ,
49+ sample_regex : str ,
50+ sample_guided_choice : str ,
3351 guided_decoding_backend : str ,
3452 model_name : str ,
3553):
3654 monkeypatch .setenv ("VLLM_USE_V1" , "1" )
55+
56+ # Use a single LLM instance for several scenarios to
57+ # speed up the test suite.
3758 llm = LLM (model = model_name ,
59+ enforce_eager = True ,
3860 max_model_len = 1024 ,
3961 guided_decoding_backend = guided_decoding_backend )
62+
63+ #
64+ # Test 1: Generate JSON output based on a provided schema
65+ #
4066 sampling_params = SamplingParams (
4167 temperature = 1.0 ,
4268 max_tokens = 1000 ,
@@ -63,20 +89,9 @@ def test_guided_json_completion(
6389 output_json = json .loads (generated_text )
6490 jsonschema .validate (instance = output_json , schema = sample_json_schema )
6591
66-
67- @pytest .mark .skip_global_cleanup
68- @pytest .mark .parametrize ("guided_decoding_backend" ,
69- GUIDED_DECODING_BACKENDS_V1 )
70- @pytest .mark .parametrize ("model_name" , MODELS_TO_TEST )
71- def test_guided_json_object (
72- monkeypatch : pytest .MonkeyPatch ,
73- guided_decoding_backend : str ,
74- model_name : str ,
75- ):
76- monkeypatch .setenv ("VLLM_USE_V1" , "1" )
77- llm = LLM (model = model_name ,
78- max_model_len = 1024 ,
79- guided_decoding_backend = guided_decoding_backend )
92+ #
93+ # Test 2: Generate JSON object without a schema
94+ #
8095 sampling_params = SamplingParams (
8196 temperature = 1.0 ,
8297 max_tokens = 100 ,
@@ -111,21 +126,9 @@ def test_guided_json_object(
111126 allowed_types = (dict , list )
112127 assert isinstance (parsed_json , allowed_types )
113128
114-
115- @pytest .mark .skip_global_cleanup
116- @pytest .mark .parametrize ("guided_decoding_backend" ,
117- GUIDED_DECODING_BACKENDS_V1 + ["auto" ])
118- @pytest .mark .parametrize ("model_name" , MODELS_TO_TEST )
119- def test_guided_json_unsupported_schema (
120- monkeypatch : pytest .MonkeyPatch ,
121- unsupported_json_schema : dict [str , Any ],
122- guided_decoding_backend : str ,
123- model_name : str ,
124- ):
125- monkeypatch .setenv ("VLLM_USE_V1" , "1" )
126- llm = LLM (model = model_name ,
127- max_model_len = 1024 ,
128- guided_decoding_backend = guided_decoding_backend )
129+ #
130+ # Test 3: test a jsonschema incompatible with xgrammar
131+ #
129132 sampling_params = SamplingParams (
130133 temperature = 1.0 ,
131134 max_tokens = 1000 ,
@@ -141,8 +144,6 @@ def test_guided_json_unsupported_schema(
141144 sampling_params = sampling_params ,
142145 use_tqdm = True )
143146 else :
144- # This should work for both "guidance" and "auto".
145-
146147 outputs = llm .generate (
147148 prompts = ("Give an example JSON object for a grade "
148149 "that fits this schema: "
@@ -161,21 +162,9 @@ def test_guided_json_unsupported_schema(
161162 parsed_json = json .loads (generated_text )
162163 assert isinstance (parsed_json , dict )
163164
164-
165- @pytest .mark .skip_global_cleanup
166- @pytest .mark .parametrize ("guided_decoding_backend" ,
167- GUIDED_DECODING_BACKENDS_V1 )
168- @pytest .mark .parametrize ("model_name" , MODELS_TO_TEST )
169- def test_guided_grammar_ebnf (
170- monkeypatch : pytest .MonkeyPatch ,
171- sample_sql_ebnf : str ,
172- guided_decoding_backend : str ,
173- model_name : str ,
174- ):
175- monkeypatch .setenv ("VLLM_USE_V1" , "1" )
176- llm = LLM (model = model_name ,
177- max_model_len = 1024 ,
178- guided_decoding_backend = guided_decoding_backend )
165+ #
166+ # Test 4: Generate SQL statement using EBNF grammar
167+ #
179168 sampling_params = SamplingParams (
180169 temperature = 0.8 ,
181170 top_p = 0.95 ,
@@ -205,21 +194,9 @@ def test_guided_grammar_ebnf(
205194
206195 print (f"Prompt: { prompt !r} , Generated text: { generated_text !r} " )
207196
208-
209- @pytest .mark .skip_global_cleanup
210- @pytest .mark .parametrize ("guided_decoding_backend" ,
211- GUIDED_DECODING_BACKENDS_V1 )
212- @pytest .mark .parametrize ("model_name" , MODELS_TO_TEST )
213- def test_guided_grammar_lark (
214- monkeypatch : pytest .MonkeyPatch ,
215- sample_sql_lark : str ,
216- guided_decoding_backend : str ,
217- model_name : str ,
218- ):
219- monkeypatch .setenv ("VLLM_USE_V1" , "1" )
220- llm = LLM (model = model_name ,
221- max_model_len = 1024 ,
222- guided_decoding_backend = guided_decoding_backend )
197+ #
198+ # Test 5: Generate SQL statement using Lark grammar
199+ #
223200 sampling_params = SamplingParams (
224201 temperature = 0.8 ,
225202 top_p = 0.95 ,
@@ -254,20 +231,9 @@ def test_guided_grammar_lark(
254231
255232 print (f"Prompt: { prompt !r} , Generated text: { generated_text !r} " )
256233
257-
258- @pytest .mark .skip_global_cleanup
259- @pytest .mark .parametrize ("guided_decoding_backend" ,
260- GUIDED_DECODING_BACKENDS_V1 )
261- @pytest .mark .parametrize ("model_name" , MODELS_TO_TEST )
262- def test_guided_grammar_ebnf_invalid (
263- monkeypatch : pytest .MonkeyPatch ,
264- guided_decoding_backend : str ,
265- model_name : str ,
266- ):
267- monkeypatch .setenv ("VLLM_USE_V1" , "1" )
268- llm = LLM (model = model_name ,
269- max_model_len = 1024 ,
270- guided_decoding_backend = guided_decoding_backend )
234+ #
235+ # Test 6: Test invalid grammar input
236+ #
271237 sampling_params = SamplingParams (
272238 temperature = 0.8 ,
273239 top_p = 0.95 ,
@@ -281,21 +247,9 @@ def test_guided_grammar_ebnf_invalid(
281247 use_tqdm = True ,
282248 )
283249
284-
285- @pytest .mark .skip_global_cleanup
286- @pytest .mark .parametrize ("guided_decoding_backend" ,
287- GUIDED_DECODING_BACKENDS_V1 )
288- @pytest .mark .parametrize ("model_name" , MODELS_TO_TEST )
289- def test_guided_regex (
290- monkeypatch : pytest .MonkeyPatch ,
291- sample_regex : str ,
292- guided_decoding_backend : str ,
293- model_name : str ,
294- ):
295- monkeypatch .setenv ("VLLM_USE_V1" , "1" )
296- llm = LLM (model = model_name ,
297- max_model_len = 1024 ,
298- guided_decoding_backend = guided_decoding_backend )
250+ #
251+ # Test 7: Generate text based on a regex pattern
252+ #
299253 sampling_params = SamplingParams (
300254 temperature = 0.8 ,
301255 top_p = 0.95 ,
@@ -319,21 +273,9 @@ def test_guided_regex(
319273 assert re .fullmatch (sample_regex , generated_text ) is not None
320274 print (f"Prompt: { prompt !r} , Generated text: { generated_text !r} " )
321275
322-
323- @pytest .mark .skip_global_cleanup
324- @pytest .mark .parametrize ("guided_decoding_backend" ,
325- GUIDED_DECODING_BACKENDS_V1 )
326- @pytest .mark .parametrize ("model_name" , MODELS_TO_TEST )
327- def test_guided_choice_completion (
328- monkeypatch : pytest .MonkeyPatch ,
329- sample_guided_choice : str ,
330- guided_decoding_backend : str ,
331- model_name : str ,
332- ):
333- monkeypatch .setenv ("VLLM_USE_V1" , "1" )
334- llm = LLM (model = model_name ,
335- max_model_len = 1024 ,
336- guided_decoding_backend = guided_decoding_backend )
276+ #
277+ # Test 8: Generate text based on a choices
278+ #
337279 sampling_params = SamplingParams (
338280 temperature = 0.8 ,
339281 top_p = 0.95 ,
@@ -353,33 +295,9 @@ def test_guided_choice_completion(
353295 assert generated_text in sample_guided_choice
354296 print (f"Prompt: { prompt !r} , Generated text: { generated_text !r} " )
355297
356-
357- class CarType (str , Enum ):
358- sedan = "sedan"
359- suv = "SUV"
360- truck = "Truck"
361- coupe = "Coupe"
362-
363-
364- class CarDescription (BaseModel ):
365- brand : str
366- model : str
367- car_type : CarType
368-
369-
370- @pytest .mark .skip_global_cleanup
371- @pytest .mark .parametrize ("guided_decoding_backend" ,
372- GUIDED_DECODING_BACKENDS_V1 )
373- @pytest .mark .parametrize ("model_name" , MODELS_TO_TEST )
374- def test_guided_json_completion_with_enum (
375- monkeypatch : pytest .MonkeyPatch ,
376- guided_decoding_backend : str ,
377- model_name : str ,
378- ):
379- monkeypatch .setenv ("VLLM_USE_V1" , "1" )
380- llm = LLM (model = model_name ,
381- max_model_len = 1024 ,
382- guided_decoding_backend = guided_decoding_backend )
298+ #
299+ # Test 9: Generate structured output using a Pydantic model with an enum
300+ #
383301 json_schema = CarDescription .model_json_schema ()
384302 sampling_params = SamplingParams (
385303 temperature = 1.0 ,
@@ -403,3 +321,41 @@ def test_guided_json_completion_with_enum(
403321 print (f"Prompt: { prompt !r} , Generated text: { generated_text !r} " )
404322 output_json = json .loads (generated_text )
405323 jsonschema .validate (instance = output_json , schema = json_schema )
324+
325+
326+ @pytest .mark .skip_global_cleanup
327+ @pytest .mark .parametrize ("model_name" , MODELS_TO_TEST )
328+ def test_structured_output_auto_mode (
329+ monkeypatch : pytest .MonkeyPatch ,
330+ unsupported_json_schema : dict [str , Any ],
331+ model_name : str ,
332+ ):
333+ monkeypatch .setenv ("VLLM_USE_V1" , "1" )
334+
335+ llm = LLM (model = model_name ,
336+ max_model_len = 1024 ,
337+ guided_decoding_backend = "auto" )
338+
339+ sampling_params = SamplingParams (
340+ temperature = 1.0 ,
341+ max_tokens = 1000 ,
342+ guided_decoding = GuidedDecodingParams (json = unsupported_json_schema ))
343+
344+ # This would fail with the default of "xgrammar", but in "auto"
345+ # we will handle fallback automatically.
346+ outputs = llm .generate (prompts = ("Give an example JSON object for a grade "
347+ "that fits this schema: "
348+ f"{ unsupported_json_schema } " ),
349+ sampling_params = sampling_params ,
350+ use_tqdm = True )
351+ assert outputs is not None
352+ for output in outputs :
353+ assert output is not None
354+ assert isinstance (output , RequestOutput )
355+ generated_text = output .outputs [0 ].text
356+ assert generated_text is not None
357+ print (generated_text )
358+
359+ # Parse to verify it is valid JSON
360+ parsed_json = json .loads (generated_text )
361+ assert isinstance (parsed_json , dict )
0 commit comments