diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 4b5c9b7ec640..f7179385ebb7 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -329,16 +329,24 @@ def load_model(self, target_model: nn.Module) -> None: self.attn_layer_names = list(draft_attn_layer_names) + if supports_multimodal(target_model): + # handle multimodality + self.model.config.image_token_index = ( + target_model.config.image_token_index) + target_language_model = target_model.get_language_model() + else: + target_language_model = target_model # share embed_tokens with the target model if needed if get_pp_group().world_size == 1 \ and self.model.model.embed_tokens.weight.shape \ - == target_model.model.embed_tokens.weight.shape: + == target_language_model.model.embed_tokens.weight.shape: logger.info( "Assuming the EAGLE head shares the same vocab embedding" \ " with the target model." ) del self.model.model.embed_tokens - self.model.model.embed_tokens = target_model.model.embed_tokens + self.model.model.embed_tokens = ( + target_language_model.model.embed_tokens) else: logger.info( "The EAGLE head's vocab embedding will be loaded separately" \ @@ -349,12 +357,9 @@ def load_model(self, target_model: nn.Module) -> None: # some model definition do not define lm_head explicitly # and reuse embed_tokens for lm_head, e.g., CohereForCausalLM if self.vllm_config.speculative_config.method != "eagle3" and \ - hasattr(target_model, "lm_head"): + hasattr(target_language_model, "lm_head"): logger.info("Loading EAGLE LM head weights from the target model.") - if supports_multimodal(target_model): - self.model.lm_head = target_model.get_language_model().lm_head - else: - self.model.lm_head = target_model.lm_head + self.model.lm_head = target_language_model.lm_head @torch.inference_mode() def dummy_run(