-
-
Notifications
You must be signed in to change notification settings - Fork 11.2k
[SpecDecode][Kernel] Use Flashinfer for Rejection Sampling in Speculative Decoding #7244
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SpecDecode][Kernel] Use Flashinfer for Rejection Sampling in Speculative Decoding #7244
Conversation
|
👋 Hi! Thank you for contributing to the vLLM project. Once the PR is approved and ready to go, please make sure to run full CI as it is required to merge (or just use auto-merge). To run full CI, you can do one of these:
🚀 |
cadedaniel
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. two qs:
- can we run the correctness test for both paths? specifically the convergence test (all of the e2e tests depend on this for temp>0)
- can we make sure there is no perf regression for the non-FlashInfer path?
|
Note: there's a bugfix for correctness issue in this sampling kernel (flashinfer-ai/flashinfer#425), so we may want to bump FlashInfer to the next release. |
Yes @LiuXiaoxuanPKU 's number are measured with flashinfer main branch where #425 was already merged. This PR depends on flashinfer v0.1.4. |
|
@cadedaniel
|
|
@cadedaniel Updates:
This PR is ready, CI tests might fail because we will need flashinfer to release and add the latest flashinfer to CI. Tests passed locally. |
comaniac
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall LGTM. Just nits.
|
|
||
| rejection_sampler = RejectionSampler( | ||
| disable_bonus_tokens=disable_bonus_tokens) | ||
| rejection_sampler = RejectionSampler(disable_bonus_tokens=False, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I feel you can leave this test untouched, and just rename the follow test to "test_flashinfer_backed".
| strict_mode=strict_mode) | ||
| self.use_flashinfer = use_flashinfer | ||
| if self.use_flashinfer: | ||
| assert not disable_bonus_tokens, \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could be just warning?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ideally, when disable_bonus_tokens, the bonus token should be -1.
However, if we use flashinfer and set disable_bonus_tokens, the bonus token will still have values (!= -1), which makes the results incorrect. I guess it might be better to just fail here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we can remove the disable_bonus_token path completely now that #4212 is fixed.
but if it's too much work let's just leave it as assert, that way "no failure" means user gets the experience we planned for them instead of missing a warning and getting subpar perf
|
Will review today. |
| strict_mode=strict_mode) | ||
| self.use_flashinfer = use_flashinfer | ||
| if self.use_flashinfer: | ||
| assert not disable_bonus_tokens, \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we can remove the disable_bonus_token path completely now that #4212 is fixed.
but if it's too much work let's just leave it as assert, that way "no failure" means user gets the experience we planned for them instead of missing a warning and getting subpar perf
|
FYI: flashinfer v0.1.6 wheels are ready: https://github.com/flashinfer-ai/flashinfer/releases/tag/v0.1.6 |
|
/ready |
| batch_size: int, device: str): | ||
|
|
||
| def get_seeded_seqs(): | ||
| seeded_mask = torch.rand(batch_size, dtype=torch.float32) <= 1.0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this needs to go out of the helper function, else the rand will be different
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I realize there's an error here -- it should be torch.rand(...) <= 0.5
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think I will just remove the rand, we should just fix the generator for each request in the batch instead of fixing it with 50% probability.
|
|
||
| # num_emitted_tokens returned by flashinfer | ||
| # does not include the bonus token | ||
| # Flashinfer stops at the first token that violates |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not just align flashinfer's behavior and this API's?
Signed-off-by: Alvant <[email protected]>
Signed-off-by: LeiWang1999 <[email protected]>
End to end Speculative Decoding Performance (request latency):
Draft: LLama-160M, Target: Vicuna-7B, batch size=8, input_len=256, output_len=512Before this PR:
After this PR: