Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
c954011
Merge pull request from ROCm/deepseek_085_sharedexperts_aiter_jun_new
valarLip Aug 1, 2025
487960d
merge upstream
kliuae Aug 25, 2025
dc038e3
fix deepseekv2
kliuae Aug 28, 2025
7a0fe48
pass extra args
kliuae Aug 28, 2025
4724152
add assert
kliuae Aug 28, 2025
ff85b25
fix deepseekr1 weight loading
kliuae Aug 29, 2025
1970a87
split weight_scale_inv
kliuae Sep 1, 2025
5642f33
simplify weight loading logic
kliuae Sep 2, 2025
0f67d7e
clean up
kliuae Sep 2, 2025
db28ff8
merge upstream
kliuae Sep 3, 2025
53ca3c9
fix aiter routed scaling factor
kliuae Sep 4, 2025
e3760c0
fix aiter routed scaling factor
kliuae Sep 4, 2025
fcceeb0
fix aiter routed scaling factor
kliuae Sep 4, 2025
257e504
merge upstream
kliuae Sep 5, 2025
333a06c
precommit
kliuae Sep 8, 2025
a7874bb
dp ep
kliuae Oct 2, 2025
37fd055
remove extra args passing
kliuae Oct 2, 2025
4c09594
pass n_shared_experts as int
kliuae Oct 2, 2025
bfab05d
reorganize shared experts buffer init
kliuae Oct 2, 2025
3a38671
merge upstream
kliuae Oct 2, 2025
ab93665
fix
kliuae Oct 2, 2025
d74bfb2
merge upstream
kliuae Oct 3, 2025
d9e6aee
merge upstream
kliuae Oct 8, 2025
87a7aa0
merge upstream
kliuae Oct 14, 2025
fdd3036
fix and add comments
kliuae Oct 14, 2025
3b6c324
pass in default value to num fused shared experts
kliuae Oct 15, 2025
af4ca2e
Merge branch 'main' into upstream-aiter-fmoe-sharedexperts
tjtanaa Oct 15, 2025
38d5075
ci
kliuae Oct 15, 2025
3d70b23
Merge branch 'upstream-aiter-fmoe-sharedexperts' of https://github.co…
kliuae Oct 15, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions tests/distributed/test_expert_placement.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def test_expert_placement_various_sizes(expert_placement_strategy, world_size):
else:
expected_test_local = base_experts

