Skip to content

Conversation

ekagra-ranjan
Copy link
Contributor

@ekagra-ranjan ekagra-ranjan commented Sep 5, 2025

Addresses: #18633
Adds a new SD approach combing the best of Ngram and EAGLE. Besides working on algorithm, this needed additional work on data and metrics as explained below.

Algorithm

The RFC discusses more on the motivation and proposed algorithm. The major change is that we can now run multiple drafters in single step. This PR allows only Ngram and Eagle to simultaneously. This can be generalized to other combinations but at the moment I think only ngram (weight free approach) and other approach which needs trained weights would be of value. This PR does not add Ngram + Eagle3 and is left for future work if someone is interested.

Dataset

Use Blazedit with max norm edit distance of 0.25 which means at max 25% of output will be different from the input. The min edit distance in this dataset is 0.1. Check the PR out for more detail on the need for this dataset: #23605

Metric

The AL reported in offline_inference/spec_decode.py takes into account only when the draft was made for a sequence. Assume

  • the vanilla model takes 10 forward pass to output 10 token
  • With SD
    • a drafter model drafts 2 tokens, i.e., K=2
    • drafter has 100% AR, i.e., the draft is accepted by target when drafter drafts
    • drafter was able to draft only during 1st 2 forward pass, i.e., 3 (K+1=2+1) token produced in 1st fwd pass and again 3 tokens produced in 2nd fwd pass. The remaining 4 tokens i.e., (10-3-3), we produced as 1 token at a time and took 4 fwd pass.

The current formula for AL will measured it as 3, i.e., K+1. However, my intuition is that AL should be measured as how many tokens are accepted in forward pass normalized across sequence generation. A sequence normalized AL would be = num of total token generated / num of fwd pass = 10 / (2+4) = 10/6 = 1.66. This sequence normalized AL is more realistic in the sense that it gives the expected speedup from SD assuming zero overhead and inefficiency from the SD method. This correction is important to make. Without this, ngram will have much higher AL than Eagle on datasets like Instruct Coder but TPOT doesn't reflect that. The current AL computation is more like computing precision whereas to get navigate which method is better just by looking at AL and assuming zero overhead, we need sequence normalized AL.

This PR estimates the sequence normalized AL by finding how many total tokens are generated, how many of them were generated by SD, how many times draft was made and how many tokens were generated without SD. More specifically,

num_tokens_generated_without_sd = total_tokens_generated - (num_drafts + num_accepted_tokens)
seq_normalized_acceptance_length = (total_tokens_generated) / (num_drafts + num_tokens_generated_without_sd)

In some cases, num_tokens_generated_without_sd is negative. This is because of boundary condition where we have <K tokens to predict and we predict K tokens and all K were accepted but the final output has <K tokens. This error is bounded by fraction of ((K-1)*num of samples / num of samples * output len per sample) = ((K-1)/output len per sample). Not all samples will run into this boundary condition therefore this is an upper bound and for K=5 and output len 256 this is ~1.5%. Empirically, this was found to be <1% in below benchmarks. Therefore, the impact of this estimation is very negligible on the final results.

Benchmarks

Offline Inference (AL)

Blazedit max edit norm distance: 0.25

method: ngram-eagle
cmd: python3 examples/offline_inference/spec_decode.py --method ngram-eagle --num-speculative-tokens-per-method "{\"ngram\": 5, \"eagle\": 3}" --prompt_lookup_max 5 --prompt_lookup_min 2 --tp 1 --dataset-name hf --dataset-path vdaita/edit_5k_char --num-prompts 90 --hf-output-len 2048 --blazedit-min-distance 0.01 --blazedit-max-distance 0.25 --no-oversample --print-output

output

num generation tokens: 5376
--------------------------------------------------
total_num_output_tokens: 5376
num_drafts: 1334
num_draft_tokens: 5710
num_accepted_tokens: 4055
mean acceptance length: 4.04
num_tokens_generated_without_sd: -13
seq normalized acceptance length: 4.07
--------------------------------------------------
acceptance at token 0: 0.83
acceptance at token 1: 0.66
acceptance at token 2: 0.58
acceptance at token 3: 0.49
acceptance at token 4: 0.48

higher precision ngram-eagle by increasing --prompt_lookup_min to 5
cmd: python3 examples/offline_inference/spec_decode.py --method n gram-eagle --num-speculative-tokens-per-method "{\"ngram\": 5, \"eagle\": 3}" --prompt_lookup_max 5 --prompt_lookup_min 5 --tp 1 --dataset-name hf --dataset-path vdaita/edit_5k_char --num-prompts 90 --hf-output-len 2048 --blazedit-min-distance 0.01 --blazedit-max-distance 0.25 --no-oversample

output

--------------------------------------------------
total_num_output_tokens: 5376
num_drafts: 1355
num_draft_tokens: 5463
num_accepted_tokens: 4046
mean acceptance length: 3.99
num_tokens_generated_without_sd: -25
seq normalized acceptance length: 4.04
--------------------------------------------------
acceptance at token 0: 0.85
acceptance at token 1: 0.67
acceptance at token 2: 0.56
acceptance at token 3: 0.45
acceptance at token 4: 0.45

method: eagle
cmd: python3 examples/offline_inference/spec_decode.py --method eagle --num_spec_tokens 3 --tp 1 --dataset-name hf --dataset-path vdaita/edit_5k_char --num-prompts 90 --hf-output-len 2048 --blazedit-min-distance 0.01 --blazedit-max-distance 0.25 --no-oversample --print-output

output

--------------------------------------------------
num generation tokens: 5376
--------------------------------------------------
total_num_output_tokens: 5376
num_drafts: 2189
num_draft_tokens: 6567
num_accepted_tokens: 3187
mean acceptance length: 2.46
num_tokens_generated_without_sd: 0
seq normalized acceptance length: 2.46
--------------------------------------------------
acceptance at token 0: 0.77
acceptance at token 1: 0.45
acceptance at token 2: 0.24

method: ngram
cmd: python3 examples/offline_inference/spec_decode.py --method ngram --num_spec_tokens 5 --prompt_lookup_max 5 --prompt_lookup_min 2 --tp 1 --dataset-name hf --dataset-path vdaita/edit_5k_char --num-prompts 90 --hf-output-len 2048 --blazedit-min-distance 0.01 --blazedit-max-distance 0.25 --no-oversample --print-output

output

--------------------------------------------------
num generation tokens: 5376
--------------------------------------------------
total_num_output_tokens: 5376
num_drafts: 993
num_draft_tokens: 4960
num_accepted_tokens: 3562
mean acceptance length: 4.59
num_tokens_generated_without_sd: 821
seq normalized acceptance length: 2.96
--------------------------------------------------
acceptance at token 0: 0.81
acceptance at token 1: 0.74
acceptance at token 2: 0.70
acceptance at token 3: 0.68
acceptance at token 4: 0.66

MTBench

method: ngram-eagle
cmd: time VLLM_USE_V1=1 python3 examples/offline_inference/spec_decode.py --method ngram-eagle --num-speculative-tokens-per-method "{\"ngram\": 5, \"eagle\": 3}" --prompt_lookup_max 5 --prompt_lookup_min 2 --tp 1 --dataset-name hf --dataset-path philschmid/mt-bench --num-prompts 80 --print-output
output

--------------------------------------------------
num generation tokens: 16947
--------------------------------------------------
total_num_output_tokens: 16947
num_drafts: 7646
num_draft_tokens: 25784
num_accepted_tokens: 9300
mean acceptance length: 2.22
num_tokens_generated_without_sd: 1
seq normalized acceptance length: 2.22
--------------------------------------------------
acceptance at token 0: 0.63
acceptance at token 1: 0.35
acceptance at token 2: 0.19
acceptance at token 3: 0.02
acceptance at token 4: 0.02

higher precision ngram-eagle by increasing --prompt_lookup_min to 5
cmd: time VLLM_USE_V1=1 python3 examples/offline_inference/spec_decode.py --method ngram-eagle --num-speculative-tokens-per-method "{\"ngram\": 5, \"eagle\": 3}" --prompt_lookup_max 5 --prompt_lookup_min 5 --tp 1 --dataset-name hf --dataset-path philschmid/mt-bench --num-prompts 80 --printt-output
output

num generation tokens: 16926
--------------------------------------------------
total_num_output_tokens: 16926
num_drafts: 7359
num_draft_tokens: 22653
num_accepted_tokens: 9571
mean acceptance length: 2.30
num_tokens_generated_without_sd: -4
seq normalized acceptance length: 2.30
--------------------------------------------------
acceptance at token 0: 0.68
acceptance at token 1: 0.39
acceptance at token 2: 0.21
acceptance at token 3: 0.01
acceptance at token 4: 0.01

method: eagle
cmd: python3 examples/offline_inference/spec_decode.py --method eagle --num_spec_tokens 3 --tp 1 --dataset-name hf --dataset-path philschmid/mt-bench --num-prompts 80 --print-output
output

num generation tokens: 16956
--------------------------------------------------
total_num_output_tokens: 16956
num_drafts: 7416
num_draft_tokens: 22248
num_accepted_tokens: 9543
mean acceptance length: 2.29
num_tokens_generated_without_sd: -3
seq normalized acceptance length: 2.29
--------------------------------------------------
acceptance at token 0: 0.68
acceptance at token 1: 0.39
acceptance at token 2: 0.21

method: ngram
cmd: time VLLM_USE_V1=1 python3 examples/offline_inference/spec_decode.py --method ngram --num_spec_tokens 5 --prompt_lookup_max 5 --prompt_lookup_min 2 --tp 1 --dataset-name hf --dataset-path philschmid/mt-bench --num-prompts 80 --print-output
output

num generation tokens: 17097
--------------------------------------------------
total_num_output_tokens: 17097
num_drafts: 2578
num_draft_tokens: 12864
num_accepted_tokens: 2579
mean acceptance length: 2.00
num_tokens_generated_without_sd: 11940
seq normalized acceptance length: 1.18
--------------------------------------------------
acceptance at token 0: 0.44
acceptance at token 1: 0.25
acceptance at token 2: 0.15
acceptance at token 3: 0.10
acceptance at token 4: 0.07
Instruct Code

method: ngram-eagle
cmd: time VLLM_USE_V1=1 python3 examples/offline_inference/spec_decode.py --method ngram-eagle --num-speculative-tokens-per-method "{\"ngram\": 5, \"eagle\": 3}" --prompt_lookup_max 5 --prompt_lookup_min 2 --tp 1 --dataset-name hf --dataset-path likaixin/InstructCoder --num-prompts 1000 --print-output
output

num generation tokens: 163128
--------------------------------------------------
total_num_output_tokens: 163128
num_drafts: 50345
num_draft_tokens: 196610
num_accepted_tokens: 113627
mean acceptance length: 3.26
num_tokens_generated_without_sd: -844
seq normalized acceptance length: 3.30
--------------------------------------------------
acceptance at token 0: 0.76
acceptance at token 1: 0.57
acceptance at token 2: 0.43
acceptance at token 3: 0.25
acceptance at token 4: 0.24

method: eagle
cmd: time VLLM_USE_V1=1 python3 examples/offline_inference/spec_decode.py --method eagle --num_spec_tokens 3 --tp 1 --dataset-name hf --dataset-path likaixin/InstructCoder --num-prompts 1000 --print-output
output

