Skip to content

Commit 4a9375f

Browse files
authored
[Model] Pass param prefix to LLMHead (#24862)
Signed-off-by: whx-sjtu <[email protected]>
1 parent 03191cd commit 4a9375f

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

58 files changed

+102
-31
lines changed

vllm/model_executor/models/arctic.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -427,6 +427,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
427427
self.vocab_size,
428428
config.hidden_size,
429429
quant_config=quant_config,
430+
prefix=maybe_prefix(prefix, "lm_head"),
430431
)
431432
if self.config.tie_word_embeddings:
432433
self.lm_head.weight = self.model.embed_tokens.weight

vllm/model_executor/models/aria.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -539,6 +539,7 @@ def __init__(
539539
config.text_config.hidden_size,
540540
org_num_embeddings=self.language_model.org_vocab_size,
541541
quant_config=quant_config,
542+
prefix=maybe_prefix(prefix, "lm_head"),
542543
)
543544
logit_scale = getattr(config, "logit_scale", 1.0)
544545
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,

vllm/model_executor/models/baichuan.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,8 @@
5151

5252
from .interfaces import SupportsLoRA, SupportsPP, SupportsQuant
5353
from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
54-
make_empty_intermediate_tensors_factory, make_layers)
54+
make_empty_intermediate_tensors_factory, make_layers,
55+
maybe_prefix)
5556

5657

5758
def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
@@ -394,7 +395,8 @@ def __init__(
394395
position_embedding=position_embedding)
395396
self.lm_head = ParallelLMHead(config.vocab_size,
396397
config.hidden_size,
397-
quant_config=quant_config)
398+
quant_config=quant_config,
399+
prefix=maybe_prefix(prefix, "lm_head"))
398400
self.lm_head.weight.weight_loader = self.lm_head_weight_loader
399401
if self.config.tie_word_embeddings:
400402
self.lm_head.weight = self.model.embed_tokens.weight

vllm/model_executor/models/bamba.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -514,6 +514,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
514514
# We need bigger padding if using lora for kernel
515515
# compatibility
516516
if not lora_config else lora_config.lora_vocab_padding_size,
517+
prefix=maybe_prefix(prefix, "lm_head"),
517518
)
518519
# Used to track and store by the Mamba cache between steps.
519520
self.mamba_cache: Optional[MambaCacheManager] = None

vllm/model_executor/models/bloom.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -330,7 +330,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
330330
self.lm_head = self.transformer.word_embeddings
331331
else:
332332
self.lm_head = ParallelLMHead(self.config.vocab_size,
333-
self.config.hidden_size)
333+
self.config.hidden_size,
334+
prefix=maybe_prefix(
335+
prefix, "lm_head"))
334336

335337
self.logits_processor = LogitsProcessor(config.vocab_size)
336338
self.make_empty_intermediate_tensors = (

vllm/model_executor/models/chameleon.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -960,6 +960,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
960960
self.lm_head = ParallelLMHead(
961961
self.unpadded_vocab_size,
962962
config.hidden_size,
963+
prefix=maybe_prefix(prefix, "lm_head"),
963964
)
964965
if config.tie_word_embeddings:
965966
self.lm_head.weight = self.model.embed_tokens.weight

vllm/model_executor/models/dbrx.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -438,6 +438,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
438438
org_num_embeddings=config.vocab_size,
439439
padding_size=DEFAULT_VOCAB_PADDING_SIZE,
440440
quant_config=quant_config,
441+
prefix=maybe_prefix(prefix, "lm_head"),
441442
)
442443
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
443444
config.vocab_size)

vllm/model_executor/models/deepseek.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -453,9 +453,12 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
453453
self.quant_config = quant_config
454454
self.model = DeepseekModel(vllm_config=vllm_config,
455455
prefix=maybe_prefix(prefix, "model"))
456-
self.lm_head = ParallelLMHead(config.vocab_size,
457-
config.hidden_size,
458-
quant_config=quant_config)
456+
self.lm_head = ParallelLMHead(
457+
config.vocab_size,
458+
config.hidden_size,
459+
quant_config=quant_config,
460+
prefix=maybe_prefix(prefix, "lm_head"),
461+
)
459462
if self.config.tie_word_embeddings:
460463
self.lm_head.weight = self.model.embed_tokens.weight
461464
self.logits_processor = LogitsProcessor(config.vocab_size)

vllm/model_executor/models/deepseek_eagle.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
199199

200200
self.lm_head = ParallelLMHead(self.config.vocab_size,
201201
self.config.hidden_size,
202-
quant_config=quant_config)
202+
quant_config=quant_config,
203+
prefix=maybe_prefix(prefix, "lm_head"))
203204

204205
logit_scale = getattr(self.config, "logit_scale", 1.0)
205206
self.logits_processor = LogitsProcessor(self.config.vocab_size,

vllm/model_executor/models/deepseek_v2.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -823,9 +823,12 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
823823
self.model = DeepseekV2Model(vllm_config=vllm_config,
824824
prefix=maybe_prefix(prefix, "model"))
825825
if get_pp_group().is_last_rank:
826-
self.lm_head = ParallelLMHead(config.vocab_size,
827-
config.hidden_size,
828-
quant_config=quant_config)
826+
self.lm_head = ParallelLMHead(
827+
config.vocab_size,
828+
config.hidden_size,
829+
quant_config=quant_config,
830+
prefix=maybe_prefix(prefix, "lm_head"),
831+
)
829832
else:
830833
self.lm_head = PPMissingLayer()
831834
self.logits_processor = LogitsProcessor(config.vocab_size)

0 commit comments

Comments
 (0)