Skip to content

Commit 88076ee

Browse files
[fix] Fix can_use_alltoall in fused_moe_wide_ep.py (#6173)
Signed-off-by: Jinyang Yuan <[email protected]>
1 parent b4c7e8c commit 88076ee

File tree

1 file changed

+4
-5
lines changed

1 file changed

+4
-5
lines changed

tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -283,16 +283,14 @@ def calculate_num_chunks(self, all_rank_num_tokens: List[int]) -> int:
283283
return (num_rows + self.moe_max_num_tokens -
284284
1) // self.moe_max_num_tokens
285285

286-
def can_use_alltoall(self, input, all_rank_num_tokens):
286+
def can_use_alltoall(self, all_rank_num_tokens, all_rank_max_num_tokens):
287287
# Disable alltoall when chunking is used
288288
if self.calculate_num_chunks(all_rank_num_tokens) > 1:
289289
return False
290290

291-
num_tokens = input.shape[0]
292-
293291
# For DeepEPLowLatency, check if tokens exceed the threshold
294292
if (self.alltoall_method_type == AlltoallMethodType.DeepEPLowLatency
295-
and num_tokens > self.deep_ep_max_num_tokens):
293+
and all_rank_max_num_tokens > self.deep_ep_max_num_tokens):
296294
return False
297295

298296
return self.enable_alltoall
@@ -726,7 +724,8 @@ def forward(
726724

727725
# in case of num_rows is larger than max_chunk_size, we need to split the input into multiple chunks
728726
num_chunks = self.calculate_num_chunks(all_rank_num_tokens)
729-
use_all_to_all = self.can_use_alltoall(x, all_rank_num_tokens)
727+
use_all_to_all = self.can_use_alltoall(all_rank_num_tokens,
728+
all_rank_max_num_tokens)
730729

731730
if use_dp_padding:
732731
all_rank_num_tokens_padded = [all_rank_max_num_tokens

0 commit comments

Comments
 (0)