From b7fb14d4468cd862db5ab7203e95d73de09541dd Mon Sep 17 00:00:00 2001 From: Shinichi Hemmi <50256998+Alnusjaponica@users.noreply.github.com> Date: Mon, 3 Mar 2025 21:05:23 +0900 Subject: [PATCH 01/36] Add PLaMo2 model at v0.6.3.post1 Signed-off-by: Shinichi Hemmi <50256998+Alnusjaponica@users.noreply.github.com> Co-Authored-By: Kento Nozawa Co-Authored-By: Hiroaki Mikami --- vllm/model_executor/models/plamo2.py | 910 +++++++++++++++++++++++++ vllm/model_executor/models/registry.py | 1 + 2 files changed, 911 insertions(+) create mode 100644 vllm/model_executor/models/plamo2.py diff --git a/vllm/model_executor/models/plamo2.py b/vllm/model_executor/models/plamo2.py new file mode 100644 index 000000000000..e35558115696 --- /dev/null +++ b/vllm/model_executor/models/plamo2.py @@ -0,0 +1,910 @@ +# coding=utf-8 +"""Inference-only Jamba model.""" +import copy +import enum +import math +from typing import Any, Iterable, List, Optional, Tuple + +import torch +from torch import nn +from transformers import PretrainedConfig, PreTrainedModel + +from vllm.attention.backends.abstract import AttentionMetadata +from vllm.attention.layer import Attention +from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig +from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.mamba.ops.causal_conv1d import ( + causal_conv1d_fn, causal_conv1d_update) +from vllm.model_executor.layers.mamba.ops.mamba_ssm import ( + selective_scan_fn, selective_state_update) +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.vocab_parallel_embedding import ( + DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import ( + composed_weight_loader, default_weight_loader, sharded_weight_loader) +from vllm.model_executor.models.mamba_cache import (MambaCacheManager, + MambaCacheParams) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.model_executor.utils import set_weight_attrs +from vllm.sequence import IntermediateTensors +from vllm.worker.model_runner import (_BATCH_SIZES_TO_CAPTURE, + _get_graph_batch_size) + +from .interfaces import HasInnerState, SupportsLoRA + +KVCache = Tuple[torch.Tensor, torch.Tensor] + + +class LinearType(str, enum.Enum): + Normal = "normal" + Fp8 = "fp8" + Fp8Retain = "fp8-retain" + + +# Just for type hinting and PlamoPreTrainedModel.config_class. +class PlamoConfig(PretrainedConfig): # type: ignore + model_type: str = "plamo" + + def __init__( + self, + hidden_size: int = 4096, + num_hidden_layers: int = 32, + rms_norm_eps: float = 1e-6, + tie_word_embeddings: bool = False, + # Attention + num_attention_heads: int = 32, + num_key_value_heads: int = 4, + hidden_size_per_head: int = 128, + max_position_embeddings: int = 2048, + attention_window_size: int = 2048, + # Mamba + mamba_d_state: int = 64, + mamba_d_conv: int = 4, + mamba_num_heads: int = 64, + mamba_step: int = 2, + mamba_chunk_size: int = 256, + # MLP + intermediate_size: int = 13312, + # Tokenizer + vocab_size: int = 32000, + tokenizer_class: str = "PlamoTokenizer", + pad_token_id: Optional[int] = None, + bos_token_id: int = 1, + eos_token_id: int = 2, + # MoE + n_expert: Optional[int] = None, + k_expert: Optional[int] = None, + expert_dropout: float = 0.0, + capacity_factor: float = 1.0, + group_size: int = 1024, + sparse_step: Optional[int] = None, + sparse_intermediate_size: Optional[int] = None, + shared_intermediate_size: Optional[int] = None, + # FP8 + linear_type: LinearType = LinearType.Normal, + fp8_accum_dtype: Optional[str] = None, + # Evaluation + eval_attention_n_bit: Optional[int] = None, + eval_mlp_n_bit: Optional[int] = None, + eval_offload_moe: bool = False, + use_cache: bool = True, + use_predefined_initial_state: bool = False, + **kwargs: Any, + ) -> None: + # max_position_embeddings is often used to determine the max length + # during inference, but samba should have extrapolation abilities + self.max_position_embeddings = max(10 * 1024 * 1024, + max_position_embeddings) + self.hidden_size = hidden_size + self.rms_norm_eps = rms_norm_eps + + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.hidden_size_per_head = hidden_size_per_head + self.num_key_value_heads = num_key_value_heads + self.attention_window_size = attention_window_size + + self.mamba_d_state = mamba_d_state + self.mamba_d_conv = mamba_d_conv + self.mamba_num_heads = mamba_num_heads + self.mamba_step = mamba_step + self.mamba_chunk_size = mamba_chunk_size + + self.intermediate_size = intermediate_size + + self.vocab_size = vocab_size + + self.n_expert = n_expert + self.k_expert = k_expert + self.sparse_intermediate_size = sparse_intermediate_size + self.shared_intermediate_size = shared_intermediate_size + self.expert_dropout = expert_dropout + self.capacity_factor = capacity_factor + self.group_size = group_size + self.sparse_step = sparse_step + + self.linear_type = linear_type + self.fp8_accum_dtype = fp8_accum_dtype + + self.eval_attention_n_bit = eval_attention_n_bit + self.eval_mlp_n_bit = eval_mlp_n_bit + self.eval_offload_moe = eval_offload_moe + self.use_cache = use_cache + + self.use_predefined_initial_state = use_predefined_initial_state + + super().__init__( + tokenizer_class=tokenizer_class, + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + +class PlamoPreTrainedModel(PreTrainedModel): # type: ignore + config_class = PlamoConfig + _no_split_modules: List[str] + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["PlamoDecoderLayer"] + _skip_keys_device_placement = "past_key_values" + _keys_to_ignore_on_load_unexpected = [r"decoder\.version"] + + def _init_weights(self, module: torch.nn.Module) -> None: + std = 0.02 + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +def get_initial_dt_bias(num_heads: int) -> torch.Tensor: + dt_min = 0.001 + dt_max = 0.1 + dt = torch.exp( + torch.rand(num_heads) * (math.log(dt_max) - math.log(dt_min)) + + math.log(dt_min)) + dt = torch.clamp(dt, 1e-4) + inv_dt = dt + torch.log(-torch.expm1(-dt)) + return inv_dt + + +def is_mamba(config: PlamoConfig, i: int) -> bool: + assert config.mamba_step > 1 + + if config.num_hidden_layers <= (config.mamba_step // 2): + # use attention in last layer + return i != config.num_hidden_layers - 1 + return (i % config.mamba_step) != (config.mamba_step // 2) + + +# TODO(Shinichi): Replace this with RMSNorm. +def _rms_norm(hidden_states: torch.Tensor, weight: torch.Tensor, + eps: float) -> torch.Tensor: + input_shape = hidden_states.shape + hidden_states = hidden_states.reshape(input_shape[:-1] + weight.shape) + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + eps) + hidden_states = hidden_states.to(input_dtype) + hidden_states = weight * hidden_states + return hidden_states.reshape(input_shape) + + +def _swiglu(h: torch.Tensor) -> torch.Tensor: + h0, h1 = h.chunk(2, dim=-1) + return torch.nn.functional.silu(h0) * h1 + + +# Adapted from transformers.models.mamba.modeling_mamba.MambaMixer +class Plamo2MambaMixer(nn.Module): + """ + Compute ∆, A, B, C, and D the state space parameters and compute + the `contextualized_states`. A, D are input independent + (see Mamba paper [1] Section 3.5.2 "Interpretation of A" + for why A isn't selective) ∆, B, C are input-dependent + (this is a key difference between Mamba and the linear time + invariant S4, and is why Mamba is called + **selective** state spaces) + """ + + def __init__(self, config: PlamoConfig): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.ssm_state_size = config.mamba_d_state + self.conv_kernel_size = config.mamba_d_conv + self.intermediate_size = (config.mamba_num_heads * + config.hidden_size_per_head) + self.hidden_size_per_head = config.hidden_size_per_head + self.num_heads = config.mamba_num_heads + self.time_step_rank = max(64, self.hidden_size // 16) + self.use_conv_bias = False + self.use_bias = False + self.conv1d = ColumnParallelLinear( + input_size=self.conv_kernel_size, + output_size=self.intermediate_size, + bias=self.use_conv_bias, + ) + # unsqueeze to fit conv1d weights shape into the linear weights shape. + # Can't do this in `weight_loader` since it already exists in + # `ColumnParallelLinear` and `set_weight_attrs` + # doesn't allow to override it + self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1) + + self.in_proj = MergedColumnParallelLinear(self.hidden_size, + [self.intermediate_size] * 2, + bias=self.use_bias) + # selective projection used to make dt, B and C input dependent + self.x_proj = RowParallelLinear( + self.intermediate_size, + self.time_step_rank + self.ssm_state_size * 2, + bias=False, + ) + # time step projection (discretization) - + # In the forward we need to apply dt_proj without the bias, + # as the bias is added in the selective scan kernel. + self.dt_proj = ColumnParallelLinear(self.time_step_rank, + self.num_heads, + bias=False) + self.dt_bias = torch.nn.Parameter(get_initial_dt_bias(self.num_heads)) + + tp_size = get_tensor_model_parallel_world_size() + self.A = nn.Parameter( + torch.empty( + self.intermediate_size // tp_size, + self.ssm_state_size, + dtype=torch.float32, + )) + self.D = nn.Parameter(torch.ones(self.intermediate_size // tp_size)) + + set_weight_attrs(self.D, {"weight_loader": sharded_weight_loader(0)}) + a_weight_loader = composed_weight_loader( + sharded_weight_loader(0), lambda x: -torch.exp(x.float())) + set_weight_attrs(self.A, {"weight_loader": a_weight_loader}) + + self.out_proj = RowParallelLinear( + self.intermediate_size, + self.hidden_size, + bias=self.use_bias, + input_is_parallel=True, + ) + # The activation function is fixed to SiLU. + self.activation = "silu" + + self.dt_layernorm = RMSNorm(self.time_step_rank, + eps=config.rms_norm_eps) + self.b_layernorm = RMSNorm(self.ssm_state_size, + eps=config.rms_norm_eps) + self.c_layernorm = RMSNorm(self.ssm_state_size, + eps=config.rms_norm_eps) + + def forward(self, hidden_states: torch.Tensor, + attn_metadata: AttentionMetadata, + mamba_cache_params: MambaCacheParams): + + # 1. Gated MLP's linear projection + projected_states = self.in_proj(hidden_states)[0] + # Reshaping the projected states as in modeling_plamo.py. + length = len(hidden_states) + projected_states = projected_states.reshape(length, self.num_heads, -1) + gate, hidden_states = torch.split( + projected_states, + [self.hidden_size_per_head, self.hidden_size_per_head], + dim=-1) + hidden_states = hidden_states.reshape(length, -1).transpose(0, 1) + gate = gate.reshape(length, -1).transpose(0, 1) + + # 2. Convolution sequence transformation + conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), + self.conv1d.weight.size(2)) + + if attn_metadata.query_start_loc is not None \ + and attn_metadata.context_lens_tensor is not None: + # |---------- N-1 iteration --------| + # |---------------- N iteration ---------------------| + # |- tokenA -|......................|-- newTokens ---| + # |---------- context_len ----------| + # |-------------------- seq_len ---------------------| + # |-- query_len ---| + hidden_states = causal_conv1d_fn( + hidden_states, + conv_weights, + self.conv1d.bias, + activation=self.activation, + conv_states=mamba_cache_params.conv_state, + has_initial_state=attn_metadata.context_lens_tensor > 0, + cache_indices=mamba_cache_params.state_indices_tensor, + query_start_loc=attn_metadata.query_start_loc) + else: + hidden_states = causal_conv1d_update( + hidden_states.transpose(0, 1), + mamba_cache_params.conv_state, + conv_weights, + self.conv1d.bias, + self.activation, + conv_state_indices=mamba_cache_params.state_indices_tensor) + hidden_states = hidden_states.transpose(0, 1) + + # 3. State Space Model sequence transformation + # 3.a. input varying initialization of time_step, B and C + ssm_parameters = self.x_proj(hidden_states.transpose(-2, -1))[0] + + # Splitting the ssm_parameters as in modeling_plamo.py. + B, C, time_step = torch.split( + ssm_parameters, + [self.ssm_state_size, self.ssm_state_size, self.time_step_rank], + dim=-1, + ) + time_step = self.dt_layernorm(time_step.contiguous()) + B = self.b_layernorm(B.contiguous()) + C = self.c_layernorm(C.contiguous()) + + discrete_time_step = self.dt_proj(time_step)[0].transpose(-2, -1) + # 3.c perform the recurrence y ← SSM(A, B, C)(x) + time_proj_bias = (self.dt_bias.float() if hasattr( + self.dt_proj, "bias") else None) + + # Broadcasting as in modeling_plamo.py. + discrete_time_step = discrete_time_step.transpose( + 0, 1)[..., None].expand(-1, -1, self.hidden_size_per_head) + discrete_time_step = discrete_time_step.reshape( + -1, self.intermediate_size).transpose(0, 1) + time_proj_bias = time_proj_bias[..., + None].expand(-1, + self.hidden_size_per_head) + time_proj_bias = time_proj_bias.reshape(self.intermediate_size) + + if attn_metadata.query_start_loc is not None \ + and attn_metadata.context_lens_tensor is not None: + scan_outputs = selective_scan_fn( + hidden_states, + mamba_cache_params.ssm_state, + discrete_time_step, + self.A, + B.transpose(-2, -1), + C.transpose(-2, -1), + self.D.float(), + gate, + time_proj_bias, + delta_softplus=True, + cache_indices=mamba_cache_params.state_indices_tensor, + has_initial_state=attn_metadata.context_lens_tensor > 0, + query_start_loc=attn_metadata.query_start_loc) + else: + scan_outputs = selective_state_update( + mamba_cache_params.ssm_state, + hidden_states.transpose(0, 1), + discrete_time_step.transpose(0, 1), + self.A, + B, + C, + self.D, + gate.transpose(0, 1), + time_proj_bias, + dt_softplus=True, + state_batch_indices=mamba_cache_params.state_indices_tensor) + scan_outputs = scan_outputs.transpose(0, 1) + + # 4. Final linear projection + contextualized_states = self.out_proj(scan_outputs.transpose(-2, + -1))[0] + return contextualized_states + + +class Plamo2MoE(nn.Module): + + def __init__(self, + config: PlamoConfig, + num_experts: Optional[int] = None, + top_k: Optional[int] = None, + params_dtype: Optional[torch.dtype] = None, + tp_size: Optional[int] = None, + quant_config: Optional[QuantizationConfig] = None) -> None: + super().__init__() + assert num_experts is None or num_experts <= 1, "MoE not supported" + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_up_proj = torch.nn.Linear(self.hidden_size, + self.intermediate_size * 2, + bias=False) + self.down_proj = torch.nn.Linear(self.intermediate_size, + self.hidden_size, + bias=False) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + h = self.gate_up_proj(hidden_states) + h = _swiglu(h) + return self.down_proj(h) # type: ignore + + +class DenseMLP(Plamo2MoE): + + def __init__(self, + config: PlamoConfig, + params_dtype: Optional[torch.dtype] = None, + tp_size: Optional[int] = None, + quant_config: Optional[QuantizationConfig] = None): + super().__init__(config, + num_experts=1, + top_k=1, + params_dtype=params_dtype, + tp_size=tp_size, + quant_config=quant_config) + + +class Plamo2MambaDecoderLayer(nn.Module): + + def __init__(self, + config: PlamoConfig, + layer_idx: int, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + scheduler_config: Optional[SchedulerConfig] = None) -> None: + super().__init__() + self.layer_idx = layer_idx + self.config = config + self.mamba = Plamo2MambaMixer(config) + + ffn_layer_class = DenseMLP + self.mlp = ffn_layer_class(config, quant_config=quant_config) + self.pre_mixer_norm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_mixer_norm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.pre_mlp_norm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_mlp_norm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attn_metadata: AttentionMetadata, + residual: Optional[torch.Tensor], + mamba_cache_params: MambaCacheParams, + **kwargs, + ): + if residual is None: + residual = hidden_states + hidden_states = self.pre_mixer_norm(hidden_states) + else: + hidden_states, residual = self.pre_mixer_norm( + hidden_states, residual) + + hidden_states = self.mamba(hidden_states, attn_metadata, + mamba_cache_params) + hidden_states = self.post_mixer_norm(hidden_states) + # Fully Connected + hidden_states, residual = self.pre_mlp_norm(hidden_states, residual) + hidden_states = self.mlp(hidden_states) + hidden_states = self.post_mlp_norm(hidden_states) + return hidden_states, residual + + +class Plamo2AttentionDecoderLayer(nn.Module): + + def __init__(self, + config: PlamoConfig, + layer_idx: int, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + scheduler_config: Optional[SchedulerConfig] = None) -> None: + super().__init__() + self.hidden_size = config.hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = config.num_attention_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = config.num_key_value_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = config.hidden_size_per_head + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + + self.qkv_proj = QKVParallelLinear( + config.hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=False, + quant_config=quant_config, + ) + self.o_proj = RowParallelLinear(self.total_num_heads * self.head_dim, + config.hidden_size, + bias=False, + quant_config=quant_config) + + self.rope_theta = config.rope_theta if hasattr(config, + "rope_theta") else 10000 + self.rope_scaling = config.rope_scaling if hasattr( + config, "rope_scaling") else None + self.max_position_embeddings = config.attention_window_size + + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=self.max_position_embeddings, + base=self.rope_theta, + rope_scaling=self.rope_scaling, + max_model_len=scheduler_config.max_model_len, + ) + self.q_weight = torch.nn.Parameter( + torch.ones((self.num_heads, config.hidden_size_per_head))) + self.k_weight = torch.nn.Parameter( + torch.ones((self.num_kv_heads, config.hidden_size_per_head))) + + # TODO(Shinichi): Remove this workaround. + cache_config = copy.deepcopy( + cache_config) if cache_config is not None else CacheConfig() + cache_config.sliding_window = config.attention_window_size + + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + ) + + ffn_layer_class = DenseMLP + self.mlp = ffn_layer_class(config, quant_config=quant_config) + self.pre_mixer_norm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_mixer_norm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.pre_mlp_norm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_mlp_norm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.layer_idx = layer_idx + + def self_attention( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + **kwargs, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q = _rms_norm(q, self.q_weight, 1e-6) + k = _rms_norm(k, self.k_weight, 1e-6) + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + output, _ = self.o_proj(attn_output) + return output + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + residual: Optional[torch.Tensor], + **kwargs, + ): + if residual is None: + residual = hidden_states + hidden_states = self.pre_mixer_norm(hidden_states) + else: + hidden_states, residual = self.pre_mixer_norm( + hidden_states, residual) + + hidden_states = self.self_attention( + positions=positions, + hidden_states=hidden_states, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + ) + hidden_states = self.post_mixer_norm(hidden_states) + # Fully Connected + hidden_states, residual = self.pre_mlp_norm(hidden_states, residual) + hidden_states = self.mlp(hidden_states) + hidden_states = self.post_mlp_norm(hidden_states) + return hidden_states, residual + + +class PlamoModel(PlamoPreTrainedModel): + + def __init__( + self, + config: PlamoConfig, + quant_config: Optional[QuantizationConfig] = None, + cache_config: Optional[CacheConfig] = None, + lora_config: Optional[LoRAConfig] = None, + scheduler_config: Optional[SchedulerConfig] = None, + ) -> None: + super().__init__(config) + self.config = config + self.padding_idx = config.pad_token_id + lora_vocab = ((lora_config.lora_extra_vocab_size * + (lora_config.max_loras or 1)) if lora_config else 0) + self.vocab_size = config.vocab_size + lora_vocab + self.org_vocab_size = config.vocab_size + + self.embed_tokens = VocabParallelEmbedding( + self.vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + ) + + decoder_layers = [] + for i in range(config.num_hidden_layers): + layer_class = Plamo2MambaDecoderLayer if is_mamba( + config, i) else Plamo2AttentionDecoderLayer + decoder_layers.append( + layer_class(config, + layer_idx=i, + cache_config=cache_config, + quant_config=quant_config, + scheduler_config=scheduler_config)) + self.layers = nn.ModuleList(decoder_layers) + self.final_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_init() + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + mamba_cache_params: MambaCacheParams, + ) -> torch.Tensor: + hidden_states = self.embed_tokens(input_ids) + residual = None + attention_layer_idx = 0 + mamba_layer_idx = 0 + for i in range(len(self.layers)): + layer = self.layers[i] + kv_cache = None + layer_mamba_cache_params = None + if isinstance(layer, Plamo2AttentionDecoderLayer): + kv_cache = kv_caches[attention_layer_idx] + attention_layer_idx += 1 + if isinstance(layer, Plamo2MambaDecoderLayer): + layer_mamba_cache_params = mamba_cache_params.at_layer_idx( + mamba_layer_idx) + mamba_layer_idx += 1 + hidden_states, residual = layer( + positions=positions, + hidden_states=hidden_states, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + residual=residual, + mamba_cache_params=layer_mamba_cache_params) + hidden_states, _ = self.final_layernorm(hidden_states, residual) + return hidden_states + + +class Plamo2ForCausalLM(PlamoPreTrainedModel, HasInnerState, SupportsLoRA): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + } + + # LoRA specific attributes + supported_lora_modules = [ + "qkv_proj", + "o_proj", + "embed_tokens", + "lm_head", + ] + embedding_modules = { + "embed_tokens": "input_embeddings", + "lm_head": "output_embeddings", + } + embedding_padding_modules = ["lm_head"] + + def __init__( + self, + config: PlamoConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + lora_config: Optional[LoRAConfig] = None, + scheduler_config: Optional[SchedulerConfig] = None, + ) -> None: + assert not cache_config.enable_prefix_caching, \ + "PLaMo2 currently does not support prefix caching" + + super().__init__(config) + self.config = config + self.scheduler_config = scheduler_config + self.model = PlamoModel(config, + cache_config=cache_config, + quant_config=quant_config, + lora_config=lora_config, + scheduler_config=scheduler_config) + self.vocab_size = config.vocab_size + self.unpadded_vocab_size = config.vocab_size + if lora_config: + self.unpadded_vocab_size += lora_config.lora_extra_vocab_size + num_embeddings = ((self.vocab_size + 15) // 16) * 16 + self.lm_head = ParallelLMHead( + num_embeddings, + config.hidden_size, + org_num_embeddings=config.vocab_size, + padding_size=DEFAULT_VOCAB_PADDING_SIZE + # We need bigger padding if using lora for kernel + # compatibility + if not lora_config else lora_config.lora_vocab_padding_size, + ) + if config.tie_word_embeddings: + self.lm_head = self.lm_head.tie_weights(self.model.embed_tokens) + + # Used to track and store by the Mamba cache between steps. + self.mamba_cache: Optional[MambaCacheManager] = None + + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + config.vocab_size) + self.sampler = Sampler() + + # Initialize weights and apply final processing + self.post_init() + + def forward(self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[KVCache], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + **kwargs): + if self.mamba_cache is None: + max_batch_size = (_get_graph_batch_size( + self.scheduler_config.max_num_seqs) if self.scheduler_config + else max(_BATCH_SIZES_TO_CAPTURE) + 2) + + num_mamba_layers = sum([ + is_mamba(self.config, i) + for i in range(self.config.num_hidden_layers) + ]) + + self.mamba_cache = MambaCacheManager( + self.lm_head.weight.dtype, num_mamba_layers, max_batch_size, + *self._get_mamba_cache_shape()) + ( + mamba_cache_tensors, + state_indices_tensor, + ) = self.mamba_cache.current_run_tensors(input_ids, attn_metadata, + **kwargs) + mamba_cache_params = MambaCacheParams(mamba_cache_tensors[0], + mamba_cache_tensors[1], + state_indices_tensor) + hidden_states = self.model(input_ids, positions, kv_caches, + attn_metadata, mamba_cache_params) + return hidden_states + + def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs): + return self.mamba_cache.copy_inputs_before_cuda_graphs( + input_buffers, **kwargs) + + def get_seqlen_agnostic_capture_inputs(self, batch_size: int): + return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size) + + def _get_mamba_cache_shape( + self) -> Tuple[Tuple[int, int], Tuple[int, int]]: + world_size = get_tensor_model_parallel_world_size() + hidden_size = (self.config.mamba_num_heads * + self.config.hidden_size_per_head) + conv_state_shape = ( + hidden_size // world_size, + self.config.mamba_d_conv - 1, + ) + temporal_state_shape = ( + hidden_size // world_size, + self.config.mamba_d_state, + ) + return conv_state_shape, temporal_state_shape + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: Optional[torch.Tensor], + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + params_dict = dict(self.named_parameters()) + for name, loaded_weight in weights: + + # Alignment team workaround: somehow when tie_word_embeddings=True, + # `lm_head.weight` may be in the safetensor, which causing dict key + # access error. + if name == "lm_head.weight" and self.config.tie_word_embeddings: + assert "lm_head.weight" not in params_dict + continue + + # Update the weight names to be compatible with the vllm version + # of the model. Do not change the order of the replacements. + replacements = { + # Skip PlamoDecoderLayers. + ".layers.layers": ".layers", + # Skip PlmoDecoderLayer. + ".mixer": "", + # Rename the final layernorm of the model. + "model.norm.weight": "model.final_layernorm.weight", + + # Rename each mamba layer's components. + ".A_log": ".mamba.A", + ".B_norm_weight": ".mamba.b_layernorm.weight", + ".C_norm_weight": ".mamba.c_layernorm.weight", + ".dt_norm_weight": ".mamba.dt_layernorm.weight", + ".bcdt_proj.weight": ".mamba.x_proj.weight", + ".conv1d.weight": ".mamba.conv1d.weight", + ".in_proj.weight": ".mamba.in_proj.weight", + ".out_proj.weight": ".mamba.out_proj.weight", + ".D": ".mamba.D", + ".dt_bias": ".mamba.dt_bias", + ".dt_proj.weight": ".mamba.dt_proj.weight", + } + # Apply replacements based on the defined mappings + for old, new in replacements.items(): + if old in name: + name = name.replace(old, new) + + # Broadcast the loaded weight to match the model's parameter shape. + if ".A" in name: + loaded_weight = loaded_weight[:, None, None].expand( + -1, self.config.hidden_size_per_head, + self.config.mamba_d_state) + loaded_weight = loaded_weight.reshape( + -1, self.config.mamba_d_state) + elif ".D" in name: + loaded_weight = loaded_weight[:, None].expand( + -1, self.config.hidden_size_per_head) + loaded_weight = loaded_weight.reshape(-1) + # Offset parameter with vllm's RMSNorm haven't been supported yet. + if ".pre_mixer_norm" in name: + loaded_weight += 1.0 + elif ".post_mixer_norm" in name: + loaded_weight += 1.0 / 5 + elif ".pre_mlp_norm" in name: + loaded_weight += 1.0 + elif ".post_mlp_norm" in name: + loaded_weight += 1.0 / (5**1.5) + elif "model.final_layernorm.weight" in name: + loaded_weight += 1.0 + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 4551d81e8a5d..eb391a13646b 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -93,6 +93,7 @@ "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"), "Phi3SmallForCausalLM": ("phi3_small", "Phi3SmallForCausalLM"), "PhiMoEForCausalLM": ("phimoe", "PhiMoEForCausalLM"), + "PlamoForCausalLM": ("plamo2", "Plamo2ForCausalLM"), "QWenLMHeadModel": ("qwen", "QWenLMHeadModel"), "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"), "Qwen2MoeForCausalLM": ("qwen2_moe", "Qwen2MoeForCausalLM"), From e58e384cdef25a0b66933a5476b9fb0082a36074 Mon Sep 17 00:00:00 2001 From: Shinichi Hemmi <50256998+Alnusjaponica@users.noreply.github.com> Date: Tue, 4 Mar 2025 21:46:09 +0900 Subject: [PATCH 02/36] Follow-up to the latest based on Jamba implementaion Signed-off-by: Shinichi Hemmi <50256998+Alnusjaponica@users.noreply.github.com> Co-authored-by: Calvin Metzger --- vllm/model_executor/models/plamo2.py | 216 ++++++++++----------------- 1 file changed, 80 insertions(+), 136 deletions(-) diff --git a/vllm/model_executor/models/plamo2.py b/vllm/model_executor/models/plamo2.py index e35558115696..ecc5fbb4d9a9 100644 --- a/vllm/model_executor/models/plamo2.py +++ b/vllm/model_executor/models/plamo2.py @@ -1,6 +1,5 @@ -# coding=utf-8 -"""Inference-only Jamba model.""" -import copy +# SPDX-License-Identifier: Apache-2.0 +"""Inference-only PLaMo2 model.""" import enum import math from typing import Any, Iterable, List, Optional, Tuple @@ -11,8 +10,9 @@ from vllm.attention.backends.abstract import AttentionMetadata from vllm.attention.layer import Attention -from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig +from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.forward_context import get_forward_context from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (ColumnParallelLinear, MergedColumnParallelLinear, @@ -25,20 +25,18 @@ selective_scan_fn, selective_state_update) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( composed_weight_loader, default_weight_loader, sharded_weight_loader) +from vllm.model_executor.models.interfaces import HasInnerState, SupportsLoRA from vllm.model_executor.models.mamba_cache import (MambaCacheManager, MambaCacheParams) +from vllm.model_executor.models.utils import maybe_prefix from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.utils import set_weight_attrs from vllm.sequence import IntermediateTensors -from vllm.worker.model_runner import (_BATCH_SIZES_TO_CAPTURE, - _get_graph_batch_size) - -from .interfaces import HasInnerState, SupportsLoRA KVCache = Tuple[torch.Tensor, torch.Tensor] @@ -213,17 +211,9 @@ def _swiglu(h: torch.Tensor) -> torch.Tensor: # Adapted from transformers.models.mamba.modeling_mamba.MambaMixer class Plamo2MambaMixer(nn.Module): - """ - Compute ∆, A, B, C, and D the state space parameters and compute - the `contextualized_states`. A, D are input independent - (see Mamba paper [1] Section 3.5.2 "Interpretation of A" - for why A isn't selective) ∆, B, C are input-dependent - (this is a key difference between Mamba and the linear time - invariant S4, and is why Mamba is called - **selective** state spaces) - """ - - def __init__(self, config: PlamoConfig): + # TODO(Shinichi): Rebase on Mamba2 implementation. + + def __init__(self, config: PlamoConfig, prefix: str = ""): super().__init__() self.config = config self.hidden_size = config.hidden_size @@ -247,14 +237,18 @@ def __init__(self, config: PlamoConfig): # doesn't allow to override it self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1) - self.in_proj = MergedColumnParallelLinear(self.hidden_size, - [self.intermediate_size] * 2, - bias=self.use_bias) + self.in_proj = MergedColumnParallelLinear( + self.hidden_size, + [self.intermediate_size] * 2, + bias=self.use_bias, + prefix=f"{prefix}.in_proj", + ) # selective projection used to make dt, B and C input dependent self.x_proj = RowParallelLinear( self.intermediate_size, self.time_step_rank + self.ssm_state_size * 2, bias=False, + prefix=f"{prefix}.x_proj", ) # time step projection (discretization) - # In the forward we need to apply dt_proj without the bias, @@ -295,9 +289,10 @@ def __init__(self, config: PlamoConfig): eps=config.rms_norm_eps) def forward(self, hidden_states: torch.Tensor, - attn_metadata: AttentionMetadata, mamba_cache_params: MambaCacheParams): + attn_metadata: AttentionMetadata = get_forward_context().attn_metadata + # 1. Gated MLP's linear projection projected_states = self.in_proj(hidden_states)[0] # Reshaping the projected states as in modeling_plamo.py. @@ -455,9 +450,9 @@ def __init__(self, layer_idx: int, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, - scheduler_config: Optional[SchedulerConfig] = None) -> None: + max_model_len: int | None = None, + **kwargs) -> None: super().__init__() - self.layer_idx = layer_idx self.config = config self.mamba = Plamo2MambaMixer(config) @@ -475,7 +470,6 @@ def __init__(self, def forward( self, hidden_states: torch.Tensor, - attn_metadata: AttentionMetadata, residual: Optional[torch.Tensor], mamba_cache_params: MambaCacheParams, **kwargs, @@ -487,8 +481,7 @@ def forward( hidden_states, residual = self.pre_mixer_norm( hidden_states, residual) - hidden_states = self.mamba(hidden_states, attn_metadata, - mamba_cache_params) + hidden_states = self.mamba(hidden_states, mamba_cache_params) hidden_states = self.post_mixer_norm(hidden_states) # Fully Connected hidden_states, residual = self.pre_mlp_norm(hidden_states, residual) @@ -504,7 +497,9 @@ def __init__(self, layer_idx: int, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, - scheduler_config: Optional[SchedulerConfig] = None) -> None: + max_model_len: int | None = None, + prefix: str = "", + **kwargs) -> None: super().__init__() self.hidden_size = config.hidden_size tp_size = get_tensor_model_parallel_world_size() @@ -543,32 +538,27 @@ def __init__(self, "rope_theta") else 10000 self.rope_scaling = config.rope_scaling if hasattr( config, "rope_scaling") else None - self.max_position_embeddings = config.attention_window_size + assert max_model_len is not None, "max_model_len must be provided" self.rotary_emb = get_rope( self.head_dim, rotary_dim=self.head_dim, - max_position=self.max_position_embeddings, + max_position=max_model_len, base=self.rope_theta, rope_scaling=self.rope_scaling, - max_model_len=scheduler_config.max_model_len, ) self.q_weight = torch.nn.Parameter( torch.ones((self.num_heads, config.hidden_size_per_head))) self.k_weight = torch.nn.Parameter( torch.ones((self.num_kv_heads, config.hidden_size_per_head))) - # TODO(Shinichi): Remove this workaround. - cache_config = copy.deepcopy( - cache_config) if cache_config is not None else CacheConfig() - cache_config.sliding_window = config.attention_window_size - self.attn = Attention( self.num_heads, self.head_dim, self.scaling, num_kv_heads=self.num_kv_heads, cache_config=cache_config, + prefix=f"{prefix}.attn", ) ffn_layer_class = DenseMLP @@ -581,14 +571,11 @@ def __init__(self, eps=config.rms_norm_eps) self.post_mlp_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.layer_idx = layer_idx def self_attention( self, positions: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, **kwargs, ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) @@ -596,7 +583,7 @@ def self_attention( q = _rms_norm(q, self.q_weight, 1e-6) k = _rms_norm(k, self.k_weight, 1e-6) q, k = self.rotary_emb(positions, q, k) - attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + attn_output = self.attn(q, k, v) output, _ = self.o_proj(attn_output) return output @@ -604,8 +591,6 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, residual: Optional[torch.Tensor], **kwargs, ): @@ -619,8 +604,6 @@ def forward( hidden_states = self.self_attention( positions=positions, hidden_states=hidden_states, - kv_cache=kv_cache, - attn_metadata=attn_metadata, ) hidden_states = self.post_mixer_norm(hidden_states) # Fully Connected @@ -630,28 +613,25 @@ def forward( return hidden_states, residual -class PlamoModel(PlamoPreTrainedModel): +class Plamo2Model(PlamoPreTrainedModel): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__(vllm_config.model_config.hf_config) + + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config - def __init__( - self, - config: PlamoConfig, - quant_config: Optional[QuantizationConfig] = None, - cache_config: Optional[CacheConfig] = None, - lora_config: Optional[LoRAConfig] = None, - scheduler_config: Optional[SchedulerConfig] = None, - ) -> None: - super().__init__(config) self.config = config self.padding_idx = config.pad_token_id - lora_vocab = ((lora_config.lora_extra_vocab_size * - (lora_config.max_loras or 1)) if lora_config else 0) - self.vocab_size = config.vocab_size + lora_vocab + self.vocab_size = config.vocab_size self.org_vocab_size = config.vocab_size self.embed_tokens = VocabParallelEmbedding( self.vocab_size, config.hidden_size, org_num_embeddings=config.vocab_size, + prefix=f"{prefix}.embed_tokens", ) decoder_layers = [] @@ -659,11 +639,13 @@ def __init__( layer_class = Plamo2MambaDecoderLayer if is_mamba( config, i) else Plamo2AttentionDecoderLayer decoder_layers.append( - layer_class(config, - layer_idx=i, - cache_config=cache_config, - quant_config=quant_config, - scheduler_config=scheduler_config)) + layer_class( + config, + layer_idx=i, + cache_config=cache_config, + quant_config=quant_config, + max_model_len=vllm_config.scheduler_config.max_model_len, + prefix=f"{prefix}.decoder_layers.{i}")) self.layers = nn.ModuleList(decoder_layers) self.final_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -673,30 +655,25 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, mamba_cache_params: MambaCacheParams, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, ) -> torch.Tensor: + # TODO(Shinichi): Implement pipeline parallelism. hidden_states = self.embed_tokens(input_ids) residual = None - attention_layer_idx = 0 - mamba_layer_idx = 0 - for i in range(len(self.layers)): - layer = self.layers[i] - kv_cache = None + + mamba_cache_index = 0 + for layer in self.layers: layer_mamba_cache_params = None - if isinstance(layer, Plamo2AttentionDecoderLayer): - kv_cache = kv_caches[attention_layer_idx] - attention_layer_idx += 1 if isinstance(layer, Plamo2MambaDecoderLayer): layer_mamba_cache_params = mamba_cache_params.at_layer_idx( - mamba_layer_idx) - mamba_layer_idx += 1 + mamba_cache_index) + mamba_cache_index += 1 + hidden_states, residual = layer( positions=positions, hidden_states=hidden_states, - kv_cache=kv_cache, - attn_metadata=attn_metadata, residual=residual, mamba_cache_params=layer_mamba_cache_params) hidden_states, _ = self.final_layernorm(hidden_states, residual) @@ -712,61 +689,38 @@ class Plamo2ForCausalLM(PlamoPreTrainedModel, HasInnerState, SupportsLoRA): ], } - # LoRA specific attributes - supported_lora_modules = [ - "qkv_proj", - "o_proj", - "embed_tokens", - "lm_head", - ] - embedding_modules = { - "embed_tokens": "input_embeddings", - "lm_head": "output_embeddings", - } - embedding_padding_modules = ["lm_head"] - - def __init__( - self, - config: PlamoConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - lora_config: Optional[LoRAConfig] = None, - scheduler_config: Optional[SchedulerConfig] = None, - ) -> None: - assert not cache_config.enable_prefix_caching, \ + def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: + config = vllm_config.model_config.hf_config + scheduler_config = vllm_config.scheduler_config + assert not vllm_config.cache_config.enable_prefix_caching, \ "PLaMo2 currently does not support prefix caching" - super().__init__(config) + super().__init__(vllm_config.model_config.hf_config) self.config = config + self.vllm_config = vllm_config + self.model_config = vllm_config.model_config self.scheduler_config = scheduler_config - self.model = PlamoModel(config, - cache_config=cache_config, - quant_config=quant_config, - lora_config=lora_config, - scheduler_config=scheduler_config) - self.vocab_size = config.vocab_size - self.unpadded_vocab_size = config.vocab_size - if lora_config: - self.unpadded_vocab_size += lora_config.lora_extra_vocab_size + self.model = Plamo2Model(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) + self.vocab_size = self.config.vocab_size + self.unpadded_vocab_size = self.config.vocab_size num_embeddings = ((self.vocab_size + 15) // 16) * 16 self.lm_head = ParallelLMHead( num_embeddings, - config.hidden_size, - org_num_embeddings=config.vocab_size, - padding_size=DEFAULT_VOCAB_PADDING_SIZE - # We need bigger padding if using lora for kernel - # compatibility - if not lora_config else lora_config.lora_vocab_padding_size, + self.config.hidden_size, + org_num_embeddings=self.config.vocab_size, + padding_size=DEFAULT_VOCAB_PADDING_SIZE, + prefix=f"{prefix}.lm_head", ) - if config.tie_word_embeddings: + if self.config.tie_word_embeddings: self.lm_head = self.lm_head.tie_weights(self.model.embed_tokens) # Used to track and store by the Mamba cache between steps. self.mamba_cache: Optional[MambaCacheManager] = None self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, - config.vocab_size) - self.sampler = Sampler() + self.config.vocab_size) + self.sampler = get_sampler() # Initialize weights and apply final processing self.post_init() @@ -774,33 +728,23 @@ def __init__( def forward(self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[KVCache], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, **kwargs): if self.mamba_cache is None: - max_batch_size = (_get_graph_batch_size( - self.scheduler_config.max_num_seqs) if self.scheduler_config - else max(_BATCH_SIZES_TO_CAPTURE) + 2) - num_mamba_layers = sum([ is_mamba(self.config, i) for i in range(self.config.num_hidden_layers) ]) self.mamba_cache = MambaCacheManager( - self.lm_head.weight.dtype, num_mamba_layers, max_batch_size, + self.vllm_config, self.lm_head.weight.dtype, num_mamba_layers, *self._get_mamba_cache_shape()) - ( - mamba_cache_tensors, - state_indices_tensor, - ) = self.mamba_cache.current_run_tensors(input_ids, attn_metadata, - **kwargs) - mamba_cache_params = MambaCacheParams(mamba_cache_tensors[0], - mamba_cache_tensors[1], - state_indices_tensor) - hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, mamba_cache_params) + + mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs) + + hidden_states = self.model(input_ids, positions, mamba_cache_params, + intermediate_tensors, inputs_embeds) return hidden_states def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs): From b0f101eaa7be6355a0a6ea5f91961697a1432924 Mon Sep 17 00:00:00 2001 From: Shinichi Hemmi <50256998+Alnusjaponica@users.noreply.github.com> Date: Wed, 5 Mar 2025 11:44:03 +0900 Subject: [PATCH 03/36] Modify interfaces Signed-off-by: Shinichi Hemmi <50256998+Alnusjaponica@users.noreply.github.com> --- vllm/model_executor/models/plamo2.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/models/plamo2.py b/vllm/model_executor/models/plamo2.py index ecc5fbb4d9a9..580bfaf68a63 100644 --- a/vllm/model_executor/models/plamo2.py +++ b/vllm/model_executor/models/plamo2.py @@ -30,7 +30,8 @@ DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( composed_weight_loader, default_weight_loader, sharded_weight_loader) -from vllm.model_executor.models.interfaces import HasInnerState, SupportsLoRA +from vllm.model_executor.models.interfaces import (HasInnerState, IsHybrid, + SupportsV0Only) from vllm.model_executor.models.mamba_cache import (MambaCacheManager, MambaCacheParams) from vllm.model_executor.models.utils import maybe_prefix @@ -680,7 +681,8 @@ def forward( return hidden_states -class Plamo2ForCausalLM(PlamoPreTrainedModel, HasInnerState, SupportsLoRA): +class Plamo2ForCausalLM(PlamoPreTrainedModel, HasInnerState, IsHybrid, + SupportsV0Only): packed_modules_mapping = { "qkv_proj": [ "q_proj", From a783e312b8685d26394576c8ab02b38d9d565135 Mon Sep 17 00:00:00 2001 From: Shinichi Hemmi <50256998+Alnusjaponica@users.noreply.github.com> Date: Thu, 6 Mar 2025 02:30:35 +0900 Subject: [PATCH 04/36] Add workaround to use IsHybrid interface Signed-off-by: Shinichi Hemmi <50256998+Alnusjaponica@users.noreply.github.com> --- vllm/model_executor/models/plamo2.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/vllm/model_executor/models/plamo2.py b/vllm/model_executor/models/plamo2.py index 580bfaf68a63..b581a35fdde4 100644 --- a/vllm/model_executor/models/plamo2.py +++ b/vllm/model_executor/models/plamo2.py @@ -38,6 +38,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.utils import set_weight_attrs from vllm.sequence import IntermediateTensors +from vllm.utils import LayerBlockType KVCache = Tuple[torch.Tensor, torch.Tensor] @@ -697,11 +698,18 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: assert not vllm_config.cache_config.enable_prefix_caching, \ "PLaMo2 currently does not support prefix caching" - super().__init__(vllm_config.model_config.hf_config) + super().__init__(config) self.config = config self.vllm_config = vllm_config self.model_config = vllm_config.model_config self.scheduler_config = scheduler_config + + # TODO(Shinichi): Remove this workaround. + self.config.layers_block_type = [ + "mamba" if is_mamba(self.config, i) else "attention" + for i in range(self.config.num_hidden_layers) + ] + self.model = Plamo2Model(vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")) self.vocab_size = self.config.vocab_size @@ -734,10 +742,8 @@ def forward(self, inputs_embeds: Optional[torch.Tensor] = None, **kwargs): if self.mamba_cache is None: - num_mamba_layers = sum([ - is_mamba(self.config, i) - for i in range(self.config.num_hidden_layers) - ]) + num_mamba_layers = self.model_config.get_num_layers_by_block_type( + self.vllm_config.parallel_config, LayerBlockType.mamba) self.mamba_cache = MambaCacheManager( self.vllm_config, self.lm_head.weight.dtype, num_mamba_layers, From 5ffec2c72b3ced75d165b767e4ed633274abad36 Mon Sep 17 00:00:00 2001 From: Shinichi Hemmi <50256998+Alnusjaponica@users.noreply.github.com> Date: Thu, 6 Mar 2025 03:49:50 +0900 Subject: [PATCH 05/36] Update dependencies for test Signed-off-by: Shinichi Hemmi <50256998+Alnusjaponica@users.noreply.github.com> --- requirements-test.in | 1 + requirements-test.txt | 9 +++++++++ 2 files changed, 10 insertions(+) diff --git a/requirements-test.in b/requirements-test.in index de33f92b37b9..48e4688bc4e9 100644 --- a/requirements-test.in +++ b/requirements-test.in @@ -24,6 +24,7 @@ timm # required for internvl test torch==2.5.1 torchaudio==2.5.1 transformers_stream_generator # required for qwen-vl test +mamba_ssm # required for plamo2 test matplotlib # required for qwen-vl test mistral_common[opencv] >= 1.5.0 # required for pixtral test datamodel_code_generator # required for minicpm3 test diff --git a/requirements-test.txt b/requirements-test.txt index f5722c82e201..7a57c8165acc 100644 --- a/requirements-test.txt +++ b/requirements-test.txt @@ -106,6 +106,7 @@ einops==0.8.0 # via # -r requirements-test.in # encodec + # mamba-ssm # vector-quantize-pytorch # vocos einx==0.3.0 @@ -221,6 +222,8 @@ lm-eval==0.4.4 # via -r requirements-test.in lxml==5.3.0 # via sacrebleu +mamba-ssm==2.2.4 + # via -r requirements-test.in markdown-it-py==3.0.0 # via rich markupsafe==3.0.2 @@ -256,6 +259,8 @@ mypy-extensions==1.0.0 # via black networkx==3.2.1 # via torch +ninja==1.11.1.3 + # via mamba-ssm nltk==3.9.1 # via rouge-score numba==0.60.0 @@ -341,6 +346,7 @@ packaging==24.1 # fastparquet # huggingface-hub # lazy-loader + # mamba-ssm # matplotlib # peft # plotly @@ -545,6 +551,7 @@ sentencepiece==0.2.0 # via mistral-common setuptools==75.8.0 # via + # mamba-ssm # pytablewriter # torch six==1.16.0 @@ -598,6 +605,7 @@ torch==2.5.1 # bitsandbytes # encodec # lm-eval + # mamba-ssm # peft # runai-model-streamer # sentence-transformers @@ -633,6 +641,7 @@ transformers==4.48.2 # -r requirements-test.in # genai-perf # lm-eval + # mamba-ssm # peft # sentence-transformers # transformers-stream-generator From 49dd3b0e48ce73d447b953ac78142358f933bb31 Mon Sep 17 00:00:00 2001 From: Shinichi Hemmi <50256998+Alnusjaponica@users.noreply.github.com> Date: Thu, 6 Mar 2025 04:55:45 +0900 Subject: [PATCH 06/36] Add test for plamo2 model Signed-off-by: Shinichi Hemmi <50256998+Alnusjaponica@users.noreply.github.com> --- .../decoder_only/language/test_hybrid.py | 52 ++++++++++++------- 1 file changed, 34 insertions(+), 18 deletions(-) diff --git a/tests/models/decoder_only/language/test_hybrid.py b/tests/models/decoder_only/language/test_hybrid.py index a39b11923582..c04e9cafc9f3 100644 --- a/tests/models/decoder_only/language/test_hybrid.py +++ b/tests/models/decoder_only/language/test_hybrid.py @@ -1,5 +1,9 @@ # SPDX-License-Identifier: Apache-2.0 +import importlib +import site + +import pip import pytest from tests.utils import multi_gpu_test @@ -8,8 +12,14 @@ from ...utils import check_outputs_equal +# Install causal-conv1d here, as it is not compatible with pip-compile. +pip.main(['install', 'causal-conv1d']) +importlib.reload(site) + # This test is for the hybrid models -MODELS = ["ai21labs/Jamba-tiny-dev", "ibm-ai-platform/Bamba-9B"] +MODELS = [ + "ai21labs/Jamba-tiny-dev", "ibm-ai-platform/Bamba-9B", "pfnet/plamo-2-1b" +] @pytest.mark.parametrize("model", MODELS) @@ -25,17 +35,16 @@ def test_models( ) -> None: # numeric error produces different generation + model_kwargs = { + "use_mamba_kernels": False, # mamba kernels are not installed so HF + # don't use them + } if 'Bamba' in model: example_prompts.pop(3) + if 'plamo' in model: + model_kwargs = None - with hf_runner( - model, - dtype=dtype, - model_kwargs={ - "use_mamba_kernels": - False, # mamba kernels are not installed so HF - # don't use them - }) as hf_model: + with hf_runner(model, dtype=dtype, model_kwargs=model_kwargs) as hf_model: hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) with vllm_runner(model, dtype=dtype) as vllm_model: @@ -97,6 +106,10 @@ def test_mamba_prefill_chunking_with_parallel_sampling( # correctly for n > 1 decoding steps inside a # chunked prefill forward pass (where we have both prefills # and decoding together ) + + if 'plamo' in model: + dtype = "float" # use a different dtype for plamo + sampling_params = SamplingParams(n=3, temperature=1, seed=0, @@ -128,15 +141,17 @@ def test_mamba_prefill_chunking(hf_runner, vllm_runner, example_prompts, example_prompts.pop(3) example_prompts.pop(2) dtype = "half" # use a different dtype for Bamba + elif 'plamo' in model: + example_prompts.pop(7) - with hf_runner( - model, - dtype=dtype, - model_kwargs={ - "use_mamba_kernels": - False, # mamba kernels are not installed so HF - # don't use them - }) as hf_model: + model_kwargs = { + "use_mamba_kernels": False, # mamba kernels are not installed so HF + # don't use them + } + if 'plamo' in model: + model_kwargs = None + + with hf_runner(model, dtype=dtype, model_kwargs=model_kwargs) as hf_model: non_chunked = hf_model.generate_greedy(example_prompts, max_tokens) with vllm_runner(model, @@ -205,7 +220,8 @@ def test_mamba_cache_cg_padding( # This test is for verifying that mamba cache is padded to CG captured # batch size. If it's not, a torch RuntimeError will be raised because # tensor dimensions aren't compatible - vllm_config = EngineArgs(model=model).create_engine_config() + vllm_config = EngineArgs(model=model, + trust_remote_code=True).create_engine_config() while len(example_prompts) == vllm_config.pad_for_cudagraph( len(example_prompts)): example_prompts.append(example_prompts[0]) From 68d3bedab33e9bd8c651a99600c3ad73694142f3 Mon Sep 17 00:00:00 2001 From: Shinichi Hemmi <50256998+Alnusjaponica@users.noreply.github.com> Date: Thu, 6 Mar 2025 05:28:53 +0900 Subject: [PATCH 07/36] Modify code comment Signed-off-by: Shinichi Hemmi <50256998+Alnusjaponica@users.noreply.github.com> --- vllm/model_executor/models/plamo2.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/vllm/model_executor/models/plamo2.py b/vllm/model_executor/models/plamo2.py index b581a35fdde4..c9593868ed7e 100644 --- a/vllm/model_executor/models/plamo2.py +++ b/vllm/model_executor/models/plamo2.py @@ -49,7 +49,7 @@ class LinearType(str, enum.Enum): Fp8Retain = "fp8-retain" -# Just for type hinting and PlamoPreTrainedModel.config_class. +# Only used for type hinting and PlamoPreTrainedModel.config_class. class PlamoConfig(PretrainedConfig): # type: ignore model_type: str = "plamo" @@ -798,15 +798,15 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): params_dict = dict(self.named_parameters()) for name, loaded_weight in weights: - # Alignment team workaround: somehow when tie_word_embeddings=True, - # `lm_head.weight` may be in the safetensor, which causing dict key - # access error. + # Both tie_word_embeddings=True and lm_head.weight in the safetensor + # at the same time causes dict key access error. if name == "lm_head.weight" and self.config.tie_word_embeddings: assert "lm_head.weight" not in params_dict continue # Update the weight names to be compatible with the vllm version - # of the model. Do not change the order of the replacements. + # of the model. + # Do not change the order of the replacements. replacements = { # Skip PlamoDecoderLayers. ".layers.layers": ".layers", From bee8035de95f0b7bb87484af13439d7360a3c269 Mon Sep 17 00:00:00 2001 From: Shinichi Hemmi <50256998+Alnusjaponica@users.noreply.github.com> Date: Thu, 6 Mar 2025 12:19:35 +0900 Subject: [PATCH 08/36] Resolve mypy error Signed-off-by: Shinichi Hemmi <50256998+Alnusjaponica@users.noreply.github.com> --- tests/models/decoder_only/language/test_hybrid.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/models/decoder_only/language/test_hybrid.py b/tests/models/decoder_only/language/test_hybrid.py index c04e9cafc9f3..4fa8b6f68d87 100644 --- a/tests/models/decoder_only/language/test_hybrid.py +++ b/tests/models/decoder_only/language/test_hybrid.py @@ -42,7 +42,7 @@ def test_models( if 'Bamba' in model: example_prompts.pop(3) if 'plamo' in model: - model_kwargs = None + model_kwargs = {} with hf_runner(model, dtype=dtype, model_kwargs=model_kwargs) as hf_model: hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) @@ -149,7 +149,7 @@ def test_mamba_prefill_chunking(hf_runner, vllm_runner, example_prompts, # don't use them } if 'plamo' in model: - model_kwargs = None + model_kwargs = {} with hf_runner(model, dtype=dtype, model_kwargs=model_kwargs) as hf_model: non_chunked = hf_model.generate_greedy(example_prompts, max_tokens) From 10587773e8fe7975c78a5c9b41c292295b7e5fbe Mon Sep 17 00:00:00 2001 From: Shinichi Hemmi <50256998+Alnusjaponica@users.noreply.github.com> Date: Thu, 6 Mar 2025 12:48:55 +0900 Subject: [PATCH 09/36] Add plamo to test_registry Signed-off-by: Shinichi Hemmi <50256998+Alnusjaponica@users.noreply.github.com> --- tests/models/registry.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/models/registry.py b/tests/models/registry.py index b5ded20c5af5..98da2af86d7f 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -174,6 +174,7 @@ def check_available_online( trust_remote_code=True), "PhiMoEForCausalLM": _HfExamplesInfo("microsoft/Phi-3.5-MoE-instruct", trust_remote_code=True), + "PlamoForCausalLM": _HfExamplesInfo("pfnet/plamo-2-1b"), "QWenLMHeadModel": _HfExamplesInfo("Qwen/Qwen-7B-Chat", trust_remote_code=True), "Qwen2ForCausalLM": _HfExamplesInfo("Qwen/Qwen2-7B-Instruct", From e86e46ffa5822c9f199603be7c7f3e9f097a60e2 Mon Sep 17 00:00:00 2001 From: shemmi Date: Mon, 10 Mar 2025 14:29:45 +0900 Subject: [PATCH 10/36] pip-compile Signed-off-by: shemmi --- requirements/test.txt | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/requirements/test.txt b/requirements/test.txt index f952bbe2cc10..c4c1027f8891 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -223,7 +223,7 @@ lm-eval==0.4.4 lxml==5.3.0 # via sacrebleu mamba-ssm==2.2.4 - # via -r requirements-test.in + # via -r requirements/test.in markdown-it-py==3.0.0 # via rich markupsafe==3.0.2 @@ -553,7 +553,6 @@ setuptools==75.8.0 # via # mamba-ssm # pytablewriter - # torch six==1.16.0 # via # python-dateutil From 765975543b5601df74119b44d49eb906d5c2819e Mon Sep 17 00:00:00 2001 From: Shinichi Hemmi <50256998+Alnusjaponica@users.noreply.github.com> Date: Mon, 10 Mar 2025 18:04:38 +0900 Subject: [PATCH 11/36] pip-compile Signed-off-by: Shinichi Hemmi <50256998+Alnusjaponica@users.noreply.github.com> --- requirements/test.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements/test.txt b/requirements/test.txt index c4c1027f8891..e19a2ba26f52 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -553,6 +553,7 @@ setuptools==75.8.0 # via # mamba-ssm # pytablewriter + # torch six==1.16.0 # via # python-dateutil From e39437135b7d5a2b677ab90e8122e41abfc79a42 Mon Sep 17 00:00:00 2001 From: Shinichi Hemmi <50256998+Alnusjaponica@users.noreply.github.com> Date: Tue, 11 Mar 2025 23:03:19 +0900 Subject: [PATCH 12/36] Add workarounds to hundle the difference in config assumptions Signed-off-by: Shinichi Hemmi <50256998+Alnusjaponica@users.noreply.github.com> --- vllm/model_executor/models/plamo2.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/vllm/model_executor/models/plamo2.py b/vllm/model_executor/models/plamo2.py index c9593868ed7e..e42fc441c725 100644 --- a/vllm/model_executor/models/plamo2.py +++ b/vllm/model_executor/models/plamo2.py @@ -705,10 +705,16 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: self.scheduler_config = scheduler_config # TODO(Shinichi): Remove this workaround. + # vllm.model_executor.models.interfaces.IsHybrid requires + # self.config.layers_block_type to be set. self.config.layers_block_type = [ "mamba" if is_mamba(self.config, i) else "attention" for i in range(self.config.num_hidden_layers) ] + # ModelConfig.get_head_size assumes head_dim is set or calculated as + # hidden_size // num_attention_heads. However, this is not always + # the case for PLaMo2, as indicated by the FIXME comment. + setattr(self.config, "head_dim", self.config.hidden_size_per_head) self.model = Plamo2Model(vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")) From 9d7efcc08386244dddb68aaf854aded47a2cb238 Mon Sep 17 00:00:00 2001 From: Shinichi Hemmi <50256998+Alnusjaponica@users.noreply.github.com> Date: Tue, 11 Mar 2025 23:11:19 +0900 Subject: [PATCH 13/36] Make workaround simple Signed-off-by: Shinichi Hemmi <50256998+Alnusjaponica@users.noreply.github.com> --- vllm/model_executor/models/plamo2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/models/plamo2.py b/vllm/model_executor/models/plamo2.py index e42fc441c725..b88b20a1de16 100644 --- a/vllm/model_executor/models/plamo2.py +++ b/vllm/model_executor/models/plamo2.py @@ -714,7 +714,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: # ModelConfig.get_head_size assumes head_dim is set or calculated as # hidden_size // num_attention_heads. However, this is not always # the case for PLaMo2, as indicated by the FIXME comment. - setattr(self.config, "head_dim", self.config.hidden_size_per_head) + self.config.head_dim = self.config.hidden_size_per_head self.model = Plamo2Model(vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")) From f4a6ac192312d1ee7bd66ba139730b840476d79a Mon Sep 17 00:00:00 2001 From: Shinichi Hemmi <50256998+Alnusjaponica@users.noreply.github.com> Date: Wed, 19 Mar 2025 12:10:10 +0900 Subject: [PATCH 14/36] yapf Signed-off-by: Shinichi Hemmi <50256998+Alnusjaponica@users.noreply.github.com> --- tests/models/decoder_only/language/test_hybrid.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/models/decoder_only/language/test_hybrid.py b/tests/models/decoder_only/language/test_hybrid.py index c93b3330afe9..7946a87ae319 100644 --- a/tests/models/decoder_only/language/test_hybrid.py +++ b/tests/models/decoder_only/language/test_hybrid.py @@ -17,7 +17,10 @@ importlib.reload(site) # This test is for the hybrid models -MODELS = ["ai21labs/Jamba-tiny-dev", "Zyphra/Zamba2-1.2B-instruct", "pfnet/plamo-2-1b"] +MODELS = [ + "ai21labs/Jamba-tiny-dev", "Zyphra/Zamba2-1.2B-instruct", + "pfnet/plamo-2-1b" +] # Bamba at Fp32 is too big for the CI (L4 GPU). # MODELS = ["ai21labs/Jamba-tiny-dev", "ibm-ai-platform/Bamba-9B"] From 9e013484b99ee4594d05501d3470e38c355d62bf Mon Sep 17 00:00:00 2001 From: Shinichi Hemmi <50256998+Alnusjaponica@users.noreply.github.com> Date: Thu, 20 Mar 2025 11:10:37 +0900 Subject: [PATCH 15/36] Added PLaMo to docs Signed-off-by: Shinichi Hemmi <50256998+Alnusjaponica@users.noreply.github.com> --- docs/source/models/supported_models.md | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/docs/source/models/supported_models.md b/docs/source/models/supported_models.md index fbcea826e6c9..bc7a135ef093 100644 --- a/docs/source/models/supported_models.md +++ b/docs/source/models/supported_models.md @@ -437,6 +437,11 @@ See [this page](#generative-models) for more information on how to use generativ * `adept/persimmon-8b-base`, `adept/persimmon-8b-chat`, etc. * * ✅︎ +- * `Plamo2ForCausalLM` + * PLaMo2 + * `pfnet/plamo-2-1b`, `pfnet/plamo-2-8b`, etc. + * + * - * `QWenLMHeadModel` * Qwen * `Qwen/Qwen-7B`, `Qwen/Qwen-7B-Chat`, etc. From d051b1f867d7253a1fd53dfa224def9df091bd6d Mon Sep 17 00:00:00 2001 From: Shinichi Hemmi <50256998+Alnusjaponica@users.noreply.github.com> Date: Thu, 20 Mar 2025 11:10:47 +0900 Subject: [PATCH 16/36] Set trust_remote_code=true for PLaMo in the test Signed-off-by: Shinichi Hemmi <50256998+Alnusjaponica@users.noreply.github.com> --- tests/models/registry.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/models/registry.py b/tests/models/registry.py index b5c2c122ae5d..fb5be1a22252 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -178,7 +178,8 @@ def check_available_online( trust_remote_code=True), "PhiMoEForCausalLM": _HfExamplesInfo("microsoft/Phi-3.5-MoE-instruct", trust_remote_code=True), - "PlamoForCausalLM": _HfExamplesInfo("pfnet/plamo-2-1b"), + "PlamoForCausalLM": _HfExamplesInfo("pfnet/plamo-2-1b", + trust_remote_code=True), "QWenLMHeadModel": _HfExamplesInfo("Qwen/Qwen-7B-Chat", trust_remote_code=True), "Qwen2ForCausalLM": _HfExamplesInfo("Qwen/Qwen2-7B-Instruct", From 1a7111bbc6a9b6ed8be19725193020fc79ff39b6 Mon Sep 17 00:00:00 2001 From: Shinichi Hemmi <50256998+Alnusjaponica@users.noreply.github.com> Date: Thu, 20 Mar 2025 12:35:15 +0900 Subject: [PATCH 17/36] Clean-up unused lines Signed-off-by: Shinichi Hemmi <50256998+Alnusjaponica@users.noreply.github.com> --- vllm/model_executor/models/plamo2.py | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/vllm/model_executor/models/plamo2.py b/vllm/model_executor/models/plamo2.py index b88b20a1de16..f72d5ab960f3 100644 --- a/vllm/model_executor/models/plamo2.py +++ b/vllm/model_executor/models/plamo2.py @@ -2,7 +2,7 @@ """Inference-only PLaMo2 model.""" import enum import math -from typing import Any, Iterable, List, Optional, Tuple +from typing import Any, Iterable, Optional, Tuple import torch from torch import nn @@ -40,8 +40,6 @@ from vllm.sequence import IntermediateTensors from vllm.utils import LayerBlockType -KVCache = Tuple[torch.Tensor, torch.Tensor] - class LinearType(str, enum.Enum): Normal = "normal" @@ -153,12 +151,6 @@ def __init__( class PlamoPreTrainedModel(PreTrainedModel): # type: ignore config_class = PlamoConfig - _no_split_modules: List[str] - base_model_prefix = "model" - supports_gradient_checkpointing = True - _no_split_modules = ["PlamoDecoderLayer"] - _skip_keys_device_placement = "past_key_values" - _keys_to_ignore_on_load_unexpected = [r"decoder\.version"] def _init_weights(self, module: torch.nn.Module) -> None: std = 0.02 From b318d0fa7e05b121a47fd435a1e762115c03d6f4 Mon Sep 17 00:00:00 2001 From: Shinichi Hemmi <50256998+Alnusjaponica@users.noreply.github.com> Date: Fri, 21 Mar 2025 17:57:13 +0900 Subject: [PATCH 18/36] Revert renaming final norm component on loading model Signed-off-by: Shinichi Hemmi <50256998+Alnusjaponica@users.noreply.github.com> --- vllm/model_executor/models/plamo2.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/vllm/model_executor/models/plamo2.py b/vllm/model_executor/models/plamo2.py index f72d5ab960f3..58b04f7b780a 100644 --- a/vllm/model_executor/models/plamo2.py +++ b/vllm/model_executor/models/plamo2.py @@ -641,7 +641,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): max_model_len=vllm_config.scheduler_config.max_model_len, prefix=f"{prefix}.decoder_layers.{i}")) self.layers = nn.ModuleList(decoder_layers) - self.final_layernorm = RMSNorm(config.hidden_size, + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_init() @@ -670,7 +670,7 @@ def forward( hidden_states=hidden_states, residual=residual, mamba_cache_params=layer_mamba_cache_params) - hidden_states, _ = self.final_layernorm(hidden_states, residual) + hidden_states, _ = self.norm(hidden_states, residual) return hidden_states @@ -810,8 +810,6 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): ".layers.layers": ".layers", # Skip PlmoDecoderLayer. ".mixer": "", - # Rename the final layernorm of the model. - "model.norm.weight": "model.final_layernorm.weight", # Rename each mamba layer's components. ".A_log": ".mamba.A", @@ -851,7 +849,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): loaded_weight += 1.0 elif ".post_mlp_norm" in name: loaded_weight += 1.0 / (5**1.5) - elif "model.final_layernorm.weight" in name: + elif "model.norm.weight" in name: loaded_weight += 1.0 param = params_dict[name] From d8df40da261f5ac66a1f3534a4b9fd8cbc340a26 Mon Sep 17 00:00:00 2001 From: Shinichi Hemmi <50256998+Alnusjaponica@users.noreply.github.com> Date: Fri, 21 Mar 2025 18:01:23 +0900 Subject: [PATCH 19/36] Clean-up PlamoConfig Signed-off-by: Shinichi Hemmi <50256998+Alnusjaponica@users.noreply.github.com> --- vllm/model_executor/models/plamo2.py | 127 ++++----------------------- 1 file changed, 19 insertions(+), 108 deletions(-) diff --git a/vllm/model_executor/models/plamo2.py b/vllm/model_executor/models/plamo2.py index 58b04f7b780a..f67b28b503bc 100644 --- a/vllm/model_executor/models/plamo2.py +++ b/vllm/model_executor/models/plamo2.py @@ -1,8 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 """Inference-only PLaMo2 model.""" -import enum import math -from typing import Any, Iterable, Optional, Tuple +from typing import Iterable, Optional, Tuple import torch from torch import nn @@ -41,116 +40,29 @@ from vllm.utils import LayerBlockType -class LinearType(str, enum.Enum): - Normal = "normal" - Fp8 = "fp8" - Fp8Retain = "fp8-retain" - - -# Only used for type hinting and PlamoPreTrainedModel.config_class. +# Only used for type hinting. class PlamoConfig(PretrainedConfig): # type: ignore model_type: str = "plamo" - def __init__( - self, - hidden_size: int = 4096, - num_hidden_layers: int = 32, - rms_norm_eps: float = 1e-6, - tie_word_embeddings: bool = False, - # Attention - num_attention_heads: int = 32, - num_key_value_heads: int = 4, - hidden_size_per_head: int = 128, - max_position_embeddings: int = 2048, - attention_window_size: int = 2048, - # Mamba - mamba_d_state: int = 64, - mamba_d_conv: int = 4, - mamba_num_heads: int = 64, - mamba_step: int = 2, - mamba_chunk_size: int = 256, - # MLP - intermediate_size: int = 13312, - # Tokenizer - vocab_size: int = 32000, - tokenizer_class: str = "PlamoTokenizer", - pad_token_id: Optional[int] = None, - bos_token_id: int = 1, - eos_token_id: int = 2, - # MoE - n_expert: Optional[int] = None, - k_expert: Optional[int] = None, - expert_dropout: float = 0.0, - capacity_factor: float = 1.0, - group_size: int = 1024, - sparse_step: Optional[int] = None, - sparse_intermediate_size: Optional[int] = None, - shared_intermediate_size: Optional[int] = None, - # FP8 - linear_type: LinearType = LinearType.Normal, - fp8_accum_dtype: Optional[str] = None, - # Evaluation - eval_attention_n_bit: Optional[int] = None, - eval_mlp_n_bit: Optional[int] = None, - eval_offload_moe: bool = False, - use_cache: bool = True, - use_predefined_initial_state: bool = False, - **kwargs: Any, - ) -> None: - # max_position_embeddings is often used to determine the max length - # during inference, but samba should have extrapolation abilities - self.max_position_embeddings = max(10 * 1024 * 1024, - max_position_embeddings) - self.hidden_size = hidden_size - self.rms_norm_eps = rms_norm_eps - - self.num_hidden_layers = num_hidden_layers - self.num_attention_heads = num_attention_heads - self.hidden_size_per_head = hidden_size_per_head - self.num_key_value_heads = num_key_value_heads - self.attention_window_size = attention_window_size - - self.mamba_d_state = mamba_d_state - self.mamba_d_conv = mamba_d_conv - self.mamba_num_heads = mamba_num_heads - self.mamba_step = mamba_step - self.mamba_chunk_size = mamba_chunk_size - - self.intermediate_size = intermediate_size - - self.vocab_size = vocab_size - - self.n_expert = n_expert - self.k_expert = k_expert - self.sparse_intermediate_size = sparse_intermediate_size - self.shared_intermediate_size = shared_intermediate_size - self.expert_dropout = expert_dropout - self.capacity_factor = capacity_factor - self.group_size = group_size - self.sparse_step = sparse_step - - self.linear_type = linear_type - self.fp8_accum_dtype = fp8_accum_dtype - - self.eval_attention_n_bit = eval_attention_n_bit - self.eval_mlp_n_bit = eval_mlp_n_bit - self.eval_offload_moe = eval_offload_moe - self.use_cache = use_cache - - self.use_predefined_initial_state = use_predefined_initial_state - - super().__init__( - tokenizer_class=tokenizer_class, - pad_token_id=pad_token_id, - bos_token_id=bos_token_id, - eos_token_id=eos_token_id, - tie_word_embeddings=tie_word_embeddings, - **kwargs, - ) + hidden_size: int + num_hidden_layers: int + rms_norm_eps: float + # Attention + num_attention_heads: int + hidden_size_per_head: int + num_key_value_heads: int + # Mamba + mamba_d_state: int + mamba_d_conv: int + mamba_num_heads: int + mamba_step: int + # MLP + intermediate_size: int + # Tokenizer + vocab_size: int class PlamoPreTrainedModel(PreTrainedModel): # type: ignore - config_class = PlamoConfig def _init_weights(self, module: torch.nn.Module) -> None: std = 0.02 @@ -641,8 +553,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): max_model_len=vllm_config.scheduler_config.max_model_len, prefix=f"{prefix}.decoder_layers.{i}")) self.layers = nn.ModuleList(decoder_layers) - self.norm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_init() def forward( From a36caaf3800f03d784ee2c0be9cc508faf6bfb5a Mon Sep 17 00:00:00 2001 From: Shinichi Hemmi <50256998+Alnusjaponica@users.noreply.github.com> Date: Sun, 23 Mar 2025 20:37:18 +0900 Subject: [PATCH 20/36] Revert PlamoDecoder for class structure consistency with transformers implementaion Signed-off-by: Shinichi Hemmi <50256998+Alnusjaponica@users.noreply.github.com> --- vllm/model_executor/models/plamo2.py | 73 ++++++++++++++++++---------- 1 file changed, 47 insertions(+), 26 deletions(-) diff --git a/vllm/model_executor/models/plamo2.py b/vllm/model_executor/models/plamo2.py index f67b28b503bc..b03455484d60 100644 --- a/vllm/model_executor/models/plamo2.py +++ b/vllm/model_executor/models/plamo2.py @@ -519,27 +519,15 @@ def forward( return hidden_states, residual -class Plamo2Model(PlamoPreTrainedModel): +class PlamoDecoder(torch.nn.Module): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - super().__init__(vllm_config.model_config.hf_config) + def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None: + super().__init__() config = vllm_config.model_config.hf_config cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config - self.config = config - self.padding_idx = config.pad_token_id - self.vocab_size = config.vocab_size - self.org_vocab_size = config.vocab_size - - self.embed_tokens = VocabParallelEmbedding( - self.vocab_size, - config.hidden_size, - org_num_embeddings=config.vocab_size, - prefix=f"{prefix}.embed_tokens", - ) - decoder_layers = [] for i in range(config.num_hidden_layers): layer_class = Plamo2MambaDecoderLayer if is_mamba( @@ -553,21 +541,14 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): max_model_len=vllm_config.scheduler_config.max_model_len, prefix=f"{prefix}.decoder_layers.{i}")) self.layers = nn.ModuleList(decoder_layers) - self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.post_init() def forward( self, - input_ids: torch.Tensor, positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], mamba_cache_params: MambaCacheParams, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, ) -> torch.Tensor: - # TODO(Shinichi): Implement pipeline parallelism. - hidden_states = self.embed_tokens(input_ids) - residual = None - mamba_cache_index = 0 for layer in self.layers: layer_mamba_cache_params = None @@ -581,6 +562,48 @@ def forward( hidden_states=hidden_states, residual=residual, mamba_cache_params=layer_mamba_cache_params) + return hidden_states, residual + + +class Plamo2Model(PlamoPreTrainedModel): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__(vllm_config.model_config.hf_config) + + config = vllm_config.model_config.hf_config + + self.config = config + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + self.org_vocab_size = config.vocab_size + + self.embed_tokens = VocabParallelEmbedding( + self.vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + prefix=f"{prefix}.embed_tokens", + ) + self.layers = PlamoDecoder(vllm_config, prefix=f"{prefix}.layers") + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_init() + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + mamba_cache_params: MambaCacheParams, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + # TODO(Shinichi): Implement pipeline parallelism. + hidden_states = self.embed_tokens(input_ids) + residual = None + + hidden_states, residual = self.layers( + positions=positions, + hidden_states=hidden_states, + residual=residual, + mamba_cache_params=mamba_cache_params) hidden_states, _ = self.norm(hidden_states, residual) return hidden_states @@ -717,8 +740,6 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # of the model. # Do not change the order of the replacements. replacements = { - # Skip PlamoDecoderLayers. - ".layers.layers": ".layers", # Skip PlmoDecoderLayer. ".mixer": "", From 7cbdc8c94f3fdd187f8620337552b96e865a25fd Mon Sep 17 00:00:00 2001 From: Shinichi Hemmi <50256998+Alnusjaponica@users.noreply.github.com> Date: Sun, 23 Mar 2025 21:03:17 +0900 Subject: [PATCH 21/36] Rename PlamoDecoder to Plamo2Decoder Signed-off-by: Shinichi Hemmi <50256998+Alnusjaponica@users.noreply.github.com> --- vllm/model_executor/models/plamo2.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/models/plamo2.py b/vllm/model_executor/models/plamo2.py index b03455484d60..1c70b68eabb7 100644 --- a/vllm/model_executor/models/plamo2.py +++ b/vllm/model_executor/models/plamo2.py @@ -519,7 +519,7 @@ def forward( return hidden_states, residual -class PlamoDecoder(torch.nn.Module): +class Plamo2Decoder(torch.nn.Module): def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None: super().__init__() @@ -583,7 +583,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): org_num_embeddings=config.vocab_size, prefix=f"{prefix}.embed_tokens", ) - self.layers = PlamoDecoder(vllm_config, prefix=f"{prefix}.layers") + self.layers = Plamo2Decoder(vllm_config, prefix=f"{prefix}.layers") self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_init() From 4368f63ba863f40cdce581ab15d5738b3e076afe Mon Sep 17 00:00:00 2001 From: Shinichi Hemmi <50256998+Alnusjaponica@users.noreply.github.com> Date: Sun, 23 Mar 2025 21:20:00 +0900 Subject: [PATCH 22/36] Revert Plamo2DecoderLayer for consistency with transformers implementation Signed-off-by: Shinichi Hemmi <50256998+Alnusjaponica@users.noreply.github.com> --- vllm/model_executor/models/plamo2.py | 185 ++++++++++++--------------- 1 file changed, 79 insertions(+), 106 deletions(-) diff --git a/vllm/model_executor/models/plamo2.py b/vllm/model_executor/models/plamo2.py index 1c70b68eabb7..81dfd838f7d4 100644 --- a/vllm/model_executor/models/plamo2.py +++ b/vllm/model_executor/models/plamo2.py @@ -119,7 +119,13 @@ def _swiglu(h: torch.Tensor) -> torch.Tensor: class Plamo2MambaMixer(nn.Module): # TODO(Shinichi): Rebase on Mamba2 implementation. - def __init__(self, config: PlamoConfig, prefix: str = ""): + def __init__(self, + config: PlamoConfig, + cache_config: CacheConfig, + quant_config: QuantizationConfig, + max_model_len: int, + prefix: str = "", + **kwargs) -> None: super().__init__() self.config = config self.hidden_size = config.hidden_size @@ -150,11 +156,11 @@ def __init__(self, config: PlamoConfig, prefix: str = ""): prefix=f"{prefix}.in_proj", ) # selective projection used to make dt, B and C input dependent - self.x_proj = RowParallelLinear( + self.bcdt_proj = RowParallelLinear( self.intermediate_size, self.time_step_rank + self.ssm_state_size * 2, bias=False, - prefix=f"{prefix}.x_proj", + prefix=f"{prefix}.bcdt_proj", ) # time step projection (discretization) - # In the forward we need to apply dt_proj without the bias, @@ -194,8 +200,12 @@ def __init__(self, config: PlamoConfig, prefix: str = ""): self.c_layernorm = RMSNorm(self.ssm_state_size, eps=config.rms_norm_eps) - def forward(self, hidden_states: torch.Tensor, - mamba_cache_params: MambaCacheParams): + def forward( + self, + hidden_states: torch.Tensor, + mamba_cache_params: MambaCacheParams, + **kwargs, + ) -> torch.Tensor: attn_metadata: AttentionMetadata = get_forward_context().attn_metadata @@ -244,7 +254,7 @@ def forward(self, hidden_states: torch.Tensor, # 3. State Space Model sequence transformation # 3.a. input varying initialization of time_step, B and C - ssm_parameters = self.x_proj(hidden_states.transpose(-2, -1))[0] + ssm_parameters = self.bcdt_proj(hidden_states.transpose(-2, -1))[0] # Splitting the ssm_parameters as in modeling_plamo.py. B, C, time_step = torch.split( @@ -349,60 +359,12 @@ def __init__(self, quant_config=quant_config) -class Plamo2MambaDecoderLayer(nn.Module): - - def __init__(self, - config: PlamoConfig, - layer_idx: int, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - max_model_len: int | None = None, - **kwargs) -> None: - super().__init__() - self.config = config - self.mamba = Plamo2MambaMixer(config) - - ffn_layer_class = DenseMLP - self.mlp = ffn_layer_class(config, quant_config=quant_config) - self.pre_mixer_norm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.post_mixer_norm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.pre_mlp_norm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.post_mlp_norm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - - def forward( - self, - hidden_states: torch.Tensor, - residual: Optional[torch.Tensor], - mamba_cache_params: MambaCacheParams, - **kwargs, - ): - if residual is None: - residual = hidden_states - hidden_states = self.pre_mixer_norm(hidden_states) - else: - hidden_states, residual = self.pre_mixer_norm( - hidden_states, residual) - - hidden_states = self.mamba(hidden_states, mamba_cache_params) - hidden_states = self.post_mixer_norm(hidden_states) - # Fully Connected - hidden_states, residual = self.pre_mlp_norm(hidden_states, residual) - hidden_states = self.mlp(hidden_states) - hidden_states = self.post_mlp_norm(hidden_states) - return hidden_states, residual - - -class Plamo2AttentionDecoderLayer(nn.Module): +class Plamo2AttentionMixer(nn.Module): def __init__(self, config: PlamoConfig, - layer_idx: int, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, + cache_config: CacheConfig, + quant_config: QuantizationConfig, max_model_len: int | None = None, prefix: str = "", **kwargs) -> None: @@ -467,21 +429,11 @@ def __init__(self, prefix=f"{prefix}.attn", ) - ffn_layer_class = DenseMLP - self.mlp = ffn_layer_class(config, quant_config=quant_config) - self.pre_mixer_norm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.post_mixer_norm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.pre_mlp_norm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.post_mlp_norm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - - def self_attention( + def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], **kwargs, ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) @@ -493,11 +445,52 @@ def self_attention( output, _ = self.o_proj(attn_output) return output + +class Plamo2DecoderLayer(nn.Module): + + def __init__(self, + vllm_config: VllmConfig, + layer_idx: int, + max_model_len: int | None = None, + prefix: str = "", + **kwargs) -> None: + super().__init__() + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + max_model_len = vllm_config.scheduler_config.max_model_len + + self.is_mamba = is_mamba(config, layer_idx) + if self.is_mamba: + self.mixer = Plamo2MambaMixer(config=config, + cache_config=cache_config, + quant_config=quant_config, + max_model_len=max_model_len, + prefix=f"{prefix}.mixer") + else: + self.mixer = Plamo2AttentionMixer(config=config, + cache_config=cache_config, + quant_config=quant_config, + max_model_len=max_model_len, + prefix=f"{prefix}.mixer") + + ffn_layer_class = DenseMLP + self.mlp = ffn_layer_class(config, quant_config=quant_config) + self.pre_mixer_norm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_mixer_norm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.pre_mlp_norm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_mlp_norm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, residual: Optional[torch.Tensor], + mamba_cache_params: MambaCacheParams, **kwargs, ): if residual is None: @@ -507,10 +500,10 @@ def forward( hidden_states, residual = self.pre_mixer_norm( hidden_states, residual) - hidden_states = self.self_attention( - positions=positions, - hidden_states=hidden_states, - ) + hidden_states = self.mixer(positions=positions, + hidden_states=hidden_states, + residual=residual, + mamba_cache_params=mamba_cache_params) hidden_states = self.post_mixer_norm(hidden_states) # Fully Connected hidden_states, residual = self.pre_mlp_norm(hidden_states, residual) @@ -523,24 +516,14 @@ class Plamo2Decoder(torch.nn.Module): def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None: super().__init__() + num_hidden_layers = vllm_config.model_config.hf_config.num_hidden_layers - config = vllm_config.model_config.hf_config - cache_config = vllm_config.cache_config - quant_config = vllm_config.quant_config - - decoder_layers = [] - for i in range(config.num_hidden_layers): - layer_class = Plamo2MambaDecoderLayer if is_mamba( - config, i) else Plamo2AttentionDecoderLayer - decoder_layers.append( - layer_class( - config, - layer_idx=i, - cache_config=cache_config, - quant_config=quant_config, - max_model_len=vllm_config.scheduler_config.max_model_len, - prefix=f"{prefix}.decoder_layers.{i}")) - self.layers = nn.ModuleList(decoder_layers) + self.layers = nn.ModuleList([ + Plamo2DecoderLayer(vllm_config=vllm_config, + layer_idx=i, + prefix=f"{prefix}.layers.{i}") + for i in range(num_hidden_layers) + ]) def forward( self, @@ -552,7 +535,7 @@ def forward( mamba_cache_index = 0 for layer in self.layers: layer_mamba_cache_params = None - if isinstance(layer, Plamo2MambaDecoderLayer): + if layer.is_mamba: layer_mamba_cache_params = mamba_cache_params.at_layer_idx( mamba_cache_index) mamba_cache_index += 1 @@ -740,21 +723,11 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # of the model. # Do not change the order of the replacements. replacements = { - # Skip PlmoDecoderLayer. - ".mixer": "", - - # Rename each mamba layer's components. - ".A_log": ".mamba.A", - ".B_norm_weight": ".mamba.b_layernorm.weight", - ".C_norm_weight": ".mamba.c_layernorm.weight", - ".dt_norm_weight": ".mamba.dt_layernorm.weight", - ".bcdt_proj.weight": ".mamba.x_proj.weight", - ".conv1d.weight": ".mamba.conv1d.weight", - ".in_proj.weight": ".mamba.in_proj.weight", - ".out_proj.weight": ".mamba.out_proj.weight", - ".D": ".mamba.D", - ".dt_bias": ".mamba.dt_bias", - ".dt_proj.weight": ".mamba.dt_proj.weight", + # Rename incompatible weight names. + ".A_log": ".A", + ".B_norm_weight": ".b_layernorm.weight", + ".C_norm_weight": ".c_layernorm.weight", + ".dt_norm_weight": ".dt_layernorm.weight", } # Apply replacements based on the defined mappings for old, new in replacements.items(): From a4510114a76dea989857fd14411b543eca6a545f Mon Sep 17 00:00:00 2001 From: Shinichi Hemmi <50256998+Alnusjaponica@users.noreply.github.com> Date: Sun, 23 Mar 2025 21:25:10 +0900 Subject: [PATCH 23/36] Drop Plamo2MoE for consistency with transformers implementaion Signed-off-by: Shinichi Hemmi <50256998+Alnusjaponica@users.noreply.github.com> --- vllm/model_executor/models/plamo2.py | 20 +------------------- 1 file changed, 1 insertion(+), 19 deletions(-) diff --git a/vllm/model_executor/models/plamo2.py b/vllm/model_executor/models/plamo2.py index 81dfd838f7d4..fe566e725d99 100644 --- a/vllm/model_executor/models/plamo2.py +++ b/vllm/model_executor/models/plamo2.py @@ -318,17 +318,14 @@ def forward( return contextualized_states -class Plamo2MoE(nn.Module): +class DenseMLP(nn.Module): def __init__(self, config: PlamoConfig, - num_experts: Optional[int] = None, - top_k: Optional[int] = None, params_dtype: Optional[torch.dtype] = None, tp_size: Optional[int] = None, quant_config: Optional[QuantizationConfig] = None) -> None: super().__init__() - assert num_experts is None or num_experts <= 1, "MoE not supported" self.hidden_size = config.hidden_size self.intermediate_size = config.intermediate_size self.gate_up_proj = torch.nn.Linear(self.hidden_size, @@ -344,21 +341,6 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: return self.down_proj(h) # type: ignore -class DenseMLP(Plamo2MoE): - - def __init__(self, - config: PlamoConfig, - params_dtype: Optional[torch.dtype] = None, - tp_size: Optional[int] = None, - quant_config: Optional[QuantizationConfig] = None): - super().__init__(config, - num_experts=1, - top_k=1, - params_dtype=params_dtype, - tp_size=tp_size, - quant_config=quant_config) - - class Plamo2AttentionMixer(nn.Module): def __init__(self, From 256957f2e1423785e9d21b719553d00632f6c678 Mon Sep 17 00:00:00 2001 From: Shinichi Hemmi <50256998+Alnusjaponica@users.noreply.github.com> Date: Sun, 23 Mar 2025 21:36:57 +0900 Subject: [PATCH 24/36] Minimize model's member renaming Signed-off-by: Shinichi Hemmi <50256998+Alnusjaponica@users.noreply.github.com> --- vllm/model_executor/models/plamo2.py | 21 +++++++++------------ 1 file changed, 9 insertions(+), 12 deletions(-) diff --git a/vllm/model_executor/models/plamo2.py b/vllm/model_executor/models/plamo2.py index fe566e725d99..0968d1101a76 100644 --- a/vllm/model_executor/models/plamo2.py +++ b/vllm/model_executor/models/plamo2.py @@ -193,12 +193,9 @@ def __init__(self, # The activation function is fixed to SiLU. self.activation = "silu" - self.dt_layernorm = RMSNorm(self.time_step_rank, - eps=config.rms_norm_eps) - self.b_layernorm = RMSNorm(self.ssm_state_size, - eps=config.rms_norm_eps) - self.c_layernorm = RMSNorm(self.ssm_state_size, - eps=config.rms_norm_eps) + self.dt_norm = RMSNorm(self.time_step_rank, eps=config.rms_norm_eps) + self.B_norm = RMSNorm(self.ssm_state_size, eps=config.rms_norm_eps) + self.C_norm = RMSNorm(self.ssm_state_size, eps=config.rms_norm_eps) def forward( self, @@ -262,9 +259,9 @@ def forward( [self.ssm_state_size, self.ssm_state_size, self.time_step_rank], dim=-1, ) - time_step = self.dt_layernorm(time_step.contiguous()) - B = self.b_layernorm(B.contiguous()) - C = self.c_layernorm(C.contiguous()) + time_step = self.dt_norm(time_step.contiguous()) + B = self.B_norm(B.contiguous()) + C = self.C_norm(C.contiguous()) discrete_time_step = self.dt_proj(time_step)[0].transpose(-2, -1) # 3.c perform the recurrence y ← SSM(A, B, C)(x) @@ -707,9 +704,9 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): replacements = { # Rename incompatible weight names. ".A_log": ".A", - ".B_norm_weight": ".b_layernorm.weight", - ".C_norm_weight": ".c_layernorm.weight", - ".dt_norm_weight": ".dt_layernorm.weight", + ".B_norm_weight": ".B_norm.weight", + ".C_norm_weight": ".C_norm.weight", + ".dt_norm_weight": ".dt_norm.weight", } # Apply replacements based on the defined mappings for old, new in replacements.items(): From 0f9f140ad6c8b7aa765182a2aafce77e24e1a391 Mon Sep 17 00:00:00 2001 From: Shinichi Hemmi <50256998+Alnusjaponica@users.noreply.github.com> Date: Sun, 23 Mar 2025 21:49:39 +0900 Subject: [PATCH 25/36] Move causal-conv1d installation to buildkite config Signed-off-by: Shinichi Hemmi <50256998+Alnusjaponica@users.noreply.github.com> --- .buildkite/test-pipeline.yaml | 4 ++++ tests/models/decoder_only/language/test_hybrid.py | 11 +++-------- 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 230dd8383420..788ff9cbda1c 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -392,6 +392,8 @@ steps: - tests/models/embedding/language - tests/models/encoder_decoder/language commands: + # Install causal-conv1d for plamo2 models here, as it is not compatible with pip-compile. + - pip install causal-conv1d - pytest -v -s models/decoder_only/language -m 'core_model or quant_model' - pytest -v -s models/embedding/language -m core_model @@ -403,6 +405,8 @@ steps: - tests/models/embedding/language - tests/models/encoder_decoder/language commands: + # Install causal-conv1d for plamo2 models here, as it is not compatible with pip-compile. + - pip install causal-conv1d - pytest -v -s models/decoder_only/language -m 'not core_model and not quant_model' - pytest -v -s models/embedding/language -m 'not core_model' diff --git a/tests/models/decoder_only/language/test_hybrid.py b/tests/models/decoder_only/language/test_hybrid.py index 7946a87ae319..71adefebb1de 100644 --- a/tests/models/decoder_only/language/test_hybrid.py +++ b/tests/models/decoder_only/language/test_hybrid.py @@ -1,9 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 -import importlib -import site - -import pip import pytest from tests.utils import multi_gpu_test @@ -12,10 +8,6 @@ from ...utils import check_outputs_equal -# Install causal-conv1d here, as it is not compatible with pip-compile. -pip.main(['install', 'causal-conv1d']) -importlib.reload(site) - # This test is for the hybrid models MODELS = [ "ai21labs/Jamba-tiny-dev", "Zyphra/Zamba2-1.2B-instruct", @@ -23,6 +15,9 @@ ] # Bamba at Fp32 is too big for the CI (L4 GPU). # MODELS = ["ai21labs/Jamba-tiny-dev", "ibm-ai-platform/Bamba-9B"] +# Note: Running Plamo2 in transformers implementation requires to install +# causal-conv1d package, which is not listed as a test dependency as it's +# not compatible with pip-compile. @pytest.mark.parametrize("model", MODELS) From dc50e0a25705848b13abb19806fc1336a83544f3 Mon Sep 17 00:00:00 2001 From: Shinichi Hemmi <50256998+Alnusjaponica@users.noreply.github.com> Date: Sun, 23 Mar 2025 22:08:24 +0900 Subject: [PATCH 26/36] Simplefy DenseMLP Signed-off-by: Shinichi Hemmi <50256998+Alnusjaponica@users.noreply.github.com> --- vllm/model_executor/models/plamo2.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/vllm/model_executor/models/plamo2.py b/vllm/model_executor/models/plamo2.py index 0968d1101a76..e1da21070c7a 100644 --- a/vllm/model_executor/models/plamo2.py +++ b/vllm/model_executor/models/plamo2.py @@ -319,8 +319,6 @@ class DenseMLP(nn.Module): def __init__(self, config: PlamoConfig, - params_dtype: Optional[torch.dtype] = None, - tp_size: Optional[int] = None, quant_config: Optional[QuantizationConfig] = None) -> None: super().__init__() self.hidden_size = config.hidden_size @@ -453,8 +451,7 @@ def __init__(self, max_model_len=max_model_len, prefix=f"{prefix}.mixer") - ffn_layer_class = DenseMLP - self.mlp = ffn_layer_class(config, quant_config=quant_config) + self.mlp = DenseMLP(config=config, quant_config=quant_config) self.pre_mixer_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_mixer_norm = RMSNorm(config.hidden_size, From c6adb46639e6b61f443da909c10b771d60e2b8f8 Mon Sep 17 00:00:00 2001 From: Shinichi Hemmi <50256998+Alnusjaponica@users.noreply.github.com> Date: Mon, 24 Mar 2025 11:15:27 +0900 Subject: [PATCH 27/36] Stop specifying use_mamba_kernels=False as a mamba kernel is installed in the test Signed-off-by: Shinichi Hemmi <50256998+Alnusjaponica@users.noreply.github.com> --- .../decoder_only/language/test_hybrid.py | 22 ++----------------- 1 file changed, 2 insertions(+), 20 deletions(-) diff --git a/tests/models/decoder_only/language/test_hybrid.py b/tests/models/decoder_only/language/test_hybrid.py index 71adefebb1de..1714c7834056 100644 --- a/tests/models/decoder_only/language/test_hybrid.py +++ b/tests/models/decoder_only/language/test_hybrid.py @@ -35,16 +35,7 @@ def test_models( if "Bamba" in model: example_prompts.pop(3) - model_kwargs = { - "use_mamba_kernels": False, # mamba kernels are not installed so HF - # don't use them - } - if "Zamba2" in model or "plamo" in model: - # Zamba2 HF implementation automatically checks if mamba kernels are - # installed - model_kwargs = {} - - with hf_runner(model, dtype=dtype, model_kwargs=model_kwargs) as hf_model: + with hf_runner(model, dtype=dtype) as hf_model: hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) with vllm_runner(model, dtype=dtype) as vllm_model: @@ -141,16 +132,7 @@ def test_mamba_prefill_chunking(hf_runner, vllm_runner, example_prompts, elif "plamo" in model: example_prompts.pop(7) - model_kwargs = { - "use_mamba_kernels": False, # mamba kernels are not installed so HF - # don't use them - } - if "Zamba2" in model or "plamo" in model: - # Zamba2 HF implementation automatically checks if mamba kernels are - # installed - model_kwargs = {} - - with hf_runner(model, dtype=dtype, model_kwargs=model_kwargs) as hf_model: + with hf_runner(model, dtype=dtype) as hf_model: non_chunked = hf_model.generate_greedy(example_prompts, max_tokens) with vllm_runner(model, From 83f6be5ba2087eb4a9112c070cc7f6abb8eb25f5 Mon Sep 17 00:00:00 2001 From: Shinichi Hemmi <50256998+Alnusjaponica@users.noreply.github.com> Date: Wed, 26 Mar 2025 21:49:28 +0900 Subject: [PATCH 28/36] Remove nn.Linear for quantization support Signed-off-by: Shinichi Hemmi <50256998+Alnusjaponica@users.noreply.github.com> --- vllm/model_executor/models/plamo2.py | 30 ++++++++++++++++++---------- 1 file changed, 19 insertions(+), 11 deletions(-) diff --git a/vllm/model_executor/models/plamo2.py b/vllm/model_executor/models/plamo2.py index e1da21070c7a..f3e6ead4ca6e 100644 --- a/vllm/model_executor/models/plamo2.py +++ b/vllm/model_executor/models/plamo2.py @@ -317,23 +317,31 @@ def forward( class DenseMLP(nn.Module): - def __init__(self, - config: PlamoConfig, - quant_config: Optional[QuantizationConfig] = None) -> None: + def __init__( + self, + config: PlamoConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: super().__init__() self.hidden_size = config.hidden_size self.intermediate_size = config.intermediate_size - self.gate_up_proj = torch.nn.Linear(self.hidden_size, - self.intermediate_size * 2, - bias=False) - self.down_proj = torch.nn.Linear(self.intermediate_size, - self.hidden_size, - bias=False) + self.gate_up_proj = MergedColumnParallelLinear( + self.hidden_size, [self.intermediate_size] * 2, + bias=False, + prefix=f"{prefix}.gate_up_proj", + quant_config=quant_config) + self.down_proj = RowParallelLinear(self.intermediate_size, + self.hidden_size, + bias=False, + prefix=f"{prefix}.down_proj", + quant_config=quant_config) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - h = self.gate_up_proj(hidden_states) + h = self.gate_up_proj(hidden_states)[0] h = _swiglu(h) - return self.down_proj(h) # type: ignore + output, _ = self.down_proj(h) + return output # type: ignore class Plamo2AttentionMixer(nn.Module): From 81a19546da4b9c825fb34d297c177d558c44808b Mon Sep 17 00:00:00 2001 From: Shinichi Hemmi <50256998+Alnusjaponica@users.noreply.github.com> Date: Tue, 1 Apr 2025 17:00:04 +0900 Subject: [PATCH 29/36] Properly pass prefixes Signed-off-by: Shinichi Hemmi <50256998+Alnusjaponica@users.noreply.github.com> --- vllm/model_executor/models/plamo2.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/vllm/model_executor/models/plamo2.py b/vllm/model_executor/models/plamo2.py index f3e6ead4ca6e..48abd4d4af1b 100644 --- a/vllm/model_executor/models/plamo2.py +++ b/vllm/model_executor/models/plamo2.py @@ -165,9 +165,12 @@ def __init__(self, # time step projection (discretization) - # In the forward we need to apply dt_proj without the bias, # as the bias is added in the selective scan kernel. - self.dt_proj = ColumnParallelLinear(self.time_step_rank, - self.num_heads, - bias=False) + self.dt_proj = ColumnParallelLinear( + self.time_step_rank, + self.num_heads, + bias=False, + prefix=f"{prefix}.dt_proj", + ) self.dt_bias = torch.nn.Parameter(get_initial_dt_bias(self.num_heads)) tp_size = get_tensor_model_parallel_world_size() @@ -189,6 +192,7 @@ def __init__(self, self.hidden_size, bias=self.use_bias, input_is_parallel=True, + prefix=f"{prefix}.out_proj", ) # The activation function is fixed to SiLU. self.activation = "silu" @@ -459,7 +463,9 @@ def __init__(self, max_model_len=max_model_len, prefix=f"{prefix}.mixer") - self.mlp = DenseMLP(config=config, quant_config=quant_config) + self.mlp = DenseMLP(config=config, + quant_config=quant_config, + prefix=f"{prefix}.mlp") self.pre_mixer_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_mixer_norm = RMSNorm(config.hidden_size, From 0ed004261befbe9f7b5f0ca19903a12b3594e00f Mon Sep 17 00:00:00 2001 From: Shinichi Hemmi <50256998+Alnusjaponica@users.noreply.github.com> Date: Tue, 1 Apr 2025 18:24:34 +0900 Subject: [PATCH 30/36] Stop using float16 when dtype=auto is specified. Signed-off-by: Shinichi Hemmi <50256998+Alnusjaponica@users.noreply.github.com> --- vllm/config.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/vllm/config.py b/vllm/config.py index c510677d64ea..811f40308697 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -2539,6 +2539,12 @@ def _get_and_verify_dtype( "instead of float16 by default. Please specify `dtype` " "if you want to use float16.") torch_dtype = torch.bfloat16 + elif config.model_type == "plamo": + logger.info( + "For PLaMo2, we downcast float32 to bfloat16, instead " + "of float32 by default. This is because float16 does " + "not work.") + torch_dtype = torch.bfloat16 else: # Following the common practice, we use float16 for float32 # models. @@ -2575,6 +2581,11 @@ def _get_and_verify_dtype( "using float16 by default. Please specify `dtype` if you " "want to use float16.") torch_dtype = torch.bfloat16 + elif dtype == "float16" and config.model_type == "plamo": + logger.warning( + "For PLaMo2, using float16 is unstable and might cause " + "unexpected behavior. Please use bfloat16 or float32 instead.") + torch_dtype = torch.float16 else: if dtype not in _STR_DTYPE_TO_TORCH_DTYPE: raise ValueError(f"Unknown dtype: {dtype}") From 63283c1c937ff2487d4f2a629393c0b7a9898ccf Mon Sep 17 00:00:00 2001 From: Shinichi Hemmi <50256998+Alnusjaponica@users.noreply.github.com> Date: Tue, 1 Apr 2025 18:58:50 +0900 Subject: [PATCH 31/36] Revert "Stop using float16 when dtype=auto is specified." This reverts commit 0ed004261befbe9f7b5f0ca19903a12b3594e00f. Signed-off-by: Shinichi Hemmi <50256998+Alnusjaponica@users.noreply.github.com> --- vllm/config.py | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index 811f40308697..c510677d64ea 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -2539,12 +2539,6 @@ def _get_and_verify_dtype( "instead of float16 by default. Please specify `dtype` " "if you want to use float16.") torch_dtype = torch.bfloat16 - elif config.model_type == "plamo": - logger.info( - "For PLaMo2, we downcast float32 to bfloat16, instead " - "of float32 by default. This is because float16 does " - "not work.") - torch_dtype = torch.bfloat16 else: # Following the common practice, we use float16 for float32 # models. @@ -2581,11 +2575,6 @@ def _get_and_verify_dtype( "using float16 by default. Please specify `dtype` if you " "want to use float16.") torch_dtype = torch.bfloat16 - elif dtype == "float16" and config.model_type == "plamo": - logger.warning( - "For PLaMo2, using float16 is unstable and might cause " - "unexpected behavior. Please use bfloat16 or float32 instead.") - torch_dtype = torch.float16 else: if dtype not in _STR_DTYPE_TO_TORCH_DTYPE: raise ValueError(f"Unknown dtype: {dtype}") From 19fcd5fabef068db816232ab107d83dd28033b15 Mon Sep 17 00:00:00 2001 From: Shinichi Hemmi <50256998+Alnusjaponica@users.noreply.github.com> Date: Tue, 1 Apr 2025 19:47:18 +0900 Subject: [PATCH 32/36] Handle dtype for plamo2 in config Signed-off-by: Shinichi Hemmi <50256998+Alnusjaponica@users.noreply.github.com> --- vllm/config.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/vllm/config.py b/vllm/config.py index c82c9763ccdc..efb8e9095bcc 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -2625,6 +2625,13 @@ def _get_and_verify_dtype( else: torch_dtype = config_dtype + if config.model_type == "plamo": + logger.info( + "For PLaMo2, we cast models to bfloat16 instead of using " + "float16 by default. This is because float16 does not work." + ) + torch_dtype = torch.bfloat16 + from vllm.platforms import current_platform if (current_platform.is_cpu() and current_platform.get_cpu_architecture() @@ -2654,6 +2661,11 @@ def _get_and_verify_dtype( "using float16 by default. Please specify `dtype` if you " "want to use float16.") torch_dtype = torch.bfloat16 + elif dtype == "float16" and config.model_type == "plamo": + logger.warning( + "For PLaMo2, using float16 is unstable and might cause " + "unexpected behavior. Please use bfloat16 or float32 instead.") + torch_dtype = torch.float16 else: if dtype not in _STR_DTYPE_TO_TORCH_DTYPE: raise ValueError(f"Unknown dtype: {dtype}") From 3f446750b9752fd8dc8b1549dc5c311f4644b9e0 Mon Sep 17 00:00:00 2001 From: Shinichi Hemmi <50256998+Alnusjaponica@users.noreply.github.com> Date: Mon, 7 Apr 2025 19:25:18 +0900 Subject: [PATCH 33/36] Update object names to plamo2-prefixed Signed-off-by: Shinichi Hemmi <50256998+Alnusjaponica@users.noreply.github.com> --- vllm/config.py | 2 +- vllm/model_executor/models/plamo2.py | 25 +++++++++---------------- vllm/model_executor/models/registry.py | 2 +- 3 files changed, 11 insertions(+), 18 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index efb8e9095bcc..cfdc2d048704 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -2625,7 +2625,7 @@ def _get_and_verify_dtype( else: torch_dtype = config_dtype - if config.model_type == "plamo": + if config.model_type == "plamo2": logger.info( "For PLaMo2, we cast models to bfloat16 instead of using " "float16 by default. This is because float16 does not work." diff --git a/vllm/model_executor/models/plamo2.py b/vllm/model_executor/models/plamo2.py index 48abd4d4af1b..fb1442526c6c 100644 --- a/vllm/model_executor/models/plamo2.py +++ b/vllm/model_executor/models/plamo2.py @@ -41,8 +41,8 @@ # Only used for type hinting. -class PlamoConfig(PretrainedConfig): # type: ignore - model_type: str = "plamo" +class Plamo2Config(PretrainedConfig): # type: ignore + model_type: str = "plamo2" hidden_size: int num_hidden_layers: int @@ -62,7 +62,7 @@ class PlamoConfig(PretrainedConfig): # type: ignore vocab_size: int -class PlamoPreTrainedModel(PreTrainedModel): # type: ignore +class Plamo2PreTrainedModel(PreTrainedModel): # type: ignore def _init_weights(self, module: torch.nn.Module) -> None: std = 0.02 @@ -87,7 +87,7 @@ def get_initial_dt_bias(num_heads: int) -> torch.Tensor: return inv_dt -def is_mamba(config: PlamoConfig, i: int) -> bool: +def is_mamba(config: Plamo2Config, i: int) -> bool: assert config.mamba_step > 1 if config.num_hidden_layers <= (config.mamba_step // 2): @@ -120,7 +120,7 @@ class Plamo2MambaMixer(nn.Module): # TODO(Shinichi): Rebase on Mamba2 implementation. def __init__(self, - config: PlamoConfig, + config: Plamo2Config, cache_config: CacheConfig, quant_config: QuantizationConfig, max_model_len: int, @@ -323,7 +323,7 @@ class DenseMLP(nn.Module): def __init__( self, - config: PlamoConfig, + config: Plamo2Config, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", ) -> None: @@ -351,7 +351,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: class Plamo2AttentionMixer(nn.Module): def __init__(self, - config: PlamoConfig, + config: Plamo2Config, cache_config: CacheConfig, quant_config: QuantizationConfig, max_model_len: int | None = None, @@ -538,7 +538,7 @@ def forward( return hidden_states, residual -class Plamo2Model(PlamoPreTrainedModel): +class Plamo2Model(Plamo2PreTrainedModel): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__(vllm_config.model_config.hf_config) @@ -581,7 +581,7 @@ def forward( return hidden_states -class Plamo2ForCausalLM(PlamoPreTrainedModel, HasInnerState, IsHybrid, +class Plamo2ForCausalLM(Plamo2PreTrainedModel, HasInnerState, IsHybrid, SupportsV0Only): packed_modules_mapping = { "qkv_proj": [ @@ -603,13 +603,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: self.model_config = vllm_config.model_config self.scheduler_config = scheduler_config - # TODO(Shinichi): Remove this workaround. - # vllm.model_executor.models.interfaces.IsHybrid requires - # self.config.layers_block_type to be set. - self.config.layers_block_type = [ - "mamba" if is_mamba(self.config, i) else "attention" - for i in range(self.config.num_hidden_layers) - ] # ModelConfig.get_head_size assumes head_dim is set or calculated as # hidden_size // num_attention_heads. However, this is not always # the case for PLaMo2, as indicated by the FIXME comment. diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 9c9aada7840e..380fa4a23412 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -96,7 +96,7 @@ "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"), "Phi3SmallForCausalLM": ("phi3_small", "Phi3SmallForCausalLM"), "PhiMoEForCausalLM": ("phimoe", "PhiMoEForCausalLM"), - "PlamoForCausalLM": ("plamo2", "Plamo2ForCausalLM"), + "Plamo2ForCausalLM": ("plamo2", "Plamo2ForCausalLM"), "QWenLMHeadModel": ("qwen", "QWenLMHeadModel"), "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"), "Qwen2MoeForCausalLM": ("qwen2_moe", "Qwen2MoeForCausalLM"), From f43d02a8d2fca5ad9ca6f256e43d2152e611f60d Mon Sep 17 00:00:00 2001 From: Shinichi Hemmi <50256998+Alnusjaponica@users.noreply.github.com> Date: Mon, 7 Apr 2025 20:18:50 +0900 Subject: [PATCH 34/36] Update object names to plamo2-prefixed in the tests Signed-off-by: Shinichi Hemmi <50256998+Alnusjaponica@users.noreply.github.com> --- tests/models/decoder_only/language/test_hybrid.py | 4 ++-- vllm/config.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/models/decoder_only/language/test_hybrid.py b/tests/models/decoder_only/language/test_hybrid.py index 1714c7834056..64a02cb8907b 100644 --- a/tests/models/decoder_only/language/test_hybrid.py +++ b/tests/models/decoder_only/language/test_hybrid.py @@ -91,7 +91,7 @@ def test_mamba_prefill_chunking_with_parallel_sampling( # chunked prefill forward pass (where we have both prefills # and decoding together ) - if 'plamo' in model: + if 'plamo-2' in model: dtype = "float" # use a different dtype for plamo sampling_params = SamplingParams(n=3, @@ -129,7 +129,7 @@ def test_mamba_prefill_chunking(hf_runner, vllm_runner, example_prompts, elif "Zamba2" in model: example_prompts.pop(7) dtype = "half" - elif "plamo" in model: + elif "plamo-2-1b" in model: example_prompts.pop(7) with hf_runner(model, dtype=dtype) as hf_model: diff --git a/vllm/config.py b/vllm/config.py index cfdc2d048704..d911f318b24a 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -2661,7 +2661,7 @@ def _get_and_verify_dtype( "using float16 by default. Please specify `dtype` if you " "want to use float16.") torch_dtype = torch.bfloat16 - elif dtype == "float16" and config.model_type == "plamo": + elif dtype == "float16" and config.model_type == "plamo2": logger.warning( "For PLaMo2, using float16 is unstable and might cause " "unexpected behavior. Please use bfloat16 or float32 instead.") From 7b41a187bbb32ec5f64ebbf9070f4e5ffad1f896 Mon Sep 17 00:00:00 2001 From: Shinichi Hemmi <50256998+Alnusjaponica@users.noreply.github.com> Date: Tue, 15 Apr 2025 14:26:48 +0900 Subject: [PATCH 35/36] Fix Plamo2ForCausalLM class name Signed-off-by: Shinichi Hemmi <50256998+Alnusjaponica@users.noreply.github.com> --- tests/models/registry.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/models/registry.py b/tests/models/registry.py index a5bea25afc2e..51aeeb5e441d 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -204,7 +204,7 @@ def check_available_online( trust_remote_code=True), "PhiMoEForCausalLM": _HfExamplesInfo("microsoft/Phi-3.5-MoE-instruct", trust_remote_code=True), - "PlamoForCausalLM": _HfExamplesInfo("pfnet/plamo-2-1b", + "Plamo2ForCausalLM": _HfExamplesInfo("pfnet/plamo-2-1b", trust_remote_code=True), "QWenLMHeadModel": _HfExamplesInfo("Qwen/Qwen-7B-Chat", trust_remote_code=True), From 0c8fb3600924b29f62585a3487923818647cd45d Mon Sep 17 00:00:00 2001 From: Shinichi Hemmi <50256998+Alnusjaponica@users.noreply.github.com> Date: Tue, 15 Apr 2025 22:04:42 +0900 Subject: [PATCH 36/36] Split plamo2 initialization test for debugging purpose Signed-off-by: Shinichi Hemmi <50256998+Alnusjaponica@users.noreply.github.com> --- .buildkite/test-pipeline.yaml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index c4ade9cdd3fe..c86f6add6cb2 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -400,8 +400,9 @@ steps: - pytest -v -s models/test_transformers.py - pytest -v -s models/test_registry.py # V1 Test: https://github.com/vllm-project/vllm/issues/14531 - - VLLM_USE_V1=0 pytest -v -s models/test_initialization.py -k 'not llama4' + - VLLM_USE_V1=0 pytest -v -s models/test_initialization.py -k 'not llama4 and not plamo2' - VLLM_USE_V1=0 pytest -v -s models/test_initialization.py -k 'llama4' + - VLLM_USE_V1=0 pytest -v -s models/test_initialization.py -k 'plamo2' - label: Language Models Test (Standard) # 32min #mirror_hardwares: [amd]