diff --git a/vllm/model_executor/models/transformers.py b/vllm/model_executor/models/transformers.py index 61cfc566dd31..a6bfdebb1a7e 100644 --- a/vllm/model_executor/models/transformers.py +++ b/vllm/model_executor/models/transformers.py @@ -25,7 +25,6 @@ from vllm.attention import Attention from vllm.config import VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size -from vllm.distributed.utils import divide from vllm.logger import init_logger from vllm.model_executor.layers.linear import (ColumnParallelLinear, ReplicatedLinear, @@ -128,10 +127,12 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: config = vllm_config.model_config.hf_config cache_config = vllm_config.cache_config + model_config = vllm_config.model_config + parallel_config = vllm_config.parallel_config self.config = config - self.vocab_size = config.vocab_size - self.unpadded_vocab_size = config.vocab_size + self.vocab_size = model_config.get_vocab_size() + self.unpadded_vocab_size = model_config.get_vocab_size() self.model: PreTrainedModel = AutoModel.from_config( self.config, @@ -145,15 +146,17 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: self.apply_base_model_tp_plan(self.model) # Attention modifications (assumes 1 attention op per hidden layer) - tp_size = get_tensor_model_parallel_world_size() + num_heads = model_config.get_num_attention_heads(parallel_config) + head_size = model_config.get_head_size() + num_kv_heads = model_config.get_num_kv_heads(parallel_config) self.attention_instances = [ Attention( - num_heads=divide(config.num_attention_heads, tp_size), - head_size=config.head_dim, + num_heads=num_heads, + head_size=head_size, # NOTE: We use Llama scale as default, if it's set by # Transformers, it's updated in vllm_flash_attention_forward - scale=config.head_dim**-0.5, - num_kv_heads=divide(config.num_key_value_heads, tp_size), + scale=head_size**-0.5, + num_kv_heads=num_kv_heads, cache_config=cache_config, quant_config=self.quant_config, prefix=f"{i}.attn") for i in range(config.num_hidden_layers) @@ -163,7 +166,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: self.replace_vocab_embed_class(self.model) # ForCausalLM modifications - self.lm_head = ParallelLMHead(config.vocab_size, + self.lm_head = ParallelLMHead(self.vocab_size, config.hidden_size, quant_config=self.quant_config, prefix=maybe_prefix(prefix, "lm_head")) @@ -172,7 +175,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: logit_scale = getattr(config, "logit_scale", 1.0) self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, - config.vocab_size, logit_scale) + self.vocab_size, logit_scale) self.sampler = get_sampler() def apply_base_model_tp_plan(self, module: nn.Module, prefix: str = ""): @@ -203,12 +206,12 @@ def replace_vocab_embed_class(self, module: nn.Module): new_module = VocabParallelEmbedding( self.vocab_size, self.config.hidden_size, - org_num_embeddings=self.config.vocab_size, + org_num_embeddings=self.vocab_size, quant_config=None, ) log_replacement("input embedding", self.model.get_input_embeddings(), new_module) - self.model.set_input_embeddings(new_module) + module.set_input_embeddings(new_module) def forward( self,