diff --git a/tensorrt_llm/_torch/distributed/ops.py b/tensorrt_llm/_torch/distributed/ops.py index 74ac9590a38..b3811204dfa 100644 --- a/tensorrt_llm/_torch/distributed/ops.py +++ b/tensorrt_llm/_torch/distributed/ops.py @@ -312,6 +312,7 @@ def __init__(self, mapping: Mapping, dtype: torch.dtype): def get_supported_dtypes(): return (torch.float16, torch.bfloat16, torch.float32) + # Check if MNNVL is supported @staticmethod def is_mnnvl(mapping: Mapping, dtype: torch.dtype) -> bool: from tensorrt_llm._mnnvl_utils import MnnvlMemory @@ -455,8 +456,14 @@ def __init__(self, self.workspace = get_allreduce_workspace(self.mapping) # Initialize MNNVL AllReduce if needed - if self.strategy == AllReduceStrategy.MNNVL: - if MNNVLAllReduce.is_mnnvl(self.mapping, dtype): + if self.strategy in (AllReduceStrategy.AUTO, + AllReduceStrategy.MNNVL): + if self.mapping.tp_size != self.mapping.world_size: + logger.debug( + f"MNNVLAllReduce is disabled due to tp_size:{self.mapping.tp_size} " + f"!= world_size:{self.mapping.world_size}") + self.mnnvl_allreduce = None + elif MNNVLAllReduce.is_mnnvl(self.mapping, dtype): try: self.mnnvl_allreduce = MNNVLAllReduce( self.mapping, dtype) if dtype else None @@ -474,6 +481,9 @@ def __init__(self, ) self.mnnvl_allreduce = None + def is_mnnvl(self) -> bool: + return self.mnnvl_allreduce is not None + def forward( self, input: torch.Tensor, diff --git a/tensorrt_llm/_torch/models/modeling_deepseekv3.py b/tensorrt_llm/_torch/models/modeling_deepseekv3.py index 5fdcb43be3a..f34ff92ba93 100644 --- a/tensorrt_llm/_torch/models/modeling_deepseekv3.py +++ b/tensorrt_llm/_torch/models/modeling_deepseekv3.py @@ -611,7 +611,8 @@ def __init__(self, model_config: ModelConfig[PretrainedConfig], torch.cuda.Stream]): super().__init__() self.model_config = model_config - config = model_config.pretrained_config + self.config = model_config.pretrained_config + config = self.config self.hidden_size = config.hidden_size self.moe_intermediate_size = config.moe_intermediate_size @@ -642,6 +643,10 @@ def __init__(self, model_config: ModelConfig[PretrainedConfig], self.is_nvfp4 = quant_config.layer_quant_mode.has_nvfp4() has_tp = mapping.has_tp() + self.allreduce = AllReduce(mapping=model_config.mapping, + strategy=model_config.allreduce_strategy, + dtype=config.torch_dtype) + self.moe_allreduce = MoEAllReduce(self.mapping) if (config.n_routed_experts is not None and layer_idx >= config.first_k_dense_replace @@ -694,10 +699,6 @@ def __init__(self, model_config: ModelConfig[PretrainedConfig], eps=config.rms_norm_eps, dtype=config.torch_dtype) self.layer_idx = layer_idx - self.allreduce = AllReduce(mapping=model_config.mapping, - strategy=model_config.allreduce_strategy, - dtype=config.torch_dtype) - self.moe_allreduce = MoEAllReduce(self.mapping) self.next_layer_layernorm: RMSNorm = None def _get_decoder_layer_quant_config( @@ -743,10 +744,15 @@ def _compute_mlp_tp_size(self, intermediate_size: int, intermediate_size // block_size, self.mapping.tp_size, ) - mlp_tp_size = math.gcd( - tp, - self.mapping.gpus_per_node, - ) if tp > self.mapping.gpus_per_node else tp # Avoid costly inter-node TP + + if tp > self.mapping.gpus_per_node and not self.allreduce.is_mnnvl( + ): + mlp_tp_size = math.gcd( + tp, + self.mapping.gpus_per_node, + ) # Avoid costly inter-node TP when MNNVL is not supported + else: + mlp_tp_size = tp return mlp_tp_size def forward(