Skip to content

Commit 5cc454f

Browse files
kaiyuxzongfeijing
andcommitted
[None] [test] Add MNNVL AlltoAll tests to pre-merge (#7465)
Signed-off-by: Kaiyu Xie <[email protected]> Signed-off-by: Zongfei Jing <[email protected]> Co-authored-by: Zongfei Jing <[email protected]>
1 parent 77657a1 commit 5cc454f

File tree

5 files changed

+42
-20
lines changed

5 files changed

+42
-20
lines changed

tensorrt_llm/_torch/models/modeling_deepseekv3.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@
6060
from ..modules.fused_moe import (DeepSeekV3MoeRoutingMethod,
6161
MoEWeightLoadingMode, TRTLLMGenFusedMoE,
6262
create_moe)
63+
from ..modules.fused_moe.fused_moe_wide_ep import WideEPMoE
6364
from ..modules.gated_mlp import GatedMLP
6465
from ..modules.linear import Linear, TensorParallelMode, WeightsLoadingConfig
6566
from ..modules.multi_stream_utils import maybe_execute_in_parallel
@@ -570,6 +571,9 @@ def compute_routed_output(self, hidden_states, hidden_states_fp4,
570571
all_rank_num_tokens=all_rank_num_tokens,
571572
all_rank_max_num_tokens=all_rank_max_num_tokens,
572573
use_dp_padding=use_dp_padding,
574+
**({
575+
"alltoall_result_do_sum": False
576+
} if isinstance(self.experts, WideEPMoE) else {}),
573577
)
574578

575579
return routed_output

tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py

Lines changed: 27 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -368,15 +368,16 @@ def reducescatter_or_allreduce(
368368
return outputs
369369

370370
def forward_chunk(
371-
self,
372-
x: Union[torch.Tensor, Fp4QuantizedTensor],
373-
router_logits: torch.Tensor,
374-
use_all_to_all: bool,
375-
output_dtype: Optional[torch.dtype] = None,
376-
all_rank_num_tokens: Optional[List[int]] = None,
377-
all_rank_max_num_tokens: Optional[int] = None,
378-
use_dp_padding: Optional[bool] = None,
379-
repeating_info: Tuple = (True, True),
371+
self,
372+
x: Union[torch.Tensor, Fp4QuantizedTensor],
373+
router_logits: torch.Tensor,
374+
use_all_to_all: bool,
375+
output_dtype: Optional[torch.dtype] = None,
376+
all_rank_num_tokens: Optional[List[int]] = None,
377+
all_rank_max_num_tokens: Optional[int] = None,
378+
use_dp_padding: Optional[bool] = None,
379+
repeating_info: Tuple = (True, True),
380+
alltoall_result_do_sum: bool = True,
380381
) -> torch.Tensor:
381382
if isinstance(x, Fp4QuantizedTensor):
382383
assert output_dtype is not None
@@ -389,6 +390,9 @@ def forward_chunk(
389390
if self.layer_load_balancer and is_first_call:
390391
self.layer_load_balancer.start_wait_gpu_stage()
391392

393+
if not use_all_to_all or self.alltoall_method_type != AlltoallMethodType.MNNVL:
394+
alltoall_result_do_sum = True
395+
392396
use_deepseek_fp8_block_scale = False
393397
use_w4_group_scaling = False
394398
weight_dtype = self.w3_w1_weight.dtype
@@ -679,7 +683,8 @@ def forward_chunk(
679683
if self.enable_dummy_allreduce:
680684
self.dummy_allreduce()
681685
final_hidden_states = self.alltoall_combine(
682-
final_hidden_states, alltoall_info, token_count)
686+
final_hidden_states, alltoall_info, token_count,
687+
alltoall_result_do_sum)
683688
elif self.alltoall_method_type == AlltoallMethodType.DeepEP:
684689
final_hidden_states = self.unpad_tensors(
685690
padded, final_hidden_states)
@@ -719,6 +724,7 @@ def forward(
719724
all_rank_num_tokens: Optional[List[int]] = None,
720725
all_rank_max_num_tokens: Optional[int] = None,
721726
use_dp_padding: Optional[bool] = None,
727+
alltoall_result_do_sum: bool = True,
722728
) -> torch.Tensor:
723729
assert all_rank_num_tokens is not None
724730
assert use_dp_padding is not None
@@ -744,7 +750,8 @@ def forward(
744750
all_rank_num_tokens=all_rank_num_tokens_padded,
745751
all_rank_max_num_tokens=all_rank_max_num_tokens,
746752
use_dp_padding=use_dp_padding,
747-
repeating_info=(is_first_call, is_last_call))
753+
repeating_info=(is_first_call, is_last_call),
754+
alltoall_result_do_sum=alltoall_result_do_sum)
748755
outputs = self.reducescatter_or_allreduce(
749756
outputs,
750757
use_all_to_all,
@@ -804,7 +811,8 @@ def split_chunk(split_token_num: int, split_num_chunks: int):
804811
all_rank_max_num_tokens=
805812
all_rank_max_num_tokens_list[idx_chunk],
806813
use_dp_padding=use_dp_padding,
807-
repeating_info=(is_first_call, is_last_call))
814+
repeating_info=(is_first_call, is_last_call),
815+
alltoall_result_do_sum=alltoall_result_do_sum)
808816
if idx_chunk > 0:
809817
outputs_list[-1] = self.reducescatter_or_allreduce(
810818
outputs_list[-1],
@@ -822,7 +830,8 @@ def split_chunk(split_token_num: int, split_num_chunks: int):
822830
all_rank_max_num_tokens=all_rank_max_num_tokens_list[
823831
idx_chunk],
824832
use_dp_padding=use_dp_padding,
825-
repeating_info=(is_first_call, is_last_call))
833+
repeating_info=(is_first_call, is_last_call),
834+
alltoall_result_do_sum=alltoall_result_do_sum)
826835
with torch.cuda.stream(self.aux_stream):
827836
outputs_list[-1] = self.reducescatter_or_allreduce(
828837
outputs_list[-1],
@@ -838,7 +847,8 @@ def split_chunk(split_token_num: int, split_num_chunks: int):
838847
all_rank_num_tokens=all_rank_num_tokens_list[idx_chunk],
839848
all_rank_max_num_tokens=all_rank_max_num_tokens_list[
840849
idx_chunk],
841-
repeating_info=(is_first_call, is_last_call))
850+
repeating_info=(is_first_call, is_last_call),
851+
alltoall_result_do_sum=alltoall_result_do_sum)
842852

