Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 48 additions & 4 deletions tensorrt_llm/_torch/pyexecutor/py_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from ..models.modeling_utils import DecoderModelForCausalLM
from ..modules.decoder_layer import DecoderLayer
from ..speculative.drafter import Drafter
from ..speculative.speculation_gate import SpeculationGate
from .executor_request_queue import ExecutorRequestQueue, RequestQueueItem
from .guided_decoder import GuidedDecoder
from .handle_additional_outputs import HandleAdditionalOutputs
Expand Down Expand Up @@ -211,6 +212,20 @@ def __init__(self,
self.num_fetch_requests = 0
self.shutdown_event = threading.Event()

# Rolling acceptance tracking for spec decode (disable speculation if rolling acceptance is below threshold)
spec_config = getattr(self.model_engine, 'spec_config', None)
self.acceptance_window = getattr(
spec_config, 'acceptance_window',
None) if spec_config is not None else None
self.acceptance_length_threshold = getattr(
spec_config, 'acceptance_length_threshold',
None) if spec_config is not None else None
self.speculation_permanently_disabled = False
self.speculation_gate = None
if self.acceptance_window and self.acceptance_length_threshold is not None:
self.speculation_gate = SpeculationGate(
self.acceptance_window, self.acceptance_length_threshold)

# response used data
self.response_lock = threading.Lock()
self.response_cv = threading.Condition(self.response_lock)
Expand Down Expand Up @@ -1018,10 +1033,15 @@ def _prepare_and_schedule_batch(self):
self._pad_attention_dp_dummy_request()

if self.drafter is not None:
self.use_spec_decode = self.drafter.should_use_spec_decode(
self.active_requests, self.max_batch_size,
self.model_engine.max_num_tokens,
self.model_engine.spec_config.max_draft_len)
# Honor permanent disable flag based on rolling acceptance first
if getattr(self, 'speculation_permanently_disabled', False):
self.use_spec_decode = False
else:
self.use_spec_decode = self.drafter.should_use_spec_decode(
self.active_requests, self.max_batch_size,
self.model_engine.max_num_tokens,
self.model_engine.spec_config.max_draft_len)
logger.debug(f"Use spec decode: {self.use_spec_decode}")
self.model_engine.enable_spec_decode = self.use_spec_decode

# Set up draft_tokens in active_requests, because they could be used in the scheduling stage.
Expand Down Expand Up @@ -2056,6 +2076,30 @@ def _handle_responses(self):
new_responses.append((req_id, response))

if request_done:
if (self.drafter is not None and getattr(
self.model_engine, 'enable_spec_decode', False)
and not self.speculation_permanently_disabled
and not request.is_dummy and not self.is_warmup):
if self.speculation_gate is not None:
# Response handling runs on multiple PP ranks. Only the last PP rank performs
# sampling; restrict rolling stat updates to it to avoid overcounting.
if (not getattr(self.dist, 'has_pp',
False)) or self.dist.is_last_pp_rank:
avg_decoded = getattr(
request, 'avg_decoded_tokens_per_iter', None)
if avg_decoded is not None:
disabled_now, _ = self.speculation_gate.record_avg_decoded(
avg_decoded,
request_id=getattr(request, 'py_request_id',
None))
if disabled_now:
# disable speculation permanently
# starting from next iteration, _prepare_and_schedule_batch will set self.use_spec_decode to False
self.speculation_permanently_disabled = True
else:
logger.debug(
f"Request {request.py_request_id} has no avg_decoded_tokens_per_iter"
)
if self.block_reuse_enabled and not self.kv_cache_manager.is_vswa:
requests_to_terminate.append(request)
else:
Expand Down
77 changes: 77 additions & 0 deletions tensorrt_llm/_torch/speculative/speculation_gate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
from collections import deque
from typing import Optional, Tuple

from tensorrt_llm.logger import logger


class SpeculationGate:
"""
Tracks rolling average of accepted draft tokens per iteration over the last N completed requests.
Permanently disables speculation when average falls below a threshold.
"""

def __init__(self, window: int, threshold: float):
self.window = window
self.threshold = threshold
self.acceptance_history: Deque[float] = deque()
self.acceptance_sum: float = 0.0
self.num_completed_for_acceptance = 0
self.disabled = False
logger.debug(
f"[SpeculationGate] SpeculationGate initialized with window={self.window}, threshold={self.threshold}"
)

def reset(self) -> None:
self.acceptance_history.clear()
self.acceptance_sum = 0.0
self.num_completed_for_acceptance = 0
self.disabled = False

def record_avg_decoded(
self,
avg_decoded_tokens_per_iter: float,
request_id: Optional[int] = None) -> Tuple[bool, Optional[float]]:
"""
Record a completed request's avg_decoded_tokens_per_iter.
Returns (disabled_now, current_avg_accept) where disabled_now is True only when the call causes disable.
"""
if self.disabled or self.window is None or self.window <= 0 or self.threshold is None:
return False, None

# Extra Guard: if caller passed None, skip updating the rolling stats
if avg_decoded_tokens_per_iter is None:
return False, None

accepted_len = 0.0
accepted_len = max(0.0, float(avg_decoded_tokens_per_iter) - 1.0)

# Log per-request completion for debug
if request_id is not None:
logger.debug(
f"[SpeculationGate] Request {request_id} completed: avg_decoded={avg_decoded_tokens_per_iter if avg_decoded_tokens_per_iter is not None else 'None'}, accepted_len={accepted_len:.3f}"
)

# O(1) rolling update
self.acceptance_history.append(accepted_len)
logger.debug(
f"[SpeculationGate] Acceptance history: {self.acceptance_history}")
self.acceptance_sum += accepted_len
if len(self.acceptance_history) > self.window:
removed = self.acceptance_history.popleft()
self.acceptance_sum -= removed

self.num_completed_for_acceptance += 1

if self.num_completed_for_acceptance >= self.window:
avg_accept = self.acceptance_sum / len(self.acceptance_history)
if avg_accept < self.threshold:
self.disabled = True
logger.info(
f"[SpeculationGate] Speculative decoding disabled: rolling acceptance avg {avg_accept:.3f} < threshold {self.threshold} over last {self.window} requests"
)
return True, avg_accept
else:
# speculation is still enabled
return False, avg_accept

return False, None
30 changes: 30 additions & 0 deletions tensorrt_llm/llmapi/llm_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,36 @@ class DecodingBaseConfig(StrictBaseModel):
max_concurrency: Optional[int] = None

load_format: Optional[str] = None
# PyTorch only.
# Rolling average window size (N) for acceptance length across completed requests.
# If not set or set to 0, the feature is disabled.
acceptance_window: Optional[int] = None
# PyTorch only.
# Threshold for average acceptance length; speculation will be disabled
# permanently once the rolling average over the last N completed requests
# (N = acceptance_window) drops below this value.
acceptance_length_threshold: Optional[float] = None

# Validate acceptance controls at field level so they run on model creation
@field_validator('acceptance_window')
@classmethod
def _validate_acceptance_window(cls, v: Optional[int]):
if v is None:
return v
if v < 0:
raise ValueError(
f"acceptance_window must be >= 0 (0 disables), got {v}")
return v

@field_validator('acceptance_length_threshold')
@classmethod
def _validate_acceptance_length_threshold(cls, v: Optional[float]):
if v is None:
return v
if v < 0:
raise ValueError(
f"acceptance_length_threshold must be >= 0, got {v}")
return v

# If set, drafting is allowed to use chain drafter.
_allow_chain_drafter: bool = PrivateAttr(True)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ def prepare_draft_tokens(self,
max_num_tokens=4096 * 8,
max_draft_len=4)

# Small token budget ON case: token_cap = 28 // (1+4) = 5 → min(8, 12, 5) = 5 <= 6 → True
# Small token budget ON case: token_cap = 28 // (1+4) = 5 → min(12, 8, 5) = 5 <= 6 → True
active_requests = [object()] * 12
assert drafter.should_use_spec_decode(active_requests,
max_batch_size=8,
Expand Down
142 changes: 142 additions & 0 deletions tests/unittest/_torch/speculative/test_spec_gate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
import os
import sys
import unittest

import pytest
import torch
from utils.llm_data import llm_models_root
from utils.util import similar

from tensorrt_llm import LLM, SamplingParams
from tensorrt_llm._torch.speculative.speculation_gate import SpeculationGate
from tensorrt_llm.llmapi import (CudaGraphConfig, EagleDecodingConfig,
KvCacheConfig)

sys.path.append(os.path.join(os.path.dirname(__file__), '..'))


# It tests the end-to-end functionality of the SpeculationGate,
# which will turn off spec decode when the average acceptance length is below the threshold.
# It is set with acceptance window and acceptance threshold in spec_config.
# This test set the max_concurrency to a large value to prevent spec decode turned off due to number of effective requests > max_concurrency,
# So that we can only focus on the turning off effect from the SpeculationGate.
@pytest.mark.high_cuda_memory
def test_spec_gate_e2e():
total_mem_gb = torch.cuda.get_device_properties(0).total_memory / 1e9
if total_mem_gb < 35:
pytest.skip("Not enough memory to load target + draft model")
models_path = llm_models_root()
eagle_model_dir = f"{models_path}/EAGLE3-LLaMA3.1-Instruct-8B"
target_model_dir = f"{models_path}/llama-3.1-model/Llama-3.1-8B-Instruct"

max_batch_size = 2
max_draft_len = 4
kv_cache_config = KvCacheConfig(enable_block_reuse=True, max_tokens=8192)
cuda_graph_config = CudaGraphConfig(batch_sizes=[1])

llm_common_config = dict(
model=target_model_dir,
attn_backend="TRTLLM",
disable_overlap_scheduler=True,
cuda_graph_config=cuda_graph_config,
max_batch_size=max_batch_size,
kv_cache_config=kv_cache_config,
max_seq_len=4096,
)

spec_config = EagleDecodingConfig(
max_draft_len=max_draft_len,
speculative_model_dir=eagle_model_dir,
# Llama 3 does not support one model eagle.
eagle3_one_model=False,
max_concurrency=10000,
acceptance_window=5,
acceptance_length_threshold=0.6,
)

llm_spec = LLM(**llm_common_config, speculative_config=spec_config)
# Output tests
prompts = [
"The capital of France is",
"The president of the United States is",
"What is the capital of Australia?",
"Explain in one sentence why the sky is blue.",
"Who wrote the book 'Pride and Prejudice'?",
"List three U.S. national holidays in the year 2025.",
"What is the currency of Japan?",
"How many players are on a basketball court for one team?",
"List three primary colors.",
]
sampling_params = SamplingParams(max_tokens=32, temperature=0)

results_spec = llm_spec.generate(prompts, sampling_params)
generated_text_spec = [result.outputs[0].text for result in results_spec]
llm_spec.shutdown()

llm_ref = LLM(**llm_common_config)
results_ref = llm_ref.generate(prompts, sampling_params)
generated_text_ref = [result.outputs[0].text for result in results_ref]
llm_ref.shutdown()

for text_spec, text_ref in zip(generated_text_spec, generated_text_ref):
assert similar(text_spec, text_ref)


def test_returns_none_until_window_and_enabled_when_above_threshold():
gate = SpeculationGate(window=3, threshold=0.5)

disabled, avg = gate.record_avg_decoded(2.0, request_id=1)
assert disabled is False and avg is None
assert gate.disabled is False

disabled, avg = gate.record_avg_decoded(2.0, request_id=2)
assert disabled is False and avg is None
assert gate.disabled is False

disabled, avg = gate.record_avg_decoded(2.0, request_id=3)
assert disabled is False
assert avg == pytest.approx(1.0, rel=1e-6)
assert gate.disabled is False


def test_disables_when_avg_below_threshold_and_stays_disabled():
gate = SpeculationGate(window=3, threshold=0.7)

gate.record_avg_decoded(1.1)
gate.record_avg_decoded(1.2)

disabled, avg = gate.record_avg_decoded(1.3)
assert disabled is True
assert avg == pytest.approx(0.2, rel=1e-6)
assert gate.disabled is True

# Once disabled, subsequent calls do nothing and return (False, None)
disabled, avg = gate.record_avg_decoded(100.0)
assert disabled is False and avg is None
assert gate.disabled is True

disabled, avg = gate.record_avg_decoded(200.0)
assert disabled is False and avg is None
assert gate.disabled is True


def test_rolling_window_and_disable_on_drop():
gate = SpeculationGate(window=3, threshold=0.8)

# First three high-acceptance requests keep it enabled
gate.record_avg_decoded(2.0)
gate.record_avg_decoded(2.0)
disabled, avg = gate.record_avg_decoded(2.0)
assert disabled is False
assert avg == pytest.approx(1.0, rel=1e-6)
assert gate.disabled is False

# Fourth lower value enters window -> average drops below threshold -> disable
disabled, avg = gate.record_avg_decoded(1.2)
assert disabled is True
assert avg == pytest.approx((1.0 + 1.0 + 0.2) / 3.0, rel=1e-6)
assert gate.disabled is True


if __name__ == "__main__":
unittest.main()