Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions examples/llm-api/quickstart_advanced.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,16 +170,15 @@ def setup_llm(args, **kwargs):

if spec_decode_algo == 'MTP':
if not args.use_one_model:
print(
"MTP only supports one model style spec decode; ignoring default use_one_model=False"
)

print("Running MTP eagle with two model style.")
spec_config = MTPDecodingConfig(
num_nextn_predict_layers=args.spec_decode_max_draft_len,
use_relaxed_acceptance_for_thinking=args.
use_relaxed_acceptance_for_thinking,
relaxed_topk=args.relaxed_topk,
relaxed_delta=args.relaxed_delta)
relaxed_delta=args.relaxed_delta,
mtp_eagle_one_model=args.use_one_model,
speculative_model_dir=args.model_dir)
elif spec_decode_algo == "EAGLE3":
spec_config = EagleDecodingConfig(
max_draft_len=args.spec_decode_max_draft_len,
Expand Down
2 changes: 2 additions & 0 deletions tensorrt_llm/_torch/models/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ def from_config(
"") # Strip the appended EAGLE3
if hasattr(config.pretrained_config, "draft_vocab_size"):
model_arch = "EAGLE3" + model_arch
if model_arch == "DeepseekV3ForCausalLM" and config.spec_config is not None and config.spec_config.max_draft_len == 0:
model_arch = "MTPDraftModelForCausalLM"

cls = MODEL_CLASS_MAPPING.get(model_arch)
if cls is None:
Expand Down
2,407 changes: 1,216 additions & 1,191 deletions tensorrt_llm/_torch/models/modeling_deepseekv3.py
100644 → 100755

Large diffs are not rendered by default.

117 changes: 113 additions & 4 deletions tensorrt_llm/_torch/models/modeling_speculative.py
100644 → 100755
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict, Generic, Optional, Tuple
from typing import Dict, Generic, List, Optional, Tuple

import torch
from torch import nn
Expand All @@ -18,6 +18,7 @@
from ..modules.rms_norm import RMSNorm
from ..pyexecutor.guided_decoder import CapturableGuidedDecoder
from ..speculative import SpecMetadata, get_spec_worker
from ..utils import AuxStreamType
from .checkpoints.base_weight_mapper import BaseWeightMapper
from .modeling_utils import (DecoderModel, DecoderModelForCausalLM, TModel,
register_auto_model)
Expand Down Expand Up @@ -342,8 +343,8 @@ def __init__(
from .modeling_deepseekv3 import DeepseekV3MTP

spec_dec_mode = model_config.spec_config.spec_dec_mode
assert spec_dec_mode.is_mtp()
mtp_num_layers = 1 if spec_dec_mode.is_mtp_eagle(
assert spec_dec_mode.is_mtp_one_model()
mtp_num_layers = 1 if spec_dec_mode.is_mtp_eagle_one_model(
) else model_config.spec_config.num_nextn_predict_layers

moe_load_balancer_set_repeated_for_next_layer(
Expand All @@ -358,16 +359,124 @@ def __init__(
self.embed_tokens = model.embed_tokens


class MTPDraftModel(nn.Module):

def __init__(self, model_config: ModelConfig[PretrainedConfig],
layer_idx: int, aux_stream_dict: Dict[AuxStreamType,
torch.cuda.Stream]):
super().__init__()
# Import here to avoid circular import
from .modeling_deepseekv3 import DeepseekV3MTP

mtp_layer = DeepseekV3MTP(model_config,
layer_idx,
aux_stream_dict,
is_separate_draft_engine=True)
setattr(self, f"layers.{layer_idx}", mtp_layer)
self.layers = mtp_layer
self.layer_idx = layer_idx
self.config = model_config.pretrained_config
self.embed_tokens = Embedding(
self.config.vocab_size,
self.config.hidden_size,
dtype=self.config.torch_dtype,
)

def __repr__(self):
"""Custom string representation to display layer index"""
return f"(layers): ({self.layer_idx}): {repr(self.layers)}"

def forward(
self,
input_ids: torch.IntTensor,
position_ids: torch.IntTensor,
hidden_states: torch.Tensor,
attn_metadata: AttentionMetadata,
all_rank_num_tokens: Optional[List[int]] = None,
**kwargs,
) -> Tuple[torch.Tensor, torch.Tensor]:
hidden_states = self.layers(
input_ids,
position_ids,
hidden_states,
embed_tokens=self.embed_tokens,
attn_metadata=attn_metadata,
all_rank_num_tokens=all_rank_num_tokens,
)

return hidden_states


@register_auto_model("MTPDraftModelForCausalLM")
class MTPDraftModelForCausalLM(DecoderModelForCausalLM[MTPDraftModel,
PretrainedConfig]):

def __init__(self, model_config: ModelConfig[PretrainedConfig]):
self.model_config = model_config
aux_stream_list = [torch.cuda.Stream() for _ in range(2)]
self.aux_stream_dict = {
AuxStreamType.Attention: aux_stream_list[0],
AuxStreamType.MoeShared: aux_stream_list[0],
AuxStreamType.MoeChunkingOverlap: aux_stream_list[1],
}
super().__init__(
MTPDraftModel(self.model_config,
self.model_config.pretrained_config.num_hidden_layers,
self.aux_stream_dict),
config=self.model_config,
hidden_size=self.model_config.pretrained_config.hidden_size,
vocab_size=self.model_config.pretrained_config.vocab_size)

def load_weights(self, weights: Dict):
# Import here to avoid circular import
from .modeling_deepseekv3 import DeepseekV3WeightLoader
weight_loader = DeepseekV3WeightLoader(self, is_draft_model=True)
weight_loader.load_weights(weights)

def load_weights_from_target_model(self,
target_model: torch.nn.Module) -> None:
if self.model.embed_tokens is None:
self.model.embed_tokens = target_model.model.embed_tokens
self.lm_head = target_model.lm_head

def forward(self,
attn_metadata: AttentionMetadata,
input_ids: torch.IntTensor = None,
position_ids: torch.IntTensor = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
return_context_logits: bool = False,
spec_metadata: Optional[SpecMetadata] = None,
hidden_states: torch.Tensor = None,
**kwargs) -> torch.Tensor:

hidden_states = spec_metadata.get_hidden_states()
output = self.model(
input_ids=input_ids,
position_ids=position_ids,
hidden_states=hidden_states,
attn_metadata=attn_metadata,
all_rank_num_tokens=attn_metadata.all_rank_num_tokens,
**kwargs)
return self.logits_processor.forward(
output,
self.lm_head,
attn_metadata,
return_context_logits,
)


def get_draft_model(model_config, draft_config, lm_head, model):
assert getattr(model_config, 'spec_config', None) != None
spec_dec_mode = model_config.spec_config.spec_dec_mode
if spec_dec_mode.is_eagle3_one_model():
return Eagle3ForCausalLM(
draft_config, model_config.pretrained_config.num_hidden_layers)
elif spec_dec_mode.is_mtp():
elif spec_dec_mode.is_mtp_one_model():
return MTPForCausalLM(model_config,
model_config.pretrained_config.num_hidden_layers,
lm_head, model)
elif spec_dec_mode.is_mtp_eagle():
return MTPDraftModelForCausalLM(model_config)
else:
raise NotImplementedError(
f"get_draft_model does not support speculative decoding mode {spec_dec_mode}."
Expand Down
3 changes: 3 additions & 0 deletions tensorrt_llm/_torch/pyexecutor/py_executor_creator.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,9 @@ def drafting_loop_wrapper(model):
is_draft_model=True,
drafting_loop_wrapper=drafting_loop_wrapper,
)
# For DeepseekV3 MTP, we need to set the num_hidden_layers to 1 for the draft model
if spec_config.spec_dec_mode.is_mtp_eagle():
draft_model_engine.model.model_config.pretrained_config.num_hidden_layers = 1
draft_model_engine.kv_cache_manager_key = ResourceManagerType.DRAFT_KV_CACHE_MANAGER
draft_model_engine.load_weights_from_target_model(
model_engine.model)
Expand Down
4 changes: 2 additions & 2 deletions tensorrt_llm/_torch/speculative/eagle3.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,18 +92,18 @@ class Eagle3SpecMetadata(SpecMetadata):
is_draft_model: bool = False
is_first_draft: bool = False
eagle3_resource_manager: Optional[Eagle3ResourceManager] = None
is_mtp_eagle: bool = False

def __post_init__(self):
if self.is_draft_model:
self.layers_to_capture = (self.num_layers - 1, )
elif self.layers_to_capture is None:
if self.num_layers == 1:
if self.num_layers == 1 or self.is_mtp_eagle:
self.layers_to_capture = (self.num_layers - 1, )
else:
if self.num_layers <= 5:
raise ValueError(
"Not enough hidden layers for default EAGLE3 capture")

self.layers_to_capture = (1, self.num_layers // 2 - 1,
self.num_layers - 4)
else:
Expand Down
28 changes: 17 additions & 11 deletions tensorrt_llm/_torch/speculative/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
class SpeculativeDecodingMode(IntEnum):
MTP = auto()
MTP_EAGLE = auto()
MTP_EAGLE_ONE_MODEL = auto()
EAGLE3 = auto()
EAGLE3_ONE_MODEL = auto()
NGRAM = auto()
Expand All @@ -20,8 +21,11 @@ class SpeculativeDecodingMode(IntEnum):
NONE = auto()
AUTO = auto()

def is_mtp(self):
return self == SpeculativeDecodingMode.MTP or self == SpeculativeDecodingMode.MTP_EAGLE
def is_mtp_one_model(self):
return self == SpeculativeDecodingMode.MTP or self == SpeculativeDecodingMode.MTP_EAGLE_ONE_MODEL

def is_mtp_eagle_one_model(self):
return self == SpeculativeDecodingMode.MTP_EAGLE_ONE_MODEL

def is_mtp_vanilla(self):
return self == SpeculativeDecodingMode.MTP
Expand All @@ -33,7 +37,7 @@ def is_eagle3(self):
return self == SpeculativeDecodingMode.EAGLE3

def use_one_engine(self):
return self.is_mtp() or self.is_eagle3_one_model()
return self.is_eagle3_one_model() or self.is_mtp_one_model()

def is_eagle3_one_model(self):
return self == SpeculativeDecodingMode.EAGLE3_ONE_MODEL
Expand All @@ -51,31 +55,32 @@ def is_draft_target(self):
return self == SpeculativeDecodingMode.DRAFT_TARGET

def without_logits(self):
return self.is_mtp() or self.is_eagle3_one_model()
return self.is_mtp_one_model() or self.is_eagle3_one_model()

def needs_kv_cache_rewind(self):
return self.is_mtp() or self.is_eagle3_one_model() or self.is_ngram()
return self.is_mtp_one_model() or self.is_eagle3_one_model(
) or self.is_ngram()

def support_overlap_scheduler(self):
return self.is_mtp() or self.is_eagle3_one_model(
return self.is_mtp_one_model() or self.is_eagle3_one_model(
) or self.has_draft_model()

def support_guided_decoder(self):
return self.is_none() or self.has_spec_drafter()

def support_capturable_guided_decoder(self):
return self.is_mtp() or self.is_eagle3_one_model()
return self.is_mtp_one_model() or self.is_eagle3_one_model()

def has_draft_model(self):
return self.is_eagle3() or self.is_draft_target()
return self.is_eagle3() or self.is_draft_target() or self.is_mtp_eagle()

def needs_kv_cache_recompute(self):
"""
Whether the draft model needs to recompute the kv cache.
If true, the 1st draft model forward will recompute the kv cache for
the accepted draft tokens.
"""
return self.is_eagle3()
return self.is_eagle3() or self.is_mtp_eagle()

def need_load_draft_weights(self):
"""
Expand All @@ -85,11 +90,12 @@ def need_load_draft_weights(self):
return self.is_eagle3_one_model()

def has_spec_decoder(self):
return self.is_mtp() or self.is_eagle3() or self.is_eagle3_one_model()
return self.is_mtp_one_model() or self.is_mtp_eagle() or self.is_eagle3(
) or self.is_eagle3_one_model()

def has_spec_drafter(self):
return self.is_eagle3() or self.is_draft_target() or self.is_ngram(
) or self.is_user_provided()
) or self.is_user_provided() or self.is_mtp_eagle()

def extend_ctx(self, attention_backend: Type[AttentionBackend]):
"""
Expand Down
2 changes: 1 addition & 1 deletion tensorrt_llm/_torch/speculative/model_drafter.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def get_draft_model_prompt(spec_dec_mode: SpeculativeDecodingMode,
Can be used to modify prompts for speculative algorithms that need to update tokens
before drafting.
"""
if spec_dec_mode.is_eagle3():
if spec_dec_mode.is_eagle3() or spec_dec_mode.is_mtp_eagle():
# EAGLE3 always throws away the first token when processing draft inputs
return input_tokens[1:]
return input_tokens
Expand Down
6 changes: 3 additions & 3 deletions tensorrt_llm/_torch/speculative/mtp.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def all_rank_num_seqs(self):
@all_rank_num_seqs.setter
def all_rank_num_seqs(self, value: List[int]):
self._all_rank_num_seqs = value
if self.spec_dec_mode.is_mtp_eagle():
if self.spec_dec_mode.is_mtp_eagle_one_model():
self.subseq_all_rank_num_tokens = value

def prepare(self):
Expand All @@ -172,7 +172,7 @@ def prepare(self):
# while MTP Eagle worker uses (max_draft_len + 1) input tokens in the 1st draft
# forward and only one input token in the following draft forward.
# This num_tokens is used to set the all_rank_num_tokens for attention dp.
if not self.spec_dec_mode.is_mtp_eagle():
if not self.spec_dec_mode.is_mtp_eagle_one_model():
self.num_tokens -= self.num_generations

if self.mtp_hidden_states_manager is not None: # MTP vanilla or use relaxed acceptance
Expand All @@ -183,7 +183,7 @@ def prepare(self):
mtp_slot_ids.append(slot_id)

# MTP Vanilla: Update mtp hidden states and past tokens
if self.spec_dec_mode.is_mtp():
if self.spec_dec_mode.is_mtp_one_model():
mtp_hidden_states_ptrs = []
mtp_past_tokens_ptrs = []
for slot_id in mtp_slot_ids:
Expand Down
Loading