Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions tensorrt_llm/_torch/models/modeling_deepseekv3.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
from ..modules.embedding import Embedding
from ..modules.fused_moe import (DeepSeekV3MoeRoutingMethod,
MoEWeightLoadingMode, create_moe)
from ..modules.fused_moe.fused_moe_wide_ep import WideEPMoE
from ..modules.gated_mlp import GatedMLP
from ..modules.linear import Linear, TensorParallelMode, WeightsLoadingConfig
from ..modules.multi_stream_utils import maybe_execute_in_parallel
Expand Down Expand Up @@ -849,6 +850,9 @@ def compute_routed_output(self, hidden_states, hidden_states_fp4,
output_dtype=hidden_states.dtype,
all_rank_num_tokens=all_rank_num_tokens,
use_dp_padding=use_dp_padding,
**({
"alltoall_result_do_sum": False
} if isinstance(self.experts, WideEPMoE) else {}),
)

return routed_output
Expand Down
40 changes: 24 additions & 16 deletions tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,14 +419,15 @@ def reducescatter_or_allreduce(
return outputs

def forward_chunk(
self,
x: Union[torch.Tensor, Fp4QuantizedTensor],
router_logits: torch.Tensor,
use_all_to_all: bool,
output_dtype: Optional[torch.dtype] = None,
all_rank_num_tokens: Optional[List[int]] = None,
use_dp_padding: Optional[bool] = None,
repeating_info: Tuple = (True, True),
self,
x: Union[torch.Tensor, Fp4QuantizedTensor],
router_logits: torch.Tensor,
use_all_to_all: bool,
output_dtype: Optional[torch.dtype] = None,
all_rank_num_tokens: Optional[List[int]] = None,
use_dp_padding: Optional[bool] = None,
repeating_info: Tuple = (True, True),
alltoall_result_do_sum: bool = True,
) -> torch.Tensor:
all_rank_max_num_tokens = max(all_rank_num_tokens)
if isinstance(x, Fp4QuantizedTensor):
Expand All @@ -441,7 +442,7 @@ def forward_chunk(
self.layer_load_balancer.start_wait_gpu_stage()

if not use_all_to_all or self.alltoall_method_type != AlltoallMethodType.MNNVL:
pass
alltoall_result_do_sum = True

weight_dtype = self.w3_w1_weight.dtype

Expand Down Expand Up @@ -706,7 +707,8 @@ def forward_chunk(
if self.enable_dummy_allreduce:
self.dummy_allreduce()
final_hidden_states = self.alltoall_combine(
final_hidden_states, alltoall_info, token_count)
final_hidden_states, alltoall_info, token_count,
alltoall_result_do_sum)
elif self.alltoall_method_type == AlltoallMethodType.DeepEP:
final_hidden_states = self.unpad_tensors(
padded, final_hidden_states)
Expand Down Expand Up @@ -751,6 +753,7 @@ def forward_impl(
output_dtype: Optional[torch.dtype] = None,
all_rank_num_tokens: Optional[List[int]] = None,
use_dp_padding: Optional[bool] = None,
alltoall_result_do_sum: bool = True,
**kwargs,
) -> torch.Tensor:
assert all_rank_num_tokens is not None
Expand Down Expand Up @@ -778,7 +781,8 @@ def forward_impl(
output_dtype,
all_rank_num_tokens=all_rank_num_tokens_padded,
use_dp_padding=use_dp_padding,
repeating_info=(is_first_call, is_last_call))
repeating_info=(is_first_call, is_last_call),
alltoall_result_do_sum=alltoall_result_do_sum)
outputs = self.reducescatter_or_allreduce(
outputs,
use_all_to_all,
Expand Down Expand Up @@ -836,7 +840,8 @@ def split_chunk(split_token_num: int, split_num_chunks: int):
all_rank_num_tokens=all_rank_num_tokens_list[
idx_chunk],
use_dp_padding=use_dp_padding,
repeating_info=(is_first_call, is_last_call))
repeating_info=(is_first_call, is_last_call),
alltoall_result_do_sum=alltoall_result_do_sum)
if idx_chunk > 0:
outputs_list[-1] = self.reducescatter_or_allreduce(
outputs_list[-1],
Expand All @@ -852,7 +857,8 @@ def split_chunk(split_token_num: int, split_num_chunks: int):
all_rank_num_tokens=all_rank_num_tokens_list[
idx_chunk],
use_dp_padding=use_dp_padding,
repeating_info=(is_first_call, is_last_call))
repeating_info=(is_first_call, is_last_call),
alltoall_result_do_sum=alltoall_result_do_sum)
with torch.cuda.stream(self.aux_stream):
outputs_list[-1] = self.reducescatter_or_allreduce(
outputs_list[-1],
Expand All @@ -866,7 +872,8 @@ def split_chunk(split_token_num: int, split_num_chunks: int):
router_logits,
use_all_to_all,
all_rank_num_tokens=all_rank_num_tokens_list[idx_chunk],
repeating_info=(is_first_call, is_last_call))
repeating_info=(is_first_call, is_last_call),
alltoall_result_do_sum=alltoall_result_do_sum)

