diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index dbbbc5122b80..a24fa4bcce33 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -615,7 +615,7 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen | `Gemma3nForConditionalGeneration` | Gemma 3n | T + I + A | `google/gemma-3n-E2B-it`, `google/gemma-3n-E4B-it`, etc. | | | ✅︎ | | `GLM4VForCausalLM`^ | GLM-4V | T + I | `zai-org/glm-4v-9b`, `zai-org/cogagent-9b-20241220`, etc. | ✅︎ | ✅︎ | ✅︎ | | `Glm4vForConditionalGeneration` | GLM-4.1V-Thinking | T + IE+ + VE+ | `zai-org/GLM-4.1V-9B-Thinking`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `Glm4vMoeForConditionalGeneration` | GLM-4.5V | T + IE+ + VE+ | `zai-org/GLM-4.5V`, etc. | | ✅︎ | ✅︎ | +| `Glm4vMoeForConditionalGeneration` | GLM-4.5V | T + IE+ + VE+ | `zai-org/GLM-4.5V`, etc. | ✅︎ | ✅︎ | ✅︎ | | `GraniteSpeechForConditionalGeneration` | Granite Speech | T + A | `ibm-granite/granite-speech-3.3-8b` | ✅︎ | ✅︎ | ✅︎ | | `H2OVLChatModel` | H2OVL | T + IE+ | `h2oai/h2ovl-mississippi-800m`, `h2oai/h2ovl-mississippi-2b`, etc. | | ✅︎ | ✅︎ | | `Idefics3ForConditionalGeneration` | Idefics3 | T + I | `HuggingFaceM4/Idefics3-8B-Llama3`, etc. | ✅︎ | | ✅︎ | diff --git a/vllm/model_executor/models/glm4_moe.py b/vllm/model_executor/models/glm4_moe.py index 131c042c3c2d..aff491f9596c 100644 --- a/vllm/model_executor/models/glm4_moe.py +++ b/vllm/model_executor/models/glm4_moe.py @@ -41,7 +41,6 @@ from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, QKVParallelLinear, - ReplicatedLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig @@ -118,14 +117,15 @@ def __init__( if config.hidden_act != "silu": raise ValueError(f"Unsupported activation: {config.hidden_act}. " "Only silu is supported for now.") - - self.gate = ReplicatedLinear(config.hidden_size, - config.n_routed_experts, - bias=False, - quant_config=None, - params_dtype=torch.float32, - prefix=f"{prefix}.gate") - + # NOTE In the transformers implementation, the gate isn't an nn.Linear, + # so we cannot use ReplicatedLinear here. + # See: https://github.com/huggingface/transformers/blob/v4.55.1/src/transformers/models/glm4_moe/modeling_glm4_moe.py#L260 + self.gate = nn.Linear( + config.hidden_size, + config.n_routed_experts, + bias=False, + dtype=torch.float32, + ) self.gate.e_score_correction_bias = nn.Parameter( torch.empty(config.n_routed_experts, dtype=torch.float32)) @@ -181,7 +181,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: if self.n_shared_experts is not None: shared_output = self.shared_experts(hidden_states) - router_logits, _ = self.gate(hidden_states.to(dtype=torch.float32)) + router_logits = self.gate(hidden_states.to(dtype=torch.float32)) final_hidden_states = self.experts( hidden_states=hidden_states, router_logits=router_logits) * self.routed_scaling_factor