diff --git a/cpp/tensorrt_llm/deep_ep/CMakeLists.txt b/cpp/tensorrt_llm/deep_ep/CMakeLists.txt index f4c3f48bbb2..088391aef4f 100644 --- a/cpp/tensorrt_llm/deep_ep/CMakeLists.txt +++ b/cpp/tensorrt_llm/deep_ep/CMakeLists.txt @@ -1,4 +1,4 @@ -set(DEEP_EP_COMMIT 7b15af835942675df041eca2dcb9930b880287e1) +set(DEEP_EP_COMMIT edf3ea2b086a393d3163bf2773eab69d9191cc01) set(NVSHMEM_URL_HASH SHA256=eb2c8fb3b7084c2db86bd9fd905387909f1dfd483e7b45f7b3c3d5fcf5374b5a) diff --git a/tensorrt_llm/_torch/modules/fused_moe/deep_ep_utils.py b/tensorrt_llm/_torch/modules/fused_moe/deep_ep_utils.py index 385a5ec4b91..5ad37024817 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/deep_ep_utils.py +++ b/tensorrt_llm/_torch/modules/fused_moe/deep_ep_utils.py @@ -154,6 +154,24 @@ def low_latency_dispatch(self, hidden_states: torch.Tensor, # Later, you can use our GEMM library to do the computation with this specific format return recv_hidden_states, recv_expert_count, handle + def low_latency_dispatch_fp4(self, hidden_states: torch.Tensor, + scales: torch.Tensor, topk_idx: torch.Tensor, + num_max_dispatch_tokens_per_rank: int, + num_experts: int): + assert num_experts == self.num_experts + + # Do MoE dispatch, compatible with CUDA graph (but you may restore some buffer status once you replay) + recv_hidden_states, recv_scales, recv_expert_count, handle, event, hook = \ + self.buffer.low_latency_dispatch_fp4(hidden_states, scales, topk_idx, num_max_dispatch_tokens_per_rank, num_experts) + assert event.event is None + assert hook is None + + # NOTES: the actual tensor will not be received only if you call `hook()`, + # it is useful for double-batch overlapping, but **without any SM occupation** + # If you don't want to overlap, please set `return_recv_hook=False` + # Later, you can use our GEMM library to do the computation with this specific format + return recv_hidden_states, recv_scales, recv_expert_count, handle + def low_latency_combine(self, hidden_states: torch.Tensor, topk_idx: torch.Tensor, topk_weights: torch.Tensor, handle: Tuple): diff --git a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py index 81778c28544..f10387d1f68 100755 --- a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py +++ b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py @@ -588,43 +588,26 @@ def forward_chunk( x_sf = swizzle_sf(x_sf, x.shape[0], x.shape[1] * 2, self.scaling_vector_size) elif self.alltoall_method_type == AlltoallMethodType.DeepEPLowLatency: - assert x_sf is not None and self.has_nvfp4 token_num = x_row hidden_size = x_col + assert x_sf is not None and self.has_nvfp4 assert hidden_size % 32 == 0 - x_sf_dtype = x_sf.dtype - x_dtype = x.dtype - assert x_sf_dtype == torch.uint8 and x_dtype == torch.uint8 - x_sf = x_sf.view(torch.bfloat16) + assert x.dtype == torch.uint8 and x_sf.dtype == torch.uint8 assert x_sf.shape[0] == token_num and x_sf.shape[ - 1] == hidden_size // 16 // 2 - x = x.view(torch.bfloat16) - assert x.shape[0] == token_num and x.shape[1] == hidden_size // 4 - # DeepEP LL dispatch only supports bf16 tensors with a hidden size of 2560, 4096, 5120, or 7168 as input. A hidden size of 2560 is sufficient to accommodate packed FP4 data. - packed_hidden_size = 2560 - assert x.shape[1] + x_sf.shape[1] <= packed_hidden_size - fp4_packed_tensor = torch.empty((token_num, packed_hidden_size), - dtype=torch.bfloat16, - device=x.device) - fp4_packed_tensor[:, :x.shape[1]] = x - fp4_packed_tensor[:, - x.shape[1]:x.shape[1] + x_sf.shape[1]] = x_sf + 1] == hidden_size // 16 + assert x.shape[0] == token_num and x.shape[1] == hidden_size // 2 deep_ep_topk_idx = token_selected_slots deep_ep_topk_weights = token_final_scales assert all_rank_max_num_tokens <= self.deep_ep_max_num_tokens - fp4_packed_tensor, recv_expert_count, deep_ep_handle = \ - self.deep_ep_buffer.low_latency_dispatch(fp4_packed_tensor, deep_ep_topk_idx, all_rank_max_num_tokens, self.num_slots) - deep_ep_handle = list(deep_ep_handle) - deep_ep_handle[3] = hidden_size - deep_ep_handle = tuple(deep_ep_handle) - - assert fp4_packed_tensor.ndim == 3 and fp4_packed_tensor.shape[ - 2] == packed_hidden_size - x_sf = fp4_packed_tensor[:, :, x.shape[1]:x.shape[1] + - x_sf.shape[1]].contiguous() - x = fp4_packed_tensor[:, :, :x.shape[1]].contiguous() + x, x_sf, recv_expert_count, deep_ep_handle = \ + self.deep_ep_buffer.low_latency_dispatch_fp4(x, x_sf, deep_ep_topk_idx, all_rank_max_num_tokens, self.num_slots) + assert x.dtype == torch.uint8 and x_sf.dtype == torch.uint8 + assert x.dim() == 3 and x_sf.dim() == 3 + assert x.shape[2] == hidden_size // 2 and x_sf.shape[ + 2] == hidden_size // 16 + mask = torch.arange( x.shape[1], dtype=torch.int32, device=x.device).expand( x.shape[0], x.shape[1]) < recv_expert_count.unsqueeze(1) @@ -634,9 +617,9 @@ def forward_chunk( x.shape[0] * (self.mapping.moe_ep_rank + 1), dtype=torch.int32, device=x.device).unsqueeze(1), self.num_slots) - x = x.reshape(x.shape[0] * x.shape[1], x.shape[2]).view(x_dtype) + x = x.reshape(x.shape[0] * x.shape[1], x.shape[2]) x_sf = x_sf.reshape(x_sf.shape[0] * x_sf.shape[1], - x_sf.shape[2]).view(x_sf_dtype) + x_sf.shape[2]) x_sf = swizzle_sf(x_sf, x.shape[0], x.shape[1] * 2, self.scaling_vector_size) token_selected_slots = token_selected_slots.view(x.shape[0], 1)