Skip to content

Commit 156c571

Browse files
noijinv-guomingz
authored andcommitted
feat: Support JSON Schema in OpenAI-Compatible API
Signed-off-by: noiji <[email protected]>
1 parent 428e340 commit 156c571

File tree

4 files changed

+167
-3
lines changed

4 files changed

+167
-3
lines changed

tensorrt_llm/serve/openai_protocol.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,9 @@ class StructuralTag(OpenAIBaseModel):
5252

5353

5454
class ResponseFormat(OpenAIBaseModel):
55-
# type must be "json_object" or "text" or "structural_tag"
56-
type: Literal["text", "json_object", "structural_tag"]
55+
# type must be one of "text", "json", "json_object", or "structural_tag"
56+
type: Literal["text", "json", "json_object", "structural_tag"]
57+
schema: Optional[dict] = None
5758
structures: Optional[List[StructuralTag]] = None
5859
triggers: Optional[List[str]] = None
5960

@@ -142,6 +143,12 @@ def _response_format_to_guided_decoding_params(
142143
return None
143144
elif response_format.type == "text":
144145
return None
146+
elif response_format.type == "json":
147+
if response_format.schema is None:
148+
raise ValueError(
149+
"The 'schema' field is required when response_format.type is 'json'."
150+
)
151+
return GuidedDecodingParams(json=response_format.schema)
145152
elif response_format.type == "json_object":
146153
return GuidedDecodingParams(json_object=True)
147154
elif response_format.type == "structural_tag":
@@ -205,7 +212,7 @@ class CompletionRequest(OpenAIBaseModel):
205212
default=None,
206213
description=
207214
("Similar to chat completion, this parameter specifies the format of "
208-
"output. {'type': 'json_object'}, {'type': 'text' }, {'type': 'structural_tag'} are "
215+
"output. {'type': 'json_object'}, {'type': 'text' }, {'type': 'structural_tag'}, {'type': 'json'} are "
209216
"supported."),
210217
)
211218

tests/integration/defs/test_e2e.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1443,6 +1443,14 @@ def test_openai_chat_structural_tag_example(llm_venv):
14431443
])
14441444

14451445

1446+
def test_openai_chat_json_example(llm_venv):
1447+
test_root = unittest_path() / "llmapi" / "apps"
1448+
1449+
llm_venv.run_cmd(
1450+
["-m", "pytest",
1451+
str(test_root / "_test_openai_chat_json.py")])
1452+
1453+
14461454
@pytest.mark.skip_less_device(2)
14471455
@pytest.mark.skip_less_device_memory(40000)
14481456
def test_openai_multi_chat_example(llm_root, llm_venv):

