Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 14 additions & 3 deletions vllm/model_executor/layers/fused_moe/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1341,6 +1341,13 @@ def fused_experts_impl(hidden_states: torch.Tensor,
curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx]
curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx]

# FIXME: We apply router weights before we do mm. To properly
# achieve this, we should move this into our moe kernels.
if apply_router_weight_on_input:
assert topk_ids.shape[1] == 1, "Can only apply router weight \
on input when topk is 1!"
curr_hidden_states = curr_hidden_states * curr_topk_weights

qcurr_hidden_states, qa1_scale = moe_kernel_prepare_input(
A=curr_hidden_states,
B=w1,
Expand All @@ -1367,7 +1374,11 @@ def fused_experts_impl(hidden_states: torch.Tensor,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
apply_router_weight_on_input,
# FIXME: Always False here because fused_moe_kernel
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

probably this comment not needed, instead we add the explanation to where apply_router_weight_on_input actually used

# apply router weight on mm result, not on input
# before mm. Apply router weight on input is done
# outside the kernel for now.
False,
top_k_num,
config,
compute_type=compute_type,
Expand Down Expand Up @@ -1486,8 +1497,8 @@ def fused_moe(
Defaults to False.
- global_num_experts (int): The total number of experts in the global
expert space.
- expert_map (Optional[torch.Tensor]): A tensor mapping expert indices
from the global expert space to the local expert space of the expert
- expert_map (Optional[torch.Tensor]): A tensor mapping expert indices
from the global expert space to the local expert space of the expert
parallel shard.
- w1_scale (Optional[torch.Tensor]): Optional scale to be used for
w1.
Expand Down
Loading