From 5a3ddefc5a509b4eab536f6ccb000dc1b8356019 Mon Sep 17 00:00:00 2001 From: Lalit Pradhan Date: Mon, 4 Mar 2024 19:36:52 +0000 Subject: [PATCH 01/31] updated code for jais --- README.md | 1 + docs/source/models/supported_models.rst | 3 + tests/models/test_models.py | 1 + vllm/model_executor/models/__init__.py | 1 + vllm/model_executor/models/jais.py | 320 ++++++++++++++++++++ vllm/transformers_utils/config.py | 1 + vllm/transformers_utils/configs/__init__.py | 2 + vllm/transformers_utils/configs/jais.py | 196 ++++++++++++ 8 files changed, 525 insertions(+) create mode 100644 vllm/model_executor/models/jais.py create mode 100644 vllm/transformers_utils/configs/jais.py diff --git a/README.md b/README.md index 064faa550f26..11786095b34e 100644 --- a/README.md +++ b/README.md @@ -67,6 +67,7 @@ vLLM seamlessly supports many Hugging Face models, including the following archi - GPT-NeoX (`EleutherAI/gpt-neox-20b`, `databricks/dolly-v2-12b`, `stabilityai/stablelm-tuned-alpha-7b`, etc.) - InternLM (`internlm/internlm-7b`, `internlm/internlm-chat-7b`, etc.) - InternLM2 (`internlm/internlm2-7b`, `internlm/internlm2-chat-7b`, etc.) +- Jais (`core42/jais-13b`, `core42/jais-13b-chat`, `core42/jais-30b-v3`, `core42/jais-30b-chat-v3`, etc.) - LLaMA & LLaMA-2 (`meta-llama/Llama-2-70b-hf`, `lmsys/vicuna-13b-v1.3`, `young-geng/koala`, `openlm-research/open_llama_13b`, etc.) - Mistral (`mistralai/Mistral-7B-v0.1`, `mistralai/Mistral-7B-Instruct-v0.1`, etc.) - Mixtral (`mistralai/Mixtral-8x7B-v0.1`, `mistralai/Mixtral-8x7B-Instruct-v0.1`, etc.) diff --git a/docs/source/models/supported_models.rst b/docs/source/models/supported_models.rst index 9d4ec663a16e..bae87255e922 100644 --- a/docs/source/models/supported_models.rst +++ b/docs/source/models/supported_models.rst @@ -53,6 +53,9 @@ Alongside each architecture, we include some popular models that use it. * - :code:`InternLM2ForCausalLM` - InternLM2 - :code:`internlm/internlm2-7b`, :code:`internlm/internlm2-chat-7b`, etc. + * - :code:`JAISLMHeadModel` + - Jais + - :code:`core42/jais-13b`, :code:`core42/jais-13b-chat`, :code:`core42/jais-30b-v3`, :code:`core42/jais-30b-chat-v3`, etc. * - :code:`LlamaForCausalLM` - LLaMA, LLaMA-2, Vicuna, Alpaca, Yi - :code:`meta-llama/Llama-2-13b-hf`, :code:`meta-llama/Llama-2-70b-hf`, :code:`openlm-research/open_llama_13b`, :code:`lmsys/vicuna-13b-v1.3`, :code:`01-ai/Yi-6B`, :code:`01-ai/Yi-34B`, etc. diff --git a/tests/models/test_models.py b/tests/models/test_models.py index fb567e837d28..5488149227df 100644 --- a/tests/models/test_models.py +++ b/tests/models/test_models.py @@ -20,6 +20,7 @@ "stabilityai/stablelm-3b-4e1t", "allenai/OLMo-1B", "bigcode/starcoder2-3b", + "core42/jais-13b", ] diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index 75c2ae1e9f48..f25a22e3c213 100644 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -27,6 +27,7 @@ "GPTNeoXForCausalLM": ("gpt_neox", "GPTNeoXForCausalLM"), "InternLMForCausalLM": ("llama", "LlamaForCausalLM"), "InternLM2ForCausalLM": ("internlm2", "InternLM2ForCausalLM"), + "JAISLMHeadModel": ("jais", "JAISLMHeadModel"), "LlamaForCausalLM": ("llama", "LlamaForCausalLM"), # For decapoda-research/llama-* "LLaMAForCausalLM": ("llama", "LlamaForCausalLM"), diff --git a/vllm/model_executor/models/jais.py b/vllm/model_executor/models/jais.py new file mode 100644 index 000000000000..09fc1ff67222 --- /dev/null +++ b/vllm/model_executor/models/jais.py @@ -0,0 +1,320 @@ +# coding=utf-8 +# Adapted from +# https://huggingface.co/core42/jais-30b-chat-v3/blob/main/modeling_jais.py +# Copyright 2023 The vLLM team. +# Copyright 2023 the Jais authors and HuggingFace Inc. team. All rights +# reserved. +# Copyright 2023 Cerebras Systems. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Inference-only Jais model compatible with HuggingFace weights.""" + +import math +from typing import Any, Dict, List, Optional, Tuple + +import torch +from torch import nn +from torch.nn import LayerNorm +from vllm.transformers_utils.configs import JAISConfig + +from vllm.model_executor.input_metadata import InputMetadata +from vllm.model_executor.layers.attention import PagedAttention +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + LinearMethodBase, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.rotary_embedding import get_act_fn +from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ( + VocabParallelEmbedding) +from vllm.model_executor.parallel_utils.parallel_state import ( + get_tensor_model_parallel_world_size, get_tensor_model_parallel_rank) +from vllm.model_executor.parallel_utils.communication_op import ( + tensor_model_parallel_all_reduce) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.model_executor.weight_utils import (default_weight_loader, + hf_model_weights_iterator) +from vllm.sequence import SamplerOutput + +KVCache = Tuple[torch.Tensor, torch.Tensor] + + +class SwiGLUActivation(nn.Module): + def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor: + return x1 * nn.functional.silu(x2) + + +def _get_alibi_slopes(n): + def get_slopes_power_of_2(n): + start = 2 ** (-(2 ** -(math.log2(n) - 3))) + ratio = start + return [start * ratio**i for i in range(n)] + + if math.log2(n).is_integer(): + return get_slopes_power_of_2(n) + else: + closest_power_of_2 = 2 ** math.floor(math.log2(n)) + return ( + get_slopes_power_of_2(closest_power_of_2) + + _get_alibi_slopes(2 * closest_power_of_2)[0::2][: n - closest_power_of_2] + ) + + +class JAISAttention(nn.Module): + + def __init__( + self, + config: JAISConfig, + linear_method: Optional[LinearMethodBase] = None, + ): + super().__init__() + self.hidden_size = config.hidden_size + total_num_heads = config.num_attention_heads + tensor_model_parallel_world_size = ( + get_tensor_model_parallel_world_size()) + assert total_num_heads % tensor_model_parallel_world_size == 0 + self.num_heads = total_num_heads // tensor_model_parallel_world_size + self.head_dim = self.hidden_size // total_num_heads + self.attn_scale_power = 1.0 if config.mup_scale_qk_dot_by_d else 0.5 + self.scale = self.head_dim**-attn_scale_power + + self.c_attn = QKVParallelLinear( + self.hidden_size, + self.head_dim, + total_num_heads, + bias=True, + linear_method=linear_method, + ) + self.c_proj = RowParallelLinear( + self.hidden_size, + self.hidden_size, + bias=True, + linear_method=linear_method, + ) + + tp_rank = get_tensor_model_parallel_rank() + head_start = tp_rank * self.num_heads + head_end = (tp_rank + 1) * self.num_heads + alibi_slopes = _get_alibi_slopes(self.total_num_heads) + alibi_slopes = alibi_slopes[head_start:head_end].tolist() + self.attn = PagedAttention(self.num_heads, + self.head_dim, + scale=self.scale, + alibi_slopes=alibi_slopes) + + def forward( + self, + hidden_states: torch.Tensor, + kv_cache: KVCache, + input_metadata: InputMetadata, + ) -> torch.Tensor: + qkv, _ = self.c_attn(hidden_states) + q, k, v = qkv.chunk(chunks=3, dim=-1) + key_cache, value_cache = kv_cache + attn_output = self.attn(q, k, v, key_cache, value_cache, + input_metadata) + attn_output, _ = self.c_proj(attn_output) + return attn_output + + +class JAISMLP(nn.Module): + + def __init__( + self, + intermediate_size: int, + config: JAISConfig, + linear_method: Optional[LinearMethodBase] = None, + ): + super().__init__() + hidden_size = config.hidden_size + self.swiglu = config.activation_function == "swiglu" + self.c_fc = ColumnParallelLinear( + hidden_size, + intermediate_size, + bias=True, + linear_method=linear_method, + ) + self.c_fc2 = ColumnParallelLinear( + hidden_size, + intermediate_size, + bias=True, + linear_method=linear_method, + ) if self.swiglu else None + self.c_proj = RowParallelLinear( + intermediate_size, + hidden_size, + bias=True, + linear_method=linear_method, + ) + quant_config = getattr(linear_method, "quant_config", None) + self.act_gpt2 = get_act_fn(config.activation_function, quant_config, + intermediate_size) + self.act = SwiGLUActivation() if self.swiglu else self.act_gpt2 + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + if self.swiglu: + hidden_states2, _ = self.c_fc2(hidden_states) + hidden_states, _ = self.c_fc(hidden_states) + hidden_states = self.act(hidden_states, hidden_states2) if self.swiglu else self.act(hidden_states) + hidden_states, _ = self.c_proj(hidden_states) + return hidden_states + +class JAISBlock(nn.Module): + + def __init__( + self, + config: JAISConfig, + linear_method: Optional[LinearMethodBase] = None, + ): + super().__init__() + hidden_size = config.hidden_size + inner_dim = (config.n_inner if config.n_inner is not None else 4 * + hidden_size) + + self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + self.attn = JAISAttention(config, linear_method) + self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + self.mlp = JAISMLP(inner_dim, config, linear_method) + + def forward( + self, + hidden_states: torch.Tensor, + kv_cache: KVCache, + input_metadata: InputMetadata, + ) -> torch.Tensor: + residual = hidden_states + hidden_states = self.ln_1(hidden_states) + attn_output = self.attn( + hidden_states=hidden_states, + kv_cache=kv_cache, + input_metadata=input_metadata, + ) + # residual connection + hidden_states = attn_output + residual + + residual = hidden_states + hidden_states = self.ln_2(hidden_states) + feed_forward_hidden_states = self.mlp(hidden_states) + # residual connection + hidden_states = residual + feed_forward_hidden_states + return hidden_states + + +class JAISModel(nn.Module): + + def __init__( + self, + config: JAISConfig, + linear_method: Optional[LinearMethodBase] = None, + ): + super().__init__() + self.config = config + assert not config.add_cross_attention + assert not config.scale_attn_by_inverse_layer_idx + assert not config.reorder_and_upcast_attn + self.embed_dim = config.hidden_size + self.wte = VocabParallelEmbedding(config.vocab_size, self.embed_dim) + self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim) if config.position_embedding_type != "alibi" else None + self.embeddings_scale = config.mup_embeddings_scale + self.h = nn.ModuleList([ + JAISBlock(config, linear_method) + for _ in range(config.num_hidden_layers) + ]) + self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + kv_caches: List[KVCache], + input_metadata: InputMetadata, + ) -> torch.Tensor: + inputs_embeds = self.wte(input_ids) + if self.wpe is not None: + position_embeds = self.wpe(position_ids) + hidden_states = inputs_embeds + position_embeds + else: + hidden_states = inputs_embeds + + for i in range(len(self.h)): + layer = self.h[i] + hidden_states = layer(hidden_states, kv_caches[i], input_metadata) + + hidden_states = self.ln_f(hidden_states) + return hidden_states + +class JAISLMHeadModel(nn.Module): + + def __init__( + self, + config: JAISConfig, + linear_method: Optional[LinearMethodBase] = None, + ): + super().__init__() + self.config = config + self.linear_method = linear_method + self.transformer = JAISModel(config, linear_method) + self.lm_head_weight = self.transformer.wte.weight + self.sampler = Sampler(config.vocab_size) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[KVCache], + input_metadata: InputMetadata, + ) -> torch.Tensor: + hidden_states = self.transformer(input_ids, positions, kv_caches, + input_metadata) + return hidden_states + + def sample( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(self.lm_head_weight, hidden_states, + sampling_metadata) + return next_tokens + + def load_weights(self, + model_name_or_path: str, + cache_dir: Optional[str] = None, + load_format: str = "auto", + revision: Optional[str] = None): + params_dict = dict(self.named_parameters(remove_duplicate=False)) + for name, loaded_weight in hf_model_weights_iterator( + model_name_or_path, cache_dir, load_format, revision): + if "lm_head.weight" in name: + # GPT-2 ties the weights of the embedding layer and the final + # linear layer. + continue + if ".attn.bias" in name or ".attn.masked_bias" in name: + # Skip attention mask. + # NOTE: "c_attn.bias" should not be skipped. + continue + if not name.startswith("transformer."): + name = "transformer." + name + param = params_dict[name] + # The HF's GPT-2 implementation uses Conv1D instead of Linear. + # Because of this, we need to transpose the weights. + # Note(zhuohan): the logic below might break quantized models. + for conv1d_weight_name in ["c_attn", "c_proj", "c_fc", "c_fc2"]: + if conv1d_weight_name not in name: + continue + if not name.endswith(".weight"): + continue + loaded_weight = loaded_weight.t() + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) \ No newline at end of file diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index 5e1f0439aec5..081e81768b23 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -10,6 +10,7 @@ "RefinedWeb": RWConfig, # For tiiuae/falcon-40b(-instruct) "RefinedWebModel": RWConfig, # For tiiuae/falcon-7b(-instruct) "starcoder2": Starcoder2Config, + "jais": JAISConfig, } diff --git a/vllm/transformers_utils/configs/__init__.py b/vllm/transformers_utils/configs/__init__.py index 4966526f1518..150ee2ce97ad 100644 --- a/vllm/transformers_utils/configs/__init__.py +++ b/vllm/transformers_utils/configs/__init__.py @@ -5,10 +5,12 @@ # `FalconConfig` class from the official HuggingFace transformers library. from vllm.transformers_utils.configs.falcon import RWConfig from vllm.transformers_utils.configs.starcoder2 import Starcoder2Config +from vllm.transformers_utils.configs.jais import JAISConfig __all__ = [ "ChatGLMConfig", "MPTConfig", "RWConfig", "Starcoder2Config", + "JAISConfig", ] diff --git a/vllm/transformers_utils/configs/jais.py b/vllm/transformers_utils/configs/jais.py new file mode 100644 index 000000000000..2cc2d6d9a1c7 --- /dev/null +++ b/vllm/transformers_utils/configs/jais.py @@ -0,0 +1,196 @@ +# coding=utf-8 +# Copyright 2023 The OpenAI Team Authors and HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# Copyright 2023 Cerebras Systems. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" JAIS configuration""" + +from transformers.configuration_utils import PretrainedConfig +from transformers.utils import logging + + +logger = logging.get_logger(__name__) + +class JAISConfig(PretrainedConfig): + """ + This is the configuration class to store the configuration of a [`JAISModel`]. It is used to instantiate a JAIS + model according to the specified arguments, defining the model architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 50257): + Vocabulary size of the JAIS model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`JAISModel`]. + n_positions (`int`, *optional*, defaults to 1024): + The maximum sequence length that this model might ever be used with. Typically set this to something large + just in case (e.g., 512 or 1024 or 2048). + n_embd (`int`, *optional*, defaults to 768): + Dimensionality of the embeddings and hidden states. + n_layer (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + n_head (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + n_inner (`int`, *optional*, defaults to None): + Dimensionality of the inner feed-forward layers. `None` will set it to 4 times n_embd + activation_function (`str`, *optional*, defaults to `"gelu"`): + Activation function, to be selected in the list `["relu", "silu", "gelu", "tanh", "gelu_new", "swiglu"]`. + resid_pdrop (`float`, *optional*, defaults to 0.1): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + embd_pdrop (`float`, *optional*, defaults to 0.1): + The dropout ratio for the embeddings. + attn_pdrop (`float`, *optional*, defaults to 0.1): + The dropout ratio for the attention. + layer_norm_epsilon (`float`, *optional*, defaults to 1e-5): + The epsilon to use in the layer normalization layers. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + scale_attn_weights (`bool`, *optional*, defaults to `True`): + Scale attention weights by dividing by sqrt(hidden_size).. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). + scale_attn_by_inverse_layer_idx (`bool`, *optional*, defaults to `False`): + Whether to additionally scale attention weights by `1 / layer_idx + 1`. + reorder_and_upcast_attn (`bool`, *optional*, defaults to `False`): + Whether to scale keys (K) prior to computing attention (dot-product) and upcast attention + dot-product/softmax to float() when training with mixed precision. + position_embedding_type (`str`, *optional*, defaults to `"learned"`): + Positional embedding can be either `"alibi"` or `"learned"`. + mup_width_scale (`float`, *optional*, defaults to 1.0): + muP parameter to scale learning rate and initializers. Calculated as (`d_model,0 / d_model`), where + `d_model` is the model's width and `d_model,0` is the proxy model's width. + mup_embeddings_scale (`float`, *optional*, defaults to 1.0): + muP parameter to scale token and position embeddings. + mup_output_alpha (`float`, *optional*, defaults to 1.0): + muP parameter to scale output logits (`output_logits_scale = mup_output_alpha * mup_width_scale`). + mup_scale_qk_dot_by_d (`bool`, *optional*, defaults to `False`): + Scale attention weights by dividing by hidden_size instead of sqrt(hidden_size). Need to set + scale_attn_weights to `True` as well. + alibi_scaling (`Dict`, *optional*): + Dictionary containing the scaling configuration for ALiBi embeddings. Currently only supports linear + scaling strategy. Can specify either the scaling `factor` (must be a float greater than 1) for fixed scaling + or `train_seq_len` for dynamic scaling on input samples with sequence length > `train_seq_len`. The expected + formats are `{"type": strategy name, "factor": scaling factor}` or + `{"type": strategy name, "train_seq_len": training sequence length}`. + + Example: + + ```python + >>> from transformers import JAISConfig, JAISModel + + >>> # Initializing a JAIS configuration + >>> configuration = JAISConfig() + + >>> # Initializing a model (with random weights) from the configuration + >>> model = JAISModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "jais" + keys_to_ignore_at_inference = ["past_key_values"] + attribute_map = { + "hidden_size": "n_embd", + "max_position_embeddings": "n_positions", + "num_attention_heads": "n_head", + "num_hidden_layers": "n_layer", + } + + def __init__( + self, + vocab_size=50257, + n_positions=1024, + n_embd=768, + n_layer=12, + n_head=12, + n_inner=None, + activation_function="gelu_new", + resid_pdrop=0.1, + embd_pdrop=0.1, + attn_pdrop=0.1, + layer_norm_epsilon=1e-5, + initializer_range=0.02, + scale_attn_weights=True, + use_cache=True, + bos_token_id=50256, + eos_token_id=50256, + scale_attn_by_inverse_layer_idx=False, + reorder_and_upcast_attn=False, + position_embedding_type="learned", + mup_width_scale=1.0, + mup_embeddings_scale=1.0, + mup_output_alpha=1.0, + mup_scale_qk_dot_by_d=False, + alibi_scaling=None, + **kwargs, + ): + self.vocab_size = vocab_size + self.n_positions = n_positions + self.n_embd = n_embd + self.n_layer = n_layer + self.n_head = n_head + self.n_inner = n_inner + self.activation_function = activation_function + self.resid_pdrop = resid_pdrop + self.embd_pdrop = embd_pdrop + self.attn_pdrop = attn_pdrop + self.layer_norm_epsilon = layer_norm_epsilon + self.initializer_range = initializer_range + self.scale_attn_weights = scale_attn_weights + self.use_cache = use_cache + self.scale_attn_by_inverse_layer_idx = scale_attn_by_inverse_layer_idx + self.reorder_and_upcast_attn = reorder_and_upcast_attn + + self.bos_token_id = bos_token_id + self.eos_token_id = eos_token_id + + self.position_embedding_type = position_embedding_type + self.mup_width_scale = mup_width_scale + self.mup_embeddings_scale = mup_embeddings_scale + self.mup_output_alpha = mup_output_alpha + self.mup_scale_qk_dot_by_d = mup_scale_qk_dot_by_d + + self.alibi_scaling = alibi_scaling + self._alibi_scaling_validation() + + super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) + + def _alibi_scaling_validation(self): + """ + Validate the `alibi_scaling` configuration. + """ + if self.alibi_scaling is None: + return + + if not isinstance(self.alibi_scaling, dict) or len(self.alibi_scaling) != 2: + raise ValueError( + "`alibi_scaling` must be a dictionary with two fields, `type` and `factor` or `type` and `train_seq_len`, " + f"got {self.alibi_scaling}" + ) + alibi_scaling_type = self.alibi_scaling.get("type", None) + alibi_scaling_factor = self.alibi_scaling.get("factor", None) + alibi_dynamic_scaling = self.alibi_scaling.get("train_seq_len", None) + if alibi_scaling_type is None or alibi_scaling_type != "linear": + raise ValueError( + f"`alibi_scaling`'s type field must be 'linear', got {alibi_scaling_type}" + ) + if alibi_scaling_factor is not None: + if not isinstance(alibi_scaling_factor, float) or alibi_scaling_factor <= 1.0: + raise ValueError(f"`alibi_scaling`'s factor field must be a float > 1.0, got {alibi_scaling_factor}") + if alibi_dynamic_scaling is not None: + if not isinstance(alibi_dynamic_scaling, int) or alibi_dynamic_scaling <= 1: + raise ValueError(f"`alibi_scaling`'s `train_seq_len` field must be an integer > 1, got {alibi_dynamic_scaling}") \ No newline at end of file From b5feaa698d877cf80477c156f604459ddfb39f3e Mon Sep 17 00:00:00 2001 From: Lalit Pradhan Date: Mon, 4 Mar 2024 20:23:49 +0000 Subject: [PATCH 02/31] updated flake-8 --- vllm/model_executor/models/jais.py | 29 ++++++++++++++----------- vllm/transformers_utils/configs/jais.py | 28 ++++++++++++++---------- 2 files changed, 33 insertions(+), 24 deletions(-) diff --git a/vllm/model_executor/models/jais.py b/vllm/model_executor/models/jais.py index 09fc1ff67222..d7e1b8074fe3 100644 --- a/vllm/model_executor/models/jais.py +++ b/vllm/model_executor/models/jais.py @@ -20,7 +20,7 @@ """ Inference-only Jais model compatible with HuggingFace weights.""" import math -from typing import Any, Dict, List, Optional, Tuple +from typing import List, Optional, Tuple import torch from torch import nn @@ -39,8 +39,6 @@ VocabParallelEmbedding) from vllm.model_executor.parallel_utils.parallel_state import ( get_tensor_model_parallel_world_size, get_tensor_model_parallel_rank) -from vllm.model_executor.parallel_utils.communication_op import ( - tensor_model_parallel_all_reduce) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.weight_utils import (default_weight_loader, hf_model_weights_iterator) @@ -50,24 +48,23 @@ class SwiGLUActivation(nn.Module): + def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor: return x1 * nn.functional.silu(x2) def _get_alibi_slopes(n): def get_slopes_power_of_2(n): - start = 2 ** (-(2 ** -(math.log2(n) - 3))) + start = 2**(-(2**-(math.log2(n) - 3))) ratio = start return [start * ratio**i for i in range(n)] if math.log2(n).is_integer(): return get_slopes_power_of_2(n) else: - closest_power_of_2 = 2 ** math.floor(math.log2(n)) - return ( - get_slopes_power_of_2(closest_power_of_2) - + _get_alibi_slopes(2 * closest_power_of_2)[0::2][: n - closest_power_of_2] - ) + closest_power_of_2 = 2**math.floor(math.log2(n)) + return (get_slopes_power_of_2(closest_power_of_2) + _get_alibi_slopes( + 2 * closest_power_of_2)[0::2][:n - closest_power_of_2]) class JAISAttention(nn.Module): @@ -86,7 +83,7 @@ def __init__( self.num_heads = total_num_heads // tensor_model_parallel_world_size self.head_dim = self.hidden_size // total_num_heads self.attn_scale_power = 1.0 if config.mup_scale_qk_dot_by_d else 0.5 - self.scale = self.head_dim**-attn_scale_power + self.scale = self.head_dim**-self.attn_scale_power self.c_attn = QKVParallelLinear( self.hidden_size, @@ -158,17 +155,20 @@ def __init__( ) quant_config = getattr(linear_method, "quant_config", None) self.act_gpt2 = get_act_fn(config.activation_function, quant_config, - intermediate_size) + intermediate_size) self.act = SwiGLUActivation() if self.swiglu else self.act_gpt2 def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: if self.swiglu: hidden_states2, _ = self.c_fc2(hidden_states) hidden_states, _ = self.c_fc(hidden_states) - hidden_states = self.act(hidden_states, hidden_states2) if self.swiglu else self.act(hidden_states) + hidden_states = self.act( + hidden_states, + hidden_states2) if self.swiglu else self.act(hidden_states) hidden_states, _ = self.c_proj(hidden_states) return hidden_states + class JAISBlock(nn.Module): def __init__( @@ -224,7 +224,9 @@ def __init__( assert not config.reorder_and_upcast_attn self.embed_dim = config.hidden_size self.wte = VocabParallelEmbedding(config.vocab_size, self.embed_dim) - self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim) if config.position_embedding_type != "alibi" else None + self.wpe = nn.Embedding( + config.max_position_embeddings, self.embed_dim + ) if config.position_embedding_type != "alibi" else None self.embeddings_scale = config.mup_embeddings_scale self.h = nn.ModuleList([ JAISBlock(config, linear_method) @@ -253,6 +255,7 @@ def forward( hidden_states = self.ln_f(hidden_states) return hidden_states + class JAISLMHeadModel(nn.Module): def __init__( diff --git a/vllm/transformers_utils/configs/jais.py b/vllm/transformers_utils/configs/jais.py index 2cc2d6d9a1c7..880fc607f7bc 100644 --- a/vllm/transformers_utils/configs/jais.py +++ b/vllm/transformers_utils/configs/jais.py @@ -19,9 +19,9 @@ from transformers.configuration_utils import PretrainedConfig from transformers.utils import logging - logger = logging.get_logger(__name__) + class JAISConfig(PretrainedConfig): """ This is the configuration class to store the configuration of a [`JAISModel`]. It is used to instantiate a JAIS @@ -167,7 +167,9 @@ def __init__( self.alibi_scaling = alibi_scaling self._alibi_scaling_validation() - super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) + super().__init__(bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + **kwargs) def _alibi_scaling_validation(self): """ @@ -176,11 +178,11 @@ def _alibi_scaling_validation(self): if self.alibi_scaling is None: return - if not isinstance(self.alibi_scaling, dict) or len(self.alibi_scaling) != 2: + if not isinstance(self.alibi_scaling, + dict) or len(self.alibi_scaling) != 2: raise ValueError( "`alibi_scaling` must be a dictionary with two fields, `type` and `factor` or `type` and `train_seq_len`, " - f"got {self.alibi_scaling}" - ) + f"got {self.alibi_scaling}") alibi_scaling_type = self.alibi_scaling.get("type", None) alibi_scaling_factor = self.alibi_scaling.get("factor", None) alibi_dynamic_scaling = self.alibi_scaling.get("train_seq_len", None) @@ -188,9 +190,13 @@ def _alibi_scaling_validation(self): raise ValueError( f"`alibi_scaling`'s type field must be 'linear', got {alibi_scaling_type}" ) - if alibi_scaling_factor is not None: - if not isinstance(alibi_scaling_factor, float) or alibi_scaling_factor <= 1.0: - raise ValueError(f"`alibi_scaling`'s factor field must be a float > 1.0, got {alibi_scaling_factor}") - if alibi_dynamic_scaling is not None: - if not isinstance(alibi_dynamic_scaling, int) or alibi_dynamic_scaling <= 1: - raise ValueError(f"`alibi_scaling`'s `train_seq_len` field must be an integer > 1, got {alibi_dynamic_scaling}") \ No newline at end of file + if alibi_scaling_factor is not None and not isinstance(alibi_scaling_factor, + float) or alibi_scaling_factor <= 1.0: + raise ValueError( + f"`alibi_scaling`'s factor field must be a float > 1.0, got {alibi_scaling_factor}" + ) + if alibi_dynamic_scaling is not None and not isinstance(alibi_dynamic_scaling, + int) or alibi_dynamic_scaling <= 1: + raise ValueError( + f"`alibi_scaling`'s `train_seq_len` field must be an integer > 1, got {alibi_dynamic_scaling}" + ) \ No newline at end of file From 4d5b65e0bc682271cafed705629b5b61c94e723c Mon Sep 17 00:00:00 2001 From: Lalit Pradhan Date: Mon, 4 Mar 2024 20:32:02 +0000 Subject: [PATCH 03/31] fixed formatting --- vllm/model_executor/models/jais.py | 2 +- vllm/transformers_utils/configs/jais.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/vllm/model_executor/models/jais.py b/vllm/model_executor/models/jais.py index d7e1b8074fe3..13f31f23ed33 100644 --- a/vllm/model_executor/models/jais.py +++ b/vllm/model_executor/models/jais.py @@ -24,7 +24,6 @@ import torch from torch import nn -from torch.nn import LayerNorm from vllm.transformers_utils.configs import JAISConfig from vllm.model_executor.input_metadata import InputMetadata @@ -54,6 +53,7 @@ def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor: def _get_alibi_slopes(n): + def get_slopes_power_of_2(n): start = 2**(-(2**-(math.log2(n) - 3))) ratio = start diff --git a/vllm/transformers_utils/configs/jais.py b/vllm/transformers_utils/configs/jais.py index 880fc607f7bc..2c5e286f7f15 100644 --- a/vllm/transformers_utils/configs/jais.py +++ b/vllm/transformers_utils/configs/jais.py @@ -190,13 +190,13 @@ def _alibi_scaling_validation(self): raise ValueError( f"`alibi_scaling`'s type field must be 'linear', got {alibi_scaling_type}" ) - if alibi_scaling_factor is not None and not isinstance(alibi_scaling_factor, - float) or alibi_scaling_factor <= 1.0: + if alibi_scaling_factor is not None and not isinstance( + alibi_scaling_factor, float) or alibi_scaling_factor <= 1.0: raise ValueError( f"`alibi_scaling`'s factor field must be a float > 1.0, got {alibi_scaling_factor}" ) - if alibi_dynamic_scaling is not None and not isinstance(alibi_dynamic_scaling, - int) or alibi_dynamic_scaling <= 1: + if alibi_dynamic_scaling is not None and not isinstance( + alibi_dynamic_scaling, int) or alibi_dynamic_scaling <= 1: raise ValueError( f"`alibi_scaling`'s `train_seq_len` field must be an integer > 1, got {alibi_dynamic_scaling}" ) \ No newline at end of file From 9ad30617bf3b544e96b3aecbe695e63a7257da0a Mon Sep 17 00:00:00 2001 From: Lalit Pradhan Date: Mon, 4 Mar 2024 20:40:10 +0000 Subject: [PATCH 04/31] fixed formatting --- vllm/transformers_utils/configs/jais.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/vllm/transformers_utils/configs/jais.py b/vllm/transformers_utils/configs/jais.py index 2c5e286f7f15..22b0dc074444 100644 --- a/vllm/transformers_utils/configs/jais.py +++ b/vllm/transformers_utils/configs/jais.py @@ -191,12 +191,12 @@ def _alibi_scaling_validation(self): f"`alibi_scaling`'s type field must be 'linear', got {alibi_scaling_type}" ) if alibi_scaling_factor is not None and not isinstance( - alibi_scaling_factor, float) or alibi_scaling_factor <= 1.0: - raise ValueError( - f"`alibi_scaling`'s factor field must be a float > 1.0, got {alibi_scaling_factor}" + alibi_scaling_factor, float) or alibi_scaling_factor <= 1.0: + raise ValueError( + f"`alibi_scaling`'s factor field must be a float > 1.0, got {alibi_scaling_factor}" ) if alibi_dynamic_scaling is not None and not isinstance( - alibi_dynamic_scaling, int) or alibi_dynamic_scaling <= 1: - raise ValueError( + alibi_dynamic_scaling, int) or alibi_dynamic_scaling <= 1: + raise ValueError( f"`alibi_scaling`'s `train_seq_len` field must be an integer > 1, got {alibi_dynamic_scaling}" ) \ No newline at end of file From 7b015e648e37204e755283c3e6bf97d66905d1d4 Mon Sep 17 00:00:00 2001 From: Lalit Pradhan Date: Mon, 4 Mar 2024 20:44:32 +0000 Subject: [PATCH 05/31] fixed formatting --- vllm/model_executor/models/jais.py | 2 -- vllm/transformers_utils/configs/jais.py | 6 +++--- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/vllm/model_executor/models/jais.py b/vllm/model_executor/models/jais.py index 13f31f23ed33..e188a124cf11 100644 --- a/vllm/model_executor/models/jais.py +++ b/vllm/model_executor/models/jais.py @@ -47,13 +47,11 @@ class SwiGLUActivation(nn.Module): - def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor: return x1 * nn.functional.silu(x2) def _get_alibi_slopes(n): - def get_slopes_power_of_2(n): start = 2**(-(2**-(math.log2(n) - 3))) ratio = start diff --git a/vllm/transformers_utils/configs/jais.py b/vllm/transformers_utils/configs/jais.py index 22b0dc074444..cd31f28c9e80 100644 --- a/vllm/transformers_utils/configs/jais.py +++ b/vllm/transformers_utils/configs/jais.py @@ -194,9 +194,9 @@ def _alibi_scaling_validation(self): alibi_scaling_factor, float) or alibi_scaling_factor <= 1.0: raise ValueError( f"`alibi_scaling`'s factor field must be a float > 1.0, got {alibi_scaling_factor}" - ) + ) if alibi_dynamic_scaling is not None and not isinstance( alibi_dynamic_scaling, int) or alibi_dynamic_scaling <= 1: raise ValueError( - f"`alibi_scaling`'s `train_seq_len` field must be an integer > 1, got {alibi_dynamic_scaling}" - ) \ No newline at end of file + f"`alibi_scaling`'s `train_seq_len` field must be an integer > 1, got {alibi_dynamic_scaling}" + ) \ No newline at end of file From b595d396a1204d0326d9222f56163d850bcd5fd5 Mon Sep 17 00:00:00 2001 From: Lalit Pradhan Date: Mon, 4 Mar 2024 20:46:23 +0000 Subject: [PATCH 06/31] fixed formatting --- vllm/model_executor/models/jais.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/vllm/model_executor/models/jais.py b/vllm/model_executor/models/jais.py index e188a124cf11..13f31f23ed33 100644 --- a/vllm/model_executor/models/jais.py +++ b/vllm/model_executor/models/jais.py @@ -47,11 +47,13 @@ class SwiGLUActivation(nn.Module): + def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor: return x1 * nn.functional.silu(x2) def _get_alibi_slopes(n): + def get_slopes_power_of_2(n): start = 2**(-(2**-(math.log2(n) - 3))) ratio = start From c976954fcbb9cd82640504f76f31ec25190c7a3c Mon Sep 17 00:00:00 2001 From: Lalit Pradhan Date: Mon, 4 Mar 2024 20:51:20 +0000 Subject: [PATCH 07/31] fixed formatting --- vllm/model_executor/models/jais.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/models/jais.py b/vllm/model_executor/models/jais.py index 13f31f23ed33..94764ef64194 100644 --- a/vllm/model_executor/models/jais.py +++ b/vllm/model_executor/models/jais.py @@ -47,13 +47,13 @@ class SwiGLUActivation(nn.Module): - + def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor: return x1 * nn.functional.silu(x2) def _get_alibi_slopes(n): - + def get_slopes_power_of_2(n): start = 2**(-(2**-(math.log2(n) - 3))) ratio = start From 7cb2757765b641213e4d7da749855ea341f54668 Mon Sep 17 00:00:00 2001 From: Lalit Pradhan Date: Tue, 5 Mar 2024 06:33:30 +0000 Subject: [PATCH 08/31] fixed inference bugs --- vllm/model_executor/models/jais.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/vllm/model_executor/models/jais.py b/vllm/model_executor/models/jais.py index 94764ef64194..7058ffee8c35 100644 --- a/vllm/model_executor/models/jais.py +++ b/vllm/model_executor/models/jais.py @@ -32,7 +32,7 @@ LinearMethodBase, QKVParallelLinear, RowParallelLinear) -from vllm.model_executor.layers.rotary_embedding import get_act_fn +from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) @@ -102,7 +102,7 @@ def __init__( tp_rank = get_tensor_model_parallel_rank() head_start = tp_rank * self.num_heads head_end = (tp_rank + 1) * self.num_heads - alibi_slopes = _get_alibi_slopes(self.total_num_heads) + alibi_slopes = _get_alibi_slopes(total_num_heads) alibi_slopes = alibi_slopes[head_start:head_end].tolist() self.attn = PagedAttention(self.num_heads, self.head_dim, @@ -154,9 +154,7 @@ def __init__( linear_method=linear_method, ) quant_config = getattr(linear_method, "quant_config", None) - self.act_gpt2 = get_act_fn(config.activation_function, quant_config, - intermediate_size) - self.act = SwiGLUActivation() if self.swiglu else self.act_gpt2 + self.act = SwiGLUActivation()# if self.swiglu else self.act_gpt2 def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: if self.swiglu: @@ -306,13 +304,15 @@ def load_weights(self, # Skip attention mask. # NOTE: "c_attn.bias" should not be skipped. continue + if "relative_pe" in name: + continue if not name.startswith("transformer."): name = "transformer." + name param = params_dict[name] # The HF's GPT-2 implementation uses Conv1D instead of Linear. # Because of this, we need to transpose the weights. # Note(zhuohan): the logic below might break quantized models. - for conv1d_weight_name in ["c_attn", "c_proj", "c_fc", "c_fc2"]: + for conv1d_weight_name in ["c_attn", "c_proj", "c_fc"]: if conv1d_weight_name not in name: continue if not name.endswith(".weight"): From 452227eb410d16ef80065009bda29e14ba72bb32 Mon Sep 17 00:00:00 2001 From: Lalit Pradhan Date: Tue, 5 Mar 2024 06:35:45 +0000 Subject: [PATCH 09/31] apply ruff --- vllm/model_executor/models/jais.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/models/jais.py b/vllm/model_executor/models/jais.py index 7058ffee8c35..2f54c207d8b9 100644 --- a/vllm/model_executor/models/jais.py +++ b/vllm/model_executor/models/jais.py @@ -32,7 +32,7 @@ LinearMethodBase, QKVParallelLinear, RowParallelLinear) -from vllm.model_executor.layers.activation import get_act_fn +# from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) @@ -153,7 +153,7 @@ def __init__( bias=True, linear_method=linear_method, ) - quant_config = getattr(linear_method, "quant_config", None) + # quant_config = getattr(linear_method, "quant_config", None) self.act = SwiGLUActivation()# if self.swiglu else self.act_gpt2 def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: From 3776a669a952e8759c596e01e87d83c878943b43 Mon Sep 17 00:00:00 2001 From: Lalit Pradhan Date: Tue, 5 Mar 2024 06:37:43 +0000 Subject: [PATCH 10/31] apply yapf --- vllm/model_executor/models/jais.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/models/jais.py b/vllm/model_executor/models/jais.py index 2f54c207d8b9..414932bd051a 100644 --- a/vllm/model_executor/models/jais.py +++ b/vllm/model_executor/models/jais.py @@ -154,7 +154,7 @@ def __init__( linear_method=linear_method, ) # quant_config = getattr(linear_method, "quant_config", None) - self.act = SwiGLUActivation()# if self.swiglu else self.act_gpt2 + self.act = SwiGLUActivation() # if self.swiglu else self.act_gpt2 def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: if self.swiglu: From 689c3ecfbbce0dca5afd7661588e2d1679d583a8 Mon Sep 17 00:00:00 2001 From: Lalit Pradhan Date: Tue, 5 Mar 2024 14:34:21 +0000 Subject: [PATCH 11/31] bug fixes --- vllm/model_executor/models/jais.py | 116 ++++++++++++++++++++++++++--- 1 file changed, 107 insertions(+), 9 deletions(-) diff --git a/vllm/model_executor/models/jais.py b/vllm/model_executor/models/jais.py index 414932bd051a..5a8d5b1ac725 100644 --- a/vllm/model_executor/models/jais.py +++ b/vllm/model_executor/models/jais.py @@ -32,8 +32,15 @@ LinearMethodBase, QKVParallelLinear, RowParallelLinear) -# from vllm.model_executor.layers.activation import get_act_fn -from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.sampler import (Sampler, + _prune_hidden_states, + _apply_logits_processors, + _apply_penalties, + _apply_top_k_top_p, + _apply_min_p, + _sample, + _get_logprobs, + _build_sampler_output) from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) from vllm.model_executor.parallel_utils.parallel_state import ( @@ -42,6 +49,7 @@ from vllm.model_executor.weight_utils import (default_weight_loader, hf_model_weights_iterator) from vllm.sequence import SamplerOutput +from vllm.model_executor.sampling_metadata import SamplingMetadata, SamplingTensors KVCache = Tuple[torch.Tensor, torch.Tensor] @@ -66,6 +74,86 @@ def get_slopes_power_of_2(n): return (get_slopes_power_of_2(closest_power_of_2) + _get_alibi_slopes( 2 * closest_power_of_2)[0::2][:n - closest_power_of_2]) +class JAISSampler(Sampler): + + def __init__(self, + vocab_size: int, + org_vocab_size: Optional[int] = None) -> None: + super().__init__(vocab_size, org_vocab_size) + + def forward( + self, + embedding: torch.Tensor, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + output_logits_scale: float, + embedding_bias: Optional[torch.Tensor] = None, + ) -> Optional[SamplerOutput]: + # Get the hidden states that we use for sampling. + if self.logits_as_hidden_states: + logits = hidden_states + else: + hidden_states = _prune_hidden_states(hidden_states, + sampling_metadata) + + # Get the logits for the next tokens. + logits = self._get_logits(hidden_states, embedding, embedding_bias) + logits *= torch.tensor( + float(output_logits_scale), dtype=logits.dtype + ) + + + # Only perform sampling in the driver worker. + # Note: `_get_logits` is still distributed across TP workers because + # the `embedding` weight is distributed across TP workers. + # TODO(zhuohan): Change the get_logits part to a separate stage. + if not sampling_metadata.perform_sampling: + return None + + assert logits is not None + _, vocab_size = logits.shape + + # Apply logits processors (if any). + logits = _apply_logits_processors(logits, sampling_metadata) + + # Prepare sampling tensors with pinned memory to avoid blocking. + (sampling_tensors, do_penalties, do_top_p_top_k, + do_min_p) = SamplingTensors.from_sampling_metadata( + sampling_metadata, vocab_size, logits.device, logits.dtype) + + # Apply presence and frequency penalties. + if do_penalties: + logits = _apply_penalties(logits, sampling_tensors.prompt_tokens, + sampling_tensors.output_tokens, + sampling_tensors.presence_penalties, + sampling_tensors.frequency_penalties, + sampling_tensors.repetition_penalties) + + # Apply temperature scaling. + # Use in-place division to avoid creating a new tensor. + logits.div_(sampling_tensors.temperatures.unsqueeze_(dim=1)) + + if do_top_p_top_k: + logits = _apply_top_k_top_p(logits, sampling_tensors.top_ps, + sampling_tensors.top_ks) + + if do_min_p: + logits = _apply_min_p(logits, sampling_tensors.min_ps) + + # We use float32 for probabilities and log probabilities. + # Compute the probabilities. + probs = torch.softmax(logits, dim=-1, dtype=torch.float) + # Compute the log probabilities. + # Use log_softmax to ensure numerical stability. + logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float) + + # Sample the next tokens. + sample_results = _sample(probs, logprobs, sampling_metadata) + # Get the logprobs query results. + prompt_logprobs, sample_logprobs = _get_logprobs( + logprobs, sampling_metadata, sample_results) + return _build_sampler_output(sample_results, sampling_metadata, + prompt_logprobs, sample_logprobs) class JAISAttention(nn.Module): @@ -103,7 +191,7 @@ def __init__( head_start = tp_rank * self.num_heads head_end = (tp_rank + 1) * self.num_heads alibi_slopes = _get_alibi_slopes(total_num_heads) - alibi_slopes = alibi_slopes[head_start:head_end].tolist() + alibi_slopes = alibi_slopes[head_start:head_end] self.attn = PagedAttention(self.num_heads, self.head_dim, scale=self.scale, @@ -153,8 +241,8 @@ def __init__( bias=True, linear_method=linear_method, ) - # quant_config = getattr(linear_method, "quant_config", None) - self.act = SwiGLUActivation() # if self.swiglu else self.act_gpt2 + + self.act = SwiGLUActivation() def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: if self.swiglu: @@ -225,7 +313,10 @@ def __init__( self.wpe = nn.Embedding( config.max_position_embeddings, self.embed_dim ) if config.position_embedding_type != "alibi" else None - self.embeddings_scale = config.mup_embeddings_scale + if hasattr(config, 'embeddings_scale'): + self.embeddings_scale = config.embeddings_scale + else: + self.embeddings_scale = config.mup_embeddings_scale self.h = nn.ModuleList([ JAISBlock(config, linear_method) for _ in range(config.num_hidden_layers) @@ -245,6 +336,9 @@ def forward( hidden_states = inputs_embeds + position_embeds else: hidden_states = inputs_embeds + hidden_states *= torch.tensor( + float(self.embeddings_scale), dtype=hidden_states.dtype + ) for i in range(len(self.h)): layer = self.h[i] @@ -266,8 +360,12 @@ def __init__( self.linear_method = linear_method self.transformer = JAISModel(config, linear_method) self.lm_head_weight = self.transformer.wte.weight - self.sampler = Sampler(config.vocab_size) - + if hasattr(config, 'width_scale'): + self.output_logits_scale = config.width_scale + else: + self.output_logits_scale = config.mup_output_alpha * config.mup_width_scale + self.sampler = JAISSampler(config.vocab_size) + def forward( self, input_ids: torch.Tensor, @@ -285,7 +383,7 @@ def sample( sampling_metadata: SamplingMetadata, ) -> Optional[SamplerOutput]: next_tokens = self.sampler(self.lm_head_weight, hidden_states, - sampling_metadata) + sampling_metadata, self.output_logits_scale) return next_tokens def load_weights(self, From 697969d9f4156ca650b25542ce7f5e9363b0fe73 Mon Sep 17 00:00:00 2001 From: Lalit Pradhan Date: Tue, 5 Mar 2024 14:42:06 +0000 Subject: [PATCH 12/31] ruff and yapf --- vllm/model_executor/models/jais.py | 30 ++++++++++++------------------ 1 file changed, 12 insertions(+), 18 deletions(-) diff --git a/vllm/model_executor/models/jais.py b/vllm/model_executor/models/jais.py index 5a8d5b1ac725..872e51a11fb6 100644 --- a/vllm/model_executor/models/jais.py +++ b/vllm/model_executor/models/jais.py @@ -32,20 +32,14 @@ LinearMethodBase, QKVParallelLinear, RowParallelLinear) -from vllm.model_executor.layers.sampler import (Sampler, - _prune_hidden_states, - _apply_logits_processors, - _apply_penalties, - _apply_top_k_top_p, - _apply_min_p, - _sample, - _get_logprobs, - _build_sampler_output) +from vllm.model_executor.layers.sampler import ( + Sampler, _prune_hidden_states, _apply_logits_processors, _apply_penalties, + _apply_top_k_top_p, _apply_min_p, _sample, _get_logprobs, + _build_sampler_output) from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) from vllm.model_executor.parallel_utils.parallel_state import ( get_tensor_model_parallel_world_size, get_tensor_model_parallel_rank) -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.weight_utils import (default_weight_loader, hf_model_weights_iterator) from vllm.sequence import SamplerOutput @@ -74,6 +68,7 @@ def get_slopes_power_of_2(n): return (get_slopes_power_of_2(closest_power_of_2) + _get_alibi_slopes( 2 * closest_power_of_2)[0::2][:n - closest_power_of_2]) + class JAISSampler(Sampler): def __init__(self, @@ -98,9 +93,8 @@ def forward( # Get the logits for the next tokens. logits = self._get_logits(hidden_states, embedding, embedding_bias) - logits *= torch.tensor( - float(output_logits_scale), dtype=logits.dtype - ) + logits *= torch.tensor(float(output_logits_scale), + dtype=logits.dtype) # Only perform sampling in the driver worker. @@ -155,6 +149,7 @@ def forward( return _build_sampler_output(sample_results, sampling_metadata, prompt_logprobs, sample_logprobs) + class JAISAttention(nn.Module): def __init__( @@ -241,7 +236,7 @@ def __init__( bias=True, linear_method=linear_method, ) - + self.act = SwiGLUActivation() def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: @@ -336,9 +331,8 @@ def forward( hidden_states = inputs_embeds + position_embeds else: hidden_states = inputs_embeds - hidden_states *= torch.tensor( - float(self.embeddings_scale), dtype=hidden_states.dtype - ) + hidden_states *= torch.tensor(float(self.embeddings_scale), + dtype=hidden_states.dtype) for i in range(len(self.h)): layer = self.h[i] @@ -365,7 +359,7 @@ def __init__( else: self.output_logits_scale = config.mup_output_alpha * config.mup_width_scale self.sampler = JAISSampler(config.vocab_size) - + def forward( self, input_ids: torch.Tensor, From 1d430438d78e23977d8be11caeacdcf8664f2f25 Mon Sep 17 00:00:00 2001 From: Lalit Pradhan Date: Tue, 5 Mar 2024 14:44:58 +0000 Subject: [PATCH 13/31] ruff and yapf --- vllm/model_executor/models/jais.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/model_executor/models/jais.py b/vllm/model_executor/models/jais.py index 872e51a11fb6..e8f7c614894b 100644 --- a/vllm/model_executor/models/jais.py +++ b/vllm/model_executor/models/jais.py @@ -95,7 +95,6 @@ def forward( logits = self._get_logits(hidden_states, embedding, embedding_bias) logits *= torch.tensor(float(output_logits_scale), dtype=logits.dtype) - # Only perform sampling in the driver worker. # Note: `_get_logits` is still distributed across TP workers because From 6e4b06e417f29f24a8d2b2d5c87d363759dc3189 Mon Sep 17 00:00:00 2001 From: Lalit Pradhan Date: Tue, 5 Mar 2024 15:13:31 +0000 Subject: [PATCH 14/31] fixed bug in config.scale_qk_dot_by_d --- vllm/model_executor/models/jais.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/vllm/model_executor/models/jais.py b/vllm/model_executor/models/jais.py index e8f7c614894b..6d4cefe4f414 100644 --- a/vllm/model_executor/models/jais.py +++ b/vllm/model_executor/models/jais.py @@ -164,6 +164,8 @@ def __init__( assert total_num_heads % tensor_model_parallel_world_size == 0 self.num_heads = total_num_heads // tensor_model_parallel_world_size self.head_dim = self.hidden_size // total_num_heads + if hasattr(config, "scale_qk_dot_by_d"): + config.mup_scale_qk_dot_by_d = config.scale_qk_dot_by_d self.attn_scale_power = 1.0 if config.mup_scale_qk_dot_by_d else 0.5 self.scale = self.head_dim**-self.attn_scale_power From 4321fc4d6780be628bdeba07010be0991afd1f6e Mon Sep 17 00:00:00 2001 From: Lalit Pradhan Date: Tue, 5 Mar 2024 15:39:30 +0000 Subject: [PATCH 15/31] updated architectures in config --- vllm/transformers_utils/configs/jais.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/vllm/transformers_utils/configs/jais.py b/vllm/transformers_utils/configs/jais.py index cd31f28c9e80..fb4c81dfe5a5 100644 --- a/vllm/transformers_utils/configs/jais.py +++ b/vllm/transformers_utils/configs/jais.py @@ -136,6 +136,7 @@ def __init__( mup_output_alpha=1.0, mup_scale_qk_dot_by_d=False, alibi_scaling=None, + architectures=['JAISLMHeadModel'], **kwargs, ): self.vocab_size = vocab_size @@ -169,6 +170,7 @@ def __init__( super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, + architectures=architectures, **kwargs) def _alibi_scaling_validation(self): From a6166d1e27412332e657afd7135fc63f542694be Mon Sep 17 00:00:00 2001 From: Lalit Pradhan Date: Tue, 5 Mar 2024 15:44:46 +0000 Subject: [PATCH 16/31] apply ruff --- vllm/transformers_utils/configs/jais.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/vllm/transformers_utils/configs/jais.py b/vllm/transformers_utils/configs/jais.py index fb4c81dfe5a5..9ec2ee2eba9f 100644 --- a/vllm/transformers_utils/configs/jais.py +++ b/vllm/transformers_utils/configs/jais.py @@ -85,6 +85,8 @@ class JAISConfig(PretrainedConfig): or `train_seq_len` for dynamic scaling on input samples with sequence length > `train_seq_len`. The expected formats are `{"type": strategy name, "factor": scaling factor}` or `{"type": strategy name, "train_seq_len": training sequence length}`. + architectures (`List`, *optional*, defaults to ['JAISLMHeadModel']): + architecture names for Jais. Example: From b68e2b170292f64202d9e26a69403483b3ebca89 Mon Sep 17 00:00:00 2001 From: Lalit Pradhan Date: Tue, 5 Mar 2024 15:47:40 +0000 Subject: [PATCH 17/31] apply ruff --- vllm/transformers_utils/configs/jais.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/vllm/transformers_utils/configs/jais.py b/vllm/transformers_utils/configs/jais.py index 9ec2ee2eba9f..1485c5932c8d 100644 --- a/vllm/transformers_utils/configs/jais.py +++ b/vllm/transformers_utils/configs/jais.py @@ -138,7 +138,7 @@ def __init__( mup_output_alpha=1.0, mup_scale_qk_dot_by_d=False, alibi_scaling=None, - architectures=['JAISLMHeadModel'], + architectures=None, **kwargs, ): self.vocab_size = vocab_size @@ -169,6 +169,8 @@ def __init__( self.alibi_scaling = alibi_scaling self._alibi_scaling_validation() + if architectures==None: + architectures=['JAISLMHeadModel'] super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, From 51b745a177aaea275c734355ae76a8da7e09bfbc Mon Sep 17 00:00:00 2001 From: Lalit Pradhan Date: Tue, 5 Mar 2024 15:49:25 +0000 Subject: [PATCH 18/31] apply ruff --- vllm/transformers_utils/configs/jais.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/transformers_utils/configs/jais.py b/vllm/transformers_utils/configs/jais.py index 1485c5932c8d..d5d46636b03b 100644 --- a/vllm/transformers_utils/configs/jais.py +++ b/vllm/transformers_utils/configs/jais.py @@ -169,7 +169,7 @@ def __init__( self.alibi_scaling = alibi_scaling self._alibi_scaling_validation() - if architectures==None: + if architectures is None: architectures=['JAISLMHeadModel'] super().__init__(bos_token_id=bos_token_id, From 4fb9fe90e9a095da56ad96477f5a16ce07ad8c9c Mon Sep 17 00:00:00 2001 From: Lalit Pradhan Date: Tue, 5 Mar 2024 15:51:49 +0000 Subject: [PATCH 19/31] apply yapf --- vllm/transformers_utils/configs/jais.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/transformers_utils/configs/jais.py b/vllm/transformers_utils/configs/jais.py index d5d46636b03b..447bd03ba82e 100644 --- a/vllm/transformers_utils/configs/jais.py +++ b/vllm/transformers_utils/configs/jais.py @@ -170,7 +170,7 @@ def __init__( self.alibi_scaling = alibi_scaling self._alibi_scaling_validation() if architectures is None: - architectures=['JAISLMHeadModel'] + architectures = ['JAISLMHeadModel'] super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, From e04e56d981bc3fc81f94307c5c274ed065a13bd3 Mon Sep 17 00:00:00 2001 From: Lalit Pradhan Date: Wed, 6 Mar 2024 19:50:23 +0000 Subject: [PATCH 20/31] fixed bug in multi GPU setting --- vllm/model_executor/models/jais.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/models/jais.py b/vllm/model_executor/models/jais.py index 6d4cefe4f414..5cbc38ece9fa 100644 --- a/vllm/model_executor/models/jais.py +++ b/vllm/model_executor/models/jais.py @@ -93,8 +93,9 @@ def forward( # Get the logits for the next tokens. logits = self._get_logits(hidden_states, embedding, embedding_bias) - logits *= torch.tensor(float(output_logits_scale), - dtype=logits.dtype) + if logits is not None: + logits *= torch.tensor(float(output_logits_scale), + dtype=logits.dtype) # Only perform sampling in the driver worker. # Note: `_get_logits` is still distributed across TP workers because From 8fd0aecb5d332326302bd939b2cac36dad313c1d Mon Sep 17 00:00:00 2001 From: Lalit Pradhan Date: Sat, 9 Mar 2024 05:39:31 +0000 Subject: [PATCH 21/31] adapted to PR #3005 --- vllm/model_executor/models/jais.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/models/jais.py b/vllm/model_executor/models/jais.py index 5cbc38ece9fa..6848eb904a98 100644 --- a/vllm/model_executor/models/jais.py +++ b/vllm/model_executor/models/jais.py @@ -27,7 +27,7 @@ from vllm.transformers_utils.configs import JAISConfig from vllm.model_executor.input_metadata import InputMetadata -from vllm.model_executor.layers.attention import PagedAttention +from vllm.model_executor.layers.attention import Attention from vllm.model_executor.layers.linear import (ColumnParallelLinear, LinearMethodBase, QKVParallelLinear, @@ -189,7 +189,7 @@ def __init__( head_end = (tp_rank + 1) * self.num_heads alibi_slopes = _get_alibi_slopes(total_num_heads) alibi_slopes = alibi_slopes[head_start:head_end] - self.attn = PagedAttention(self.num_heads, + self.attn = Attention(self.num_heads, self.head_dim, scale=self.scale, alibi_slopes=alibi_slopes) From a80e2dc506fe6c4a7769eeb4016fe850d4354820 Mon Sep 17 00:00:00 2001 From: Lalit Pradhan Date: Sat, 9 Mar 2024 05:43:16 +0000 Subject: [PATCH 22/31] apply yapf --- vllm/model_executor/models/jais.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/models/jais.py b/vllm/model_executor/models/jais.py index 6848eb904a98..8fce719bc3f1 100644 --- a/vllm/model_executor/models/jais.py +++ b/vllm/model_executor/models/jais.py @@ -190,9 +190,9 @@ def __init__( alibi_slopes = _get_alibi_slopes(total_num_heads) alibi_slopes = alibi_slopes[head_start:head_end] self.attn = Attention(self.num_heads, - self.head_dim, - scale=self.scale, - alibi_slopes=alibi_slopes) + self.head_dim, + scale=self.scale, + alibi_slopes=alibi_slopes) def forward( self, From ade4c0ab24fd767eca95cc8b7bb842151e3ea0f5 Mon Sep 17 00:00:00 2001 From: Lalit Pradhan Date: Mon, 11 Mar 2024 07:13:34 +0000 Subject: [PATCH 23/31] apply ruff --- vllm/model_executor/models/jais.py | 224 ++++++++++++++---------- vllm/transformers_utils/configs/jais.py | 41 +++-- 2 files changed, 162 insertions(+), 103 deletions(-) diff --git a/vllm/model_executor/models/jais.py b/vllm/model_executor/models/jais.py index 8fce719bc3f1..0942a0de5c96 100644 --- a/vllm/model_executor/models/jais.py +++ b/vllm/model_executor/models/jais.py @@ -17,7 +17,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" Inference-only Jais model compatible with HuggingFace weights.""" +"""Inference-only Jais model compatible with HuggingFace weights.""" import math from typing import List, Optional, Tuple @@ -28,52 +28,70 @@ from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.layers.attention import Attention -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - LinearMethodBase, - QKVParallelLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + LinearMethodBase, + QKVParallelLinear, + RowParallelLinear, +) from vllm.model_executor.layers.sampler import ( - Sampler, _prune_hidden_states, _apply_logits_processors, _apply_penalties, - _apply_top_k_top_p, _apply_min_p, _sample, _get_logprobs, - _build_sampler_output) + Sampler, + _prune_hidden_states, + _apply_logits_processors, + _apply_penalties, + _apply_top_k_top_p, + _apply_min_p, + _sample, + _get_logprobs, + _build_sampler_output, +) from vllm.model_executor.layers.vocab_parallel_embedding import ( - VocabParallelEmbedding) + VocabParallelEmbedding, +) from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_world_size, get_tensor_model_parallel_rank) -from vllm.model_executor.weight_utils import (default_weight_loader, - hf_model_weights_iterator) + get_tensor_model_parallel_world_size, + get_tensor_model_parallel_rank, +) +from vllm.model_executor.weight_utils import ( + default_weight_loader, + hf_model_weights_iterator, +) from vllm.sequence import SamplerOutput -from vllm.model_executor.sampling_metadata import SamplingMetadata, SamplingTensors +from vllm.model_executor.sampling_metadata import ( + SamplingMetadata, + SamplingTensors, +) KVCache = Tuple[torch.Tensor, torch.Tensor] class SwiGLUActivation(nn.Module): - def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor: return x1 * nn.functional.silu(x2) def _get_alibi_slopes(n): - def get_slopes_power_of_2(n): - start = 2**(-(2**-(math.log2(n) - 3))) + start = 2 ** (-(2 ** -(math.log2(n) - 3))) ratio = start return [start * ratio**i for i in range(n)] if math.log2(n).is_integer(): return get_slopes_power_of_2(n) else: - closest_power_of_2 = 2**math.floor(math.log2(n)) - return (get_slopes_power_of_2(closest_power_of_2) + _get_alibi_slopes( - 2 * closest_power_of_2)[0::2][:n - closest_power_of_2]) + closest_power_of_2 = 2 ** math.floor(math.log2(n)) + return ( + get_slopes_power_of_2(closest_power_of_2) + + _get_alibi_slopes(2 * closest_power_of_2)[0::2][ + : n - closest_power_of_2 + ] + ) class JAISSampler(Sampler): - - def __init__(self, - vocab_size: int, - org_vocab_size: Optional[int] = None) -> None: + def __init__( + self, vocab_size: int, org_vocab_size: Optional[int] = None + ) -> None: super().__init__(vocab_size, org_vocab_size) def forward( @@ -88,14 +106,16 @@ def forward( if self.logits_as_hidden_states: logits = hidden_states else: - hidden_states = _prune_hidden_states(hidden_states, - sampling_metadata) + hidden_states = _prune_hidden_states( + hidden_states, sampling_metadata + ) # Get the logits for the next tokens. logits = self._get_logits(hidden_states, embedding, embedding_bias) if logits is not None: - logits *= torch.tensor(float(output_logits_scale), - dtype=logits.dtype) + logits *= torch.tensor( + float(output_logits_scale), dtype=logits.dtype + ) # Only perform sampling in the driver worker. # Note: `_get_logits` is still distributed across TP workers because @@ -111,25 +131,31 @@ def forward( logits = _apply_logits_processors(logits, sampling_metadata) # Prepare sampling tensors with pinned memory to avoid blocking. - (sampling_tensors, do_penalties, do_top_p_top_k, - do_min_p) = SamplingTensors.from_sampling_metadata( - sampling_metadata, vocab_size, logits.device, logits.dtype) + (sampling_tensors, do_penalties, do_top_p_top_k, do_min_p) = ( + SamplingTensors.from_sampling_metadata( + sampling_metadata, vocab_size, logits.device, logits.dtype + ) + ) # Apply presence and frequency penalties. if do_penalties: - logits = _apply_penalties(logits, sampling_tensors.prompt_tokens, - sampling_tensors.output_tokens, - sampling_tensors.presence_penalties, - sampling_tensors.frequency_penalties, - sampling_tensors.repetition_penalties) + logits = _apply_penalties( + logits, + sampling_tensors.prompt_tokens, + sampling_tensors.output_tokens, + sampling_tensors.presence_penalties, + sampling_tensors.frequency_penalties, + sampling_tensors.repetition_penalties, + ) # Apply temperature scaling. # Use in-place division to avoid creating a new tensor. logits.div_(sampling_tensors.temperatures.unsqueeze_(dim=1)) if do_top_p_top_k: - logits = _apply_top_k_top_p(logits, sampling_tensors.top_ps, - sampling_tensors.top_ks) + logits = _apply_top_k_top_p( + logits, sampling_tensors.top_ps, sampling_tensors.top_ks + ) if do_min_p: logits = _apply_min_p(logits, sampling_tensors.min_ps) @@ -145,13 +171,14 @@ def forward( sample_results = _sample(probs, logprobs, sampling_metadata) # Get the logprobs query results. prompt_logprobs, sample_logprobs = _get_logprobs( - logprobs, sampling_metadata, sample_results) - return _build_sampler_output(sample_results, sampling_metadata, - prompt_logprobs, sample_logprobs) + logprobs, sampling_metadata, sample_results + ) + return _build_sampler_output( + sample_results, sampling_metadata, prompt_logprobs, sample_logprobs + ) class JAISAttention(nn.Module): - def __init__( self, config: JAISConfig, @@ -161,7 +188,8 @@ def __init__( self.hidden_size = config.hidden_size total_num_heads = config.num_attention_heads tensor_model_parallel_world_size = ( - get_tensor_model_parallel_world_size()) + get_tensor_model_parallel_world_size() + ) assert total_num_heads % tensor_model_parallel_world_size == 0 self.num_heads = total_num_heads // tensor_model_parallel_world_size self.head_dim = self.hidden_size // total_num_heads @@ -189,10 +217,12 @@ def __init__( head_end = (tp_rank + 1) * self.num_heads alibi_slopes = _get_alibi_slopes(total_num_heads) alibi_slopes = alibi_slopes[head_start:head_end] - self.attn = Attention(self.num_heads, - self.head_dim, - scale=self.scale, - alibi_slopes=alibi_slopes) + self.attn = Attention( + self.num_heads, + self.head_dim, + scale=self.scale, + alibi_slopes=alibi_slopes, + ) def forward( self, @@ -203,14 +233,12 @@ def forward( qkv, _ = self.c_attn(hidden_states) q, k, v = qkv.chunk(chunks=3, dim=-1) key_cache, value_cache = kv_cache - attn_output = self.attn(q, k, v, key_cache, value_cache, - input_metadata) + attn_output = self.attn(q, k, v, key_cache, value_cache, input_metadata) attn_output, _ = self.c_proj(attn_output) return attn_output class JAISMLP(nn.Module): - def __init__( self, intermediate_size: int, @@ -226,12 +254,16 @@ def __init__( bias=True, linear_method=linear_method, ) - self.c_fc2 = ColumnParallelLinear( - hidden_size, - intermediate_size, - bias=True, - linear_method=linear_method, - ) if self.swiglu else None + self.c_fc2 = ( + ColumnParallelLinear( + hidden_size, + intermediate_size, + bias=True, + linear_method=linear_method, + ) + if self.swiglu + else None + ) self.c_proj = RowParallelLinear( intermediate_size, hidden_size, @@ -245,15 +277,16 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: if self.swiglu: hidden_states2, _ = self.c_fc2(hidden_states) hidden_states, _ = self.c_fc(hidden_states) - hidden_states = self.act( - hidden_states, - hidden_states2) if self.swiglu else self.act(hidden_states) + hidden_states = ( + self.act(hidden_states, hidden_states2) + if self.swiglu + else self.act(hidden_states) + ) hidden_states, _ = self.c_proj(hidden_states) return hidden_states class JAISBlock(nn.Module): - def __init__( self, config: JAISConfig, @@ -261,8 +294,9 @@ def __init__( ): super().__init__() hidden_size = config.hidden_size - inner_dim = (config.n_inner if config.n_inner is not None else 4 * - hidden_size) + inner_dim = ( + config.n_inner if config.n_inner is not None else 4 * hidden_size + ) self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) self.attn = JAISAttention(config, linear_method) @@ -294,7 +328,6 @@ def forward( class JAISModel(nn.Module): - def __init__( self, config: JAISConfig, @@ -307,17 +340,21 @@ def __init__( assert not config.reorder_and_upcast_attn self.embed_dim = config.hidden_size self.wte = VocabParallelEmbedding(config.vocab_size, self.embed_dim) - self.wpe = nn.Embedding( - config.max_position_embeddings, self.embed_dim - ) if config.position_embedding_type != "alibi" else None - if hasattr(config, 'embeddings_scale'): + self.wpe = ( + nn.Embedding(config.max_position_embeddings, self.embed_dim) + if config.position_embedding_type != "alibi" + else None + ) + if hasattr(config, "embeddings_scale"): self.embeddings_scale = config.embeddings_scale else: self.embeddings_scale = config.mup_embeddings_scale - self.h = nn.ModuleList([ - JAISBlock(config, linear_method) - for _ in range(config.num_hidden_layers) - ]) + self.h = nn.ModuleList( + [ + JAISBlock(config, linear_method) + for _ in range(config.num_hidden_layers) + ] + ) self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) def forward( @@ -333,8 +370,9 @@ def forward( hidden_states = inputs_embeds + position_embeds else: hidden_states = inputs_embeds - hidden_states *= torch.tensor(float(self.embeddings_scale), - dtype=hidden_states.dtype) + hidden_states *= torch.tensor( + float(self.embeddings_scale), dtype=hidden_states.dtype + ) for i in range(len(self.h)): layer = self.h[i] @@ -345,7 +383,6 @@ def forward( class JAISLMHeadModel(nn.Module): - def __init__( self, config: JAISConfig, @@ -356,10 +393,12 @@ def __init__( self.linear_method = linear_method self.transformer = JAISModel(config, linear_method) self.lm_head_weight = self.transformer.wte.weight - if hasattr(config, 'width_scale'): + if hasattr(config, "width_scale"): self.output_logits_scale = config.width_scale else: - self.output_logits_scale = config.mup_output_alpha * config.mup_width_scale + self.output_logits_scale = ( + config.mup_output_alpha * config.mup_width_scale + ) self.sampler = JAISSampler(config.vocab_size) def forward( @@ -369,8 +408,9 @@ def forward( kv_caches: List[KVCache], input_metadata: InputMetadata, ) -> torch.Tensor: - hidden_states = self.transformer(input_ids, positions, kv_caches, - input_metadata) + hidden_states = self.transformer( + input_ids, positions, kv_caches, input_metadata + ) return hidden_states def sample( @@ -378,18 +418,25 @@ def sample( hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(self.lm_head_weight, hidden_states, - sampling_metadata, self.output_logits_scale) + next_tokens = self.sampler( + self.lm_head_weight, + hidden_states, + sampling_metadata, + self.output_logits_scale, + ) return next_tokens - def load_weights(self, - model_name_or_path: str, - cache_dir: Optional[str] = None, - load_format: str = "auto", - revision: Optional[str] = None): + def load_weights( + self, + model_name_or_path: str, + cache_dir: Optional[str] = None, + load_format: str = "auto", + revision: Optional[str] = None, + ): params_dict = dict(self.named_parameters(remove_duplicate=False)) for name, loaded_weight in hf_model_weights_iterator( - model_name_or_path, cache_dir, load_format, revision): + model_name_or_path, cache_dir, load_format, revision + ): if "lm_head.weight" in name: # GPT-2 ties the weights of the embedding layer and the final # linear layer. @@ -412,6 +459,7 @@ def load_weights(self, if not name.endswith(".weight"): continue loaded_weight = loaded_weight.t() - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) \ No newline at end of file + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) + weight_loader(param, loaded_weight) diff --git a/vllm/transformers_utils/configs/jais.py b/vllm/transformers_utils/configs/jais.py index 447bd03ba82e..6b463ce30d7c 100644 --- a/vllm/transformers_utils/configs/jais.py +++ b/vllm/transformers_utils/configs/jais.py @@ -14,7 +14,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" JAIS configuration""" +"""JAIS configuration""" from transformers.configuration_utils import PretrainedConfig from transformers.utils import logging @@ -81,7 +81,7 @@ class JAISConfig(PretrainedConfig): scale_attn_weights to `True` as well. alibi_scaling (`Dict`, *optional*): Dictionary containing the scaling configuration for ALiBi embeddings. Currently only supports linear - scaling strategy. Can specify either the scaling `factor` (must be a float greater than 1) for fixed scaling + scaling strategy. Can specify either the scaling `factor` (must be a float greater than 1) for fixed scaling or `train_seq_len` for dynamic scaling on input samples with sequence length > `train_seq_len`. The expected formats are `{"type": strategy name, "factor": scaling factor}` or `{"type": strategy name, "train_seq_len": training sequence length}`. @@ -170,12 +170,14 @@ def __init__( self.alibi_scaling = alibi_scaling self._alibi_scaling_validation() if architectures is None: - architectures = ['JAISLMHeadModel'] + architectures = ["JAISLMHeadModel"] - super().__init__(bos_token_id=bos_token_id, - eos_token_id=eos_token_id, - architectures=architectures, - **kwargs) + super().__init__( + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + architectures=architectures, + **kwargs, + ) def _alibi_scaling_validation(self): """ @@ -184,11 +186,14 @@ def _alibi_scaling_validation(self): if self.alibi_scaling is None: return - if not isinstance(self.alibi_scaling, - dict) or len(self.alibi_scaling) != 2: + if ( + not isinstance(self.alibi_scaling, dict) + or len(self.alibi_scaling) != 2 + ): raise ValueError( "`alibi_scaling` must be a dictionary with two fields, `type` and `factor` or `type` and `train_seq_len`, " - f"got {self.alibi_scaling}") + f"got {self.alibi_scaling}" + ) alibi_scaling_type = self.alibi_scaling.get("type", None) alibi_scaling_factor = self.alibi_scaling.get("factor", None) alibi_dynamic_scaling = self.alibi_scaling.get("train_seq_len", None) @@ -196,13 +201,19 @@ def _alibi_scaling_validation(self): raise ValueError( f"`alibi_scaling`'s type field must be 'linear', got {alibi_scaling_type}" ) - if alibi_scaling_factor is not None and not isinstance( - alibi_scaling_factor, float) or alibi_scaling_factor <= 1.0: + if ( + alibi_scaling_factor is not None + and not isinstance(alibi_scaling_factor, float) + or alibi_scaling_factor <= 1.0 + ): raise ValueError( f"`alibi_scaling`'s factor field must be a float > 1.0, got {alibi_scaling_factor}" ) - if alibi_dynamic_scaling is not None and not isinstance( - alibi_dynamic_scaling, int) or alibi_dynamic_scaling <= 1: + if ( + alibi_dynamic_scaling is not None + and not isinstance(alibi_dynamic_scaling, int) + or alibi_dynamic_scaling <= 1 + ): raise ValueError( f"`alibi_scaling`'s `train_seq_len` field must be an integer > 1, got {alibi_dynamic_scaling}" - ) \ No newline at end of file + ) From b4012a2135e727cbb5ae104585aebdd1417197fe Mon Sep 17 00:00:00 2001 From: Lalit Pradhan Date: Mon, 11 Mar 2024 07:27:43 +0000 Subject: [PATCH 24/31] applied ruff --- vllm/transformers_utils/configs/jais.py | 86 ++++++++++++++++--------- 1 file changed, 56 insertions(+), 30 deletions(-) diff --git a/vllm/transformers_utils/configs/jais.py b/vllm/transformers_utils/configs/jais.py index 6b463ce30d7c..11f5c00a4c75 100644 --- a/vllm/transformers_utils/configs/jais.py +++ b/vllm/transformers_utils/configs/jais.py @@ -24,32 +24,40 @@ class JAISConfig(PretrainedConfig): """ - This is the configuration class to store the configuration of a [`JAISModel`]. It is used to instantiate a JAIS - model according to the specified arguments, defining the model architecture. + This is the configuration class to store the configuration of a + [`JAISModel`]. It is used to instantiate a JAIS model according to the + specified arguments, defining the model architecture. - Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the - documentation from [`PretrainedConfig`] for more information. + Configuration objects inherit from [`PretrainedConfig`] and can be used + to control the model outputs. Read the documentation from + [`PretrainedConfig`] for more information. Args: vocab_size (`int`, *optional*, defaults to 50257): - Vocabulary size of the JAIS model. Defines the number of different tokens that can be represented by the + Vocabulary size of the JAIS model. Defines the number of different + tokens that can be represented by the `inputs_ids` passed when calling [`JAISModel`]. n_positions (`int`, *optional*, defaults to 1024): - The maximum sequence length that this model might ever be used with. Typically set this to something large - just in case (e.g., 512 or 1024 or 2048). + The maximum sequence length that this model might ever be used + with. Typically set this to something large just in case + (e.g., 512 or 1024 or 2048). n_embd (`int`, *optional*, defaults to 768): Dimensionality of the embeddings and hidden states. n_layer (`int`, *optional*, defaults to 12): Number of hidden layers in the Transformer encoder. n_head (`int`, *optional*, defaults to 12): - Number of attention heads for each attention layer in the Transformer encoder. + Number of attention heads for each attention layer in the + Transformer encoder. n_inner (`int`, *optional*, defaults to None): - Dimensionality of the inner feed-forward layers. `None` will set it to 4 times n_embd + Dimensionality of the inner feed-forward layers. `None` will set + it to 4 times n_embd activation_function (`str`, *optional*, defaults to `"gelu"`): - Activation function, to be selected in the list `["relu", "silu", "gelu", "tanh", "gelu_new", "swiglu"]`. + Activation function, to be selected in the list + `["relu", "silu", "gelu", "tanh", "gelu_new", "swiglu"]`. resid_pdrop (`float`, *optional*, defaults to 0.1): - The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + The dropout probability for all fully connected layers in + the embeddings, encoder, and pooler. embd_pdrop (`float`, *optional*, defaults to 0.1): The dropout ratio for the embeddings. attn_pdrop (`float`, *optional*, defaults to 0.1): @@ -57,34 +65,48 @@ class JAISConfig(PretrainedConfig): layer_norm_epsilon (`float`, *optional*, defaults to 1e-5): The epsilon to use in the layer normalization layers. initializer_range (`float`, *optional*, defaults to 0.02): - The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + The standard deviation of the truncated_normal_initializer for + initializing all weight matrices. scale_attn_weights (`bool`, *optional*, defaults to `True`): Scale attention weights by dividing by sqrt(hidden_size).. use_cache (`bool`, *optional*, defaults to `True`): - Whether or not the model should return the last key/values attentions (not used by all models). - scale_attn_by_inverse_layer_idx (`bool`, *optional*, defaults to `False`): - Whether to additionally scale attention weights by `1 / layer_idx + 1`. + Whether or not the model should return the last key/values + attentions (not used by all models). + scale_attn_by_inverse_layer_idx (`bool`, *optional*, + defaults to `False`): + Whether to additionally scale attention weights by + `1 / layer_idx + 1`. reorder_and_upcast_attn (`bool`, *optional*, defaults to `False`): - Whether to scale keys (K) prior to computing attention (dot-product) and upcast attention - dot-product/softmax to float() when training with mixed precision. + Whether to scale keys (K) prior to computing attention + (dot-product) + and upcast attention dot-product/softmax to float() when training + with mixed precision. position_embedding_type (`str`, *optional*, defaults to `"learned"`): Positional embedding can be either `"alibi"` or `"learned"`. mup_width_scale (`float`, *optional*, defaults to 1.0): - muP parameter to scale learning rate and initializers. Calculated as (`d_model,0 / d_model`), where - `d_model` is the model's width and `d_model,0` is the proxy model's width. + muP parameter to scale learning rate and initializers. Calculated + as (`d_model,0 / d_model`), where + `d_model` is the model's width and `d_model,0` is the proxy + model's width. mup_embeddings_scale (`float`, *optional*, defaults to 1.0): muP parameter to scale token and position embeddings. mup_output_alpha (`float`, *optional*, defaults to 1.0): - muP parameter to scale output logits (`output_logits_scale = mup_output_alpha * mup_width_scale`). + muP parameter to scale output logits + (`output_logits_scale = mup_output_alpha * mup_width_scale`). mup_scale_qk_dot_by_d (`bool`, *optional*, defaults to `False`): - Scale attention weights by dividing by hidden_size instead of sqrt(hidden_size). Need to set - scale_attn_weights to `True` as well. + Scale attention weights by dividing by hidden_size instead of + sqrt(hidden_size). Need to set scale_attn_weights to `True` as + well. alibi_scaling (`Dict`, *optional*): - Dictionary containing the scaling configuration for ALiBi embeddings. Currently only supports linear - scaling strategy. Can specify either the scaling `factor` (must be a float greater than 1) for fixed scaling - or `train_seq_len` for dynamic scaling on input samples with sequence length > `train_seq_len`. The expected + Dictionary containing the scaling configuration for ALiBi + embeddings. Currently only supports linear + scaling strategy. Can specify either the scaling `factor` (must be + a float greater than 1) for fixed scaling + or `train_seq_len` for dynamic scaling on input samples with + sequence length > `train_seq_len`. The expected formats are `{"type": strategy name, "factor": scaling factor}` or - `{"type": strategy name, "train_seq_len": training sequence length}`. + `{"type": strategy name, + "train_seq_len": training sequence length}`. architectures (`List`, *optional*, defaults to ['JAISLMHeadModel']): architecture names for Jais. @@ -191,7 +213,8 @@ def _alibi_scaling_validation(self): or len(self.alibi_scaling) != 2 ): raise ValueError( - "`alibi_scaling` must be a dictionary with two fields, `type` and `factor` or `type` and `train_seq_len`, " + "`alibi_scaling` must be a dictionary with two fields," + "`type` and `factor` or `type` and `train_seq_len`, " f"got {self.alibi_scaling}" ) alibi_scaling_type = self.alibi_scaling.get("type", None) @@ -199,7 +222,8 @@ def _alibi_scaling_validation(self): alibi_dynamic_scaling = self.alibi_scaling.get("train_seq_len", None) if alibi_scaling_type is None or alibi_scaling_type != "linear": raise ValueError( - f"`alibi_scaling`'s type field must be 'linear', got {alibi_scaling_type}" + f"`alibi_scaling`'s type field must be 'linear'," + f"got {alibi_scaling_type}" ) if ( alibi_scaling_factor is not None @@ -207,7 +231,8 @@ def _alibi_scaling_validation(self): or alibi_scaling_factor <= 1.0 ): raise ValueError( - f"`alibi_scaling`'s factor field must be a float > 1.0, got {alibi_scaling_factor}" + f"`alibi_scaling`'s factor field must be a float > 1.0," + f"got {alibi_scaling_factor}" ) if ( alibi_dynamic_scaling is not None @@ -215,5 +240,6 @@ def _alibi_scaling_validation(self): or alibi_dynamic_scaling <= 1 ): raise ValueError( - f"`alibi_scaling`'s `train_seq_len` field must be an integer > 1, got {alibi_dynamic_scaling}" + f"`alibi_scaling`'s `train_seq_len` field must be an" + f"integer > 1, got {alibi_dynamic_scaling}" ) From 07d4e5d85a77c45e6c43cc7b03acb9f412ebd538 Mon Sep 17 00:00:00 2001 From: Lalit Pradhan Date: Mon, 11 Mar 2024 07:37:52 +0000 Subject: [PATCH 25/31] applied yapf --- vllm/model_executor/models/jais.py | 135 ++++++++++++----------------- 1 file changed, 57 insertions(+), 78 deletions(-) diff --git a/vllm/model_executor/models/jais.py b/vllm/model_executor/models/jais.py index 0942a0de5c96..261f570284ec 100644 --- a/vllm/model_executor/models/jais.py +++ b/vllm/model_executor/models/jais.py @@ -46,8 +46,7 @@ _build_sampler_output, ) from vllm.model_executor.layers.vocab_parallel_embedding import ( - VocabParallelEmbedding, -) + VocabParallelEmbedding, ) from vllm.model_executor.parallel_utils.parallel_state import ( get_tensor_model_parallel_world_size, get_tensor_model_parallel_rank, @@ -66,32 +65,31 @@ class SwiGLUActivation(nn.Module): + def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor: return x1 * nn.functional.silu(x2) def _get_alibi_slopes(n): + def get_slopes_power_of_2(n): - start = 2 ** (-(2 ** -(math.log2(n) - 3))) + start = 2**(-(2**-(math.log2(n) - 3))) ratio = start return [start * ratio**i for i in range(n)] if math.log2(n).is_integer(): return get_slopes_power_of_2(n) else: - closest_power_of_2 = 2 ** math.floor(math.log2(n)) - return ( - get_slopes_power_of_2(closest_power_of_2) - + _get_alibi_slopes(2 * closest_power_of_2)[0::2][ - : n - closest_power_of_2 - ] - ) + closest_power_of_2 = 2**math.floor(math.log2(n)) + return (get_slopes_power_of_2(closest_power_of_2) + _get_alibi_slopes( + 2 * closest_power_of_2)[0::2][:n - closest_power_of_2]) class JAISSampler(Sampler): - def __init__( - self, vocab_size: int, org_vocab_size: Optional[int] = None - ) -> None: + + def __init__(self, + vocab_size: int, + org_vocab_size: Optional[int] = None) -> None: super().__init__(vocab_size, org_vocab_size) def forward( @@ -106,16 +104,14 @@ def forward( if self.logits_as_hidden_states: logits = hidden_states else: - hidden_states = _prune_hidden_states( - hidden_states, sampling_metadata - ) + hidden_states = _prune_hidden_states(hidden_states, + sampling_metadata) # Get the logits for the next tokens. logits = self._get_logits(hidden_states, embedding, embedding_bias) if logits is not None: - logits *= torch.tensor( - float(output_logits_scale), dtype=logits.dtype - ) + logits *= torch.tensor(float(output_logits_scale), + dtype=logits.dtype) # Only perform sampling in the driver worker. # Note: `_get_logits` is still distributed across TP workers because @@ -131,11 +127,9 @@ def forward( logits = _apply_logits_processors(logits, sampling_metadata) # Prepare sampling tensors with pinned memory to avoid blocking. - (sampling_tensors, do_penalties, do_top_p_top_k, do_min_p) = ( - SamplingTensors.from_sampling_metadata( - sampling_metadata, vocab_size, logits.device, logits.dtype - ) - ) + (sampling_tensors, do_penalties, do_top_p_top_k, + do_min_p) = (SamplingTensors.from_sampling_metadata( + sampling_metadata, vocab_size, logits.device, logits.dtype)) # Apply presence and frequency penalties. if do_penalties: @@ -153,9 +147,8 @@ def forward( logits.div_(sampling_tensors.temperatures.unsqueeze_(dim=1)) if do_top_p_top_k: - logits = _apply_top_k_top_p( - logits, sampling_tensors.top_ps, sampling_tensors.top_ks - ) + logits = _apply_top_k_top_p(logits, sampling_tensors.top_ps, + sampling_tensors.top_ks) if do_min_p: logits = _apply_min_p(logits, sampling_tensors.min_ps) @@ -171,14 +164,13 @@ def forward( sample_results = _sample(probs, logprobs, sampling_metadata) # Get the logprobs query results. prompt_logprobs, sample_logprobs = _get_logprobs( - logprobs, sampling_metadata, sample_results - ) - return _build_sampler_output( - sample_results, sampling_metadata, prompt_logprobs, sample_logprobs - ) + logprobs, sampling_metadata, sample_results) + return _build_sampler_output(sample_results, sampling_metadata, + prompt_logprobs, sample_logprobs) class JAISAttention(nn.Module): + def __init__( self, config: JAISConfig, @@ -188,8 +180,7 @@ def __init__( self.hidden_size = config.hidden_size total_num_heads = config.num_attention_heads tensor_model_parallel_world_size = ( - get_tensor_model_parallel_world_size() - ) + get_tensor_model_parallel_world_size()) assert total_num_heads % tensor_model_parallel_world_size == 0 self.num_heads = total_num_heads // tensor_model_parallel_world_size self.head_dim = self.hidden_size // total_num_heads @@ -233,12 +224,14 @@ def forward( qkv, _ = self.c_attn(hidden_states) q, k, v = qkv.chunk(chunks=3, dim=-1) key_cache, value_cache = kv_cache - attn_output = self.attn(q, k, v, key_cache, value_cache, input_metadata) + attn_output = self.attn(q, k, v, key_cache, value_cache, + input_metadata) attn_output, _ = self.c_proj(attn_output) return attn_output class JAISMLP(nn.Module): + def __init__( self, intermediate_size: int, @@ -254,16 +247,12 @@ def __init__( bias=True, linear_method=linear_method, ) - self.c_fc2 = ( - ColumnParallelLinear( - hidden_size, - intermediate_size, - bias=True, - linear_method=linear_method, - ) - if self.swiglu - else None - ) + self.c_fc2 = (ColumnParallelLinear( + hidden_size, + intermediate_size, + bias=True, + linear_method=linear_method, + ) if self.swiglu else None) self.c_proj = RowParallelLinear( intermediate_size, hidden_size, @@ -277,16 +266,14 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: if self.swiglu: hidden_states2, _ = self.c_fc2(hidden_states) hidden_states, _ = self.c_fc(hidden_states) - hidden_states = ( - self.act(hidden_states, hidden_states2) - if self.swiglu - else self.act(hidden_states) - ) + hidden_states = (self.act(hidden_states, hidden_states2) + if self.swiglu else self.act(hidden_states)) hidden_states, _ = self.c_proj(hidden_states) return hidden_states class JAISBlock(nn.Module): + def __init__( self, config: JAISConfig, @@ -294,9 +281,8 @@ def __init__( ): super().__init__() hidden_size = config.hidden_size - inner_dim = ( - config.n_inner if config.n_inner is not None else 4 * hidden_size - ) + inner_dim = (config.n_inner if config.n_inner is not None else 4 * + hidden_size) self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) self.attn = JAISAttention(config, linear_method) @@ -328,6 +314,7 @@ def forward( class JAISModel(nn.Module): + def __init__( self, config: JAISConfig, @@ -340,21 +327,17 @@ def __init__( assert not config.reorder_and_upcast_attn self.embed_dim = config.hidden_size self.wte = VocabParallelEmbedding(config.vocab_size, self.embed_dim) - self.wpe = ( - nn.Embedding(config.max_position_embeddings, self.embed_dim) - if config.position_embedding_type != "alibi" - else None - ) + self.wpe = (nn.Embedding(config.max_position_embeddings, + self.embed_dim) + if config.position_embedding_type != "alibi" else None) if hasattr(config, "embeddings_scale"): self.embeddings_scale = config.embeddings_scale else: self.embeddings_scale = config.mup_embeddings_scale - self.h = nn.ModuleList( - [ - JAISBlock(config, linear_method) - for _ in range(config.num_hidden_layers) - ] - ) + self.h = nn.ModuleList([ + JAISBlock(config, linear_method) + for _ in range(config.num_hidden_layers) + ]) self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) def forward( @@ -370,9 +353,8 @@ def forward( hidden_states = inputs_embeds + position_embeds else: hidden_states = inputs_embeds - hidden_states *= torch.tensor( - float(self.embeddings_scale), dtype=hidden_states.dtype - ) + hidden_states *= torch.tensor(float(self.embeddings_scale), + dtype=hidden_states.dtype) for i in range(len(self.h)): layer = self.h[i] @@ -383,6 +365,7 @@ def forward( class JAISLMHeadModel(nn.Module): + def __init__( self, config: JAISConfig, @@ -396,9 +379,8 @@ def __init__( if hasattr(config, "width_scale"): self.output_logits_scale = config.width_scale else: - self.output_logits_scale = ( - config.mup_output_alpha * config.mup_width_scale - ) + self.output_logits_scale = (config.mup_output_alpha * + config.mup_width_scale) self.sampler = JAISSampler(config.vocab_size) def forward( @@ -408,9 +390,8 @@ def forward( kv_caches: List[KVCache], input_metadata: InputMetadata, ) -> torch.Tensor: - hidden_states = self.transformer( - input_ids, positions, kv_caches, input_metadata - ) + hidden_states = self.transformer(input_ids, positions, kv_caches, + input_metadata) return hidden_states def sample( @@ -435,8 +416,7 @@ def load_weights( ): params_dict = dict(self.named_parameters(remove_duplicate=False)) for name, loaded_weight in hf_model_weights_iterator( - model_name_or_path, cache_dir, load_format, revision - ): + model_name_or_path, cache_dir, load_format, revision): if "lm_head.weight" in name: # GPT-2 ties the weights of the embedding layer and the final # linear layer. @@ -459,7 +439,6 @@ def load_weights( if not name.endswith(".weight"): continue loaded_weight = loaded_weight.t() - weight_loader = getattr( - param, "weight_loader", default_weight_loader - ) + weight_loader = getattr(param, "weight_loader", + default_weight_loader) weight_loader(param, loaded_weight) From 159f7f90a0821907e02d140e6d1c5f0ccf9abd1f Mon Sep 17 00:00:00 2001 From: Lalit Pradhan Date: Mon, 11 Mar 2024 07:42:11 +0000 Subject: [PATCH 26/31] applied yapf --- vllm/transformers_utils/configs/jais.py | 37 +++++++++---------------- 1 file changed, 13 insertions(+), 24 deletions(-) diff --git a/vllm/transformers_utils/configs/jais.py b/vllm/transformers_utils/configs/jais.py index 11f5c00a4c75..94f438716f8b 100644 --- a/vllm/transformers_utils/configs/jais.py +++ b/vllm/transformers_utils/configs/jais.py @@ -208,38 +208,27 @@ def _alibi_scaling_validation(self): if self.alibi_scaling is None: return - if ( - not isinstance(self.alibi_scaling, dict) - or len(self.alibi_scaling) != 2 - ): + if (not isinstance(self.alibi_scaling, dict) + or len(self.alibi_scaling) != 2): raise ValueError( "`alibi_scaling` must be a dictionary with two fields," "`type` and `factor` or `type` and `train_seq_len`, " - f"got {self.alibi_scaling}" - ) + f"got {self.alibi_scaling}") alibi_scaling_type = self.alibi_scaling.get("type", None) alibi_scaling_factor = self.alibi_scaling.get("factor", None) alibi_dynamic_scaling = self.alibi_scaling.get("train_seq_len", None) if alibi_scaling_type is None or alibi_scaling_type != "linear": - raise ValueError( - f"`alibi_scaling`'s type field must be 'linear'," - f"got {alibi_scaling_type}" - ) - if ( - alibi_scaling_factor is not None - and not isinstance(alibi_scaling_factor, float) - or alibi_scaling_factor <= 1.0 - ): + raise ValueError(f"`alibi_scaling`'s type field must be 'linear'," + f"got {alibi_scaling_type}") + if (alibi_scaling_factor is not None + and not isinstance(alibi_scaling_factor, float) + or alibi_scaling_factor <= 1.0): raise ValueError( f"`alibi_scaling`'s factor field must be a float > 1.0," - f"got {alibi_scaling_factor}" - ) - if ( - alibi_dynamic_scaling is not None - and not isinstance(alibi_dynamic_scaling, int) - or alibi_dynamic_scaling <= 1 - ): + f"got {alibi_scaling_factor}") + if (alibi_dynamic_scaling is not None + and not isinstance(alibi_dynamic_scaling, int) + or alibi_dynamic_scaling <= 1): raise ValueError( f"`alibi_scaling`'s `train_seq_len` field must be an" - f"integer > 1, got {alibi_dynamic_scaling}" - ) + f"integer > 1, got {alibi_dynamic_scaling}") From 4965c5666e9e29eff5f77ea9ddf538205e16dbf2 Mon Sep 17 00:00:00 2001 From: Lalit Pradhan Date: Mon, 11 Mar 2024 08:16:51 +0000 Subject: [PATCH 27/31] adapted to #3299 --- docs/source/models/supported_models.rst | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/source/models/supported_models.rst b/docs/source/models/supported_models.rst index 5dd624bcb8d6..af4eb81646eb 100644 --- a/docs/source/models/supported_models.rst +++ b/docs/source/models/supported_models.rst @@ -66,9 +66,11 @@ Alongside each architecture, we include some popular models that use it. * - :code:`InternLM2ForCausalLM` - InternLM2 - :code:`internlm/internlm2-7b`, :code:`internlm/internlm2-chat-7b`, etc. + - * - :code:`JAISLMHeadModel` - Jais - :code:`core42/jais-13b`, :code:`core42/jais-13b-chat`, :code:`core42/jais-30b-v3`, :code:`core42/jais-30b-chat-v3`, etc. + - * - :code:`LlamaForCausalLM` - LLaMA, LLaMA-2, Vicuna, Alpaca, Yi - :code:`meta-llama/Llama-2-13b-hf`, :code:`meta-llama/Llama-2-70b-hf`, :code:`openlm-research/open_llama_13b`, :code:`lmsys/vicuna-13b-v1.3`, :code:`01-ai/Yi-6B`, :code:`01-ai/Yi-34B`, etc. From 33a3a8c39b7d0b4d2de66e55cedc4a04837156d7 Mon Sep 17 00:00:00 2001 From: Lalit Pradhan Date: Thu, 21 Mar 2024 07:27:24 +0000 Subject: [PATCH 28/31] adapted to #3233 and bug fix for gpt2 --- tests/models/test_models.py | 1 - vllm/model_executor/models/gpt2.py | 3 +- vllm/model_executor/models/jais.py | 116 ++++------------------------- 3 files changed, 14 insertions(+), 106 deletions(-) diff --git a/tests/models/test_models.py b/tests/models/test_models.py index 5488149227df..fb567e837d28 100644 --- a/tests/models/test_models.py +++ b/tests/models/test_models.py @@ -20,7 +20,6 @@ "stabilityai/stablelm-3b-4e1t", "allenai/OLMo-1B", "bigcode/starcoder2-3b", - "core42/jais-13b", ] diff --git a/vllm/model_executor/models/gpt2.py b/vllm/model_executor/models/gpt2.py index 263727cac19f..e75dda750cb2 100644 --- a/vllm/model_executor/models/gpt2.py +++ b/vllm/model_executor/models/gpt2.py @@ -242,8 +242,7 @@ def sample( logits: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(self.lm_head_weight, logits, - sampling_metadata) + next_tokens = self.sampler(logits, sampling_metadata) return next_tokens def load_weights(self, diff --git a/vllm/model_executor/models/jais.py b/vllm/model_executor/models/jais.py index 261f570284ec..471322a0ea14 100644 --- a/vllm/model_executor/models/jais.py +++ b/vllm/model_executor/models/jais.py @@ -34,17 +34,8 @@ QKVParallelLinear, RowParallelLinear, ) -from vllm.model_executor.layers.sampler import ( - Sampler, - _prune_hidden_states, - _apply_logits_processors, - _apply_penalties, - _apply_top_k_top_p, - _apply_min_p, - _sample, - _get_logprobs, - _build_sampler_output, -) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding, ) from vllm.model_executor.parallel_utils.parallel_state import ( @@ -85,90 +76,6 @@ def get_slopes_power_of_2(n): 2 * closest_power_of_2)[0::2][:n - closest_power_of_2]) -class JAISSampler(Sampler): - - def __init__(self, - vocab_size: int, - org_vocab_size: Optional[int] = None) -> None: - super().__init__(vocab_size, org_vocab_size) - - def forward( - self, - embedding: torch.Tensor, - hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, - output_logits_scale: float, - embedding_bias: Optional[torch.Tensor] = None, - ) -> Optional[SamplerOutput]: - # Get the hidden states that we use for sampling. - if self.logits_as_hidden_states: - logits = hidden_states - else: - hidden_states = _prune_hidden_states(hidden_states, - sampling_metadata) - - # Get the logits for the next tokens. - logits = self._get_logits(hidden_states, embedding, embedding_bias) - if logits is not None: - logits *= torch.tensor(float(output_logits_scale), - dtype=logits.dtype) - - # Only perform sampling in the driver worker. - # Note: `_get_logits` is still distributed across TP workers because - # the `embedding` weight is distributed across TP workers. - # TODO(zhuohan): Change the get_logits part to a separate stage. - if not sampling_metadata.perform_sampling: - return None - - assert logits is not None - _, vocab_size = logits.shape - - # Apply logits processors (if any). - logits = _apply_logits_processors(logits, sampling_metadata) - - # Prepare sampling tensors with pinned memory to avoid blocking. - (sampling_tensors, do_penalties, do_top_p_top_k, - do_min_p) = (SamplingTensors.from_sampling_metadata( - sampling_metadata, vocab_size, logits.device, logits.dtype)) - - # Apply presence and frequency penalties. - if do_penalties: - logits = _apply_penalties( - logits, - sampling_tensors.prompt_tokens, - sampling_tensors.output_tokens, - sampling_tensors.presence_penalties, - sampling_tensors.frequency_penalties, - sampling_tensors.repetition_penalties, - ) - - # Apply temperature scaling. - # Use in-place division to avoid creating a new tensor. - logits.div_(sampling_tensors.temperatures.unsqueeze_(dim=1)) - - if do_top_p_top_k: - logits = _apply_top_k_top_p(logits, sampling_tensors.top_ps, - sampling_tensors.top_ks) - - if do_min_p: - logits = _apply_min_p(logits, sampling_tensors.min_ps) - - # We use float32 for probabilities and log probabilities. - # Compute the probabilities. - probs = torch.softmax(logits, dim=-1, dtype=torch.float) - # Compute the log probabilities. - # Use log_softmax to ensure numerical stability. - logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float) - - # Sample the next tokens. - sample_results = _sample(probs, logprobs, sampling_metadata) - # Get the logprobs query results. - prompt_logprobs, sample_logprobs = _get_logprobs( - logprobs, sampling_metadata, sample_results) - return _build_sampler_output(sample_results, sampling_metadata, - prompt_logprobs, sample_logprobs) - - class JAISAttention(nn.Module): def __init__( @@ -381,7 +288,9 @@ def __init__( else: self.output_logits_scale = (config.mup_output_alpha * config.mup_width_scale) - self.sampler = JAISSampler(config.vocab_size) + self.logits_processor = LogitsProcessor(vocab_size=config.vocab_size, + scale=self.output_logits_scale) + self.sampler = Sampler() def forward( self, @@ -394,17 +303,18 @@ def forward( input_metadata) return hidden_states + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + logits = self.logits_processor(self.lm_head_weight, hidden_states, + sampling_metadata) + return logits + def sample( self, - hidden_states: torch.Tensor, + logits: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[SamplerOutput]: - next_tokens = self.sampler( - self.lm_head_weight, - hidden_states, - sampling_metadata, - self.output_logits_scale, - ) + next_tokens = self.sampler(logits, sampling_metadata) return next_tokens def load_weights( From d0b4df5457142486c1e0dc137890b49c1a6e71cd Mon Sep 17 00:00:00 2001 From: Lalit Pradhan Date: Thu, 21 Mar 2024 07:42:12 +0000 Subject: [PATCH 29/31] applied ruff and yapf --- vllm/model_executor/models/jais.py | 112 +++++++++++++++++------------ 1 file changed, 65 insertions(+), 47 deletions(-) diff --git a/vllm/model_executor/models/jais.py b/vllm/model_executor/models/jais.py index 471322a0ea14..e1f31aaa3303 100644 --- a/vllm/model_executor/models/jais.py +++ b/vllm/model_executor/models/jais.py @@ -37,7 +37,8 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( - VocabParallelEmbedding, ) + VocabParallelEmbedding, +) from vllm.model_executor.parallel_utils.parallel_state import ( get_tensor_model_parallel_world_size, get_tensor_model_parallel_rank, @@ -56,28 +57,29 @@ class SwiGLUActivation(nn.Module): - def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor: return x1 * nn.functional.silu(x2) def _get_alibi_slopes(n): - def get_slopes_power_of_2(n): - start = 2**(-(2**-(math.log2(n) - 3))) + start = 2 ** (-(2 ** -(math.log2(n) - 3))) ratio = start return [start * ratio**i for i in range(n)] if math.log2(n).is_integer(): return get_slopes_power_of_2(n) else: - closest_power_of_2 = 2**math.floor(math.log2(n)) - return (get_slopes_power_of_2(closest_power_of_2) + _get_alibi_slopes( - 2 * closest_power_of_2)[0::2][:n - closest_power_of_2]) + closest_power_of_2 = 2 ** math.floor(math.log2(n)) + return ( + get_slopes_power_of_2(closest_power_of_2) + + _get_alibi_slopes(2 * closest_power_of_2)[0::2][ + : n - closest_power_of_2 + ] + ) class JAISAttention(nn.Module): - def __init__( self, config: JAISConfig, @@ -87,7 +89,8 @@ def __init__( self.hidden_size = config.hidden_size total_num_heads = config.num_attention_heads tensor_model_parallel_world_size = ( - get_tensor_model_parallel_world_size()) + get_tensor_model_parallel_world_size() + ) assert total_num_heads % tensor_model_parallel_world_size == 0 self.num_heads = total_num_heads // tensor_model_parallel_world_size self.head_dim = self.hidden_size // total_num_heads @@ -131,14 +134,12 @@ def forward( qkv, _ = self.c_attn(hidden_states) q, k, v = qkv.chunk(chunks=3, dim=-1) key_cache, value_cache = kv_cache - attn_output = self.attn(q, k, v, key_cache, value_cache, - input_metadata) + attn_output = self.attn(q, k, v, key_cache, value_cache, input_metadata) attn_output, _ = self.c_proj(attn_output) return attn_output class JAISMLP(nn.Module): - def __init__( self, intermediate_size: int, @@ -154,12 +155,16 @@ def __init__( bias=True, linear_method=linear_method, ) - self.c_fc2 = (ColumnParallelLinear( - hidden_size, - intermediate_size, - bias=True, - linear_method=linear_method, - ) if self.swiglu else None) + self.c_fc2 = ( + ColumnParallelLinear( + hidden_size, + intermediate_size, + bias=True, + linear_method=linear_method, + ) + if self.swiglu + else None + ) self.c_proj = RowParallelLinear( intermediate_size, hidden_size, @@ -173,14 +178,16 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: if self.swiglu: hidden_states2, _ = self.c_fc2(hidden_states) hidden_states, _ = self.c_fc(hidden_states) - hidden_states = (self.act(hidden_states, hidden_states2) - if self.swiglu else self.act(hidden_states)) + hidden_states = ( + self.act(hidden_states, hidden_states2) + if self.swiglu + else self.act(hidden_states) + ) hidden_states, _ = self.c_proj(hidden_states) return hidden_states class JAISBlock(nn.Module): - def __init__( self, config: JAISConfig, @@ -188,8 +195,9 @@ def __init__( ): super().__init__() hidden_size = config.hidden_size - inner_dim = (config.n_inner if config.n_inner is not None else 4 * - hidden_size) + inner_dim = ( + config.n_inner if config.n_inner is not None else 4 * hidden_size + ) self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) self.attn = JAISAttention(config, linear_method) @@ -221,7 +229,6 @@ def forward( class JAISModel(nn.Module): - def __init__( self, config: JAISConfig, @@ -234,17 +241,21 @@ def __init__( assert not config.reorder_and_upcast_attn self.embed_dim = config.hidden_size self.wte = VocabParallelEmbedding(config.vocab_size, self.embed_dim) - self.wpe = (nn.Embedding(config.max_position_embeddings, - self.embed_dim) - if config.position_embedding_type != "alibi" else None) + self.wpe = ( + nn.Embedding(config.max_position_embeddings, self.embed_dim) + if config.position_embedding_type != "alibi" + else None + ) if hasattr(config, "embeddings_scale"): self.embeddings_scale = config.embeddings_scale else: self.embeddings_scale = config.mup_embeddings_scale - self.h = nn.ModuleList([ - JAISBlock(config, linear_method) - for _ in range(config.num_hidden_layers) - ]) + self.h = nn.ModuleList( + [ + JAISBlock(config, linear_method) + for _ in range(config.num_hidden_layers) + ] + ) self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) def forward( @@ -260,8 +271,9 @@ def forward( hidden_states = inputs_embeds + position_embeds else: hidden_states = inputs_embeds - hidden_states *= torch.tensor(float(self.embeddings_scale), - dtype=hidden_states.dtype) + hidden_states *= torch.tensor( + float(self.embeddings_scale), dtype=hidden_states.dtype + ) for i in range(len(self.h)): layer = self.h[i] @@ -272,7 +284,6 @@ def forward( class JAISLMHeadModel(nn.Module): - def __init__( self, config: JAISConfig, @@ -286,10 +297,12 @@ def __init__( if hasattr(config, "width_scale"): self.output_logits_scale = config.width_scale else: - self.output_logits_scale = (config.mup_output_alpha * - config.mup_width_scale) - self.logits_processor = LogitsProcessor(vocab_size=config.vocab_size, - scale=self.output_logits_scale) + self.output_logits_scale = ( + config.mup_output_alpha * config.mup_width_scale + ) + self.logits_processor = LogitsProcessor( + vocab_size=config.vocab_size, scale=self.output_logits_scale + ) self.sampler = Sampler() def forward( @@ -299,14 +312,17 @@ def forward( kv_caches: List[KVCache], input_metadata: InputMetadata, ) -> torch.Tensor: - hidden_states = self.transformer(input_ids, positions, kv_caches, - input_metadata) + hidden_states = self.transformer( + input_ids, positions, kv_caches, input_metadata + ) return hidden_states - def compute_logits(self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata) -> torch.Tensor: - logits = self.logits_processor(self.lm_head_weight, hidden_states, - sampling_metadata) + def compute_logits( + self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata + ) -> torch.Tensor: + logits = self.logits_processor( + self.lm_head_weight, hidden_states, sampling_metadata + ) return logits def sample( @@ -326,7 +342,8 @@ def load_weights( ): params_dict = dict(self.named_parameters(remove_duplicate=False)) for name, loaded_weight in hf_model_weights_iterator( - model_name_or_path, cache_dir, load_format, revision): + model_name_or_path, cache_dir, load_format, revision + ): if "lm_head.weight" in name: # GPT-2 ties the weights of the embedding layer and the final # linear layer. @@ -349,6 +366,7 @@ def load_weights( if not name.endswith(".weight"): continue loaded_weight = loaded_weight.t() - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) weight_loader(param, loaded_weight) From 31c12c8486e2e123b99bf0a5af96bcf90dddcd7c Mon Sep 17 00:00:00 2001 From: Lalit Pradhan Date: Thu, 21 Mar 2024 07:58:05 +0000 Subject: [PATCH 30/31] apply ruff --- vllm/model_executor/models/jais.py | 115 ++++++++++++----------------- 1 file changed, 48 insertions(+), 67 deletions(-) diff --git a/vllm/model_executor/models/jais.py b/vllm/model_executor/models/jais.py index e1f31aaa3303..6692cb7b5c15 100644 --- a/vllm/model_executor/models/jais.py +++ b/vllm/model_executor/models/jais.py @@ -37,8 +37,7 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( - VocabParallelEmbedding, -) + VocabParallelEmbedding, ) from vllm.model_executor.parallel_utils.parallel_state import ( get_tensor_model_parallel_world_size, get_tensor_model_parallel_rank, @@ -50,36 +49,34 @@ from vllm.sequence import SamplerOutput from vllm.model_executor.sampling_metadata import ( SamplingMetadata, - SamplingTensors, ) KVCache = Tuple[torch.Tensor, torch.Tensor] class SwiGLUActivation(nn.Module): + def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor: return x1 * nn.functional.silu(x2) def _get_alibi_slopes(n): + def get_slopes_power_of_2(n): - start = 2 ** (-(2 ** -(math.log2(n) - 3))) + start = 2**(-(2**-(math.log2(n) - 3))) ratio = start return [start * ratio**i for i in range(n)] if math.log2(n).is_integer(): return get_slopes_power_of_2(n) else: - closest_power_of_2 = 2 ** math.floor(math.log2(n)) - return ( - get_slopes_power_of_2(closest_power_of_2) - + _get_alibi_slopes(2 * closest_power_of_2)[0::2][ - : n - closest_power_of_2 - ] - ) + closest_power_of_2 = 2**math.floor(math.log2(n)) + return (get_slopes_power_of_2(closest_power_of_2) + _get_alibi_slopes( + 2 * closest_power_of_2)[0::2][:n - closest_power_of_2]) class JAISAttention(nn.Module): + def __init__( self, config: JAISConfig, @@ -89,8 +86,7 @@ def __init__( self.hidden_size = config.hidden_size total_num_heads = config.num_attention_heads tensor_model_parallel_world_size = ( - get_tensor_model_parallel_world_size() - ) + get_tensor_model_parallel_world_size()) assert total_num_heads % tensor_model_parallel_world_size == 0 self.num_heads = total_num_heads // tensor_model_parallel_world_size self.head_dim = self.hidden_size // total_num_heads @@ -134,12 +130,14 @@ def forward( qkv, _ = self.c_attn(hidden_states) q, k, v = qkv.chunk(chunks=3, dim=-1) key_cache, value_cache = kv_cache - attn_output = self.attn(q, k, v, key_cache, value_cache, input_metadata) + attn_output = self.attn(q, k, v, key_cache, value_cache, + input_metadata) attn_output, _ = self.c_proj(attn_output) return attn_output class JAISMLP(nn.Module): + def __init__( self, intermediate_size: int, @@ -155,16 +153,12 @@ def __init__( bias=True, linear_method=linear_method, ) - self.c_fc2 = ( - ColumnParallelLinear( - hidden_size, - intermediate_size, - bias=True, - linear_method=linear_method, - ) - if self.swiglu - else None - ) + self.c_fc2 = (ColumnParallelLinear( + hidden_size, + intermediate_size, + bias=True, + linear_method=linear_method, + ) if self.swiglu else None) self.c_proj = RowParallelLinear( intermediate_size, hidden_size, @@ -178,16 +172,14 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: if self.swiglu: hidden_states2, _ = self.c_fc2(hidden_states) hidden_states, _ = self.c_fc(hidden_states) - hidden_states = ( - self.act(hidden_states, hidden_states2) - if self.swiglu - else self.act(hidden_states) - ) + hidden_states = (self.act(hidden_states, hidden_states2) + if self.swiglu else self.act(hidden_states)) hidden_states, _ = self.c_proj(hidden_states) return hidden_states class JAISBlock(nn.Module): + def __init__( self, config: JAISConfig, @@ -195,9 +187,8 @@ def __init__( ): super().__init__() hidden_size = config.hidden_size - inner_dim = ( - config.n_inner if config.n_inner is not None else 4 * hidden_size - ) + inner_dim = (config.n_inner if config.n_inner is not None else 4 * + hidden_size) self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) self.attn = JAISAttention(config, linear_method) @@ -229,6 +220,7 @@ def forward( class JAISModel(nn.Module): + def __init__( self, config: JAISConfig, @@ -241,21 +233,17 @@ def __init__( assert not config.reorder_and_upcast_attn self.embed_dim = config.hidden_size self.wte = VocabParallelEmbedding(config.vocab_size, self.embed_dim) - self.wpe = ( - nn.Embedding(config.max_position_embeddings, self.embed_dim) - if config.position_embedding_type != "alibi" - else None - ) + self.wpe = (nn.Embedding(config.max_position_embeddings, + self.embed_dim) + if config.position_embedding_type != "alibi" else None) if hasattr(config, "embeddings_scale"): self.embeddings_scale = config.embeddings_scale else: self.embeddings_scale = config.mup_embeddings_scale - self.h = nn.ModuleList( - [ - JAISBlock(config, linear_method) - for _ in range(config.num_hidden_layers) - ] - ) + self.h = nn.ModuleList([ + JAISBlock(config, linear_method) + for _ in range(config.num_hidden_layers) + ]) self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) def forward( @@ -271,9 +259,8 @@ def forward( hidden_states = inputs_embeds + position_embeds else: hidden_states = inputs_embeds - hidden_states *= torch.tensor( - float(self.embeddings_scale), dtype=hidden_states.dtype - ) + hidden_states *= torch.tensor(float(self.embeddings_scale), + dtype=hidden_states.dtype) for i in range(len(self.h)): layer = self.h[i] @@ -284,6 +271,7 @@ def forward( class JAISLMHeadModel(nn.Module): + def __init__( self, config: JAISConfig, @@ -297,12 +285,10 @@ def __init__( if hasattr(config, "width_scale"): self.output_logits_scale = config.width_scale else: - self.output_logits_scale = ( - config.mup_output_alpha * config.mup_width_scale - ) - self.logits_processor = LogitsProcessor( - vocab_size=config.vocab_size, scale=self.output_logits_scale - ) + self.output_logits_scale = (config.mup_output_alpha * + config.mup_width_scale) + self.logits_processor = LogitsProcessor(vocab_size=config.vocab_size, + scale=self.output_logits_scale) self.sampler = Sampler() def forward( @@ -312,17 +298,14 @@ def forward( kv_caches: List[KVCache], input_metadata: InputMetadata, ) -> torch.Tensor: - hidden_states = self.transformer( - input_ids, positions, kv_caches, input_metadata - ) + hidden_states = self.transformer(input_ids, positions, kv_caches, + input_metadata) return hidden_states - def compute_logits( - self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata - ) -> torch.Tensor: - logits = self.logits_processor( - self.lm_head_weight, hidden_states, sampling_metadata - ) + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + logits = self.logits_processor(self.lm_head_weight, hidden_states, + sampling_metadata) return logits def sample( @@ -342,8 +325,7 @@ def load_weights( ): params_dict = dict(self.named_parameters(remove_duplicate=False)) for name, loaded_weight in hf_model_weights_iterator( - model_name_or_path, cache_dir, load_format, revision - ): + model_name_or_path, cache_dir, load_format, revision): if "lm_head.weight" in name: # GPT-2 ties the weights of the embedding layer and the final # linear layer. @@ -366,7 +348,6 @@ def load_weights( if not name.endswith(".weight"): continue loaded_weight = loaded_weight.t() - weight_loader = getattr( - param, "weight_loader", default_weight_loader - ) - weight_loader(param, loaded_weight) + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) \ No newline at end of file From 54d17c7222c757ee5feb3347db9866772f8452f8 Mon Sep 17 00:00:00 2001 From: Lalit Pradhan Date: Thu, 21 Mar 2024 08:01:22 +0000 Subject: [PATCH 31/31] format --- vllm/model_executor/models/jais.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/vllm/model_executor/models/jais.py b/vllm/model_executor/models/jais.py index 6692cb7b5c15..74c8e7f96302 100644 --- a/vllm/model_executor/models/jais.py +++ b/vllm/model_executor/models/jais.py @@ -47,9 +47,7 @@ hf_model_weights_iterator, ) from vllm.sequence import SamplerOutput -from vllm.model_executor.sampling_metadata import ( - SamplingMetadata, -) +from vllm.model_executor.sampling_metadata import SamplingMetadata KVCache = Tuple[torch.Tensor, torch.Tensor]