tests/integration/test_lists/test-db/l0_a10.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ l0_a10:
2222
- disaggregated/test_disaggregated.py::test_disaggregated_mixed[TinyLlama-1.1B-Chat-v1.0]
2323
- disaggregated/test_disaggregated.py::test_disaggregated_overlap[TinyLlama-1.1B-Chat-v1.0]
2424
- test_e2e.py::test_openai_chat_structural_tag_example
25+
- test_e2e.py::test_openai_chat_json_example
2526
- test_e2e.py::test_openai_chat_multimodal_example
2627
- test_e2e.py::test_openai_lora
2728
- test_e2e.py::test_trtllm_serve_multimodal_example
Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
# Adapted from
2+
# https://github.com/vllm-project/vllm/blob/aae6927be06dedbda39c6b0c30f6aa3242b84388/tests/entrypoints/openai/test_chat.py
3+
import json
4+
import os
5+
import tempfile
6+
from typing import Any
7+
8+
import jsonschema
9+
import openai
10+
import pytest
11+
import yaml
12+
13+
from ..test_llm import get_model_path
14+
from .openai_server import RemoteOpenAIServer
15+
16+
pytestmark = pytest.mark.threadleak(enabled=False)
17+
18+
19+
@pytest.fixture(scope="module", ids=["TinyLlama-1.1B-Chat"])
20+
def model_name():
21+
return "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
22+
23+
24+
@pytest.fixture(scope="module")
25+
def temp_extra_llm_api_options_file(request):
26+
temp_dir = tempfile.gettempdir()
27+
temp_file_path = os.path.join(temp_dir, "extra_llm_api_options.yaml")
28+
try:
29+
extra_llm_api_options_dict = {
30+
"guided_decoding_backend": "xgrammar",
31+
"disable_overlap_scheduler":
32+
True, # Guided decoding is not supported with overlap scheduler
33+
}
34+
35+
with open(temp_file_path, "w") as f:
36+
yaml.dump(extra_llm_api_options_dict, f)
37+
38+
yield temp_file_path
39+
finally:
40+
if os.path.exists(temp_file_path):
41+
os.remove(temp_file_path)
42+
43+
44+
@pytest.fixture(scope="module")
45+
def server(model_name: str, temp_extra_llm_api_options_file: str):
46+
model_path = get_model_path(model_name)
47+
args = [
48+
"--backend", "pytorch", "--extra_llm_api_options",
49+
temp_extra_llm_api_options_file
50+
]
51+
with RemoteOpenAIServer(model_path, args) as remote_server:
52+
yield remote_server
53+
54+
55+
@pytest.fixture(scope="module")
56+
def client(server: RemoteOpenAIServer):
57+
return server.get_client()
58+
59+
60+
@pytest.fixture(scope="module")
61+
def async_client(server: RemoteOpenAIServer):
62+
return server.get_async_client()
63+
64+
65+
@pytest.fixture(scope="module")
66+
def user_profile_schema():
67+
"""Provides a sample JSON schema for a user profile."""
68+
return {
69+
"type": "object",
70+
"properties": {
71+
"name": {
72+
"type": "string",
73+
"description": "The full name of the user."
74+
},
75+
"age": {
76+
"type": "integer",
77+
"description": "The age of the user, in years."
78+
},
79+
},
80+
"required": ["name", "age"],
81+
}
82+
83+
84+
def test_chat_json_schema(client: openai.OpenAI, model_name: str):
85+
"""
86+
Tests the `json` response format in a multi-turn synchronous conversation.
87+
Adapted from https://github.com/vllm-project/vllm/blob/aae6927be06dedbda39c6b0c30f6aa3242b84388/tests/entrypoints/openai/test_chat.py#L413
88+
"""
89+
90+
def _create_and_validate_response(
91+
messages: list[dict[str, Any]]) -> dict[str, Any]:
92+
chat_completion = client.chat.completions.create(
93+
model=model_name,
94+
messages=messages,
95+
max_tokens=1000,
96+
temperature=0.0,
97+
response_format={
98+
"type": "json",
99+
"schema": user_profile_schema
100+
},
101+
)
102+
message = chat_completion.choices[0].message
103+
assert message.content is not None
104+
105+
try:
106+
message_json = json.loads(message.content)
107+
except json.JSONDecodeError:
108+
pytest.fail(
109+
f"The output was not a valid JSON string. Output: {message.content}"
110+
)
111+
112+
jsonschema.validate(instance=message_json, schema=user_profile_schema)
113+
return message_json
114+
115+
messages = [
116+
{
117+
"role": "system",
118+
"content": "you are a helpful assistant"
119+
},
120+
{
121+
"role":
122+
"user",
123+
"content":
124+
f"Give an example JSON for an employee profile that "
125+
f"fits this schema: {user_profile_schema}",
126+
},
127+
]
128+
129+
first_json = _create_and_validate_response(messages)
130+
131+
messages.extend([
132+
{
133+
"role": "assistant",
134+
"content": first_message.content,
135+
},
136+
{
137+
"role": "user",
138+
"content": "Give me another one with a different name and age.",
139+
},
140+
])
141+
second_json = _create_and_validate_response(messages)
142+
143+
assert (
144+
first_json["name"] != second_json["name"]
145+
), "The model should have generated a different name in the second turn."
146+
assert (
147+
first_json["age"] != second_json["age"]
148+
), "The model should have generated a different age in the second turn."

0 commit comments

Comments
 (0)