Skip to content

Conversation

ekagra-ranjan
Copy link
Contributor

@ekagra-ranjan ekagra-ranjan commented Apr 28, 2025

This PR:

Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@mergify mergify bot added documentation Improvements or additions to documentation v1 labels Apr 28, 2025
Copy link
Collaborator

@WoosukKwon WoosukKwon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ekagra-ranjan Thanks for the PR!

One issue with the PR is that it assumes PP=1. Can you please handle PP > 1 as well (at least for llama)?

Copy link

mergify bot commented Apr 29, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @ekagra-ranjan.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@WoosukKwon
Copy link
Collaborator

@ekagra-ranjan Could you please update the PR? If handling PP is tricky, you can simply check the pipeline_parallel_size and raise an error if it's not 1 (for now). We can fix it later.

Comment on lines -145 to -146
skip_prefixes=(["lm_head."]
if self.config.tie_word_embeddings else None),
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

eagle model def doesnt have lm_head nor the weights to removed it

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ekagra-ranjan Do you mean EAGLE1 doesn't have the LM head? I'm wondering because some EAGLE3 weights do include the LM head.

Copy link
Contributor Author

@ekagra-ranjan ekagra-ranjan May 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

EAGLE1 reuses the lm_head of target model for each spec step whereas EAGLE3 does not. For e.g.,

yuhuili/EAGLE-LLaMA3-Instruct-8B has these weights

Number of weights: 10
Key: layers.0.self_attn.q_proj.weight, Shape: torch.Size([4096, 4096]), Dtype: torch.float16
Key: layers.0.self_attn.k_proj.weight, Shape: torch.Size([1024, 4096]), Dtype: torch.float16
Key: layers.0.self_attn.v_proj.weight, Shape: torch.Size([1024, 4096]), Dtype: torch.float16
Key: layers.0.self_attn.o_proj.weight, Shape: torch.Size([4096, 4096]), Dtype: torch.float16
Key: layers.0.mlp.gate_proj.weight, Shape: torch.Size([14336, 4096]), Dtype: torch.float16
Key: layers.0.mlp.up_proj.weight, Shape: torch.Size([14336, 4096]), Dtype: torch.float16
Key: layers.0.mlp.down_proj.weight, Shape: torch.Size([4096, 14336]), Dtype: torch.float16
Key: layers.0.post_attention_layernorm.weight, Shape: torch.Size([4096]), Dtype: torch.float16
Key: embed_tokens.weight, Shape: torch.Size([128256, 4096]), Dtype: torch.float16
Key: fc.weight, Shape: torch.Size([4096, 8192]), Dtype: torch.float16

EAGLE1 sets the target lm_head as draft's lm_head here

EAGLE 3's lm_head is not the same as the target model. It's noted in this PR as well #16937 (comment)

root and others added 2 commits May 10, 2025 19:28
@ekagra-ranjan
Copy link
Contributor Author

ekagra-ranjan commented May 10, 2025

@WoosukKwon Done! For PP > 1, the target embed would be on rank 0 whereas the drafter will run on last rank so the drafter's embed cannot be shared with target. In the case, the current code will expect the embed weights to be present in draft checkpoint during weight loading when using PP and raise an exception if that's not the case.

Copy link
Collaborator

@WoosukKwon WoosukKwon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ekagra-ranjan Left some minor comments. Please check them out.

@WoosukKwon WoosukKwon added the ready ONLY add when PR is ready to merge/full CI is needed label May 13, 2025
@WoosukKwon
Copy link
Collaborator

@ekagra-ranjan Please fix the lint errors.

@ekagra-ranjan
Copy link
Contributor Author

@WoosukKwon - Done!

@WoosukKwon WoosukKwon merged commit 418d2f8 into vllm-project:main May 14, 2025
61 of 62 checks passed
zzzyq pushed a commit to zzzyq/vllm that referenced this pull request May 24, 2025
…aft model to free ~1GB for llama 3 model (vllm-project#17326)

Co-authored-by: root <[email protected]>
Co-authored-by: Woosuk Kwon <[email protected]>
Signed-off-by: Yuqi Zhang <[email protected]>
@singh-git10
Copy link

@ekagra-ranjan @WoosukKwon I believe the scenario where the EAGLE-3 draft model has different embedding weights than the target model is not being properly handled in the current implementation. This issue specifically applies to the EAGLE-3 head for the Llama 3.3 70B model.(yuhuili/EAGLE3-LLaMA3.3-Instruct-70B.

@ekagra-ranjan
Copy link
Contributor Author

@singh-git10 - its being addressed here: #19033

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
documentation Improvements or additions to documentation ready ONLY add when PR is ready to merge/full CI is needed speculative-decoding v1
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants