From b9dcf782600c61d2f340e51f3a81518e8c10003c Mon Sep 17 00:00:00 2001 From: woWoosuk Kwon Date: Wed, 3 May 2023 09:41:16 +0000 Subject: [PATCH 1/6] Add GPT-2 without TP --- cacheflow/models/gpt2.py | 256 ++++++++++++++++++++++++++++ cacheflow/models/memory_analyzer.py | 70 ++++++++ cacheflow/models/model_utils.py | 4 + 3 files changed, 330 insertions(+) create mode 100644 cacheflow/models/gpt2.py diff --git a/cacheflow/models/gpt2.py b/cacheflow/models/gpt2.py new file mode 100644 index 000000000000..186cf48eaaa8 --- /dev/null +++ b/cacheflow/models/gpt2.py @@ -0,0 +1,256 @@ +"""1D GPT-2 model compatible with HuggingFace weights.""" +from typing import Dict, List, Optional, Tuple + +import torch +from torch import nn +from transformers import GPT2Config + +from cacheflow.models import InputMetadata +from cacheflow.models.attention import GPTCacheFlowAttention +from cacheflow.models.sample import Sampler +from cacheflow.models.utils import (hf_model_weights_iterator, + load_tensor_parallel_weights) +from cacheflow.parallel_utils.parallel_state import ( + get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) +from cacheflow.parallel_utils.tensor_parallel import (VocabParallelEmbedding, + ColumnParallelLinear, + RowParallelLinear) +from cacheflow.sequence import SequenceOutputs + +KVCache = Tuple[torch.Tensor, torch.Tensor] + + +class Conv1D(nn.Module): + """ + 1D-convolutional layer as defined by Radford et al. for OpenAI GPT (and also used in GPT-2). + + Basically works like a linear layer but the weights are transposed. + + Args: + nf (`int`): The number of output features. + nx (`int`): The number of input features. + """ + + def __init__(self, nf, nx): + super().__init__() + self.nf = nf + self.weight = nn.Parameter(torch.empty(nx, nf)) + self.bias = nn.Parameter(torch.zeros(nf)) + nn.init.normal_(self.weight, std=0.02) + + def forward(self, x): + size_out = x.size()[:-1] + (self.nf,) + x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight) + x = x.view(size_out) + return x + + +class GPT2Attention(nn.Module): + + def __init__(self, config: GPT2Config): + super().__init__() + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + self.scale = self.head_dim ** -0.5 + + if self.head_dim * self.num_heads != self.embed_dim: + raise ValueError( + f"`embed_dim` must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" + f" {self.num_heads})." + ) + + self.c_attn = Conv1D(3 * self.embed_dim, self.embed_dim) + self.c_proj = Conv1D(self.embed_dim, self.embed_dim) + self.attn = GPTCacheFlowAttention(scale=self.scale) + + def forward( + self, + hidden_states: torch.Tensor, + kv_cache: KVCache, + input_metadata: InputMetadata, + cache_event: Optional[torch.cuda.Event], + ) -> torch.Tensor: + qkv = self.c_attn(hidden_states) + q, k, v = qkv.chunk(chunks=3, dim=-1) + key_cache, value_cache = kv_cache + attn_output = self.attn( + q, k, v, key_cache, value_cache, input_metadata, cache_event) + attn_output = self.c_proj(attn_output) + return attn_output + + +class GPT2MLP(nn.Module): + + def __init__( + self, + intermediate_size: int, + config: GPT2Config, + ): + super().__init__() + hidden_size = config.hidden_size + self.c_fc = Conv1D(intermediate_size, hidden_size) + self.c_proj = Conv1D(hidden_size, intermediate_size) + assert config.activation_function == 'gelu_new' + self.act = torch.nn.GELU(approximate='tanh') + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.c_fc(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.c_proj(hidden_states) + return hidden_states + + +class GPT2Block(nn.Module): + + def __init__(self, config: GPT2Config): + super().__init__() + hidden_size = config.hidden_size + inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size + + self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + self.attn = GPT2Attention(config) + self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + self.mlp = GPT2MLP(inner_dim, config) + + def forward( + self, + hidden_states: torch.Tensor, + kv_cache: KVCache, + input_metadata: InputMetadata, + cache_event: Optional[torch.cuda.Event], + ) -> torch.Tensor: + residual = hidden_states + hidden_states = self.ln_1(hidden_states) + attn_output = self.attn( + hidden_states=hidden_states, + kv_cache=kv_cache, + input_metadata=input_metadata, + cache_event=cache_event, + ) + # residual connection + hidden_states = attn_output + residual + + residual = hidden_states + hidden_states = self.ln_2(hidden_states) + feed_forward_hidden_states = self.mlp(hidden_states) + # residual connection + hidden_states = residual + feed_forward_hidden_states + return hidden_states + + +class GPT2Model(nn.Module): + + def __init__(self, config: GPT2Config): + super().__init__() + self.config = config + assert config.add_cross_attention == False + assert config.scale_attn_by_inverse_layer_idx == False + assert config.reorder_and_upcast_attn == False + + self.embed_dim = config.hidden_size + self.wte = nn.Embedding(config.vocab_size, self.embed_dim) + self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim) + self.h = nn.ModuleList( + [GPT2Block(config) for _ in range(config.num_hidden_layers)]) + self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) + + def forward( + self, + input_ids: torch.LongTensor, + position_ids: torch.LongTensor, + kv_caches: List[KVCache], + input_metadata: InputMetadata, + cache_events: Optional[List[torch.cuda.Event]], + ) -> torch.Tensor: + inputs_embeds = self.wte(input_ids) + position_embeds = self.wpe(position_ids) + hidden_states = inputs_embeds + position_embeds + + for i in range(len(self.h)): + if cache_events is None: + cache_event = None + else: + cache_event = cache_events[i] + layer = self.h[i] + hidden_states = layer( + hidden_states, kv_caches[i], input_metadata, cache_event) + + hidden_states = self.ln_f(hidden_states) + return hidden_states + + +class GPT2LMHeadModel(nn.Module): + + def __init__(self, config: GPT2Config): + super().__init__() + self.config = config + self.transformer = GPT2Model(config) + # TODO(zhuohan): create a new weight after implementing pipeline + # parallelism + self.lm_head_weight = self.transformer.wte.weight + self.sampler = Sampler() + + def forward( + self, + input_ids: torch.LongTensor, + positions: torch.LongTensor, + kv_caches: List[KVCache], + input_metadata: InputMetadata, + cache_events: Optional[List[torch.cuda.Event]], + ) -> Dict[int, SequenceOutputs]: + hidden_states = self.transformer( + input_ids, positions, kv_caches, input_metadata, cache_events) + next_tokens = self.sampler( + self.lm_head_weight, hidden_states, input_metadata) + return next_tokens + + # _column_parallel_weights = ["embed_tokens.weight", "fc1.weight", "fc1.bias"] + # _row_parallel_weights = ["out_proj.weight", "fc2.weight"] + _column_parallel_weights = [] + _row_parallel_weights = [] + + def load_weights(self, model_name_or_path: str, + cache_dir: Optional[str] = None, + use_np_cache: bool = False): + tensor_model_parallel_rank = get_tensor_model_parallel_rank() + state_dict = self.state_dict() + + for name, loaded_weight in hf_model_weights_iterator( + model_name_or_path, cache_dir, use_np_cache): + if "lm_head.weight" in name: + # GPT-2 ties the weights of the embedding layer and the final + # linear layer. + continue + if ".attn.bias" in name: + # Skip attention mask. + # NOTE: "c_attn.bias" should not be skipped. + continue + name = "transformer." + name + + # is_attention_weight = False + # for stride_id, att_weight_name in enumerate(["q_proj", "k_proj", "v_proj"]): + # if att_weight_name not in name: + # continue + # param = state_dict[name.replace(att_weight_name, "qkv_proj")] + # shard_size = param.shape[0] // 3 + # loaded_weight = loaded_weight[ + # shard_size * tensor_model_parallel_rank + # :shard_size * (tensor_model_parallel_rank + 1)] + # param_slice = param.data[shard_size * stride_id + # :shard_size * (stride_id + 1)] + # assert param_slice.shape == loaded_weight.shape + # param_slice.copy_(loaded_weight) + # is_attention_weight = True + # break + # if is_attention_weight: + # continue + + param = state_dict[name] + load_tensor_parallel_weights(param, loaded_weight, name, + self._column_parallel_weights, + self._row_parallel_weights) + + def initialize_dummy_weights(self) -> None: + for param in self.state_dict().values(): + param.data.uniform_(-1e-3, 1e-3) diff --git a/cacheflow/models/memory_analyzer.py b/cacheflow/models/memory_analyzer.py index 738c6d11d023..2f15052aeb38 100644 --- a/cacheflow/models/memory_analyzer.py +++ b/cacheflow/models/memory_analyzer.py @@ -72,6 +72,76 @@ def get_max_num_gpu_blocks( return max_num_blocks +class GPT2MemoryAnalyzer(CacheFlowMemoryAnalyzer): + + def __init__( + self, + model_name: str, + block_size: int, + dtype: torch.dtype, + gpu_memory: int, + cpu_memory: int, + tensor_parallel_size: int, + ) -> None: + self.model_name = model_name + self.block_size = block_size + self.dtype = dtype + self.gpu_memory = gpu_memory + self.cpu_memory = cpu_memory + self.tensor_parallel_size = tensor_parallel_size + + config = AutoConfig.from_pretrained(model_name) + self.num_layers = config.num_hidden_layers + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_size = config.hidden_size // self.num_heads + self.ffn_size = config.n_inner if config.n_inner is not None else 4 * self.hidden_size + self.vocab_size = config.vocab_size + self.max_position = config.max_position_embeddings + + def get_param_size(self) -> int: + word_embedding = self.vocab_size * self.hidden_size // self.tensor_parallel_size + position_embedding = self.max_position * self.hidden_size + + ln1 = 2 * self.hidden_size + q = self.hidden_size * self.hidden_size // self.tensor_parallel_size + self.hidden_size + k = self.hidden_size * self.hidden_size // self.tensor_parallel_size + self.hidden_size + v = self.hidden_size * self.hidden_size // self.tensor_parallel_size + self.hidden_size + out = self.hidden_size * self.hidden_size // self.tensor_parallel_size + self.hidden_size + mha = ln1 + q + k + v + out + + ln2 = 2 * self.hidden_size + ffn1 = self.hidden_size * self.ffn_size // self.tensor_parallel_size + self.ffn_size + ffn2 = self.ffn_size * self.hidden_size // self.tensor_parallel_size + self.hidden_size + ffn = ln2 + ffn1 + ffn2 + + total = (word_embedding + position_embedding + + self.num_layers * (mha + ffn)) + dtype_size = get_dtype_size(self.dtype) + return dtype_size * total + + def get_max_act_size( + self, + max_num_batched_tokens: int, + ) -> int: + # NOTE: We approxmiately calculate the maximum activation size by + # estimating + # 1) the maximum activation tensor size during inference + # 2) the residual tensor size during inference + # Here, we assume that FlashAttention is used and + # thus the attention maps are never materialized in GPU DRAM. + residual = max_num_batched_tokens * self.hidden_size + qkv = 3 * (max_num_batched_tokens * self.hidden_size) // self.tensor_parallel_size + ffn = max_num_batched_tokens * self.ffn_size // self.tensor_parallel_size + # Double the activation size for input and output. + max_act = 2 * (max(qkv, ffn) + residual) + # Size of output logits. + output_logits = 2 * (max_num_batched_tokens * self.vocab_size) + max_act = max(max_act, output_logits) + dtype_size = get_dtype_size(self.dtype) + return dtype_size * max_act + + class OPTMemoryAnalyzer(CacheFlowMemoryAnalyzer): def __init__( diff --git a/cacheflow/models/model_utils.py b/cacheflow/models/model_utils.py index ec3f0e006869..67fdd0f214b5 100644 --- a/cacheflow/models/model_utils.py +++ b/cacheflow/models/model_utils.py @@ -5,9 +5,11 @@ from transformers import AutoConfig from cacheflow.models.memory_analyzer import CacheFlowMemoryAnalyzer +from cacheflow.models.memory_analyzer import GPT2MemoryAnalyzer from cacheflow.models.memory_analyzer import GPTNeoXMemoryAnalyzer from cacheflow.models.memory_analyzer import LlamaMemoryAnalyzer from cacheflow.models.memory_analyzer import OPTMemoryAnalyzer +from cacheflow.models.gpt2 import GPT2LMHeadModel from cacheflow.models.gpt_neox import GPTNeoXForCausalLM from cacheflow.models.llama import LlamaForCausalLM from cacheflow.models.opt import OPTForCausalLM @@ -15,6 +17,7 @@ _MODELS = { + 'gpt2': GPT2LMHeadModel, 'llama': LlamaForCausalLM, 'opt': OPTForCausalLM, 'stablelm': GPTNeoXForCausalLM, @@ -22,6 +25,7 @@ } _MEMORY_ANALYZERS = { + 'gpt2': GPT2MemoryAnalyzer, 'llama': LlamaMemoryAnalyzer, 'opt': OPTMemoryAnalyzer, 'stablelm': GPTNeoXMemoryAnalyzer, From 8066235ce2eb2d94b3fd9efd9265fdc3855c5f7e Mon Sep 17 00:00:00 2001 From: woWoosuk Kwon Date: Wed, 3 May 2023 10:38:50 +0000 Subject: [PATCH 2/6] [WIP] Support TP --- cacheflow/models/gpt2.py | 118 +++++++++++++++-------------------- cacheflow/models/gpt_neox.py | 4 +- cacheflow/models/sample.py | 4 +- 3 files changed, 57 insertions(+), 69 deletions(-) diff --git a/cacheflow/models/gpt2.py b/cacheflow/models/gpt2.py index 186cf48eaaa8..44f536504bdd 100644 --- a/cacheflow/models/gpt2.py +++ b/cacheflow/models/gpt2.py @@ -20,48 +20,24 @@ KVCache = Tuple[torch.Tensor, torch.Tensor] -class Conv1D(nn.Module): - """ - 1D-convolutional layer as defined by Radford et al. for OpenAI GPT (and also used in GPT-2). - - Basically works like a linear layer but the weights are transposed. - - Args: - nf (`int`): The number of output features. - nx (`int`): The number of input features. - """ - - def __init__(self, nf, nx): - super().__init__() - self.nf = nf - self.weight = nn.Parameter(torch.empty(nx, nf)) - self.bias = nn.Parameter(torch.zeros(nf)) - nn.init.normal_(self.weight, std=0.02) - - def forward(self, x): - size_out = x.size()[:-1] + (self.nf,) - x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight) - x = x.view(size_out) - return x - - class GPT2Attention(nn.Module): def __init__(self, config: GPT2Config): super().__init__() - self.embed_dim = config.hidden_size - self.num_heads = config.num_attention_heads - self.head_dim = self.embed_dim // self.num_heads + self.hidden_size = config.hidden_size + total_num_heads = config.num_attention_heads + tensor_model_parallel_world_size = get_tensor_model_parallel_world_size() + assert total_num_heads % tensor_model_parallel_world_size == 0 + self.num_heads = total_num_heads // tensor_model_parallel_world_size + self.head_dim = self.hidden_size // total_num_heads self.scale = self.head_dim ** -0.5 - if self.head_dim * self.num_heads != self.embed_dim: - raise ValueError( - f"`embed_dim` must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" - f" {self.num_heads})." - ) - - self.c_attn = Conv1D(3 * self.embed_dim, self.embed_dim) - self.c_proj = Conv1D(self.embed_dim, self.embed_dim) + self.c_attn = ColumnParallelLinear(self.hidden_size, 3 * self.hidden_size, bias=True, + gather_output=False, + perform_initialization=False) + self.c_proj = RowParallelLinear(self.hidden_size, self.hidden_size, bias=True, + input_is_parallel=True, + perform_initialization=False) self.attn = GPTCacheFlowAttention(scale=self.scale) def forward( @@ -71,12 +47,12 @@ def forward( input_metadata: InputMetadata, cache_event: Optional[torch.cuda.Event], ) -> torch.Tensor: - qkv = self.c_attn(hidden_states) + qkv, _ = self.c_attn(hidden_states) q, k, v = qkv.chunk(chunks=3, dim=-1) key_cache, value_cache = kv_cache attn_output = self.attn( q, k, v, key_cache, value_cache, input_metadata, cache_event) - attn_output = self.c_proj(attn_output) + attn_output, _ = self.c_proj(attn_output) return attn_output @@ -89,15 +65,20 @@ def __init__( ): super().__init__() hidden_size = config.hidden_size - self.c_fc = Conv1D(intermediate_size, hidden_size) - self.c_proj = Conv1D(hidden_size, intermediate_size) + self.c_fc = ColumnParallelLinear(hidden_size, intermediate_size, + bias=True, gather_output=False, + perform_initialization=False) + self.c_proj = RowParallelLinear(intermediate_size, hidden_size, + bias=True, input_is_parallel=True, + perform_initialization=False) + assert config.activation_function == 'gelu_new' self.act = torch.nn.GELU(approximate='tanh') def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - hidden_states = self.c_fc(hidden_states) + hidden_states, _ = self.c_fc(hidden_states) hidden_states = self.act(hidden_states) - hidden_states = self.c_proj(hidden_states) + hidden_states, _ = self.c_proj(hidden_states) return hidden_states @@ -149,7 +130,8 @@ def __init__(self, config: GPT2Config): assert config.reorder_and_upcast_attn == False self.embed_dim = config.hidden_size - self.wte = nn.Embedding(config.vocab_size, self.embed_dim) + assert config.vocab_size == 50257 + self.wte = VocabParallelEmbedding(50304, self.embed_dim) self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim) self.h = nn.ModuleList( [GPT2Block(config) for _ in range(config.num_hidden_layers)]) @@ -189,7 +171,7 @@ def __init__(self, config: GPT2Config): # TODO(zhuohan): create a new weight after implementing pipeline # parallelism self.lm_head_weight = self.transformer.wte.weight - self.sampler = Sampler() + self.sampler = Sampler(config.vocab_size) def forward( self, @@ -205,10 +187,8 @@ def forward( self.lm_head_weight, hidden_states, input_metadata) return next_tokens - # _column_parallel_weights = ["embed_tokens.weight", "fc1.weight", "fc1.bias"] - # _row_parallel_weights = ["out_proj.weight", "fc2.weight"] - _column_parallel_weights = [] - _row_parallel_weights = [] + _column_parallel_weights = ["wte.weight", "c_fc.weight", "c_fc.bias"] + _row_parallel_weights = ["c_proj.weight"] def load_weights(self, model_name_or_path: str, cache_dir: Optional[str] = None, @@ -228,25 +208,31 @@ def load_weights(self, model_name_or_path: str, continue name = "transformer." + name - # is_attention_weight = False - # for stride_id, att_weight_name in enumerate(["q_proj", "k_proj", "v_proj"]): - # if att_weight_name not in name: - # continue - # param = state_dict[name.replace(att_weight_name, "qkv_proj")] - # shard_size = param.shape[0] // 3 - # loaded_weight = loaded_weight[ - # shard_size * tensor_model_parallel_rank - # :shard_size * (tensor_model_parallel_rank + 1)] - # param_slice = param.data[shard_size * stride_id - # :shard_size * (stride_id + 1)] - # assert param_slice.shape == loaded_weight.shape - # param_slice.copy_(loaded_weight) - # is_attention_weight = True - # break - # if is_attention_weight: - # continue - + # Optimization: While the vocab size of GPT-2 is 50257, we + # extend it to 50304 to make it divisible by 64. + if name == "transformer.wte.weight": + extra_rows = torch.empty(47, loaded_weight.shape[1]).to( + loaded_weight) + loaded_weight = torch.cat([loaded_weight, extra_rows], dim=0) + + # The HF's GPT-2 implementation uses Conv1D instead of Linear. + # Because of this, we need to transpose the weight. + for conv1d_weight_name in ["c_attn", "c_proj", "c_fc"]: + if conv1d_weight_name not in name: + continue + if not name.endswith(".weight"): + continue + loaded_weight = loaded_weight.t() param = state_dict[name] + + # For the fused QKV linear layer, manually shard the weights. + if "c_attn" in name: + # print(name, param.shape) + shard_size = param.shape[0] + # FIXME + loaded_weight = loaded_weight[shard_size * tensor_model_parallel_rank: + shard_size * (tensor_model_parallel_rank + 1)] + # print(name, loaded_weight.shape) load_tensor_parallel_weights(param, loaded_weight, name, self._column_parallel_weights, self._row_parallel_weights) diff --git a/cacheflow/models/gpt_neox.py b/cacheflow/models/gpt_neox.py index 9fe332d8381c..1d8a1231a349 100644 --- a/cacheflow/models/gpt_neox.py +++ b/cacheflow/models/gpt_neox.py @@ -205,8 +205,8 @@ def load_weights(self, model_name_or_path: str, param = state_dict[name] if "query_key_value" in name: # NOTE(woosuk): GPT-NeoX's fused QKV has the shape of - # [num_heads * 3 * head_size, num_heads * head_size], while the - # required shape is [3 * num_heads * head_size, num_heads * head_size]. + # [num_heads * 3 * head_size, hidden_size], while the + # required shape is [3 * num_heads * head_size, hidden_size]. # Thus, we need weight conversion. shard_size = param.shape[0] loaded_weight = loaded_weight[shard_size * tensor_model_parallel_rank diff --git a/cacheflow/models/sample.py b/cacheflow/models/sample.py index 1e358c7e5278..e9290120b5fa 100644 --- a/cacheflow/models/sample.py +++ b/cacheflow/models/sample.py @@ -11,8 +11,9 @@ class Sampler(nn.Module): - def __init__(self) -> None: + def __init__(self, vocab_size: int) -> None: super().__init__() + self.vocab_size = vocab_size def forward( self, @@ -26,6 +27,7 @@ def forward( # Get the logits for the next tokens. logits = torch.matmul(hidden_states, embedding.t()) logits = gather_from_tensor_model_parallel_region(logits) + logits = logits[:, :self.vocab_size] # Apply temperature scaling. temperatures = _get_temperatures(input_metadata) From 268da23b071084e991704f831dba1df7c545cf58 Mon Sep 17 00:00:00 2001 From: woWoosuk Kwon Date: Wed, 3 May 2023 20:09:10 +0000 Subject: [PATCH 3/6] Support TP for GPT-2 --- cacheflow/models/gpt2.py | 54 +++++++++++++++++++++++++++------------- 1 file changed, 37 insertions(+), 17 deletions(-) diff --git a/cacheflow/models/gpt2.py b/cacheflow/models/gpt2.py index 44f536504bdd..01a52cb6a62d 100644 --- a/cacheflow/models/gpt2.py +++ b/cacheflow/models/gpt2.py @@ -128,10 +128,15 @@ def __init__(self, config: GPT2Config): assert config.add_cross_attention == False assert config.scale_attn_by_inverse_layer_idx == False assert config.reorder_and_upcast_attn == False - self.embed_dim = config.hidden_size - assert config.vocab_size == 50257 - self.wte = VocabParallelEmbedding(50304, self.embed_dim) + + # Optimization: While the vocab size of GPT-2 is 50257, we extend it + # to 50304 in order to make it divisible by 64. + # This improves performance since GPUs are faster if the dimension + # is divisible by 64. In addition, it allows us to shard the embedding + # layer across 2, 4, 8, or more GPUs. + vocab_size = ((config.vocab_size + 63) // 64) * 64 + self.wte = VocabParallelEmbedding(vocab_size, self.embed_dim) self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim) self.h = nn.ModuleList( [GPT2Block(config) for _ in range(config.num_hidden_layers)]) @@ -193,6 +198,7 @@ def forward( def load_weights(self, model_name_or_path: str, cache_dir: Optional[str] = None, use_np_cache: bool = False): + tensor_model_parallel_world_size = get_tensor_model_parallel_world_size() tensor_model_parallel_rank = get_tensor_model_parallel_rank() state_dict = self.state_dict() @@ -208,15 +214,8 @@ def load_weights(self, model_name_or_path: str, continue name = "transformer." + name - # Optimization: While the vocab size of GPT-2 is 50257, we - # extend it to 50304 to make it divisible by 64. - if name == "transformer.wte.weight": - extra_rows = torch.empty(47, loaded_weight.shape[1]).to( - loaded_weight) - loaded_weight = torch.cat([loaded_weight, extra_rows], dim=0) - # The HF's GPT-2 implementation uses Conv1D instead of Linear. - # Because of this, we need to transpose the weight. + # Because of this, we need to transpose the weights. for conv1d_weight_name in ["c_attn", "c_proj", "c_fc"]: if conv1d_weight_name not in name: continue @@ -225,14 +224,35 @@ def load_weights(self, model_name_or_path: str, loaded_weight = loaded_weight.t() param = state_dict[name] + if name == "transformer.wte.weight": + # Consider padding in the vocab size. + padded_vocab_size = param.shape[0] * tensor_model_parallel_world_size + num_extra_rows = padded_vocab_size - self.config.vocab_size + extra_rows = torch.empty(num_extra_rows, loaded_weight.shape[1]) + extra_rows = extra_rows.to(loaded_weight) + loaded_weight = torch.cat([loaded_weight, extra_rows], dim=0) + # For the fused QKV linear layer, manually shard the weights. if "c_attn" in name: - # print(name, param.shape) - shard_size = param.shape[0] - # FIXME - loaded_weight = loaded_weight[shard_size * tensor_model_parallel_rank: - shard_size * (tensor_model_parallel_rank + 1)] - # print(name, loaded_weight.shape) + # GPT-2's fused QKV has the shape of [3 * num_heads * head_size, hidden_size]. + # When tensor parallelism is used, we shard the weights along the head dimension. + total_num_heads = self.config.num_attention_heads + hidden_size = self.config.hidden_size + head_size = hidden_size // total_num_heads + num_heads = total_num_heads // tensor_model_parallel_world_size + head_start = tensor_model_parallel_rank * num_heads + head_end = (tensor_model_parallel_rank + 1) * num_heads + + if name.endswith(".weight"): + loaded_weight = loaded_weight.view(3, total_num_heads, head_size, hidden_size) + loaded_weight = loaded_weight[:, head_start:head_end, :, :] + loaded_weight = loaded_weight.reshape(-1, hidden_size) + elif name.endswith(".bias"): + loaded_weight = loaded_weight.view(3, total_num_heads, head_size) + loaded_weight = loaded_weight[:, head_start:head_end, :] + loaded_weight = loaded_weight.reshape(-1) + else: + raise ValueError(f"Unexpected parameter name {name}") load_tensor_parallel_weights(param, loaded_weight, name, self._column_parallel_weights, self._row_parallel_weights) From 3cdae7c89f364fffc8ce3f98f5fc979bb05d2ec6 Mon Sep 17 00:00:00 2001 From: woWoosuk Kwon Date: Wed, 3 May 2023 20:09:44 +0000 Subject: [PATCH 4/6] Minor --- cacheflow/models/sample.py | 1 + 1 file changed, 1 insertion(+) diff --git a/cacheflow/models/sample.py b/cacheflow/models/sample.py index e9290120b5fa..dc488c814441 100644 --- a/cacheflow/models/sample.py +++ b/cacheflow/models/sample.py @@ -27,6 +27,7 @@ def forward( # Get the logits for the next tokens. logits = torch.matmul(hidden_states, embedding.t()) logits = gather_from_tensor_model_parallel_region(logits) + # Remove paddings in vocab. logits = logits[:, :self.vocab_size] # Apply temperature scaling. From f57622564e33fecc6b0bb2ade1f168315d196cae Mon Sep 17 00:00:00 2001 From: woWoosuk Kwon Date: Wed, 3 May 2023 20:10:40 +0000 Subject: [PATCH 5/6] Fix --- cacheflow/models/gpt_neox.py | 6 +++--- cacheflow/models/llama.py | 2 +- cacheflow/models/opt.py | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/cacheflow/models/gpt_neox.py b/cacheflow/models/gpt_neox.py index 1d8a1231a349..fb85e2f72454 100644 --- a/cacheflow/models/gpt_neox.py +++ b/cacheflow/models/gpt_neox.py @@ -173,7 +173,7 @@ def __init__(self, config): self.embed_out = ColumnParallelLinear(config.hidden_size, config.vocab_size, bias=False, gather_output=False, perform_initialization=False) - self.sampler = Sampler() + self.sampler = Sampler(config.vocab_size) def forward( self, @@ -218,11 +218,11 @@ def load_weights(self, model_name_or_path: str, if 'query_key_value.weight' in name: loaded_weight = loaded_weight.view(-1, 3, head_size, hidden_size) loaded_weight = loaded_weight.transpose(0, 1) - loaded_weight = loaded_weight.reshape(-1, hidden_size).contiguous() + loaded_weight = loaded_weight.reshape(-1, hidden_size) elif 'query_key_value.bias' in name: loaded_weight = loaded_weight.view(-1, 3, head_size) loaded_weight = loaded_weight.transpose(0, 1) - loaded_weight = loaded_weight.reshape(-1).contiguous() + loaded_weight = loaded_weight.reshape(-1) else: raise ValueError(f"Unexpected weight name: {name}") load_tensor_parallel_weights(param, loaded_weight, name, diff --git a/cacheflow/models/llama.py b/cacheflow/models/llama.py index 1eda7f23d077..706650301ee1 100644 --- a/cacheflow/models/llama.py +++ b/cacheflow/models/llama.py @@ -192,7 +192,7 @@ def __init__(self, config): bias=False, gather_output=False, perform_initialization=False) - self.sampler = Sampler() + self.sampler = Sampler(config.vocab_size) def forward( self, diff --git a/cacheflow/models/opt.py b/cacheflow/models/opt.py index 15f0f688d1af..79b81cd0e3a4 100644 --- a/cacheflow/models/opt.py +++ b/cacheflow/models/opt.py @@ -234,7 +234,7 @@ def __init__(self, config): # TODO(zhuohan): create a new weight after implementing pipeline # parallelism self.lm_head_weight = self.model.decoder.embed_tokens.weight - self.sampler = Sampler() + self.sampler = Sampler(config.vocab_size) def forward( self, From 1b4edeab36d267d5e3fa221d7dc6be8be07c15d0 Mon Sep 17 00:00:00 2001 From: woWoosuk Kwon Date: Wed, 3 May 2023 20:19:41 +0000 Subject: [PATCH 6/6] Minor --- cacheflow/models/gpt2.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/cacheflow/models/gpt2.py b/cacheflow/models/gpt2.py index 01a52cb6a62d..1b30ced28aa6 100644 --- a/cacheflow/models/gpt2.py +++ b/cacheflow/models/gpt2.py @@ -72,8 +72,11 @@ def __init__( bias=True, input_is_parallel=True, perform_initialization=False) - assert config.activation_function == 'gelu_new' - self.act = torch.nn.GELU(approximate='tanh') + act_fn = config.activation_function + if act_fn != "gelu_new": + raise ValueError(f"Unsupported activation: {act_fn}. " + "GPT-2 only supports gelu_new for now.") + self.act = torch.nn.GELU(approximate="tanh") def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states, _ = self.c_fc(hidden_states)