Skip to content

Conversation

benchislett
Copy link
Collaborator

@benchislett benchislett commented Jun 2, 2025

A previous PR allowed the EAGLE vocab embeddings to be omitted from initialization during weight loading, operating under the assumption that they are always overridden by the embeddings from the target model. However, the EAGLE3 head for Llama 3.3 70B has a different hidden size than the target model, and thus a distinct vocab embeding.

Currently, Llama 70B is unusable on main for this reason. The vocab embedding must be initialized for that case. I updated the model code to always declare the vocab embedding, and load it when it is present. Then, the eagle load_weights will free the old one and replace it with the target model's embeddings only when they have the same shape.

Notes

  • I had to fix some unrelated import in order to pass the pre-commit
  • I added gc.collect and torch.cuda.empty_cache to the memory profiling so that the deallocated vocab embedding will not affect memory profiling. I confirmed locally that this is works as intended, showing 1GB less on the memory profiling stage.

Signed-off-by: Benjamin Chislett <[email protected]>
Signed-off-by: Benjamin Chislett <[email protected]>
Signed-off-by: Benjamin Chislett <[email protected]>
Copy link

github-actions bot commented Jun 2, 2025

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@mergify mergify bot added the v1 label Jun 2, 2025
def current_memory_usage(self) -> float:
# Return the memory usage in bytes.
from vllm.platforms import current_platform
gc.collect()
Copy link
Contributor

@ekagra-ranjan ekagra-ranjan Jun 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@WoosukKwon can share if gc.collect() and torch.cuda.empty_cache() are fine here. Maybe there is some reason why they were not already added before.

I believe this was added because we delete some torch tensor after allocation. Just in case for some reason we think its better to avoid these new gc commands, an alternative approach to avoid it would be to first load draft model weights from checkpoint and determine if the draft vocab is needed and then pass this info to draft model object instantiation which can skip allocating draft vocab and achieve the same objective.

Copy link
Collaborator Author

@benchislett benchislett Jun 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the current approach makes sense, as enforcing GC and clearing the torch cache seem like natural choices to improve the accuracy of the memory profiler.

If we foresee any issues with calling GC/cleanup in this way, then I'm on board for doing it the other way

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agreed with @ekagra-ranjan, though I didn’t see a clear problem. Let’s keep this in mind and revisit if any issue arises.

@WoosukKwon
Copy link
Collaborator

I feel like this kind of bug is repeating, probably because 1) we lack tests on EAGLE and 2) I'm not a great reviewer for eagle's weight loading (I don't have enough background).

@benchislett
Copy link
Collaborator Author

I can contribute additional testing for EAGLE in the next couple days.

@singh-git10
Copy link

@benchislett Do you have any comparison of the impact of reduced hidden size on acceptance length versus the speedup, as has been done for Llama 3.3 70B EAGLE-3 head?

@benchislett
Copy link
Collaborator Author

@benchislett Do you have any comparison of the impact of reduced hidden size on acceptance length versus the speedup, as has been done for Llama 3.3 70B EAGLE-3 head?

No, this was an architecture decision made at training time by the EAGLE3 authors. I do not recall any comments from the authors regarding the ablation of this parameter.

@benchislett
Copy link
Collaborator Author

@WoosukKwon This PR should be ready, PTAL. I've updated the weight loading test to cover the edge-cases more exhaustively, and also added some more thorough acceptance-rate tests in #19104 .

@Neo9061
Copy link

Neo9061 commented Jun 4, 2025

Hi @benchislett thanks a lot for this PR! Super useful.

Do you see OTPTS improvement of official-released EAGLE-3 70B over EAGLE-2, with your code changes?

Previously, we are able to see improvement on acceptance rates but cannot see OTPS gain previously. Code is based on the vLLM without your fix, and head is EAGLE-3 70B. Many thanks

@WoosukKwon WoosukKwon added this to the v0.9.1 milestone Jun 4, 2025
@Neo9061
Copy link

Neo9061 commented Jun 5, 2025

Hi @benchislett thanks a lot for this PR! Super useful.

Do you see OTPTS improvement of official-released EAGLE-3 70B over EAGLE-2, with your code changes?

Previously, we are able to see improvement on acceptance rates but cannot see OTPS gain previously. Code is based on the vLLM without your fix, and head is EAGLE-3 70B. Many thanks

Sorry to follow up, I pulled and installed your PR and saw an error with vllm loading llama 3.3 70B from their official release. Wonder if you see the same issue or there is something wrong I configured wrongly on my end.

Here is details:
#19174

Thank you very much!

@benchislett
Copy link
Collaborator Author

@Neo9061 I worry that this might be an inherent limitation of the architecture used for the EAGLE3 heads. If it seems correct to you when running with max_seq_len==2048, I am inclined to think that this is the cause.

Regarding the speedup, I have not done comparison benchmarks with EAGLE1 but the EAGLE3 head does seem quite performant. If you identify a degradation compared to EAGLE1, please share the details and we will work to improve the performance. Thank you for the feedback and early testing.

@ekagra-ranjan
Copy link
Contributor

Previously, we are able to see improvement on acceptance rates but cannot see OTPS gain previously. Code is based on the vLLM without your fix, and head is EAGLE-3 70B. Many thanks

@Neo9061 -in the last benchmark, we have seen that EAGLE-3 on vllm has better AL than Eagle-1 but the TOPS is worse in online serving.

offline bench - #16937 (comment)
online bench - #17202 (comment)

Can you pls share your exact setup for comparison?

@benchislett
Copy link
Collaborator Author

This was also before torch.compile support was added to EAGLE/EAGLE3 I think. I expect that those numbers would look different now.

@Neo9061
Copy link

Neo9061 commented Jun 5, 2025

@ekagra-ranjan @benchislett Sure, we will share the data soon. Basically in the online inference, for 70B, using the official EAGLE-2 head and EAGLE-3 head, we saw acceptance rate improves 29% and head size reduces 50% but the OTPS gain is only 26%. We are using llmperf + vllm to benchmark.

@singh-git10 will provide an github issue soon.

@singh-git10
Copy link

@ekagra-ranjan @benchislett I have created a GitHub issue here. Please let me know if any other details are needed to reproduce the results.

@WoosukKwon WoosukKwon merged commit 3465b87 into vllm-project:main Jun 6, 2025
69 checks passed
Copy link
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed v1
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants