@@ -368,15 +368,16 @@ def reducescatter_or_allreduce(
368368 return outputs
369369
370370 def forward_chunk (
371- self ,
372- x : Union [torch .Tensor , Fp4QuantizedTensor ],
373- router_logits : torch .Tensor ,
374- use_all_to_all : bool ,
375- output_dtype : Optional [torch .dtype ] = None ,
376- all_rank_num_tokens : Optional [List [int ]] = None ,
377- all_rank_max_num_tokens : Optional [int ] = None ,
378- use_dp_padding : Optional [bool ] = None ,
379- repeating_info : Tuple = (True , True ),
371+ self ,
372+ x : Union [torch .Tensor , Fp4QuantizedTensor ],
373+ router_logits : torch .Tensor ,
374+ use_all_to_all : bool ,
375+ output_dtype : Optional [torch .dtype ] = None ,
376+ all_rank_num_tokens : Optional [List [int ]] = None ,
377+ all_rank_max_num_tokens : Optional [int ] = None ,
378+ use_dp_padding : Optional [bool ] = None ,
379+ repeating_info : Tuple = (True , True ),
380+ alltoall_result_do_sum : bool = True ,
380381 ) -> torch .Tensor :
381382 if isinstance (x , Fp4QuantizedTensor ):
382383 assert output_dtype is not None
@@ -389,6 +390,9 @@ def forward_chunk(
389390 if self .layer_load_balancer and is_first_call :
390391 self .layer_load_balancer .start_wait_gpu_stage ()
391392
393+ if not use_all_to_all or self .alltoall_method_type != AlltoallMethodType .MNNVL :
394+ alltoall_result_do_sum = True
395+
392396 use_deepseek_fp8_block_scale = False
393397 use_w4_group_scaling = False
394398 weight_dtype = self .w3_w1_weight .dtype
@@ -679,7 +683,8 @@ def forward_chunk(
679683 if self .enable_dummy_allreduce :
680684 self .dummy_allreduce ()
681685 final_hidden_states = self .alltoall_combine (
682- final_hidden_states , alltoall_info , token_count )
686+ final_hidden_states , alltoall_info , token_count ,
687+ alltoall_result_do_sum )
683688 elif self .alltoall_method_type == AlltoallMethodType .DeepEP :
684689 final_hidden_states = self .unpad_tensors (
685690 padded , final_hidden_states )
@@ -719,6 +724,7 @@ def forward(
719724 all_rank_num_tokens : Optional [List [int ]] = None ,
720725 all_rank_max_num_tokens : Optional [int ] = None ,
721726 use_dp_padding : Optional [bool ] = None ,
727+ alltoall_result_do_sum : bool = True ,
722728 ) -> torch .Tensor :
723729 assert all_rank_num_tokens is not None
724730 assert use_dp_padding is not None
@@ -744,7 +750,8 @@ def forward(
744750 all_rank_num_tokens = all_rank_num_tokens_padded ,
745751 all_rank_max_num_tokens = all_rank_max_num_tokens ,
746752 use_dp_padding = use_dp_padding ,
747- repeating_info = (is_first_call , is_last_call ))
753+ repeating_info = (is_first_call , is_last_call ),
754+ alltoall_result_do_sum = alltoall_result_do_sum )
748755 outputs = self .reducescatter_or_allreduce (
749756 outputs ,
750757 use_all_to_all ,
@@ -804,7 +811,8 @@ def split_chunk(split_token_num: int, split_num_chunks: int):
804811 all_rank_max_num_tokens =
805812 all_rank_max_num_tokens_list [idx_chunk ],
806813 use_dp_padding = use_dp_padding ,
807- repeating_info = (is_first_call , is_last_call ))
814+ repeating_info = (is_first_call , is_last_call ),
815+ alltoall_result_do_sum = alltoall_result_do_sum )
808816 if idx_chunk > 0 :
809817 outputs_list [- 1 ] = self .reducescatter_or_allreduce (
810818 outputs_list [- 1 ],
@@ -822,7 +830,8 @@ def split_chunk(split_token_num: int, split_num_chunks: int):
822830 all_rank_max_num_tokens = all_rank_max_num_tokens_list [
823831 idx_chunk ],
824832 use_dp_padding = use_dp_padding ,
825- repeating_info = (is_first_call , is_last_call ))
833+ repeating_info = (is_first_call , is_last_call ),
834+ alltoall_result_do_sum = alltoall_result_do_sum )
826835 with torch .cuda .stream (self .aux_stream ):
827836 outputs_list [- 1 ] = self .reducescatter_or_allreduce (
828837 outputs_list [- 1 ],
@@ -838,7 +847,8 @@ def split_chunk(split_token_num: int, split_num_chunks: int):
838847 all_rank_num_tokens = all_rank_num_tokens_list [idx_chunk ],
839848 all_rank_max_num_tokens = all_rank_max_num_tokens_list [
840849 idx_chunk ],
841- repeating_info = (is_first_call , is_last_call ))
850+ repeating_info = (is_first_call , is_last_call ),
851+ alltoall_result_do_sum = alltoall_result_do_sum )
842852
843853 outputs_list .append (outputs )
844854 if not use_all_to_all :
@@ -894,7 +904,8 @@ def alltoall_dispatch(self, x: torch.Tensor, x_sf: Optional[torch.Tensor],
894904 return x , x_sf , token_selected_slots , token_final_scales
895905
896906 def alltoall_combine (self , final_hidden_states : torch .Tensor ,
897- alltoall_info : MoEAlltoallInfo , token_count : int ):
907+ alltoall_info : MoEAlltoallInfo , token_count : int ,
908+ alltoall_result_do_sum : bool ):
898909 top_k = self .routing_method .experts_per_token
899910 if isinstance (final_hidden_states , list ):
900911 final_hidden_states = final_hidden_states [0 ]
@@ -907,7 +918,7 @@ def alltoall_combine(self, final_hidden_states: torch.Tensor,
907918 top_k = top_k ,
908919 token_count = token_count ,
909920 use_low_precision_combine = self .use_low_precision_combine ,
910- do_reduce = False )
921+ do_reduce = alltoall_result_do_sum )
911922
912923 return final_hidden_states
913924
0 commit comments