Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions cpp/tensorrt_llm/deep_ep/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
set(DEEP_EP_COMMIT edf3ea2b086a393d3163bf2773eab69d9191cc01)
set(DEEP_EP_COMMIT 515a311f290eb6d9592fcccfcc80c40f5123ca72)
set(NVSHMEM_URL_HASH
SHA256=eb2c8fb3b7084c2db86bd9fd905387909f1dfd483e7b45f7b3c3d5fcf5374b5a)

Expand All @@ -19,8 +19,15 @@ foreach(CUDA_ARCH IN LISTS CMAKE_CUDA_ARCHITECTURES)
set(CUDA_ARCH_MINOR ${CMAKE_MATCH_2})
set(CUDA_ARCH_POSTFIX ${CMAKE_MATCH_3})
if(${CUDA_ARCH_MAJOR} GREATER_EQUAL 9)
list(APPEND DEEP_EP_CUDA_ARCHITECTURES
"${CUDA_ARCH_MAJOR}${CUDA_ARCH_MINOR}${CUDA_ARCH_POSTFIX}")
# The FP4-related conversion instructions in DeepEP require SM100a, SM110a,
# or SM120a.
if(${CUDA_ARCH_MAJOR} GREATER_EQUAL 10 AND ${CUDA_ARCH_MINOR} EQUAL 0)
list(APPEND DEEP_EP_CUDA_ARCHITECTURES
"${CUDA_ARCH_MAJOR}${CUDA_ARCH_MINOR}a${CUDA_ARCH_POSTFIX}")
else()
list(APPEND DEEP_EP_CUDA_ARCHITECTURES
"${CUDA_ARCH_MAJOR}${CUDA_ARCH_MINOR}${CUDA_ARCH_POSTFIX}")
endif()
endif()
endforeach()

Expand Down
28 changes: 17 additions & 11 deletions tensorrt_llm/_torch/modules/fused_moe/deep_ep_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,34 +154,40 @@ def low_latency_dispatch(self, hidden_states: torch.Tensor,
# Later, you can use our GEMM library to do the computation with this specific format
return recv_hidden_states, recv_expert_count, handle

def low_latency_combine(self, hidden_states: torch.Tensor,
topk_idx: torch.Tensor, topk_weights: torch.Tensor,
handle: Tuple):
# Do MoE combine, compatible with CUDA graph (but you may restore some buffer status once you replay)
combined_hidden_states, event, hook = \
self.buffer.low_latency_combine(hidden_states, topk_idx, topk_weights, handle)
assert event.event is None
assert hook is None

# NOTES: the same behavior as described in the dispatch kernel
return combined_hidden_states

def low_latency_dispatch_fp4(self, hidden_states: torch.Tensor,
scales: torch.Tensor, topk_idx: torch.Tensor,
num_max_dispatch_tokens_per_rank: int,
num_experts: int):
assert num_experts == self.num_experts

# Do MoE dispatch, compatible with CUDA graph (but you may restore some buffer status once you replay)
recv_hidden_states, recv_scales, recv_expert_count, handle, event, hook = \
self.buffer.low_latency_dispatch_fp4(hidden_states, scales, topk_idx, num_max_dispatch_tokens_per_rank, num_experts)
assert event.event is None
assert hook is None

# NOTES: the actual tensor will not be received only if you call `hook()`,
# it is useful for double-batch overlapping, but **without any SM occupation**
# If you don't want to overlap, please set `return_recv_hook=False`
# Later, you can use our GEMM library to do the computation with this specific format
return recv_hidden_states, recv_scales, recv_expert_count, handle

def low_latency_combine(self, hidden_states: torch.Tensor,
topk_idx: torch.Tensor, topk_weights: torch.Tensor,
handle: Tuple):
# Do MoE combine, compatible with CUDA graph (but you may restore some buffer status once you replay)
def low_latency_combine_fp4(self, hidden_states: torch.Tensor,
global_scales: torch.Tensor,
topk_idx: torch.Tensor,
topk_weights: torch.Tensor, handle: Tuple):
combined_hidden_states, event, hook = \
self.buffer.low_latency_combine(hidden_states, topk_idx, topk_weights, handle)
self.buffer.low_latency_combine_fp4(hidden_states, global_scales, topk_idx, topk_weights, handle)
assert event.event is None
assert hook is None

# NOTES: the same behavior as described in the dispatch kernel
return combined_hidden_states

def clean_low_latency_buffer(self, num_max_dispatch_tokens_per_rank: int,
Expand Down
17 changes: 14 additions & 3 deletions tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,11 +183,15 @@ def __init__(
f"{self.__class__.__name__} selects alltoall_method_type {self.alltoall_method_type!r}",
key="alltoall_method_type")
self.use_postquant_alltoall = False
self.use_low_precision_combine = False
if self.enable_alltoall:
qm = self.quant_config.quant_mode
self.use_postquant_alltoall = (os.environ.get(
"TRTLLM_MOE_POST_QUANT_ALLTOALLV", "1")
== "1") and qm.has_nvfp4()
self.use_low_precision_combine = (os.environ.get(
"TRTLLM_MOE_USE_LOW_PRECISION_COMBINE", "0")
== "1") and qm.has_nvfp4()
# TODO: support alltoall without allgather for top_k % 4 != 0
self.enable_alltoall_without_allgather = (
os.environ.get("TRTLLM_MOE_ENABLE_ALLTOALL_WITHOUT_ALLGATHER",
Expand Down Expand Up @@ -684,9 +688,16 @@ def forward_chunk(
final_hidden_states = final_hidden_states.view(
self.expert_size_per_partition,
num_tokens_per_expert_for_fused_moe, self.hidden_size)
final_hidden_states = self.deep_ep_buffer.low_latency_combine(
final_hidden_states, deep_ep_topk_idx, deep_ep_topk_weights,
deep_ep_handle)
if self.use_low_precision_combine:
global_scales = (448 * 6) / final_hidden_states.abs().max(
dim=-1, keepdim=True).values.to(torch.float32)
final_hidden_states = self.deep_ep_buffer.low_latency_combine_fp4(
final_hidden_states, global_scales, deep_ep_topk_idx,
deep_ep_topk_weights, deep_ep_handle)
else:
final_hidden_states = self.deep_ep_buffer.low_latency_combine(
final_hidden_states, deep_ep_topk_idx,
deep_ep_topk_weights, deep_ep_handle)
else:
raise NotImplementedError(
f"Not available alltoall method type: {self.alltoall_method_type!r}"
Expand Down