Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 36 additions & 6 deletions vllm/distributed/eplb/eplb_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,13 @@
import time
from collections.abc import Sequence
from dataclasses import dataclass
from typing import Optional, Union
from typing import Optional, Union, get_args

import torch
from torch.distributed import ProcessGroup, all_reduce

from vllm.config import ParallelConfig
from vllm.config.parallel import ExpertPlacementStrategy
from vllm.distributed.parallel_state import (get_ep_group, get_node_count,
in_the_same_node_as)
from vllm.distributed.utils import StatelessProcessGroup
Expand Down Expand Up @@ -161,6 +162,7 @@ class EplbState:
def build_initial_global_physical_to_logical_map(
num_routed_experts: int,
num_redundant_experts: int,
expert_placement_strategy: ExpertPlacementStrategy = "linear",
) -> Sequence[int]:
"""
Build an initial expert arrangement using the following structure:
Expand All @@ -171,11 +173,36 @@ def build_initial_global_physical_to_logical_map(
where each integer is the index of the logical expert
that the corresponding physical expert maps to.
"""
global_physical_to_logical_map = list(range(num_routed_experts))
global_physical_to_logical_map += [
i % num_routed_experts for i in range(num_redundant_experts)
]
return global_physical_to_logical_map
if expert_placement_strategy == "linear":
global_physical_to_logical_map = list(range(num_routed_experts))
global_physical_to_logical_map += [
i % num_routed_experts for i in range(num_redundant_experts)
]
return global_physical_to_logical_map

elif expert_placement_strategy == "round_robin":
assert num_redundant_experts == 0, (
"Round-robin expert placement is not supported with "
"redundant experts.")

ep_group = get_ep_group().device_group
ep_size = ep_group.size()

base = num_routed_experts // ep_size
remainder = num_routed_experts % ep_size
global_physical_to_logical_map = []
for i in range(ep_size):
cnt = base + (1 if i < remainder else 0)
for k in range(cnt):
gid = i + k * ep_size
if gid < num_routed_experts:
global_physical_to_logical_map.append(gid)

return global_physical_to_logical_map
else:
raise ValueError("Unsupported expert placement strategy "
f"'{expert_placement_strategy}', expected one of "
f"{get_args(ExpertPlacementStrategy)}")

