From 801532120831b416f3a66cdab2073aaa1b0e7e09 Mon Sep 17 00:00:00 2001 From: Zheyu Fu Date: Wed, 3 Sep 2025 15:10:14 -0700 Subject: [PATCH 1/9] Set py_draft_token to [] instead of None when spec decode is off. Fix test in test_dynamic_spec_decode(patch is not called at all). Signed-off-by: Zheyu Fu --- .../_torch/speculative/test_dynamic_spec_decode.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/tests/unittest/_torch/speculative/test_dynamic_spec_decode.py b/tests/unittest/_torch/speculative/test_dynamic_spec_decode.py index 92937b34835..64a063bba6a 100644 --- a/tests/unittest/_torch/speculative/test_dynamic_spec_decode.py +++ b/tests/unittest/_torch/speculative/test_dynamic_spec_decode.py @@ -30,12 +30,12 @@ def test_dynamic_spec_decode(enforce_single_worker, total_mem_gb = torch.cuda.get_device_properties(0).total_memory / 1e9 if total_mem_gb < 35: pytest.skip("Not enough memory to load target + draft model") - models_path = llm_models_root() eagle_model_dir = f"{models_path}/EAGLE3-LLaMA3.1-Instruct-8B" target_model_dir = f"{models_path}/llama-3.1-model/Llama-3.1-8B-Instruct" - max_batch_size = 1 + # Allow with 3 concurrent requests + max_batch_size = 3 max_draft_len = 4 kv_cache_config = KvCacheConfig(enable_block_reuse=True, max_tokens=8192) cuda_graph_config = CudaGraphConfig(batch_sizes=[1]) @@ -47,11 +47,7 @@ def test_dynamic_spec_decode(enforce_single_worker, cuda_graph_config=cuda_graph_config, max_batch_size=max_batch_size, kv_cache_config=kv_cache_config, - # This max_seq_len is larger than the one specified - # in the llama 3 8B eagle's config. We want to make sure - # that the draft model won't go above its max in warmup - # in this test. - max_seq_len=8192, + max_seq_len=4096, ) spec_config = EagleDecodingConfig( @@ -59,6 +55,8 @@ def test_dynamic_spec_decode(enforce_single_worker, speculative_model_dir=eagle_model_dir, # Llama 3 does not support one model eagle. eagle3_one_model=False, + # allow speculation only when <= 2 effective request + max_concurrency=2, ) llm_spec = LLM(**llm_common_config, speculative_config=spec_config) From 4366ed4970eb07fa4b1b955c7d75a7bd9c2a9ddb Mon Sep 17 00:00:00 2001 From: Zheyu Fu Date: Fri, 5 Sep 2025 13:56:08 -0700 Subject: [PATCH 2/9] [None][feat] Turn off spec decode when rolling acceptance drops below threshold. Signed-off-by: Zheyu Fu --- .../_torch/pyexecutor/model_engine.py | 13 ++ tensorrt_llm/_torch/pyexecutor/py_executor.py | 28 +++- .../_torch/speculative/speculation_gate.py | 72 ++++++++ tensorrt_llm/llmapi/llm_args.py | 37 +++++ .../_torch/speculative/test_spec_gate.py | 154 ++++++++++++++++++ 5 files changed, 300 insertions(+), 4 deletions(-) create mode 100644 tensorrt_llm/_torch/speculative/speculation_gate.py create mode 100644 tests/unittest/_torch/speculative/test_spec_gate.py diff --git a/tensorrt_llm/_torch/pyexecutor/model_engine.py b/tensorrt_llm/_torch/pyexecutor/model_engine.py index bcd95020bb8..200e7d16a01 100644 --- a/tensorrt_llm/_torch/pyexecutor/model_engine.py +++ b/tensorrt_llm/_torch/pyexecutor/model_engine.py @@ -166,6 +166,19 @@ def __init__( self.spec_config = spec_config self.is_spec_decode = spec_config is not None self.enable_spec_decode = self.is_spec_decode + # Rolling acceptance tracking + self.acceptance_window = getattr( + spec_config, 'acceptance_window', + None) if spec_config is not None else None + self.acceptance_length_threshold = getattr( + spec_config, 'acceptance_length_threshold', + None) if spec_config is not None else None + # Initialize speculation gate early since it only depends on config + self.speculation_permanently_disabled = False + self.speculation_gate = None + if self.acceptance_window and self.acceptance_length_threshold is not None: + self.speculation_gate = SpeculationGate( + self.acceptance_window, self.acceptance_length_threshold) self.is_draft_model = is_draft_model self.attn_runtime_features = attn_runtime_features or AttentionRuntimeFeatures( diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor.py b/tensorrt_llm/_torch/pyexecutor/py_executor.py index aa0902484a5..dc4b0f63267 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor.py @@ -969,10 +969,16 @@ def _prepare_and_schedule_batch(self): self._pad_attention_dp_dummy_request() if self.drafter is not None: - self.use_spec_decode = self.drafter.should_use_spec_decode( - self.active_requests, self.max_batch_size, - self.model_engine.max_num_tokens, - self.model_engine.spec_config.max_draft_len) + # Honor permanent disable flag based on rolling acceptance first + if getattr(self.model_engine, 'speculation_permanently_disabled', + False): + self.use_spec_decode = False + else: + self.use_spec_decode = self.drafter.should_use_spec_decode( + self.active_requests, self.max_batch_size, + self.model_engine.max_num_tokens, + self.model_engine.spec_config.max_draft_len) + self.model_engine.enable_spec_decode = self.use_spec_decode # Set up draft_tokens in active_requests, because they could be used in the scheduling stage. @@ -1920,6 +1926,20 @@ def _handle_responses(self): new_responses.append((req_id, response)) if request_done: + if (self.model_engine.enable_spec_decode and + not self.model_engine.speculation_permanently_disabled + and not request.is_dummy and not self.is_warmup): + if self.model_engine.speculation_gate is not None: + avg_decoded = getattr(request, + 'avg_decoded_tokens_per_iter', + None) + disabled_now, _ = self.model_engine.speculation_gate.record_avg_decoded( + avg_decoded, + request_id=getattr(request, 'py_request_id', None)) + if disabled_now: + # disable speculation permanently + # starting from next iteration, _prepare_and_schedule_batch will set self.use_spec_decode to False + self.model_engine.speculation_permanently_disabled = True if request.is_disagg_context_transmission_state: self.ctx_in_transmission_requests.append(request) else: diff --git a/tensorrt_llm/_torch/speculative/speculation_gate.py b/tensorrt_llm/_torch/speculative/speculation_gate.py new file mode 100644 index 00000000000..ccd534f0855 --- /dev/null +++ b/tensorrt_llm/_torch/speculative/speculation_gate.py @@ -0,0 +1,72 @@ +from collections import deque +from typing import Optional, Tuple + +from tensorrt_llm.logger import logger + + +class SpeculationGate: + """ + Tracks rolling average of accepted draft tokens per iteration over the last N completed requests. + Permanently disables speculation when average falls below a threshold. + """ + + def __init__(self, window: int, threshold: float): + self.window = window + self.threshold = threshold + self.acceptance_history: Deque[float] = deque() + self.acceptance_sum: float = 0.0 + self.num_completed_for_acceptance = 0 + self.disabled = False + logger.debug( + f"[SpeculationGate] SpeculationGate initialized with window={self.window}, threshold={self.threshold}" + ) + + def reset(self) -> None: + self.acceptance_history.clear() + self.acceptance_sum = 0.0 + self.num_completed_for_acceptance = 0 + self.disabled = False + + def record_avg_decoded( + self, + avg_decoded_tokens_per_iter: Optional[float], + request_id: Optional[int] = None) -> Tuple[bool, Optional[float]]: + """ + Record a completed request's avg_decoded_tokens_per_iter. + Returns (disabled_now, current_avg_accept) where disabled_now is True only when the call causes disable. + """ + if self.disabled or self.window is None or self.window <= 0 or self.threshold is None: + return False, None + + accepted_len = 0.0 + if avg_decoded_tokens_per_iter is not None: + accepted_len = max(0.0, float(avg_decoded_tokens_per_iter) - 1.0) + + # Log per-request completion for debug + if request_id is not None: + logger.debug( + f"[SpeculationGate] Request {request_id} completed: avg_decoded={avg_decoded_tokens_per_iter if avg_decoded_tokens_per_iter is not None else 'None'}, accepted_len={accepted_len:.3f}" + ) + + # O(1) rolling update + self.acceptance_history.append(accepted_len) + self.acceptance_sum += accepted_len + if len(self.acceptance_history) > self.window: + removed = self.acceptance_history.popleft() + self.acceptance_sum -= removed + + self.num_completed_for_acceptance += 1 + + if self.num_completed_for_acceptance >= self.window: + avg_accept = self.acceptance_sum / len(self.acceptance_history) + if avg_accept < self.threshold: + self.disabled = True + logger.info( + f"[SpeculationGate] Speculative decoding disabled: rolling acceptance avg {avg_accept:.3f} < threshold {self.threshold} over last {self.window} requests" + ) + return True, avg_accept + else: + # speculation is still enabled + return False, avg_accept + + return False, None diff --git a/tensorrt_llm/llmapi/llm_args.py b/tensorrt_llm/llmapi/llm_args.py index 5a05ee741f3..3ed20d9485d 100644 --- a/tensorrt_llm/llmapi/llm_args.py +++ b/tensorrt_llm/llmapi/llm_args.py @@ -359,6 +359,43 @@ class DecodingBaseConfig(StrictBaseModel): max_concurrency: Optional[int] = None load_format: Optional[str] = None + # PyTorch only. + # Rolling average window size (N) for acceptance length across completed requests. + # If not set or set to 0, the feature is disabled. + acceptance_window: Optional[int] = None + # PyTorch only. + # Threshold for average acceptance length; speculation will be disabled + # permanently once the rolling average over the last N completed requests + # (N = acceptance_window) drops below this value. + acceptance_length_threshold: Optional[float] = None + + # Upper bound to avoid accidental huge windows + MAX_ACCEPTANCE_WINDOW: ClassVar[int] = 100000 + + # Validate acceptance controls at field level so they run on model creation + @field_validator('acceptance_window') + @classmethod + def _validate_acceptance_window(cls, v: Optional[int]): + if v is None: + return v + if v < 0: + raise ValueError( + f"acceptance_window must be >= 0 (0 disables), got {v}") + if v > cls.MAX_ACCEPTANCE_WINDOW: + raise ValueError( + f"acceptance_window must be <= {cls.MAX_ACCEPTANCE_WINDOW}, got {v}" + ) + return v + + @field_validator('acceptance_length_threshold') + @classmethod + def _validate_acceptance_length_threshold(cls, v: Optional[float]): + if v is None: + return v + if v < 0: + raise ValueError( + f"acceptance_length_threshold must be >= 0, got {v}") + return v # If set, drafting uses greedy sampling, irrespective of sampling parameters. _allow_greedy_draft_tokens: bool = PrivateAttr(True) diff --git a/tests/unittest/_torch/speculative/test_spec_gate.py b/tests/unittest/_torch/speculative/test_spec_gate.py new file mode 100644 index 00000000000..77837f5dfdd --- /dev/null +++ b/tests/unittest/_torch/speculative/test_spec_gate.py @@ -0,0 +1,154 @@ +import os +import sys +import unittest + +import pytest +import torch +from utils.llm_data import llm_models_root + +from tensorrt_llm import LLM, SamplingParams +from tensorrt_llm._torch.speculative.speculation_gate import SpeculationGate +from tensorrt_llm.llmapi import (CudaGraphConfig, EagleDecodingConfig, + KvCacheConfig) +from tensorrt_llm.llmapi.llm_args import SamplerType + +sys.path.append(os.path.join(os.path.dirname(__file__), '..')) + + +# It tests the end-to-end functionality of the SpeculationGate, +# which will turn off spec decode when the average acceptance length is below the threshold. +# It is set with acceptance window and acceptance threshold in spec_config. +# This test set the max_concurrency to a large value to prevent spec decode turned off due to number of effective requests > max_concurrency, +# So that we can only focus on the turning off effect from the SpeculationGate. +@pytest.mark.high_cuda_memory +def test_spec_gate_e2e(): + total_mem_gb = torch.cuda.get_device_properties(0).total_memory / 1e9 + if total_mem_gb < 35: + pytest.skip("Not enough memory to load target + draft model") + models_path = llm_models_root() + eagle_model_dir = f"{models_path}/EAGLE3-LLaMA3.1-Instruct-8B" + target_model_dir = f"{models_path}/llama-3.1-model/Llama-3.1-8B-Instruct" + + max_batch_size = 2 + max_draft_len = 4 + kv_cache_config = KvCacheConfig(enable_block_reuse=True, max_tokens=8192) + cuda_graph_config = CudaGraphConfig(batch_sizes=[1]) + + llm_common_config = dict( + model=target_model_dir, + attn_backend="TRTLLM", + disable_overlap_scheduler=True, + cuda_graph_config=cuda_graph_config, + max_batch_size=max_batch_size, + kv_cache_config=kv_cache_config, + max_seq_len=4096, + # Force TRTLLMSampler for testing avg_decoded_tokens_per_iter from C++ path + sampler_type=SamplerType.TRTLLMSampler, + ) + + spec_config = EagleDecodingConfig( + max_draft_len=max_draft_len, + speculative_model_dir=eagle_model_dir, + # Llama 3 does not support one model eagle. + eagle3_one_model=False, + max_concurrency=10000, + acceptance_window=5, + acceptance_length_threshold=0.6, + ) + + llm_spec = LLM(**llm_common_config, speculative_config=spec_config) + # Output tests + prompts = [ + "The capital of France is", + "The president of the United States is", + "What is the capital of Australia?", + "Explain in one sentence why the sky is blue.", + "Who wrote the book 'Pride and Prejudice'?", + "List three U.S. national holidays in the year 2025.", + "What is the currency of Japan?", + "How many players are on a basketball court for one team?", + "List three primary colors.", + "The Roman Empire fell in the year", + ] + sampling_params = SamplingParams(max_tokens=10, temperature=0) + + results_spec = llm_spec.generate(prompts, sampling_params) + generated_text_spec = [result.outputs[0].text for result in results_spec] + llm_spec.shutdown() + + llm_ref = LLM(**llm_common_config) + results_ref = llm_ref.generate(prompts, sampling_params) + generated_text_ref = [result.outputs[0].text for result in results_ref] + llm_ref.shutdown() + + i = 0 + for text_spec, text_ref in zip(generated_text_spec, generated_text_ref): + print(f"prompt: {prompts[i]}") + print(f"spec: {text_spec}") + print(f"ref: {text_ref}") + print("-" * 100) + i += 1 + + for text_spec, text_ref in zip(generated_text_spec, generated_text_ref): + # The spec decode algorithm currently guarantees identical results + assert text_spec == text_ref + + +def test_returns_none_until_window_and_enabled_when_above_threshold(): + gate = SpeculationGate(window=3, threshold=0.5) + + disabled, avg = gate.record_avg_decoded(2.0, request_id=1) + assert disabled is False and avg is None + assert gate.disabled is False + + disabled, avg = gate.record_avg_decoded(2.0, request_id=2) + assert disabled is False and avg is None + assert gate.disabled is False + + disabled, avg = gate.record_avg_decoded(2.0, request_id=3) + assert disabled is False + assert avg == pytest.approx(1.0, rel=1e-6) + assert gate.disabled is False + + +def test_disables_when_avg_below_threshold_and_stays_disabled(): + gate = SpeculationGate(window=3, threshold=0.7) + + gate.record_avg_decoded(1.1) + gate.record_avg_decoded(1.2) + + disabled, avg = gate.record_avg_decoded(1.3) + assert disabled is True + assert avg == pytest.approx(0.2, rel=1e-6) + assert gate.disabled is True + + # Once disabled, subsequent calls do nothing and return (False, None) + disabled, avg = gate.record_avg_decoded(100.0) + assert disabled is False and avg is None + assert gate.disabled is True + + disabled, avg = gate.record_avg_decoded(200.0) + assert disabled is False and avg is None + assert gate.disabled is True + + +def test_rolling_window_and_disable_on_drop(): + gate = SpeculationGate(window=3, threshold=0.8) + + # First three high-acceptance requests keep it enabled + gate.record_avg_decoded(2.0) + gate.record_avg_decoded(2.0) + disabled, avg = gate.record_avg_decoded(2.0) + assert disabled is False + assert avg == pytest.approx(1.0, rel=1e-6) + assert gate.disabled is False + + # Fourth lower value enters window -> average drops below threshold -> disable + disabled, avg = gate.record_avg_decoded(1.2) + assert disabled is True + assert avg == pytest.approx((1.0 + 1.0 + 0.2) / 3.0, rel=1e-6) + assert gate.disabled is True + + +if __name__ == "__main__": + unittest.main() From 253a70a6d441a4bc78cbf996f1ae0b2c28e9aeb3 Mon Sep 17 00:00:00 2001 From: Zheyu Fu Date: Fri, 5 Sep 2025 14:01:21 -0700 Subject: [PATCH 3/9] Clean. Signed-off-by: Zheyu Fu --- tests/unittest/_torch/speculative/test_dynamic_spec_decode.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unittest/_torch/speculative/test_dynamic_spec_decode.py b/tests/unittest/_torch/speculative/test_dynamic_spec_decode.py index 64a063bba6a..5be2e215df6 100644 --- a/tests/unittest/_torch/speculative/test_dynamic_spec_decode.py +++ b/tests/unittest/_torch/speculative/test_dynamic_spec_decode.py @@ -203,7 +203,7 @@ def prepare_draft_tokens(self, max_num_tokens=4096 * 8, max_draft_len=4) - # Small token budget ON case: token_cap = 28 // (1+4) = 5 → min(8, 12, 5) = 5 <= 6 → True + # Small token budget ON case: token_cap = 28 // (1+4) = 5 → min(12, 8, 5) = 5 <= 6 → True active_requests = [object()] * 12 assert drafter.should_use_spec_decode(active_requests, max_batch_size=8, From 7448b54748813f6825ba4e808657464f0d6ad77a Mon Sep 17 00:00:00 2001 From: Zheyu Fu Date: Fri, 5 Sep 2025 15:37:20 -0700 Subject: [PATCH 4/9] Add PP guard to prevent overcounting average acceptance in py_executor. Also clean. Signed-off-by: Zheyu Fu --- tensorrt_llm/_torch/pyexecutor/py_executor.py | 24 +++++++++++-------- .../_torch/speculative/test_spec_gate.py | 11 --------- 2 files changed, 14 insertions(+), 21 deletions(-) diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor.py b/tensorrt_llm/_torch/pyexecutor/py_executor.py index dc4b0f63267..95aaff0c115 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor.py @@ -1930,16 +1930,20 @@ def _handle_responses(self): not self.model_engine.speculation_permanently_disabled and not request.is_dummy and not self.is_warmup): if self.model_engine.speculation_gate is not None: - avg_decoded = getattr(request, - 'avg_decoded_tokens_per_iter', - None) - disabled_now, _ = self.model_engine.speculation_gate.record_avg_decoded( - avg_decoded, - request_id=getattr(request, 'py_request_id', None)) - if disabled_now: - # disable speculation permanently - # starting from next iteration, _prepare_and_schedule_batch will set self.use_spec_decode to False - self.model_engine.speculation_permanently_disabled = True + # Response handling runs on multiple PP ranks. Only the last PP rank performs + # sampling; restrict rolling stat updates to it to avoid overcounting. + if (not getattr(self.dist, 'has_pp', + False)) or self.dist.is_last_pp_rank: + avg_decoded = getattr( + request, 'avg_decoded_tokens_per_iter', None) + disabled_now, _ = self.model_engine.speculation_gate.record_avg_decoded( + avg_decoded, + request_id=getattr(request, 'py_request_id', + None)) + if disabled_now: + # disable speculation permanently + # starting from next iteration, _prepare_and_schedule_batch will set self.use_spec_decode to False + self.model_engine.speculation_permanently_disabled = True if request.is_disagg_context_transmission_state: self.ctx_in_transmission_requests.append(request) else: diff --git a/tests/unittest/_torch/speculative/test_spec_gate.py b/tests/unittest/_torch/speculative/test_spec_gate.py index 77837f5dfdd..4319133c711 100644 --- a/tests/unittest/_torch/speculative/test_spec_gate.py +++ b/tests/unittest/_torch/speculative/test_spec_gate.py @@ -10,7 +10,6 @@ from tensorrt_llm._torch.speculative.speculation_gate import SpeculationGate from tensorrt_llm.llmapi import (CudaGraphConfig, EagleDecodingConfig, KvCacheConfig) -from tensorrt_llm.llmapi.llm_args import SamplerType sys.path.append(os.path.join(os.path.dirname(__file__), '..')) @@ -42,8 +41,6 @@ def test_spec_gate_e2e(): max_batch_size=max_batch_size, kv_cache_config=kv_cache_config, max_seq_len=4096, - # Force TRTLLMSampler for testing avg_decoded_tokens_per_iter from C++ path - sampler_type=SamplerType.TRTLLMSampler, ) spec_config = EagleDecodingConfig( @@ -81,14 +78,6 @@ def test_spec_gate_e2e(): generated_text_ref = [result.outputs[0].text for result in results_ref] llm_ref.shutdown() - i = 0 - for text_spec, text_ref in zip(generated_text_spec, generated_text_ref): - print(f"prompt: {prompts[i]}") - print(f"spec: {text_spec}") - print(f"ref: {text_ref}") - print("-" * 100) - i += 1 - for text_spec, text_ref in zip(generated_text_spec, generated_text_ref): # The spec decode algorithm currently guarantees identical results assert text_spec == text_ref From 0c0314c2f7a724223ef14b935b1430d5395f623a Mon Sep 17 00:00:00 2001 From: Zheyu Fu Date: Fri, 5 Sep 2025 15:49:31 -0700 Subject: [PATCH 5/9] Easier test case. Signed-off-by: Zheyu Fu --- tests/unittest/_torch/speculative/test_spec_gate.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/unittest/_torch/speculative/test_spec_gate.py b/tests/unittest/_torch/speculative/test_spec_gate.py index 4319133c711..b2ec7685580 100644 --- a/tests/unittest/_torch/speculative/test_spec_gate.py +++ b/tests/unittest/_torch/speculative/test_spec_gate.py @@ -65,9 +65,8 @@ def test_spec_gate_e2e(): "What is the currency of Japan?", "How many players are on a basketball court for one team?", "List three primary colors.", - "The Roman Empire fell in the year", ] - sampling_params = SamplingParams(max_tokens=10, temperature=0) + sampling_params = SamplingParams(max_tokens=5, temperature=0) results_spec = llm_spec.generate(prompts, sampling_params) generated_text_spec = [result.outputs[0].text for result in results_spec] From 403b4610ace8af593033de6762d3d2414379d896 Mon Sep 17 00:00:00 2001 From: Zheyu Fu Date: Wed, 10 Sep 2025 13:23:03 -0700 Subject: [PATCH 6/9] Address Mike's comments Signed-off-by: Zheyu Fu --- .../_torch/pyexecutor/model_engine.py | 13 ----- tensorrt_llm/_torch/pyexecutor/py_executor.py | 47 +++++++++++++------ .../_torch/speculative/speculation_gate.py | 17 ++++--- tensorrt_llm/llmapi/llm_args.py | 7 --- 4 files changed, 44 insertions(+), 40 deletions(-) diff --git a/tensorrt_llm/_torch/pyexecutor/model_engine.py b/tensorrt_llm/_torch/pyexecutor/model_engine.py index 200e7d16a01..bcd95020bb8 100644 --- a/tensorrt_llm/_torch/pyexecutor/model_engine.py +++ b/tensorrt_llm/_torch/pyexecutor/model_engine.py @@ -166,19 +166,6 @@ def __init__( self.spec_config = spec_config self.is_spec_decode = spec_config is not None self.enable_spec_decode = self.is_spec_decode - # Rolling acceptance tracking - self.acceptance_window = getattr( - spec_config, 'acceptance_window', - None) if spec_config is not None else None - self.acceptance_length_threshold = getattr( - spec_config, 'acceptance_length_threshold', - None) if spec_config is not None else None - # Initialize speculation gate early since it only depends on config - self.speculation_permanently_disabled = False - self.speculation_gate = None - if self.acceptance_window and self.acceptance_length_threshold is not None: - self.speculation_gate = SpeculationGate( - self.acceptance_window, self.acceptance_length_threshold) self.is_draft_model = is_draft_model self.attn_runtime_features = attn_runtime_features or AttentionRuntimeFeatures( diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor.py b/tensorrt_llm/_torch/pyexecutor/py_executor.py index 95aaff0c115..391331f2ebe 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor.py @@ -38,6 +38,7 @@ from ..models.modeling_utils import DecoderModelForCausalLM from ..modules.decoder_layer import DecoderLayer from ..speculative.drafter import Drafter +from ..speculative.speculation_gate import SpeculationGate from .executor_request_queue import ExecutorRequestQueue, RequestQueueItem from .guided_decoder import GuidedDecoder from .handle_logits import HandleLogits @@ -208,6 +209,20 @@ def __init__(self, self.num_fetch_requests = 0 self.shutdown_event = threading.Event() + # Rolling acceptance tracking for spec decode (disable speculation if rolling acceptance is below threshold) + spec_config = getattr(self.model_engine, 'spec_config', None) + self.acceptance_window = getattr( + spec_config, 'acceptance_window', + None) if spec_config is not None else None + self.acceptance_length_threshold = getattr( + spec_config, 'acceptance_length_threshold', + None) if spec_config is not None else None + self.speculation_permanently_disabled = False + self.speculation_gate = None + if self.acceptance_window and self.acceptance_length_threshold is not None: + self.speculation_gate = SpeculationGate( + self.acceptance_window, self.acceptance_length_threshold) + # response used data self.response_lock = threading.Lock() self.response_cv = threading.Condition(self.response_lock) @@ -970,15 +985,14 @@ def _prepare_and_schedule_batch(self): if self.drafter is not None: # Honor permanent disable flag based on rolling acceptance first - if getattr(self.model_engine, 'speculation_permanently_disabled', - False): + if getattr(self, 'speculation_permanently_disabled', False): self.use_spec_decode = False else: self.use_spec_decode = self.drafter.should_use_spec_decode( self.active_requests, self.max_batch_size, self.model_engine.max_num_tokens, self.model_engine.spec_config.max_draft_len) - + logger.debug(f"Use spec decode: {self.use_spec_decode}") self.model_engine.enable_spec_decode = self.use_spec_decode # Set up draft_tokens in active_requests, because they could be used in the scheduling stage. @@ -1926,24 +1940,29 @@ def _handle_responses(self): new_responses.append((req_id, response)) if request_done: - if (self.model_engine.enable_spec_decode and - not self.model_engine.speculation_permanently_disabled + if (self.model_engine.enable_spec_decode + and not self.speculation_permanently_disabled and not request.is_dummy and not self.is_warmup): - if self.model_engine.speculation_gate is not None: + if self.speculation_gate is not None: # Response handling runs on multiple PP ranks. Only the last PP rank performs # sampling; restrict rolling stat updates to it to avoid overcounting. if (not getattr(self.dist, 'has_pp', False)) or self.dist.is_last_pp_rank: avg_decoded = getattr( request, 'avg_decoded_tokens_per_iter', None) - disabled_now, _ = self.model_engine.speculation_gate.record_avg_decoded( - avg_decoded, - request_id=getattr(request, 'py_request_id', - None)) - if disabled_now: - # disable speculation permanently - # starting from next iteration, _prepare_and_schedule_batch will set self.use_spec_decode to False - self.model_engine.speculation_permanently_disabled = True + if avg_decoded is not None: + disabled_now, _ = self.speculation_gate.record_avg_decoded( + avg_decoded, + request_id=getattr(request, 'py_request_id', + None)) + if disabled_now: + # disable speculation permanently + # starting from next iteration, _prepare_and_schedule_batch will set self.use_spec_decode to False + self.speculation_permanently_disabled = True + else: + logger.debug( + f"Request {request.py_request_id} has no avg_decoded_tokens_per_iter" + ) if request.is_disagg_context_transmission_state: self.ctx_in_transmission_requests.append(request) else: diff --git a/tensorrt_llm/_torch/speculative/speculation_gate.py b/tensorrt_llm/_torch/speculative/speculation_gate.py index ccd534f0855..69b4fa22e99 100644 --- a/tensorrt_llm/_torch/speculative/speculation_gate.py +++ b/tensorrt_llm/_torch/speculative/speculation_gate.py @@ -29,18 +29,21 @@ def reset(self) -> None: def record_avg_decoded( self, - avg_decoded_tokens_per_iter: Optional[float], + avg_decoded_tokens_per_iter: float, request_id: Optional[int] = None) -> Tuple[bool, Optional[float]]: """ - Record a completed request's avg_decoded_tokens_per_iter. - Returns (disabled_now, current_avg_accept) where disabled_now is True only when the call causes disable. - """ + Record a completed request's avg_decoded_tokens_per_iter. + Returns (disabled_now, current_avg_accept) where disabled_now is True only when the call causes disable. + """ if self.disabled or self.window is None or self.window <= 0 or self.threshold is None: return False, None + # Extra Guard: if caller passed None, skip updating the rolling stats + if avg_decoded_tokens_per_iter is None: + return False, None + accepted_len = 0.0 - if avg_decoded_tokens_per_iter is not None: - accepted_len = max(0.0, float(avg_decoded_tokens_per_iter) - 1.0) + accepted_len = max(0.0, float(avg_decoded_tokens_per_iter) - 1.0) # Log per-request completion for debug if request_id is not None: @@ -50,6 +53,8 @@ def record_avg_decoded( # O(1) rolling update self.acceptance_history.append(accepted_len) + logger.debug( + f"[SpeculationGate] Acceptance history: {self.acceptance_history}") self.acceptance_sum += accepted_len if len(self.acceptance_history) > self.window: removed = self.acceptance_history.popleft() diff --git a/tensorrt_llm/llmapi/llm_args.py b/tensorrt_llm/llmapi/llm_args.py index 3ed20d9485d..ef6d9757bdc 100644 --- a/tensorrt_llm/llmapi/llm_args.py +++ b/tensorrt_llm/llmapi/llm_args.py @@ -369,9 +369,6 @@ class DecodingBaseConfig(StrictBaseModel): # (N = acceptance_window) drops below this value. acceptance_length_threshold: Optional[float] = None - # Upper bound to avoid accidental huge windows - MAX_ACCEPTANCE_WINDOW: ClassVar[int] = 100000 - # Validate acceptance controls at field level so they run on model creation @field_validator('acceptance_window') @classmethod @@ -381,10 +378,6 @@ def _validate_acceptance_window(cls, v: Optional[int]): if v < 0: raise ValueError( f"acceptance_window must be >= 0 (0 disables), got {v}") - if v > cls.MAX_ACCEPTANCE_WINDOW: - raise ValueError( - f"acceptance_window must be <= {cls.MAX_ACCEPTANCE_WINDOW}, got {v}" - ) return v @field_validator('acceptance_length_threshold') From 3f4781cf612bb20a9bb1d9a0147ed48dfcb8bc26 Mon Sep 17 00:00:00 2001 From: Zheyu Fu Date: Tue, 23 Sep 2025 17:45:02 +0000 Subject: [PATCH 7/9] Clean. Signed-off-by: Zheyu Fu --- .../_torch/speculative/test_dynamic_spec_decode.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/tests/unittest/_torch/speculative/test_dynamic_spec_decode.py b/tests/unittest/_torch/speculative/test_dynamic_spec_decode.py index 5be2e215df6..dfe3ab5bc0c 100644 --- a/tests/unittest/_torch/speculative/test_dynamic_spec_decode.py +++ b/tests/unittest/_torch/speculative/test_dynamic_spec_decode.py @@ -30,12 +30,12 @@ def test_dynamic_spec_decode(enforce_single_worker, total_mem_gb = torch.cuda.get_device_properties(0).total_memory / 1e9 if total_mem_gb < 35: pytest.skip("Not enough memory to load target + draft model") + models_path = llm_models_root() eagle_model_dir = f"{models_path}/EAGLE3-LLaMA3.1-Instruct-8B" target_model_dir = f"{models_path}/llama-3.1-model/Llama-3.1-8B-Instruct" - # Allow with 3 concurrent requests - max_batch_size = 3 + max_batch_size = 1 max_draft_len = 4 kv_cache_config = KvCacheConfig(enable_block_reuse=True, max_tokens=8192) cuda_graph_config = CudaGraphConfig(batch_sizes=[1]) @@ -47,7 +47,11 @@ def test_dynamic_spec_decode(enforce_single_worker, cuda_graph_config=cuda_graph_config, max_batch_size=max_batch_size, kv_cache_config=kv_cache_config, - max_seq_len=4096, + # This max_seq_len is larger than the one specified + # in the llama 3 8B eagle's config. We want to make sure + # that the draft model won't go above its max in warmup + # in this test. + max_seq_len=8192, ) spec_config = EagleDecodingConfig( @@ -55,8 +59,6 @@ def test_dynamic_spec_decode(enforce_single_worker, speculative_model_dir=eagle_model_dir, # Llama 3 does not support one model eagle. eagle3_one_model=False, - # allow speculation only when <= 2 effective request - max_concurrency=2, ) llm_spec = LLM(**llm_common_config, speculative_config=spec_config) From 03ef65f037bacf7a825efc3a8b163b17ef6a3cce Mon Sep 17 00:00:00 2001 From: Zheyu Fu Date: Fri, 10 Oct 2025 00:17:45 +0000 Subject: [PATCH 8/9] Change to similarity check. Signed-off-by: Zheyu Fu --- tests/unittest/_torch/speculative/test_spec_gate.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/unittest/_torch/speculative/test_spec_gate.py b/tests/unittest/_torch/speculative/test_spec_gate.py index b2ec7685580..ad0d3d3190e 100644 --- a/tests/unittest/_torch/speculative/test_spec_gate.py +++ b/tests/unittest/_torch/speculative/test_spec_gate.py @@ -5,6 +5,7 @@ import pytest import torch from utils.llm_data import llm_models_root +from utils.util import similar from tensorrt_llm import LLM, SamplingParams from tensorrt_llm._torch.speculative.speculation_gate import SpeculationGate @@ -66,7 +67,7 @@ def test_spec_gate_e2e(): "How many players are on a basketball court for one team?", "List three primary colors.", ] - sampling_params = SamplingParams(max_tokens=5, temperature=0) + sampling_params = SamplingParams(max_tokens=32, temperature=0) results_spec = llm_spec.generate(prompts, sampling_params) generated_text_spec = [result.outputs[0].text for result in results_spec] @@ -78,8 +79,7 @@ def test_spec_gate_e2e(): llm_ref.shutdown() for text_spec, text_ref in zip(generated_text_spec, generated_text_ref): - # The spec decode algorithm currently guarantees identical results - assert text_spec == text_ref + assert similar(text_spec, text_ref) def test_returns_none_until_window_and_enabled_when_above_threshold(): From c10fcae06a884ea78281c765698a3d51415e379f Mon Sep 17 00:00:00 2001 From: Zheyu Fu Date: Fri, 10 Oct 2025 18:25:51 +0000 Subject: [PATCH 9/9] Add defensive Signed-off-by: Zheyu Fu --- tensorrt_llm/_torch/pyexecutor/py_executor.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor.py b/tensorrt_llm/_torch/pyexecutor/py_executor.py index 0148add1c61..e14703840de 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor.py @@ -2066,7 +2066,8 @@ def _handle_responses(self): new_responses.append((req_id, response)) if request_done: - if (self.model_engine.enable_spec_decode + if (self.drafter is not None and getattr( + self.model_engine, 'enable_spec_decode', False) and not self.speculation_permanently_disabled and not request.is_dummy and not self.is_warmup): if self.speculation_gate is not None: