Skip to content

[Feature]: Per-sequence speculative decoding #17984

@yesredpig

Description

@yesredpig

🚀 The feature, motivation and pitch

1. Problem

Currently, increasing batch size in vLLM's Speculative Decoding inference causes inefficiency.
When using the LLaMA 1B SSM model on the LLaMA 70B Original model, a performance reversal occurs at Batchsize 32.
In addition, when the num_speculative_tokens (SL; speculative length) is large, the inefficiency increases further as the batch size increases (Fig. 2).

vLLM was also aware of the need for optimization for this. (https://docs.google.com/document/d/1T-JaS2T1NRfdP51qzqpyakoCXxSXTtORppiwaj5asxA/edit?tab=t.0#heading=h.kk7dq05lc6q8)

2. Previous work in vLLM 

To handle the increasing batch size in SD, vLLM has been performing the following tasks: Batch Expansion #3103 and  MQA (Multi-Query Attention) Scorer
Image

"Batch expansion" expands the batch by the factor of k (num_speculative_tokens). Each original sequence + one proposal token become a separate sequence in the expanded batch for the target model's scoring pass. Because of the "expansion" it has drawn backs as it increases memory usage and attention calculation by factor K.
To overcome the drawbacks, MQAScorer utilizes specialized MQA kernels (when available) to score all k proposal tokens for each sequence without expanding the batch size explicitly.

Both Batch Expansion and MQA Scorer face performance degradation on dynamic shape & CUDA Graphs. if K ( num_speculative_tokens) changes each iteration or across sequences. CUDA graph can't hadling the dynamic shape/size.

Ultimately, the current vLLM batch size processing method can only handle static SL that the K ( num_speculative_tokens) set before inference.

3. Dynamic SL 

Even if a piece-wise CUDA graph is applied to MQA, the inefficiency caused by each sequence in a batch having a FixedSL value is not resolved. Also, when using models such as EAGLE2 that use tree-attetion, you need to use SL 32, etc., but as the batch grows, the MQA padding space grows larger.

The following experiment is the result of checking until rejection occurs after setting max SL to 60. The experimental results show that OracleSL has a large range from 1 to 60, and that it infers 117 times less than when inferred by setting it to the optimal SL value of 4. Since the target model is 8B, there was no significant difference in speed, but if the target model is large, the difference between OracleSL and StaticSL will be large.

Image

Previous papers have proven that it is effective to have a different SL for each sequence and a different SL for each iteration.

BASS,
DISCO,
SPRINTER,
AdaEDL

4. CONCLUSION 

In conclusion, per-sequence decoding is necessary to apply dynamicSL to each iteration and each sequence in the batch and to resolve inefficiency due to increasing batch size.

My team are implementing per-sequence decoding in the flash-attention2 kernel. However, we are currently developing with vLLM 0.8.4, so it would be nice if the schedule for SD updates in vLLM V1 could be in sync.

Alternatives

No response

Additional context

No response

Before submitting a new issue...

  • Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.

Metadata

Metadata

Assignees

No one assigned

    Labels

    feature requestNew feature or requeststaleOver 90 days of inactivity

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions