Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions cpp/include/tensorrt_llm/executor/executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -483,6 +483,9 @@ class GuidedDecodingParams
/// @brief The generated text is amenable to the user-specified extended Backus-Naur form (EBNF) grammar.
/// EBNF grammar is widely-used to express context-free grammars.
kEBNF_GRAMMAR = 3,

/// @brief The generated text is amenable to the XGrammar structural tag.
kSTRUCTURAL_TAG = 4,
};

explicit GuidedDecodingParams(GuideType guideType, std::optional<std::string> guide = std::nullopt);
Expand Down
3 changes: 2 additions & 1 deletion cpp/tensorrt_llm/pybind/executor/request.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -462,7 +462,8 @@ void initRequestBindings(pybind11::module_& m)
.value("JSON", tle::GuidedDecodingParams::GuideType::kJSON)
.value("JSON_SCHEMA", tle::GuidedDecodingParams::GuideType::kJSON_SCHEMA)
.value("REGEX", tle::GuidedDecodingParams::GuideType::kREGEX)
.value("EBNF_GRAMMAR", tle::GuidedDecodingParams::GuideType::kEBNF_GRAMMAR);
.value("EBNF_GRAMMAR", tle::GuidedDecodingParams::GuideType::kEBNF_GRAMMAR)
.value("STRUCTURAL_TAG", tle::GuidedDecodingParams::GuideType::kSTRUCTURAL_TAG);

