diff --git a/cpp/include/tensorrt_llm/executor/executor.h b/cpp/include/tensorrt_llm/executor/executor.h index e492c1beaad..b2aafd02701 100644 --- a/cpp/include/tensorrt_llm/executor/executor.h +++ b/cpp/include/tensorrt_llm/executor/executor.h @@ -483,6 +483,9 @@ class GuidedDecodingParams /// @brief The generated text is amenable to the user-specified extended Backus-Naur form (EBNF) grammar. /// EBNF grammar is widely-used to express context-free grammars. kEBNF_GRAMMAR = 3, + + /// @brief The generated text is amenable to the XGrammar structural tag. + kSTRUCTURAL_TAG = 4, }; explicit GuidedDecodingParams(GuideType guideType, std::optional guide = std::nullopt); diff --git a/cpp/tensorrt_llm/pybind/executor/request.cpp b/cpp/tensorrt_llm/pybind/executor/request.cpp index 9d504e0c264..e0a5eedd4da 100644 --- a/cpp/tensorrt_llm/pybind/executor/request.cpp +++ b/cpp/tensorrt_llm/pybind/executor/request.cpp @@ -462,7 +462,8 @@ void initRequestBindings(pybind11::module_& m) .value("JSON", tle::GuidedDecodingParams::GuideType::kJSON) .value("JSON_SCHEMA", tle::GuidedDecodingParams::GuideType::kJSON_SCHEMA) .value("REGEX", tle::GuidedDecodingParams::GuideType::kREGEX) - .value("EBNF_GRAMMAR", tle::GuidedDecodingParams::GuideType::kEBNF_GRAMMAR); + .value("EBNF_GRAMMAR", tle::GuidedDecodingParams::GuideType::kEBNF_GRAMMAR) + .value("STRUCTURAL_TAG", tle::GuidedDecodingParams::GuideType::kSTRUCTURAL_TAG); auto guidedDecodingParamsGetstate = [](tle::GuidedDecodingParams const& self) { return py::make_tuple(self.getGuideType(), self.getGuide()); }; diff --git a/tensorrt_llm/_torch/pyexecutor/_util.py b/tensorrt_llm/_torch/pyexecutor/_util.py index 865975b4c4b..8a1b757c52a 100644 --- a/tensorrt_llm/_torch/pyexecutor/_util.py +++ b/tensorrt_llm/_torch/pyexecutor/_util.py @@ -369,6 +369,9 @@ def create_py_executor_instance(dist, if spec_config is not None: raise ValueError( "Guided decoding is not supported with speculative decoding.") + if pytorch_backend_config.enable_overlap_scheduler: + raise ValueError( + "Guided decoding is not supported with overlap scheduler.") logger.info( f"max_seq_len={executor_config.max_seq_len}, max_num_requests={executor_config.max_batch_size}, max_num_tokens={executor_config.max_num_tokens}" diff --git a/tensorrt_llm/_torch/pyexecutor/guided_decoder.py b/tensorrt_llm/_torch/pyexecutor/guided_decoder.py index 2e0b0b886bb..a4ea0c1f7cd 100644 --- a/tensorrt_llm/_torch/pyexecutor/guided_decoder.py +++ b/tensorrt_llm/_torch/pyexecutor/guided_decoder.py @@ -1,4 +1,5 @@ import itertools +import json import math from typing import List, Optional @@ -82,6 +83,18 @@ def build(self, scheduled_requests: ScheduledRequests, grammar = xgrammar.Grammar.from_ebnf(guide) compiled_grammar = self.xgrammar_compiler.compile_grammar( grammar) + case GuidedDecodingParams.GuideType.STRUCTURAL_TAG: + structural_tag_parameters = json.loads(guide) + structures = structural_tag_parameters["structures"] + structures = [ + xgrammar.StructuralTagItem( + begin=s["begin"], + schema=json.dumps(s["schema"]), + end=s["end"]) for s in structures + ] + triggers = structural_tag_parameters["triggers"] + compiled_grammar = self.xgrammar_compiler.compile_structural_tag( + structures, triggers) case _: raise ValueError( f"Unrecognized guide type: {guide_type}.") diff --git a/tensorrt_llm/sampling_params.py b/tensorrt_llm/sampling_params.py index 1e3ee819edc..7e0c756bed4 100644 --- a/tensorrt_llm/sampling_params.py +++ b/tensorrt_llm/sampling_params.py @@ -20,11 +20,13 @@ class GuidedDecodingParams: regex (str, optional): The generated text is amenable to the user-specified regular expression. Defaults to None. grammar (str, optional): The generated text is amenable to the user-specified extended Backus-Naur form (EBNF) grammar. Defaults to None. json_object (bool): If True, the generated text is amenable to json format. Defaults to False. + structural_tag (str, optional): The generated text is amenable to the user-specified structural tag. Defaults to None. """ json: Optional[Union[str, BaseModel, dict]] = None regex: Optional[str] = None grammar: Optional[str] = None json_object: bool = False + structural_tag: Optional[str] = None def _validate(self): num_guides = 0 @@ -451,7 +453,7 @@ def _get_guided_decoding_params(self) -> tllme.GuidedDecodingParams: tllme.GuidedDecodingParams.GuideType.JSON) elif self.guided_decoding.json is not None: json_schema = self.guided_decoding.json - if isinstance(json, BaseModel): + if isinstance(json_schema, BaseModel): json_schema = json_schema.model_json_schema() if isinstance(json_schema, dict): json_schema = json.dumps(json_schema) @@ -465,5 +467,9 @@ def _get_guided_decoding_params(self) -> tllme.GuidedDecodingParams: return tllme.GuidedDecodingParams( tllme.GuidedDecodingParams.GuideType.EBNF_GRAMMAR, self.guided_decoding.grammar) + elif self.guided_decoding.structural_tag is not None: + return tllme.GuidedDecodingParams( + tllme.GuidedDecodingParams.GuideType.STRUCTURAL_TAG, + self.guided_decoding.structural_tag) else: return None diff --git a/tensorrt_llm/serve/openai_protocol.py b/tensorrt_llm/serve/openai_protocol.py index ba9c586c4f6..de27f942e37 100644 --- a/tensorrt_llm/serve/openai_protocol.py +++ b/tensorrt_llm/serve/openai_protocol.py @@ -13,7 +13,7 @@ from typing_extensions import Annotated, Required, TypedDict from tensorrt_llm.llmapi import DisaggregatedParams as LlmDisaggregatedParams -from tensorrt_llm.llmapi import SamplingParams +from tensorrt_llm.llmapi import GuidedDecodingParams, SamplingParams class OpenAIBaseModel(BaseModel): @@ -44,9 +44,17 @@ class ModelList(OpenAIBaseModel): data: List[ModelCard] = Field(default_factory=list) +class StructuralTag(OpenAIBaseModel): + begin: str + schema_: Optional[dict[str, Any]] = Field(alias="schema") + end: str + + class ResponseFormat(OpenAIBaseModel): - # type must be "json_object" or "text" - type: Literal["text", "json_object"] + # type must be "json_object" or "text" or "structural_tag" + type: Literal["text", "json_object", "structural_tag"] + structures: Optional[List[StructuralTag]] = None + triggers: Optional[List[str]] = None class DisaggregatedParams(OpenAIBaseModel): @@ -121,6 +129,23 @@ class CompletionStreamResponse(OpenAIBaseModel): usage: Optional[UsageInfo] = Field(default=None) +def _response_format_to_guided_decoding_params( + response_format: Optional[ResponseFormat] +) -> Optional[GuidedDecodingParams]: + if response_format is None: + return None + elif response_format.type == "text": + return None + elif response_format.type == "json_object": + return GuidedDecodingParams(json_object=True) + elif response_format.type == "structural_tag": + return GuidedDecodingParams( + structural_tag=response_format.model_dump_json(by_alias=True, + exclude_none=True)) + else: + raise ValueError(f"Unsupported response format: {response_format.type}") + + class CompletionRequest(OpenAIBaseModel): # Ordered by official OpenAI API documentation # https://platform.openai.com/docs/api-reference/completions/create @@ -170,10 +195,10 @@ class CompletionRequest(OpenAIBaseModel): ) response_format: Optional[ResponseFormat] = Field( default=None, - description=( - "Similar to chat completion, this parameter specifies the format of " - "output. Only {'type': 'json_object'} or {'type': 'text' } is " - "supported."), + description= + ("Similar to chat completion, this parameter specifies the format of " + "output. {'type': 'json_object'}, {'type': 'text' }, {'type': 'structural_tag'} are " + "supported."), ) disaggregated_params: Optional[DisaggregatedParams] = Field( @@ -211,6 +236,8 @@ def to_sampling_params(self) -> SamplingParams: spaces_between_special_tokens=self.spaces_between_special_tokens, truncate_prompt_tokens=self.truncate_prompt_tokens, return_context_logits=self.return_context_logits, + guided_decoding=_response_format_to_guided_decoding_params( + self.response_format), # completion-extra-params add_special_tokens=self.add_special_tokens, @@ -255,13 +282,6 @@ def verify_multi_responses(cls, data): raise ValueError("best_of should not be smaller than n") return data - @model_validator(mode="before") - @classmethod - def check_response_format(cls, data): - if data.get("response_format"): - raise ValueError("response_format is not supported") - return data - @model_validator(mode="before") @classmethod def check_suffix(cls, data): @@ -520,6 +540,8 @@ def to_sampling_params(self) -> SamplingParams: skip_special_tokens=self.skip_special_tokens, spaces_between_special_tokens=self.spaces_between_special_tokens, truncate_prompt_tokens=self.truncate_prompt_tokens, + guided_decoding=_response_format_to_guided_decoding_params( + self.response_format), # chat-completion-extra-params add_special_tokens=self.add_special_tokens, @@ -582,13 +604,6 @@ def verify_logit_processor(cls, data): raise ValueError("logit bias is not supported") return data - @model_validator(mode="before") - @classmethod - def check_response_format(cls, data): - if data.get("response_format"): - raise ValueError("response_format is not supported") - return data - @model_validator(mode="before") @classmethod def check_suffix(cls, data): diff --git a/tests/integration/defs/test_e2e.py b/tests/integration/defs/test_e2e.py index 7c54d38fabc..6a530c1a302 100644 --- a/tests/integration/defs/test_e2e.py +++ b/tests/integration/defs/test_e2e.py @@ -1121,6 +1121,15 @@ def test_openai_chat_multimodal_example(llm_root, llm_venv): str(test_root / "_test_openai_chat_multimodal.py")]) +def test_openai_chat_structural_tag_example(llm_venv): + test_root = unittest_path() / "llmapi" / "apps" + + llm_venv.run_cmd([ + "-m", "pytest", + str(test_root / "_test_openai_chat_structural_tag.py") + ]) + + @pytest.mark.skip_less_device(2) @pytest.mark.skip_less_device_memory(40000) def test_openai_multi_chat_example(llm_root, llm_venv): diff --git a/tests/integration/test_lists/test-db/l0_a10.yml b/tests/integration/test_lists/test-db/l0_a10.yml index 07e4747e5c2..8cb0a1d7ec2 100644 --- a/tests/integration/test_lists/test-db/l0_a10.yml +++ b/tests/integration/test_lists/test-db/l0_a10.yml @@ -20,6 +20,7 @@ l0_a10: - disaggregated/test_disaggregated.py::test_disaggregated_overlap[TinyLlama-1.1B-Chat-v1.0] - stress_test/stress_test.py::test_run_stress_test[llama-v3-8b-instruct-hf_tp1-stress_time_300s_timeout_450s-MAX_UTILIZATION-pytorch-stress-test] - stress_test/stress_test.py::test_run_stress_test[llama-v3-8b-instruct-hf_tp1-stress_time_300s_timeout_450s-GUARANTEED_NO_EVICT-pytorch-stress-test] + - test_e2e.py::test_openai_chat_structural_tag_example - condition: ranges: system_gpu_count: diff --git a/tests/unittest/api_stability/references/guided_decoding_params.yaml b/tests/unittest/api_stability/references/guided_decoding_params.yaml index 5b9dd4b16bb..3ec8a2eda41 100644 --- a/tests/unittest/api_stability/references/guided_decoding_params.yaml +++ b/tests/unittest/api_stability/references/guided_decoding_params.yaml @@ -13,5 +13,8 @@ methods: regex: annotation: Optional[str] default: null + structural_tag: + annotation: Optional[str] + default: null return_annotation: None properties: {} diff --git a/tests/unittest/llmapi/apps/_test_openai_chat_structural_tag.py b/tests/unittest/llmapi/apps/_test_openai_chat_structural_tag.py new file mode 100644 index 00000000000..cd298967a31 --- /dev/null +++ b/tests/unittest/llmapi/apps/_test_openai_chat_structural_tag.py @@ -0,0 +1,191 @@ +# Adapted from +# https://github.com/vllm-project/vllm/blob/aae6927be06dedbda39c6b0c30f6aa3242b84388/tests/entrypoints/openai/test_chat.py +import os +import tempfile + +import openai +import pytest +import yaml + +from ..test_llm import get_model_path, similar +from .openai_server import RemoteOpenAIServer + +pytestmark = pytest.mark.threadleak(enabled=False) + + +@pytest.fixture(scope="module", ids=["TinyLlama-1.1B-Chat"]) +def model_name(): + return "llama-3.1-model/Llama-3.1-8B-Instruct" + + +@pytest.fixture(scope="module") +def temp_extra_llm_api_options_file(request): + temp_dir = tempfile.gettempdir() + temp_file_path = os.path.join(temp_dir, "extra_llm_api_options.yaml") + try: + extra_llm_api_options_dict = { + "guided_decoding_backend": "xgrammar", + "pytorch_backend_config": { + "enable_overlap_scheduler": False, + } + } + + with open(temp_file_path, 'w') as f: + yaml.dump(extra_llm_api_options_dict, f) + + yield temp_file_path + finally: + if os.path.exists(temp_file_path): + os.remove(temp_file_path) + + +@pytest.fixture(scope="module") +def server(model_name: str, temp_extra_llm_api_options_file: str): + model_path = get_model_path(model_name) + args = [ + "--backend", "pytorch", "--extra_llm_api_options", + temp_extra_llm_api_options_file + ] + with RemoteOpenAIServer(model_path, args) as remote_server: + yield remote_server + + +@pytest.fixture(scope="module") +def client(server: RemoteOpenAIServer): + return server.get_client() + + +@pytest.fixture(scope="module") +def async_client(server: RemoteOpenAIServer): + return server.get_async_client() + + +@pytest.fixture(scope="module") +def tool_get_current_weather(): + return { + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "city": { + "type": + "string", + "description": + "The city to find the weather for, e.g. 'San Francisco'", + }, + "state": { + "type": + "string", + "description": + "the two-letter abbreviation for the state that the city is" + " in, e.g. 'CA' which would mean 'California'", + }, + "unit": { + "type": "string", + "description": "The unit to fetch the temperature in", + "enum": ["celsius", "fahrenheit"], + }, + }, + "required": ["city", "state", "unit"], + }, + }, + } + + +@pytest.fixture(scope="module") +def tool_get_current_date(): + return { + "type": "function", + "function": { + "name": "get_current_date", + "description": "Get the current date and time for a given timezone", + "parameters": { + "type": "object", + "properties": { + "timezone": { + "type": + "string", + "description": + "The timezone to fetch the current date and time for, e.g. 'America/New_York'", + } + }, + "required": ["timezone"], + }, + }, + } + + +def test_chat_structural_tag(client: openai.OpenAI, model_name: str, + tool_get_current_weather, tool_get_current_date): + messages = [ + { + "role": + "system", + "content": + f""" +# Tool Instructions +- Always execute python code in messages that you share. +- When looking for real time information use relevant functions if available else fallback to brave_search +You have access to the following functions: +Use the function 'get_current_weather' to: Get the current weather in a given location +{tool_get_current_weather["function"]} +Use the function 'get_current_date' to: Get the current date and time for a given timezone +{tool_get_current_date["function"]} +If a you choose to call a function ONLY reply in the following format: +<{{start_tag}}={{function_name}}>{{parameters}}{{end_tag}} +where +start_tag => ` a JSON dict with the function argument name as key and function argument value as value. +end_tag => `` +Here is an example, +{{"example_name": "example_value"}} +Reminder: +- Function calls MUST follow the specified format +- Required parameters MUST be specified +- Only call one function at a time +- Put the entire function call reply on one line +- Always add your sources when using search results to answer the user query +You are a helpful assistant.""", + }, + { + "role": + "user", + "content": + "You are in New York. Please get the current date and time, and the weather.", + }, + ] + + chat_completion = client.chat.completions.create( + model=model_name, + messages=messages, + max_completion_tokens=100, + response_format={ + "type": + "structural_tag", + "structures": [ + { + "begin": "", + "schema": + tool_get_current_weather["function"]["parameters"], + "end": "", + }, + { + "begin": "", + "schema": tool_get_current_date["function"]["parameters"], + "end": "", + }, + ], + "triggers": ["