diff --git a/vllm/distributed/eplb/eplb_adaptor/__init__.py b/vllm/distributed/eplb/eplb_adaptor/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/vllm/distributed/eplb/eplb_adaptor/abstract_adaptor.py b/vllm/distributed/eplb/eplb_adaptor/abstract_adaptor.py new file mode 100644 index 000000000000..a1f531cf8989 --- /dev/null +++ b/vllm/distributed/eplb/eplb_adaptor/abstract_adaptor.py @@ -0,0 +1,107 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from abc import ABC, abstractmethod +from typing import Any + + +class BaseAdaptor(ABC): + """ + Abstract base class for Expert Parallel Load Balancer (EPLB) adaptors. + + This class defines the interface required for coordination with EPLB, + including obtaining workloads, managing expert maps, and updating + expert weights. Specific adaptor implementations (e.g., for vLLM) + should inherit from this base class and implement all abstract methods. + """ + + @abstractmethod + def __init__(self, **args): + """ + Initializes the adaptor. + + Args: + **args: Any additional initialization arguments. + """ + pass + + @abstractmethod + def get_rank_expert_workload(self): + """ + Abstract method: Retrieves the expert workload statistics for the + current rank. + + Concrete implementations should return a tensor or other data structure + representing the workload metrics for MoE layers within the current + process. + + Raises: + NotImplementedError: If the subclass does not implement this method. + """ + raise NotImplementedError + + @abstractmethod + def get_init_expert_map(self, num_moe_layers: Any) -> Any: + """ + Abstract method: Collects the initial expert mappings across all ranks. + + Concrete implementations should return a tensor or other data structure + representing the global expert map. + + Args: + num_moe_layers: The number of MoE layers to process. + + Returns: + Any: The global expert map. + + Raises: + NotImplementedError: If the subclass does not implement this method. + """ + raise NotImplementedError + + @abstractmethod + def do_update_expert_map(self, layer_id: Any, + updated_expert_map: Any) -> Any: + """ + Abstract method: Performs an update of the expert map. + + Concrete implementations should apply the updated expert map to the + specified MoE layer. + + Args: + layer_id: The ID of the MoE layer to update. + updated_expert_map: The tensor or data structure containing the new + expert map. + + Returns: + Any: The result of the update operation (if applicable). + + Raises: + NotImplementedError: If the subclass does not implement this method. + """ + raise NotImplementedError + + @abstractmethod + def do_update_expert_weight(self, layer_id: Any, + local_expert_to_replace: Any, + buffer_tensor_id: Any) -> Any: + """ + Abstract method: Performs an update of expert weights. + + Concrete implementations should copy weights from a specified buffer + tensor to the target local expert. + + Args: + layer_id: The ID of the MoE layer containing the expert to update. + local_expert_to_replace: The local ID of the expert whose + weights are to be replaced. + buffer_tensor_id: The ID of the buffer tensor containing the new + weights. + + Returns: + Any: The result of the update operation (if applicable). + + Raises: + NotImplementedError: If the subclass does not implement this method. + """ + raise NotImplementedError diff --git a/vllm/distributed/eplb/eplb_adaptor/vllm_adaptor.py b/vllm/distributed/eplb/eplb_adaptor/vllm_adaptor.py new file mode 100644 index 000000000000..5e8f3f15736c --- /dev/null +++ b/vllm/distributed/eplb/eplb_adaptor/vllm_adaptor.py @@ -0,0 +1,384 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Expert Parallel Load Balancer (EPLB) Adaptor Implementation for vLLM. + +This module implements distributed expert management for +Mixture-of-Experts (MoE) models in vLLM framework. +Key features include: + +1. Expert Mapping Management + - Maintains physical/logical expert mappings across devices + - Handles expert placement updates during load balancing + +2. Weight Synchronization + - Manages expert weight buffers for dynamic parameter updates + - Supports quantized expert weights (w8a8 format) + +3. Distributed Coordination + - Collects global expert workload metrics + - Implements cross-rank expert map synchronization + +Supported Model Architectures: +- DeepSeek V3 (with quantization support) +- Qwen3-MoE +- Kimi K2 +- Standard MoE models with configurable expert layers + +Note: Current implementation assumes homogeneous expert structure +across MoE layers. +""" + +import json +from typing import Any + +import torch +import torch.distributed as dist + +from vllm.distributed.eplb.eplb_adaptor.abstract_adaptor import BaseAdaptor + + +class VllmEplbAdaptor(BaseAdaptor): + """vLLM implementation of Expert Parallel Load Balancer (EPLB) adaptor. + + Handles expert mapping management, weight synchronization and distributed + coordination for MoE models in vLLM framework. + + Attributes: + model: vLLM model instance + rank_id: Current process rank in distributed group + world_size: Total number of processes in distributed group + num_dense_layers: Number of dense layers before MoE layers + global_expert_num: Total number of experts in the model + num_moe_layers: Number of MoE layers in the model + expert_weight_names: List of parameter names for expert weights + """ + + def __init__(self, model, **args): + """Initialize adaptor with model configuration. + + Args: + model: vLLM model instance containing MoE layers + **args: Additional base class arguments + """ + self.model = model + self.rank_id = dist.get_rank() + self.world_size = dist.get_world_size() + self.param_dict = dict(self.model.named_parameters()) + if self.model.config.model_type == "qwen3_moe": + self.num_dense_layers = 0 + self.global_expert_num = self.model.config.num_experts + else: + self.num_dense_layers = self.model.config.first_k_dense_replace + self.global_expert_num = self.model.config.n_routed_experts + self.num_moe_layers = \ + self.model.config.num_hidden_layers - self.num_dense_layers + + # TODO: init self.expert_weight_names depending on different model + # types, only deepseek v3 w8a8 and qwen3-moe is supported here + if self.model.quant_config is not None: + self.expert_weight_names = [ + "w13_weight", "w2_weight", "w13_weight_scale", + "w13_weight_offset", "w2_weight_scale", "w2_weight_offset" + ] + else: + self.expert_weight_names = ["w13_weight", "w2_weight"] + + self.expert_map_per_layer = dict( + ) # reference to expert map on device for expert map update + self.expert_map_per_layer_cpu = dict( + ) # copy of expert map on CPU to avoid device synchronize frequently + for layer_idx in range(self.num_moe_layers): + self.expert_map_per_layer[self.num_dense_layers + layer_idx] = \ + self.model.get_expert_map(self.num_dense_layers + layer_idx) + + # TODO: here we set number of buffer tensor equal to number of expert + # in each layer, which can be improved + num_buffer_tensor = torch.where( + self.expert_map_per_layer[self.num_dense_layers] != -1)[0].numel() + self.buffer_tensor_list: list[list[Any]] = [ + [] for _ in range(num_buffer_tensor) + ] + self.init_buffer_tensor(num_buffer_tensor) + + self.expert_param_per_layer = dict() + self.init_expert_param_per_layer() + + self.log2phy_map_per_layer = dict() + for layer_idx in range(self.num_moe_layers): + self.log2phy_map_per_layer[self.num_dense_layers + layer_idx] = \ + self.model.get_log2phy_map(self.num_dense_layers + layer_idx) + + self.all_topk_ids = [] + + def init_buffer_tensor(self, num_buffer_tensor): + """Initialize buffer tensors for expert weight updates. + + Args: + num_buffer_tensor: Number of buffer slots per expert parameter + """ + for name in self.expert_weight_names: + complete_name = "model.layers." + str( + self.num_dense_layers) + ".mlp.experts." + name + expert_tensor = self.param_dict[complete_name].data[ + 0:num_buffer_tensor] + buffer_tensors = torch.empty_like(expert_tensor) + for buffer_id in range(num_buffer_tensor): + self.buffer_tensor_list[buffer_id].append( + buffer_tensors[buffer_id]) + + def init_expert_param_per_layer(self): + """Initialize expert parameter references for all MoE layers.""" + num_local_expert = self.param_dict["model.layers." + \ + str(self.num_dense_layers) + ".mlp.experts." + \ + self.expert_weight_names[0]].data.shape[0] + for moe_layer_id in range(self.num_moe_layers): + layer_idx = self.num_dense_layers + moe_layer_id + self.expert_param_per_layer[layer_idx] = list() + for local_expert_id in range(num_local_expert): + self.expert_param_per_layer[layer_idx].append([ + self.param_dict["model.layers." + str(layer_idx) + + ".mlp.experts." + + name].data[local_expert_id] + for name in self.expert_weight_names + ]) + + def get_rank_expert_workload(self) -> torch.Tensor: + """Get current rank's expert workload statistics. + + Returns: + torch.Tensor: Tensor containing MoE layer workload metrics + """ + self.moe_load = self.model.get_all_moe_loads() + return self.moe_load + + def get_init_expert_map(self, num_moe_layers): + """Collect initial expert mappings across all ranks. + + Args: + num_moe_layers: Number of MoE layers to process + + Returns: + torch.Tensor: Global expert mapping tensor [layers, ranks, experts] + """ + expert_map = self.model.get_all_expert_map(num_moe_layers) + if dist.is_initialized(): + world_size = self.world_size + + gathered = torch.empty( + (world_size, *expert_map.shape), # [W, L, E] + dtype=expert_map.dtype, + device=expert_map.device) + + dist.all_gather_into_tensor(gathered, expert_map) + all_maps = gathered.permute(1, 0, 2) + all_expert_maps = all_maps.cpu() + + for layer_idx in range(num_moe_layers): + self.expert_map_per_layer_cpu[self.num_dense_layers + layer_idx] = \ + all_expert_maps[layer_idx][self.rank_id] + + return all_expert_maps + + def get_init_expert_map_from_file(self, num_moe_layers, expert_map_path): + """Retrieves initial expert mappings from a file. + + If file reading fails or the file does not exist, it falls back to + the default expert map determination logic. + + Args: + num_moe_layers: The number of MoE layers to process. + expert_map_path: The path to the JSON file containing expert + mapping information. + + Returns: + torch.Tensor: A global expert mapping tensor of shape + [layers, ranks, experts]. + """ + try: + expert_map_tensor, layers_num, ranks_num = \ + self._expert_file_to_tensor(expert_map_path) + expert_map_all = self.local2global(expert_map_tensor) + except (TypeError, FileNotFoundError, OSError, json.JSONDecodeError, + KeyError): + expert_map_all = self.determine_expert_map_all() + + for layer_idx in range(num_moe_layers): + if self.model.config.model_type == "qwen3_moe": + self.expert_map_per_layer_cpu[layer_idx] = \ + expert_map_all[layer_idx][self.rank_id] + else: #adapt both dsv3 and kimik2 + self.expert_map_per_layer_cpu[layer_idx + \ + self.num_dense_layers] = \ + expert_map_all[layer_idx][self.rank_id] + return expert_map_all + + def _expert_file_to_tensor(self, expert_map_path: str): + """Reads expert mappings from a JSON file and converts them to + a PyTorch tensor. + + The file format is expected to contain 'moe_layer_count' + and 'layer_list'. Each layer in 'layer_list' should contain a + 'device_list', and each device should contain 'device_expert'. + + Args: + expert_map_path: The path to the JSON file containing + expert mapping information. + + Returns: + tuple: A tuple containing (expert map tensor, + number of layers, number of GPUs). + + Raises: + FileNotFoundError: If the file does not exist. + json.JSONDecodeError: If the file content is not valid JSON. + KeyError: If the JSON structure does not conform to expectations. + """ + with open(expert_map_path) as f: + data = json.load(f) + layers_num = data["moe_layer_count"] + gpus_num = data["layer_list"][0]["device_count"] + + tensor_data = [] + for layer in data["layer_list"]: + device_data = [] + for device in layer["device_list"]: + device_data.append(device["device_expert"]) + tensor_data.append(device_data) + expert_map_tensor = torch.tensor(tensor_data, dtype=torch.int32) + return expert_map_tensor, layers_num, gpus_num + + def do_update_expert_map(self, layer_id, updated_expert_map): + """Performs an update of the expert map. + + Copies the updated expert map to both the on-device map + and its CPU copy. + + Args: + layer_id: The ID of the MoE layer to update. + updated_expert_map: A PyTorch tensor containing + the new expert map. + """ + self.expert_map_per_layer[layer_id].copy_(updated_expert_map) + self.expert_map_per_layer_cpu[layer_id].copy_(updated_expert_map) + + def do_update_expert_weight(self, layer_id, local_expert_to_replace, + buffer_tensor_id): + """Performs an update of expert weights. + + Copies weights from a specified buffer tensor to + the target local expert. + + Args: + layer_id: The ID of the MoE layer containing + the expert to update. + local_expert_to_replace: The local ID of the expert + whose weights are to be replaced. + buffer_tensor_id: The ID of the buffer tensor containing + the new weights. + """ + for expert_tensor, buffer_tensor in zip( + self.expert_param_per_layer[layer_id][local_expert_to_replace], + self.buffer_tensor_list[buffer_tensor_id]): + expert_tensor.copy_(buffer_tensor) + + def do_update_log2phy_map(self, layer_id, updated_log2phy_map): + """Performs an update of the logical-to-physical map. + + If a logical-to-physical map exists for the given layer, + it is updated with the new values. + + Args: + layer_id: The ID of the MoE layer to update. + updated_log2phy_map: A PyTorch tensor containing + the new logical-to-physical map. + """ + if self.log2phy_map_per_layer[layer_id] is not None: + self.log2phy_map_per_layer[layer_id].copy_(updated_log2phy_map) + + def local2global(self, placement_local: torch.Tensor) -> torch.Tensor: + """Converts a local expert placement map to a global expert + placement map. + + A local map typically only contains the IDs of local experts + on each device. This function transforms it into a global view + where each expert slot contains its global expert ID. + + For example, if the local placement is `[[0, 1], [2, 3]]` (meaning + device 0 has experts 0,1 and device 1 has experts 2,3), the global + placement would be `[[0, 1, -1, -1], [-1, -1, 0, 1]]` (meaning + global expert 0 is in slot 0 on device 0, global expert 1 is in + slot 1 on device 0, global expert 2 is in slot 0 on device 1, + and global expert 3 is in slot 1 on device 1). + + Args: + placement_local: A local expert placement tensor of shape + [L, G, E_local], where L is the number of layers, + G is the number of GPUs, and E_local is the number of + local expert slots per GPU. Values represent local + expert IDs. + + Returns: + torch.Tensor: A global expert placement tensor of shape + [L, G, E_global], where E_global is the total number of + global experts. Values represent local slot IDs, + with -1 indicating that this global expert is not + on this device. + """ + L, G, E_local = placement_local.shape + device = placement_local.device + + max_id = torch.max(placement_local) + E_global = (max_id + 1).item() if max_id >= 0 else 0 + + if E_global == 0: + return torch.empty((L, G, 0), dtype=torch.long, device=device) + + placement_global = torch.full((L, G, E_global), + fill_value=-1, + dtype=torch.long, + device=device) + + valid = placement_local >= 0 + l_idx, g_idx, slot_idx = valid.nonzero(as_tuple=True) + gid_idx = placement_local[l_idx, g_idx, slot_idx] + + placement_global[l_idx, g_idx, gid_idx] = slot_idx + + return placement_global + + def determine_expert_map_all(self): + """Determines the default expert mapping across all ranks. + + This method distributes experts evenly among each rank based on + the total number of global experts and the world size. Each rank is + responsible for a contiguous range of global experts. + + Returns: + torch.Tensor: A global expert mapping tensor of shape + [layers, ranks, global_expert_num]. The values in the tensor + represent the local slot ID of that global expert on the + corresponding rank. + """ + local_num_experts = self.global_expert_num // self.world_size + + expert_map_all = torch.full( + (self.num_moe_layers, self.world_size, self.global_expert_num), + -1, + dtype=torch.int32) + + for r in range(self.world_size): + if r < self.world_size - 1: + start = r * local_num_experts + end = (r + 1) * local_num_experts + local_count = local_num_experts + else: + start = r * local_num_experts + end = self.global_expert_num + local_count = self.global_expert_num - r * local_num_experts + + local_ids = torch.arange(local_count, dtype=torch.int32) + expert_map_all[:, r, start:end] = local_ids.unsqueeze(0).expand( + self.num_moe_layers, -1) + + return expert_map_all diff --git a/vllm/distributed/eplb/gpu_model_register.py b/vllm/distributed/eplb/gpu_model_register.py new file mode 100644 index 000000000000..9238205741c3 --- /dev/null +++ b/vllm/distributed/eplb/gpu_model_register.py @@ -0,0 +1,116 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import types +import typing +import torch +from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.models.utils import is_pp_missing_parameter +from typing import Callable + +def set_eplb_state( + self, + expert_load_view: torch.Tensor, + logical_to_physical_map: torch.Tensor, + logical_replica_count: torch.Tensor, +) -> None: + for layer_idx, layer in enumerate(self.moe_layers): + # Register the expert weights. + self.expert_weights.append(layer.get_expert_weights()) + layer.set_eplb_state( + moe_layer_idx=layer_idx, + expert_load_view=expert_load_view, + logical_to_physical_map=logical_to_physical_map, + logical_replica_count=logical_replica_count, + ) + +def update_physical_experts_metadata( + self, + num_physical_experts: int, + num_local_physical_experts: int, +) -> None: + assert self.num_local_physical_experts == num_local_physical_experts + self.num_physical_experts = num_physical_experts + self.num_local_physical_experts = num_local_physical_experts + self.num_redundant_experts = (num_physical_experts - + self.num_logical_experts) + for layer in self.model.layers: + if isinstance(layer.mlp, self.example_moe): + moe = layer.mlp + moe.n_local_physical_experts = num_local_physical_experts + moe.n_physical_experts = num_physical_experts + moe.n_redundant_experts = self.num_redundant_experts + moe.experts.update_expert_map() + +def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: + # Params for weights, fp8 weight scales, fp8 activation scales + # (param_name, weight_name, expert_id, shard_id) + return FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + num_experts=self.config.n_routed_experts, + num_redundant_experts=self.num_redundant_experts) + +def load_expert_weight(self, mapping, name, loaded_weight, params_dict): + ignore_suffixes = (".bias", "_bias", ".k_scale", "_k_scale", + ".v_scale", "_v_scale", ".weight_scale", + "_weight_scale", ".input_scale", "_input_scale") + + expert_matched = False + is_continue = False + success = False + name_mapped = '' + param_name, weight_name, expert_id, shard_id = mapping + if weight_name not in name: + is_continue = True + return expert_matched, is_continue, success, name_mapped + + # Anyway, this is an expert weight and should not be + # attempted to load as other weights later + expert_matched = True + + # Do not modify `name` since the loop may continue here + # Instead, create a new variable + name_mapped = name.replace(weight_name, param_name) + + if is_pp_missing_parameter(name_mapped, self): + is_continue = True + return expert_matched, is_continue, success, name_mapped + + # Skip loading extra parameters for GPTQ/modelopt models. + if name_mapped.endswith(ignore_suffixes) \ + and name_mapped not in params_dict: + is_continue = True + return expert_matched, is_continue, success, name_mapped + + param = params_dict[name_mapped] + # We should ask the weight loader to return success or not + # here since otherwise we may skip experts with other + # available replicas. + weight_loader = typing.cast(Callable[..., bool], + param.weight_loader) + success = weight_loader(param, + loaded_weight, + name_mapped, + shard_id=shard_id, + expert_id=expert_id, + return_success=True) + return expert_matched, is_continue, success, name_mapped + +def model_register(model): + """ + Registers custom methods related to Expert Parallel Load Balancing (EPLB) + onto the vLLM model instance. It also determines the number of MoE layers + based on the model configuration. + + Args: + model: The vLLM model instance to which the methods will be added. + """ + model.set_eplb_state = types.MethodType(set_eplb_state, model) + model.load_expert_weight = types.MethodType(load_expert_weight, model) + model.update_physical_experts_metadata = \ + types.MethodType(update_physical_experts_metadata, model) + model.model.get_expert_mapping = \ + types.MethodType(get_expert_mapping, model.model) + print("register complete") diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index 636554bd648f..da2c91eae77f 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -842,7 +842,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.num_expert_groups = config.n_group self.moe_layers: list[FusedMoE] = [] - example_moe = None + self.example_moe = None for layer in self.model.layers: if isinstance(layer, PPMissingLayer): continue @@ -850,52 +850,18 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): assert isinstance(layer, DeepseekV2DecoderLayer) if isinstance(layer.mlp, DeepseekV2MoE): # Pick last one layer since the first ones may be dense layers. - example_moe = layer.mlp + self.example_moe = layer.mlp self.moe_layers.append(layer.mlp.experts) - if example_moe is None: + if self.example_moe is None: raise RuntimeError("No DeepseekV2MoE layer found in model.layers.") - self.num_logical_experts = example_moe.n_logical_experts - self.num_physical_experts = example_moe.n_physical_experts - self.num_local_physical_experts = example_moe.n_local_physical_experts - self.num_routed_experts = example_moe.n_routed_experts - self.num_shared_experts = example_moe.n_shared_experts - self.num_redundant_experts = example_moe.n_redundant_experts - - def set_eplb_state( - self, - expert_load_view: torch.Tensor, - logical_to_physical_map: torch.Tensor, - logical_replica_count: torch.Tensor, - ) -> None: - for layer_idx, layer in enumerate(self.moe_layers): - # Register the expert weights. - self.expert_weights.append(layer.get_expert_weights()) - layer.set_eplb_state( - moe_layer_idx=layer_idx, - expert_load_view=expert_load_view, - logical_to_physical_map=logical_to_physical_map, - logical_replica_count=logical_replica_count, - ) - - def update_physical_experts_metadata( - self, - num_physical_experts: int, - num_local_physical_experts: int, - ) -> None: - assert self.num_local_physical_experts == num_local_physical_experts - self.num_physical_experts = num_physical_experts - self.num_local_physical_experts = num_local_physical_experts - self.num_redundant_experts = (num_physical_experts - - self.num_logical_experts) - for layer in self.model.layers: - if isinstance(layer.mlp, DeepseekV2MoE): - moe = layer.mlp - moe.n_local_physical_experts = num_local_physical_experts - moe.n_physical_experts = num_physical_experts - moe.n_redundant_experts = self.num_redundant_experts - moe.experts.update_expert_map() + self.num_logical_experts = self.example_moe.n_logical_experts + self.num_physical_experts = self.example_moe.n_physical_experts + self.num_local_physical_experts = self.example_moe.n_local_physical_experts + self.num_routed_experts = self.example_moe.n_routed_experts + self.num_shared_experts = self.example_moe.n_shared_experts + self.num_redundant_experts = self.example_moe.n_redundant_experts def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) @@ -929,16 +895,10 @@ def load_weights(self, weights: Iterable[tuple[str, ("fused_qkv_a_proj", "q_a_proj", 0), ("fused_qkv_a_proj", "kv_a_proj_with_mqa", 1), ] - + from vllm.distributed.eplb.gpu_model_register import get_expert_mapping, load_expert_weight # Params for weights, fp8 weight scales, fp8 activation scales # (param_name, weight_name, expert_id, shard_id) - expert_params_mapping = FusedMoE.make_expert_params_mapping( - ckpt_gate_proj_name="gate_proj", - ckpt_down_proj_name="down_proj", - ckpt_up_proj_name="up_proj", - num_experts=self.config.n_routed_experts, - num_redundant_experts=self.num_redundant_experts) - + expert_params_mapping = get_expert_mapping(self) params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() for name, loaded_weight in weights: @@ -984,34 +944,17 @@ def load_weights(self, weights: Iterable[tuple[str, break else: is_expert_weight = False + is_continue = False for mapping in expert_params_mapping: - param_name, weight_name, expert_id, shard_id = mapping - if weight_name not in name: - continue - - # Anyway, this is an expert weight and should not be - # attempted to load as other weights later - is_expert_weight = True - - # Do not modify `name` since the loop may continue here - # Instead, create a new variable - name_mapped = name.replace(weight_name, param_name) + expert_matched, is_continue, success, name_mapped = \ + load_expert_weight(self, mapping, name, + loaded_weight, params_dict) + if expert_matched: + is_expert_weight = True - if is_pp_missing_parameter(name_mapped, self): + if is_continue: continue - param = params_dict[name_mapped] - # We should ask the weight loader to return success or not - # here since otherwise we may skip experts with other - # available replicas. - weight_loader = typing.cast(Callable[..., bool], - param.weight_loader) - success = weight_loader(param, - loaded_weight, - name_mapped, - shard_id=shard_id, - expert_id=expert_id, - return_success=True) if success: name = name_mapped break diff --git a/vllm/model_executor/models/glm4_moe.py b/vllm/model_executor/models/glm4_moe.py index 1acbd18091fb..c8c25abb0c8e 100644 --- a/vllm/model_executor/models/glm4_moe.py +++ b/vllm/model_executor/models/glm4_moe.py @@ -496,15 +496,6 @@ def make_empty_intermediate_tensors( device=device), }) - def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: - # Params for weights, fp8 weight scales, fp8 activation scales - # (param_name, weight_name, expert_id, shard_id) - return FusedMoE.make_expert_params_mapping( - ckpt_gate_proj_name="gate_proj", - ckpt_down_proj_name="down_proj", - ckpt_up_proj_name="up_proj", - num_experts=self.config.n_routed_experts) - def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ @@ -515,10 +506,10 @@ def load_weights(self, weights: Iterable[tuple[str, ("gate_up_proj", "gate_proj", 0), ("gate_up_proj", "up_proj", 1), ] - + from vllm.distributed.eplb.gpu_model_register import get_expert_mapping, load_expert_weight params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() - expert_params_mapping = self.get_expert_mapping() + expert_params_mapping = get_expert_mapping(self) for name, loaded_weight in weights: spec_layer = get_spec_layer_idx_from_weight_name(self.config, name) if spec_layer is not None: @@ -548,34 +539,17 @@ def load_weights(self, weights: Iterable[tuple[str, break else: is_expert_weight = False + is_continue = False for mapping in expert_params_mapping: - param_name, weight_name, expert_id, shard_id = mapping - if weight_name not in name: - continue - - # Anyway, this is an expert weight and should not be - # attempted to load as other weights later - is_expert_weight = True + expert_matched, is_continue, success, name_mapped = \ + load_expert_weight(self, mapping, name, + loaded_weight, params_dict) + if expert_matched: + is_expert_weight = True - # Do not modify `name` since the loop may continue here - # Instead, create a new variable - name_mapped = name.replace(weight_name, param_name) - - if is_pp_missing_parameter(name_mapped, self): + if is_continue: continue - param = params_dict[name_mapped] - # We should ask the weight loader to return success or not - # here since otherwise we may skip experts with other - # available replicas. - weight_loader = typing.cast(Callable[..., bool], - param.weight_loader) - success = weight_loader(param, - loaded_weight, - name_mapped, - shard_id=shard_id, - expert_id=expert_id, - return_success=True) if success: name = name_mapped break @@ -649,7 +623,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.num_expert_groups = config.n_group self.moe_layers: list[FusedMoE] = [] - example_moe = None + self.example_moe = None for layer in self.model.layers: if isinstance(layer, PPMissingLayer): continue @@ -657,34 +631,18 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): assert isinstance(layer, Glm4MoeDecoderLayer) if isinstance(layer.mlp, Glm4MoE): # Pick last one layer since the first ones may be dense layers. - example_moe = layer.mlp + self.example_moe = layer.mlp self.moe_layers.append(layer.mlp.experts) - if example_moe is None: + if self.example_moe is None: raise RuntimeError("No Glm4MoE layer found in model.layers.") - self.num_logical_experts = example_moe.n_logical_experts - self.num_physical_experts = example_moe.n_physical_experts - self.num_local_physical_experts = example_moe.n_local_physical_experts - self.num_routed_experts = example_moe.n_routed_experts - self.num_shared_experts = example_moe.n_shared_experts - self.num_redundant_experts = example_moe.n_redundant_experts - - def set_eplb_state( - self, - expert_load_view: torch.Tensor, - logical_to_physical_map: torch.Tensor, - logical_replica_count: torch.Tensor, - ) -> None: - for layer_idx, layer in enumerate(self.moe_layers): - # Register the expert weights. - self.expert_weights.append(layer.get_expert_weights()) - layer.set_eplb_state( - moe_layer_idx=layer_idx, - expert_load_view=expert_load_view, - logical_to_physical_map=logical_to_physical_map, - logical_replica_count=logical_replica_count, - ) + self.num_logical_experts = self.example_moe.n_logical_experts + self.num_physical_experts = self.example_moe.n_physical_experts + self.num_local_physical_experts = self.example_moe.n_local_physical_experts + self.num_routed_experts = self.example_moe.n_routed_experts + self.num_shared_experts = self.example_moe.n_shared_experts + self.num_redundant_experts = self.example_moe.n_redundant_experts def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) diff --git a/vllm/model_executor/models/qwen3_moe.py b/vllm/model_executor/models/qwen3_moe.py index 029309c49efd..416912e6ec63 100644 --- a/vllm/model_executor/models/qwen3_moe.py +++ b/vllm/model_executor/models/qwen3_moe.py @@ -438,16 +438,6 @@ def forward( hidden_states, _ = self.norm(hidden_states, residual) return hidden_states - def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: - # Params for weights, fp8 weight scales, fp8 activation scales - # (param_name, weight_name, expert_id, shard_id) - return FusedMoE.make_expert_params_mapping( - ckpt_gate_proj_name="gate_proj", - ckpt_down_proj_name="down_proj", - ckpt_up_proj_name="up_proj", - num_experts=self.config.num_experts, - num_redundant_experts=self.num_redundant_experts) - def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ @@ -463,10 +453,10 @@ def load_weights(self, weights: Iterable[tuple[str, ignore_suffixes = (".bias", "_bias", ".k_scale", "_k_scale", ".v_scale", "_v_scale", ".weight_scale", "_weight_scale", ".input_scale", "_input_scale") - + from vllm.distributed.eplb.gpu_model_register import get_expert_mapping, load_expert_weight params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() - expert_params_mapping = self.get_expert_mapping() + expert_params_mapping = get_expert_mapping(self) for name, loaded_weight in weights: for (param_name, weight_name, shard_id) in stacked_params_mapping: # Skip non-stacked layers and experts (experts handled below). @@ -507,40 +497,17 @@ def load_weights(self, weights: Iterable[tuple[str, break else: is_expert_weight = False + is_continue = False for mapping in expert_params_mapping: - param_name, weight_name, expert_id, shard_id = mapping - if weight_name not in name: - continue - - # Anyway, this is an expert weight and should not be - # attempted to load as other weights later - is_expert_weight = True - - # Do not modify `name` since the loop may continue here - # Instead, create a new variable - name_mapped = name.replace(weight_name, param_name) - - if is_pp_missing_parameter(name_mapped, self): - continue + expert_matched, is_continue, success, name_mapped = \ + load_expert_weight(self, mapping, name, + loaded_weight, params_dict) + if expert_matched: + is_expert_weight = True - # Skip loading extra parameters for GPTQ/modelopt models. - if name_mapped.endswith( - ignore_suffixes - ) and name_mapped not in params_dict: + if is_continue: continue - param = params_dict[name_mapped] - # We should ask the weight loader to return success or not - # here since otherwise we may skip experts with other - # available replicas. - weight_loader = typing.cast(Callable[..., bool], - param.weight_loader) - success = weight_loader(param, - loaded_weight, - name_mapped, - shard_id=shard_id, - expert_id=expert_id, - return_success=True) if success: name = name_mapped break @@ -617,61 +584,27 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.expert_weights = [] self.moe_layers: list[FusedMoE] = [] - example_layer = None + self.example_moe = None for layer in self.model.layers: if isinstance(layer, PPMissingLayer): continue assert isinstance(layer, Qwen3MoeDecoderLayer) if isinstance(layer.mlp, Qwen3MoeSparseMoeBlock): - example_layer = layer.mlp + self.example_moe = layer.mlp self.moe_layers.append(layer.mlp.experts) - if example_layer is None: + if self.example_moe is None: raise RuntimeError("No Qwen3MoE layer found in the model.layers.") self.num_moe_layers = len(self.moe_layers) self.num_expert_groups = 1 self.num_shared_experts = 0 - self.num_logical_experts = example_layer.n_logical_experts - self.num_physical_experts = example_layer.n_physical_experts - self.num_local_physical_experts = example_layer.n_local_physical_experts - self.num_routed_experts = example_layer.n_routed_experts - self.num_redundant_experts = example_layer.n_redundant_experts - - def set_eplb_state( - self, - expert_load_view: torch.Tensor, - logical_to_physical_map: torch.Tensor, - logical_replica_count: torch.Tensor, - ) -> None: - for layer_idx, layer in enumerate(self.moe_layers): - # Register the expert weights. - self.expert_weights.append(layer.get_expert_weights()) - layer.set_eplb_state( - moe_layer_idx=layer_idx, - expert_load_view=expert_load_view, - logical_to_physical_map=logical_to_physical_map, - logical_replica_count=logical_replica_count, - ) - - def update_physical_experts_metadata( - self, - num_physical_experts: int, - num_local_physical_experts: int, - ) -> None: - assert self.num_local_physical_experts == num_local_physical_experts - self.num_physical_experts = num_physical_experts - self.num_local_physical_experts = num_local_physical_experts - self.num_redundant_experts = (num_physical_experts - - self.num_logical_experts) - for layer in self.model.layers: - if isinstance(layer.mlp, Qwen3MoeSparseMoeBlock): - moe = layer.mlp - moe.n_local_physical_experts = num_local_physical_experts - moe.n_physical_experts = num_physical_experts - moe.n_redundant_experts = self.num_redundant_experts - moe.experts.update_expert_map() + self.num_logical_experts = self.example_moe.n_logical_experts + self.num_physical_experts = self.example_moe.n_physical_experts + self.num_local_physical_experts = self.example_moe.n_local_physical_experts + self.num_routed_experts = self.example_moe.n_routed_experts + self.num_redundant_experts = self.example_moe.n_redundant_experts def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index e8ad9c2fca07..dd0f62011619 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -27,6 +27,7 @@ from vllm.config import (CompilationLevel, CUDAGraphMode, VllmConfig, get_layers_from_vllm_config, update_config) from vllm.distributed.eplb.eplb_state import EplbState +from vllm.distributed.eplb.gpu_model_register import model_register from vllm.distributed.kv_transfer import (get_kv_transfer_group, has_kv_transfer_group) from vllm.distributed.kv_transfer.kv_connector.utils import copy_kv_blocks @@ -2520,6 +2521,8 @@ def load_model(self, eep_scale_up: bool = False) -> None: logger.info("Loading model from scratch...") self.model = model_loader.load_model( vllm_config=self.vllm_config, model_config=self.model_config) + if self.parallel_config.enable_eplb: + model_register(self.model) if self.lora_config: self.model = self.load_lora_model(self.model, self.model_config,