diff --git a/cpp/tensorrt_llm/thop/logitsBitmaskOp.cpp b/cpp/tensorrt_llm/thop/logitsBitmaskOp.cpp index 11b24e7a989..ad4588a6ce5 100644 --- a/cpp/tensorrt_llm/thop/logitsBitmaskOp.cpp +++ b/cpp/tensorrt_llm/thop/logitsBitmaskOp.cpp @@ -54,8 +54,8 @@ void logitsBitmask(std::vector const& logits, std::vector(bitmask[i].data_ptr()); } - auto logitsPtrs = logitsPtrsHost.to(torch::kCUDA); - auto bitmaskPtrs = bitmaskPtrsHost.to(torch::kCUDA); + auto logitsPtrs = logitsPtrsHost.to(torch::kCUDA, /*non_blocking=*/true); + auto bitmaskPtrs = bitmaskPtrsHost.to(torch::kCUDA, /*non_blocking=*/true); auto stream = at::cuda::getCurrentCUDAStream(logits[0].get_device()).stream(); diff --git a/docs/source/torch/features/feature_combination_matrix.md b/docs/source/torch/features/feature_combination_matrix.md index 8f8d5defe80..f62c1d33aa4 100644 --- a/docs/source/torch/features/feature_combination_matrix.md +++ b/docs/source/torch/features/feature_combination_matrix.md @@ -15,4 +15,4 @@ | KV Cache Reuse | Yes | Yes | Yes | Untested | Untested | Untested | Yes | No | Yes | Yes | --- | | | | | Slide Window Attention | Yes | Yes | Yes | Untested | Untested | Untested | Untested | Untested | Yes | Yes | WIP | --- | | | | Logits Post Processor | No | Yes | Yes | No | Untested | No | No | No | Yes | Yes | Yes | Yes | --- | | -| Guided Decoding | No | Yes | Yes | Untested | Yes | No | No | No | Yes | Yes | Yes | Yes | Yes | --- | +| Guided Decoding | Yes | Yes | Yes | No | Yes | No | No | No | Yes | Yes | Yes | Yes | Yes | --- | diff --git a/examples/llm-api/llm_guided_decoding.py b/examples/llm-api/llm_guided_decoding.py index a5e0f89244d..e5df98e5da3 100644 --- a/examples/llm-api/llm_guided_decoding.py +++ b/examples/llm-api/llm_guided_decoding.py @@ -7,12 +7,9 @@ def main(): - # Specify the guided decoding backend; xgrammar is supported currently. - llm = LLM( - model="TinyLlama/TinyLlama-1.1B-Chat-v1.0", - guided_decoding_backend='xgrammar', - disable_overlap_scheduler=True # Not supported by xgrammar mode - ) + # Specify the guided decoding backend; xgrammar and llguidance are supported currently. + llm = LLM(model="TinyLlama/TinyLlama-1.1B-Chat-v1.0", + guided_decoding_backend='xgrammar') # An example from json-mode-eval schema = '{"title": "WirelessAccessPoint", "type": "object", "properties": {"ssid": {"title": "SSID", "type": "string"}, "securityProtocol": {"title": "SecurityProtocol", "type": "string"}, "bandwidth": {"title": "Bandwidth", "type": "string"}}, "required": ["ssid", "securityProtocol", "bandwidth"]}' diff --git a/tensorrt_llm/_torch/pyexecutor/_util.py b/tensorrt_llm/_torch/pyexecutor/_util.py index 979bc83f218..a8119c5ad25 100644 --- a/tensorrt_llm/_torch/pyexecutor/_util.py +++ b/tensorrt_llm/_torch/pyexecutor/_util.py @@ -21,6 +21,7 @@ from ..speculative import get_spec_decoder from .config import PyTorchConfig from .config_utils import is_mla, is_nemotron_hybrid +from .guided_decoder import GuidedDecoder from .kv_cache_transceiver import AttentionTypeCpp, create_kv_cache_transceiver from .llm_request import ExecutorResponse from .model_engine import PyTorchModelEngine @@ -414,19 +415,12 @@ def create_py_executor_instance( start_worker, sampler, drafter, + guided_decoder: Optional[GuidedDecoder] = None, lora_config: Optional[LoraConfig] = None, garbage_collection_gen0_threshold: Optional[int] = None) -> PyExecutor: kv_cache_manager = resources.get(ResourceManagerType.KV_CACHE_MANAGER, None) spec_config = model_engine.spec_config - if mapping.is_last_pp_rank( - ) and executor_config.guided_decoding_config is not None: - if spec_config is not None: - raise ValueError( - "Guided decoding is not supported with speculative decoding.") - if not pytorch_backend_config.disable_overlap_scheduler: - raise ValueError( - "Guided decoding is not supported with overlap scheduler.") logger.info( f"max_seq_len={executor_config.max_seq_len}, max_num_requests={executor_config.max_batch_size}, max_num_tokens={executor_config.max_num_tokens}, max_batch_size={executor_config.max_batch_size}" @@ -544,6 +538,7 @@ def create_py_executor_instance( if spec_config is not None else 0, kv_cache_transceiver=kv_cache_transceiver, draft_model_engine=draft_model_engine, + guided_decoder=guided_decoder, start_worker=start_worker, garbage_collection_gen0_threshold=garbage_collection_gen0_threshold) diff --git a/tensorrt_llm/_torch/pyexecutor/guided_decoder.py b/tensorrt_llm/_torch/pyexecutor/guided_decoder.py index 756c177a6ea..f1b21339b9a 100644 --- a/tensorrt_llm/_torch/pyexecutor/guided_decoder.py +++ b/tensorrt_llm/_torch/pyexecutor/guided_decoder.py @@ -3,11 +3,11 @@ import torch +from ..._utils import nvtx_range from ...bindings.executor import GuidedDecodingConfig from .grammar_matcher import (GrammarMatcher, GrammarMatcherFactory, LLGuidanceMatcherFactory, XGrammarMatcherFactory) from .scheduler import ScheduledRequests -from .seq_slot_manager import SeqSlotManager class GuidedDecoder: @@ -49,12 +49,12 @@ def __init__(self, guided_decoding_config: GuidedDecodingConfig, def bitmask_size(self) -> int: return math.ceil(self.vocab_size_padded / 32) - def build(self, scheduled_requests: ScheduledRequests, - resource_manager: SeqSlotManager) -> None: + @nvtx_range("GuidedDecoder.build") + def build(self, scheduled_requests: ScheduledRequests) -> None: for llm_req in scheduled_requests.all_requests(): if llm_req.guided_decoding_params is None: continue - slot = resource_manager.slot_manager.get_slot(llm_req.request_id) + slot = llm_req.py_seq_slot if llm_req.is_context_init_state and llm_req.context_current_position == llm_req.prepopulated_prompt_len: self.grammar_matchers[ slot] = self.grammar_matcher_factory.create( @@ -75,8 +75,9 @@ def build(self, scheduled_requests: ScheduledRequests, self.bitmask[slot].copy_(self.bitmask_host[slot], non_blocking=True) + @nvtx_range("GuidedDecoder.execute") def execute(self, scheduled_requests: ScheduledRequests, - logits: torch.Tensor, resource_manager: SeqSlotManager) -> None: + logits: torch.Tensor) -> None: assert logits.size(0) == len(scheduled_requests.context_requests) + len( scheduled_requests.generation_requests) torch.cuda.current_stream().wait_stream(self._stream) @@ -88,8 +89,7 @@ def execute(self, scheduled_requests: ScheduledRequests, if llm_req.is_context_init_state and not llm_req.is_last_context_chunk: continue batched_logits.append(logits[i]) - slot = resource_manager.slot_manager.get_slot(llm_req.request_id) - batched_bitmask.append(self.bitmask[slot]) + batched_bitmask.append(self.bitmask[llm_req.py_seq_slot]) if len(batched_logits) > 0: torch.ops.trtllm.logits_bitmask(batched_logits, batched_bitmask) diff --git a/tensorrt_llm/_torch/pyexecutor/model_engine.py b/tensorrt_llm/_torch/pyexecutor/model_engine.py index 635787a0324..6e1ced85456 100644 --- a/tensorrt_llm/_torch/pyexecutor/model_engine.py +++ b/tensorrt_llm/_torch/pyexecutor/model_engine.py @@ -25,7 +25,6 @@ from tensorrt_llm._utils import (is_trace_enabled, local_mpi_rank, local_mpi_size, nvtx_range, release_gc, torch_dtype_to_str, trace_func) -from tensorrt_llm.bindings.executor import GuidedDecodingConfig from tensorrt_llm.inputs.multimodal import MultimodalParams from tensorrt_llm.logger import logger from tensorrt_llm.lora_manager import LoraConfig, LoraModelConfig @@ -57,7 +56,6 @@ from .config import LoadFormat, PyTorchConfig from .config_utils import is_mla from .cuda_graph_runner import DecodingCUDAGraphRunner -from .guided_decoder import GuidedDecoder from .layerwise_nvtx_marker import LayerwiseNvtxMarker from .resource_manager import (BaseResourceManager, KVCacheManager, ResourceManager, ResourceManagerType) @@ -354,7 +352,6 @@ def __init__( attn_runtime_features: Optional[AttentionRuntimeFeatures] = None, dist: Optional[MPIDist] = None, spec_config: Optional["DecodingBaseConfig"] = None, - guided_decoding_config: Optional[GuidedDecodingConfig] = None, lora_config: Optional[LoraConfig] = None, is_draft_model: bool = False, ): @@ -408,13 +405,6 @@ def __init__( self.dtype = self.model.config.torch_dtype self._init_model_capacity() - self.guided_decoder: Optional[GuidedDecoder] = None - if self.mapping.is_last_pp_rank( - ) and guided_decoding_config is not None: - self.guided_decoder = GuidedDecoder(guided_decoding_config, - self.batch_size, - self.model.vocab_size_padded) - self._torch_compile_backend = None try: @@ -2170,18 +2160,6 @@ def capture_forward_fn(inputs: Dict[str, Any]): with MoeLoadBalancerIterContext(moe_load_balancer): outputs = maybe_graph.run(inputs) - # Note: To overlap the CPU and GPU computation as much as possible, - # guided_decoder.build should be called immediately after the launch of the single step; - # while guided_decoder.execute should be called right before the samplings. - # We can insert other CPU computation between them in the future. - if self.mapping.is_last_pp_rank( - ) and self.guided_decoder is not None: - seq_slot_manager = resource_manager.get_resource_manager( - ResourceManagerType.SEQ_SLOT_MANAGER) - self.guided_decoder.build(scheduled_requests, seq_slot_manager) - self.guided_decoder.execute(scheduled_requests, - outputs['logits'], seq_slot_manager) - self._execute_logit_post_processors(scheduled_requests, outputs) return outputs diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor.py b/tensorrt_llm/_torch/pyexecutor/py_executor.py index dc7b79c265c..eefd09d6c53 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor.py @@ -31,6 +31,7 @@ from ..distributed import Distributed from ..speculative.drafter import Drafter +from .guided_decoder import GuidedDecoder from .kv_cache_transceiver import KvCacheTransceiver from .llm_request import (ExecutorRequest, LlmRequest, LlmRequestState, LlmResponse, executor_request_to_llm_request) @@ -204,6 +205,7 @@ def __init__(self, max_draft_len: int = 0, kv_cache_transceiver: Optional[KvCacheTransceiver] = None, draft_model_engine: Optional[ModelEngine] = None, + guided_decoder: Optional[GuidedDecoder] = None, garbage_collection_gen0_threshold: Optional[int] = None, start_worker: bool = True): super(PyExecutor, self).__init__() @@ -225,6 +227,7 @@ def __init__(self, self.enable_attention_dp = model_engine.enable_attention_dp self.sampler = sampler self.drafter = drafter + self.guided_decoder = guided_decoder self.dist = dist self.disable_overlap_scheduler = disable_overlap_scheduler @@ -801,6 +804,12 @@ def _executor_loop_pp(self): if self._need_return_logits(scheduled_batch): logits_host = batch_outputs["logits"].to( "cpu", non_blocking=True) + + if self.guided_decoder is not None: + self.guided_decoder.build(scheduled_batch) + self.guided_decoder.execute( + scheduled_batch, batch_outputs['logits']) + sample_state = self._sample_async( scheduled_batch, batch_outputs) sample_state.host.logits = logits_host @@ -975,6 +984,11 @@ def _executor_loop(self): batch_outputs = self._forward_step(scheduled_batch) + if self.guided_decoder is not None: + self.guided_decoder.build(scheduled_batch) + self.guided_decoder.execute(scheduled_batch, + batch_outputs['logits']) + sample_state = self._sample_async(scheduled_batch, batch_outputs) @@ -1123,6 +1137,14 @@ def _executor_loop_overlap(self): batch_outputs = self._forward_step(scheduled_batch, previous_tensors_device) + if self.previous_batch is not None: + self._update_requests(self.previous_batch.sample_state) + + if self.guided_decoder is not None: + self.guided_decoder.build(scheduled_batch) + self.guided_decoder.execute(scheduled_batch, + batch_outputs['logits']) + sample_state = self._sample_async(scheduled_batch, batch_outputs) assert sample_state is not None, "Sampling failed" @@ -1156,8 +1178,6 @@ def _executor_loop_overlap(self): self._terminate_ctx_finished_requests() def _process_previous_batch(self): - self._update_requests(self.previous_batch.sample_state) - if self.kv_cache_transceiver and self.previous_batch.ctx_transmission_reqs: for req in self.previous_batch.ctx_transmission_reqs: req.state = LlmRequestState.DISAGG_CONTEXT_TRANS_IN_PROGRESS diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py b/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py index b6893d69e26..de1978bb45f 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py @@ -24,6 +24,7 @@ create_py_executor_instance, instantiate_sampler, is_mla) from .config import PyTorchConfig from .config_utils import is_mla +from .guided_decoder import GuidedDecoder from .model_engine import PyTorchModelEngine from .py_executor import PyExecutor @@ -237,7 +238,6 @@ def create_py_executor( attn_runtime_features=attn_runtime_features, dist=dist, spec_config=spec_config, - guided_decoding_config=executor_config.guided_decoding_config, lora_config=lora_config, ) @@ -342,6 +342,17 @@ def create_py_executor( sampler = instantiate_sampler(model_engine, executor_config, pytorch_backend_config, mapping) + guided_decoder: Optional[GuidedDecoder] = None + if executor_config.guided_decoding_config is not None: + if spec_config is not None: + raise ValueError( + "Guided decoding is not supported with speculative decoding.") + if mapping.is_last_pp_rank(): + guided_decoder = GuidedDecoder( + executor_config.guided_decoding_config, + executor_config.max_batch_size, + model_engine.model.vocab_size_padded) + resources = {} estimating_kv_cache = False kv_cache_creator = None @@ -385,6 +396,7 @@ def create_py_executor( start_worker=False, sampler=sampler, drafter=drafter, + guided_decoder=guided_decoder, lora_config=lora_config, garbage_collection_gen0_threshold=garbage_collection_gen0_threshold, ) @@ -427,6 +439,7 @@ def create_py_executor( start_worker=False, sampler=sampler, drafter=drafter, + guided_decoder=guided_decoder, lora_config=lora_config, garbage_collection_gen0_threshold= garbage_collection_gen0_threshold, diff --git a/tests/integration/defs/accuracy/test_llm_api_pytorch.py b/tests/integration/defs/accuracy/test_llm_api_pytorch.py index 701461b19d1..e98075ffc9a 100644 --- a/tests/integration/defs/accuracy/test_llm_api_pytorch.py +++ b/tests/integration/defs/accuracy/test_llm_api_pytorch.py @@ -288,7 +288,6 @@ def test_guided_decoding(self, backend: str, mocker): mocker.patch.dict(os.environ, {"TRTLLM_XGUIDANCE_LENIENT": "1"}) llm = LLM(self.MODEL_PATH, guided_decoding_backend=backend, - disable_overlap_scheduler=True, cuda_graph_config=CudaGraphConfig()) with llm: task = JsonModeEval(self.MODEL_NAME) @@ -301,7 +300,6 @@ def test_guided_decoding_4gpus(self, backend: str, mocker): mocker.patch.dict(os.environ, {"TRTLLM_XGUIDANCE_LENIENT": "1"}) with LLM(self.MODEL_PATH, guided_decoding_backend=backend, - disable_overlap_scheduler=True, cuda_graph_config=CudaGraphConfig(), tensor_parallel_size=2, pipeline_parallel_size=2) as llm: diff --git a/tests/unittest/llmapi/apps/_test_openai_chat_structural_tag.py b/tests/unittest/llmapi/apps/_test_openai_chat_structural_tag.py index aeb46a8a0b0..edf6243c912 100644 --- a/tests/unittest/llmapi/apps/_test_openai_chat_structural_tag.py +++ b/tests/unittest/llmapi/apps/_test_openai_chat_structural_tag.py @@ -23,10 +23,7 @@ def temp_extra_llm_api_options_file(request): temp_dir = tempfile.gettempdir() temp_file_path = os.path.join(temp_dir, "extra_llm_api_options.yaml") try: - extra_llm_api_options_dict = { - "guided_decoding_backend": "xgrammar", - "disable_overlap_scheduler": True, - } + extra_llm_api_options_dict = {"guided_decoding_backend": "xgrammar"} with open(temp_file_path, 'w') as f: yaml.dump(extra_llm_api_options_dict, f)