Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions cpp/include/tensorrt_llm/executor/executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -1351,6 +1351,8 @@ class GuidedDecodingConfig
{
/// @brief Enable guided decoding with XGrammar backend.
kXGRAMMAR = 0,
/// @brief Enable guided decoding with LLGuidance backend.
kLLGUIDANCE = 1,
};

explicit GuidedDecodingConfig(GuidedDecodingBackend backend,
Expand Down
3 changes: 2 additions & 1 deletion cpp/tensorrt_llm/pybind/executor/executorConfig.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -374,7 +374,8 @@ void initConfigBindings(pybind11::module_& m)
auto pyGuidedDecodingConfig = py::class_<tle::GuidedDecodingConfig>(m, "GuidedDecodingConfig");

py::enum_<tle::GuidedDecodingConfig::GuidedDecodingBackend>(pyGuidedDecodingConfig, "GuidedDecodingBackend")
.value("XGRAMMAR", tle::GuidedDecodingConfig::GuidedDecodingBackend::kXGRAMMAR);
.value("XGRAMMAR", tle::GuidedDecodingConfig::GuidedDecodingBackend::kXGRAMMAR)
.value("LLGUIDANCE", tle::GuidedDecodingConfig::GuidedDecodingBackend::kLLGUIDANCE);

auto guidedDecodingConfigGetstate = [](tle::GuidedDecodingConfig const& self) {
return py::make_tuple(
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -57,3 +57,4 @@ meson
ninja
etcd3
blake3
llguidance==0.7.29
176 changes: 176 additions & 0 deletions tensorrt_llm/_torch/pyexecutor/grammar_matcher.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
import json
from abc import ABC, abstractmethod

import llguidance
import llguidance.torch
import torch
import xgrammar

from ...bindings.executor import GuidedDecodingConfig, GuidedDecodingParams


class GrammarMatcher(ABC):

@abstractmethod
def accept_token(self, token_id: int) -> bool:
pass

@abstractmethod
def fill_next_token_bitmask(self, next_token_bitmask: torch.Tensor,
index: int) -> None:
pass


class GrammarMatcherFactory(ABC):

@abstractmethod
def create(self,
guided_decoding_params: GuidedDecodingParams) -> GrammarMatcher:
pass


class XGrammarMatcher(GrammarMatcher):

def __init__(self, matcher: xgrammar.GrammarMatcher):
super().__init__()
self._matcher = matcher

def accept_token(self, token_id: int) -> bool:
return self._matcher.accept_token(token_id)

def fill_next_token_bitmask(self, next_token_bitmask: torch.Tensor,
index: int) -> None:
self._matcher.fill_next_token_bitmask(next_token_bitmask, index)


class XGrammarMatcherFactory(GrammarMatcherFactory):

def __init__(self, guided_decoding_config: GuidedDecodingConfig,
vocab_size_padded: int):
super().__init__()
if guided_decoding_config.tokenizer_str is not None:
metadata = xgrammar.TokenizerInfo._detect_metadata_from_hf(
guided_decoding_config.tokenizer_str)
tokenizer_info = xgrammar.TokenizerInfo(
guided_decoding_config.encoded_vocab,
vocab_type=metadata["vocab_type"],
vocab_size=vocab_size_padded,
stop_token_ids=guided_decoding_config.stop_token_ids,
add_prefix_space=metadata["add_prefix_space"])
else:
tokenizer_info = xgrammar.TokenizerInfo(
guided_decoding_config.encoded_vocab,
xgrammar.VocabType.RAW,
vocab_size=vocab_size_padded,
stop_token_ids=guided_decoding_config.stop_token_ids)
self._xgrammar_compiler = xgrammar.GrammarCompiler(tokenizer_info)

def create(self,
guided_decoding_params: GuidedDecodingParams) -> XGrammarMatcher:
guide_type = guided_decoding_params.guide_type
guide = guided_decoding_params.guide
match guide_type:
case GuidedDecodingParams.GuideType.JSON:
compiled_grammar = self._xgrammar_compiler.compile_builtin_json_grammar(
)
case GuidedDecodingParams.GuideType.JSON_SCHEMA:
compiled_grammar = self._xgrammar_compiler.compile_json_schema(
guide)
case GuidedDecodingParams.GuideType.REGEX:
grammar = xgrammar.Grammar.from_regex(guide)
compiled_grammar = self._xgrammar_compiler.compile_grammar(
grammar)
case GuidedDecodingParams.GuideType.EBNF_GRAMMAR:
grammar = xgrammar.Grammar.from_ebnf(guide)
compiled_grammar = self._xgrammar_compiler.compile_grammar(
grammar)
case GuidedDecodingParams.GuideType.STRUCTURAL_TAG:
structural_tag_parameters = json.loads(guide)
structures = structural_tag_parameters["structures"]
structures = [
xgrammar.StructuralTagItem(begin=s["begin"],
schema=json.dumps(s["schema"]),
end=s["end"]) for s in structures
]
triggers = structural_tag_parameters["triggers"]
compiled_grammar = self._xgrammar_compiler.compile_structural_tag(
structures, triggers)
case _:
raise ValueError(f"Unrecognized guide type: {guide_type}.")

matcher = xgrammar.GrammarMatcher(compiled_grammar)
return XGrammarMatcher(matcher)


class LLGuidanceMatcher(GrammarMatcher):

def __init__(self, matcher: llguidance.LLMatcher):
super().__init__()
self._matcher = matcher

def accept_token(self, token_id: int) -> bool:
result = self._matcher.consume_token(token_id)
self._check_err()
return result

def fill_next_token_bitmask(self, next_token_bitmask: torch.Tensor,
index: int) -> None:
llguidance.torch.fill_next_token_bitmask(self._matcher,
next_token_bitmask, index)
self._check_err()

def _check_err(self) -> None:
if self._matcher.is_error():
raise ValueError(
f"LLGuidance matcher error: {self._matcher.get_error()}")


class LLGuidanceMatcherFactory(GrammarMatcherFactory):

def __init__(self, guided_decoding_config: GuidedDecodingConfig,
vocab_size_padded: int):
super().__init__()
tokenizer_str = guided_decoding_config.tokenizer_str
stop_token_ids = guided_decoding_config.stop_token_ids

if tokenizer_str is None:
raise ValueError("tokenizer_str is required")

eos_token = None
if stop_token_ids is not None:
if len(stop_token_ids) != 1:
raise ValueError("expected stop_token_ids size to be 1")
eos_token = stop_token_ids[0]

self._tokenizer = llguidance.LLTokenizer(tokenizer_str,
n_vocab=vocab_size_padded,
eos_token=eos_token)

def create(
self,
guided_decoding_params: GuidedDecodingParams) -> LLGuidanceMatcher:
guide_type = guided_decoding_params.guide_type
guide = guided_decoding_params.guide

grammar = None
match guide_type:
case GuidedDecodingParams.GuideType.JSON:
grammar = llguidance.LLMatcher.grammar_from_json_schema(
'{"type": "object"}')
case GuidedDecodingParams.GuideType.JSON_SCHEMA:
grammar = llguidance.LLMatcher.grammar_from_json_schema(guide)
case GuidedDecodingParams.GuideType.REGEX:
grammar = llguidance.LLMatcher.grammar_from_regex(guide)
case GuidedDecodingParams.GuideType.EBNF_GRAMMAR:
# Note: LLGuidance expects Lark grammar format, not standard EBNF.
# When using LLGuidance backend with EBNF_GRAMMAR type, users must
# provide Lark-formatted grammar instead of standard EBNF.
grammar = llguidance.LLMatcher.grammar_from_lark(guide)
case _:
raise ValueError(f"Unrecognized guide type: {guide_type}.")

matcher = llguidance.LLMatcher(self._tokenizer, grammar)
if matcher.is_error():
raise ValueError(f"LLGuidance matcher error: {matcher.get_error()}")

return LLGuidanceMatcher(matcher)
167 changes: 63 additions & 104 deletions tensorrt_llm/_torch/pyexecutor/guided_decoder.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import itertools
import json
import math
from typing import List, Optional

import torch
import xgrammar

from ...bindings.executor import GuidedDecodingConfig, GuidedDecodingParams
from ...bindings.executor import GuidedDecodingConfig
from .grammar_matcher import (GrammarMatcher, GrammarMatcherFactory,
LLGuidanceMatcherFactory, XGrammarMatcherFactory)
from .scheduler import ScheduledRequests
from .seq_slot_manager import SeqSlotManager

Expand All @@ -20,33 +20,29 @@ def __init__(self, guided_decoding_config: GuidedDecodingConfig,
self.max_num_sequences = max_num_sequences
self.vocab_size_padded = vocab_size_padded

self.grammar_matcher_factory: Optional[GrammarMatcherFactory] = None
self.grammar_matchers: List[
Optional[GrammarMatcher]] = [None] * self.max_num_sequences

if self.guided_decoding_backend == GuidedDecodingConfig.GuidedDecodingBackend.XGRAMMAR:
if guided_decoding_config.tokenizer_str is not None:
metadata = xgrammar.TokenizerInfo._detect_metadata_from_hf(
guided_decoding_config.tokenizer_str)
tokenizer_info = xgrammar.TokenizerInfo(
guided_decoding_config.encoded_vocab,
vocab_type=metadata["vocab_type"],
vocab_size=vocab_size_padded,
stop_token_ids=guided_decoding_config.stop_token_ids,
add_prefix_space=metadata["add_prefix_space"])
else:
tokenizer_info = xgrammar.TokenizerInfo(
guided_decoding_config.encoded_vocab,
xgrammar.VocabType.RAW,
vocab_size=vocab_size_padded,
stop_token_ids=guided_decoding_config.stop_token_ids)
self.xgrammar_compiler = xgrammar.GrammarCompiler(tokenizer_info)
self.xgrammar_matchers: List[Optional[
xgrammar.GrammarMatcher]] = [None] * self.max_num_sequences
self.bitmask = torch.empty(self.max_num_sequences,
self.bitmask_size,
dtype=self.bitmask_dtype,
device='cuda')
self.bitmask_host = torch.empty(self.max_num_sequences,
self.bitmask_size,
dtype=self.bitmask_dtype,
pin_memory=True)
self.grammar_matcher_factory = XGrammarMatcherFactory(
guided_decoding_config, vocab_size_padded)
elif self.guided_decoding_backend == GuidedDecodingConfig.GuidedDecodingBackend.LLGUIDANCE:
self.grammar_matcher_factory = LLGuidanceMatcherFactory(
guided_decoding_config, vocab_size_padded)
else:
raise ValueError(
f"invalid guided_decoding_backend: {self.guided_decoding_backend}"
)

self.bitmask = torch.empty(self.max_num_sequences,
self.bitmask_size,
dtype=self.bitmask_dtype,
device='cuda')
self.bitmask_host = torch.empty(self.max_num_sequences,
self.bitmask_size,
dtype=self.bitmask_dtype,
pin_memory=True)

self._stream = torch.cuda.Stream()

Expand All @@ -56,85 +52,48 @@ def bitmask_size(self) -> int:

def build(self, scheduled_requests: ScheduledRequests,
resource_manager: SeqSlotManager) -> None:
if self.guided_decoding_backend == GuidedDecodingConfig.GuidedDecodingBackend.XGRAMMAR:
for llm_req in itertools.chain(
scheduled_requests.context_requests,
scheduled_requests.generation_requests):
if llm_req.guided_decoding_params is None:
continue
slot = resource_manager.slot_manager.get_slot(
llm_req.request_id)
if llm_req.is_context_init_state and llm_req.context_current_position == llm_req.prepopulated_prompt_len:
# The request is in the first context forward step (considering kv cache reuse).
guide_type = llm_req.guided_decoding_params.guide_type
guide = llm_req.guided_decoding_params.guide
match guide_type:
case GuidedDecodingParams.GuideType.JSON:
compiled_grammar = self.xgrammar_compiler.compile_builtin_json_grammar(
)
case GuidedDecodingParams.GuideType.JSON_SCHEMA:
compiled_grammar = self.xgrammar_compiler.compile_json_schema(
guide)
case GuidedDecodingParams.GuideType.REGEX:
grammar = xgrammar.Grammar.from_regex(guide)
compiled_grammar = self.xgrammar_compiler.compile_grammar(
grammar)
case GuidedDecodingParams.GuideType.EBNF_GRAMMAR:
grammar = xgrammar.Grammar.from_ebnf(guide)
compiled_grammar = self.xgrammar_compiler.compile_grammar(
grammar)
case GuidedDecodingParams.GuideType.STRUCTURAL_TAG:
structural_tag_parameters = json.loads(guide)
structures = structural_tag_parameters["structures"]
structures = [
xgrammar.StructuralTagItem(
begin=s["begin"],
schema=json.dumps(s["schema"]),
end=s["end"]) for s in structures
]
triggers = structural_tag_parameters["triggers"]
compiled_grammar = self.xgrammar_compiler.compile_structural_tag(
structures, triggers)
case _:
raise ValueError(
f"Unrecognized guide type: {guide_type}.")
self.xgrammar_matchers[slot] = xgrammar.GrammarMatcher(
compiled_grammar)

elif llm_req.is_generation_in_progress_state:
# The request is in a generation forward step.
# Currently, guided decoding does not support with beam search.
self.xgrammar_matchers[slot].accept_token(
llm_req.get_last_tokens(0))
else:
continue

# Fill the bitmask on host and asynchorously copy to device.
self.xgrammar_matchers[slot].fill_next_token_bitmask(
self.bitmask_host, slot)
with torch.cuda.stream(self._stream):
self.bitmask[slot].copy_(self.bitmask_host[slot],
non_blocking=True)
for llm_req in itertools.chain(scheduled_requests.context_requests,
scheduled_requests.generation_requests):
if llm_req.guided_decoding_params is None:
continue
slot = resource_manager.slot_manager.get_slot(llm_req.request_id)
if llm_req.is_context_init_state and llm_req.context_current_position == llm_req.prepopulated_prompt_len:
self.grammar_matchers[
slot] = self.grammar_matcher_factory.create(
llm_req.guided_decoding_params)

elif llm_req.is_generation_in_progress_state:
# The request is in a generation forward step.
# Currently, guided decoding does not support with beam search.
self.grammar_matchers[slot].accept_token(
llm_req.get_last_tokens(0))
else:
continue

# Fill the bitmask on host and asynchorously copy to device.
self.grammar_matchers[slot].fill_next_token_bitmask(
self.bitmask_host, slot)
with torch.cuda.stream(self._stream):
self.bitmask[slot].copy_(self.bitmask_host[slot],
non_blocking=True)

def execute(self, scheduled_requests: ScheduledRequests,
logits: torch.Tensor, resource_manager: SeqSlotManager) -> None:
assert logits.size(0) == len(scheduled_requests.context_requests) + len(
scheduled_requests.generation_requests)
torch.cuda.current_stream().wait_stream(self._stream)

if self.guided_decoding_backend == GuidedDecodingConfig.GuidedDecodingBackend.XGRAMMAR:
batched_logits, batched_bitmask = [], []
for i, llm_req in enumerate(
itertools.chain(scheduled_requests.context_requests,
scheduled_requests.generation_requests)):
if llm_req.guided_decoding_params is None:
continue
if llm_req.is_context_init_state and not llm_req.is_last_context_chunk:
continue
batched_logits.append(logits[i])
slot = resource_manager.slot_manager.get_slot(
llm_req.request_id)
batched_bitmask.append(self.bitmask[slot])

if len(batched_logits) > 0:
torch.ops.trtllm.logits_bitmask(batched_logits, batched_bitmask)
batched_logits, batched_bitmask = [], []
for i, llm_req in enumerate(
itertools.chain(scheduled_requests.context_requests,
scheduled_requests.generation_requests)):
if llm_req.guided_decoding_params is None:
continue
if llm_req.is_context_init_state and not llm_req.is_last_context_chunk:
continue
batched_logits.append(logits[i])
slot = resource_manager.slot_manager.get_slot(llm_req.request_id)
batched_bitmask.append(self.bitmask[slot])

if len(batched_logits) > 0:
torch.ops.trtllm.logits_bitmask(batched_logits, batched_bitmask)
Loading