| 
 | 1 | +# SPDX-License-Identifier: Apache-2.0  | 
 | 2 | + | 
 | 3 | +from typing import Iterable, Set, Tuple  | 
 | 4 | + | 
 | 5 | +import torch  | 
 | 6 | +import torch.nn as nn  | 
 | 7 | +from transformers import LlamaConfig  | 
 | 8 | + | 
 | 9 | +from vllm.config import ModelConfig  | 
 | 10 | +from vllm.logger import init_logger  | 
 | 11 | +from vllm.model_executor.layers.logits_processor import LogitsProcessor  | 
 | 12 | +from vllm.model_executor.layers.vocab_parallel_embedding import (  | 
 | 13 | +    VocabParallelEmbedding)  | 
 | 14 | +from vllm.model_executor.model_loader.weight_utils import default_weight_loader  | 
 | 15 | +from vllm.model_executor.models.llama import (LlamaDecoderLayer,  | 
 | 16 | +                                              LlamaForCausalLM)  | 
 | 17 | + | 
 | 18 | +from .utils import AutoWeightsLoader, maybe_prefix  | 
 | 19 | + | 
 | 20 | +logger = init_logger(__name__)  | 
 | 21 | + | 
 | 22 | + | 
 | 23 | +class LlamaDecoderLayer(LlamaDecoderLayer):  | 
 | 24 | + | 
 | 25 | +    def __init__(  | 
 | 26 | +        self,  | 
 | 27 | +        config: LlamaConfig,  | 
 | 28 | +        disable_input_layernorm: bool,  | 
 | 29 | +        prefix: str = "",  | 
 | 30 | +    ) -> None:  | 
 | 31 | +        super().__init__(config, prefix=prefix)  | 
 | 32 | + | 
 | 33 | +        # Skip the input_layernorm  | 
 | 34 | +        # https://github.com/SafeAILab/EAGLE/blob/35c78f6cdc19a73e05cf5c330b4c358dad970c6a/eagle/model/cnets.py#L427  | 
 | 35 | +        if disable_input_layernorm:  | 
 | 36 | +            del self.input_layernorm  | 
 | 37 | +            self.input_layernorm = nn.Identity()  | 
 | 38 | + | 
 | 39 | + | 
 | 40 | +class LlamaModel(nn.Module):  | 
 | 41 | + | 
 | 42 | +    def __init__(  | 
 | 43 | +        self,  | 
 | 44 | +        *,  | 
 | 45 | +        model_config: ModelConfig,  | 
 | 46 | +        start_layer_id: int = 0,  | 
 | 47 | +        prefix: str = "",  | 
 | 48 | +    ) -> None:  | 
 | 49 | +        super().__init__()  | 
 | 50 | +        self.config = model_config.hf_config  | 
 | 51 | +        self.vocab_size = self.config.vocab_size  | 
 | 52 | +        self.embed_tokens = VocabParallelEmbedding(  | 
 | 53 | +            self.config.vocab_size,  | 
 | 54 | +            self.config.hidden_size,  | 
 | 55 | +            prefix=maybe_prefix(prefix, "embed_tokens"),  | 
 | 56 | +        )  | 
 | 57 | +        self.layers = nn.ModuleList([  | 
 | 58 | +            LlamaDecoderLayer(  | 
 | 59 | +                self.config,  | 
 | 60 | +                i == 0,  | 
 | 61 | +                prefix=maybe_prefix(prefix, f"layers.{i + start_layer_id}"),  | 
 | 62 | +            ) for i in range(self.config.num_hidden_layers)  | 
 | 63 | +        ])  | 
 | 64 | +        self.fc = torch.nn.Linear(self.config.hidden_size * 2,  | 
 | 65 | +                                  self.config.hidden_size,  | 
 | 66 | +                                  bias=False)  | 
 | 67 | + | 
 | 68 | +    def forward(  | 
 | 69 | +        self,  | 
 | 70 | +        input_ids: torch.Tensor,  | 
 | 71 | +        positions: torch.Tensor,  | 
 | 72 | +        hidden_states: torch.Tensor,  | 
 | 73 | +    ) -> torch.Tensor:  | 
 | 74 | +        input_embeds = self.embed_tokens(input_ids)  | 
 | 75 | +        hidden_states = self.fc(  | 
 | 76 | +            torch.cat((input_embeds, hidden_states), dim=-1))  | 
 | 77 | +        residual = None  | 
 | 78 | +        for i in range(len(self.layers)):  | 
 | 79 | +            layer = self.layers[i]  | 
 | 80 | +            hidden_states, residual = layer(  | 
 | 81 | +                positions,  | 
 | 82 | +                hidden_states,  | 
 | 83 | +                residual,  | 
 | 84 | +            )  | 
 | 85 | +        return hidden_states + residual  | 
 | 86 | + | 
 | 87 | +    def load_weights(self, weights: Iterable[Tuple[str,  | 
 | 88 | +                                                   torch.Tensor]]) -> Set[str]:  | 
 | 89 | +        stacked_params_mapping = [  | 
 | 90 | +            # (param_name, shard_name, shard_id)  | 
 | 91 | +            (".qkv_proj", ".q_proj", "q"),  | 
 | 92 | +            (".qkv_proj", ".k_proj", "k"),  | 
 | 93 | +            (".qkv_proj", ".v_proj", "v"),  | 
 | 94 | +            (".gate_up_proj", ".gate_proj", 0),  | 
 | 95 | +            (".gate_up_proj", ".up_proj", 1),  | 
 | 96 | +        ]  | 
 | 97 | +        params_dict = dict(self.named_parameters())  | 
 | 98 | +        loaded_params: Set[str] = set()  | 
 | 99 | +        for name, loaded_weight in weights:  | 
 | 100 | +            for param_name, weight_name, shard_id in stacked_params_mapping:  | 
 | 101 | +                if weight_name not in name:  | 
 | 102 | +                    continue  | 
 | 103 | +                name = name.replace(weight_name, param_name)  | 
 | 104 | +                param = params_dict[name]  | 
 | 105 | +                weight_loader = param.weight_loader  | 
 | 106 | +                weight_loader(param, loaded_weight, shard_id)  | 
 | 107 | +                break  | 
 | 108 | +            else:  | 
 | 109 | +                param = params_dict[name]  | 
 | 110 | +                weight_loader = getattr(param, "weight_loader",  | 
 | 111 | +                                        default_weight_loader)  | 
 | 112 | +                weight_loader(param, loaded_weight)  | 
 | 113 | +            loaded_params.add(name)  | 
 | 114 | +        return loaded_params  | 
 | 115 | + | 
 | 116 | + | 
 | 117 | +class EagleLlamaForCausalLM(LlamaForCausalLM):  | 
 | 118 | + | 
 | 119 | +    def __init__(self, *, model_config: ModelConfig, start_layer_id: int = 0):  | 
 | 120 | +        nn.Module.__init__(self)  | 
 | 121 | +        self.config = model_config.hf_config  | 
 | 122 | +        self.model = LlamaModel(model_config=model_config,  | 
 | 123 | +                                start_layer_id=start_layer_id,  | 
 | 124 | +                                prefix="model")  | 
 | 125 | + | 
 | 126 | +        logit_scale = getattr(self.config, "logit_scale", 1.0)  | 
 | 127 | +        self.logits_processor = LogitsProcessor(self.config.vocab_size,  | 
 | 128 | +                                                scale=logit_scale)  | 
 | 129 | + | 
 | 130 | +    def forward(  | 
 | 131 | +        self,  | 
 | 132 | +        input_ids: torch.Tensor,  | 
 | 133 | +        positions: torch.Tensor,  | 
 | 134 | +        hidden_states: torch.Tensor,  | 
 | 135 | +    ) -> torch.Tensor:  | 
 | 136 | +        return self.model(input_ids, positions, hidden_states)  | 
 | 137 | + | 
 | 138 | +    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):  | 
 | 139 | +        loader = AutoWeightsLoader(  | 
 | 140 | +            self,  | 
 | 141 | +            skip_prefixes=(["lm_head."]  | 
 | 142 | +                           if self.config.tie_word_embeddings else None),  | 
 | 143 | +        )  | 
 | 144 | + | 
 | 145 | +        model_weights = {}  | 
 | 146 | +        for name, loaded_weight in weights:  | 
 | 147 | +            if "lm_head" not in name:  | 
 | 148 | +                name = "model." + name  | 
 | 149 | +            model_weights[name] = loaded_weight  | 
 | 150 | + | 
 | 151 | +        loader.load_weights(model_weights.items())  | 
0 commit comments