diff --git a/cpp/include/tensorrt_llm/executor/executor.h b/cpp/include/tensorrt_llm/executor/executor.h index d9aa4a6deeb..d5bd7985843 100644 --- a/cpp/include/tensorrt_llm/executor/executor.h +++ b/cpp/include/tensorrt_llm/executor/executor.h @@ -1351,6 +1351,8 @@ class GuidedDecodingConfig { /// @brief Enable guided decoding with XGrammar backend. kXGRAMMAR = 0, + /// @brief Enable guided decoding with LLGuidance backend. + kLLGUIDANCE = 1, }; explicit GuidedDecodingConfig(GuidedDecodingBackend backend, diff --git a/cpp/tensorrt_llm/pybind/executor/executorConfig.cpp b/cpp/tensorrt_llm/pybind/executor/executorConfig.cpp index fdd7f7c1fd3..2af7e7e008e 100644 --- a/cpp/tensorrt_llm/pybind/executor/executorConfig.cpp +++ b/cpp/tensorrt_llm/pybind/executor/executorConfig.cpp @@ -374,7 +374,8 @@ void initConfigBindings(pybind11::module_& m) auto pyGuidedDecodingConfig = py::class_(m, "GuidedDecodingConfig"); py::enum_(pyGuidedDecodingConfig, "GuidedDecodingBackend") - .value("XGRAMMAR", tle::GuidedDecodingConfig::GuidedDecodingBackend::kXGRAMMAR); + .value("XGRAMMAR", tle::GuidedDecodingConfig::GuidedDecodingBackend::kXGRAMMAR) + .value("LLGUIDANCE", tle::GuidedDecodingConfig::GuidedDecodingBackend::kLLGUIDANCE); auto guidedDecodingConfigGetstate = [](tle::GuidedDecodingConfig const& self) { return py::make_tuple( diff --git a/requirements.txt b/requirements.txt index b6ea19bc2c4..4ce11b7d854 100644 --- a/requirements.txt +++ b/requirements.txt @@ -57,3 +57,4 @@ meson ninja etcd3 blake3 +llguidance==0.7.29 diff --git a/tensorrt_llm/_torch/pyexecutor/grammar_matcher.py b/tensorrt_llm/_torch/pyexecutor/grammar_matcher.py new file mode 100644 index 00000000000..536caab6d3a --- /dev/null +++ b/tensorrt_llm/_torch/pyexecutor/grammar_matcher.py @@ -0,0 +1,176 @@ +import json +from abc import ABC, abstractmethod + +import llguidance +import llguidance.torch +import torch +import xgrammar + +from ...bindings.executor import GuidedDecodingConfig, GuidedDecodingParams + + +class GrammarMatcher(ABC): + + @abstractmethod + def accept_token(self, token_id: int) -> bool: + pass + + @abstractmethod + def fill_next_token_bitmask(self, next_token_bitmask: torch.Tensor, + index: int) -> None: + pass + + +class GrammarMatcherFactory(ABC): + + @abstractmethod + def create(self, + guided_decoding_params: GuidedDecodingParams) -> GrammarMatcher: + pass + + +class XGrammarMatcher(GrammarMatcher): + + def __init__(self, matcher: xgrammar.GrammarMatcher): + super().__init__() + self._matcher = matcher + + def accept_token(self, token_id: int) -> bool: + return self._matcher.accept_token(token_id) + + def fill_next_token_bitmask(self, next_token_bitmask: torch.Tensor, + index: int) -> None: + self._matcher.fill_next_token_bitmask(next_token_bitmask, index) + + +class XGrammarMatcherFactory(GrammarMatcherFactory): + + def __init__(self, guided_decoding_config: GuidedDecodingConfig, + vocab_size_padded: int): + super().__init__() + if guided_decoding_config.tokenizer_str is not None: + metadata = xgrammar.TokenizerInfo._detect_metadata_from_hf( + guided_decoding_config.tokenizer_str) + tokenizer_info = xgrammar.TokenizerInfo( + guided_decoding_config.encoded_vocab, + vocab_type=metadata["vocab_type"], + vocab_size=vocab_size_padded, + stop_token_ids=guided_decoding_config.stop_token_ids, + add_prefix_space=metadata["add_prefix_space"]) + else: + tokenizer_info = xgrammar.TokenizerInfo( + guided_decoding_config.encoded_vocab, + xgrammar.VocabType.RAW, + vocab_size=vocab_size_padded, + stop_token_ids=guided_decoding_config.stop_token_ids) + self._xgrammar_compiler = xgrammar.GrammarCompiler(tokenizer_info) + + def create(self, + guided_decoding_params: GuidedDecodingParams) -> XGrammarMatcher: + guide_type = guided_decoding_params.guide_type + guide = guided_decoding_params.guide + match guide_type: + case GuidedDecodingParams.GuideType.JSON: + compiled_grammar = self._xgrammar_compiler.compile_builtin_json_grammar( + ) + case GuidedDecodingParams.GuideType.JSON_SCHEMA: + compiled_grammar = self._xgrammar_compiler.compile_json_schema( + guide) + case GuidedDecodingParams.GuideType.REGEX: + grammar = xgrammar.Grammar.from_regex(guide) + compiled_grammar = self._xgrammar_compiler.compile_grammar( + grammar) + case GuidedDecodingParams.GuideType.EBNF_GRAMMAR: + grammar = xgrammar.Grammar.from_ebnf(guide) + compiled_grammar = self._xgrammar_compiler.compile_grammar( + grammar) + case GuidedDecodingParams.GuideType.STRUCTURAL_TAG: + structural_tag_parameters = json.loads(guide) + structures = structural_tag_parameters["structures"] + structures = [ + xgrammar.StructuralTagItem(begin=s["begin"], + schema=json.dumps(s["schema"]), + end=s["end"]) for s in structures + ] + triggers = structural_tag_parameters["triggers"] + compiled_grammar = self._xgrammar_compiler.compile_structural_tag( + structures, triggers) + case _: + raise ValueError(f"Unrecognized guide type: {guide_type}.") + + matcher = xgrammar.GrammarMatcher(compiled_grammar) + return XGrammarMatcher(matcher) + + +class LLGuidanceMatcher(GrammarMatcher): + + def __init__(self, matcher: llguidance.LLMatcher): + super().__init__() + self._matcher = matcher + + def accept_token(self, token_id: int) -> bool: + result = self._matcher.consume_token(token_id) + self._check_err() + return result + + def fill_next_token_bitmask(self, next_token_bitmask: torch.Tensor, + index: int) -> None: + llguidance.torch.fill_next_token_bitmask(self._matcher, + next_token_bitmask, index) + self._check_err() + + def _check_err(self) -> None: + if self._matcher.is_error(): + raise ValueError( + f"LLGuidance matcher error: {self._matcher.get_error()}") + + +class LLGuidanceMatcherFactory(GrammarMatcherFactory): + + def __init__(self, guided_decoding_config: GuidedDecodingConfig, + vocab_size_padded: int): + super().__init__() + tokenizer_str = guided_decoding_config.tokenizer_str + stop_token_ids = guided_decoding_config.stop_token_ids + + if tokenizer_str is None: + raise ValueError("tokenizer_str is required") + + eos_token = None + if stop_token_ids is not None: + if len(stop_token_ids) != 1: + raise ValueError("expected stop_token_ids size to be 1") + eos_token = stop_token_ids[0] + + self._tokenizer = llguidance.LLTokenizer(tokenizer_str, + n_vocab=vocab_size_padded, + eos_token=eos_token) + + def create( + self, + guided_decoding_params: GuidedDecodingParams) -> LLGuidanceMatcher: + guide_type = guided_decoding_params.guide_type + guide = guided_decoding_params.guide + + grammar = None + match guide_type: + case GuidedDecodingParams.GuideType.JSON: + grammar = llguidance.LLMatcher.grammar_from_json_schema( + '{"type": "object"}') + case GuidedDecodingParams.GuideType.JSON_SCHEMA: + grammar = llguidance.LLMatcher.grammar_from_json_schema(guide) + case GuidedDecodingParams.GuideType.REGEX: + grammar = llguidance.LLMatcher.grammar_from_regex(guide) + case GuidedDecodingParams.GuideType.EBNF_GRAMMAR: + # Note: LLGuidance expects Lark grammar format, not standard EBNF. + # When using LLGuidance backend with EBNF_GRAMMAR type, users must + # provide Lark-formatted grammar instead of standard EBNF. + grammar = llguidance.LLMatcher.grammar_from_lark(guide) + case _: + raise ValueError(f"Unrecognized guide type: {guide_type}.") + + matcher = llguidance.LLMatcher(self._tokenizer, grammar) + if matcher.is_error(): + raise ValueError(f"LLGuidance matcher error: {matcher.get_error()}") + + return LLGuidanceMatcher(matcher) diff --git a/tensorrt_llm/_torch/pyexecutor/guided_decoder.py b/tensorrt_llm/_torch/pyexecutor/guided_decoder.py index 3ff78c4e4c4..fc21a2096e2 100644 --- a/tensorrt_llm/_torch/pyexecutor/guided_decoder.py +++ b/tensorrt_llm/_torch/pyexecutor/guided_decoder.py @@ -1,12 +1,12 @@ import itertools -import json import math from typing import List, Optional import torch -import xgrammar -from ...bindings.executor import GuidedDecodingConfig, GuidedDecodingParams +from ...bindings.executor import GuidedDecodingConfig +from .grammar_matcher import (GrammarMatcher, GrammarMatcherFactory, + LLGuidanceMatcherFactory, XGrammarMatcherFactory) from .scheduler import ScheduledRequests from .seq_slot_manager import SeqSlotManager @@ -20,33 +20,29 @@ def __init__(self, guided_decoding_config: GuidedDecodingConfig, self.max_num_sequences = max_num_sequences self.vocab_size_padded = vocab_size_padded + self.grammar_matcher_factory: Optional[GrammarMatcherFactory] = None + self.grammar_matchers: List[ + Optional[GrammarMatcher]] = [None] * self.max_num_sequences + if self.guided_decoding_backend == GuidedDecodingConfig.GuidedDecodingBackend.XGRAMMAR: - if guided_decoding_config.tokenizer_str is not None: - metadata = xgrammar.TokenizerInfo._detect_metadata_from_hf( - guided_decoding_config.tokenizer_str) - tokenizer_info = xgrammar.TokenizerInfo( - guided_decoding_config.encoded_vocab, - vocab_type=metadata["vocab_type"], - vocab_size=vocab_size_padded, - stop_token_ids=guided_decoding_config.stop_token_ids, - add_prefix_space=metadata["add_prefix_space"]) - else: - tokenizer_info = xgrammar.TokenizerInfo( - guided_decoding_config.encoded_vocab, - xgrammar.VocabType.RAW, - vocab_size=vocab_size_padded, - stop_token_ids=guided_decoding_config.stop_token_ids) - self.xgrammar_compiler = xgrammar.GrammarCompiler(tokenizer_info) - self.xgrammar_matchers: List[Optional[ - xgrammar.GrammarMatcher]] = [None] * self.max_num_sequences - self.bitmask = torch.empty(self.max_num_sequences, - self.bitmask_size, - dtype=self.bitmask_dtype, - device='cuda') - self.bitmask_host = torch.empty(self.max_num_sequences, - self.bitmask_size, - dtype=self.bitmask_dtype, - pin_memory=True) + self.grammar_matcher_factory = XGrammarMatcherFactory( + guided_decoding_config, vocab_size_padded) + elif self.guided_decoding_backend == GuidedDecodingConfig.GuidedDecodingBackend.LLGUIDANCE: + self.grammar_matcher_factory = LLGuidanceMatcherFactory( + guided_decoding_config, vocab_size_padded) + else: + raise ValueError( + f"invalid guided_decoding_backend: {self.guided_decoding_backend}" + ) + + self.bitmask = torch.empty(self.max_num_sequences, + self.bitmask_size, + dtype=self.bitmask_dtype, + device='cuda') + self.bitmask_host = torch.empty(self.max_num_sequences, + self.bitmask_size, + dtype=self.bitmask_dtype, + pin_memory=True) self._stream = torch.cuda.Stream() @@ -56,65 +52,30 @@ def bitmask_size(self) -> int: def build(self, scheduled_requests: ScheduledRequests, resource_manager: SeqSlotManager) -> None: - if self.guided_decoding_backend == GuidedDecodingConfig.GuidedDecodingBackend.XGRAMMAR: - for llm_req in itertools.chain( - scheduled_requests.context_requests, - scheduled_requests.generation_requests): - if llm_req.guided_decoding_params is None: - continue - slot = resource_manager.slot_manager.get_slot( - llm_req.request_id) - if llm_req.is_context_init_state and llm_req.context_current_position == llm_req.prepopulated_prompt_len: - # The request is in the first context forward step (considering kv cache reuse). - guide_type = llm_req.guided_decoding_params.guide_type - guide = llm_req.guided_decoding_params.guide - match guide_type: - case GuidedDecodingParams.GuideType.JSON: - compiled_grammar = self.xgrammar_compiler.compile_builtin_json_grammar( - ) - case GuidedDecodingParams.GuideType.JSON_SCHEMA: - compiled_grammar = self.xgrammar_compiler.compile_json_schema( - guide) - case GuidedDecodingParams.GuideType.REGEX: - grammar = xgrammar.Grammar.from_regex(guide) - compiled_grammar = self.xgrammar_compiler.compile_grammar( - grammar) - case GuidedDecodingParams.GuideType.EBNF_GRAMMAR: - grammar = xgrammar.Grammar.from_ebnf(guide) - compiled_grammar = self.xgrammar_compiler.compile_grammar( - grammar) - case GuidedDecodingParams.GuideType.STRUCTURAL_TAG: - structural_tag_parameters = json.loads(guide) - structures = structural_tag_parameters["structures"] - structures = [ - xgrammar.StructuralTagItem( - begin=s["begin"], - schema=json.dumps(s["schema"]), - end=s["end"]) for s in structures - ] - triggers = structural_tag_parameters["triggers"] - compiled_grammar = self.xgrammar_compiler.compile_structural_tag( - structures, triggers) - case _: - raise ValueError( - f"Unrecognized guide type: {guide_type}.") - self.xgrammar_matchers[slot] = xgrammar.GrammarMatcher( - compiled_grammar) - - elif llm_req.is_generation_in_progress_state: - # The request is in a generation forward step. - # Currently, guided decoding does not support with beam search. - self.xgrammar_matchers[slot].accept_token( - llm_req.get_last_tokens(0)) - else: - continue - - # Fill the bitmask on host and asynchorously copy to device. - self.xgrammar_matchers[slot].fill_next_token_bitmask( - self.bitmask_host, slot) - with torch.cuda.stream(self._stream): - self.bitmask[slot].copy_(self.bitmask_host[slot], - non_blocking=True) + for llm_req in itertools.chain(scheduled_requests.context_requests, + scheduled_requests.generation_requests): + if llm_req.guided_decoding_params is None: + continue + slot = resource_manager.slot_manager.get_slot(llm_req.request_id) + if llm_req.is_context_init_state and llm_req.context_current_position == llm_req.prepopulated_prompt_len: + self.grammar_matchers[ + slot] = self.grammar_matcher_factory.create( + llm_req.guided_decoding_params) + + elif llm_req.is_generation_in_progress_state: + # The request is in a generation forward step. + # Currently, guided decoding does not support with beam search. + self.grammar_matchers[slot].accept_token( + llm_req.get_last_tokens(0)) + else: + continue + + # Fill the bitmask on host and asynchorously copy to device. + self.grammar_matchers[slot].fill_next_token_bitmask( + self.bitmask_host, slot) + with torch.cuda.stream(self._stream): + self.bitmask[slot].copy_(self.bitmask_host[slot], + non_blocking=True) def execute(self, scheduled_requests: ScheduledRequests, logits: torch.Tensor, resource_manager: SeqSlotManager) -> None: @@ -122,19 +83,17 @@ def execute(self, scheduled_requests: ScheduledRequests, scheduled_requests.generation_requests) torch.cuda.current_stream().wait_stream(self._stream) - if self.guided_decoding_backend == GuidedDecodingConfig.GuidedDecodingBackend.XGRAMMAR: - batched_logits, batched_bitmask = [], [] - for i, llm_req in enumerate( - itertools.chain(scheduled_requests.context_requests, - scheduled_requests.generation_requests)): - if llm_req.guided_decoding_params is None: - continue - if llm_req.is_context_init_state and not llm_req.is_last_context_chunk: - continue - batched_logits.append(logits[i]) - slot = resource_manager.slot_manager.get_slot( - llm_req.request_id) - batched_bitmask.append(self.bitmask[slot]) - - if len(batched_logits) > 0: - torch.ops.trtllm.logits_bitmask(batched_logits, batched_bitmask) + batched_logits, batched_bitmask = [], [] + for i, llm_req in enumerate( + itertools.chain(scheduled_requests.context_requests, + scheduled_requests.generation_requests)): + if llm_req.guided_decoding_params is None: + continue + if llm_req.is_context_init_state and not llm_req.is_last_context_chunk: + continue + batched_logits.append(logits[i]) + slot = resource_manager.slot_manager.get_slot(llm_req.request_id) + batched_bitmask.append(self.bitmask[slot]) + + if len(batched_logits) > 0: + torch.ops.trtllm.logits_bitmask(batched_logits, batched_bitmask) diff --git a/tensorrt_llm/llmapi/llm.py b/tensorrt_llm/llmapi/llm.py index 4ddf97b665d..75efe250c20 100644 --- a/tensorrt_llm/llmapi/llm.py +++ b/tensorrt_llm/llmapi/llm.py @@ -36,7 +36,8 @@ from .llm_utils import (CachedModelLoader, KvCacheRetentionConfig, LlmBuildStats, ModelLoader, _ModelRuntimeContext) from .mpi_session import MpiPoolSession, external_mpi_comm_available -from .tokenizer import TokenizerBase, _xgrammar_tokenizer_info +from .tokenizer import (TokenizerBase, _llguidance_tokenizer_info, + _xgrammar_tokenizer_info) # TODO[chunweiy]: move the following symbols back to utils scope, and remove the following import from .utils import (append_docstring, exception_handler, get_device_count, print_colored_debug) @@ -661,6 +662,11 @@ def _build_model(self): backend=tllm.GuidedDecodingConfig.GuidedDecodingBackend. XGRAMMAR, **_xgrammar_tokenizer_info(self.tokenizer)) + elif self.args.guided_decoding_backend == 'llguidance': + self._executor_config.guided_decoding_config = tllm.GuidedDecodingConfig( + backend=tllm.GuidedDecodingConfig.GuidedDecodingBackend. + LLGUIDANCE, + **_llguidance_tokenizer_info(self.tokenizer)) elif self.args.guided_decoding_backend is not None: raise ValueError( f"Unrecognized guided decoding backend {self.args.guided_decoding_backend}" diff --git a/tensorrt_llm/llmapi/tokenizer.py b/tensorrt_llm/llmapi/tokenizer.py index d56e9cc8e8a..76c6e1110d6 100644 --- a/tensorrt_llm/llmapi/tokenizer.py +++ b/tensorrt_llm/llmapi/tokenizer.py @@ -239,6 +239,13 @@ def _xgrammar_tokenizer_info(tokenizer): raise ValueError(f"Unsupported tokenizer type: {type(tokenizer)}") +def _llguidance_tokenizer_info(tokenizer): + tokenizer_info = _xgrammar_tokenizer_info(tokenizer) + if tokenizer_info.get("tokenizer_str") is None: + raise ValueError("missing tokenizer_str") + return tokenizer_info + + def load_hf_tokenizer(model_dir: str, trust_remote_code: bool = True, use_fast: bool = True) -> Optional[TransformersTokenizer]: