Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions examples/serve/openai_completion_client_json_schema.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
### :title OpenAI Completion Client with JSON Schema

# This example requires to specify `guided_decoding_backend` as
# `xgrammar` or `llguidance` in the extra_llm_api_options.yaml file.
import json

from openai import OpenAI

client = OpenAI(
Expand All @@ -18,7 +22,6 @@
"content":
f"Give me the information of the biggest city of China in the JSON format.",
}],
max_tokens=100,
temperature=0,
response_format={
"type": "json",
Expand All @@ -39,4 +42,11 @@
}
},
)
print(response.choices[0].message.content)

content = response.choices[0].message.content
try:
response_json = json.loads(content)
assert "name" in response_json and "population" in response_json
print(content)
except json.JSONDecodeError:
print("Failed to decode JSON response")
44 changes: 37 additions & 7 deletions tests/unittest/llmapi/apps/_test_trtllm_serve_example.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import json
import os
import subprocess
import sys
import tempfile

import pytest
import yaml

from .openai_server import RemoteOpenAIServer

Expand All @@ -16,10 +19,26 @@ def model_name():


@pytest.fixture(scope="module")
def server(model_name: str):
def temp_extra_llm_api_options_file():
temp_dir = tempfile.gettempdir()
temp_file_path = os.path.join(temp_dir, "extra_llm_api_options.yaml")
try:
extra_llm_api_options_dict = {"guided_decoding_backend": "xgrammar"}
with open(temp_file_path, 'w') as f:
yaml.dump(extra_llm_api_options_dict, f)

yield temp_file_path
finally:
if os.path.exists(temp_file_path):
os.remove(temp_file_path)


@pytest.fixture(scope="module")
def server(model_name: str, temp_extra_llm_api_options_file: str):
model_path = get_model_path(model_name)
# fix port to facilitate concise trtllm-serve examples
with RemoteOpenAIServer(model_path, port=8000) as remote_server:
args = ["--extra_llm_api_options", temp_extra_llm_api_options_file]
with RemoteOpenAIServer(model_path, args, port=8000) as remote_server:
yield remote_server


Expand All @@ -40,8 +59,19 @@ def test_trtllm_serve_examples(exe: str, script: str,
server: RemoteOpenAIServer, example_root: str):
client_script = os.path.join(example_root, script)
# CalledProcessError will be raised if any errors occur
subprocess.run([exe, client_script],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
check=True)
result = subprocess.run([exe, client_script],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
check=True)
if script.startswith("curl"):
# For curl scripts, we expect a JSON response
result_stdout = result.stdout.strip()
try:
data = json.loads(result_stdout)
assert "code" not in data or data[
"code"] == 200, f"Unexpected response: {data}"
except json.JSONDecodeError as e:
pytest.fail(
f"Failed to parse JSON response from {script}: {e}\nStdout: {result_stdout}\nStderr: {result.stderr}"
)