test_local_experts, test_expert_map = determine_expert_map(
test_local_experts, test_expert_map, _ = determine_expert_map(
ep_size=test_ep_size,
ep_rank=ep_rank,
global_num_experts=test_global_experts,
Expand Down Expand Up @@ -116,7 +116,7 @@ def test_expert_placement_edge_cases(expert_placement_strategy, world_size):
"""Test edge cases for round_robin expert placement."""

# Test case 1: ep_size = 1 (should return None for expert_map)
local_num_experts, expert_map = determine_expert_map(
local_num_experts, expert_map, _ = determine_expert_map(
ep_size=1,
ep_rank=0,
global_num_experts=8,
Expand Down Expand Up @@ -217,7 +217,7 @@ def test_determine_expert_map_comprehensive():
expected_local,
expected_map_pattern,
) in test_cases:
local_num_experts, expert_map = determine_expert_map(
local_num_experts, expert_map, _ = determine_expert_map(
ep_size=ep_size,
ep_rank=ep_rank,
global_num_experts=global_num_experts,
Expand Down
2 changes: 1 addition & 1 deletion tests/kernels/moe/test_moe_permute_unpermute.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ def test_moe_permute_unpermute(
expert_map = None
n_local_expert = n_expert
if ep_size != 1:
n_local_expert, expert_map = determine_expert_map(ep_size, ep_rank, n_expert)
n_local_expert, expert_map, _ = determine_expert_map(ep_size, ep_rank, n_expert)
expert_map = expert_map.cuda()
start_expert = n_local_expert * ep_rank
current_platform.seed_everything(0)
Expand Down
7 changes: 7 additions & 0 deletions vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@
VLLM_ROCM_USE_TRITON_ROPE: bool = False
VLLM_ROCM_USE_AITER_FP8BMM: bool = True
VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION: bool = False
VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS: bool = True
VLLM_ROCM_USE_SKINNY_GEMM: bool = True
VLLM_ROCM_FP8_PADDING: bool = True
VLLM_ROCM_MOE_PADDING: bool = True
Expand Down Expand Up @@ -913,6 +914,12 @@ def get_vllm_port() -> int | None:
os.getenv("VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION", "False").lower()
in ("true", "1")
),
# Whether to use aiter fusion shared experts ops.
# By default is enabled.
"VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS": lambda: (
os.getenv("VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS", "True").lower()
in ("true", "1")
),
# use rocm skinny gemms
"VLLM_ROCM_USE_SKINNY_GEMM": lambda: (
os.getenv("VLLM_ROCM_USE_SKINNY_GEMM", "True").lower() in ("true", "1")
Expand Down
116 changes: 105 additions & 11 deletions vllm/model_executor/layers/fused_moe/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,15 @@
from collections.abc import Callable, Iterable
from contextlib import nullcontext
from enum import Enum
from functools import partial
from typing import Literal, get_args, overload

import torch
import torch.nn.functional as F
from torch.nn.parameter import UninitializedParameter

import vllm.envs as envs
from vllm.config import get_current_vllm_config
from vllm.config import VllmConfig, get_current_vllm_config
from vllm.config.parallel import ExpertPlacementStrategy
from vllm.distributed import (
get_dp_group,
Expand All @@ -39,6 +40,8 @@
FusedMoEPrepareAndFinalize,
)
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
init_aiter_topK_meta_data,
is_rocm_aiter_fusion_shared_expert_enabled,
is_rocm_aiter_moe_enabled,
)
from vllm.model_executor.layers.fused_moe.routing_simulator import RoutingSimulator
Expand Down Expand Up @@ -87,7 +90,7 @@ def _eplb_map_to_physical_and_record(

if is_rocm_aiter_moe_enabled():
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa: E501
rocm_aiter_grouped_topk as grouped_topk,
rocm_aiter_grouped_topk as grouped_topk_aiter,
)
else:
from vllm.model_executor.layers.fused_moe.fused_moe import grouped_topk
Expand Down Expand Up @@ -634,6 +637,7 @@ def forward_cuda(
global_num_experts=global_num_experts,
zero_expert_num=zero_expert_num,
zero_expert_type=zero_expert_type,
num_fused_shared_experts=layer.num_fused_shared_experts,
)

if self.rocm_aiter_moe_enabled:
Expand Down Expand Up @@ -860,7 +864,8 @@ def determine_expert_map(
ep_rank: int,
global_num_experts: int,
expert_placement_strategy: ExpertPlacementStrategy = "linear",
) -> tuple[int, torch.Tensor | None]:
num_fused_shared_experts: int = 0,
) -> tuple[int, torch.Tensor | None, torch.Tensor | None]:
"""
Calculates how many experts should be assigned to each rank for EP and
creates a mapping from global to local expert index. Experts are
Expand All @@ -882,10 +887,16 @@ def determine_expert_map(
(global_num_experts,) mapping from global to local index.
Contains -1 for experts not assigned to the current rank.
Returns None if ep_size is 1.
- expert_mask (Optional[torch.Tensor]): A tensor of shape
(global_num_experts + num_fused_shared_experts + 1,)
containing 1 for experts assigned to the current rank
and 0 for sentinel.
Returns None if ep_size is 1.
Used only when AITER MOE is enabled.
"""
assert ep_size > 0
if ep_size == 1:
return (global_num_experts, None)
return (global_num_experts, None, None)

# Distribute experts as evenly as possible to each rank.
base_experts = global_num_experts // ep_size
Expand Down Expand Up @@ -914,7 +925,26 @@ def determine_expert_map(
f"'{expert_placement_strategy}', expected one of "
f"{get_args(ExpertPlacementStrategy)}"
)
return (local_num_experts, expert_map)

expert_mask = None
if is_rocm_aiter_moe_enabled():
expert_mask = torch.ones(
(global_num_experts + num_fused_shared_experts + 1,), dtype=torch.int32
)
expert_mask[-1] = 0
expert_mask[:global_num_experts] = expert_map > -1
expert_map = torch.cat(
(
expert_map,
torch.tensor(
[local_num_experts + i for i in range(num_fused_shared_experts)],
dtype=torch.int32,
),
),
dim=0,
)

return (local_num_experts, expert_map, expert_mask)


def get_compressed_expert_map(expert_map: torch.Tensor) -> str:
Expand Down Expand Up @@ -1040,6 +1070,7 @@ def __init__(
zero_expert_num: int | None = 0,
zero_expert_type: str | None = None,
expert_mapping: list[tuple[str, str, int, str]] | None = None,
n_shared_experts: int | None = None,
):
super().__init__()
if params_dtype is None:
Expand Down Expand Up @@ -1096,6 +1127,22 @@ def __init__(
self.logical_to_physical_map: torch.Tensor | None = None
self.logical_replica_count: torch.Tensor | None = None

# ROCm aiter shared experts fusion
self.num_fused_shared_experts = (
n_shared_experts
if n_shared_experts is not None
and is_rocm_aiter_fusion_shared_expert_enabled()
else 0
)
if (
not is_rocm_aiter_fusion_shared_expert_enabled()
and self.num_fused_shared_experts != 0
):
raise ValueError(
"n_shared_experts is only supported on ROCm aiter when "
"VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS is enabled"
)

# Determine expert maps
if self.use_ep:
if self.enable_eplb:
Expand Down Expand Up @@ -1129,14 +1176,16 @@ def __init__(
expert_placement_strategy = "linear"

self.expert_map: torch.Tensor | None
local_num_experts, expert_map = determine_expert_map(
local_num_experts, expert_map, expert_mask = determine_expert_map(
ep_size=self.ep_size,
ep_rank=self.ep_rank,
global_num_experts=self.global_num_experts,
expert_placement_strategy=expert_placement_strategy,
num_fused_shared_experts=self.num_fused_shared_experts,
)
self.local_num_experts = local_num_experts
self.register_buffer("expert_map", expert_map)
self.register_buffer("expert_mask", expert_mask)
logger.info_once(
"[EP Rank %s/%s] Expert parallelism is enabled. Expert "
"placement strategy: %s. Local/global"
Expand All @@ -1150,10 +1199,18 @@ def __init__(
get_compressed_expert_map(self.expert_map),
)
else:
self.local_num_experts, self.expert_map = (self.global_num_experts, None)
self.local_num_experts, self.expert_map, self.expert_mask = (
self.global_num_experts,
None,
None,
)

self.top_k = top_k

self._init_aiter_shared_experts_topK_buffer(
vllm_config=vllm_config, dp_size=dp_size_
)

assert intermediate_size % self.tp_size == 0
self.hidden_size = hidden_size
self.intermediate_size_per_partition = intermediate_size // self.tp_size
Expand Down Expand Up @@ -1327,13 +1384,18 @@ def update_expert_map(self):
# ep_size and ep_rank should already be updated
assert self.expert_map is not None
with self.expert_map.device:
local_num_experts, expert_map = determine_expert_map(
local_num_experts, expert_map, expert_mask = determine_expert_map(
ep_size=self.ep_size,
ep_rank=self.ep_rank,
global_num_experts=self.global_num_experts,
num_fused_shared_experts=self.num_fused_shared_experts,
)
self.local_num_experts = local_num_experts
self.register_buffer("expert_map", expert_map)
self.register_buffer("expert_mask", expert_mask)
self._init_aiter_shared_experts_topK_buffer(
vllm_config=get_current_vllm_config(), dp_size=get_dp_group().world_size
)

def _load_per_tensor_weight_scale(
self,
Expand Down Expand Up @@ -1504,6 +1566,24 @@ def _map_global_expert_id_to_local_expert_id(self, expert_id: int) -> int:
return expert_id
return self.expert_map[expert_id].item()

def _init_aiter_shared_experts_topK_buffer(
self, vllm_config: VllmConfig, dp_size: int
):
if is_rocm_aiter_fusion_shared_expert_enabled():
if self.num_fused_shared_experts > 0:
init_aiter_topK_meta_data(
n_routed_experts=self.global_num_experts,
n_shared_experts=self.num_fused_shared_experts,
top_k=self.top_k,
tp_rank=self.ep_rank if self.use_ep else self.tp_rank,
tp_size=self.ep_size if self.use_ep else self.tp_size,
shared_experts_score=1.0,
max_num_tokens=vllm_config.scheduler_config.max_num_batched_tokens
* dp_size,
is_EP=self.use_ep,
)
self.local_num_experts += self.num_fused_shared_experts

@overload
def weight_loader(
self,
Expand Down Expand Up @@ -1866,6 +1946,7 @@ def select_experts(
global_num_experts: int | None = None,
zero_expert_num: int | None = None,
zero_expert_type: str | None = None,
num_fused_shared_experts: int = 0,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Route the input hidden states to the top-k experts based on the
Expand Down Expand Up @@ -1900,7 +1981,16 @@ def select_experts(
if use_grouped_topk:
assert topk_group is not None
assert num_expert_group is not None
topk_weights, topk_ids = grouped_topk(
if is_rocm_aiter_moe_enabled():
if not is_rocm_aiter_fusion_shared_expert_enabled():
assert num_fused_shared_experts == 0
grouped_topk_impl = partial(
grouped_topk_aiter,
num_fused_shared_experts=num_fused_shared_experts,
)
else:
grouped_topk_impl = grouped_topk
topk_weights, topk_ids = grouped_topk_impl(
hidden_states=hidden_states,
gating_output=router_logits,
topk=top_k,
Expand Down Expand Up @@ -2119,7 +2209,9 @@ def process_chunk(chunk_start, chunk_end, skip_result_store=False):
renormalize=self.renormalize,
use_grouped_topk=self.use_grouped_topk,
global_num_experts=self.global_num_experts,
expert_map=self.expert_map,
expert_map=self.expert_map
if not is_rocm_aiter_moe_enabled()
else self.expert_mask,
topk_group=self.topk_group,
num_expert_group=self.num_expert_group,
custom_routing_function=self.custom_routing_function,
Expand Down Expand Up @@ -2244,7 +2336,9 @@ def forward_impl(
renormalize=self.renormalize,
use_grouped_topk=self.use_grouped_topk,
global_num_experts=self.global_num_experts,
expert_map=self.expert_map,
expert_map=self.expert_map
if not is_rocm_aiter_moe_enabled()
else self.expert_mask,
topk_group=self.topk_group,
num_expert_group=self.num_expert_group,
custom_routing_function=self.custom_routing_function,
Expand Down
Loading