@@ -455,12 +455,13 @@ def forward_chunk(
455455 elif self .alltoall_method_type == AlltoallMethodType .DeepEP :
456456 if not use_postquant_alltoall :
457457 x , recv_topk_idx , token_final_scales , num_recv_tokens_per_expert_list , deep_ep_handle = \
458- self .deep_ep_buffer .dispatch (x , token_selected_slots .to (torch .int64 ), token_final_scales , self .num_slots )
459- padded , x , _ , recv_topk_idx , token_final_scales = self .pad_empty_recv_tensors (
458+ self .deep_ep_buffer .dispatch (x , token_selected_slots , token_final_scales , self .num_slots ,
459+ self .expert_size_per_partition * self .mapping .moe_ep_rank )
460+ padded , x , _ , token_selected_slots , token_final_scales = self .pad_empty_recv_tensors (
460461 x , None , recv_topk_idx , token_final_scales )
461462 elif self .alltoall_method_type == AlltoallMethodType .DeepEPLowLatency :
462463 if not use_postquant_alltoall :
463- deep_ep_topk_idx = token_selected_slots . to ( torch . int64 )
464+ deep_ep_topk_idx = token_selected_slots
464465 deep_ep_topk_weights = token_final_scales
465466 x , recv_expert_count , deep_ep_handle = \
466467 self .deep_ep_buffer .low_latency_dispatch (x , deep_ep_topk_idx , self .deep_ep_max_num_tokens , self .num_slots )
@@ -588,8 +589,9 @@ def forward_chunk(
588589 x_sf_dtype = x_sf .dtype
589590 x_sf = x_sf .view (torch .float32 )
590591 (x , x_sf ), recv_topk_idx , token_final_scales , num_recv_tokens_per_expert_list , deep_ep_handle = \
591- self .deep_ep_buffer .dispatch ((x , x_sf ), token_selected_slots .to (torch .int64 ), token_final_scales , self .num_slots )
592- padded , x , x_sf , recv_topk_idx , token_final_scales = self .pad_empty_recv_tensors (
592+ self .deep_ep_buffer .dispatch ((x , x_sf ), token_selected_slots , token_final_scales , self .num_slots ,
593+ self .expert_size_per_partition * self .mapping .moe_ep_rank )
594+ padded , x , x_sf , token_selected_slots , token_final_scales = self .pad_empty_recv_tensors (
593595 x , x_sf , recv_topk_idx , token_final_scales )
594596 if x_sf is not None :
595597 x_sf = x_sf .view (x_sf_dtype )
@@ -619,7 +621,7 @@ def forward_chunk(
619621 fp4_packed_tensor [:,
620622 x .shape [1 ]:x .shape [1 ] + x_sf .shape [1 ]] = x_sf
621623
622- deep_ep_topk_idx = token_selected_slots . to ( torch . int64 )
624+ deep_ep_topk_idx = token_selected_slots
623625 deep_ep_topk_weights = token_final_scales
624626 # Each LL combine/dispatch kernel call requires that the `dispatch_rdma_recv_count_buffer` be properly cleaned.
625627 # However, the offset of this buffer within the entire RDMA buffer changes according to the hidden size.
@@ -668,15 +670,6 @@ def forward_chunk(
668670 f"Not available alltoall method type: { self .alltoall_method_type !r} "
669671 )
670672
671- if use_all_to_all :
672- # Adapter between `torch.ops.trtllm.fused_moe` and DeepEP
673- # TODO: remove the adapter by changing APIs
674- if self .alltoall_method_type == AlltoallMethodType .DeepEP :
675- token_selected_slots = recv_topk_idx .to (torch .int32 )
676- mask = token_selected_slots == - 1
677- token_selected_slots += self .expert_size_per_partition * self .mapping .moe_ep_rank
678- token_selected_slots [mask ] = self .num_slots
679-
680673 final_hidden_states = torch .ops .trtllm .fused_moe (
681674 x ,
682675 token_selected_slots ,
0 commit comments