diff --git a/README.md b/README.md index ffa890bd30e3..ee21e711cf0a 100644 --- a/README.md +++ b/README.md @@ -46,6 +46,7 @@ vLLM seamlessly supports many Huggingface models, including the following archit - GPT BigCode (`bigcode/starcoder`, `bigcode/gpt_bigcode-santacoder`, etc.) - GPT-NeoX (`EleutherAI/gpt-neox-20b`, `databricks/dolly-v2-12b`, `stabilityai/stablelm-tuned-alpha-7b`, etc.) - LLaMA (`lmsys/vicuna-13b-v1.3`, `young-geng/koala`, `openlm-research/open_llama_13b`, etc.) +- MPT (`mosaicml/mpt-7b`, `mosaicml/mpt-30b`, etc.) - OPT (`facebook/opt-66b`, `facebook/opt-iml-max-30b`, etc.) Install vLLM with pip or [from source](https://vllm.readthedocs.io/en/latest/getting_started/installation.html#build-from-source): diff --git a/csrc/attention/attention_kernels.cu b/csrc/attention/attention_kernels.cu index 203b8644a662..5389584e8661 100644 --- a/csrc/attention/attention_kernels.cu +++ b/csrc/attention/attention_kernels.cu @@ -395,6 +395,9 @@ void single_query_cached_kv_attention_launcher( case 96: LAUNCH_ATTENTION_KERNEL(T, 96, BLOCK_SIZE, NUM_THREADS); break; + case 112: + LAUNCH_ATTENTION_KERNEL(T, 112, BLOCK_SIZE, NUM_THREADS); + break; case 128: LAUNCH_ATTENTION_KERNEL(T, 128, BLOCK_SIZE, NUM_THREADS); break; diff --git a/docs/source/models/supported_models.rst b/docs/source/models/supported_models.rst index e29f27fcd70f..75b283e92f0e 100644 --- a/docs/source/models/supported_models.rst +++ b/docs/source/models/supported_models.rst @@ -29,6 +29,9 @@ Alongside each architecture, we include some popular models that use it. * - :code:`LlamaForCausalLM` - LLaMA, Vicuna, Alpaca, Koala, Guanaco - :code:`openlm-research/open_llama_13b`, :code:`lmsys/vicuna-13b-v1.3`, :code:`young-geng/koala`, :code:`JosephusCheung/Guanaco`, etc. + * - :code: `MPTForCausalLM` + - MPT, MPT-Instruct, MPT-Chat, MPT-StoryWriter + - :code:`mosaicml/mpt-7b`, :code:`mosaicml/mpt-7b-storywriter`, :code:`mosaicml/mpt-30b`, etc. * - :code:`OPTForCausalLM` - OPT, OPT-IML - :code:`facebook/opt-66b`, :code:`facebook/opt-iml-max-30b`, etc. diff --git a/vllm/config.py b/vllm/config.py index a5732bd0b0b4..2bc8e3f21dab 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1,9 +1,10 @@ from typing import Optional import torch -from transformers import AutoConfig, PretrainedConfig +from transformers import PretrainedConfig from vllm.logger import init_logger +from vllm.transformers_utils.config import get_config from vllm.utils import get_cpu_memory logger = init_logger(__name__) @@ -49,7 +50,7 @@ def __init__( self.use_dummy_weights = use_dummy_weights self.seed = seed - self.hf_config: PretrainedConfig = AutoConfig.from_pretrained(model) + self.hf_config = get_config(model) self.dtype = _get_and_verify_dtype(self.hf_config, dtype) self._verify_tokenizer_mode() diff --git a/vllm/model_executor/layers/attention.py b/vllm/model_executor/layers/attention.py index b350de8f5941..40f979b17c1d 100644 --- a/vllm/model_executor/layers/attention.py +++ b/vllm/model_executor/layers/attention.py @@ -12,7 +12,7 @@ from vllm import pos_encoding_ops from vllm.model_executor.input_metadata import InputMetadata -_SUPPORTED_HEAD_SIZES = [64, 80, 96, 128] +_SUPPORTED_HEAD_SIZES = [64, 80, 96, 112, 128] class PagedAttention(nn.Module): diff --git a/vllm/model_executor/model_loader.py b/vllm/model_executor/model_loader.py index f5e2793b3330..1d48baab7dac 100644 --- a/vllm/model_executor/model_loader.py +++ b/vllm/model_executor/model_loader.py @@ -16,7 +16,8 @@ "GPTBigCodeForCausalLM": GPTBigCodeForCausalLM, "GPTNeoXForCausalLM": GPTNeoXForCausalLM, "LlamaForCausalLM": LlamaForCausalLM, - "LLaMAForCausalLM": LlamaForCausalLM, + "LLaMAForCausalLM": LlamaForCausalLM, # For decapoda-research/llama-* + "MPTForCausalLM": MPTForCausalLM, "OPTForCausalLM": OPTForCausalLM, } diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index 251052a29620..64a4e6282f78 100644 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -3,6 +3,7 @@ from vllm.model_executor.models.gpt_bigcode import GPTBigCodeForCausalLM from vllm.model_executor.models.gpt_neox import GPTNeoXForCausalLM from vllm.model_executor.models.llama import LlamaForCausalLM +from vllm.model_executor.models.mpt import MPTForCausalLM from vllm.model_executor.models.opt import OPTForCausalLM __all__ = [ @@ -11,5 +12,6 @@ "GPTBigCodeForCausalLM", "GPTNeoXForCausalLM", "LlamaForCausalLM", + "MPTForCausalLM", "OPTForCausalLM", ] diff --git a/vllm/model_executor/models/mpt.py b/vllm/model_executor/models/mpt.py new file mode 100644 index 000000000000..09a0d7ce44f6 --- /dev/null +++ b/vllm/model_executor/models/mpt.py @@ -0,0 +1,279 @@ +# Adapted from https://huggingface.co/mosaicml/mpt-7b/tree/main +import math +from typing import Dict, List, Optional, Tuple + +import torch +import torch.nn as nn + +from vllm.model_executor.input_metadata import InputMetadata +from vllm.model_executor.layers.activation import get_act_fn +from vllm.model_executor.layers.attention import PagedAttentionWithALiBi +from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.weight_utils import (hf_model_weights_iterator, + load_tensor_parallel_weights) +from vllm.model_executor.parallel_utils.parallel_state import ( + get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) +from vllm.model_executor.parallel_utils.tensor_parallel import ( + VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear) +from vllm.sequence import SequenceOutputs +from vllm.transformers_utils.configs.mpt import MPTConfig + +KVCache = Tuple[torch.Tensor, torch.Tensor] + + +def _get_alibi_slopes( + total_num_heads: int, + alibi_bias_max: int, +) -> torch.Tensor: + next_power_of_2 = 2**math.ceil(math.log2(total_num_heads)) + m = torch.arange(1, next_power_of_2 + 1, dtype=torch.float32) + m = m.mul(alibi_bias_max / next_power_of_2) + slopes = 1.0 / torch.pow(2, m) + if next_power_of_2 != total_num_heads: + slopes = torch.concat([slopes[1::2], slopes[::2]])[:total_num_heads] + return slopes + + +class MPTAttention(nn.Module): + + def __init__(self, config: MPTConfig): + super().__init__() + self.d_model = config.d_model + self.total_num_heads = config.n_heads + self.clip_qkv = config.attn_config["clip_qkv"] + self.qk_ln = config.attn_config["qk_ln"] + self.alibi_bias_max = config.attn_config["alibi_bias_max"] + assert not config.attn_config["prefix_lm"] + assert config.attn_config["alibi"] + + self.qkv_proj = ColumnParallelLinear( + self.d_model, + 3 * self.d_model, + bias=not config.no_bias, + gather_output=False, + perform_initialization=False, + ) + if self.qk_ln: + self.q_ln = nn.LayerNorm(self.d_model) + self.k_ln = nn.LayerNorm(self.d_model) + self.out_proj = RowParallelLinear( + self.d_model, + self.d_model, + bias=not config.no_bias, + input_is_parallel=True, + perform_initialization=False, + ) + + tp_world_size = get_tensor_model_parallel_world_size() + assert self.total_num_heads % tp_world_size == 0 + self.num_heads = self.total_num_heads // tp_world_size + + # Create the alibi slopes and slice them. + tp_rank = get_tensor_model_parallel_rank() + head_start = tp_rank * self.num_heads + head_end = (tp_rank + 1) * self.num_heads + alibi_slopes = _get_alibi_slopes(self.total_num_heads, + self.alibi_bias_max) + alibi_slopes = alibi_slopes[head_start:head_end].tolist() + + self.head_dim = self.d_model // self.total_num_heads + scaling = self.head_dim**-0.5 + self.attn = PagedAttentionWithALiBi(self.num_heads, self.head_dim, + scaling, alibi_slopes) + + def forward( + self, + position_ids: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: KVCache, + input_metadata: InputMetadata, + cache_event: Optional[torch.cuda.Event], + ) -> torch.Tensor: + del position_ids # unused. + qkv, _ = self.qkv_proj(hidden_states) + if self.clip_qkv is not None: + qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv) + q, k, v = qkv.chunk(chunks=3, dim=-1) + if self.qk_ln: + q = self.q_ln(q) + k = self.k_ln(k) + k_cache, v_cache = kv_cache + attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata, + cache_event) + output, _ = self.out_proj(attn_output) + return output + + +class MPTMLP(nn.Module): + + def __init__(self, config: MPTConfig): + super().__init__() + hidden_size = config.d_model + expansion_ratio = config.expansion_ratio + intermediate_size = expansion_ratio * hidden_size + self.up_proj = ColumnParallelLinear(hidden_size, + intermediate_size, + bias=not config.no_bias, + gather_output=False, + perform_initialization=False) + self.act = get_act_fn("gelu") + self.down_proj = RowParallelLinear(intermediate_size, + hidden_size, + bias=not config.no_bias, + input_is_parallel=True, + perform_initialization=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x, _ = self.up_proj(x) + x = self.act(x) + x, _ = self.down_proj(x) + return x + + +class MPTBlock(nn.Module): + + def __init__(self, config: MPTConfig): + super().__init__() + hidden_size = config.d_model + self.norm_1 = nn.LayerNorm(hidden_size) + self.attn = MPTAttention(config) + self.norm_2 = nn.LayerNorm(hidden_size) + self.ffn = MPTMLP(config) + + def forward( + self, + position_ids: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: KVCache, + input_metadata: InputMetadata, + cache_event: Optional[torch.cuda.Event], + ) -> torch.Tensor: + x = self.norm_1(hidden_states) + x = self.attn( + position_ids=position_ids, + hidden_states=x, + kv_cache=kv_cache, + input_metadata=input_metadata, + cache_event=cache_event, + ) + hidden_states = hidden_states + x + x = self.norm_2(hidden_states) + x = self.ffn(x) + hidden_states = hidden_states + x + return hidden_states + + +class MPTModel(nn.Module): + + def __init__(self, config: MPTConfig): + super().__init__() + assert config.embedding_fraction == 1.0 + assert config.norm_type == "low_precision_layernorm" + + self.wte = VocabParallelEmbedding(config.vocab_size, + config.d_model, + perform_initialization=False) + self.blocks = nn.ModuleList( + [MPTBlock(config) for _ in range(config.n_layers)]) + self.norm_f = nn.LayerNorm(config.d_model) + if config.no_bias: + for module in self.modules(): + if hasattr(module, "bias"): + if isinstance(module.bias, nn.Parameter): + # Remove the bias term in Linear and LayerNorm. + module.register_parameter("bias", None) + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + kv_caches: List[KVCache], + input_metadata: InputMetadata, + cache_events: Optional[List[torch.cuda.Event]], + ) -> torch.Tensor: + hidden_states = self.wte(input_ids) + for i in range(len(self.blocks)): + if cache_events is None: + cache_event = None + else: + cache_event = cache_events[i] + block = self.blocks[i] + hidden_states = block( + position_ids, + hidden_states, + kv_caches[i], + input_metadata, + cache_event, + ) + hidden_states = self.norm_f(hidden_states) + return hidden_states + + +class MPTForCausalLM(nn.Module): + + def __init__(self, config: MPTConfig): + super().__init__() + self.config = config + assert config.tie_word_embeddings + + self.transformer = MPTModel(config) + # TODO(zhuohan): create a new weight after implementing pipeline + # parallelism + self.lm_head_weight = self.transformer.wte.weight + self.sampler = Sampler(config.vocab_size) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[KVCache], + input_metadata: InputMetadata, + cache_events: Optional[List[torch.cuda.Event]], + ) -> Dict[int, SequenceOutputs]: + hidden_states = self.transformer(input_ids, positions, kv_caches, + input_metadata, cache_events) + next_tokens = self.sampler(self.lm_head_weight, hidden_states, + input_metadata) + return next_tokens + + _column_parallel_weights = ["wte.weight", "up_proj.weight", "up_proj.bias"] + _row_parallel_weights = ["out_proj.weight", "down_proj.weight"] + + def load_weights(self, + model_name_or_path: str, + cache_dir: Optional[str] = None, + use_np_cache: bool = False): + tp_world_size = get_tensor_model_parallel_world_size() + tp_rank = get_tensor_model_parallel_rank() + state_dict = self.state_dict() + for name, loaded_weight in hf_model_weights_iterator( + model_name_or_path, cache_dir, use_np_cache): + if "Wqkv" in name: + # NOTE(woosuk): MPT's fused QKV has the shape of + # [3 * num_heads * head_size, hidden_size]. + # When tensor model parallelism is used, we need to shard + # the weight along the hidden dimension. + total_num_heads = self.config.num_attention_heads + hidden_size = self.config.hidden_size + head_size = hidden_size // total_num_heads + num_heads = total_num_heads // tp_world_size + head_start = tp_rank * num_heads + head_end = (tp_rank + 1) * num_heads + + if name.endswith(".weight"): + loaded_weight = loaded_weight.view(3, total_num_heads, + head_size, hidden_size) + loaded_weight = loaded_weight[:, head_start:head_end, :, :] + loaded_weight = loaded_weight.reshape(-1, hidden_size) + elif name.endswith(".bias"): + loaded_weight = loaded_weight.view(3, total_num_heads, + head_size) + loaded_weight = loaded_weight[:, head_start:head_end, :] + loaded_weight = loaded_weight.reshape(-1) + else: + raise ValueError(f"Unexpected parameter name {name}") + name = name.replace("Wqkv", "qkv_proj") + param = state_dict[name] + load_tensor_parallel_weights(param, loaded_weight, name, + self._column_parallel_weights, + self._row_parallel_weights, tp_rank) diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py new file mode 100644 index 000000000000..866b23bff098 --- /dev/null +++ b/vllm/transformers_utils/config.py @@ -0,0 +1,15 @@ +from transformers import AutoConfig, PretrainedConfig + +from vllm.transformers_utils.configs import * # pylint: disable=wildcard-import + +_CONFIG_REGISTRY = { + "mpt": MPTConfig, +} + + +def get_config(model: str) -> PretrainedConfig: + config = AutoConfig.from_pretrained(model, trust_remote_code=True) + if config.model_type in _CONFIG_REGISTRY: + config_class = _CONFIG_REGISTRY[config.model_type] + config = config_class.from_pretrained(model) + return config diff --git a/vllm/transformers_utils/configs/__init__.py b/vllm/transformers_utils/configs/__init__.py new file mode 100644 index 000000000000..fef4caf93e43 --- /dev/null +++ b/vllm/transformers_utils/configs/__init__.py @@ -0,0 +1,5 @@ +from vllm.transformers_utils.configs.mpt import MPTConfig + +__all__ = [ + "MPTConfig", +] diff --git a/vllm/transformers_utils/configs/mpt.py b/vllm/transformers_utils/configs/mpt.py new file mode 100644 index 000000000000..3909f710d44d --- /dev/null +++ b/vllm/transformers_utils/configs/mpt.py @@ -0,0 +1,74 @@ +# Adapted from +# https://huggingface.co/mosaicml/mpt-7b/blob/main/configuration_mpt.py +from typing import Any, Dict, Optional, Union + +from transformers import PretrainedConfig + +_ATTN_CONFIG_DEFAULTS = { + "attn_type": "multihead_attention", + "attn_pdrop": 0.0, + "attn_impl": "triton", + "qk_ln": False, + "clip_qkv": None, + "softmax_scale": None, + "prefix_lm": False, + "attn_uses_sequence_id": False, + "alibi": False, + "alibi_bias_max": 8, +} + + +class MPTConfig(PretrainedConfig): + model_type = "mpt" + attribute_map = { + "hidden_size": "d_model", + "num_attention_heads": "n_heads", + "num_hidden_layers": "n_layers", + } + + def __init__( + self, + d_model: int = 2048, + n_heads: int = 16, + n_layers: int = 24, + expansion_ratio: int = 4, + max_seq_len: int = 2048, + vocab_size: int = 50368, + resid_pdrop: float = 0.0, + emb_pdrop: float = 0.0, + learned_pos_emb: bool = True, + attn_config: Optional[Dict[str, Any]] = None, + init_device: str = "cpu", + logit_scale: Optional[Union[float, str]] = None, + no_bias: bool = False, + verbose: int = 0, + embedding_fraction: float = 1.0, + norm_type: str = "low_precision_layernorm", + use_cache: bool = False, + **kwargs, + ) -> None: + self.d_model = d_model + self.n_heads = n_heads + self.n_layers = n_layers + self.expansion_ratio = expansion_ratio + self.max_seq_len = max_seq_len + self.vocab_size = vocab_size + self.resid_pdrop = resid_pdrop + self.emb_pdrop = emb_pdrop + self.learned_pos_emb = learned_pos_emb + if attn_config is None: + self.attn_config = _ATTN_CONFIG_DEFAULTS + else: + self.attn_config = attn_config + self.init_device = init_device + self.logit_scale = logit_scale + self.no_bias = no_bias + self.verbose = verbose + self.embedding_fraction = embedding_fraction + self.norm_type = norm_type + self.use_cache = use_cache + if "name" in kwargs: + del kwargs["name"] + if "loss_fn" in kwargs: + del kwargs["loss_fn"] + super().__init__(**kwargs)