Skip to content

Commit 82e7899

Browse files
yifeizhang-creasonsolo
authored andcommitted
[TRTLLM-6368] Update deepep dispatch API (NVIDIA#6037)
Signed-off-by: Yifei Zhang <[email protected]>
1 parent fde179a commit 82e7899

File tree

3 files changed

+12
-18
lines changed

3 files changed

+12
-18
lines changed

cpp/tensorrt_llm/deep_ep/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
set(DEEP_EP_COMMIT c381dadf43a85062f6a8947592017ee513abc70b)
1+
set(DEEP_EP_COMMIT eb3f072664251c05074c3ecc3c3f5dad179c29a9)
22
set(NVSHMEM_URL_HASH
33
SHA256=eb2c8fb3b7084c2db86bd9fd905387909f1dfd483e7b45f7b3c3d5fcf5374b5a)
44

tensorrt_llm/_torch/modules/fused_moe/deep_ep_utils.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def reserve(self, hidden_size: int, hidden_dtype: torch.dtype):
5959

6060
def dispatch(self, x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
6161
topk_idx: torch.Tensor, topk_weights: torch.Tensor,
62-
num_experts: int) -> \
62+
num_experts: int, global_expert_id_offset: int) -> \
6363
Tuple[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], torch.Tensor, torch.Tensor, List, Tuple]:
6464
# NOTES: an optional `previous_event` means a CUDA event captured that you want to make it as a dependency
6565
# of the dispatch kernel, it may be useful with communication-computation overlap. For more information, please
@@ -76,7 +76,8 @@ def dispatch(self, x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
7676
recv_x, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list, handle, event = \
7777
self.buffer.dispatch(x, topk_idx=topk_idx, topk_weights=topk_weights,
7878
num_tokens_per_rank=num_tokens_per_rank, num_tokens_per_rdma_rank=num_tokens_per_rdma_rank,
79-
is_token_in_rank=is_token_in_rank, num_tokens_per_expert=num_tokens_per_expert)
79+
is_token_in_rank=is_token_in_rank, num_tokens_per_expert=num_tokens_per_expert,
80+
global_expert_id_offset=global_expert_id_offset)
8081
assert event.event is None
8182

8283
# For event management, please refer to the docs of the `EventOverlap` class

tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py

Lines changed: 8 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -455,12 +455,13 @@ def forward_chunk(
455455
elif self.alltoall_method_type == AlltoallMethodType.DeepEP:
456456
if not use_postquant_alltoall:
457457
x, recv_topk_idx, token_final_scales, num_recv_tokens_per_expert_list, deep_ep_handle = \
458-
self.deep_ep_buffer.dispatch(x, token_selected_slots.to(torch.int64), token_final_scales, self.num_slots)
459-
padded, x, _, recv_topk_idx, token_final_scales = self.pad_empty_recv_tensors(
458+
self.deep_ep_buffer.dispatch(x, token_selected_slots, token_final_scales, self.num_slots,
459+
self.expert_size_per_partition * self.mapping.moe_ep_rank)
460+
padded, x, _, token_selected_slots, token_final_scales = self.pad_empty_recv_tensors(
460461
x, None, recv_topk_idx, token_final_scales)
461462
elif self.alltoall_method_type == AlltoallMethodType.DeepEPLowLatency:
462463
if not use_postquant_alltoall:
463-
deep_ep_topk_idx = token_selected_slots.to(torch.int64)
464+
deep_ep_topk_idx = token_selected_slots
464465
deep_ep_topk_weights = token_final_scales
465466
x, recv_expert_count, deep_ep_handle = \
466467
self.deep_ep_buffer.low_latency_dispatch(x, deep_ep_topk_idx, self.deep_ep_max_num_tokens, self.num_slots)
@@ -588,8 +589,9 @@ def forward_chunk(
588589
x_sf_dtype = x_sf.dtype
589590
x_sf = x_sf.view(torch.float32)
590591
(x, x_sf), recv_topk_idx, token_final_scales, num_recv_tokens_per_expert_list, deep_ep_handle = \
591-
self.deep_ep_buffer.dispatch((x, x_sf), token_selected_slots.to(torch.int64), token_final_scales, self.num_slots)
592-
padded, x, x_sf, recv_topk_idx, token_final_scales = self.pad_empty_recv_tensors(
592+
self.deep_ep_buffer.dispatch((x, x_sf), token_selected_slots, token_final_scales, self.num_slots,
593+
self.expert_size_per_partition * self.mapping.moe_ep_rank)
594+
padded, x, x_sf, token_selected_slots, token_final_scales = self.pad_empty_recv_tensors(
593595
x, x_sf, recv_topk_idx, token_final_scales)
594596
if x_sf is not None:
595597
x_sf = x_sf.view(x_sf_dtype)
@@ -619,7 +621,7 @@ def forward_chunk(
619621
fp4_packed_tensor[:,
620622
x.shape[1]:x.shape[1] + x_sf.shape[1]] = x_sf
621623

622-
deep_ep_topk_idx = token_selected_slots.to(torch.int64)
624+
deep_ep_topk_idx = token_selected_slots
623625
deep_ep_topk_weights = token_final_scales
624626
# Each LL combine/dispatch kernel call requires that the `dispatch_rdma_recv_count_buffer` be properly cleaned.
625627
# However, the offset of this buffer within the entire RDMA buffer changes according to the hidden size.
@@ -668,15 +670,6 @@ def forward_chunk(
668670
f"Not available alltoall method type: {self.alltoall_method_type!r}"
669671
)
670672

671-
if use_all_to_all:
672-
# Adapter between `torch.ops.trtllm.fused_moe` and DeepEP
673-
# TODO: remove the adapter by changing APIs
674-
if self.alltoall_method_type == AlltoallMethodType.DeepEP:
675-
token_selected_slots = recv_topk_idx.to(torch.int32)
676-
mask = token_selected_slots == -1
677-
token_selected_slots += self.expert_size_per_partition * self.mapping.moe_ep_rank
678-
token_selected_slots[mask] = self.num_slots
679-
680673
final_hidden_states = torch.ops.trtllm.fused_moe(
681674
x,
682675
token_selected_slots,

0 commit comments

Comments
 (0)