Skip to content

Conversation

LiuXiaoxuanPKU
Copy link
Collaborator

@LiuXiaoxuanPKU LiuXiaoxuanPKU commented Apr 3, 2025

Task 1 of #15901

Some limitations:

  1. only tested with single GPU
  2. only works with eager mode for both target and draft models, need to check compatibility with torch compile
  3. only support llama models

How to run this PR:
python examples/offline_inference/eagle.py
Ignore the metrics used/printed in that file.

Signed-off-by: LiuXiaoxuanPKU <[email protected]>
Signed-off-by: LiuXiaoxuanPKU <[email protected]>
Signed-off-by: LiuXiaoxuanPKU <[email protected]>
Signed-off-by: LiuXiaoxuanPKU <[email protected]>
Signed-off-by: LiuXiaoxuanPKU <[email protected]>
Signed-off-by: LiuXiaoxuanPKU <[email protected]>
Copy link

github-actions bot commented Apr 3, 2025

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@mergify mergify bot added documentation Improvements or additions to documentation v1 labels Apr 3, 2025
@markmc
Copy link
Member

markmc commented Apr 4, 2025

I'm sure this has been discussed elsewhere, but why introduce a V1-specific model (LlamaForCausalLMEagle) instead of using the existing model (EagleModel) ?

e.g. it looked like @luyuzhe111 expected EagleModel could be used in V1 with DeepSeek MTP weights?

Comment on lines +193 to +194
# We need to set the vllm_config here to register attention
# layers in the forward context.
Copy link
Contributor

@ekagra-ranjan ekagra-ranjan Apr 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to call load_model() from the __init__() so that this function runs and the attn layer are registered?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you elaborate a bit on which load_model are you talking about?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

