diff --git a/tensorrt_llm/serve/openai_protocol.py b/tensorrt_llm/serve/openai_protocol.py index 84594cd473f..f381757d448 100644 --- a/tensorrt_llm/serve/openai_protocol.py +++ b/tensorrt_llm/serve/openai_protocol.py @@ -52,8 +52,9 @@ class StructuralTag(OpenAIBaseModel): class ResponseFormat(OpenAIBaseModel): - # type must be "json_object" or "text" or "structural_tag" - type: Literal["text", "json_object", "structural_tag"] + # type must be one of "text", "json", "json_object", or "structural_tag" + type: Literal["text", "json", "json_object", "structural_tag"] + schema: Optional[dict] = None structures: Optional[List[StructuralTag]] = None triggers: Optional[List[str]] = None @@ -142,6 +143,12 @@ def _response_format_to_guided_decoding_params( return None elif response_format.type == "text": return None + elif response_format.type == "json": + if response_format.schema is None: + raise ValueError( + "The 'schema' field is required when response_format.type is 'json'." + ) + return GuidedDecodingParams(json=response_format.schema) elif response_format.type == "json_object": return GuidedDecodingParams(json_object=True) elif response_format.type == "structural_tag": @@ -205,7 +212,7 @@ class CompletionRequest(OpenAIBaseModel): default=None, description= ("Similar to chat completion, this parameter specifies the format of " - "output. {'type': 'json_object'}, {'type': 'text' }, {'type': 'structural_tag'} are " + "output. {'type': 'json_object'}, {'type': 'text' }, {'type': 'structural_tag'}, {'type': 'json'} are " "supported."), ) diff --git a/tests/integration/defs/test_e2e.py b/tests/integration/defs/test_e2e.py index 9cfd2eed341..780f096e056 100644 --- a/tests/integration/defs/test_e2e.py +++ b/tests/integration/defs/test_e2e.py @@ -1443,6 +1443,14 @@ def test_openai_chat_structural_tag_example(llm_venv): ]) +def test_openai_chat_json_example(llm_venv): + test_root = unittest_path() / "llmapi" / "apps" + + llm_venv.run_cmd( + ["-m", "pytest", + str(test_root / "_test_openai_chat_json.py")]) + + @pytest.mark.skip_less_device(2) @pytest.mark.skip_less_device_memory(40000) def test_openai_multi_chat_example(llm_root, llm_venv): diff --git a/tests/integration/test_lists/test-db/l0_a10.yml b/tests/integration/test_lists/test-db/l0_a10.yml index 5799ea27945..43100ad1bce 100644 --- a/tests/integration/test_lists/test-db/l0_a10.yml +++ b/tests/integration/test_lists/test-db/l0_a10.yml @@ -22,6 +22,7 @@ l0_a10: - disaggregated/test_disaggregated.py::test_disaggregated_mixed[TinyLlama-1.1B-Chat-v1.0] - disaggregated/test_disaggregated.py::test_disaggregated_overlap[TinyLlama-1.1B-Chat-v1.0] - test_e2e.py::test_openai_chat_structural_tag_example + - test_e2e.py::test_openai_chat_json_example - test_e2e.py::test_openai_chat_multimodal_example - test_e2e.py::test_openai_lora - test_e2e.py::test_trtllm_serve_multimodal_example diff --git a/tests/unittest/llmapi/apps/_test_openai_chat_json.py b/tests/unittest/llmapi/apps/_test_openai_chat_json.py new file mode 100644 index 00000000000..a0b263d1fc1 --- /dev/null +++ b/tests/unittest/llmapi/apps/_test_openai_chat_json.py @@ -0,0 +1,145 @@ +# Adapted from +# https://github.com/vllm-project/vllm/blob/aae6927be06dedbda39c6b0c30f6aa3242b84388/tests/entrypoints/openai/test_chat.py +import os +import tempfile + +import openai +import pytest +import yaml + +from ..test_llm import get_model_path +from .openai_server import RemoteOpenAIServer + +pytestmark = pytest.mark.threadleak(enabled=False) + + +@pytest.fixture(scope="module", ids=["llama-3.1-model/Llama-3.1-8B-Instruct"]) +def model_name(): + return "llama-3.1-model/Llama-3.1-8B-Instruct" + + +@pytest.fixture(scope="module") +def temp_extra_llm_api_options_file(request): + temp_dir = tempfile.gettempdir() + temp_file_path = os.path.join(temp_dir, "extra_llm_api_options.yaml") + try: + extra_llm_api_options_dict = { + "guided_decoding_backend": "xgrammar", + "disable_overlap_scheduler": + True, # Guided decoding is not supported with overlap scheduler + } + + with open(temp_file_path, "w") as f: + yaml.dump(extra_llm_api_options_dict, f) + + yield temp_file_path + finally: + if os.path.exists(temp_file_path): + os.remove(temp_file_path) + + +@pytest.fixture(scope="module") +def server(model_name: str, temp_extra_llm_api_options_file: str): + model_path = get_model_path(model_name) + args = [ + "--backend", "pytorch", "--extra_llm_api_options", + temp_extra_llm_api_options_file + ] + with RemoteOpenAIServer(model_path, args) as remote_server: + yield remote_server + + +@pytest.fixture(scope="module") +def client(server: RemoteOpenAIServer): + return server.get_client() + + +@pytest.fixture(scope="module") +def async_client(server: RemoteOpenAIServer): + return server.get_async_client() + + +@pytest.fixture(scope="module") +def user_profile_schema(): + """Provides a sample JSON schema for a user profile.""" + return { + "type": "object", + "properties": { + "name": { + "type": "string", + "description": "The full name of the user." + }, + "age": { + "type": "integer", + "description": "The age of the user, in years." + }, + }, + "required": ["name", "age"], + } + + +def test_chat_json_schema(client: openai.OpenAI, model_name: str): + """ + Tests the `json` response format in a multi-turn synchronous conversation. + Adapted from https://github.com/vllm-project/vllm/blob/aae6927be06dedbda39c6b0c30f6aa3242b84388/tests/entrypoints/openai/test_chat.py#L413 + """ + + def _create_and_validate_response( + messages: list[dict[str, Any]]) -> dict[str, any]: + chat_completion = client.chat.completions.create( + model=model_name, + messages=messages, + max_tokens=1000, + temperature=0.0, + response_format={ + "type": "json", + "schema": user_profile_schema + }, + ) + message = chat_completion.choices[0].message + assert message.content is not None + + try: + message_json = json.loads(message.content) + except json.JSONDecodeError: + pytest.fail( + f"The output was not a valid JSON string. Output: {output_text}" + ) + + jsonschema.validate(instance=message_json, schema=user_profile_schema) + return message_json + + messages = [ + { + "role": "system", + "content": "you are a helpful assistant" + }, + { + "role": + "user", + "content": + f"Give an example JSON for an employee profile that " + f"fits this schema: {user_profile_schema}", + }, + ] + + first_json = _create_and_validate_response(messages) + + messages.extend([ + { + "role": "assistant", + "content": first_message.content, + }, + { + "role": "user", + "content": "Give me another one with a different name and age.", + }, + ]) + second_json = _create_and_validate_response(messages) + + assert ( + first_json["name"] != second_json["name"] + ), "The model should have generated a different name in the second turn." + assert ( + first_json["age"] != second_json["age"] + ), "The model should have generated a different age in the second turn."