843853
outputs_list.append(outputs)
844854
if not use_all_to_all:
@@ -894,7 +904,8 @@ def alltoall_dispatch(self, x: torch.Tensor, x_sf: Optional[torch.Tensor],
894904
return x, x_sf, token_selected_slots, token_final_scales
895905

896906
def alltoall_combine(self, final_hidden_states: torch.Tensor,
897-
alltoall_info: MoEAlltoallInfo, token_count: int):
907+
alltoall_info: MoEAlltoallInfo, token_count: int,
908+
alltoall_result_do_sum: bool):
898909
top_k = self.routing_method.experts_per_token
899910
if isinstance(final_hidden_states, list):
900911
final_hidden_states = final_hidden_states[0]
@@ -907,7 +918,7 @@ def alltoall_combine(self, final_hidden_states: torch.Tensor,
907918
top_k=top_k,
908919
token_count=token_count,
909920
use_low_precision_combine=self.use_low_precision_combine,
910-
do_reduce=False)
921+
do_reduce=alltoall_result_do_sum)
911922

912923
return final_hidden_states
913924

tests/integration/test_lists/test-db/l0_dgx_b200.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ l0_dgx_b200:
1616
tests:
1717
- unittest/_torch/multi_gpu_modeling -k "deepseek"
1818
- unittest/_torch/modules/test_fused_moe.py::test_fused_moe_alltoall_fp4[DeepEPLowLatency]
19+
- unittest/_torch/modules/test_fused_moe.py::test_fused_moe_alltoall_fp4[MNNVL]
1920
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_bfloat16_4gpus[pp4-attn_backend=TRTLLM-torch_compile=False]
2021
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_4gpus[tp4-fp8kv=True-attn_backend=TRTLLM-torch_compile=False]
2122
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_4gpus[pp4-fp8kv=True-attn_backend=TRTLLM-torch_compile=False]

tests/integration/test_lists/test-db/l0_dgx_h100.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ l0_dgx_h100:
8383
- unittest/_torch/multi_gpu_modeling/test_deepseek.py::test_deepseek_streaming[tp4-bf16-trtllm-deepseekv3_lite]
8484
- unittest/_torch/modules/test_fused_moe.py::test_fused_moe_alltoall[DeepEP]
8585
- unittest/_torch/modules/test_fused_moe.py::test_fused_moe_alltoall[DeepEPLowLatency]
86+
- unittest/_torch/modules/test_fused_moe.py::test_fused_moe_alltoall[MNNVL]
8687
- unittest/_torch/modules/test_fused_moe.py::test_fused_moe_w4afp8[MoEWeightLoadingMode.VANILLA-dtype0]
8788
- unittest/_torch/modules/test_fused_moe.py::test_fused_moe_w4afp8[MoEWeightLoadingMode.VANILLA-dtype1]
8889
- unittest/_torch/modules/test_fused_moe.py::test_fused_moe_w4afp8[MoEWeightLoadingMode.W4A8_CUSTOM-dtype0]

tests/unittest/_torch/modules/test_fused_moe.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -212,11 +212,14 @@ def per_rank_test_fused_moe_alltoall(job_id):
212212
weights = {}
213213
for expert_id in range(NUM_EXPERTS):
214214
w1_weight = torch.empty((INTERMEDIATE_SIZE, HIDDEN_SIZE),
215-
dtype=dtype)
215+
dtype=dtype,
216+
device="cuda")
216217
w2_weight = torch.empty((HIDDEN_SIZE, INTERMEDIATE_SIZE),
217-
dtype=dtype)
218+
dtype=dtype,
219+
device="cuda")
218220
w3_weight = torch.empty((INTERMEDIATE_SIZE, HIDDEN_SIZE),
219-
dtype=dtype)
221+
dtype=dtype,
222+
device="cuda")
220223
torch.nn.init.xavier_uniform_(w1_weight)
221224
torch.nn.init.xavier_uniform_(w2_weight)
222225
torch.nn.init.xavier_uniform_(w3_weight)
@@ -289,7 +292,6 @@ def per_rank_test_fused_moe_alltoall(job_id):
289292
assert r is None
290293

291294

292-
@pytest.mark.skip(reason="https://nvbugs/5467531")
293295
@pytest.mark.skipif(torch.cuda.device_count() < 4,
294296
reason="needs 4 GPUs to run this test")
295297
@pytest.mark.parametrize("alltoall_method_type", [
@@ -299,6 +301,9 @@ def per_rank_test_fused_moe_alltoall(job_id):
299301
ids=lambda s: s.name)
300302
def test_fused_moe_alltoall_fp4(alltoall_method_type):
301303

304+
if alltoall_method_type == AlltoallMethodType.DeepEPLowLatency:
305+
pytest.skip("Skipped due to https://nvbugs/5467531")
306+
302307
world_size = 4
303308
dtype = torch.bfloat16
304309
HIDDEN_SIZE = 2560

0 commit comments

Comments
 (0)