3838from tqdm import tqdm
3939from transformers import PretrainedConfig
4040
41- from tensorrt_llm ._mnnvl_utils import MnnvlMemory
4241from tensorrt_llm .functional import PositionEmbeddingType
4342from tensorrt_llm .llmapi .utils import enable_llm_debug
4443from tensorrt_llm .mapping import Mapping
@@ -413,10 +412,6 @@ def __init__(self,
413412 config = model_config .pretrained_config
414413 self .top_k = top_k
415414 self .use_dp = model_config .mapping .enable_attention_dp
416- self .enable_alltoall = Deepseekv3MoE .should_enable_alltoall (
417- model_config , top_k )
418- if self .enable_alltoall :
419- MnnvlMemory .initialize ()
420415 self .gate = DeepseekV3Gate (
421416 hidden_size ,
422417 num_experts ,
@@ -439,7 +434,6 @@ def __init__(self,
439434 model_config = model_config ,
440435 override_quant_config = override_quant_config ,
441436 aux_stream = aux_stream_dict [AuxStreamType .MoeChunkingOverlap ],
442- enable_alltoall = self .enable_alltoall ,
443437 layer_idx = layer_idx )
444438
445439 self .mapping = model_config .mapping
@@ -505,33 +499,14 @@ def _compute_shared_expert_tp_size(self, intermediate_size: int,
505499
506500 return shared_tp_size , shared_output_scale
507501
508- @staticmethod
509- def should_enable_alltoall (model_config : ModelConfig , top_k : int ) -> bool :
510- if not model_config .mapping .enable_attention_dp :
511- return False
512-
513- if model_config .mapping .tp_size == 1 :
514- return False
515-
516- if not MnnvlMemory .supports_mnnvl ():
517- return False
518-
519- if os .environ .get ("TRTLLM_MOE_DISABLE_ALLTOALLV" , "0" ) == "1" :
520- return False
521-
522- if model_config .mapping .moe_ep_size <= top_k :
523- return False
524-
525- return True
526-
527502 def compute_routed_output (self , hidden_states , hidden_states_fp4 ,
528503 all_rank_num_tokens , do_finalize ):
529504 # max-throughput
530505 use_dp_padding = False
531506 if self .use_dp and self .mapping .tp_size > 1 :
532507 # FP4 all_gather moves this bf16 allgather in to after topk and fp4 quantization
533508 # to reduce allreduce BW
534- if disable_fp4_allgather () and not self .enable_alltoall :
509+ if disable_fp4_allgather () and not self .experts . enable_alltoall :
535510 hidden_states = allgather (hidden_states ,
536511 self .mapping ,
537512 dim = 0 ,
0 commit comments