# We need to set the vllm_config here to register attention
# layers in the forward context.
with set_default_torch_dtype(
draft_model_config.dtype), set_current_vllm_config(
Copy link
Contributor

@ekagra-ranjan ekagra-ranjan Apr 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. I didnt get this part which says that setting the current vllm config will lead to registering attn layers. Can you pls share more?

My understanding is that, we did not change any self.vllm_config in this function and the attn are registered when the model is initialized where it saves the attn prefix in static_forward_context which is then used during bind_kv_cache().

  1. So if the load_model() is called in __init__ in this file then would it not register the attn func without the need of set_current_vllm_config(self.vllm_config)?

Copy link
Collaborator Author

@LiuXiaoxuanPKU LiuXiaoxuanPKU Apr 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Which load_model are you referring to? This one?

Copy link
Contributor

@ekagra-ranjan ekagra-ranjan Apr 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The one in spec_decode/eagle.py here: https://github.com/vllm-project/vllm/pull/16035/files/59ee450306d3d719f78ad60c77ba9b739bc5cb11#diff-a4809a837fbf535a8f0999b11087a53ec1c53948b50c0a1fe64396bc86de9461R184

I have broken my above question into 2 parts along with my understanding so that it is easier for you to explain what I am missing. Looking fwd to your response

Signed-off-by: LiuXiaoxuanPKU <[email protected]>
Signed-off-by: LiuXiaoxuanPKU <[email protected]>
@LiuXiaoxuanPKU
Copy link
Collaborator Author

LiuXiaoxuanPKU commented Apr 6, 2025

Hi folks, thanks for all the comments so far @markmc @ekagra-ranjan, I am double checking the correctness and have not started fixed the comments. I will start fixing comments tomorrow. Some updates:

  1. I checked the compatibility with torch.compile, it should work. Concretely, the target model is running with torch.compile cudagraph, but the head is running in eager mode.
  2. I fixed some correctness bugs in the model definition.
  3. I benchmarked the performance of the example in here on H100 (I just changed the eagle model to yuhuili/EAGLE-LLaMA3-Instruct-8B, which is downloaded from huggingface directly). The numbers mean request latency:
Eagle (k=1) Eagle (k=2) Eagle (k=3) Eagle (k=4) w/o Eagle
1.49 1.39 1.43 1.47 1.92

Please start review and check correctness. cc @luyuzhe111 @WoosukKwon.

@WoosukKwon WoosukKwon mentioned this pull request Apr 6, 2025
10 tasks
@ekagra-ranjan
Copy link
Contributor

ekagra-ranjan commented Apr 7, 2025

@LiuXiaoxuanPKU Good results!

I was wondering how are we able to run EAGLE give Task 2, 3 are in #15901 are WIP ? What are the implications/assumptions of these wrt to the results shared in this PR?

@LiuXiaoxuanPKU
Copy link
Collaborator Author

@LiuXiaoxuanPKU Good results!

I was wondering how are we able to run EAGLE give Task 2, 3 are in #15901 are WIP ? What are the implications/assumptions of these wrt to the results shared in this PR?

Thanks for asking:

  1. task 2 is about allocating KV cache. When batch size is small (there is enough kv cache), the current implement should not cause errors such as overwriting KV cache of some other requests.
  2. task 3 is mainly for standard sampling. I 'm benchmarking greedy sampling here.

Signed-off-by: LiuXiaoxuanPKU <[email protected]>
@ekagra-ranjan
Copy link
Contributor

@LiuXiaoxuanPKU - are results on gsm8k? If possible can you run them on MTBench so that we can compare the results with SGL and identify gaps? https://docs.google.com/document/d/18ETJLsnxR88Qq3VDk5Mq-Hb7vuE9o3VNZ-hhz-OqAXk/edit?usp=sharing

@luyuzhe111
Copy link
Contributor

@LiuXiaoxuanPKU Hi Lily, when running with VLLM_USE_V1=1, I am getting errors from attention backend. Wondering if there is additional env variables that I need to set?

Copy link

mergify bot commented Apr 8, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @LiuXiaoxuanPKU.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

Signed-off-by: LiuXiaoxuanPKU <[email protected]>
Signed-off-by: LiuXiaoxuanPKU <[email protected]>
Signed-off-by: LiuXiaoxuanPKU <[email protected]>
@ekagra-ranjan
Copy link
Contributor

ekagra-ranjan commented Apr 9, 2025

@LiuXiaoxuanPKU @WoosukKwon

I benchmarked this PR on MTBench using the lmsys/sglang-EAGLE-LLaMA3-Instruct-8B so that we can compare with the SGL benchmark I did sometime back to help us get the direction.

I wired up the SD metrics in Scheduler in EngineCoreOutputs so that we can test the correctness with the Accept Length. Currently, the SD metrics in Scheduler gets reinitialized every engine step and is not aggregated so I fix that. Here is a dummy PR which has the benchmarking script and changes I did to get the below results on top of this PR. Pls lmk if something is incorrect in my setup. I can also raise a PR with the SD metric if that makes sense with some extra steps.

Here is the cmd used

VLLM_USE_V1=1 python examples/offline_inference/eagle.py --dataset="../data/mt_bench/question.jsonl" --num_spec_tokens 4 --max_num_seqs 1

num_spec_tokens is 2 or 4.

  • vanilla, [02:08<00:00, 1.61s/it, est. speed input: 47.09 toks/s, output: 128.64 toks/s]
  • k=2: [01:34<00:00, 1.18s/it, est. speed input: 64.12 toks/s, output: 175.49 toks/s] Accpt Len: 1.89
  • k=4: [01:40<00:00, 1.25s/it, est. speed input: 60.52 toks/s, output: 165.65 toks/s] Accpt Len: 2.08

k=2 is 36% faster and k=4 is 28% faster than vanilla.

Compared to the SGL bench:

  • the absolute throughput of vllm is 14% lower than SGL
  • k=2
    • AL: 1.89 (vLLM) vs 1.72 (SGL)
    • throughout gain: 36% faster (vLLM) vs 3% (SGL)
  • k=4
    • AL: 2.08 (vLLM) vs 2.4 (SGL)
    • throughout gain: 28% faster (vLLM) vs 27% (SGL)

vllm AL formula is here
SGL AL formula is here

Trends:

  • SGL has much lower gain at k=2 compared to vllm but catches up at k=4
  • vLLM has slightly higher AL at k=2 but lower AL at K=4

I feel the formulas for AL are similar and that shouldn't be the cause of differences in the AL but we are getting lower AL for K=4. Pls lmk your thoughts or if I missed something.

@WoosukKwon
Copy link
Collaborator

@ekagra-ranjan This PR itself is not enough to support eagle correctly. We need to handle the draft model's KV cache properly.

@luyuzhe111
Copy link
Contributor

@ekagra-ranjan Hi Ekagra,

Thanks for sharing the benchmarking results! I can confirm the acceptance length you collected for vllm should be accurate, as I aggregate acceptance length from request-level, per-step acceptance counts in this PR #16367.

I suspect you might have used EAGLE-2 in SGLang, which explains the larger acceptance length for k = 4. Can you double check on that?

I will include the acceptance length comparison of vLLM against the EAGLE-1 repo soon.

Regardless, I think we should merge this PR soon to unblock further developments. Without this PR, we can't even debug : )

cc @LiuXiaoxuanPKU @WoosukKwon

Signed-off-by: LiuXiaoxuanPKU <[email protected]>
@ekagra-ranjan
Copy link
Contributor

@luyuzhe111 - thank you for your response. In SGL, I am using chain based draft where --speculative-num-steps 4 --speculative-eagle-topk 1 --speculative-num-draft-tokens 4. The diff bw EAGLE-1 and 2 is that draft tree is dynamic in EAGLE 2 but I am using a chain draft so I think it should be same as EAGLE-1.

@luyuzhe111
Copy link
Contributor

Hi @ekagra-ranjan, if it's indeed chain drafts, then I don't think the acceptance length in SGL makes any sense? Basically the result is saying, the first two draft positions have 0.72 tokens accepted, and the next two draft positions also have 0.68 tokens accepted. Even in the best case scenario (mean number of accepted tokens accepted at each step = [1, 0.37, 0.35, 0.34, 0.33], assuming minimal acceptance rate drop between draft steps), it's impossible to have 0.68 tokens accepted at the third and fourth positions combined.

@luyuzhe111
Copy link
Contributor

For reference, these are the acceptance length comparisons between EAGLE repo, vLLM v0, and vLLM v1.

On MT Bench, assuming single-draft. Acceptance length is computed using #16367

When max number generated tokens = 256

Number of Speculated Tokens 1 2 3 4 5
EAGLE Repo 1.64 2.0 2.14 2.25 2.28
vLLM v0 1.60 1.88 1.99 2.04 2.06
vLLM v1 1.60 1.90 2.04 2.10 2.13

When max number generated tokens = 512

Number of Speculated Tokens 1 2 3 4 5
EAGLE Repo 1.65 2.01 2.19 2.28 2.33
vLLM v0 1.61 1.87 1.99 2.04 2.06
vLLM v1 1.61 1.91 2.05 2.11 2.14

Observations:

  1. EAGLE implementation in v1 has better acceptance length than that in v0! Kudos to all the efforts to fix previously reported bugs!
  2. The acceptance length for vLLM v1 EAGLE is still a bit lower than that from the EAGLE repo. Hopefully [V1][Spec Decode] KV cache slots for eagle heads #16370 will bridge the gap.
  3. In general, longer generation horizon gives better acceptance length. Intuitively, it's easier for EAGLE to speculate when there is more context from the base model. In the original EAGLE repo, the speculator performs slightly better when generation length increases from 256 to 512. This is not the case for v0 EAGLE, but we do observe a tiny improvement for v1 EAGLE.

Hope the numbers here from the EAGLE repo can serve as a reference for future development efforts. cc @LiuXiaoxuanPKU @WoosukKwon

@wwl2755
Copy link
Contributor

wwl2755 commented Apr 10, 2025

HI @luyuzhe111 , Great work! Nice to see such an early benchmarking!

2. The acceptance length for vLLM v1 EAGLE is still a bit lower than that from the EAGLE repo. Hopefully [V1][Spec Decode] KV cache slots for eagle heads #16370 will bridge the gap.

IIUC, this PR is intended to allocate cache slot for the draft model, which should not affect the acceptance rate?

If the implementation is good, I assume it should be more related to the sampling/rejection method. And there indeed has some gaps, like https://github.com/vllm-project/vllm/blob/main/vllm/v1/spec_decode/eagle.py#L219. Maybe could you double-check whether the default sampling parameter consistent with reported in the original EAGLE?

@ekagra-ranjan
Copy link
Contributor

IIUC, this PR is intended to allocate cache slot for the draft model, which should not affect the acceptance rate?

I have the same understanding. Can someone please share why #16370 would improve correctness?

Copy link
Collaborator

@WoosukKwon WoosukKwon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thanks for the PR!

@WoosukKwon WoosukKwon merged commit e8224f3 into vllm-project:main Apr 10, 2025
46 checks passed
@luyuzhe111
Copy link
Contributor

@wwl2755 @ekagra-ranjan Hi Wenlong and Ekagra, thanks for the comments. I actually don't understand #16370 too well and was hoping to dive deeper. Maybe we can keep the discussion under that PR now that this PR is merged?

Regarding the hypothesis on acceptance mechanism, I don't think sampling parameter is the issue since I used greedy sampling for both EAGLE repo and vLLM.

@ekagra-ranjan
Copy link
Contributor

ekagra-ranjan commented Apr 14, 2025

@luyuzhe111 Thanks for sharing your observation!

The issue was that SGL uses speculative-num-draft-tokens - 1 number of draft tokens so the numbers I got for K=2 in SGL actually is comparable to K=1 in vLLM and similarly K=4 in SGL corresponds to K=3 in vLLM. The numbers are now better aligned :)

Using 1 draft token SGL gets 1.72 AL
Using 3 draft token SGL gets 2.4 AL

yangw-dev pushed a commit to yangw-dev/vllm that referenced this pull request Apr 21, 2025
jikunshang pushed a commit to jikunshang/vllm that referenced this pull request Apr 29, 2025
lk-chen pushed a commit to lk-chen/vllm that referenced this pull request Apr 29, 2025
RichardoMrMu pushed a commit to RichardoMrMu/vllm that referenced this pull request May 12, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
documentation Improvements or additions to documentation ready ONLY add when PR is ready to merge/full CI is needed v1
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants