Skip to content

Commit a760ddc

Browse files
committed
Address Mike's comments
Signed-off-by: Zheyu Fu <[email protected]>
1 parent 52b17b5 commit a760ddc

File tree

4 files changed

+44
-40
lines changed

4 files changed

+44
-40
lines changed

tensorrt_llm/_torch/pyexecutor/model_engine.py

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -297,19 +297,6 @@ def __init__(
297297
self.spec_config = spec_config
298298
self.is_spec_decode = spec_config is not None
299299
self.enable_spec_decode = self.is_spec_decode
300-
# Rolling acceptance tracking
301-
self.acceptance_window = getattr(
302-
spec_config, 'acceptance_window',
303-
None) if spec_config is not None else None
304-
self.acceptance_length_threshold = getattr(
305-
spec_config, 'acceptance_length_threshold',
306-
None) if spec_config is not None else None
307-
# Initialize speculation gate early since it only depends on config
308-
self.speculation_permanently_disabled = False
309-
self.speculation_gate = None
310-
if self.acceptance_window and self.acceptance_length_threshold is not None:
311-
self.speculation_gate = SpeculationGate(
312-
self.acceptance_window, self.acceptance_length_threshold)
313300
self.is_draft_model = is_draft_model
314301

315302
self.attn_runtime_features = attn_runtime_features or AttentionRuntimeFeatures(

tensorrt_llm/_torch/pyexecutor/py_executor.py

Lines changed: 33 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
from ..models.modeling_utils import DecoderModelForCausalLM
3939
from ..modules.decoder_layer import DecoderLayer
4040
from ..speculative.drafter import Drafter
41+
from ..speculative.speculation_gate import SpeculationGate
4142
from .executor_request_queue import ExecutorRequestQueue, RequestQueueItem
4243
from .guided_decoder import GuidedDecoder
4344
from .handle_logits import HandleLogits
@@ -207,6 +208,20 @@ def __init__(self,
207208
self.num_fetch_requests = 0
208209
self.shutdown_event = threading.Event()
209210

211+
# Rolling acceptance tracking for spec decode (disable speculation if rolling acceptance is below threshold)
212+
spec_config = getattr(self.model_engine, 'spec_config', None)
213+
self.acceptance_window = getattr(
214+
spec_config, 'acceptance_window',
215+
None) if spec_config is not None else None
216+
self.acceptance_length_threshold = getattr(
217+
spec_config, 'acceptance_length_threshold',
218+
None) if spec_config is not None else None
219+
self.speculation_permanently_disabled = False
220+
self.speculation_gate = None
221+
if self.acceptance_window and self.acceptance_length_threshold is not None:
222+
self.speculation_gate = SpeculationGate(
223+
self.acceptance_window, self.acceptance_length_threshold)
224+
210225
# response used data
211226
self.response_lock = threading.Lock()
212227
self.response_cv = threading.Condition(self.response_lock)
@@ -969,15 +984,14 @@ def _prepare_and_schedule_batch(self):
969984

970985
if self.drafter is not None:
971986
# Honor permanent disable flag based on rolling acceptance first
972-
if getattr(self.model_engine, 'speculation_permanently_disabled',
973-
False):
987+
if getattr(self, 'speculation_permanently_disabled', False):
974988
self.use_spec_decode = False
975989
else:
976990
self.use_spec_decode = self.drafter.should_use_spec_decode(
977991
self.active_requests, self.max_batch_size,
978992
self.model_engine.max_num_tokens,
979993
self.model_engine.spec_config.max_draft_len)
980-
994+
logger.debug(f"Use spec decode: {self.use_spec_decode}")
981995
self.model_engine.enable_spec_decode = self.use_spec_decode
982996

983997
# When overlap scheduler is enabled, and we already prepared the draft tokens in the previous batch,
@@ -1911,24 +1925,29 @@ def _handle_responses(self):
19111925
new_responses.append((req_id, response))
19121926

19131927
if request_done:
1914-
if (self.model_engine.enable_spec_decode and
1915-
not self.model_engine.speculation_permanently_disabled
1928+
if (self.model_engine.enable_spec_decode
1929+
and not self.speculation_permanently_disabled
19161930
and not request.is_dummy and not self.is_warmup):
1917-
if self.model_engine.speculation_gate is not None:
1931+
if self.speculation_gate is not None:
19181932
# Response handling runs on multiple PP ranks. Only the last PP rank performs
19191933
# sampling; restrict rolling stat updates to it to avoid overcounting.
19201934
if (not getattr(self.dist, 'has_pp',
19211935
False)) or self.dist.is_last_pp_rank:
19221936
avg_decoded = getattr(
19231937
request, 'avg_decoded_tokens_per_iter', None)
1924-
disabled_now, _ = self.model_engine.speculation_gate.record_avg_decoded(
1925-
avg_decoded,
1926-
request_id=getattr(request, 'py_request_id',
1927-
None))
1928-
if disabled_now:
1929-
# disable speculation permanently
1930-
# starting from next iteration, _prepare_and_schedule_batch will set self.use_spec_decode to False
1931-
self.model_engine.speculation_permanently_disabled = True
1938+
if avg_decoded is not None:
1939+
disabled_now, _ = self.speculation_gate.record_avg_decoded(
1940+
avg_decoded,
1941+
request_id=getattr(request, 'py_request_id',
1942+
None))
1943+
if disabled_now:
1944+
# disable speculation permanently
1945+
# starting from next iteration, _prepare_and_schedule_batch will set self.use_spec_decode to False
1946+
self.speculation_permanently_disabled = True
1947+
else:
1948+
logger.debug(
1949+
f"Request {request.py_request_id} has no avg_decoded_tokens_per_iter"
1950+
)
19321951
if request.is_disagg_context_transmission_state:
19331952
self.ctx_in_transmission_requests.append(request)
19341953
else:

tensorrt_llm/_torch/speculative/speculation_gate.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,18 +29,21 @@ def reset(self) -> None:
2929

3030
def record_avg_decoded(
3131
self,
32-
avg_decoded_tokens_per_iter: Optional[float],
32+
avg_decoded_tokens_per_iter: float,
3333
request_id: Optional[int] = None) -> Tuple[bool, Optional[float]]:
3434
"""
35-
Record a completed request's avg_decoded_tokens_per_iter.
36-
Returns (disabled_now, current_avg_accept) where disabled_now is True only when the call causes disable.
37-
"""
35+
Record a completed request's avg_decoded_tokens_per_iter.
36+
Returns (disabled_now, current_avg_accept) where disabled_now is True only when the call causes disable.
37+
"""
3838
if self.disabled or self.window is None or self.window <= 0 or self.threshold is None:
3939
return False, None
4040

41+
# Extra Guard: if caller passed None, skip updating the rolling stats
42+
if avg_decoded_tokens_per_iter is None:
43+
return False, None
44+
4145
accepted_len = 0.0
42-
if avg_decoded_tokens_per_iter is not None:
43-
accepted_len = max(0.0, float(avg_decoded_tokens_per_iter) - 1.0)
46+
accepted_len = max(0.0, float(avg_decoded_tokens_per_iter) - 1.0)
4447

4548
# Log per-request completion for debug
4649
if request_id is not None:
@@ -50,6 +53,8 @@ def record_avg_decoded(
5053

5154
# O(1) rolling update
5255
self.acceptance_history.append(accepted_len)
56+
logger.debug(
57+
f"[SpeculationGate] Acceptance history: {self.acceptance_history}")
5358
self.acceptance_sum += accepted_len
5459
if len(self.acceptance_history) > self.window:
5560
removed = self.acceptance_history.popleft()

tensorrt_llm/llmapi/llm_args.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -367,9 +367,6 @@ class DecodingBaseConfig(StrictBaseModel):
367367
# (N = acceptance_window) drops below this value.
368368
acceptance_length_threshold: Optional[float] = None
369369

370-
# Upper bound to avoid accidental huge windows
371-
MAX_ACCEPTANCE_WINDOW: ClassVar[int] = 100000
372-
373370
# Validate acceptance controls at field level so they run on model creation
374371
@field_validator('acceptance_window')
375372
@classmethod
@@ -379,10 +376,6 @@ def _validate_acceptance_window(cls, v: Optional[int]):
379376
if v < 0:
380377
raise ValueError(
381378
f"acceptance_window must be >= 0 (0 disables), got {v}")
382-
if v > cls.MAX_ACCEPTANCE_WINDOW:
383-
raise ValueError(
384-
f"acceptance_window must be <= {cls.MAX_ACCEPTANCE_WINDOW}, got {v}"
385-
)
386379
return v
387380

388381
@field_validator('acceptance_length_threshold')

0 commit comments

Comments
 (0)