num generation tokens: 162992
--------------------------------------------------
total_num_output_tokens: 162992
num_drafts: 58129
num_draft_tokens: 174387
num_accepted_tokens: 105583
mean acceptance length: 2.82
num_tokens_generated_without_sd: -720
seq normalized acceptance length: 2.84
--------------------------------------------------
acceptance at token 0: 0.82
acceptance at token 1: 0.60
acceptance at token 2: 0.39

method: ngram
cmd: time VLLM_USE_V1=1 python3 examples/offline_inference/spec_decode.py --method ngram --num_spec_tokens 5 --prompt_lookup_max 5 --prompt_lookup_min 2 --tp 1 --dataset-name hf --dataset-path likaixin/InstructCoder --num-prompts 1000 --print-output
output

num generation tokens: 163128
--------------------------------------------------
total_num_output_tokens: 163128
num_drafts: 30249
num_draft_tokens: 150775
num_accepted_tokens: 79552
mean acceptance length: 3.63
num_tokens_generated_without_sd: 53327
seq normalized acceptance length: 1.95
--------------------------------------------------
acceptance at token 0: 0.69
acceptance at token 1: 0.57
acceptance at token 2: 0.50
acceptance at token 3: 0.45
acceptance at token 4: 0.42

Online Inference median TPOT (ms)

TPOT ms client cmd MTBench ``` vllm bench serve --port 9001 --save-result --save-detailed \ --model meta-llama/Llama-3.1-8B-Instruct \ --endpoint-type openai-chat \ --endpoint /v1/chat/completions \ --dataset-name hf \ --dataset-path philschmid/mt-bench \ --num-prompts 80 \ --max-concurrency 1 \ --result-dir "./log/EAGLE-1" ```

instruct coder

vllm bench serve --port 9001 --save-result --save-detailed \
    --model meta-llama/Llama-3.1-8B-Instruct \
    --endpoint-type openai-chat \
    --endpoint /v1/chat/completions \
    --dataset-name hf \
    --dataset-path likaixin/InstructCoder \
    --num-prompts 1000 \
    --max-concurrency 1 \
    --result-dir "./log/EAGLE-1"

Blazedit

vllm bench serve --port 9001 --save-result --save-detailed \
    --model meta-llama/Llama-3.1-8B-Instruct \
    --endpoint-type openai-chat \
    --endpoint /v1/chat/completions \
    --dataset-name hf \
    --dataset-path vdaita/edit_5k_char \
    --num-prompts 90 \
    --blazedit-min-distance 0.01 \
    --blazedit-max-distance 0.25 \
    --no-oversample \
    --max-concurrency 1 \
    --hf-output-len 256 \
    --result-dir "./log/EAGLE-1"

vanilla
server cmd: VLLM_USE_V1=1 vllm serve meta-llama/Llama-3.1-8B-Instruct --disable-log-requests --port 9001

  • MTBench: 7.00
  • blazedit: 6.95
  • instruct coder: 6.98

eagle
server cmd: VLLM_USE_V1=1 vllm serve meta-llama/Llama-3.1-8B-Instruct \ --disable-log-requests --port 9001 \ --speculative_config '{"method": "eagle", "model": "yuhuili/EAGLE-LLaMA3.1-Instruct-8B", "num_speculative_tokens": 3}'

  • MTBench: 4.19
  • blazedit: 3.96
  • instruct coder (100 samples): 3.41
  • instruct coder (1000 samples): 3.41

ngram
server cmd: VLLM_USE_V1=1 vllm serve meta-llama/Llama-3.1-8B-Instruct \ --disable-log-requests --port 9001 \ --speculative_config '{"method": "ngram", "num_speculative_tokens": 5, "prompt_lookup_max": 5, "prompt_lookup_min": 2}'

  • MTBench: 6.18
  • blazedit: 1.90
  • instruct coder (100 samples): 3.21
  • instruct coder (1000 samples): 3.32

