From e8e5235a3dc76a32c90f7f8587059437e6238c2d Mon Sep 17 00:00:00 2001 From: linquanh Date: Wed, 6 Aug 2025 02:33:47 -0700 Subject: [PATCH 1/2] refactor mtp worker Signed-off-by: linquanh --- .../_torch/models/modeling_deepseekv3.py | 110 +- .../_torch/models/modeling_speculative.py | 998 +++++++++++++++++- tensorrt_llm/_torch/speculative/eagle3.py | 2 +- tensorrt_llm/_torch/speculative/interface.py | 3 + tensorrt_llm/_torch/speculative/mtp.py | 24 +- tensorrt_llm/_torch/speculative/utils.py | 7 +- 6 files changed, 1020 insertions(+), 124 deletions(-) diff --git a/tensorrt_llm/_torch/models/modeling_deepseekv3.py b/tensorrt_llm/_torch/models/modeling_deepseekv3.py index a5a61d9a7d2..e568347a1e3 100644 --- a/tensorrt_llm/_torch/models/modeling_deepseekv3.py +++ b/tensorrt_llm/_torch/models/modeling_deepseekv3.py @@ -26,50 +26,29 @@ # -------------------------------------------------- import copy -import math -import os -import warnings -from typing import Dict, List, Optional, Tuple +from typing import Dict, Optional import torch -import torch.nn.functional as F import triton import triton.language as tl from torch import nn from tqdm import tqdm from transformers import PretrainedConfig -from tensorrt_llm._ipc_utils import can_access_peer from tensorrt_llm._utils import get_sm_version -from tensorrt_llm.functional import PositionEmbeddingType -from tensorrt_llm.llmapi.utils import enable_llm_debug -from tensorrt_llm.mapping import Mapping -from tensorrt_llm.models.modeling_utils import QuantConfig from tensorrt_llm.quantization.utils.fp8_utils import ( resmooth_to_fp8_e8m0, transform_sf_into_required_layout) from ..attention_backend import AttentionMetadata -from ..attention_backend.interface import PositionalEmbeddingParams, RopeParams -from ..distributed import (AllReduce, AllReduceFusionOp, AllReduceParams, - MoEAllReduce, MoEAllReduceParams, allgather) from ..model_config import ModelConfig -from ..models.modeling_utils import ModelConfig, QuantConfig -from ..modules.attention import MLA -from ..modules.decoder_layer import DecoderLayer from ..modules.embedding import Embedding -from ..modules.fused_moe import (DeepSeekV3MoeRoutingMethod, TRTLLMGenFusedMoE, - create_moe, - moe_load_balancer_set_repeated_for_next_layer) -from ..modules.gated_mlp import GatedMLP -from ..modules.linear import Linear, TensorParallelMode, WeightsLoadingConfig -from ..modules.multi_stream_utils import maybe_execute_in_parallel +from ..modules.fused_moe import moe_load_balancer_set_repeated_for_next_layer from ..modules.rms_norm import RMSNorm -from ..peft.lora.layer import LoraLayer -from ..speculative import MTPEagleWorker, MTPSpecMetadata, MTPWorker -from ..utils import AuxStreamType, EventType, Fp4QuantizedTensor -from .modeling_utils import (DecoderModel, DecoderModelForCausalLM, - EagerFusionConfig, filter_weights, - register_auto_model) +from ..speculative import MTPSpecMetadata, SpecMetadata +from ..utils import AuxStreamType +from .modeling_speculative import (DeepseekV3DecoderLayer, + SpecDecOneEngineForCausalLM) +from .modeling_utils import DecoderModel, filter_weights, register_auto_model @triton.jit @@ -128,7 +107,6 @@ def weight_dequant(x: torch.Tensor, weight_dequant_kernel[grid](x, s, y, M, N, BLOCK_SIZE=block_size) return y - class DeepseekV3MTPHead(nn.Module): def __init__(self, model_config: ModelConfig[PretrainedConfig]): @@ -1078,6 +1056,8 @@ def forward( input_ids: Optional[torch.IntTensor] = None, position_ids: Optional[torch.IntTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, + spec_metadata: Optional[SpecMetadata] = None, + **kwargs, ) -> torch.Tensor: if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError( @@ -1102,8 +1082,8 @@ def forward( @register_auto_model("DeepseekV3ForCausalLM") -class DeepseekV3ForCausalLM(DecoderModelForCausalLM[DeepseekV3Model, - PretrainedConfig]): +class DeepseekV3ForCausalLM(SpecDecOneEngineForCausalLM[DeepseekV3Model, + PretrainedConfig]): def __init__(self, model_config: ModelConfig[PretrainedConfig]): # Rename some keys of quant_config_dict to support legacy checkpoints @@ -1118,10 +1098,9 @@ def __init__(self, model_config: ModelConfig[PretrainedConfig]): model_config._frozen = False model_config.quant_config_dict = quant_config_dict model_config._frozen = True - super().__init__(DeepseekV3Model(model_config), - config=model_config, - hidden_size=model_config.pretrained_config.hidden_size, - vocab_size=model_config.pretrained_config.vocab_size) + + super().__init__(model=DeepseekV3Model(model_config), + model_config=model_config) self.model_nextn = 0 if model_config.spec_config is not None: @@ -1131,23 +1110,7 @@ def __init__(self, model_config: ModelConfig[PretrainedConfig]): assert ckpt_nextn > 0, "There is not MTP modules in the checkpoint." if ckpt_nextn == 1 and not model_config.spec_config.use_mtp_vanilla: moe_load_balancer_set_repeated_for_next_layer(model_nextn) - mtp_layer = DeepseekV3MTP(model_config, self.num_hidden_layers, - self.model.aux_stream_dict) - self.model.layers.append(mtp_layer) - self.epilogue.append(mtp_layer) - self.mtp_worker = MTPEagleWorker(model_config.spec_config, - model_config) else: - mtp_layers = nn.ModuleList([ - DeepseekV3MTP(model_config, - layer_idx + self.num_hidden_layers, - self.model.aux_stream_dict) - for layer_idx in range(model_nextn) - ]) - self.model.layers.extend(mtp_layers) - self.epilogue.extend(mtp_layers) - self.mtp_worker = MTPWorker(model_config.spec_config, - model_config) # modify the QuantConfig to support duplicated mtp layers if model_config.quant_config.exclude_modules is not None: extend_exclude_modules = [] @@ -1165,7 +1128,9 @@ def __init__(self, model_config: ModelConfig[PretrainedConfig]): ckpt_prefix, model_prefix)) self.model_config.quant_config.exclude_modules.extend( extend_exclude_modules) - self.epilogue.append(self.mtp_worker) + self.model.layers.extend(self.draft_model.mtp_layers) + self.epilogue.extend(self.draft_model.mtp_layers) + self.epilogue.append(self.spec_worker) def forward( self, @@ -1178,40 +1143,13 @@ def forward( **kwargs, ) -> torch.Tensor: attn_metadata.num_generations_per_batch = self.model_nextn + 1 - hidden_states = self.model( - input_ids=input_ids, - attn_metadata=attn_metadata, - position_ids=position_ids, - inputs_embeds=inputs_embeds, - ) - - if spec_metadata and spec_metadata.spec_dec_mode.is_mtp(): - # get logits - logits = self.logits_processor.forward( - hidden_states[spec_metadata.gather_ids], - self.lm_head, - attn_metadata, - True, - ) - # get accepted tokens and next draft tokens - return self.mtp_worker( - input_ids=input_ids, - position_ids=position_ids, - hidden_states=hidden_states, - logits=logits, - lm_head=self.lm_head, - embed_tokens=self.model.embed_tokens, - attn_metadata=attn_metadata, - spec_metadata=spec_metadata, - mtp_layers=self.model.layers[self.num_hidden_layers:]) - else: - logits = self.logits_processor.forward( - hidden_states, - self.lm_head, - attn_metadata, - return_context_logits, - ) - return logits + return super().forward(attn_metadata=attn_metadata, + input_ids=input_ids, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + spec_metadata=spec_metadata, + return_context_logits=return_context_logits, + **kwargs) def load_weights(self, weights: Dict): diff --git a/tensorrt_llm/_torch/models/modeling_speculative.py b/tensorrt_llm/_torch/models/modeling_speculative.py index e8a57742115..9414c25daf6 100644 --- a/tensorrt_llm/_torch/models/modeling_speculative.py +++ b/tensorrt_llm/_torch/models/modeling_speculative.py @@ -1,26 +1,42 @@ -from typing import Any, Dict, Generic, Optional, Tuple +import math +import os +import warnings +from typing import Any, Dict, Generic, List, Optional, Tuple import torch +import torch.nn.functional as F from torch import nn -from transformers import LlamaConfig +from transformers import LlamaConfig, PretrainedConfig +from tensorrt_llm._ipc_utils import can_access_peer from tensorrt_llm._torch.models.checkpoints.base_weight_mapper import \ BaseWeightMapper +from tensorrt_llm._utils import get_sm_version from tensorrt_llm.functional import PositionEmbeddingType +from tensorrt_llm.llmapi.utils import enable_llm_debug +from tensorrt_llm.mapping import Mapping +from tensorrt_llm.models.modeling_utils import QuantConfig from ..attention_backend import AttentionMetadata from ..attention_backend.interface import PositionalEmbeddingParams, RopeParams +from ..distributed import (AllReduce, AllReduceFusionOp, AllReduceParams, + MoEAllReduce, MoEAllReduceParams, allgather) from ..model_config import ModelConfig, TConfig -from ..modules.attention import Attention +from ..modules.attention import MLA, Attention from ..modules.decoder_layer import DecoderLayer from ..modules.embedding import Embedding +from ..modules.fused_moe import (DeepSeekV3MoeRoutingMethod, TRTLLMGenFusedMoE, + create_moe) from ..modules.gated_mlp import GatedMLP from ..modules.linear import (Linear, TensorParallelMode, WeightMode, WeightsLoadingConfig) +from ..modules.multi_stream_utils import maybe_execute_in_parallel from ..modules.rms_norm import RMSNorm +from ..peft.lora.layer import LoraLayer from ..speculative import SpecMetadata, get_spec_worker -from .modeling_utils import (DecoderModel, DecoderModelForCausalLM, TModel, - register_auto_model) +from ..utils import AuxStreamType, EventType, Fp4QuantizedTensor +from .modeling_utils import (DecoderModel, DecoderModelForCausalLM, + EagerFusionConfig, TModel, register_auto_model) class Eagle3Attention(Attention): @@ -320,12 +336,950 @@ def apply_eagle3_fc(self, hidden_states: torch.Tensor) -> torch.Tensor: return hidden_states -def get_draft_model(model_config, draft_config): +class DeepseekV3MTPHead(nn.Module): + + def __init__(self, model_config: ModelConfig[PretrainedConfig]): + super().__init__() + config = model_config.pretrained_config + self.model_config = model_config + + self.norm = RMSNorm(hidden_size=config.hidden_size, + eps=config.rms_norm_eps, + dtype=config.torch_dtype) + + @torch.compile(options={"max-autotune": True}) + def get_last_token_states(self, hidden_states, attn_metadata): + last_tokens = torch.cumsum( + attn_metadata.seq_lens_cuda, + dim=0, + dtype=torch.long, + ) - 1 + return hidden_states[last_tokens] + + def forward(self, + hidden_states: torch.Tensor, + lm_head: Linear, + attn_metadata: AttentionMetadata, + return_context_logits: bool = False) -> torch.Tensor: + if not return_context_logits: + if attn_metadata is not None: + hidden_states = self.get_last_token_states( + hidden_states, attn_metadata) + else: + hidden_states = hidden_states[-1].unsqueeze(0) + + if not (self.model_config.mapping.enable_attention_dp): + lm_head.gather_output = False + logits = lm_head(hidden_states) + if not (self.model_config.mapping.enable_attention_dp): + lm_head.gather_output = True + return logits + + +class DeepseekV3Linear(Linear): + """ + A wrapper around Linear because we may optionally use min-latency kernels depending on input shapes. + """ + + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + dtype: torch.dtype = None, + mapping: Optional[Mapping] = None, + tensor_parallel_mode: Optional[TensorParallelMode] = None, + gather_output: bool = False, # COLUMN parallel only + quant_config: Optional[QuantConfig] = None, + weights_loading_config: Optional[WeightsLoadingConfig] = None, + reduce_output: bool = True, # ROW parallel only + skip_create_weights_in_init: bool = False, + use_custom_cublas_mm: bool = False, + lora: Optional[LoraLayer] = None, + ): + super().__init__( + in_features, + out_features, + bias, + dtype, + mapping, + tensor_parallel_mode, + gather_output, + quant_config, + weights_loading_config, + reduce_output, + skip_create_weights_in_init, + use_custom_cublas_mm, + lora, + ) + + def apply_linear(self, + input, + bias, + lora_params: Optional[dict] | None = None, + layer_idx: Optional[int] | None = None): + num_tokens = input.shape[0] + if (not self.has_any_quant and 1 <= num_tokens <= 16 + and get_sm_version() != 120): + output = torch.ops.trtllm.dsv3_fused_a_gemm_op( + input, self.weight.t(), bias, None) + else: + output = super().apply_linear(input, bias, lora_params, layer_idx) + return output + + +class DeepseekV3Attention(MLA): + + def __init__( + self, + model_config: ModelConfig[PretrainedConfig], + layer_idx: Optional[int] = None, + aux_stream: Optional[torch.cuda.Stream] = None, + ): + config = model_config.pretrained_config + predicted_tokens_per_seq = model_config.spec_config.num_nextn_predict_layers + 1 if model_config.spec_config is not None else 1 + super().__init__(hidden_size=config.hidden_size, + num_attention_heads=config.num_attention_heads, + num_key_value_heads=config.num_key_value_heads, + qk_rope_head_dim=config.qk_rope_head_dim, + qk_nope_head_dim=config.qk_nope_head_dim, + q_lora_rank=config.q_lora_rank, + kv_lora_rank=config.kv_lora_rank, + v_head_dim=config.v_head_dim, + predicted_tokens_per_seq=predicted_tokens_per_seq, + max_position_embeddings=config.max_position_embeddings, + bias=False, + pos_embd_params=PositionalEmbeddingParams( + type=PositionEmbeddingType.yarn, + rope=RopeParams.from_config(config), + is_neox=False, + ), + layer_idx=layer_idx, + dtype=config.torch_dtype, + config=model_config, + aux_stream=aux_stream) + self.kv_a_proj_with_mqa = DeepseekV3Linear( + config.hidden_size, + self.kv_lora_rank + self.qk_rope_head_dim + + (self.q_lora_rank if not self.is_lite else 0), + bias=False, + dtype=config.torch_dtype, + quant_config=model_config.get_quant_config(), + skip_create_weights_in_init=model_config. + skip_create_weights_in_init, + use_custom_cublas_mm=True) + + +class Deepseekv3RoutingImpl(): + + def __init__( + self, + top_k: int, + n_group: int, + topk_group: int, + routed_scaling_factor: float, + is_fused: bool = True, + ): + super().__init__() + self.top_k = top_k + self.topk_group = topk_group + self.n_group = n_group + self.routed_scaling_factor = routed_scaling_factor + self.is_fused = is_fused + + def noaux_tc(self, logits, e_score_correction_bias): + n_group = self.n_group + scores = F.sigmoid(logits) + scores_with_bias = scores + e_score_correction_bias + scores_shape = list(scores_with_bias.shape) + + if enable_llm_debug(): + has_nan = torch.isnan(scores_with_bias).any() + if has_nan: + warnings.warn( + "Detected NAN in the tensor scores_with_bias. Please check if it matches the expectation." + ) + + if not self.is_fused: + group_scores = torch.sum(torch.topk( + scores_with_bias.view(scores_shape[:-1] + + [n_group, scores_shape[-1] // n_group]), + k=2, + dim=-1, + largest=True, + sorted=True)[0], + dim=-1) + _, group_idx = torch.topk(group_scores, + k=self.topk_group, + dim=-1, + largest=True, + sorted=True) + group_mask = torch.zeros_like(group_scores) + group_mask.scatter_(-1, group_idx, 1) + score_mask = group_mask.unsqueeze(-1).expand( + scores_shape[:-1] + + [n_group, scores_shape[-1] // n_group]).reshape(scores_shape) + scores_with_bias = scores_with_bias * score_mask + _, topk_idx = torch.topk(scores_with_bias, + k=self.top_k, + dim=-1, + largest=True, + sorted=True) + new_mask = torch.zeros_like(scores) + new_mask.scatter_(-1, topk_idx, 1) + scores = scores * new_mask + score_sum = torch.sum(scores, dim=-1, keepdim=True) + 1e-20 + scores = scores / score_sum * \ + self.routed_scaling_factor + topk_values, topk_indices = torch.topk(scores, + k=self.top_k, + dim=-1, + largest=True) + return topk_values, topk_indices + else: + topk_values, topk_indices = torch.ops.trtllm.noaux_tc_op( + scores, scores_with_bias, n_group, self.topk_group, self.top_k, + self.routed_scaling_factor) + return topk_values, topk_indices + + def apply( + self, logits: torch.Tensor, e_score_correction_bias: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: + topk_values, topk_indices = self.noaux_tc(logits, + e_score_correction_bias) + return topk_indices.to(torch.int32), topk_values.to(torch.float32) + + +class DeepseekV3Gate(DeepSeekV3MoeRoutingMethod): + + def __init__( + self, + hidden_size: int, + num_experts: int, + top_k: int, + n_group: int, + topk_group: int, + routed_scaling_factor: float, + dtype: Optional[torch.dtype] = None, + fuse_routing_kernel: bool = True, + apply_routing: bool = False, + moe_backend: str = 'CUTLASS', + ): + super().__init__(top_k=top_k) + self.weight = nn.Parameter(torch.empty((num_experts, hidden_size), + dtype=dtype), + requires_grad=False) + self.moe_backend = moe_backend + if moe_backend == 'TRTLLM': + bias_dtype = torch.bfloat16 + else: + bias_dtype = torch.float32 + + self.e_score_correction_bias = nn.Parameter(torch.empty( + (num_experts), dtype=bias_dtype), + requires_grad=False) + + assert not apply_routing, "DeepseekV3Gate routing is called inside MoE" + + # TODO: e_score_correction_bias belongs in this gate class but is required by the routing impl. + # To avoid weight-loading issues, we treat this gate as the BaseMoeRoutingMethod and dispatch to the routing impl. + # This is a temporary hack that should be refactored later. + self.routing_impl = Deepseekv3RoutingImpl( + top_k=top_k, + n_group=n_group, + topk_group=topk_group, + routed_scaling_factor=routed_scaling_factor, + is_fused=fuse_routing_kernel) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + logits = torch.ops.trtllm.dsv3_router_gemm_op(hidden_states, + self.weight.t(), + bias=None, + out_dtype=torch.float32) + return logits + + def load_weights(self, weights: List[Dict]): + assert len(weights) == 1 + + self.weight.copy_(weights[0]["weight"][:]) + + self.e_score_correction_bias.copy_( + weights[0]["e_score_correction_bias"][:].to( + self.e_score_correction_bias.dtype)) + + def apply(self, logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + # topk routing + return self.routing_impl.apply(logits, self.e_score_correction_bias) + + @property + def routing_method(self) -> DeepSeekV3MoeRoutingMethod: + return self + + def get_experts_per_token(self): + return self.routing_impl.top_k + + +class Deepseekv3MoE(nn.Module): + + def __init__(self, + *, + num_experts: int, + top_k: int, + hidden_size: int, + intermediate_size: int, + shared_expert_intermediate_size: int, + aux_stream_dict: Dict[AuxStreamType, torch.cuda.Stream], + dtype: Optional[torch.dtype] = None, + model_config: ModelConfig = ModelConfig(), + override_quant_config: Optional[QuantConfig] = None, + layer_idx: Optional[int] = None): + from ..distributed import AllReduce + + super().__init__() + config = model_config.pretrained_config + self.top_k = top_k + self.use_dp = model_config.mapping.enable_attention_dp + self.gate = DeepseekV3Gate( + hidden_size, + num_experts, + top_k=top_k, + n_group=config.n_group, + topk_group=config.topk_group, + routed_scaling_factor=config.routed_scaling_factor, + dtype=dtype, + fuse_routing_kernel=True, + apply_routing=False, + moe_backend=model_config.moe_backend) + self.experts = create_moe( + num_experts=num_experts, + routing_method=self.gate.routing_method, + hidden_size=hidden_size, + intermediate_size=intermediate_size, + dtype=dtype, + reduce_results= + False, # In both low‑latency and attention‑DP modes, FusedMoE skips the in‑op all‑reduce. + model_config=model_config, + override_quant_config=override_quant_config, + aux_stream=aux_stream_dict[AuxStreamType.MoeChunkingOverlap], + layer_idx=layer_idx) + + self.mapping = model_config.mapping + + # FIXME: incompatible with mixed quantization mode (including excluding modules from quantization) + block_size = 1 + if model_config.quant_config and model_config.quant_config.group_size is not None: + block_size = model_config.quant_config.group_size + + shared_tp_size, self.shared_output_scale = self._compute_shared_expert_tp_size( + shared_expert_intermediate_size, block_size) + + self.shared_experts = GatedMLP( + hidden_size=hidden_size, + intermediate_size=shared_expert_intermediate_size, + bias=False, + dtype=dtype, + config=model_config, + overridden_tp_size=shared_tp_size, + reduce_output=False) + + self.allreduce = AllReduce(mapping=model_config.mapping, + strategy=model_config.allreduce_strategy) + self.aux_stream = aux_stream_dict[AuxStreamType.MoeShared] + self.event_dict = { + key: torch.cuda.Event() + for key in [EventType.Main, EventType.MoeShared] + } + + def _compute_shared_expert_tp_size(self, intermediate_size: int, + block_size: int) -> int: + """ + In the case of Deepseek-R1, the TP size of MLP is capped by intermediate_size // block_size. + For example, when the intermediate_size is 2048 and block scaling size is 128, + TP sizes are limited to {1, 2, 4, 8, 16} because of 2048/128 = 16. + + Args: + intermediate_size (int): MLP intermediate size. + block_size (int): The quantization block scale size. In the case of Deepseek FP8 recipe, + it's 128. For NVFP4, it's 16. + + Returns: + int: The computed tp_size. + """ + + assert intermediate_size % block_size == 0, "intermediate_size must be divisible by block_size." + + shared_output_scale = None + # The block scale size is 128, which requires shared_expert_intermediate_size to be divisible by 128. + if self.use_dp: + # If using attention DP, the shared experts also use DP instead of TP. + shared_tp_size = 1 + else: + # Due to the restriction of block scale size (i.e., 128), the supported TP sizes only include 1, 2, 4, 8, and 16. + # The math.gcd operation ensures that shared_tp_size falls in the supported TP sizes. + shared_tp_size = math.gcd( + intermediate_size // block_size, + self.mapping.tp_size, + ) + # If shared_tp_size has been overridden, the output of shared experts needs to be scaled down accordingly before all-reduce. + if shared_tp_size != self.mapping.tp_size: + shared_output_scale = shared_tp_size / self.mapping.tp_size + + return shared_tp_size, shared_output_scale + + def compute_routed_output(self, hidden_states, hidden_states_fp4, + all_rank_num_tokens, all_rank_max_num_tokens, + do_finalize): + # max-throughput + use_dp_padding = False + if self.use_dp and self.mapping.tp_size > 1: + if isinstance(self.experts, TRTLLMGenFusedMoE): + hidden_states = allgather(hidden_states, + self.mapping, + dim=0, + sizes=all_rank_num_tokens) + + router_logits = self.gate(hidden_states) + + routed_output = self.experts( + hidden_states_fp4 or hidden_states, + router_logits, + do_finalize=do_finalize, + output_dtype=hidden_states.dtype, + all_rank_num_tokens=all_rank_num_tokens, + all_rank_max_num_tokens=all_rank_max_num_tokens, + use_dp_padding=use_dp_padding, + ) + + return routed_output + + def forward( + self, + hidden_states: torch.Tensor, + hidden_states_fp4: Optional[Fp4QuantizedTensor] = None, + all_rank_num_tokens: Optional[list[int]] = None, + all_rank_max_num_tokens: Optional[int] = None, + final_all_reduce_params: Optional[AllReduceParams] = None, + do_finalize: Optional[bool] = True, + ) -> torch.Tensor: + if not do_finalize: + assert not self.use_dp + + def _compute_shared_output(): + shared_output = self.shared_experts(hidden_states_fp4 + or hidden_states) + if self.shared_output_scale is not None: + shared_output *= self.shared_output_scale + return shared_output + + def _compute_routed_output(): + routed_output = self.compute_routed_output(hidden_states, + hidden_states_fp4, + all_rank_num_tokens, + all_rank_max_num_tokens, + do_finalize) + return routed_output + + routed_output, shared_output = maybe_execute_in_parallel( + _compute_routed_output, _compute_shared_output, + self.event_dict[EventType.Main], + self.event_dict[EventType.MoeShared], self.aux_stream) + + if not do_finalize: + return [shared_output, *routed_output] + else: + assert shared_output.size() == routed_output.size( + ), f'unmatched tensor shape' + final_hidden_states = shared_output + routed_output + if not self.use_dp and self.mapping.tp_size > 1: + final_hidden_states = self.allreduce( + final_hidden_states, + all_reduce_params=final_all_reduce_params) + + return final_hidden_states + + +class DeepseekV3DecoderLayer(DecoderLayer): + + def __init__(self, model_config: ModelConfig[PretrainedConfig], + layer_idx: int, aux_stream_dict: Dict[AuxStreamType, + torch.cuda.Stream]): + super().__init__() + self.model_config = model_config + config = model_config.pretrained_config + + self.hidden_size = config.hidden_size + self.moe_intermediate_size = config.moe_intermediate_size + self.num_experts = config.n_routed_experts + self.num_shared_experts = config.n_shared_experts + self.top_k = config.num_experts_per_tok + + self.mapping = model_config.mapping + mapping = self.mapping + + self.self_attn = DeepseekV3Attention( + model_config, + layer_idx=layer_idx, + aux_stream=aux_stream_dict[AuxStreamType.Attention]) + self.enable_attention_dp = mapping.enable_attention_dp + + self.mlp_tp_size = mapping.tp_size + self.is_p2p_supported = can_access_peer(mapping) + + self.fusion_config = EagerFusionConfig() + self.enable_fusion = os.environ.get( + "TRTLLM_DEEPSEEK_EAGER_FUSION_DISABLED", "0") == "0" + self.enable_fusion &= not self.enable_attention_dp + + # FIXME: incompatible with mixed quantization mode + quant_config = self._get_decoder_layer_quant_config( + model_config, layer_idx) + self.is_nvfp4 = quant_config.layer_quant_mode.has_nvfp4() + + has_tp = mapping.has_tp() + + if (config.n_routed_experts is not None + and layer_idx >= config.first_k_dense_replace + and layer_idx % config.moe_layer_freq == 0): + + self.fusion_config.PRE_MOE_FUSION = self.enable_fusion and has_tp + self.fusion_config.POST_MOE_FUSION = self.fusion_config.PRE_MOE_FUSION + + self.mlp = Deepseekv3MoE( + num_experts=self.num_experts, + top_k=self.top_k, + hidden_size=self.hidden_size, + intermediate_size=self.moe_intermediate_size, + shared_expert_intermediate_size=self.moe_intermediate_size * + self.num_shared_experts, + dtype=config.torch_dtype, + model_config=model_config, + override_quant_config=quant_config, + aux_stream_dict=aux_stream_dict, + layer_idx=layer_idx) + else: + block_size = 1 + if quant_config and quant_config.group_size is not None: + block_size = quant_config.group_size + self.mlp_tp_size = self._compute_mlp_tp_size( + config.intermediate_size, block_size) + + has_mlp_tp = self.mlp_tp_size > 1 + self.fusion_config.PRE_MLP_FUSION = self.enable_fusion and has_mlp_tp and self.is_nvfp4 + self.fusion_config.POST_MLP_FUSION = self.enable_fusion and has_mlp_tp + + self.mlp = GatedMLP(hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + bias=False, + dtype=config.torch_dtype, + config=model_config, + overridden_tp_size=self.mlp_tp_size, + reduce_output=True) + + self.input_layernorm = RMSNorm(hidden_size=config.hidden_size, + eps=config.rms_norm_eps, + dtype=config.torch_dtype) + + self.disable_attn_allreduce = (self.fusion_config.PRE_MOE_FUSION + or self.fusion_config.PRE_MLP_FUSION + or self.mapping.tp_size == 1 + or self.enable_attention_dp) + + self.post_attention_layernorm = RMSNorm(hidden_size=config.hidden_size, + eps=config.rms_norm_eps, + dtype=config.torch_dtype) + self.layer_idx = layer_idx + self.allreduce = AllReduce(mapping=model_config.mapping, + strategy=model_config.allreduce_strategy, + dtype=config.torch_dtype) + self.moe_allreduce = MoEAllReduce(self.mapping) + self.next_layer_layernorm: RMSNorm = None + + def _get_decoder_layer_quant_config( + self, model_config: ModelConfig[PretrainedConfig], layer_idx: int): + """ + The MTP layer in the nvfp4 checkpoint is unquantized. Because the TRTLLM + moe_backend only supports fp8/fp4 quantization, we need to override + the quant_config for the MTP layer. + """ + quant_config = model_config.quant_config + + layer_name = f"model.layers.{layer_idx}" + if quant_config.is_module_excluded_from_quantization(layer_name): + return QuantConfig( + quant_algo=None, + kv_cache_quant_algo=quant_config.kv_cache_quant_algo, + ) + else: + return model_config.quant_config + + def _compute_mlp_tp_size(self, intermediate_size: int, + block_size: int) -> int: + """ + For DeepSeek‑R1, MLP TP size is limited by intermediate_size // block_size + and must also be multiples of gpus_per_node to avoid expensive inter‑node allreduce. + + Args: + intermediate_size (int): MLP intermediate size. + block_size (int): The quantization block scale size. In the case of Deepseek FP8 recipe, + it's 128. For NVFP4, it's 16. + + Returns: + int: The computed tp_size. + """ + + assert intermediate_size % block_size == 0, "intermediate_size must be divisible by block_size." + if self.enable_attention_dp: + # If using attention DP, the MLP also uses DP instead of TP. + mlp_tp_size = 1 + else: + # The two math.gcd operations ensure that mlp_tp_size falls in the candidate TP sizes. + tp = math.gcd( + intermediate_size // block_size, + self.mapping.tp_size, + ) + mlp_tp_size = math.gcd( + tp, + self.mapping.gpus_per_node, + ) if tp > self.mapping.gpus_per_node else tp # Avoid costly inter-node TP + return mlp_tp_size + + def forward( + self, + position_ids: torch.IntTensor, + hidden_states: torch.Tensor, + attn_metadata: AttentionMetadata, + residual: torch.Tensor, + **kwargs, + ) -> torch.Tensor: + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + # Self Attention + hidden_states = self.self_attn( + position_ids=position_ids, + hidden_states=hidden_states, + attn_metadata=attn_metadata, + all_reduce_params=AllReduceParams( + enable_allreduce=not (self.disable_attn_allreduce)), + **kwargs, + ) + + if isinstance(self.mlp, Deepseekv3MoE): + return self.forward_MoE( + hidden_states=hidden_states, + attn_metadata=attn_metadata, + residual=residual, + ) + else: + assert isinstance(self.mlp, GatedMLP) + return self.forward_mlp( + hidden_states=hidden_states, + residual=residual, + ) + + def forward_MoE( + self, + hidden_states: torch.Tensor, + attn_metadata: AttentionMetadata, + residual: torch.Tensor, + ) -> torch.Tensor: + + def _run_MoE(hidden_states, hidden_states_fp4, do_finalize): + return self.mlp( + hidden_states, + hidden_states_fp4, + all_rank_num_tokens=attn_metadata.all_rank_num_tokens, + all_rank_max_num_tokens=attn_metadata.all_rank_max_num_tokens, + final_all_reduce_params=AllReduceParams( + enable_allreduce=not (self.fusion_config.POST_MOE_FUSION + or self.mapping.tp_size == 1)), + do_finalize=do_finalize, + ) + + if self.fusion_config.PRE_MOE_FUSION: + # moe_backend can be either CUTLASS or TRTLLM here + # TODO: unify the two min-latency MoE backends by enabling quant fusion + hidden_states, residual = self.allreduce( + hidden_states, + all_reduce_params=AllReduceParams( + fusion_op=AllReduceFusionOp.RESIDUAL_RMS_NORM, + residual=residual, + norm_weight=self.post_attention_layernorm.weight, + eps=self.post_attention_layernorm.variance_epsilon, + trigger_completion_at_end=False, + )) + else: + # No fusion + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) + + # Note: this fusion pattern is only supported for single-node TRTLLM-nvfp4 backend now + do_finalize = self.mapping.is_multi_node() or ( + not (hidden_states.shape[0] <= self.moe_allreduce.max_token + and self.fusion_config.POST_MOE_FUSION + and self.model_config.moe_backend == "TRTLLM" + and self.mlp.experts.has_nvfp4 and self.is_p2p_supported)) + + hidden_states = _run_MoE(hidden_states, + hidden_states_fp4=None, + do_finalize=do_finalize) + + if self.fusion_config.POST_MOE_FUSION: + if do_finalize: + hidden_states, residual = self.allreduce( + hidden_states, + all_reduce_params=AllReduceParams( + fusion_op=AllReduceFusionOp.RESIDUAL_RMS_NORM, + residual=residual, + norm_weight=self.next_layer_layernorm.weight, + eps=self.next_layer_layernorm.variance_epsilon, + trigger_completion_at_end=False, + )) + else: + assert len( + hidden_states) == 4, "hidden_states must have 4 elements" + + shared_output = hidden_states[0] + fc2_output = hidden_states[1] + expert_scale_factor = hidden_states[2] + expanded_idx_to_permuted_idx = hidden_states[3] + + moe_all_reduce_params = MoEAllReduceParams( + expanded_idx_to_permuted_idx=expanded_idx_to_permuted_idx, + expert_scale_factor=expert_scale_factor, + shared_expert_output=shared_output, + residual=residual, + norm_weight=self.next_layer_layernorm.weight, + eps=self.next_layer_layernorm.variance_epsilon, + is_cutlass_min_latency=False, + ) + hidden_states, residual = self.moe_allreduce( + fc2_output, all_reduce_params=moe_all_reduce_params) + else: + if self.next_layer_layernorm is not None: + hidden_states, residual = self.next_layer_layernorm( + hidden_states, residual) + + return hidden_states, residual + + def forward_mlp( + self, + hidden_states: torch.Tensor, + residual: torch.Tensor, + ) -> torch.Tensor: + + if self.fusion_config.PRE_MLP_FUSION: + act_fp4, act_sf, residual = self.allreduce( + hidden_states, + all_reduce_params=AllReduceParams( + fusion_op=AllReduceFusionOp.RESIDUAL_RMS_NORM_QUANT_NVFP4, + residual=residual, + norm_weight=self.post_attention_layernorm.weight, + scale=self.mlp.gate_up_proj.input_scale, + eps=self.post_attention_layernorm.variance_epsilon, + ), + ) + hidden_states = Fp4QuantizedTensor(act_fp4, act_sf) + else: + # No fusion + # We need to add twoshot allreduce here to avoid modifying MLA logic + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) + + hidden_states = self.mlp( + hidden_states, + final_all_reduce_params=AllReduceParams(enable_allreduce=not ( + self.fusion_config.POST_MLP_FUSION or self.mlp_tp_size == 1)), + ) + + if self.fusion_config.POST_MLP_FUSION: + hidden_states, residual = self.allreduce( + hidden_states, + all_reduce_params=AllReduceParams( + fusion_op=AllReduceFusionOp.RESIDUAL_RMS_NORM, + residual=residual, + norm_weight=self.next_layer_layernorm.weight, + eps=self.next_layer_layernorm.variance_epsilon, + ), + ) + else: + if self.next_layer_layernorm is not None: + hidden_states, residual = self.next_layer_layernorm( + hidden_states, residual) + + return hidden_states, residual + + +class DeepseekV3MTP(DeepseekV3DecoderLayer): + + def __init__(self, model_config: ModelConfig[PretrainedConfig], + layer_idx: int, aux_stream_dict: Dict[AuxStreamType, + torch.cuda.Stream]): + super().__init__(model_config, layer_idx, aux_stream_dict) + config = model_config.pretrained_config + self.hidden_dim = config.hidden_size + self.moe_intermediate_size = config.moe_intermediate_size + self.num_experts = config.n_routed_experts + self.num_shared_experts = config.n_shared_experts + self.top_k = config.num_experts_per_tok + + self.aux_stream = aux_stream_dict[AuxStreamType.MoeShared] + self.event_dict = { + key: torch.cuda.Event() + for key in [EventType.Main, EventType.MoeShared] + } + + self.enorm = RMSNorm(hidden_size=config.hidden_size, + eps=config.rms_norm_eps, + dtype=config.torch_dtype) + + self.hnorm = RMSNorm(hidden_size=config.hidden_size, + eps=config.rms_norm_eps, + dtype=config.torch_dtype) + if model_config.mapping.enable_attention_dp: + self.eh_proj = Linear( + config.hidden_size * 2, + config.hidden_size, + bias=False, + dtype=config.torch_dtype, + skip_create_weights_in_init=model_config. + skip_create_weights_in_init, + ) + else: + self.eh_proj = Linear( + config.hidden_size * 2, + config.hidden_size, + bias=False, + dtype=config.torch_dtype, + tensor_parallel_mode=TensorParallelMode.ROW, + mapping=model_config.mapping, + reduce_output=True, + skip_create_weights_in_init=model_config. + skip_create_weights_in_init, + ) + + self.shared_head = DeepseekV3MTPHead(model_config) + + def forward( + self, + input_ids: torch.IntTensor, + position_ids: torch.IntTensor, + hidden_states: torch.Tensor, + embed_tokens: Embedding, + attn_metadata: AttentionMetadata, + all_rank_num_tokens: Optional[List[int]] = None, + all_rank_max_num_tokens: Optional[int] = None, + **kwargs, + ) -> Tuple[torch.Tensor, torch.Tensor]: + + def norm_embeds(): + return self.enorm(embed_tokens(input_ids)) #emdedding + + def norm_hidden(): + return self.hnorm(hidden_states) + + inputs_embeds, hidden_states = maybe_execute_in_parallel( + norm_embeds, + norm_hidden, + self.event_dict[EventType.Main], + self.event_dict[EventType.MoeShared], + self.aux_stream, + ) + hidden_states = torch.concat([inputs_embeds, hidden_states], dim=-1) + # Split hidden_states columnwise based on TP + tp_size = self.model_config.mapping.tp_size + tp_rank = self.model_config.mapping.tp_rank + + if tp_size > 1 and not (self.model_config.mapping.enable_attention_dp): + hidden_states = torch.chunk(hidden_states, tp_size, dim=-1)[tp_rank] + hidden_states = self.eh_proj(hidden_states) + + # Input layer norm + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states = self.self_attn( + position_ids=position_ids, + hidden_states=hidden_states, + attn_metadata=attn_metadata, + all_reduce_params=AllReduceParams( + enable_allreduce=not (self.disable_attn_allreduce)), + **kwargs, + ) + + # MTP Layer Must have sparse MOE + if self.fusion_config.PRE_MOE_FUSION: + hidden_states, residual = self.allreduce( + hidden_states, + all_reduce_params=AllReduceParams( + fusion_op=AllReduceFusionOp.RESIDUAL_RMS_NORM, + residual=residual, + norm_weight=self.post_attention_layernorm.weight, + eps=self.post_attention_layernorm.variance_epsilon, + ), + ) + else: + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) + + # MoE + hidden_states = self.mlp( + hidden_states, + all_rank_num_tokens=all_rank_num_tokens, + all_rank_max_num_tokens=all_rank_max_num_tokens, + final_all_reduce_params=AllReduceParams( + enable_allreduce=not (self.fusion_config.POST_MOE_FUSION + or self.mapping.tp_size == 1)), + ) + + if self.fusion_config.POST_MOE_FUSION: + hidden_states, residual = self.allreduce( + hidden_states, + all_reduce_params=AllReduceParams( + fusion_op=AllReduceFusionOp.RESIDUAL_RMS_NORM, + residual=residual, + norm_weight=self.shared_head.norm.weight, + eps=self.shared_head.norm.variance_epsilon, + ), + ) + else: + hidden_states, _ = self.shared_head.norm(hidden_states, residual) + + return hidden_states + + +class MTPForCausalLM(nn.Module): + + def __init__( + self, + model, + model_config: PretrainedConfig, + start_layer_idx: int = 0, + ): + super().__init__() + spec_dec_mode = model_config.spec_config.spec_dec_mode + assert spec_dec_mode.is_mtp() + self.embed_tokens = model.embed_tokens + mtp_num_layers = 1 if spec_dec_mode.is_mtp_eagle( + ) else model_config.spec_config.num_nextn_predict_layers + + self.mtp_layers = nn.ModuleList([ + DeepseekV3MTP(model_config, layer_idx + start_layer_idx, + model.aux_stream_dict) + for layer_idx in range(mtp_num_layers) + ]) + + +def get_draft_model(model, model_config, draft_config): assert getattr(model_config, 'spec_config', None) != None spec_dec_mode = model_config.spec_config.spec_dec_mode if spec_dec_mode.is_eagle3_one_model(): return Eagle3ForCausalLM( draft_config, model_config.pretrained_config.num_hidden_layers) + elif spec_dec_mode.is_mtp(): + return MTPForCausalLM(model, model_config, + model_config.pretrained_config.num_hidden_layers) else: raise NotImplemented( f"get_draft_model does not support speculative decoding mode {spec_dec_mode}." @@ -341,23 +1295,24 @@ def __init__(self, model: TModel, model_config: ModelConfig[TConfig]): hidden_size=model_config.pretrained_config.hidden_size, vocab_size=model_config.pretrained_config.vocab_size) self.draft_model = None - if getattr( - model_config, 'spec_config', None - ) and model_config.spec_config.spec_dec_mode.use_one_engine(): - draft_config = ModelConfig.from_pretrained( - model_config.spec_config.speculative_model_dir, - trust_remote_code=True, - attn_backend=model_config.attn_backend, - moe_backend=model_config.moe_backend, - mapping=model_config.mapping, - spec_config=model_config.spec_config, - max_num_tokens=model_config.max_num_tokens, - moe_max_num_tokens=model_config.moe_max_num_tokens) - - draft_config.quant_config.kv_cache_quant_algo = \ + spec_config = getattr(model_config, 'spec_config', None) + if spec_config and spec_config.spec_dec_mode.use_one_engine(): + draft_config = None + if spec_config.spec_dec_mode.is_eagle3_one_model(): + draft_config = ModelConfig.from_pretrained( + model_config.spec_config.speculative_model_dir, + trust_remote_code=True, + attn_backend=model_config.attn_backend, + moe_backend=model_config.moe_backend, + mapping=model_config.mapping, + spec_config=model_config.spec_config, + max_num_tokens=model_config.max_num_tokens, + moe_max_num_tokens=model_config.moe_max_num_tokens) + draft_config.quant_config.kv_cache_quant_algo = \ model_config.quant_config.kv_cache_quant_algo - self.draft_model = get_draft_model(model_config, draft_config) + self.draft_model = get_draft_model(model, model_config, + draft_config) self.spec_worker = get_spec_worker(model_config.spec_config, model_config, model_config.mapping) @@ -394,6 +1349,7 @@ def forward( position_ids=position_ids, hidden_states=hidden_states, logits=logits, + lm_head=self.lm_head, attn_metadata=attn_metadata, spec_metadata=spec_metadata, draft_model=self.draft_model) diff --git a/tensorrt_llm/_torch/speculative/eagle3.py b/tensorrt_llm/_torch/speculative/eagle3.py index 417becf12f3..c30f407c93a 100644 --- a/tensorrt_llm/_torch/speculative/eagle3.py +++ b/tensorrt_llm/_torch/speculative/eagle3.py @@ -268,7 +268,7 @@ def __init__(self, spec_config: "EagleDecodingConfig", mapping: Mapping): # Skip torch.compile for now since current Torch is not compatible with Triton 3.4 # @torch.compile(options={"max-autotune": True}) - def forward(self, input_ids, position_ids, hidden_states, logits, + def forward(self, input_ids, position_ids, hidden_states, logits, lm_head, attn_metadata, spec_metadata, draft_model): batch_size = attn_metadata.num_seqs num_contexts = attn_metadata.num_contexts diff --git a/tensorrt_llm/_torch/speculative/interface.py b/tensorrt_llm/_torch/speculative/interface.py index 7abb97f3b42..acc68b4b94c 100644 --- a/tensorrt_llm/_torch/speculative/interface.py +++ b/tensorrt_llm/_torch/speculative/interface.py @@ -23,6 +23,9 @@ class SpeculativeDecodingMode(IntEnum): def is_mtp(self): return self == SpeculativeDecodingMode.MTP or self == SpeculativeDecodingMode.MTP_EAGLE + def is_mtp_vanilla(self): + return self == SpeculativeDecodingMode.MTP + def is_mtp_eagle(self): return self == SpeculativeDecodingMode.MTP_EAGLE diff --git a/tensorrt_llm/_torch/speculative/mtp.py b/tensorrt_llm/_torch/speculative/mtp.py index 1772125bcbf..22a8ea5b197 100644 --- a/tensorrt_llm/_torch/speculative/mtp.py +++ b/tensorrt_llm/_torch/speculative/mtp.py @@ -331,10 +331,9 @@ def forward( hidden_states, logits, lm_head, - embed_tokens, attn_metadata, spec_metadata, - mtp_layers, + draft_model, ): ''' Example: @@ -470,8 +469,9 @@ def forward( next_draft_tokens = [] last_tokens_idx = torch.cumsum( attn_metadata.seq_lens_cuda, dim=0, dtype=torch.long) - 1 - for _, mtp_layer in enumerate(mtp_layers): - hidden_states = mtp_layer(embed_tokens=embed_tokens, **draft_inputs) + for _, mtp_layer in enumerate(draft_model.mtp_layers): + hidden_states = mtp_layer(embed_tokens=draft_model.embed_tokens, + **draft_inputs) logits = mtp_layer.shared_head(hidden_states, lm_head, attn_metadata).float() new_draft_token = self.draft_sampler(logits) @@ -518,7 +518,6 @@ def skip_forward( hidden_states, logits, lm_head, - embed_tokens, attn_metadata, spec_metadata, mtp_layers, @@ -1128,10 +1127,9 @@ def forward( hidden_states, logits, lm_head, - embed_tokens, attn_metadata, spec_metadata, - mtp_layers, + draft_model, ): batch_size = attn_metadata.num_seqs num_contexts = attn_metadata.num_contexts @@ -1172,8 +1170,8 @@ def prepare_position_ids_and_last_tokens(position_ids, attn_metadata): next_draft_tokens = [] for i in range(self.mtp_num_modules): if i == 0: - hidden_states = mtp_layers[0]( - embed_tokens=embed_tokens, + hidden_states = draft_model.mtp_layers[0]( + embed_tokens=draft_model.embed_tokens, all_rank_num_tokens=spec_metadata.all_rank_num_tokens, all_rank_max_num_tokens=spec_metadata. all_rank_max_num_tokens, @@ -1186,8 +1184,8 @@ def prepare_position_ids_and_last_tokens(position_ids, attn_metadata): gather_ids = torch.concat( [last_tokens_idx[:num_contexts], gather_ids_gen], dim=0) else: - hidden_states = mtp_layers[0]( - embed_tokens=embed_tokens, + hidden_states = draft_model.mtp_layers[0]( + embed_tokens=draft_model.embed_tokens, all_rank_num_tokens=spec_metadata. subseq_all_rank_num_tokens, all_rank_max_num_tokens=max( @@ -1197,8 +1195,8 @@ def prepare_position_ids_and_last_tokens(position_ids, attn_metadata): **inputs) # All of the seq_len are 1, use batch_indices_cuda as gather_ids gather_ids = spec_metadata.batch_indices_cuda[:batch_size] - logits = mtp_layers[0].shared_head(hidden_states[gather_ids], - lm_head, attn_metadata, True) + logits = draft_model.mtp_layers[0].shared_head( + hidden_states[gather_ids], lm_head, attn_metadata, True) new_draft_token = self.draft_sampler(logits) hidden_states, position_ids = self.update_draft_tokens( diff --git a/tensorrt_llm/_torch/speculative/utils.py b/tensorrt_llm/_torch/speculative/utils.py index ad7fbf8fd56..c4a4ccf7e3c 100644 --- a/tensorrt_llm/_torch/speculative/utils.py +++ b/tensorrt_llm/_torch/speculative/utils.py @@ -154,11 +154,12 @@ def get_num_spec_layers(spec_config): def get_spec_worker(spec_config, model_config, mapping): - if spec_config.spec_dec_mode.is_mtp(): + spec_dec_mode = spec_config.spec_dec_mode + if spec_dec_mode.is_mtp_vanilla(): return MTPWorker(spec_config, model_config) - if spec_config.spec_dec_mode.is_mtp_eagle(): + if spec_dec_mode.is_mtp_eagle(): return MTPEagleWorker(spec_config, model_config) - if spec_config.spec_dec_mode.is_eagle3_one_model(): + if spec_dec_mode.is_eagle3_one_model(): return Eagle3OneModelWorker(spec_config, mapping) return None From d9d16b669bab95ec65e5db455514f51b5ceeede6 Mon Sep 17 00:00:00 2001 From: linquanh Date: Fri, 8 Aug 2025 03:46:05 -0700 Subject: [PATCH 2/2] update Signed-off-by: linquanh --- .../_torch/models/modeling_deepseekv3.py | 49 +- .../_torch/models/modeling_speculative.py | 962 +----------------- tensorrt_llm/_torch/speculative/eagle3.py | 2 +- tensorrt_llm/_torch/speculative/mtp.py | 10 +- 4 files changed, 60 insertions(+), 963 deletions(-) diff --git a/tensorrt_llm/_torch/models/modeling_deepseekv3.py b/tensorrt_llm/_torch/models/modeling_deepseekv3.py index e568347a1e3..7a9a1c7997a 100644 --- a/tensorrt_llm/_torch/models/modeling_deepseekv3.py +++ b/tensorrt_llm/_torch/models/modeling_deepseekv3.py @@ -26,29 +26,49 @@ # -------------------------------------------------- import copy -from typing import Dict, Optional +import math +import os +import warnings +from typing import Dict, List, Optional, Tuple import torch +import torch.nn.functional as F import triton import triton.language as tl from torch import nn from tqdm import tqdm from transformers import PretrainedConfig +from tensorrt_llm._ipc_utils import can_access_peer from tensorrt_llm._utils import get_sm_version +from tensorrt_llm.functional import PositionEmbeddingType +from tensorrt_llm.llmapi.utils import enable_llm_debug +from tensorrt_llm.mapping import Mapping +from tensorrt_llm.models.modeling_utils import QuantConfig from tensorrt_llm.quantization.utils.fp8_utils import ( resmooth_to_fp8_e8m0, transform_sf_into_required_layout) from ..attention_backend import AttentionMetadata +from ..attention_backend.interface import PositionalEmbeddingParams, RopeParams +from ..distributed import (AllReduce, AllReduceFusionOp, AllReduceParams, + MoEAllReduce, MoEAllReduceParams, allgather) from ..model_config import ModelConfig +from ..modules.attention import MLA +from ..modules.decoder_layer import DecoderLayer from ..modules.embedding import Embedding -from ..modules.fused_moe import moe_load_balancer_set_repeated_for_next_layer +from ..modules.fused_moe import (DeepSeekV3MoeRoutingMethod, TRTLLMGenFusedMoE, + create_moe, + moe_load_balancer_set_repeated_for_next_layer) +from ..modules.gated_mlp import GatedMLP +from ..modules.linear import Linear, TensorParallelMode, WeightsLoadingConfig +from ..modules.multi_stream_utils import maybe_execute_in_parallel from ..modules.rms_norm import RMSNorm +from ..peft.lora.layer import LoraLayer from ..speculative import MTPSpecMetadata, SpecMetadata -from ..utils import AuxStreamType -from .modeling_speculative import (DeepseekV3DecoderLayer, - SpecDecOneEngineForCausalLM) -from .modeling_utils import DecoderModel, filter_weights, register_auto_model +from ..utils import AuxStreamType, EventType, Fp4QuantizedTensor +from .modeling_speculative import SpecDecOneEngineForCausalLM +from .modeling_utils import (DecoderModel, EagerFusionConfig, filter_weights, + register_auto_model) @triton.jit @@ -107,6 +127,7 @@ def weight_dequant(x: torch.Tensor, weight_dequant_kernel[grid](x, s, y, M, N, BLOCK_SIZE=block_size) return y + class DeepseekV3MTPHead(nn.Module): def __init__(self, model_config: ModelConfig[PretrainedConfig]): @@ -512,7 +533,8 @@ def compute_routed_output(self, hidden_states, hidden_states_fp4, router_logits = self.gate(hidden_states) routed_output = self.experts( - hidden_states_fp4 or hidden_states, + hidden_states_fp4 + if hidden_states_fp4 is not None else hidden_states, router_logits, do_finalize=do_finalize, output_dtype=hidden_states.dtype, @@ -536,8 +558,9 @@ def forward( assert not self.use_dp def _compute_shared_output(): - shared_output = self.shared_experts(hidden_states_fp4 - or hidden_states) + shared_output = self.shared_experts( + hidden_states_fp4 + if hidden_states_fp4 is not None else hidden_states) if self.shared_output_scale is not None: shared_output *= self.shared_output_scale return shared_output @@ -721,7 +744,7 @@ def forward( attn_metadata: AttentionMetadata, residual: torch.Tensor, **kwargs, - ) -> torch.Tensor: + ) -> Tuple[torch.Tensor, torch.Tensor]: if residual is None: residual = hidden_states hidden_states = self.input_layernorm(hidden_states) @@ -753,7 +776,7 @@ def forward_MoE( hidden_states: torch.Tensor, attn_metadata: AttentionMetadata, residual: torch.Tensor, - ) -> torch.Tensor: + ) -> Tuple[torch.Tensor, torch.Tensor]: def _run_MoE(hidden_states, hidden_states_fp4, do_finalize): return self.mlp( @@ -837,7 +860,7 @@ def forward_mlp( self, hidden_states: torch.Tensor, residual: torch.Tensor, - ) -> torch.Tensor: + ) -> Tuple[torch.Tensor, torch.Tensor]: if self.fusion_config.PRE_MLP_FUSION: act_fp4, act_sf, residual = self.allreduce( @@ -941,7 +964,7 @@ def forward( all_rank_num_tokens: Optional[List[int]] = None, all_rank_max_num_tokens: Optional[int] = None, **kwargs, - ) -> Tuple[torch.Tensor, torch.Tensor]: + ) -> torch.Tensor: def norm_embeds(): return self.enorm(embed_tokens(input_ids)) #emdedding diff --git a/tensorrt_llm/_torch/models/modeling_speculative.py b/tensorrt_llm/_torch/models/modeling_speculative.py index 9414c25daf6..682c55919e0 100644 --- a/tensorrt_llm/_torch/models/modeling_speculative.py +++ b/tensorrt_llm/_torch/models/modeling_speculative.py @@ -1,42 +1,26 @@ -import math -import os -import warnings -from typing import Any, Dict, Generic, List, Optional, Tuple +from typing import Any, Dict, Generic, Optional, Tuple import torch -import torch.nn.functional as F from torch import nn from transformers import LlamaConfig, PretrainedConfig -from tensorrt_llm._ipc_utils import can_access_peer from tensorrt_llm._torch.models.checkpoints.base_weight_mapper import \ BaseWeightMapper -from tensorrt_llm._utils import get_sm_version from tensorrt_llm.functional import PositionEmbeddingType -from tensorrt_llm.llmapi.utils import enable_llm_debug -from tensorrt_llm.mapping import Mapping -from tensorrt_llm.models.modeling_utils import QuantConfig from ..attention_backend import AttentionMetadata from ..attention_backend.interface import PositionalEmbeddingParams, RopeParams -from ..distributed import (AllReduce, AllReduceFusionOp, AllReduceParams, - MoEAllReduce, MoEAllReduceParams, allgather) from ..model_config import ModelConfig, TConfig -from ..modules.attention import MLA, Attention +from ..modules.attention import Attention from ..modules.decoder_layer import DecoderLayer from ..modules.embedding import Embedding -from ..modules.fused_moe import (DeepSeekV3MoeRoutingMethod, TRTLLMGenFusedMoE, - create_moe) from ..modules.gated_mlp import GatedMLP from ..modules.linear import (Linear, TensorParallelMode, WeightMode, WeightsLoadingConfig) -from ..modules.multi_stream_utils import maybe_execute_in_parallel from ..modules.rms_norm import RMSNorm -from ..peft.lora.layer import LoraLayer from ..speculative import SpecMetadata, get_spec_worker -from ..utils import AuxStreamType, EventType, Fp4QuantizedTensor -from .modeling_utils import (DecoderModel, DecoderModelForCausalLM, - EagerFusionConfig, TModel, register_auto_model) +from .modeling_utils import (DecoderModel, DecoderModelForCausalLM, TModel, + register_auto_model) class Eagle3Attention(Attention): @@ -336,931 +320,21 @@ def apply_eagle3_fc(self, hidden_states: torch.Tensor) -> torch.Tensor: return hidden_states -class DeepseekV3MTPHead(nn.Module): - - def __init__(self, model_config: ModelConfig[PretrainedConfig]): - super().__init__() - config = model_config.pretrained_config - self.model_config = model_config - - self.norm = RMSNorm(hidden_size=config.hidden_size, - eps=config.rms_norm_eps, - dtype=config.torch_dtype) - - @torch.compile(options={"max-autotune": True}) - def get_last_token_states(self, hidden_states, attn_metadata): - last_tokens = torch.cumsum( - attn_metadata.seq_lens_cuda, - dim=0, - dtype=torch.long, - ) - 1 - return hidden_states[last_tokens] - - def forward(self, - hidden_states: torch.Tensor, - lm_head: Linear, - attn_metadata: AttentionMetadata, - return_context_logits: bool = False) -> torch.Tensor: - if not return_context_logits: - if attn_metadata is not None: - hidden_states = self.get_last_token_states( - hidden_states, attn_metadata) - else: - hidden_states = hidden_states[-1].unsqueeze(0) - - if not (self.model_config.mapping.enable_attention_dp): - lm_head.gather_output = False - logits = lm_head(hidden_states) - if not (self.model_config.mapping.enable_attention_dp): - lm_head.gather_output = True - return logits - - -class DeepseekV3Linear(Linear): - """ - A wrapper around Linear because we may optionally use min-latency kernels depending on input shapes. - """ - - def __init__( - self, - in_features: int, - out_features: int, - bias: bool = True, - dtype: torch.dtype = None, - mapping: Optional[Mapping] = None, - tensor_parallel_mode: Optional[TensorParallelMode] = None, - gather_output: bool = False, # COLUMN parallel only - quant_config: Optional[QuantConfig] = None, - weights_loading_config: Optional[WeightsLoadingConfig] = None, - reduce_output: bool = True, # ROW parallel only - skip_create_weights_in_init: bool = False, - use_custom_cublas_mm: bool = False, - lora: Optional[LoraLayer] = None, - ): - super().__init__( - in_features, - out_features, - bias, - dtype, - mapping, - tensor_parallel_mode, - gather_output, - quant_config, - weights_loading_config, - reduce_output, - skip_create_weights_in_init, - use_custom_cublas_mm, - lora, - ) - - def apply_linear(self, - input, - bias, - lora_params: Optional[dict] | None = None, - layer_idx: Optional[int] | None = None): - num_tokens = input.shape[0] - if (not self.has_any_quant and 1 <= num_tokens <= 16 - and get_sm_version() != 120): - output = torch.ops.trtllm.dsv3_fused_a_gemm_op( - input, self.weight.t(), bias, None) - else: - output = super().apply_linear(input, bias, lora_params, layer_idx) - return output - - -class DeepseekV3Attention(MLA): - - def __init__( - self, - model_config: ModelConfig[PretrainedConfig], - layer_idx: Optional[int] = None, - aux_stream: Optional[torch.cuda.Stream] = None, - ): - config = model_config.pretrained_config - predicted_tokens_per_seq = model_config.spec_config.num_nextn_predict_layers + 1 if model_config.spec_config is not None else 1 - super().__init__(hidden_size=config.hidden_size, - num_attention_heads=config.num_attention_heads, - num_key_value_heads=config.num_key_value_heads, - qk_rope_head_dim=config.qk_rope_head_dim, - qk_nope_head_dim=config.qk_nope_head_dim, - q_lora_rank=config.q_lora_rank, - kv_lora_rank=config.kv_lora_rank, - v_head_dim=config.v_head_dim, - predicted_tokens_per_seq=predicted_tokens_per_seq, - max_position_embeddings=config.max_position_embeddings, - bias=False, - pos_embd_params=PositionalEmbeddingParams( - type=PositionEmbeddingType.yarn, - rope=RopeParams.from_config(config), - is_neox=False, - ), - layer_idx=layer_idx, - dtype=config.torch_dtype, - config=model_config, - aux_stream=aux_stream) - self.kv_a_proj_with_mqa = DeepseekV3Linear( - config.hidden_size, - self.kv_lora_rank + self.qk_rope_head_dim + - (self.q_lora_rank if not self.is_lite else 0), - bias=False, - dtype=config.torch_dtype, - quant_config=model_config.get_quant_config(), - skip_create_weights_in_init=model_config. - skip_create_weights_in_init, - use_custom_cublas_mm=True) - - -class Deepseekv3RoutingImpl(): - - def __init__( - self, - top_k: int, - n_group: int, - topk_group: int, - routed_scaling_factor: float, - is_fused: bool = True, - ): - super().__init__() - self.top_k = top_k - self.topk_group = topk_group - self.n_group = n_group - self.routed_scaling_factor = routed_scaling_factor - self.is_fused = is_fused - - def noaux_tc(self, logits, e_score_correction_bias): - n_group = self.n_group - scores = F.sigmoid(logits) - scores_with_bias = scores + e_score_correction_bias - scores_shape = list(scores_with_bias.shape) - - if enable_llm_debug(): - has_nan = torch.isnan(scores_with_bias).any() - if has_nan: - warnings.warn( - "Detected NAN in the tensor scores_with_bias. Please check if it matches the expectation." - ) - - if not self.is_fused: - group_scores = torch.sum(torch.topk( - scores_with_bias.view(scores_shape[:-1] + - [n_group, scores_shape[-1] // n_group]), - k=2, - dim=-1, - largest=True, - sorted=True)[0], - dim=-1) - _, group_idx = torch.topk(group_scores, - k=self.topk_group, - dim=-1, - largest=True, - sorted=True) - group_mask = torch.zeros_like(group_scores) - group_mask.scatter_(-1, group_idx, 1) - score_mask = group_mask.unsqueeze(-1).expand( - scores_shape[:-1] + - [n_group, scores_shape[-1] // n_group]).reshape(scores_shape) - scores_with_bias = scores_with_bias * score_mask - _, topk_idx = torch.topk(scores_with_bias, - k=self.top_k, - dim=-1, - largest=True, - sorted=True) - new_mask = torch.zeros_like(scores) - new_mask.scatter_(-1, topk_idx, 1) - scores = scores * new_mask - score_sum = torch.sum(scores, dim=-1, keepdim=True) + 1e-20 - scores = scores / score_sum * \ - self.routed_scaling_factor - topk_values, topk_indices = torch.topk(scores, - k=self.top_k, - dim=-1, - largest=True) - return topk_values, topk_indices - else: - topk_values, topk_indices = torch.ops.trtllm.noaux_tc_op( - scores, scores_with_bias, n_group, self.topk_group, self.top_k, - self.routed_scaling_factor) - return topk_values, topk_indices - - def apply( - self, logits: torch.Tensor, e_score_correction_bias: torch.Tensor - ) -> tuple[torch.Tensor, torch.Tensor]: - topk_values, topk_indices = self.noaux_tc(logits, - e_score_correction_bias) - return topk_indices.to(torch.int32), topk_values.to(torch.float32) - - -class DeepseekV3Gate(DeepSeekV3MoeRoutingMethod): - - def __init__( - self, - hidden_size: int, - num_experts: int, - top_k: int, - n_group: int, - topk_group: int, - routed_scaling_factor: float, - dtype: Optional[torch.dtype] = None, - fuse_routing_kernel: bool = True, - apply_routing: bool = False, - moe_backend: str = 'CUTLASS', - ): - super().__init__(top_k=top_k) - self.weight = nn.Parameter(torch.empty((num_experts, hidden_size), - dtype=dtype), - requires_grad=False) - self.moe_backend = moe_backend - if moe_backend == 'TRTLLM': - bias_dtype = torch.bfloat16 - else: - bias_dtype = torch.float32 - - self.e_score_correction_bias = nn.Parameter(torch.empty( - (num_experts), dtype=bias_dtype), - requires_grad=False) - - assert not apply_routing, "DeepseekV3Gate routing is called inside MoE" - - # TODO: e_score_correction_bias belongs in this gate class but is required by the routing impl. - # To avoid weight-loading issues, we treat this gate as the BaseMoeRoutingMethod and dispatch to the routing impl. - # This is a temporary hack that should be refactored later. - self.routing_impl = Deepseekv3RoutingImpl( - top_k=top_k, - n_group=n_group, - topk_group=topk_group, - routed_scaling_factor=routed_scaling_factor, - is_fused=fuse_routing_kernel) - - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - logits = torch.ops.trtllm.dsv3_router_gemm_op(hidden_states, - self.weight.t(), - bias=None, - out_dtype=torch.float32) - return logits - - def load_weights(self, weights: List[Dict]): - assert len(weights) == 1 - - self.weight.copy_(weights[0]["weight"][:]) - - self.e_score_correction_bias.copy_( - weights[0]["e_score_correction_bias"][:].to( - self.e_score_correction_bias.dtype)) - - def apply(self, logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: - # topk routing - return self.routing_impl.apply(logits, self.e_score_correction_bias) - - @property - def routing_method(self) -> DeepSeekV3MoeRoutingMethod: - return self - - def get_experts_per_token(self): - return self.routing_impl.top_k - - -class Deepseekv3MoE(nn.Module): - - def __init__(self, - *, - num_experts: int, - top_k: int, - hidden_size: int, - intermediate_size: int, - shared_expert_intermediate_size: int, - aux_stream_dict: Dict[AuxStreamType, torch.cuda.Stream], - dtype: Optional[torch.dtype] = None, - model_config: ModelConfig = ModelConfig(), - override_quant_config: Optional[QuantConfig] = None, - layer_idx: Optional[int] = None): - from ..distributed import AllReduce - - super().__init__() - config = model_config.pretrained_config - self.top_k = top_k - self.use_dp = model_config.mapping.enable_attention_dp - self.gate = DeepseekV3Gate( - hidden_size, - num_experts, - top_k=top_k, - n_group=config.n_group, - topk_group=config.topk_group, - routed_scaling_factor=config.routed_scaling_factor, - dtype=dtype, - fuse_routing_kernel=True, - apply_routing=False, - moe_backend=model_config.moe_backend) - self.experts = create_moe( - num_experts=num_experts, - routing_method=self.gate.routing_method, - hidden_size=hidden_size, - intermediate_size=intermediate_size, - dtype=dtype, - reduce_results= - False, # In both low‑latency and attention‑DP modes, FusedMoE skips the in‑op all‑reduce. - model_config=model_config, - override_quant_config=override_quant_config, - aux_stream=aux_stream_dict[AuxStreamType.MoeChunkingOverlap], - layer_idx=layer_idx) - - self.mapping = model_config.mapping - - # FIXME: incompatible with mixed quantization mode (including excluding modules from quantization) - block_size = 1 - if model_config.quant_config and model_config.quant_config.group_size is not None: - block_size = model_config.quant_config.group_size - - shared_tp_size, self.shared_output_scale = self._compute_shared_expert_tp_size( - shared_expert_intermediate_size, block_size) - - self.shared_experts = GatedMLP( - hidden_size=hidden_size, - intermediate_size=shared_expert_intermediate_size, - bias=False, - dtype=dtype, - config=model_config, - overridden_tp_size=shared_tp_size, - reduce_output=False) - - self.allreduce = AllReduce(mapping=model_config.mapping, - strategy=model_config.allreduce_strategy) - self.aux_stream = aux_stream_dict[AuxStreamType.MoeShared] - self.event_dict = { - key: torch.cuda.Event() - for key in [EventType.Main, EventType.MoeShared] - } - - def _compute_shared_expert_tp_size(self, intermediate_size: int, - block_size: int) -> int: - """ - In the case of Deepseek-R1, the TP size of MLP is capped by intermediate_size // block_size. - For example, when the intermediate_size is 2048 and block scaling size is 128, - TP sizes are limited to {1, 2, 4, 8, 16} because of 2048/128 = 16. - - Args: - intermediate_size (int): MLP intermediate size. - block_size (int): The quantization block scale size. In the case of Deepseek FP8 recipe, - it's 128. For NVFP4, it's 16. - - Returns: - int: The computed tp_size. - """ - - assert intermediate_size % block_size == 0, "intermediate_size must be divisible by block_size." - - shared_output_scale = None - # The block scale size is 128, which requires shared_expert_intermediate_size to be divisible by 128. - if self.use_dp: - # If using attention DP, the shared experts also use DP instead of TP. - shared_tp_size = 1 - else: - # Due to the restriction of block scale size (i.e., 128), the supported TP sizes only include 1, 2, 4, 8, and 16. - # The math.gcd operation ensures that shared_tp_size falls in the supported TP sizes. - shared_tp_size = math.gcd( - intermediate_size // block_size, - self.mapping.tp_size, - ) - # If shared_tp_size has been overridden, the output of shared experts needs to be scaled down accordingly before all-reduce. - if shared_tp_size != self.mapping.tp_size: - shared_output_scale = shared_tp_size / self.mapping.tp_size - - return shared_tp_size, shared_output_scale - - def compute_routed_output(self, hidden_states, hidden_states_fp4, - all_rank_num_tokens, all_rank_max_num_tokens, - do_finalize): - # max-throughput - use_dp_padding = False - if self.use_dp and self.mapping.tp_size > 1: - if isinstance(self.experts, TRTLLMGenFusedMoE): - hidden_states = allgather(hidden_states, - self.mapping, - dim=0, - sizes=all_rank_num_tokens) - - router_logits = self.gate(hidden_states) - - routed_output = self.experts( - hidden_states_fp4 or hidden_states, - router_logits, - do_finalize=do_finalize, - output_dtype=hidden_states.dtype, - all_rank_num_tokens=all_rank_num_tokens, - all_rank_max_num_tokens=all_rank_max_num_tokens, - use_dp_padding=use_dp_padding, - ) - - return routed_output - - def forward( - self, - hidden_states: torch.Tensor, - hidden_states_fp4: Optional[Fp4QuantizedTensor] = None, - all_rank_num_tokens: Optional[list[int]] = None, - all_rank_max_num_tokens: Optional[int] = None, - final_all_reduce_params: Optional[AllReduceParams] = None, - do_finalize: Optional[bool] = True, - ) -> torch.Tensor: - if not do_finalize: - assert not self.use_dp - - def _compute_shared_output(): - shared_output = self.shared_experts(hidden_states_fp4 - or hidden_states) - if self.shared_output_scale is not None: - shared_output *= self.shared_output_scale - return shared_output - - def _compute_routed_output(): - routed_output = self.compute_routed_output(hidden_states, - hidden_states_fp4, - all_rank_num_tokens, - all_rank_max_num_tokens, - do_finalize) - return routed_output - - routed_output, shared_output = maybe_execute_in_parallel( - _compute_routed_output, _compute_shared_output, - self.event_dict[EventType.Main], - self.event_dict[EventType.MoeShared], self.aux_stream) - - if not do_finalize: - return [shared_output, *routed_output] - else: - assert shared_output.size() == routed_output.size( - ), f'unmatched tensor shape' - final_hidden_states = shared_output + routed_output - if not self.use_dp and self.mapping.tp_size > 1: - final_hidden_states = self.allreduce( - final_hidden_states, - all_reduce_params=final_all_reduce_params) - - return final_hidden_states - - -class DeepseekV3DecoderLayer(DecoderLayer): - - def __init__(self, model_config: ModelConfig[PretrainedConfig], - layer_idx: int, aux_stream_dict: Dict[AuxStreamType, - torch.cuda.Stream]): - super().__init__() - self.model_config = model_config - config = model_config.pretrained_config - - self.hidden_size = config.hidden_size - self.moe_intermediate_size = config.moe_intermediate_size - self.num_experts = config.n_routed_experts - self.num_shared_experts = config.n_shared_experts - self.top_k = config.num_experts_per_tok - - self.mapping = model_config.mapping - mapping = self.mapping - - self.self_attn = DeepseekV3Attention( - model_config, - layer_idx=layer_idx, - aux_stream=aux_stream_dict[AuxStreamType.Attention]) - self.enable_attention_dp = mapping.enable_attention_dp - - self.mlp_tp_size = mapping.tp_size - self.is_p2p_supported = can_access_peer(mapping) - - self.fusion_config = EagerFusionConfig() - self.enable_fusion = os.environ.get( - "TRTLLM_DEEPSEEK_EAGER_FUSION_DISABLED", "0") == "0" - self.enable_fusion &= not self.enable_attention_dp - - # FIXME: incompatible with mixed quantization mode - quant_config = self._get_decoder_layer_quant_config( - model_config, layer_idx) - self.is_nvfp4 = quant_config.layer_quant_mode.has_nvfp4() - - has_tp = mapping.has_tp() - - if (config.n_routed_experts is not None - and layer_idx >= config.first_k_dense_replace - and layer_idx % config.moe_layer_freq == 0): - - self.fusion_config.PRE_MOE_FUSION = self.enable_fusion and has_tp - self.fusion_config.POST_MOE_FUSION = self.fusion_config.PRE_MOE_FUSION - - self.mlp = Deepseekv3MoE( - num_experts=self.num_experts, - top_k=self.top_k, - hidden_size=self.hidden_size, - intermediate_size=self.moe_intermediate_size, - shared_expert_intermediate_size=self.moe_intermediate_size * - self.num_shared_experts, - dtype=config.torch_dtype, - model_config=model_config, - override_quant_config=quant_config, - aux_stream_dict=aux_stream_dict, - layer_idx=layer_idx) - else: - block_size = 1 - if quant_config and quant_config.group_size is not None: - block_size = quant_config.group_size - self.mlp_tp_size = self._compute_mlp_tp_size( - config.intermediate_size, block_size) - - has_mlp_tp = self.mlp_tp_size > 1 - self.fusion_config.PRE_MLP_FUSION = self.enable_fusion and has_mlp_tp and self.is_nvfp4 - self.fusion_config.POST_MLP_FUSION = self.enable_fusion and has_mlp_tp - - self.mlp = GatedMLP(hidden_size=config.hidden_size, - intermediate_size=config.intermediate_size, - bias=False, - dtype=config.torch_dtype, - config=model_config, - overridden_tp_size=self.mlp_tp_size, - reduce_output=True) - - self.input_layernorm = RMSNorm(hidden_size=config.hidden_size, - eps=config.rms_norm_eps, - dtype=config.torch_dtype) - - self.disable_attn_allreduce = (self.fusion_config.PRE_MOE_FUSION - or self.fusion_config.PRE_MLP_FUSION - or self.mapping.tp_size == 1 - or self.enable_attention_dp) - - self.post_attention_layernorm = RMSNorm(hidden_size=config.hidden_size, - eps=config.rms_norm_eps, - dtype=config.torch_dtype) - self.layer_idx = layer_idx - self.allreduce = AllReduce(mapping=model_config.mapping, - strategy=model_config.allreduce_strategy, - dtype=config.torch_dtype) - self.moe_allreduce = MoEAllReduce(self.mapping) - self.next_layer_layernorm: RMSNorm = None - - def _get_decoder_layer_quant_config( - self, model_config: ModelConfig[PretrainedConfig], layer_idx: int): - """ - The MTP layer in the nvfp4 checkpoint is unquantized. Because the TRTLLM - moe_backend only supports fp8/fp4 quantization, we need to override - the quant_config for the MTP layer. - """ - quant_config = model_config.quant_config - - layer_name = f"model.layers.{layer_idx}" - if quant_config.is_module_excluded_from_quantization(layer_name): - return QuantConfig( - quant_algo=None, - kv_cache_quant_algo=quant_config.kv_cache_quant_algo, - ) - else: - return model_config.quant_config - - def _compute_mlp_tp_size(self, intermediate_size: int, - block_size: int) -> int: - """ - For DeepSeek‑R1, MLP TP size is limited by intermediate_size // block_size - and must also be multiples of gpus_per_node to avoid expensive inter‑node allreduce. - - Args: - intermediate_size (int): MLP intermediate size. - block_size (int): The quantization block scale size. In the case of Deepseek FP8 recipe, - it's 128. For NVFP4, it's 16. - - Returns: - int: The computed tp_size. - """ - - assert intermediate_size % block_size == 0, "intermediate_size must be divisible by block_size." - if self.enable_attention_dp: - # If using attention DP, the MLP also uses DP instead of TP. - mlp_tp_size = 1 - else: - # The two math.gcd operations ensure that mlp_tp_size falls in the candidate TP sizes. - tp = math.gcd( - intermediate_size // block_size, - self.mapping.tp_size, - ) - mlp_tp_size = math.gcd( - tp, - self.mapping.gpus_per_node, - ) if tp > self.mapping.gpus_per_node else tp # Avoid costly inter-node TP - return mlp_tp_size - - def forward( - self, - position_ids: torch.IntTensor, - hidden_states: torch.Tensor, - attn_metadata: AttentionMetadata, - residual: torch.Tensor, - **kwargs, - ) -> torch.Tensor: - if residual is None: - residual = hidden_states - hidden_states = self.input_layernorm(hidden_states) - # Self Attention - hidden_states = self.self_attn( - position_ids=position_ids, - hidden_states=hidden_states, - attn_metadata=attn_metadata, - all_reduce_params=AllReduceParams( - enable_allreduce=not (self.disable_attn_allreduce)), - **kwargs, - ) - - if isinstance(self.mlp, Deepseekv3MoE): - return self.forward_MoE( - hidden_states=hidden_states, - attn_metadata=attn_metadata, - residual=residual, - ) - else: - assert isinstance(self.mlp, GatedMLP) - return self.forward_mlp( - hidden_states=hidden_states, - residual=residual, - ) - - def forward_MoE( - self, - hidden_states: torch.Tensor, - attn_metadata: AttentionMetadata, - residual: torch.Tensor, - ) -> torch.Tensor: - - def _run_MoE(hidden_states, hidden_states_fp4, do_finalize): - return self.mlp( - hidden_states, - hidden_states_fp4, - all_rank_num_tokens=attn_metadata.all_rank_num_tokens, - all_rank_max_num_tokens=attn_metadata.all_rank_max_num_tokens, - final_all_reduce_params=AllReduceParams( - enable_allreduce=not (self.fusion_config.POST_MOE_FUSION - or self.mapping.tp_size == 1)), - do_finalize=do_finalize, - ) - - if self.fusion_config.PRE_MOE_FUSION: - # moe_backend can be either CUTLASS or TRTLLM here - # TODO: unify the two min-latency MoE backends by enabling quant fusion - hidden_states, residual = self.allreduce( - hidden_states, - all_reduce_params=AllReduceParams( - fusion_op=AllReduceFusionOp.RESIDUAL_RMS_NORM, - residual=residual, - norm_weight=self.post_attention_layernorm.weight, - eps=self.post_attention_layernorm.variance_epsilon, - trigger_completion_at_end=False, - )) - else: - # No fusion - hidden_states, residual = self.post_attention_layernorm( - hidden_states, residual) - - # Note: this fusion pattern is only supported for single-node TRTLLM-nvfp4 backend now - do_finalize = self.mapping.is_multi_node() or ( - not (hidden_states.shape[0] <= self.moe_allreduce.max_token - and self.fusion_config.POST_MOE_FUSION - and self.model_config.moe_backend == "TRTLLM" - and self.mlp.experts.has_nvfp4 and self.is_p2p_supported)) - - hidden_states = _run_MoE(hidden_states, - hidden_states_fp4=None, - do_finalize=do_finalize) - - if self.fusion_config.POST_MOE_FUSION: - if do_finalize: - hidden_states, residual = self.allreduce( - hidden_states, - all_reduce_params=AllReduceParams( - fusion_op=AllReduceFusionOp.RESIDUAL_RMS_NORM, - residual=residual, - norm_weight=self.next_layer_layernorm.weight, - eps=self.next_layer_layernorm.variance_epsilon, - trigger_completion_at_end=False, - )) - else: - assert len( - hidden_states) == 4, "hidden_states must have 4 elements" - - shared_output = hidden_states[0] - fc2_output = hidden_states[1] - expert_scale_factor = hidden_states[2] - expanded_idx_to_permuted_idx = hidden_states[3] - - moe_all_reduce_params = MoEAllReduceParams( - expanded_idx_to_permuted_idx=expanded_idx_to_permuted_idx, - expert_scale_factor=expert_scale_factor, - shared_expert_output=shared_output, - residual=residual, - norm_weight=self.next_layer_layernorm.weight, - eps=self.next_layer_layernorm.variance_epsilon, - is_cutlass_min_latency=False, - ) - hidden_states, residual = self.moe_allreduce( - fc2_output, all_reduce_params=moe_all_reduce_params) - else: - if self.next_layer_layernorm is not None: - hidden_states, residual = self.next_layer_layernorm( - hidden_states, residual) - - return hidden_states, residual - - def forward_mlp( - self, - hidden_states: torch.Tensor, - residual: torch.Tensor, - ) -> torch.Tensor: - - if self.fusion_config.PRE_MLP_FUSION: - act_fp4, act_sf, residual = self.allreduce( - hidden_states, - all_reduce_params=AllReduceParams( - fusion_op=AllReduceFusionOp.RESIDUAL_RMS_NORM_QUANT_NVFP4, - residual=residual, - norm_weight=self.post_attention_layernorm.weight, - scale=self.mlp.gate_up_proj.input_scale, - eps=self.post_attention_layernorm.variance_epsilon, - ), - ) - hidden_states = Fp4QuantizedTensor(act_fp4, act_sf) - else: - # No fusion - # We need to add twoshot allreduce here to avoid modifying MLA logic - hidden_states, residual = self.post_attention_layernorm( - hidden_states, residual) - - hidden_states = self.mlp( - hidden_states, - final_all_reduce_params=AllReduceParams(enable_allreduce=not ( - self.fusion_config.POST_MLP_FUSION or self.mlp_tp_size == 1)), - ) - - if self.fusion_config.POST_MLP_FUSION: - hidden_states, residual = self.allreduce( - hidden_states, - all_reduce_params=AllReduceParams( - fusion_op=AllReduceFusionOp.RESIDUAL_RMS_NORM, - residual=residual, - norm_weight=self.next_layer_layernorm.weight, - eps=self.next_layer_layernorm.variance_epsilon, - ), - ) - else: - if self.next_layer_layernorm is not None: - hidden_states, residual = self.next_layer_layernorm( - hidden_states, residual) - - return hidden_states, residual - - -class DeepseekV3MTP(DeepseekV3DecoderLayer): - - def __init__(self, model_config: ModelConfig[PretrainedConfig], - layer_idx: int, aux_stream_dict: Dict[AuxStreamType, - torch.cuda.Stream]): - super().__init__(model_config, layer_idx, aux_stream_dict) - config = model_config.pretrained_config - self.hidden_dim = config.hidden_size - self.moe_intermediate_size = config.moe_intermediate_size - self.num_experts = config.n_routed_experts - self.num_shared_experts = config.n_shared_experts - self.top_k = config.num_experts_per_tok - - self.aux_stream = aux_stream_dict[AuxStreamType.MoeShared] - self.event_dict = { - key: torch.cuda.Event() - for key in [EventType.Main, EventType.MoeShared] - } - - self.enorm = RMSNorm(hidden_size=config.hidden_size, - eps=config.rms_norm_eps, - dtype=config.torch_dtype) - - self.hnorm = RMSNorm(hidden_size=config.hidden_size, - eps=config.rms_norm_eps, - dtype=config.torch_dtype) - if model_config.mapping.enable_attention_dp: - self.eh_proj = Linear( - config.hidden_size * 2, - config.hidden_size, - bias=False, - dtype=config.torch_dtype, - skip_create_weights_in_init=model_config. - skip_create_weights_in_init, - ) - else: - self.eh_proj = Linear( - config.hidden_size * 2, - config.hidden_size, - bias=False, - dtype=config.torch_dtype, - tensor_parallel_mode=TensorParallelMode.ROW, - mapping=model_config.mapping, - reduce_output=True, - skip_create_weights_in_init=model_config. - skip_create_weights_in_init, - ) - - self.shared_head = DeepseekV3MTPHead(model_config) - - def forward( - self, - input_ids: torch.IntTensor, - position_ids: torch.IntTensor, - hidden_states: torch.Tensor, - embed_tokens: Embedding, - attn_metadata: AttentionMetadata, - all_rank_num_tokens: Optional[List[int]] = None, - all_rank_max_num_tokens: Optional[int] = None, - **kwargs, - ) -> Tuple[torch.Tensor, torch.Tensor]: - - def norm_embeds(): - return self.enorm(embed_tokens(input_ids)) #emdedding - - def norm_hidden(): - return self.hnorm(hidden_states) - - inputs_embeds, hidden_states = maybe_execute_in_parallel( - norm_embeds, - norm_hidden, - self.event_dict[EventType.Main], - self.event_dict[EventType.MoeShared], - self.aux_stream, - ) - hidden_states = torch.concat([inputs_embeds, hidden_states], dim=-1) - # Split hidden_states columnwise based on TP - tp_size = self.model_config.mapping.tp_size - tp_rank = self.model_config.mapping.tp_rank - - if tp_size > 1 and not (self.model_config.mapping.enable_attention_dp): - hidden_states = torch.chunk(hidden_states, tp_size, dim=-1)[tp_rank] - hidden_states = self.eh_proj(hidden_states) - - # Input layer norm - residual = hidden_states - hidden_states = self.input_layernorm(hidden_states) - - # Self Attention - hidden_states = self.self_attn( - position_ids=position_ids, - hidden_states=hidden_states, - attn_metadata=attn_metadata, - all_reduce_params=AllReduceParams( - enable_allreduce=not (self.disable_attn_allreduce)), - **kwargs, - ) - - # MTP Layer Must have sparse MOE - if self.fusion_config.PRE_MOE_FUSION: - hidden_states, residual = self.allreduce( - hidden_states, - all_reduce_params=AllReduceParams( - fusion_op=AllReduceFusionOp.RESIDUAL_RMS_NORM, - residual=residual, - norm_weight=self.post_attention_layernorm.weight, - eps=self.post_attention_layernorm.variance_epsilon, - ), - ) - else: - hidden_states, residual = self.post_attention_layernorm( - hidden_states, residual) - - # MoE - hidden_states = self.mlp( - hidden_states, - all_rank_num_tokens=all_rank_num_tokens, - all_rank_max_num_tokens=all_rank_max_num_tokens, - final_all_reduce_params=AllReduceParams( - enable_allreduce=not (self.fusion_config.POST_MOE_FUSION - or self.mapping.tp_size == 1)), - ) - - if self.fusion_config.POST_MOE_FUSION: - hidden_states, residual = self.allreduce( - hidden_states, - all_reduce_params=AllReduceParams( - fusion_op=AllReduceFusionOp.RESIDUAL_RMS_NORM, - residual=residual, - norm_weight=self.shared_head.norm.weight, - eps=self.shared_head.norm.variance_epsilon, - ), - ) - else: - hidden_states, _ = self.shared_head.norm(hidden_states, residual) - - return hidden_states - - class MTPForCausalLM(nn.Module): def __init__( self, - model, - model_config: PretrainedConfig, + model_config: ModelConfig[PretrainedConfig], start_layer_idx: int = 0, + lm_head: nn.Module = None, + model: nn.Module = None, ): super().__init__() + # Import here to avoid circular import + from .modeling_deepseekv3 import DeepseekV3MTP + spec_dec_mode = model_config.spec_config.spec_dec_mode assert spec_dec_mode.is_mtp() - self.embed_tokens = model.embed_tokens mtp_num_layers = 1 if spec_dec_mode.is_mtp_eagle( ) else model_config.spec_config.num_nextn_predict_layers @@ -1269,19 +343,22 @@ def __init__( model.aux_stream_dict) for layer_idx in range(mtp_num_layers) ]) + self.lm_head = lm_head + self.embed_tokens = model.embed_tokens -def get_draft_model(model, model_config, draft_config): +def get_draft_model(model_config, draft_config, lm_head, model): assert getattr(model_config, 'spec_config', None) != None spec_dec_mode = model_config.spec_config.spec_dec_mode if spec_dec_mode.is_eagle3_one_model(): return Eagle3ForCausalLM( draft_config, model_config.pretrained_config.num_hidden_layers) elif spec_dec_mode.is_mtp(): - return MTPForCausalLM(model, model_config, - model_config.pretrained_config.num_hidden_layers) + return MTPForCausalLM(model_config, + model_config.pretrained_config.num_hidden_layers, + lm_head, model) else: - raise NotImplemented( + raise NotImplementedError( f"get_draft_model does not support speculative decoding mode {spec_dec_mode}." ) @@ -1311,8 +388,8 @@ def __init__(self, model: TModel, model_config: ModelConfig[TConfig]): draft_config.quant_config.kv_cache_quant_algo = \ model_config.quant_config.kv_cache_quant_algo - self.draft_model = get_draft_model(model, model_config, - draft_config) + self.draft_model = get_draft_model(model_config, draft_config, + self.lm_head, self.model) self.spec_worker = get_spec_worker(model_config.spec_config, model_config, model_config.mapping) @@ -1349,7 +426,6 @@ def forward( position_ids=position_ids, hidden_states=hidden_states, logits=logits, - lm_head=self.lm_head, attn_metadata=attn_metadata, spec_metadata=spec_metadata, draft_model=self.draft_model) diff --git a/tensorrt_llm/_torch/speculative/eagle3.py b/tensorrt_llm/_torch/speculative/eagle3.py index c30f407c93a..417becf12f3 100644 --- a/tensorrt_llm/_torch/speculative/eagle3.py +++ b/tensorrt_llm/_torch/speculative/eagle3.py @@ -268,7 +268,7 @@ def __init__(self, spec_config: "EagleDecodingConfig", mapping: Mapping): # Skip torch.compile for now since current Torch is not compatible with Triton 3.4 # @torch.compile(options={"max-autotune": True}) - def forward(self, input_ids, position_ids, hidden_states, logits, lm_head, + def forward(self, input_ids, position_ids, hidden_states, logits, attn_metadata, spec_metadata, draft_model): batch_size = attn_metadata.num_seqs num_contexts = attn_metadata.num_contexts diff --git a/tensorrt_llm/_torch/speculative/mtp.py b/tensorrt_llm/_torch/speculative/mtp.py index 22a8ea5b197..2658ce539b5 100644 --- a/tensorrt_llm/_torch/speculative/mtp.py +++ b/tensorrt_llm/_torch/speculative/mtp.py @@ -330,7 +330,6 @@ def forward( position_ids, hidden_states, logits, - lm_head, attn_metadata, spec_metadata, draft_model, @@ -472,7 +471,7 @@ def forward( for _, mtp_layer in enumerate(draft_model.mtp_layers): hidden_states = mtp_layer(embed_tokens=draft_model.embed_tokens, **draft_inputs) - logits = mtp_layer.shared_head(hidden_states, lm_head, + logits = mtp_layer.shared_head(hidden_states, draft_model.lm_head, attn_metadata).float() new_draft_token = self.draft_sampler(logits) next_draft_tokens.append(new_draft_token) @@ -517,10 +516,9 @@ def skip_forward( position_ids, hidden_states, logits, - lm_head, attn_metadata, spec_metadata, - mtp_layers, + draft_model, ): batch_size = attn_metadata.num_seqs mtp_num_modules = self.spec_config.num_nextn_predict_layers @@ -1126,7 +1124,6 @@ def forward( position_ids, hidden_states, logits, - lm_head, attn_metadata, spec_metadata, draft_model, @@ -1196,7 +1193,8 @@ def prepare_position_ids_and_last_tokens(position_ids, attn_metadata): # All of the seq_len are 1, use batch_indices_cuda as gather_ids gather_ids = spec_metadata.batch_indices_cuda[:batch_size] logits = draft_model.mtp_layers[0].shared_head( - hidden_states[gather_ids], lm_head, attn_metadata, True) + hidden_states[gather_ids], draft_model.lm_head, attn_metadata, + True) new_draft_token = self.draft_sampler(logits) hidden_states, position_ids = self.update_draft_tokens(