@@ -283,16 +283,14 @@ def calculate_num_chunks(self, all_rank_num_tokens: List[int]) -> int:
283283 return (num_rows + self .moe_max_num_tokens -
284284 1 ) // self .moe_max_num_tokens
285285
286- def can_use_alltoall (self , input , all_rank_num_tokens ):
286+ def can_use_alltoall (self , all_rank_num_tokens , all_rank_max_num_tokens ):
287287 # Disable alltoall when chunking is used
288288 if self .calculate_num_chunks (all_rank_num_tokens ) > 1 :
289289 return False
290290
291- num_tokens = input .shape [0 ]
292-
293291 # For DeepEPLowLatency, check if tokens exceed the threshold
294292 if (self .alltoall_method_type == AlltoallMethodType .DeepEPLowLatency
295- and num_tokens > self .deep_ep_max_num_tokens ):
293+ and all_rank_max_num_tokens > self .deep_ep_max_num_tokens ):
296294 return False
297295
298296 return self .enable_alltoall
@@ -726,7 +724,8 @@ def forward(
726724
727725 # in case of num_rows is larger than max_chunk_size, we need to split the input into multiple chunks
728726 num_chunks = self .calculate_num_chunks (all_rank_num_tokens )
729- use_all_to_all = self .can_use_alltoall (x , all_rank_num_tokens )
727+ use_all_to_all = self .can_use_alltoall (all_rank_num_tokens ,
728+ all_rank_max_num_tokens )
730729
731730 if use_dp_padding :
732731 all_rank_num_tokens_padded = [all_rank_max_num_tokens
0 commit comments