Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 29 additions & 68 deletions tensorrt_llm/_torch/models/modeling_deepseekv3.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@
from ..distributed import (AllReduce, AllReduceFusionOp, AllReduceParams,
MoEAllReduce, MoEAllReduceParams, allgather)
from ..model_config import ModelConfig
from ..models.modeling_utils import ModelConfig, QuantConfig
from ..modules.attention import MLA
from ..modules.decoder_layer import DecoderLayer
from ..modules.embedding import Embedding
Expand All @@ -65,10 +64,10 @@
from ..modules.multi_stream_utils import maybe_execute_in_parallel
from ..modules.rms_norm import RMSNorm
from ..peft.lora.layer import LoraLayer
from ..speculative import MTPEagleWorker, MTPSpecMetadata, MTPWorker
from ..speculative import MTPSpecMetadata, SpecMetadata
from ..utils import AuxStreamType, EventType, Fp4QuantizedTensor
from .modeling_utils import (DecoderModel, DecoderModelForCausalLM,
EagerFusionConfig, filter_weights,
from .modeling_speculative import SpecDecOneEngineForCausalLM
from .modeling_utils import (DecoderModel, EagerFusionConfig, filter_weights,
register_auto_model)


Expand Down Expand Up @@ -534,7 +533,8 @@ def compute_routed_output(self, hidden_states, hidden_states_fp4,
router_logits = self.gate(hidden_states)

routed_output = self.experts(
hidden_states_fp4 or hidden_states,
hidden_states_fp4
if hidden_states_fp4 is not None else hidden_states,
router_logits,
do_finalize=do_finalize,
output_dtype=hidden_states.dtype,
Expand All @@ -558,8 +558,9 @@ def forward(
assert not self.use_dp

def _compute_shared_output():
shared_output = self.shared_experts(hidden_states_fp4
or hidden_states)
shared_output = self.shared_experts(
hidden_states_fp4
if hidden_states_fp4 is not None else hidden_states)
if self.shared_output_scale is not None:
shared_output *= self.shared_output_scale
return shared_output
Expand Down Expand Up @@ -743,7 +744,7 @@ def forward(
attn_metadata: AttentionMetadata,
residual: torch.Tensor,
**kwargs,
) -> torch.Tensor:
) -> Tuple[torch.Tensor, torch.Tensor]:
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
Expand Down Expand Up @@ -775,7 +776,7 @@ def forward_MoE(
hidden_states: torch.Tensor,
attn_metadata: AttentionMetadata,
residual: torch.Tensor,
) -> torch.Tensor:
) -> Tuple[torch.Tensor, torch.Tensor]:

def _run_MoE(hidden_states, hidden_states_fp4, do_finalize):
return self.mlp(
Expand Down Expand Up @@ -859,7 +860,7 @@ def forward_mlp(
self,
hidden_states: torch.Tensor,
residual: torch.Tensor,
) -> torch.Tensor:
) -> Tuple[torch.Tensor, torch.Tensor]:

if self.fusion_config.PRE_MLP_FUSION:
act_fp4, act_sf, residual = self.allreduce(
Expand Down Expand Up @@ -963,7 +964,7 @@ def forward(
all_rank_num_tokens: Optional[List[int]] = None,
all_rank_max_num_tokens: Optional[int] = None,
**kwargs,
) -> Tuple[torch.Tensor, torch.Tensor]:
) -> torch.Tensor:

def norm_embeds():
return self.enorm(embed_tokens(input_ids)) #emdedding
Expand Down Expand Up @@ -1078,6 +1079,8 @@ def forward(
input_ids: Optional[torch.IntTensor] = None,
position_ids: Optional[torch.IntTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
spec_metadata: Optional[SpecMetadata] = None,
**kwargs,
) -> torch.Tensor:
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError(
Expand All @@ -1102,8 +1105,8 @@ def forward(


@register_auto_model("DeepseekV3ForCausalLM")
class DeepseekV3ForCausalLM(DecoderModelForCausalLM[DeepseekV3Model,
PretrainedConfig]):
class DeepseekV3ForCausalLM(SpecDecOneEngineForCausalLM[DeepseekV3Model,
PretrainedConfig]):

def __init__(self, model_config: ModelConfig[PretrainedConfig]):
# Rename some keys of quant_config_dict to support legacy checkpoints
Expand All @@ -1118,10 +1121,9 @@ def __init__(self, model_config: ModelConfig[PretrainedConfig]):
model_config._frozen = False
model_config.quant_config_dict = quant_config_dict
model_config._frozen = True
super().__init__(DeepseekV3Model(model_config),
config=model_config,
hidden_size=model_config.pretrained_config.hidden_size,
vocab_size=model_config.pretrained_config.vocab_size)

super().__init__(model=DeepseekV3Model(model_config),
model_config=model_config)

self.model_nextn = 0
if model_config.spec_config is not None:
Expand All @@ -1131,23 +1133,7 @@ def __init__(self, model_config: ModelConfig[PretrainedConfig]):
assert ckpt_nextn > 0, "There is not MTP modules in the checkpoint."
if ckpt_nextn == 1 and not model_config.spec_config.use_mtp_vanilla:
moe_load_balancer_set_repeated_for_next_layer(model_nextn)
mtp_layer = DeepseekV3MTP(model_config, self.num_hidden_layers,
self.model.aux_stream_dict)
self.model.layers.append(mtp_layer)
self.epilogue.append(mtp_layer)
self.mtp_worker = MTPEagleWorker(model_config.spec_config,
model_config)
else:
mtp_layers = nn.ModuleList([
DeepseekV3MTP(model_config,
layer_idx + self.num_hidden_layers,
self.model.aux_stream_dict)
for layer_idx in range(model_nextn)
])
self.model.layers.extend(mtp_layers)
self.epilogue.extend(mtp_layers)
self.mtp_worker = MTPWorker(model_config.spec_config,
model_config)
# modify the QuantConfig to support duplicated mtp layers
if model_config.quant_config.exclude_modules is not None:
extend_exclude_modules = []
Expand All @@ -1165,7 +1151,9 @@ def __init__(self, model_config: ModelConfig[PretrainedConfig]):
ckpt_prefix, model_prefix))
self.model_config.quant_config.exclude_modules.extend(
extend_exclude_modules)
self.epilogue.append(self.mtp_worker)
self.model.layers.extend(self.draft_model.mtp_layers)
self.epilogue.extend(self.draft_model.mtp_layers)
self.epilogue.append(self.spec_worker)

def forward(
self,
Expand All @@ -1178,40 +1166,13 @@ def forward(
**kwargs,
) -> torch.Tensor:
attn_metadata.num_generations_per_batch = self.model_nextn + 1
hidden_states = self.model(
input_ids=input_ids,
attn_metadata=attn_metadata,
position_ids=position_ids,
inputs_embeds=inputs_embeds,
)

if spec_metadata and spec_metadata.spec_dec_mode.is_mtp():
# get logits
logits = self.logits_processor.forward(
hidden_states[spec_metadata.gather_ids],
self.lm_head,
attn_metadata,
True,
)
# get accepted tokens and next draft tokens
return self.mtp_worker(
input_ids=input_ids,
position_ids=position_ids,
hidden_states=hidden_states,
logits=logits,
lm_head=self.lm_head,
embed_tokens=self.model.embed_tokens,
attn_metadata=attn_metadata,
spec_metadata=spec_metadata,
mtp_layers=self.model.layers[self.num_hidden_layers:])
else:
logits = self.logits_processor.forward(
hidden_states,
self.lm_head,
attn_metadata,
return_context_logits,
)
return logits
return super().forward(attn_metadata=attn_metadata,
input_ids=input_ids,
position_ids=position_ids,
inputs_embeds=inputs_embeds,
spec_metadata=spec_metadata,
return_context_logits=return_context_logits,
**kwargs)

def load_weights(self, weights: Dict):

Expand Down
68 changes: 50 additions & 18 deletions tensorrt_llm/_torch/models/modeling_speculative.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import torch
from torch import nn
from transformers import LlamaConfig
from transformers import LlamaConfig, PretrainedConfig

from tensorrt_llm._torch.models.checkpoints.base_weight_mapper import \
BaseWeightMapper
Expand Down Expand Up @@ -320,14 +320,45 @@ def apply_eagle3_fc(self, hidden_states: torch.Tensor) -> torch.Tensor:
return hidden_states


def get_draft_model(model_config, draft_config):
class MTPForCausalLM(nn.Module):

def __init__(
self,
model_config: ModelConfig[PretrainedConfig],
start_layer_idx: int = 0,
lm_head: nn.Module = None,
model: nn.Module = None,
):
super().__init__()
# Import here to avoid circular import
from .modeling_deepseekv3 import DeepseekV3MTP

spec_dec_mode = model_config.spec_config.spec_dec_mode
assert spec_dec_mode.is_mtp()
mtp_num_layers = 1 if spec_dec_mode.is_mtp_eagle(
) else model_config.spec_config.num_nextn_predict_layers

self.mtp_layers = nn.ModuleList([
DeepseekV3MTP(model_config, layer_idx + start_layer_idx,
model.aux_stream_dict)
for layer_idx in range(mtp_num_layers)
])
self.lm_head = lm_head
self.embed_tokens = model.embed_tokens


def get_draft_model(model_config, draft_config, lm_head, model):
assert getattr(model_config, 'spec_config', None) != None
spec_dec_mode = model_config.spec_config.spec_dec_mode
if spec_dec_mode.is_eagle3_one_model():
return Eagle3ForCausalLM(
draft_config, model_config.pretrained_config.num_hidden_layers)
elif spec_dec_mode.is_mtp():
return MTPForCausalLM(model_config,
model_config.pretrained_config.num_hidden_layers,
lm_head, model)
else:
raise NotImplemented(
raise NotImplementedError(
f"get_draft_model does not support speculative decoding mode {spec_dec_mode}."
)

Expand All @@ -341,23 +372,24 @@ def __init__(self, model: TModel, model_config: ModelConfig[TConfig]):
hidden_size=model_config.pretrained_config.hidden_size,
vocab_size=model_config.pretrained_config.vocab_size)
self.draft_model = None
if getattr(
model_config, 'spec_config', None
) and model_config.spec_config.spec_dec_mode.use_one_engine():
draft_config = ModelConfig.from_pretrained(
model_config.spec_config.speculative_model_dir,
trust_remote_code=True,
attn_backend=model_config.attn_backend,
moe_backend=model_config.moe_backend,
mapping=model_config.mapping,
spec_config=model_config.spec_config,
max_num_tokens=model_config.max_num_tokens,
moe_max_num_tokens=model_config.moe_max_num_tokens)

draft_config.quant_config.kv_cache_quant_algo = \
spec_config = getattr(model_config, 'spec_config', None)
if spec_config and spec_config.spec_dec_mode.use_one_engine():
draft_config = None
if spec_config.spec_dec_mode.is_eagle3_one_model():
draft_config = ModelConfig.from_pretrained(
model_config.spec_config.speculative_model_dir,
trust_remote_code=True,
attn_backend=model_config.attn_backend,
moe_backend=model_config.moe_backend,
mapping=model_config.mapping,
spec_config=model_config.spec_config,
max_num_tokens=model_config.max_num_tokens,
moe_max_num_tokens=model_config.moe_max_num_tokens)
draft_config.quant_config.kv_cache_quant_algo = \
model_config.quant_config.kv_cache_quant_algo

self.draft_model = get_draft_model(model_config, draft_config)
self.draft_model = get_draft_model(model_config, draft_config,
self.lm_head, self.model)
self.spec_worker = get_spec_worker(model_config.spec_config,
model_config,
model_config.mapping)
Expand Down
3 changes: 3 additions & 0 deletions tensorrt_llm/_torch/speculative/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ class SpeculativeDecodingMode(IntEnum):
def is_mtp(self):
return self == SpeculativeDecodingMode.MTP or self == SpeculativeDecodingMode.MTP_EAGLE

def is_mtp_vanilla(self):
return self == SpeculativeDecodingMode.MTP

def is_mtp_eagle(self):
return self == SpeculativeDecodingMode.MTP_EAGLE

Expand Down
32 changes: 14 additions & 18 deletions tensorrt_llm/_torch/speculative/mtp.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,11 +330,9 @@ def forward(
position_ids,
hidden_states,
logits,
lm_head,
embed_tokens,
attn_metadata,
spec_metadata,
mtp_layers,
draft_model,
):
'''
Example:
Expand Down Expand Up @@ -470,9 +468,10 @@ def forward(
next_draft_tokens = []
last_tokens_idx = torch.cumsum(
attn_metadata.seq_lens_cuda, dim=0, dtype=torch.long) - 1
for _, mtp_layer in enumerate(mtp_layers):
hidden_states = mtp_layer(embed_tokens=embed_tokens, **draft_inputs)
logits = mtp_layer.shared_head(hidden_states, lm_head,
for _, mtp_layer in enumerate(draft_model.mtp_layers):
hidden_states = mtp_layer(embed_tokens=draft_model.embed_tokens,
**draft_inputs)
logits = mtp_layer.shared_head(hidden_states, draft_model.lm_head,
attn_metadata).float()
new_draft_token = self.draft_sampler(logits)
next_draft_tokens.append(new_draft_token)
Expand Down Expand Up @@ -517,11 +516,9 @@ def skip_forward(
position_ids,
hidden_states,
logits,
lm_head,
embed_tokens,
attn_metadata,
spec_metadata,
mtp_layers,
draft_model,
):
batch_size = attn_metadata.num_seqs
mtp_num_modules = self.spec_config.num_nextn_predict_layers
Expand Down Expand Up @@ -1127,11 +1124,9 @@ def forward(
position_ids,
hidden_states,
logits,
lm_head,
embed_tokens,
attn_metadata,
spec_metadata,
mtp_layers,
draft_model,
):
batch_size = attn_metadata.num_seqs
num_contexts = attn_metadata.num_contexts
Expand Down Expand Up @@ -1172,8 +1167,8 @@ def prepare_position_ids_and_last_tokens(position_ids, attn_metadata):
next_draft_tokens = []
for i in range(self.mtp_num_modules):
if i == 0:
hidden_states = mtp_layers[0](
embed_tokens=embed_tokens,
hidden_states = draft_model.mtp_layers[0](
embed_tokens=draft_model.embed_tokens,
all_rank_num_tokens=spec_metadata.all_rank_num_tokens,
all_rank_max_num_tokens=spec_metadata.
all_rank_max_num_tokens,
Expand All @@ -1186,8 +1181,8 @@ def prepare_position_ids_and_last_tokens(position_ids, attn_metadata):
gather_ids = torch.concat(
[last_tokens_idx[:num_contexts], gather_ids_gen], dim=0)
else:
hidden_states = mtp_layers[0](
embed_tokens=embed_tokens,
hidden_states = draft_model.mtp_layers[0](
embed_tokens=draft_model.embed_tokens,
all_rank_num_tokens=spec_metadata.
subseq_all_rank_num_tokens,
all_rank_max_num_tokens=max(
Expand All @@ -1197,8 +1192,9 @@ def prepare_position_ids_and_last_tokens(position_ids, attn_metadata):
**inputs)
# All of the seq_len are 1, use batch_indices_cuda as gather_ids
gather_ids = spec_metadata.batch_indices_cuda[:batch_size]
logits = mtp_layers[0].shared_head(hidden_states[gather_ids],
lm_head, attn_metadata, True)
logits = draft_model.mtp_layers[0].shared_head(
hidden_states[gather_ids], draft_model.lm_head, attn_metadata,
True)
new_draft_token = self.draft_sampler(logits)

hidden_states, position_ids = self.update_draft_tokens(
Expand Down
Loading