|
| 1 | +"""1D LLaMA model compatible with HuggingFace weights.""" |
| 2 | +import os |
| 3 | +import glob |
| 4 | +import filelock |
| 5 | +from tqdm import tqdm |
| 6 | +from typing import Dict, List, Optional, Tuple |
| 7 | + |
| 8 | +import numpy as np |
| 9 | +import torch |
| 10 | +from torch import nn |
| 11 | +import torch.nn.functional as F |
| 12 | +from transformers import LlamaConfig |
| 13 | +from transformers import PreTrainedModel |
| 14 | + |
| 15 | +from cacheflow.models import InputMetadata |
| 16 | +from cacheflow.models.attention import OPTCacheFlowAttention |
| 17 | +from cacheflow.models.sample import Sampler |
| 18 | +from cacheflow.parallel_utils.parallel_state import ( |
| 19 | + get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) |
| 20 | +from cacheflow.parallel_utils.tensor_parallel import (VocabParallelEmbedding, |
| 21 | + ColumnParallelLinear, |
| 22 | + RowParallelLinear) |
| 23 | +from cacheflow.sequence import SequenceOutputs |
| 24 | + |
| 25 | +KVCache = Tuple[torch.Tensor, torch.Tensor] |
| 26 | + |
| 27 | + |
| 28 | +class LlamaRMSNorm(nn.Module): |
| 29 | + |
| 30 | + def __init__(self, hidden_size, eps=1e-6): |
| 31 | + super().__init__() |
| 32 | + self.weight = nn.Parameter(torch.ones(hidden_size)) |
| 33 | + self.variance_epsilon = eps |
| 34 | + |
| 35 | + def forward(self, hidden_states): |
| 36 | + variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) |
| 37 | + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) |
| 38 | + # convert into half-precision if necessary |
| 39 | + if self.weight.dtype in [torch.float16, torch.bfloat16]: |
| 40 | + hidden_states = hidden_states.to(self.weight.dtype) |
| 41 | + return self.weight * hidden_states |
| 42 | + |
| 43 | + |
| 44 | +class LlamaRotaryEmbedding(torch.nn.Module): |
| 45 | + |
| 46 | + def __init__(self, dim, max_position_embeddings=2048, base=10000): |
| 47 | + super().__init__() |
| 48 | + self.max_position_embeddings = max_position_embeddings |
| 49 | + |
| 50 | + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2) / dim)) |
| 51 | + self.register_buffer("inv_freq", inv_freq) |
| 52 | + |
| 53 | + # Create cos and sin embeddings. |
| 54 | + t = torch.arange(max_position_embeddings).float() |
| 55 | + freqs = torch.einsum("i,j->ij", t, self.inv_freq.float()) |
| 56 | + emb = torch.cat((freqs, freqs), dim=-1) |
| 57 | + cos = emb.cos().to(dtype=self.inv_freq.dtype) |
| 58 | + sin = emb.sin().to(dtype=self.inv_freq.dtype) |
| 59 | + self.register_buffer("cos_cached", cos, persistent=False) |
| 60 | + self.register_buffer("sin_cached", sin, persistent=False) |
| 61 | + |
| 62 | + def forward( |
| 63 | + self, |
| 64 | + positions: torch.LongTensor, |
| 65 | + ) -> Tuple[torch.Tensor, torch.Tensor]: |
| 66 | + cos = F.embedding(positions, self.cos_cached) |
| 67 | + sin = F.embedding(positions, self.sin_cached) |
| 68 | + return cos, sin |
| 69 | + |
| 70 | + |
| 71 | +def rotate_half(x): |
| 72 | + """Rotates half the hidden dims of the input.""" |
| 73 | + x1 = x[..., : x.shape[-1] // 2] |
| 74 | + x2 = x[..., x.shape[-1] // 2 :] |
| 75 | + return torch.cat((-x2, x1), dim=-1) |
| 76 | + |
| 77 | + |
| 78 | +def apply_rotary_pos_emb(q, k, cos, sin): |
| 79 | + # TODO: Optimize. |
| 80 | + q_embed = (q * cos) + (rotate_half(q) * sin) |
| 81 | + k_embed = (k * cos) + (rotate_half(k) * sin) |
| 82 | + return q_embed, k_embed |
| 83 | + |
| 84 | + |
| 85 | +class LlamaMLP(nn.Module): |
| 86 | + def __init__( |
| 87 | + self, |
| 88 | + hidden_size: int, |
| 89 | + intermediate_size: int, |
| 90 | + hidden_act: str, |
| 91 | + ): |
| 92 | + super().__init__() |
| 93 | + # TODO: Merge the gate and down linear layers. |
| 94 | + self.gate_proj = ColumnParallelLinear(hidden_size, intermediate_size, |
| 95 | + bias=False, gather_output=False, |
| 96 | + perform_initialization=False) |
| 97 | + self.down_proj = RowParallelLinear(intermediate_size, hidden_size, |
| 98 | + bias=False, input_is_parallel=True, |
| 99 | + perform_initialization=False) |
| 100 | + self.up_proj = ColumnParallelLinear(hidden_size, intermediate_size, |
| 101 | + bias=False, gather_output=False, |
| 102 | + perform_initialization=False) |
| 103 | + assert hidden_act == 'silu' |
| 104 | + self.act_fn = nn.SiLU() |
| 105 | + |
| 106 | + def forward(self, x): |
| 107 | + gate, _ = self.gate_proj(x) |
| 108 | + up, _ = self.up_proj(x) |
| 109 | + x = self.act_fn(gate) * up |
| 110 | + x, _ = self.down_proj(x) |
| 111 | + return x |
| 112 | + |
| 113 | + |
| 114 | +class LlamaAttention(nn.Module): |
| 115 | + |
| 116 | + def __init__( |
| 117 | + self, |
| 118 | + hidden_size: int, |
| 119 | + num_heads: int, |
| 120 | + ): |
| 121 | + super().__init__() |
| 122 | + self.hidden_size = hidden_size |
| 123 | + tensor_model_parallel_world_size = get_tensor_model_parallel_world_size() |
| 124 | + self.total_num_heads = num_heads |
| 125 | + assert self.total_num_heads % tensor_model_parallel_world_size == 0 |
| 126 | + self.num_heads = self.total_num_heads // tensor_model_parallel_world_size |
| 127 | + self.head_dim = hidden_size // self.total_num_heads |
| 128 | + self.scaling = self.head_dim ** -0.5 |
| 129 | + |
| 130 | + # TODO: Merge the QKV linear layers. |
| 131 | + self.q_proj = ColumnParallelLinear( |
| 132 | + hidden_size, |
| 133 | + self.total_num_heads * self.head_dim, |
| 134 | + bias=False, |
| 135 | + gather_output=False, |
| 136 | + perform_initialization=False, |
| 137 | + ) |
| 138 | + self.k_proj = ColumnParallelLinear( |
| 139 | + hidden_size, |
| 140 | + self.total_num_heads * self.head_dim, |
| 141 | + bias=False, |
| 142 | + gather_output=False, |
| 143 | + perform_initialization=False, |
| 144 | + ) |
| 145 | + self.v_proj = ColumnParallelLinear( |
| 146 | + hidden_size, |
| 147 | + self.total_num_heads * self.head_dim, |
| 148 | + bias=False, |
| 149 | + gather_output=False, |
| 150 | + perform_initialization=False, |
| 151 | + ) |
| 152 | + self.o_proj = RowParallelLinear( |
| 153 | + self.total_num_heads * self.head_dim, |
| 154 | + hidden_size, |
| 155 | + bias=False, |
| 156 | + input_is_parallel=True, |
| 157 | + perform_initialization=False, |
| 158 | + ) |
| 159 | + self.rotary_emb = LlamaRotaryEmbedding(self.head_dim) |
| 160 | + # FIXME(woosuk): Rename this. |
| 161 | + self.attn = OPTCacheFlowAttention(scale=self.scaling) |
| 162 | + |
| 163 | + def forward( |
| 164 | + self, |
| 165 | + positions: torch.LongTensor, |
| 166 | + hidden_states: torch.Tensor, |
| 167 | + kv_cache: KVCache, |
| 168 | + input_metadata: InputMetadata, |
| 169 | + cache_event: Optional[torch.cuda.Event], |
| 170 | + ) -> torch.Tensor: |
| 171 | + q, _ = self.q_proj(hidden_states) |
| 172 | + k, _ = self.k_proj(hidden_states) |
| 173 | + v, _ = self.v_proj(hidden_states) |
| 174 | + |
| 175 | + # Apply rotrary embedding. |
| 176 | + # TODO: Optimize. |
| 177 | + q = q.view(-1, self.num_heads, self.head_dim).transpose(0, 1) |
| 178 | + k = k.view(-1, self.num_heads, self.head_dim).transpose(0, 1) |
| 179 | + cos, sin = self.rotary_emb(positions) |
| 180 | + q, k = apply_rotary_pos_emb(q, k, cos, sin) |
| 181 | + q = q.transpose(0, 1).contiguous().view(-1, self.num_heads * self.head_dim) |
| 182 | + k = k.transpose(0, 1).contiguous().view(-1, self.num_heads * self.head_dim) |
| 183 | + |
| 184 | + key_cache, value_cache = kv_cache |
| 185 | + attn_output = self.attn( |
| 186 | + q, k, v, key_cache, value_cache, input_metadata, cache_event) |
| 187 | + output, _ = self.o_proj(attn_output) |
| 188 | + return output |
| 189 | + |
| 190 | + |
| 191 | +class LlamaDecoderLayer(nn.Module): |
| 192 | + |
| 193 | + def __init__(self, config: LlamaConfig): |
| 194 | + super().__init__() |
| 195 | + self.hidden_size = config.hidden_size |
| 196 | + self.self_attn = LlamaAttention( |
| 197 | + hidden_size=self.hidden_size, |
| 198 | + num_heads=config.num_attention_heads, |
| 199 | + ) |
| 200 | + self.mlp = LlamaMLP( |
| 201 | + hidden_size=self.hidden_size, |
| 202 | + intermediate_size=config.intermediate_size, |
| 203 | + hidden_act=config.hidden_act, |
| 204 | + ) |
| 205 | + self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
| 206 | + self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
| 207 | + |
| 208 | + def forward( |
| 209 | + self, |
| 210 | + positions: torch.LongTensor, |
| 211 | + hidden_states: torch.Tensor, |
| 212 | + kv_cache: KVCache, |
| 213 | + input_metadata: InputMetadata, |
| 214 | + cache_event: Optional[torch.cuda.Event], |
| 215 | + ) -> torch.Tensor: |
| 216 | + # Self Attention |
| 217 | + residual = hidden_states |
| 218 | + hidden_states = self.input_layernorm(hidden_states) |
| 219 | + hidden_states = self.self_attn( |
| 220 | + positions=positions, |
| 221 | + hidden_states=hidden_states, |
| 222 | + kv_cache=kv_cache, |
| 223 | + input_metadata=input_metadata, |
| 224 | + cache_event=cache_event, |
| 225 | + ) |
| 226 | + hidden_states = residual + hidden_states |
| 227 | + |
| 228 | + # Fully Connected |
| 229 | + residual = hidden_states |
| 230 | + hidden_states = self.post_attention_layernorm(hidden_states) |
| 231 | + hidden_states = self.mlp(hidden_states) |
| 232 | + hidden_states = residual + hidden_states |
| 233 | + return hidden_states |
| 234 | + |
| 235 | + |
| 236 | +class LlamaModel(nn.Module): |
| 237 | + |
| 238 | + def __init__(self, config: LlamaConfig): |
| 239 | + super().__init__() |
| 240 | + self.config = config |
| 241 | + self.padding_idx = config.pad_token_id |
| 242 | + self.vocab_size = config.vocab_size |
| 243 | + |
| 244 | + self.embed_tokens = VocabParallelEmbedding(config.vocab_size, config.hidden_size, |
| 245 | + perform_initialization=False) |
| 246 | + self.layers = nn.ModuleList([LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)]) |
| 247 | + self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
| 248 | + |
| 249 | + def forward( |
| 250 | + self, |
| 251 | + input_ids: torch.LongTensor, |
| 252 | + positions: torch.LongTensor, |
| 253 | + kv_caches: List[KVCache], |
| 254 | + input_metadata: InputMetadata, |
| 255 | + cache_events: Optional[List[torch.cuda.Event]], |
| 256 | + ) -> torch.Tensor: |
| 257 | + hidden_states = self.embed_tokens(input_ids) |
| 258 | + for i in range(len(self.layers)): |
| 259 | + if cache_events is None: |
| 260 | + cache_event = None |
| 261 | + else: |
| 262 | + cache_event = cache_events[i] |
| 263 | + layer = self.layers[i] |
| 264 | + hidden_states = layer( |
| 265 | + positions, |
| 266 | + hidden_states, |
| 267 | + kv_caches[i], |
| 268 | + input_metadata, |
| 269 | + cache_event, |
| 270 | + ) |
| 271 | + hidden_states = self.norm(hidden_states) |
| 272 | + return hidden_states |
| 273 | + |
| 274 | + |
| 275 | +class LlamaForCausalLM(nn.Module): |
| 276 | + def __init__(self, config): |
| 277 | + super().__init__() |
| 278 | + self.config = config |
| 279 | + self.model = LlamaModel(config) |
| 280 | + self.lm_head = ColumnParallelLinear(config.hidden_size, |
| 281 | + config.vocab_size, |
| 282 | + bias=False, |
| 283 | + gather_output=False, |
| 284 | + perform_initialization=False) |
| 285 | + self.sampler = Sampler() |
| 286 | + |
| 287 | + def forward( |
| 288 | + self, |
| 289 | + input_ids: torch.LongTensor, |
| 290 | + positions: torch.LongTensor, |
| 291 | + kv_caches: List[KVCache], |
| 292 | + input_metadata: InputMetadata, |
| 293 | + cache_events: Optional[List[torch.cuda.Event]], |
| 294 | + ) -> Dict[int, SequenceOutputs]: |
| 295 | + hidden_states = self.model( |
| 296 | + input_ids, positions, kv_caches, input_metadata, cache_events) |
| 297 | + next_tokens = self.sampler( |
| 298 | + self.lm_head.weight, hidden_states, input_metadata) |
| 299 | + return next_tokens |
| 300 | + |
| 301 | + _column_parallel_weights = ["embed_tokens.weight", "lm_head.weight", |
| 302 | + "q_proj.weight", "k_proj.weight", |
| 303 | + "v_proj.weight", "gate_proj.weight", |
| 304 | + "up_proj.weight"] |
| 305 | + _row_parallel_weights = ["o_proj.weight", "down_proj.weight"] |
| 306 | + |
| 307 | + def load_weights(self, weights_path: str): |
| 308 | + tensor_model_parallel_rank = get_tensor_model_parallel_rank() |
| 309 | + state_dict = self.state_dict() |
| 310 | + for name, param in state_dict.items(): |
| 311 | + loaded_weight = torch.from_numpy(np.load(os.path.join(weights_path, |
| 312 | + name))) |
| 313 | + for p in self._column_parallel_weights: |
| 314 | + if p in name: |
| 315 | + shard_size = param.shape[0] |
| 316 | + loaded_weight = loaded_weight[ |
| 317 | + shard_size * tensor_model_parallel_rank |
| 318 | + :shard_size * (tensor_model_parallel_rank + 1)] |
| 319 | + break |
| 320 | + for p in self._row_parallel_weights: |
| 321 | + if p in name: |
| 322 | + shard_size = param.shape[1] |
| 323 | + loaded_weight = loaded_weight[ |
| 324 | + :, |
| 325 | + shard_size * tensor_model_parallel_rank |
| 326 | + :shard_size * (tensor_model_parallel_rank + 1)] |
| 327 | + break |
| 328 | + |
| 329 | + assert param.shape == loaded_weight.shape |
| 330 | + param.data.copy_(loaded_weight) |
| 331 | + |
| 332 | + @staticmethod |
| 333 | + def get_weights(model_name: str, path: str): |
| 334 | + if not os.path.isfile(os.path.join(model_name, "config.json")): |
| 335 | + raise ValueError("LLaMA model's model_name has to be a path" |
| 336 | + "to the huggingface model's directory.") |
| 337 | + path = os.path.join(model_name, f"np") |
| 338 | + path = os.path.abspath(os.path.expanduser(path)) |
| 339 | + os.makedirs(path, exist_ok=True) |
| 340 | + lock_path = os.path.join(path, "file_lock") |
| 341 | + lock = filelock.FileLock(lock_path) |
| 342 | + |
| 343 | + with lock: |
| 344 | + test_weight_path = os.path.join(path, "model.embed_tokens.weight") |
| 345 | + if os.path.exists(test_weight_path): |
| 346 | + return path |
| 347 | + |
| 348 | + bin_files = glob.glob(os.path.join(model_name, "*.bin")) |
| 349 | + |
| 350 | + for bin_file in tqdm(bin_files, desc="Convert format"): |
| 351 | + state = torch.load(bin_file, map_location="cpu") |
| 352 | + for name, param in tqdm(state.items(), leave=False): |
| 353 | + param_path = os.path.join(path, name) |
| 354 | + with open(param_path, "wb") as f: |
| 355 | + np.save(f, param.cpu().detach().numpy()) |
| 356 | + |
| 357 | + return path |
0 commit comments