ngram-eagle
server cmd: VLLM_USE_V1=1 vllm serve meta-llama/Llama-3.1-8B-Instruct \ --disable-log-requests --port 9001 \ --speculative_config '{"method": "ngram-eagle", "model": "yuhuili/EAGLE-LLaMA3.1-Instruct-8B", "num_speculative_tokens_per_method": "{\"ngram\": 5, \"eagle\": 3}", "prompt_lookup_max": 5, "prompt_lookup_min": 5}'

  • MTBench: 4.30
  • blazedit: 2.13
  • instruct coder: 2.96

Analysis

Impact on AL

The sequence normalized AL of ngram-eagle is much higher for editing task since:

  • when input is equal to output: it uses ngram draft and hence is better than eagle and equal to ngram in this region
  • when input is not same as output: it uses eagle draft and hence is better than ngram and equal to eagle in this region
    Therefore, overall the AL is much better for ngram-eagle

Overhead:

ngram-eagle has similar overhead as eagle since drafter has to run. Its overhead is higher than ngram as ngram doesnt need to run any drafter auto-regressively.

Performance analysis on Datasets

  • MTBench: perf neutral compared to eagle on general tasks like MTBench. This is achieved by setting higher min_prompt_lookup which increases precision at the cost of lower ngram drafts
  • Blazedit: perf is better than eagle or ngram because AL of ngram-eagle is higher than that of ngram or eagle. This higher AL covers the drafting cost of eagle. It is understandable why ngram-eagle has higher AL than eagle for editing task. Comparing with ngram, the ngram-eagle approach benefits from ngram's match when input is same as output. However, when the input and output dont match, ngram will have 0 AL whereas ngram-eagle will have non-0 AL thanks to eagle. This makes the average AL higher for ngram-eagle compared to ngram

Empirically, sequence normalized AL on Blazedit when the edit distance norm is bw [0.1, 0.25] as reported above in Benchmarks section:

  • eagle: 2.46
  • ngram: 2.96 (AL precision was 4.59 but ngram doesnt draft always so seq norm AL is 2.96)
  • ngram-eagle: 4.04

Theoretical Analysis

This in line with my theoretical calculation of AL of ngram-eagle if we change edit distance norm.
Precision AL means the AL when the draft was made. vLLM computes AL in this manner which is not complete information hence in the PR I introduced a new metric in offline inference called sequence normalized AL which represents the AL across a seq which gives the expected speedup from SD assuming 0 draft overhead and implementation inefficiency. More detail on the metric is in the PR.

When input is same as output, it will follow ngram's AL and when it diverges it will follow eagle's AL.

  • The formula for ngram-eagle AL is =1000/((edit_norm*1000/eagle_AL) + ((1-edit_norm)*1000/ngram_AL)) .
  • The forumula for est ngram AL is =1000/((edit_norm*1000/1) + ((1-edit_norm)*1000/ngram_AL))
    The numerator is 1000 tokens and denominator is the num of steps needed to produce them. The assumption of 1000 tokens doesnt matter since it gets canceled

<style type="text/css"></style>

edit distance ngram precision AL eagle AL est ngram AL est ngram-eagle AL
0.1 4.59 2.46 3.377483444 4.224242424
0.2 4.59 2.46 2.671711292 3.912474012
0.3 4.59 2.46 2.209918151 3.643562439
0.4 4.59 2.46 1.884236453 3.40923913
0.5 4.59 2.46 1.642218247 3.203234043
0.6 4.59 2.46 1.455294864 3.02070626
0.7 4.59 2.46 1.306575576 2.85785877
0.8 4.59 2.46 1.185433884 2.71167147
0.9 4.59 2.46 1.084849917 2.579712132

As we can see, ngram-AL is strictly better than ngram and eagle

Final end to end results

image

yaxis - TPOT (ms), lower is better
ngram-eagle is consistently among the fastest across different types of dataset.

