diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 2a988b8644b5..a7af12a1d72f 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1341,6 +1341,13 @@ def fused_experts_impl(hidden_states: torch.Tensor, curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx] curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx] + # FIXME: We apply router weights before we do mm. To properly + # achieve this, we should move this into our moe kernels. + if apply_router_weight_on_input: + assert topk_ids.shape[1] == 1, "Can only apply router weight \ + on input when topk is 1!" + curr_hidden_states = curr_hidden_states * curr_topk_weights + qcurr_hidden_states, qa1_scale = moe_kernel_prepare_input( A=curr_hidden_states, B=w1, @@ -1367,7 +1374,11 @@ def fused_experts_impl(hidden_states: torch.Tensor, sorted_token_ids, expert_ids, num_tokens_post_padded, - apply_router_weight_on_input, + # FIXME: Always False here because fused_moe_kernel + # apply router weight on mm result, not on input + # before mm. Apply router weight on input is done + # outside the kernel for now. + False, top_k_num, config, compute_type=compute_type, @@ -1486,8 +1497,8 @@ def fused_moe( Defaults to False. - global_num_experts (int): The total number of experts in the global expert space. - - expert_map (Optional[torch.Tensor]): A tensor mapping expert indices - from the global expert space to the local expert space of the expert + - expert_map (Optional[torch.Tensor]): A tensor mapping expert indices + from the global expert space to the local expert space of the expert parallel shard. - w1_scale (Optional[torch.Tensor]): Optional scale to be used for w1.