Skip to content

Commit 61ebc95

Browse files
committed
Move all_reduce from custom op in fused_moe
Signed-off-by: ilmarkov <[email protected]>
1 parent fde8bd1 commit 61ebc95

File tree

1 file changed

+17
-20
lines changed
  • vllm/model_executor/layers/fused_moe

1 file changed

+17
-20
lines changed

vllm/model_executor/layers/fused_moe/layer.py

Lines changed: 17 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1599,6 +1599,19 @@ def forward(
15991599
(0, self.hidden_size - og_hidden_states),
16001600
mode='constant',
16011601
value=0.0)
1602+
do_naive_dispatch_combine: bool = (
1603+
self.dp_size > 1
1604+
and not self.moe_parallel_config.use_deepep_ht_kernels
1605+
and not self.moe_config.use_flashinfer_cutlass_kernels)
1606+
1607+
def reduce_output(states: torch.Tensor) -> torch.Tensor:
1608+
if do_naive_dispatch_combine:
1609+
states = get_ep_group().combine(states)
1610+
1611+
if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1):
1612+
states = self.maybe_all_reduce_tensor_model_parallel(states)
1613+
1614+
return states
16021615

16031616
if self.shared_experts is None:
16041617
if current_platform.is_tpu():
@@ -1609,7 +1622,7 @@ def forward(
16091622
else:
16101623
fused_output = torch.ops.vllm.moe_forward(
16111624
hidden_states, router_logits, self.layer_name)
1612-
return fused_output[..., :og_hidden_states]
1625+
return reduce_output(fused_output[..., :og_hidden_states])
16131626
else:
16141627
if current_platform.is_tpu():
16151628
# TODO: Once the OOM issue for the TPU backend is resolved, we
@@ -1619,8 +1632,8 @@ def forward(
16191632
else:
16201633
shared_output, fused_output = torch.ops.vllm.moe_forward_shared(
16211634
hidden_states, router_logits, self.layer_name)
1622-
return (shared_output[..., :og_hidden_states],
1623-
fused_output[..., :og_hidden_states])
1635+
return (reduce_output(shared_output[..., :og_hidden_states]),
1636+
reduce_output(fused_output[..., :og_hidden_states]))
16241637

16251638
def forward_impl_chunked(
16261639
self,
@@ -1786,23 +1799,7 @@ def forward_impl(
17861799
shared_output,
17871800
final_hidden_states,
17881801
)
1789-
1790-
def reduce_output(states: torch.Tensor) -> torch.Tensor:
1791-
if do_naive_dispatch_combine:
1792-
states = get_ep_group().combine(states)
1793-
1794-
if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1):
1795-
states = self.maybe_all_reduce_tensor_model_parallel(states)
1796-
1797-
return states
1798-
1799-
if self.shared_experts is None:
1800-
return reduce_output(final_hidden_states)
1801-
else:
1802-
return (
1803-
reduce_output(final_hidden_states[0]),
1804-
reduce_output(final_hidden_states[1]),
1805-
)
1802+
return final_hidden_states
18061803

18071804
@classmethod
18081805
def make_expert_params_mapping(

0 commit comments

Comments
 (0)