@classmethod
def build(
Expand All @@ -190,10 +217,13 @@ def build(
"""
Build the initial EPLB state.
"""
expert_placement_strategy = parallel_config.expert_placement_strategy

physical_to_logical_map_list = (
cls.build_initial_global_physical_to_logical_map(
model.num_routed_experts,
model.num_redundant_experts,
expert_placement_strategy,
))
physical_to_logical_map = torch.tensor(
physical_to_logical_map_list,
Expand Down
46 changes: 23 additions & 23 deletions vllm/model_executor/layers/fused_moe/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@
is_rocm_aiter_moe_enabled)
from vllm.model_executor.layers.fused_moe.routing_simulator import (
RoutingSimulator)
from vllm.model_executor.layers.fused_moe.utils import (
determine_expert_placement_strategy)
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.utils import set_weight_attrs
Expand Down Expand Up @@ -1023,22 +1025,15 @@ def __init__(
assert num_redundant_experts == 0, \
"Redundant experts are only supported with EPLB."

expert_placement_strategy = (
vllm_config.parallel_config.expert_placement_strategy)
if expert_placement_strategy == "round_robin":
# TODO(Bruce): will support round robin expert placement with
# EPLB enabled in the future.
round_robin_supported = ((num_expert_group is not None
and num_expert_group > 1)
and num_redundant_experts == 0
and not self.enable_eplb)

if not round_robin_supported:
logger.warning(
"Round-robin expert placement is only supported for "
"models with multiple expert groups and no redundant "
"experts. Falling back to linear expert placement.")
expert_placement_strategy = "linear"
expert_placement_strategy = determine_expert_placement_strategy(
num_expert_group=num_expert_group,
num_redundant_experts=num_redundant_experts,
)
if expert_placement_strategy == "round_robin" and self.enable_eplb:
# When eplb is enabled, it assumes that the expert_map is
# linear, so we keep it unchanged and apply the
# round_robin logic elsewhere.
expert_placement_strategy = "linear"

self.expert_map: Optional[torch.Tensor]
local_num_experts, expert_map = determine_expert_map(
Expand Down Expand Up @@ -2071,22 +2066,27 @@ def reduce_output(states: torch.Tensor,

@classmethod
def make_expert_params_mapping(
cls,
ckpt_gate_proj_name: str,
ckpt_down_proj_name: str,
ckpt_up_proj_name: str,
num_experts: int,
num_redundant_experts: int = 0) -> list[tuple[str, str, int, str]]:
cls,
ckpt_gate_proj_name: str,
ckpt_down_proj_name: str,
ckpt_up_proj_name: str,
num_experts: int,
num_redundant_experts: int = 0,
expert_placement_strategy: ExpertPlacementStrategy = None,
) -> list[tuple[str, str, int, str]]:

num_physical_experts = num_experts + num_redundant_experts
if expert_placement_strategy is None:
expert_placement_strategy = determine_expert_placement_strategy(
num_redundant_experts=num_redundant_experts)

# In the returned mapping:
# - `expert_id` is the physical expert id
# - `weight_name` contains the weight name of the logical expert
# So that we should map the expert id to logical in `weight_name`
physical_to_logical_map = \
EplbState.build_initial_global_physical_to_logical_map(
num_experts, num_redundant_experts)
num_experts, num_redundant_experts, expert_placement_strategy)

return [
# (param_name, weight_name, expert_id, shard_id)
Expand Down
40 changes: 39 additions & 1 deletion vllm/model_executor/layers/fused_moe/utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from math import prod
from typing import Optional, Union
from typing import Optional, Union, get_args

import torch

from vllm import _custom_ops as ops
from vllm.config import get_current_vllm_config
from vllm.config.parallel import ExpertPlacementStrategy
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8)
from vllm.model_executor.layers.quantization.utils.int8_utils import (
Expand All @@ -19,6 +22,8 @@
from vllm.utils import cdiv
from vllm.utils.flashinfer import fp4_quantize

logger = init_logger(__name__)


@triton.jit
def _count_expert_num_tokens(topk_ids_ptr, expert_num_tokens_ptr, num_experts,
Expand Down Expand Up @@ -272,3 +277,36 @@ def _validate_scale_shape(

def activation_without_mul(activation: str) -> str:
return activation + "_no_mul"


def determine_expert_placement_strategy(
num_expert_group: Optional[int] = None,
num_redundant_experts: Optional[int] = None,
) -> ExpertPlacementStrategy:
vllm_config = get_current_vllm_config()
expert_placement_strategy = \
vllm_config.parallel_config.expert_placement_strategy
if expert_placement_strategy == "linear":
return "linear"
elif expert_placement_strategy == "round_robin":
if num_redundant_experts is None:
num_redundant_experts = \
vllm_config.parallel_config.eplb_config.num_redundant_experts
if num_expert_group is None:
model_hf_config = vllm_config.model_config.hf_config
if hasattr(model_hf_config, "n_group"):
num_expert_group = model_hf_config.n_group
if num_expert_group is not None and num_expert_group > 1 \
and num_redundant_experts == 0:
return "round_robin"
else:
logger.warning(
"Round-robin expert placement is only supported for "
"models with multiple expert groups and no redundant "
"experts. Falling back to linear expert placement.")
return "linear"

else:
raise ValueError("Unsupported expert placement strategy "
f"'{expert_placement_strategy}', expected one of "
f"{get_args(ExpertPlacementStrategy)}")
14 changes: 9 additions & 5 deletions vllm/model_executor/model_loader/base_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch
import torch.nn as nn

from vllm.config import ModelConfig, VllmConfig
from vllm.config import ModelConfig, VllmConfig, set_current_vllm_config
from vllm.config.load import LoadConfig
from vllm.logger import init_logger
from vllm.model_executor.model_loader.utils import (
Expand Down Expand Up @@ -40,10 +40,14 @@ def load_model(self, vllm_config: VllmConfig,
load_device = device_config.device if load_config.device is None else \
load_config.device
target_device = torch.device(load_device)
with set_default_torch_dtype(model_config.dtype):
with target_device:
model = initialize_model(vllm_config=vllm_config,
model_config=model_config)

# Ensure the current vLLM config is available during model
# initialization, weight loading, and post-processing.
with set_current_vllm_config(vllm_config), \
set_default_torch_dtype(model_config.dtype), \
target_device:
model = initialize_model(vllm_config=vllm_config,
model_config=model_config)

logger.debug("Loading weights on %s ...", load_device)
# Quantization does not happen in `load_weights` but after it
Expand Down