auto guidedDecodingParamsGetstate
= [](tle::GuidedDecodingParams const& self) { return py::make_tuple(self.getGuideType(), self.getGuide()); };
Expand Down
3 changes: 3 additions & 0 deletions tensorrt_llm/_torch/pyexecutor/_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,9 @@ def create_py_executor_instance(dist,
if spec_config is not None:
raise ValueError(
"Guided decoding is not supported with speculative decoding.")
if pytorch_backend_config.enable_overlap_scheduler:
raise ValueError(
"Guided decoding is not supported with overlap scheduler.")

logger.info(
f"max_seq_len={executor_config.max_seq_len}, max_num_requests={executor_config.max_batch_size}, max_num_tokens={executor_config.max_num_tokens}"
Expand Down
13 changes: 13 additions & 0 deletions tensorrt_llm/_torch/pyexecutor/guided_decoder.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import itertools
import json
import math
from typing import List, Optional

Expand Down Expand Up @@ -82,6 +83,18 @@ def build(self, scheduled_requests: ScheduledRequests,
grammar = xgrammar.Grammar.from_ebnf(guide)
compiled_grammar = self.xgrammar_compiler.compile_grammar(
grammar)
case GuidedDecodingParams.GuideType.STRUCTURAL_TAG:
structural_tag_parameters = json.loads(guide)
structures = structural_tag_parameters["structures"]
structures = [
xgrammar.StructuralTagItem(
begin=s["begin"],
schema=json.dumps(s["schema"]),
end=s["end"]) for s in structures
]
triggers = structural_tag_parameters["triggers"]
compiled_grammar = self.xgrammar_compiler.compile_structural_tag(
structures, triggers)
case _:
raise ValueError(
f"Unrecognized guide type: {guide_type}.")
Expand Down
8 changes: 7 additions & 1 deletion tensorrt_llm/sampling_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,13 @@ class GuidedDecodingParams:
regex (str, optional): The generated text is amenable to the user-specified regular expression. Defaults to None.
grammar (str, optional): The generated text is amenable to the user-specified extended Backus-Naur form (EBNF) grammar. Defaults to None.
json_object (bool): If True, the generated text is amenable to json format. Defaults to False.
structural_tag (str, optional): The generated text is amenable to the user-specified structural tag. Defaults to None.
"""
json: Optional[Union[str, BaseModel, dict]] = None
regex: Optional[str] = None
grammar: Optional[str] = None
json_object: bool = False
structural_tag: Optional[str] = None

def _validate(self):
num_guides = 0
Expand Down Expand Up @@ -451,7 +453,7 @@ def _get_guided_decoding_params(self) -> tllme.GuidedDecodingParams:
tllme.GuidedDecodingParams.GuideType.JSON)
elif self.guided_decoding.json is not None:
json_schema = self.guided_decoding.json
if isinstance(json, BaseModel):
if isinstance(json_schema, BaseModel):
json_schema = json_schema.model_json_schema()
if isinstance(json_schema, dict):
json_schema = json.dumps(json_schema)
Expand All @@ -465,5 +467,9 @@ def _get_guided_decoding_params(self) -> tllme.GuidedDecodingParams:
return tllme.GuidedDecodingParams(
tllme.GuidedDecodingParams.GuideType.EBNF_GRAMMAR,
self.guided_decoding.grammar)
elif self.guided_decoding.structural_tag is not None:
return tllme.GuidedDecodingParams(
tllme.GuidedDecodingParams.GuideType.STRUCTURAL_TAG,
self.guided_decoding.structural_tag)
else:
return None
57 changes: 36 additions & 21 deletions tensorrt_llm/serve/openai_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from typing_extensions import Annotated, Required, TypedDict

from tensorrt_llm.llmapi import DisaggregatedParams as LlmDisaggregatedParams
from tensorrt_llm.llmapi import SamplingParams
from tensorrt_llm.llmapi import GuidedDecodingParams, SamplingParams


class OpenAIBaseModel(BaseModel):
Expand Down Expand Up @@ -44,9 +44,17 @@ class ModelList(OpenAIBaseModel):
data: List[ModelCard] = Field(default_factory=list)


class StructuralTag(OpenAIBaseModel):
begin: str
schema_: Optional[dict[str, Any]] = Field(alias="schema")
end: str


class ResponseFormat(OpenAIBaseModel):
# type must be "json_object" or "text"
type: Literal["text", "json_object"]
# type must be "json_object" or "text" or "structural_tag"
type: Literal["text", "json_object", "structural_tag"]
structures: Optional[List[StructuralTag]] = None
triggers: Optional[List[str]] = None


class DisaggregatedParams(OpenAIBaseModel):
Expand Down Expand Up @@ -121,6 +129,23 @@ class CompletionStreamResponse(OpenAIBaseModel):
usage: Optional[UsageInfo] = Field(default=None)


def _response_format_to_guided_decoding_params(
response_format: Optional[ResponseFormat]
) -> Optional[GuidedDecodingParams]:
if response_format is None:
return None
elif response_format.type == "text":
return None
elif response_format.type == "json_object":
return GuidedDecodingParams(json_object=True)
elif response_format.type == "structural_tag":
return GuidedDecodingParams(
structural_tag=response_format.model_dump_json(by_alias=True,
exclude_none=True))
else:
raise ValueError(f"Unsupported response format: {response_format.type}")


class CompletionRequest(OpenAIBaseModel):
# Ordered by official OpenAI API documentation
# https://platform.openai.com/docs/api-reference/completions/create
Expand Down Expand Up @@ -170,10 +195,10 @@ class CompletionRequest(OpenAIBaseModel):
)
response_format: Optional[ResponseFormat] = Field(
default=None,
description=(
"Similar to chat completion, this parameter specifies the format of "
"output. Only {'type': 'json_object'} or {'type': 'text' } is "
"supported."),
description=
("Similar to chat completion, this parameter specifies the format of "
"output. {'type': 'json_object'}, {'type': 'text' }, {'type': 'structural_tag'} are "
"supported."),
)

disaggregated_params: Optional[DisaggregatedParams] = Field(
Expand Down Expand Up @@ -211,6 +236,8 @@ def to_sampling_params(self) -> SamplingParams:
spaces_between_special_tokens=self.spaces_between_special_tokens,
truncate_prompt_tokens=self.truncate_prompt_tokens,
return_context_logits=self.return_context_logits,
guided_decoding=_response_format_to_guided_decoding_params(
self.response_format),

# completion-extra-params
add_special_tokens=self.add_special_tokens,
Expand Down Expand Up @@ -255,13 +282,6 @@ def verify_multi_responses(cls, data):
raise ValueError("best_of should not be smaller than n")
return data

@model_validator(mode="before")
@classmethod
def check_response_format(cls, data):
if data.get("response_format"):
raise ValueError("response_format is not supported")
return data

@model_validator(mode="before")
@classmethod
def check_suffix(cls, data):
Expand Down Expand Up @@ -520,6 +540,8 @@ def to_sampling_params(self) -> SamplingParams:
skip_special_tokens=self.skip_special_tokens,
spaces_between_special_tokens=self.spaces_between_special_tokens,
truncate_prompt_tokens=self.truncate_prompt_tokens,
guided_decoding=_response_format_to_guided_decoding_params(
self.response_format),

# chat-completion-extra-params
add_special_tokens=self.add_special_tokens,
Expand Down Expand Up @@ -582,13 +604,6 @@ def verify_logit_processor(cls, data):
raise ValueError("logit bias is not supported")
return data

@model_validator(mode="before")
@classmethod
def check_response_format(cls, data):
if data.get("response_format"):
raise ValueError("response_format is not supported")
return data

@model_validator(mode="before")
@classmethod
def check_suffix(cls, data):
Expand Down
9 changes: 9 additions & 0 deletions tests/integration/defs/test_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -1121,6 +1121,15 @@ def test_openai_chat_multimodal_example(llm_root, llm_venv):
str(test_root / "_test_openai_chat_multimodal.py")])


def test_openai_chat_structural_tag_example(llm_venv):
test_root = unittest_path() / "llmapi" / "apps"

llm_venv.run_cmd([
"-m", "pytest",
str(test_root / "_test_openai_chat_structural_tag.py")
])


@pytest.mark.skip_less_device(2)
@pytest.mark.skip_less_device_memory(40000)
def test_openai_multi_chat_example(llm_root, llm_venv):
Expand Down
1 change: 1 addition & 0 deletions tests/integration/test_lists/test-db/l0_a10.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ l0_a10:
- disaggregated/test_disaggregated.py::test_disaggregated_overlap[TinyLlama-1.1B-Chat-v1.0]
- stress_test/stress_test.py::test_run_stress_test[llama-v3-8b-instruct-hf_tp1-stress_time_300s_timeout_450s-MAX_UTILIZATION-pytorch-stress-test]
- stress_test/stress_test.py::test_run_stress_test[llama-v3-8b-instruct-hf_tp1-stress_time_300s_timeout_450s-GUARANTEED_NO_EVICT-pytorch-stress-test]
- test_e2e.py::test_openai_chat_structural_tag_example
- condition:
ranges:
system_gpu_count:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,8 @@ methods:
regex:
annotation: Optional[str]
default: null
structural_tag:
annotation: Optional[str]
default: null
return_annotation: None
properties: {}
Loading