@@ -1599,6 +1599,19 @@ def forward(
15991599 (0 , self .hidden_size - og_hidden_states ),
16001600 mode = 'constant' ,
16011601 value = 0.0 )
1602+ do_naive_dispatch_combine : bool = (
1603+ self .dp_size > 1
1604+ and not self .moe_parallel_config .use_deepep_ht_kernels
1605+ and not self .moe_config .use_flashinfer_cutlass_kernels )
1606+
1607+ def reduce_output (states : torch .Tensor ) -> torch .Tensor :
1608+ if do_naive_dispatch_combine :
1609+ states = get_ep_group ().combine (states )
1610+
1611+ if self .reduce_results and (self .tp_size > 1 or self .ep_size > 1 ):
1612+ states = self .maybe_all_reduce_tensor_model_parallel (states )
1613+
1614+ return states
16021615
16031616 if self .shared_experts is None :
16041617 if current_platform .is_tpu ():
@@ -1609,7 +1622,7 @@ def forward(
16091622 else :
16101623 fused_output = torch .ops .vllm .moe_forward (
16111624 hidden_states , router_logits , self .layer_name )
1612- return fused_output [..., :og_hidden_states ]
1625+ return reduce_output ( fused_output [..., :og_hidden_states ])
16131626 else :
16141627 if current_platform .is_tpu ():
16151628 # TODO: Once the OOM issue for the TPU backend is resolved, we
@@ -1619,8 +1632,8 @@ def forward(
16191632 else :
16201633 shared_output , fused_output = torch .ops .vllm .moe_forward_shared (
16211634 hidden_states , router_logits , self .layer_name )
1622- return (shared_output [..., :og_hidden_states ],
1623- fused_output [..., :og_hidden_states ])
1635+ return (reduce_output ( shared_output [..., :og_hidden_states ]) ,
1636+ reduce_output ( fused_output [..., :og_hidden_states ]) )
16241637
16251638 def forward_impl_chunked (
16261639 self ,
@@ -1786,23 +1799,7 @@ def forward_impl(
17861799 shared_output ,
17871800 final_hidden_states ,
17881801 )
1789-
1790- def reduce_output (states : torch .Tensor ) -> torch .Tensor :
1791- if do_naive_dispatch_combine :
1792- states = get_ep_group ().combine (states )
1793-
1794- if self .reduce_results and (self .tp_size > 1 or self .ep_size > 1 ):
1795- states = self .maybe_all_reduce_tensor_model_parallel (states )
1796-
1797- return states
1798-
1799- if self .shared_experts is None :
1800- return reduce_output (final_hidden_states )
1801- else :
1802- return (
1803- reduce_output (final_hidden_states [0 ]),
1804- reduce_output (final_hidden_states [1 ]),
1805- )
1802+ return final_hidden_states
18061803
18071804 @classmethod
18081805 def make_expert_params_mapping (
0 commit comments