TODO

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a new speculative decoding method, ngram-eagle, which combines n-gram based proposals with EAGLE proposals. The changes include adding the new method to configuration options, updating the speculative decoding logic to handle the combined approach, and modifying the example script to support it. The implementation correctly initializes both n-gram and EAGLE proposers when ngram-eagle is selected and combines their outputs. My review found one critical issue in the configuration validation logic that should be addressed.

ekagra-ranjan and others added 2 commits September 5, 2025 21:36
Signed-off-by: Ekagra Ranjan <[email protected]>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Signed-off-by: Ekagra Ranjan <[email protected]>
Signed-off-by: Ekagra Ranjan <[email protected]>
Signed-off-by: Ekagra Ranjan <[email protected]>
Copy link

mergify bot commented Sep 8, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @ekagra-ranjan.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Sep 8, 2025
Signed-off-by: Ekagra Ranjan <[email protected]>
Signed-off-by: Ekagra Ranjan <[email protected]>
Signed-off-by: Ekagra Ranjan <[email protected]>
@ekagra-ranjan ekagra-ranjan marked this pull request as ready for review September 10, 2025 05:48
Signed-off-by: Ekagra Ranjan <[email protected]>
@mergify mergify bot removed the needs-rebase label Sep 10, 2025
Signed-off-by: Ekagra Ranjan <[email protected]>
Copy link

mergify bot commented Sep 11, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @ekagra-ranjan.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Sep 11, 2025
@Neo9061
Copy link

Neo9061 commented Sep 23, 2025

This is great enablement! When can we merge this into vllm v1 mainline?

@Neo9061
Copy link

Neo9061 commented Sep 24, 2025

Wonder if your implementation works with EAGLE 3 in vllm v1? and whether the performance gain you established will hold true for higher concurrency level? many thanks!

@ekagra-ranjan
Copy link
Contributor Author

This is great enablement! When can we merge this into vllm v1 mainline?

Thanks @Neo9061 . I am waiting for reviews from @WoosukKwon and @LiuXiaoxuanPKU.

Wonder if your implementation works with EAGLE 3 in vllm v1? many thanks!

Eagle 3 is left for future PR. It will be straight forward and I leave it to the OSS.

and whether the performance gain you established will hold true for higher concurrency level?

SD in general doesnt hold good for very high concurrency and is not a byproduct of this method.

# combine ngram and eagle drafts
# prefer ngram drafts when available
# choose eagle drafts when ngram drafts are empty
for bid in range(len(draft_token_ids_ngram)):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this mean you always use both ngram and eagle to do generate speculation proposals? Isn't it more efficient to generate eagle proposals only when there are no valid ngram proposals?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In multi batch setting, if even 1 seq is running EAGLE then its almost the same cost as all req running it. The current implementation is easier to implement. Future improvements can be made to further optimize on lower batch settings.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @ekagra-ranjan, similar question, in your code where for each request, you sequentially first use n-gram to generate draft_token_ids_ngram and then use eagle to generate draft_token_ids_eagle. Can you share insights why such hybrid approach can be faster than EAGLE alone? I think the speedup from hybrid approach is to skip the EAGLE auto-regressive drafting if we have some results from n-gram.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Neo9061 while EAGLE drafting cost is certainly non-trivial, I think the main benefit of this method is that it can verify both long-but-rare speculated sequences (ngram) as well as short-but-accurate speculated sequences (eagle) together. this way, a deployment can get the benefits from either of ngram or eagle, whichever would have a better prediction accuracy on each token. I think of it less as an EAGLE-taxed ngram deployment and more of an ngram-augmented EAGLE deployment that gets the widespread speedup from EAGLE as a baseline, and for some cases it gets to leverage ngram for much higher AL

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Although, it does look like a free win to skip EAGLE drafting when all requests in the batch get an ngram hit. This might not be likely for BS >> 1, but for low-latency single-request this might actually pay off noticeably

