Skip to content

Commit 18d49b2

Browse files
committed
[None][feat] Optimize CUDA graph memory usage for spec decode cases
Signed-off-by: Mike Iovine <[email protected]>
1 parent 3b2dd40 commit 18d49b2

File tree

2 files changed

+12
-4
lines changed

2 files changed

+12
-4
lines changed

tensorrt_llm/_torch/pyexecutor/model_engine.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -722,8 +722,11 @@ def disable_optimization(backend: Backend):
722722
# For non-draft model, we also capture the CUDA graph instance for draft length 0,
723723
# so that when we disable spec decode at runtime, we can still run the captured graph.
724724
# Note that for one engine mode, we are not able to turn off spec decode at runtime.
725-
if not self.is_draft_model and self.max_draft_len > 0 and not self.spec_config.spec_dec_mode.use_one_engine(
726-
):
725+
if (not self.is_draft_model and self.max_draft_len > 0
726+
and not self.spec_config.spec_dec_mode.use_one_engine()
727+
# Assume that speculation is always on if the user didn't give us a max_concurrency
728+
# value. This will save on memory.
729+
and self.spec_config.max_concurrency is not None):
727730
draft_lengths.append(0)
728731

729732
for bs in cuda_graph_batch_sizes:

tensorrt_llm/_torch/speculative/drafter.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from abc import ABC, abstractmethod
2-
from typing import List, Optional
2+
from typing import List, Optional, final
33

44
from ..pyexecutor.llm_request import LlmRequest
55
from ..pyexecutor.resource_manager import ResourceManager
@@ -26,8 +26,13 @@ def prepare_draft_tokens(
2626
"""
2727
raise NotImplementedError
2828

29+
@final
2930
def should_use_spec_decode(self, requests: List[LlmRequest]) -> bool:
30-
"""Check if spec decode should be used for the current iteration."""
31+
"""
32+
You probably don't want to override this. ModelEngine
33+
assumes that speculation is always on if max_concurrency
34+
is not specified by the user's spec config.
35+
"""
3136
if self.max_concurrency is not None:
3237
return len(requests) <= self.max_concurrency
3338
return True

0 commit comments

Comments
 (0)