Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cpp/tensorrt_llm/deep_ep/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
set(DEEP_EP_COMMIT 7b15af835942675df041eca2dcb9930b880287e1)
set(DEEP_EP_COMMIT edf3ea2b086a393d3163bf2773eab69d9191cc01)
set(NVSHMEM_URL_HASH
SHA256=eb2c8fb3b7084c2db86bd9fd905387909f1dfd483e7b45f7b3c3d5fcf5374b5a)

Expand Down
18 changes: 18 additions & 0 deletions tensorrt_llm/_torch/modules/fused_moe/deep_ep_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,24 @@ def low_latency_dispatch(self, hidden_states: torch.Tensor,
# Later, you can use our GEMM library to do the computation with this specific format
return recv_hidden_states, recv_expert_count, handle

def low_latency_dispatch_fp4(self, hidden_states: torch.Tensor,
scales: torch.Tensor, topk_idx: torch.Tensor,
num_max_dispatch_tokens_per_rank: int,
num_experts: int):
assert num_experts == self.num_experts

# Do MoE dispatch, compatible with CUDA graph (but you may restore some buffer status once you replay)
recv_hidden_states, recv_scales, recv_expert_count, handle, event, hook = \
self.buffer.low_latency_dispatch_fp4(hidden_states, scales, topk_idx, num_max_dispatch_tokens_per_rank, num_experts)
assert event.event is None
assert hook is None

# NOTES: the actual tensor will not be received only if you call `hook()`,
# it is useful for double-batch overlapping, but **without any SM occupation**
# If you don't want to overlap, please set `return_recv_hook=False`
# Later, you can use our GEMM library to do the computation with this specific format
return recv_hidden_states, recv_scales, recv_expert_count, handle

def low_latency_combine(self, hidden_states: torch.Tensor,
topk_idx: torch.Tensor, topk_weights: torch.Tensor,
handle: Tuple):
Expand Down
43 changes: 13 additions & 30 deletions tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py
Original file line number Diff line number Diff line change
Expand Up @@ -588,43 +588,26 @@ def forward_chunk(
x_sf = swizzle_sf(x_sf, x.shape[0], x.shape[1] * 2,
self.scaling_vector_size)
elif self.alltoall_method_type == AlltoallMethodType.DeepEPLowLatency:
assert x_sf is not None and self.has_nvfp4
token_num = x_row
hidden_size = x_col
assert x_sf is not None and self.has_nvfp4
assert hidden_size % 32 == 0
x_sf_dtype = x_sf.dtype
x_dtype = x.dtype
assert x_sf_dtype == torch.uint8 and x_dtype == torch.uint8
x_sf = x_sf.view(torch.bfloat16)
assert x.dtype == torch.uint8 and x_sf.dtype == torch.uint8
assert x_sf.shape[0] == token_num and x_sf.shape[
1] == hidden_size // 16 // 2
x = x.view(torch.bfloat16)
assert x.shape[0] == token_num and x.shape[1] == hidden_size // 4
# DeepEP LL dispatch only supports bf16 tensors with a hidden size of 2560, 4096, 5120, or 7168 as input. A hidden size of 2560 is sufficient to accommodate packed FP4 data.
packed_hidden_size = 2560
assert x.shape[1] + x_sf.shape[1] <= packed_hidden_size
fp4_packed_tensor = torch.empty((token_num, packed_hidden_size),
dtype=torch.bfloat16,
device=x.device)
fp4_packed_tensor[:, :x.shape[1]] = x
fp4_packed_tensor[:,
x.shape[1]:x.shape[1] + x_sf.shape[1]] = x_sf
1] == hidden_size // 16
assert x.shape[0] == token_num and x.shape[1] == hidden_size // 2

deep_ep_topk_idx = token_selected_slots
deep_ep_topk_weights = token_final_scales

assert all_rank_max_num_tokens <= self.deep_ep_max_num_tokens
fp4_packed_tensor, recv_expert_count, deep_ep_handle = \
self.deep_ep_buffer.low_latency_dispatch(fp4_packed_tensor, deep_ep_topk_idx, all_rank_max_num_tokens, self.num_slots)
deep_ep_handle = list(deep_ep_handle)
deep_ep_handle[3] = hidden_size
deep_ep_handle = tuple(deep_ep_handle)

assert fp4_packed_tensor.ndim == 3 and fp4_packed_tensor.shape[
2] == packed_hidden_size
x_sf = fp4_packed_tensor[:, :, x.shape[1]:x.shape[1] +
x_sf.shape[1]].contiguous()
x = fp4_packed_tensor[:, :, :x.shape[1]].contiguous()
x, x_sf, recv_expert_count, deep_ep_handle = \
self.deep_ep_buffer.low_latency_dispatch_fp4(x, x_sf, deep_ep_topk_idx, all_rank_max_num_tokens, self.num_slots)
assert x.dtype == torch.uint8 and x_sf.dtype == torch.uint8
assert x.dim() == 3 and x_sf.dim() == 3
assert x.shape[2] == hidden_size // 2 and x_sf.shape[
2] == hidden_size // 16

mask = torch.arange(
x.shape[1], dtype=torch.int32, device=x.device).expand(
x.shape[0], x.shape[1]) < recv_expert_count.unsqueeze(1)
Expand All @@ -634,9 +617,9 @@ def forward_chunk(
x.shape[0] * (self.mapping.moe_ep_rank + 1),
dtype=torch.int32,
device=x.device).unsqueeze(1), self.num_slots)
x = x.reshape(x.shape[0] * x.shape[1], x.shape[2]).view(x_dtype)
x = x.reshape(x.shape[0] * x.shape[1], x.shape[2])
x_sf = x_sf.reshape(x_sf.shape[0] * x_sf.shape[1],
x_sf.shape[2]).view(x_sf_dtype)
x_sf.shape[2])
x_sf = swizzle_sf(x_sf, x.shape[0], x.shape[1] * 2,
self.scaling_vector_size)
token_selected_slots = token_selected_slots.view(x.shape[0], 1)
Expand Down