Copy link

@Neo9061 Neo9061 Oct 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, thanks! @benchislett

Another question, probably for @ekagra-ranjan , if we want to use potentially longer drafting length from n-gram, will make the number of speculative tokens for n-garm part higher than 5 give you lower TPOT? I saw in your benchmarking above, you use the same drafting length as the EAGLE.

I am sure you are aware of Suffix decoding PR, compared to n-gram, suffix decoding without cache (an equivalent comparison, i.e. not making global tree caches the results from previous requests) can search over both prompt + intermediately generated tokens. N-gram can only search over the prompt part (

context_token_ids = token_ids_cpu[idx, :num_tokens]
). Also for long context prompt, the time complexity of suffix decoding is constant while n-gram is O(prompt length).

Wonder if any plan to integrate suffix decoding with eagle?

@simon-mo
Copy link
Collaborator

Please help fix the merge conflict, ty

@mergify mergify bot removed the needs-rebase label Oct 1, 2025
Signed-off-by: Ekagra Ranjan <[email protected]>
Comment on lines 248 to 269
if self.num_speculative_tokens_per_method is not None:
if isinstance(self.num_speculative_tokens_per_method, str):
self.num_speculative_tokens_per_method = json.loads(
self.num_speculative_tokens_per_method)
assert isinstance(self.num_speculative_tokens_per_method, dict), (
"num_speculative_tokens_per_method must be a dict or a json "
"string that can be converted to a dict.")
assert all(
isinstance(v, int) and v > 0
for v in self.num_speculative_tokens_per_method.values()), (
"All values in num_speculative_tokens_per_method must be "
"positive integers.")
max_num_speculative_tokens = max(
self.num_speculative_tokens_per_method.values())
if self.num_speculative_tokens is None:
self.num_speculative_tokens = max_num_speculative_tokens
else:
assert self.num_speculative_tokens <= \
max_num_speculative_tokens, (
"num_speculative_tokens should be None or must be"
" less than or equal to the "
"max value in num_speculative_tokens_per_method.")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why bother with str? The CLI can parse JSON

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! I didnt know about that. Fixed.

Comment on lines +464 to +468
if self.use_ngram() and not self.disable_padded_drafter_batch:
logger.warning(
"padded_drafter_batch has to be disabled with ngram. "
"Setting it disable_padded_drafter_batch to True.")
self.disable_padded_drafter_batch = True
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@benchislett - jfyi, padded_drafter_batch has been disabled by default for nrgam and ngram-eagle.

Copy link
Collaborator

@benchislett benchislett Oct 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think it has to be disabled, but it's likely a good decision to do so.

@ekagra-ranjan ekagra-ranjan changed the title [Spec Decode] Add ngram-eagle SD method [Spec Decode][Hybrid] Add ngram-eagle SD method Oct 1, 2025
self.propose([[]] * 1024, [""] * 1024, np.zeros(1024, dtype=np.int32),
np.zeros((1024, self.max_model_len), dtype=np.int32),
set())
logger.info(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this intended to be left in?

Comment on lines +285 to +286
# use ifs and not elifs to allow multiple
# draft models to be initialized
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: clarity

Suggested change
# use ifs and not elifs to allow multiple
# draft models to be initialized
# allow multiple draft methods to be used together

use_padded_batch_for_eagle = self.speculative_config and \
self.speculative_config.use_eagle() and \
not self.speculative_config.disable_padded_drafter_batch
not self.speculative_config.disable_padded_drafter_batch and \
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

have you considered keeping padded_drafter_batch for eagle drafting and then doing ngram separately? If both methods are going to be used anyway, this seems possible? do you think there's a benefit to keeping padded_drafter_batch in this case?

Copy link

mergify bot commented Oct 5, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @ekagra-ranjan.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Oct 5, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
documentation Improvements or additions to documentation needs-rebase performance Performance-related issues speculative-decoding v1
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants