Skip to content

Commit bcb72e0

Browse files
committed
Auto-enable ngram with concurrency <= 32.
Signed-off-by: Simeng Liu <[email protected]>
1 parent 9645814 commit bcb72e0

File tree

4 files changed

+46
-8
lines changed

4 files changed

+46
-8
lines changed

examples/llm-api/quickstart_advanced.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -108,9 +108,9 @@ def add_llm_args(parser):
108108

109109
# Speculative decoding
110110
parser.add_argument('--spec_decode_algo', type=str, default=None)
111-
parser.add_argument('--spec_decode_max_draft_len', type=int, default=1)
111+
parser.add_argument('--spec_decode_max_draft_len', type=int, default=0)
112112
parser.add_argument('--draft_model_dir', type=str, default=None)
113-
parser.add_argument('--max_matching_ngram_size', type=int, default=5)
113+
parser.add_argument('--max_matching_ngram_size', type=int, default=0)
114114
parser.add_argument('--use_one_model', default=False, action='store_true')
115115

116116
# Relaxed acceptance
@@ -152,6 +152,11 @@ def setup_llm(args, **kwargs):
152152
spec_decode_algo = args.spec_decode_algo.upper(
153153
) if args.spec_decode_algo is not None else None
154154

155+
# Update spec_decode_max_draft_len to 1 if unset by the user for non-NGRAM spec_decode_algo
156+
# NGRAM spec_decode_algo will use default heuristic to set spec_decode_max_draft_len and max_matching_ngram_size
157+
if spec_decode_algo != "NGRAM" and args.spec_decode_max_draft_len == 0:
158+
args.spec_decode_max_draft_len = 1
159+
155160
if spec_decode_algo == 'MTP':
156161
if not args.use_one_model:
157162
print(

tensorrt_llm/_torch/pyexecutor/py_executor.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -922,7 +922,7 @@ def _executor_loop(self):
922922
self._pad_attention_dp_dummy_request()
923923

924924
if self.drafter is not None:
925-
self._prepare_draft_requests(self.active_requests)
925+
self._prepare_draft_requests()
926926

927927
scheduled_batch, fitting_disagg_gen_init_requests, num_fitting_reqs = self._schedule(
928928
)
@@ -1009,14 +1009,15 @@ def _executor_loop(self):
10091009
iter_stats=iter_stats,
10101010
iter_start_time=iter_start_time))
10111011

1012-
def _prepare_draft_requests(self, requests):
1012+
def _prepare_draft_requests(self):
10131013
try:
10141014
# Set draft tokens here to make the KV cache manager
10151015
# and scheduler aware of them.
1016-
for req in requests:
1016+
for req in self.active_requests:
10171017
if req.state not in (LlmRequestState.GENERATION_IN_PROGRESS,
10181018
LlmRequestState.DISAGG_GENERATION_INIT):
10191019
continue
1020+
10201021
req.py_last_draft_tokens = req.py_draft_tokens
10211022
max_draft_len = self.model_engine.spec_config.max_draft_len
10221023

tensorrt_llm/llmapi/llm.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -964,13 +964,41 @@ def _build_model(self):
964964
self._executor_config.cache_transceiver_config = PybindMirror.maybe_to_pybind(
965965
self.args.cache_transceiver_config)
966966
from tensorrt_llm._torch.pyexecutor.config import update_executor_config
967+
968+
spec_config = self.args.speculative_config
969+
max_batch_size = self._executor_config.max_batch_size
970+
# Apply heuristic to incomplete NGramDecodingConfig based on benchmark results
971+
# With concurrency <= 4, max_draft_len = 5, max_matching_ngram_size = 3
972+
# With concurrency <= 32, max_draft_len = 3, max_matching_ngram_size = 5
973+
if spec_config.spec_dec_mode() == "NGRAM" and max_batch_size <= 32:
974+
if not self.args.disable_overlap_scheduler:
975+
logger.info(
976+
"Disable overlap scheduler to enable NGram speculative decoding."
977+
)
978+
# From benchmark results, we found that NGram speculative decoding provides better performance than overlap scheduler with low concurrency <= 32.
979+
# Therefore, we disable overlap scheduler to enable NGram speculative decoding.
980+
self.args.disable_overlap_scheduler = True
981+
982+
if spec_config.max_draft_len != 0 and spec_config.max_matching_ngram_size != 0:
983+
pass
984+
else:
985+
if max_batch_size <= 4:
986+
spec_config.max_draft_len = 5 if spec_config.max_draft_len == 0 else spec_config.max_draft_len
987+
spec_config.max_matching_ngram_size = 3 if spec_config.max_matching_ngram_size == 0 else spec_config.max_matching_ngram_size
988+
elif max_batch_size <= 32:
989+
spec_config.max_draft_len = 3 if spec_config.max_draft_len == 0 else spec_config.max_draft_len
990+
spec_config.max_matching_ngram_size = 5 if spec_config.max_matching_ngram_size == 0 else spec_config.max_matching_ngram_size
991+
logger.info(
992+
f"Apply heuristic to incomplete NGramDecodingConfig: max_draft_len={spec_config.max_draft_len}, max_matching_ngram_size={spec_config.max_matching_ngram_size}"
993+
)
994+
967995
update_executor_config(
968996
self._executor_config,
969997
backend=self.args.backend,
970998
pytorch_backend_config=self.args.get_pytorch_backend_config()
971999
if self.args.backend in ["pytorch", "_autodeploy"] else None,
9721000
mapping=self.args.parallel_config.to_mapping(),
973-
speculative_config=self.args.speculative_config,
1001+
speculative_config=spec_config,
9741002
hf_model_dir=self._hf_model_dir,
9751003
max_input_len=self.args.max_input_len,
9761004
max_seq_len=max_seq_len,

tensorrt_llm/llmapi/llm_args.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -385,8 +385,12 @@ class NGramDecodingConfig(DecodingBaseConfig):
385385
is_public_pool: bool = True
386386
Whether to use a common pool for all requests, or the pool is private for each request if False.
387387
"""
388-
389-
max_matching_ngram_size: int = 4
388+
# If max_draft_len or max_matching_ngram_size are not set by user
389+
# Default heuristic will be use
390+
# With concurrency <= 4, max_draft_len = 5, max_matching_ngram_size = 3
391+
# With concurrency <= 32, max_draft_len = 3, max_matching_ngram_size = 5
392+
max_draft_len: int = 0
393+
max_matching_ngram_size: int = 0
390394
is_keep_all: bool = True
391395
is_use_oldest: bool = True
392396
is_public_pool: bool = True

0 commit comments

Comments
 (0)