outputs_list.append(outputs)
if not use_all_to_all:
Expand Down Expand Up @@ -922,7 +929,8 @@ def alltoall_dispatch(self, x: torch.Tensor, x_sf: Optional[torch.Tensor],
return x, x_sf, token_selected_slots, token_final_scales

def alltoall_combine(self, final_hidden_states: torch.Tensor,
alltoall_info: MoEAlltoallInfo, token_count: int):
alltoall_info: MoEAlltoallInfo, token_count: int,
alltoall_result_do_sum: bool):
top_k = self.routing_method.experts_per_token
if isinstance(final_hidden_states, list):
final_hidden_states = final_hidden_states[0]
Expand All @@ -935,7 +943,7 @@ def alltoall_combine(self, final_hidden_states: torch.Tensor,
top_k=top_k,
token_count=token_count,
use_low_precision_combine=self.use_low_precision_combine,
do_reduce=False)
do_reduce=alltoall_result_do_sum)

return final_hidden_states

Expand Down
2 changes: 2 additions & 0 deletions tensorrt_llm/_torch/modules/fused_moe/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,7 @@ def forward(
output_dtype: Optional[torch.dtype] = None,
all_rank_num_tokens: Optional[List[int]] = None,
use_dp_padding: Optional[bool] = None,
**kwargs,
) -> Union[torch.Tensor, List[torch.Tensor]]:
if self.register_to_config and is_torch_compiling():
hidden_states = x.fp4_tensor if isinstance(
Expand Down Expand Up @@ -274,6 +275,7 @@ def forward(
output_dtype=output_dtype,
all_rank_num_tokens=all_rank_num_tokens,
use_dp_padding=use_dp_padding,
**kwargs,
)

@property
Expand Down
1 change: 1 addition & 0 deletions tests/integration/test_lists/test-db/l0_dgx_b200.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ l0_dgx_b200:
tests:
- unittest/_torch/multi_gpu_modeling -k "deepseek"
- unittest/_torch/modules/test_fused_moe.py::test_fused_moe_alltoall_fp4[DeepEPLowLatency]
- unittest/_torch/modules/test_fused_moe.py::test_fused_moe_alltoall_fp4[MNNVL]
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_bfloat16_4gpus[pp4-attn_backend=TRTLLM-torch_compile=False]
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_4gpus[tp4-fp8kv=True-attn_backend=TRTLLM-torch_compile=False]
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_4gpus[pp4-fp8kv=True-attn_backend=TRTLLM-torch_compile=False]
Expand Down
1 change: 1 addition & 0 deletions tests/integration/test_lists/test-db/l0_dgx_h100.yml
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ l0_dgx_h100:
- unittest/_torch/multi_gpu_modeling/test_deepseek.py::test_deepseek_streaming[tp4-bf16-trtllm-deepseekv3_lite]
- unittest/_torch/modules/test_fused_moe.py::test_fused_moe_alltoall[DeepEP]
- unittest/_torch/modules/test_fused_moe.py::test_fused_moe_alltoall[DeepEPLowLatency]
- unittest/_torch/modules/test_fused_moe.py::test_fused_moe_alltoall[MNNVL]
- unittest/_torch/modules/test_fused_moe.py::test_fused_moe_w4afp8[MoEWeightLoadingMode.VANILLA-dtype0]
- unittest/_torch/modules/test_fused_moe.py::test_fused_moe_w4afp8[MoEWeightLoadingMode.VANILLA-dtype1]
- unittest/_torch/modules/test_fused_moe.py::test_fused_moe_w4afp8[MoEWeightLoadingMode.W4A8_CUSTOM-dtype0]
Expand Down
13 changes: 9 additions & 4 deletions tests/unittest/_torch/modules/test_fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,11 +212,14 @@ def per_rank_test_fused_moe_alltoall(job_id):
weights = {}
for expert_id in range(NUM_EXPERTS):
w1_weight = torch.empty((INTERMEDIATE_SIZE, HIDDEN_SIZE),
dtype=dtype)
dtype=dtype,
device="cuda")
w2_weight = torch.empty((HIDDEN_SIZE, INTERMEDIATE_SIZE),
dtype=dtype)
dtype=dtype,
device="cuda")
w3_weight = torch.empty((INTERMEDIATE_SIZE, HIDDEN_SIZE),
dtype=dtype)
dtype=dtype,
device="cuda")
torch.nn.init.xavier_uniform_(w1_weight)
torch.nn.init.xavier_uniform_(w2_weight)
torch.nn.init.xavier_uniform_(w3_weight)
Expand Down Expand Up @@ -292,7 +295,6 @@ def per_rank_test_fused_moe_alltoall(job_id):
assert r is None


@pytest.mark.skip(reason="https://nvbugs/5467531")
@pytest.mark.skipif(torch.cuda.device_count() < 4,
reason="needs 4 GPUs to run this test")
@pytest.mark.parametrize("alltoall_method_type", [
Expand All @@ -302,6 +304,9 @@ def per_rank_test_fused_moe_alltoall(job_id):
ids=lambda s: s.name)
def test_fused_moe_alltoall_fp4(alltoall_method_type):

if alltoall_method_type == AlltoallMethodType.DeepEPLowLatency:
pytest.skip("Skipped due to https://nvbugs/5467531")

world_size = 4
dtype = torch.bfloat16
HIDDEN_SIZE = 2560
Expand Down