Skip to content

Commit 80a2f81

Browse files
Implement LLaMA (#9)
Co-authored-by: Zhuohan Li <[email protected]>
1 parent a1b3de8 commit 80a2f81

File tree

7 files changed

+500
-35
lines changed

7 files changed

+500
-35
lines changed

README.md

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,10 @@
33
## Installation
44

55
```bash
6-
pip install psutil numpy torch transformers
7-
pip install flash-attn # This may take up to 10 mins.
6+
pip install psutil numpy ray torch
7+
pip install git+https://github.com/huggingface/transformers # Required for LLaMA.
8+
pip install sentencepiece # Required for LlamaTokenizer.
9+
pip install flash-attn # This may take up to 20 mins.
810
pip install -e .
911
```
1012

cacheflow/master/simple_frontend.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,4 +61,5 @@ def print_response(
6161
for seq in seq_group.seqs:
6262
token_ids = seq.get_token_ids()
6363
output = self.tokenizer.decode(token_ids, skip_special_tokens=True)
64+
output = output.strip()
6465
print(f'Seq {seq.seq_id}: {output!r}')

cacheflow/models/llama.py

Lines changed: 357 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,357 @@
1+
"""1D LLaMA model compatible with HuggingFace weights."""
2+
import os
3+
import glob
4+
import filelock
5+
from tqdm import tqdm
6+
from typing import Dict, List, Optional, Tuple
7+
8+
import numpy as np
9+
import torch
10+
from torch import nn
11+
import torch.nn.functional as F
12+
from transformers import LlamaConfig
13+
from transformers import PreTrainedModel
14+
15+
from cacheflow.models import InputMetadata
16+
from cacheflow.models.attention import OPTCacheFlowAttention
17+
from cacheflow.models.sample import Sampler
18+
from cacheflow.parallel_utils.parallel_state import (
19+
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
20+
from cacheflow.parallel_utils.tensor_parallel import (VocabParallelEmbedding,
21+
ColumnParallelLinear,
22+
RowParallelLinear)
23+
from cacheflow.sequence import SequenceOutputs
24+
25+
KVCache = Tuple[torch.Tensor, torch.Tensor]
26+
27+
28+
class LlamaRMSNorm(nn.Module):
29+
30+
def __init__(self, hidden_size, eps=1e-6):
31+
super().__init__()
32+
self.weight = nn.Parameter(torch.ones(hidden_size))
33+
self.variance_epsilon = eps
34+
35+
def forward(self, hidden_states):
36+
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
37+
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
38+
# convert into half-precision if necessary
39+
if self.weight.dtype in [torch.float16, torch.bfloat16]:
40+
hidden_states = hidden_states.to(self.weight.dtype)
41+
return self.weight * hidden_states
42+
43+
44+
class LlamaRotaryEmbedding(torch.nn.Module):
45+
46+
def __init__(self, dim, max_position_embeddings=2048, base=10000):
47+
super().__init__()
48+
self.max_position_embeddings = max_position_embeddings
49+
50+
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2) / dim))
51+
self.register_buffer("inv_freq", inv_freq)
52+
53+
# Create cos and sin embeddings.
54+
t = torch.arange(max_position_embeddings).float()
55+
freqs = torch.einsum("i,j->ij", t, self.inv_freq.float())
56+
emb = torch.cat((freqs, freqs), dim=-1)
57+
cos = emb.cos().to(dtype=self.inv_freq.dtype)
58+
sin = emb.sin().to(dtype=self.inv_freq.dtype)
59+
self.register_buffer("cos_cached", cos, persistent=False)
60+
self.register_buffer("sin_cached", sin, persistent=False)
61+
62+
def forward(
63+
self,
64+
positions: torch.LongTensor,
65+
) -> Tuple[torch.Tensor, torch.Tensor]:
66+
cos = F.embedding(positions, self.cos_cached)
67+
sin = F.embedding(positions, self.sin_cached)
68+
return cos, sin
69+
70+
71+
def rotate_half(x):
72+
"""Rotates half the hidden dims of the input."""
73+
x1 = x[..., : x.shape[-1] // 2]
74+
x2 = x[..., x.shape[-1] // 2 :]
75+
return torch.cat((-x2, x1), dim=-1)
76+
77+
78+
def apply_rotary_pos_emb(q, k, cos, sin):
79+
# TODO: Optimize.
80+
q_embed = (q * cos) + (rotate_half(q) * sin)
81+
k_embed = (k * cos) + (rotate_half(k) * sin)
82+
return q_embed, k_embed
83+
84+
85+
class LlamaMLP(nn.Module):
86+
def __init__(
87+
self,
88+
hidden_size: int,
89+
intermediate_size: int,
90+
hidden_act: str,
91+
):
92+
super().__init__()
93+
# TODO: Merge the gate and down linear layers.
94+
self.gate_proj = ColumnParallelLinear(hidden_size, intermediate_size,
95+
bias=False, gather_output=False,
96+
perform_initialization=False)
97+
self.down_proj = RowParallelLinear(intermediate_size, hidden_size,
98+
bias=False, input_is_parallel=True,
99+
perform_initialization=False)
100+
self.up_proj = ColumnParallelLinear(hidden_size, intermediate_size,
101+
bias=False, gather_output=False,
102+
perform_initialization=False)
103+
assert hidden_act == 'silu'
104+
self.act_fn = nn.SiLU()
105+
106+
def forward(self, x):
107+
gate, _ = self.gate_proj(x)
108+
up, _ = self.up_proj(x)
109+
x = self.act_fn(gate) * up
110+
x, _ = self.down_proj(x)
111+
return x
112+
113+
114+
class LlamaAttention(nn.Module):
115+
116+
def __init__(
117+
self,
118+
hidden_size: int,
119+
num_heads: int,
120+
):
121+
super().__init__()
122+
self.hidden_size = hidden_size
123+
tensor_model_parallel_world_size = get_tensor_model_parallel_world_size()
124+
self.total_num_heads = num_heads
125+
assert self.total_num_heads % tensor_model_parallel_world_size == 0
126+
self.num_heads = self.total_num_heads // tensor_model_parallel_world_size
127+
self.head_dim = hidden_size // self.total_num_heads
128+
self.scaling = self.head_dim ** -0.5
129+
130+
# TODO: Merge the QKV linear layers.
131+
self.q_proj = ColumnParallelLinear(
132+
hidden_size,
133+
self.total_num_heads * self.head_dim,
134+
bias=False,
135+
gather_output=False,
136+
perform_initialization=False,
137+
)
138+
self.k_proj = ColumnParallelLinear(
139+
hidden_size,
140+
self.total_num_heads * self.head_dim,
141+
bias=False,
142+
gather_output=False,
143+
perform_initialization=False,
144+
)
145+
self.v_proj = ColumnParallelLinear(
146+
hidden_size,
147+
self.total_num_heads * self.head_dim,
148+
bias=False,
149+
gather_output=False,
150+
perform_initialization=False,
151+
)
152+
self.o_proj = RowParallelLinear(
153+
self.total_num_heads * self.head_dim,
154+
hidden_size,
155+
bias=False,
156+
input_is_parallel=True,
157+
perform_initialization=False,
158+
)
159+
self.rotary_emb = LlamaRotaryEmbedding(self.head_dim)
160+
# FIXME(woosuk): Rename this.
161+
self.attn = OPTCacheFlowAttention(scale=self.scaling)
162+
163+
def forward(
164+
self,
165+
positions: torch.LongTensor,
166+
hidden_states: torch.Tensor,
167+
kv_cache: KVCache,
168+
input_metadata: InputMetadata,
169+
cache_event: Optional[torch.cuda.Event],
170+
) -> torch.Tensor:
171+
q, _ = self.q_proj(hidden_states)
172+
k, _ = self.k_proj(hidden_states)
173+
v, _ = self.v_proj(hidden_states)
174+
175+
# Apply rotrary embedding.
176+
# TODO: Optimize.
177+
q = q.view(-1, self.num_heads, self.head_dim).transpose(0, 1)
178+
k = k.view(-1, self.num_heads, self.head_dim).transpose(0, 1)
179+
cos, sin = self.rotary_emb(positions)
180+
q, k = apply_rotary_pos_emb(q, k, cos, sin)
181+
q = q.transpose(0, 1).contiguous().view(-1, self.num_heads * self.head_dim)
182+
k = k.transpose(0, 1).contiguous().view(-1, self.num_heads * self.head_dim)
183+
184+
key_cache, value_cache = kv_cache
185+
attn_output = self.attn(
186+
q, k, v, key_cache, value_cache, input_metadata, cache_event)
187+
output, _ = self.o_proj(attn_output)
188+
return output
189+
190+
191+
class LlamaDecoderLayer(nn.Module):
192+
193+
def __init__(self, config: LlamaConfig):
194+
super().__init__()
195+
self.hidden_size = config.hidden_size
196+
self.self_attn = LlamaAttention(
197+
hidden_size=self.hidden_size,
198+
num_heads=config.num_attention_heads,
199+
)
200+
self.mlp = LlamaMLP(
201+
hidden_size=self.hidden_size,
202+
intermediate_size=config.intermediate_size,
203+
hidden_act=config.hidden_act,
204+
)
205+
self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
206+
self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
207+
208+
def forward(
209+
self,
210+
positions: torch.LongTensor,
211+
hidden_states: torch.Tensor,
212+
kv_cache: KVCache,
213+
input_metadata: InputMetadata,
214+
cache_event: Optional[torch.cuda.Event],
215+
) -> torch.Tensor:
216+
# Self Attention
217+
residual = hidden_states
218+
hidden_states = self.input_layernorm(hidden_states)
219+
hidden_states = self.self_attn(
220+
positions=positions,
221+
hidden_states=hidden_states,
222+
kv_cache=kv_cache,
223+
input_metadata=input_metadata,
224+
cache_event=cache_event,
225+
)
226+
hidden_states = residual + hidden_states
227+
228+
# Fully Connected
229+
residual = hidden_states
230+
hidden_states = self.post_attention_layernorm(hidden_states)
231+
hidden_states = self.mlp(hidden_states)
232+
hidden_states = residual + hidden_states
233+
return hidden_states
234+
235+
236+
class LlamaModel(nn.Module):
237+
238+
def __init__(self, config: LlamaConfig):
239+
super().__init__()
240+
self.config = config
241+
self.padding_idx = config.pad_token_id
242+
self.vocab_size = config.vocab_size
243+
244+
self.embed_tokens = VocabParallelEmbedding(config.vocab_size, config.hidden_size,
245+
perform_initialization=False)
246+
self.layers = nn.ModuleList([LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)])
247+
self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
248+
249+
def forward(
250+
self,
251+
input_ids: torch.LongTensor,
252+
positions: torch.LongTensor,
253+
kv_caches: List[KVCache],
254+
input_metadata: InputMetadata,
255+
cache_events: Optional[List[torch.cuda.Event]],
256+
) -> torch.Tensor:
257+
hidden_states = self.embed_tokens(input_ids)
258+
for i in range(len(self.layers)):
259+
if cache_events is None:
260+
cache_event = None
261+
else:
262+
cache_event = cache_events[i]
263+
layer = self.layers[i]
264+
hidden_states = layer(
265+
positions,
266+
hidden_states,
267+
kv_caches[i],
268+
input_metadata,
269+
cache_event,
270+
)
271+
hidden_states = self.norm(hidden_states)
272+
return hidden_states
273+
274+
275+
class LlamaForCausalLM(nn.Module):
276+
def __init__(self, config):
277+
super().__init__()
278+
self.config = config
279+
self.model = LlamaModel(config)
280+
self.lm_head = ColumnParallelLinear(config.hidden_size,
281+
config.vocab_size,
282+
bias=False,
283+
gather_output=False,
284+
perform_initialization=False)
285+
self.sampler = Sampler()
286+
287+
def forward(
288+
self,
289+
input_ids: torch.LongTensor,
290+
positions: torch.LongTensor,
291+
kv_caches: List[KVCache],
292+
input_metadata: InputMetadata,
293+
cache_events: Optional[List[torch.cuda.Event]],
294+
) -> Dict[int, SequenceOutputs]:
295+
hidden_states = self.model(
296+
input_ids, positions, kv_caches, input_metadata, cache_events)
297+
next_tokens = self.sampler(
298+
self.lm_head.weight, hidden_states, input_metadata)
299+
return next_tokens
300+
301+
_column_parallel_weights = ["embed_tokens.weight", "lm_head.weight",
302+
"q_proj.weight", "k_proj.weight",
303+
"v_proj.weight", "gate_proj.weight",
304+
"up_proj.weight"]
305+
_row_parallel_weights = ["o_proj.weight", "down_proj.weight"]
306+
307+
def load_weights(self, weights_path: str):
308+
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
309+
state_dict = self.state_dict()
310+
for name, param in state_dict.items():
311+
loaded_weight = torch.from_numpy(np.load(os.path.join(weights_path,
312+
name)))
313+
for p in self._column_parallel_weights:
314+
if p in name:
315+
shard_size = param.shape[0]
316+
loaded_weight = loaded_weight[
317+
shard_size * tensor_model_parallel_rank
318+
:shard_size * (tensor_model_parallel_rank + 1)]
319+
break
320+
for p in self._row_parallel_weights:
321+
if p in name:
322+
shard_size = param.shape[1]
323+
loaded_weight = loaded_weight[
324+
:,
325+
shard_size * tensor_model_parallel_rank
326+
:shard_size * (tensor_model_parallel_rank + 1)]
327+
break
328+
329+
assert param.shape == loaded_weight.shape
330+
param.data.copy_(loaded_weight)
331+
332+
@staticmethod
333+
def get_weights(model_name: str, path: str):
334+
if not os.path.isfile(os.path.join(model_name, "config.json")):
335+
raise ValueError("LLaMA model's model_name has to be a path"
336+
"to the huggingface model's directory.")
337+
path = os.path.join(model_name, f"np")
338+
path = os.path.abspath(os.path.expanduser(path))
339+
os.makedirs(path, exist_ok=True)
340+
lock_path = os.path.join(path, "file_lock")
341+
lock = filelock.FileLock(lock_path)
342+
343+
with lock:
344+
test_weight_path = os.path.join(path, "model.embed_tokens.weight")
345+
if os.path.exists(test_weight_path):
346+
return path
347+
348+
bin_files = glob.glob(os.path.join(model_name, "*.bin"))
349+
350+
for bin_file in tqdm(bin_files, desc="Convert format"):
351+
state = torch.load(bin_file, map_location="cpu")
352+
for name, param in tqdm(state.items(), leave=False):
353+
param_path = os.path.join(path, name)
354+
with open(param_path, "wb") as f:
355+
np.save(f, param.cpu().detach().numpy())
356+
357+
return path

0 commit comments

Comments
 (0)