Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions cpp/tensorrt_llm/thop/mtpOp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,8 @@ std::tuple<th::Tensor, th::Tensor> mtp_sampling_and_accepted_draft_tokens_op(th:
TLLM_CHECK(draftTokensSizes[0] == (numGenerationRequest * numMTPModules));

auto stream = at::cuda::getCurrentCUDAStream(logits.get_device());
auto acceptedTokens = torch::empty(
{batchSize, numMTPModules + 1}, at::TensorOptions().dtype(torch::kInt32).device(logits.device()));
auto acceptedTokens
= torch::ones({batchSize, numMTPModules + 1}, at::TensorOptions().dtype(torch::kInt32).device(logits.device()));
auto numAcceptedTokens = torch::ones({batchSize}, at::TensorOptions().dtype(torch::kInt32).device(logits.device()));

// Fill params
Expand Down
37 changes: 19 additions & 18 deletions tensorrt_llm/_torch/models/modeling_deepseekv3.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,19 +134,23 @@ def __init__(self, model_config: ModelConfig[PretrainedConfig]):
eps=config.rms_norm_eps,
dtype=config.torch_dtype)

def forward(self, hidden_states: torch.Tensor, lm_head: Linear,
attn_metadata: AttentionMetadata) -> torch.Tensor:
if attn_metadata is not None:
last_tokens = torch.cumsum(
attn_metadata.seq_lens_cuda,
dim=0,
dtype=torch.long,
) - 1
last_token_hidden_states = hidden_states[last_tokens]
else:
last_token_hidden_states = hidden_states[-1].unsqueeze(0)
def forward(self,
hidden_states: torch.Tensor,
lm_head: Linear,
attn_metadata: AttentionMetadata,
return_context_logits: bool = False) -> torch.Tensor:
if not return_context_logits:
if attn_metadata is not None:
last_tokens = torch.cumsum(
attn_metadata.seq_lens_cuda,
dim=0,
dtype=torch.long,
) - 1
hidden_states = hidden_states[last_tokens]
else:
hidden_states = hidden_states[-1].unsqueeze(0)

logits = lm_head(last_token_hidden_states)
logits = lm_head(hidden_states)
return logits


Expand Down Expand Up @@ -976,10 +980,9 @@ def forward(
input_ids: torch.IntTensor,
position_ids: torch.IntTensor,
hidden_states: torch.Tensor,
lm_head: Linear,
embed_tokens: Embedding,
attn_metadata: AttentionMetadata,
spec_metadata: MTPSpecMetadata,
all_rank_num_tokens: Optional[List[int]] = None,
**kwargs,
) -> Tuple[torch.Tensor, torch.Tensor]:

Expand Down Expand Up @@ -1020,7 +1023,7 @@ def forward(
# MoE
hidden_states = self.mlp(
hidden_states,
all_rank_num_tokens=spec_metadata.all_rank_num_tokens,
all_rank_num_tokens=all_rank_num_tokens,
final_all_reduce_params=AllReduceParams(
enable_allreduce=not (self.fusion_config.POST_MOE_FUSION
or self.mapping.tp_size == 1)),
Expand All @@ -1039,9 +1042,7 @@ def forward(
else:
hidden_states, _ = self.shared_head.norm(hidden_states, residual)

logits = self.shared_head(hidden_states, lm_head, attn_metadata).float()

return hidden_states, logits
return hidden_states


class DeepseekV3Model(DecoderModel):
Expand Down
166 changes: 86 additions & 80 deletions tensorrt_llm/_torch/speculative/mtp.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,13 @@ class MTPSpecMetadata(SpecMetadata):
slot_ids: Optional[torch.Tensor] = None
# The index of the batche inputs
batch_indices_cuda: Optional[torch.Tensor] = None
# The number of sequences for speculative model/layer of different rank
_all_rank_num_seqs: Optional[List[int]] = None
# This is used for attention dp in the MTP Eagle worker. The numbers of input
# tokens varies between the 1st draft forward and subsequent ones. To support
# CUDA graph, we use this tensor to store the number of input tokens for the
# subsequence draft forward.
subseq_all_rank_num_tokens: Optional[List[int]] = None

def __post_init__(self) -> None:
if self.mtp_hidden_states_manager is not None:
Expand All @@ -166,6 +173,16 @@ def __post_init__(self) -> None:
device='cuda',
)

@property
def all_rank_num_seqs(self):
return self._all_rank_num_seqs

@all_rank_num_seqs.setter
def all_rank_num_seqs(self, value: List[int]):
self._all_rank_num_seqs = value
if self.spec_dec_mode.is_mtp_eagle():
self.subseq_all_rank_num_tokens = value

def prepare(self):
assert self.request_ids is not None
num_seqs = len(self.request_ids)
Expand All @@ -176,10 +193,11 @@ def prepare(self):
pin_memory=True)
self.batch_indices_cuda[:num_seqs].copy_(batch_indices,
non_blocking=True)
# MTP module need different number of input tokens in generation phase
if self.spec_dec_mode.is_mtp_eagle():
self.num_tokens -= (self.num_generations) * self.mtp_num_modules
else:
# MTP vanilla worker uses total max_draft_tokens input tokens in generation phase,
# while MTP Eagle worker uses (max_draft_tokens + 1) input tokens in the 1st draft
# forward and only one input token in the following draft forward.
# This num_tokens is used to set the all_rank_num_tokens for attention dp.
if not self.spec_dec_mode.is_mtp_eagle():
self.num_tokens -= self.num_generations

if self.mtp_hidden_states_manager is not None: # MTP vanilla or use relaxed acceptance
Expand Down Expand Up @@ -375,9 +393,9 @@ def forward(
num_accepted_tokens=num_accepted_tokens,
spec_metadata=spec_metadata,
attn_metadata=attn_metadata)
hidden_states, logits = mtp_layer(lm_head=lm_head,
embed_tokens=embed_tokens,
**draft_inputs)
hidden_states = mtp_layer(embed_tokens=embed_tokens, **draft_inputs)
logits = mtp_layer.shared_head(hidden_states, lm_head,
attn_metadata).float()
previous_layer_draft_tokens = self.draft_sampler(logits)
next_draft_tokens.append(previous_layer_draft_tokens)

Expand Down Expand Up @@ -727,12 +745,13 @@ def sample_and_accept_draft_tokens(
logits = logits.unsqueeze(0)

# The return buffer
accepted_tokens = torch.empty((batch_size, (mtp_num_modules + 1)),
dtype=torch.int,
device=logits.device)
num_accepted_tokens = torch.ones(batch_size,
if self.spec_config.use_relaxed_acceptance_for_thinking or not self.is_thop:
accepted_tokens = torch.ones((batch_size, (mtp_num_modules + 1)),
dtype=torch.int,
device=logits.device)
num_accepted_tokens = torch.ones(batch_size,
dtype=torch.int,
device=logits.device)
if self.spec_config.use_relaxed_acceptance_for_thinking:
mtp_relaxed_delta_pool = spec_metadata.mtp_hidden_states_manager.mtp_relaxed_delta_pool

Expand Down Expand Up @@ -1021,7 +1040,6 @@ def prepare_drafter_inputs(
"position_ids": position_ids,
"hidden_states": return_hidden_states,
"attn_metadata": attn_metadata,
"spec_metadata": spec_metadata,
}

def draft_sampler(
Expand Down Expand Up @@ -1066,6 +1084,7 @@ def forward(
):
batch_size = attn_metadata.num_seqs
num_contexts = attn_metadata.num_contexts
num_gens = batch_size - num_contexts

# Sample and verify draft tokens
raw_logits = logits
Expand All @@ -1079,58 +1098,79 @@ def forward(

# Prepare inputs for the 1st MTP layer
position_ids = position_ids.squeeze(0)
inputs = self.prepare_drafter_inputs(
input_ids=input_ids,
position_ids=position_ids,
hidden_states=hidden_states,
accepted_tokens=accepted_tokens,
num_accepted_tokens=num_accepted_tokens,
attn_metadata=attn_metadata,
spec_metadata=spec_metadata)
last_tokens_idx = torch.cumsum(
attn_metadata.seq_lens_cuda, dim=0, dtype=torch.long) - 1
inputs = self.prepare_drafter_inputs(input_ids=input_ids,
position_ids=position_ids,
last_tokens_idx=last_tokens_idx,
hidden_states=hidden_states,
accepted_tokens=accepted_tokens,
attn_metadata=attn_metadata,
spec_metadata=spec_metadata)

# Predict draft tokens
next_draft_tokens = []
for i in range(self.mtp_num_modules):
hidden_states, logits = mtp_layers[0](lm_head=lm_head,
embed_tokens=embed_tokens,
**inputs)
if i == 0:
hidden_states = mtp_layers[0](
embed_tokens=embed_tokens,
all_rank_num_tokens=spec_metadata.all_rank_num_tokens,
**inputs)
start_ids_gen = (spec_metadata.batch_indices_cuda[:num_gens] *
(self.mtp_num_modules + 1)).long()
gather_ids_gen = (start_ids_gen +
num_accepted_tokens[num_contexts:] - 1 +
attn_metadata.num_ctx_tokens)
gather_ids = torch.concat(
[last_tokens_idx[:num_contexts], gather_ids_gen], dim=0)
else:
hidden_states = mtp_layers[0](embed_tokens=embed_tokens,
all_rank_num_tokens=spec_metadata.
subseq_all_rank_num_tokens,
**inputs)
# All of the seq_len are 1, use batch_indices_cuda as gather_ids
gather_ids = spec_metadata.batch_indices_cuda[:batch_size]
logits = mtp_layers[0].shared_head(hidden_states[gather_ids],
lm_head, attn_metadata, True)
new_draft_token = self.draft_sampler(logits)
next_draft_tokens.append(new_draft_token)
# update inputs
last_tokens = torch.cumsum(
attn_metadata.seq_lens_cuda,
dim=0,
dtype=torch.long,
) - 1
position_ids = inputs["position_ids"][last_tokens] + 1
hidden_states = hidden_states[last_tokens]
attn_metadata._seq_lens[:attn_metadata.num_contexts].fill_(1)
attn_metadata._seq_lens_cuda[:attn_metadata.num_contexts].fill_(1)
attn_metadata.on_update()
# cannot run generation if their is no kv cache
if inputs["attn_metadata"].kv_cache_manager is not None:
attn_metadata.host_request_types[:attn_metadata.
num_contexts].fill_(1)
attn_metadata.num_contexts = 0
if i == 0 and num_contexts > 0 and attn_metadata.enable_flash_mla:
hidden_states = hidden_states[gather_ids]
position_ids = inputs["position_ids"][gather_ids] + 1
# update attn_metadata
if i == 0:
attn_metadata._seq_lens[:batch_size].fill_(1)
attn_metadata._seq_lens_cuda[:batch_size].fill_(1)
attn_metadata.on_update()
# cannot run generation if their is no kv cache
has_kv_cache = inputs[
"attn_metadata"].kv_cache_manager is not None
if has_kv_cache:
attn_metadata.host_request_types[:attn_metadata.
num_contexts].fill_(1)
attn_metadata.num_contexts = 0
# update kv_lens_cuda
if hasattr(attn_metadata, 'kv_lens_cuda'):
attn_metadata.kv_lens_cuda[num_contexts:batch_size] -= (
self.mtp_num_modules -
num_accepted_tokens[num_contexts:])
attn_metadata.kv_lens_cuda[:num_contexts] += 1
# update metadata for flash mla
if has_kv_cache and num_contexts > 0 and attn_metadata.enable_flash_mla:
reorder_block_ids_per_seq = torch.cat([
attn_metadata.
kv_block_ids_per_seq[num_contexts:batch_size],
attn_metadata.kv_block_ids_per_seq[:num_contexts]
])
attn_metadata.block_ids_per_seq[:batch_size, :].copy_(
reorder_block_ids_per_seq, non_blocking=True)
if hasattr(attn_metadata, 'kv_lens_cuda'):
elif hasattr(attn_metadata, 'kv_lens_cuda'):
attn_metadata.kv_lens_cuda[:batch_size] += 1
# support attention dp
if spec_metadata.all_rank_num_tokens is not None:
spec_metadata.all_rank_num_tokens = spec_metadata.all_rank_num_seqs
inputs = {
"input_ids": new_draft_token,
"position_ids": position_ids,
"hidden_states": hidden_states,
"attn_metadata": attn_metadata,
"spec_metadata": spec_metadata,
}
next_draft_tokens = torch.stack(next_draft_tokens, dim=1)

Expand Down Expand Up @@ -1159,66 +1199,32 @@ def prepare_drafter_inputs(
self,
input_ids: torch.IntTensor,
position_ids: torch.IntTensor,
last_tokens_idx: torch.LongTensor,
hidden_states: torch.Tensor,
accepted_tokens: torch.Tensor,
num_accepted_tokens: torch.Tensor,
attn_metadata: AttentionMetadata,
spec_metadata: MTPSpecMetadata,
):
batch_size = attn_metadata.num_seqs
num_contexts = attn_metadata.num_contexts
num_gens = batch_size - num_contexts
num_ctx_tokens = attn_metadata.num_ctx_tokens
hidden_size = hidden_states.shape[1]
last_tokens_idx = torch.cumsum(
attn_metadata.seq_lens_cuda, dim=0, dtype=torch.long) - 1

# context
hidden_states_ctx = hidden_states[:attn_metadata.num_ctx_tokens, :]
input_ctx_ids = input_ids[:attn_metadata.num_ctx_tokens]
input_ids_ctx = torch.empty_like(input_ctx_ids,
dtype=torch.int32,
device="cuda")
input_ids_ctx[:-1].copy_(input_ctx_ids[1:])
input_ids_ctx[
last_tokens_idx[:num_contexts]] = accepted_tokens[:num_contexts, 0]
position_ids_ctx = position_ids[:num_ctx_tokens]

# generation
gen_batch_idx = spec_metadata.batch_indices_cuda[:num_gens]
gen_token_idx = num_accepted_tokens[num_contexts:] - 1
hidden_states_gen = hidden_states[attn_metadata.num_ctx_tokens:, :]
hidden_states_gen = hidden_states_gen.reshape(num_gens,
self.mtp_num_modules + 1,
hidden_size)
hidden_states_gen = hidden_states_gen[gen_batch_idx, gen_token_idx, :]
accepted_tokens_gen = accepted_tokens[num_contexts:, :]
input_ids_gen = accepted_tokens_gen[gen_batch_idx, gen_token_idx]
position_ids_gen = position_ids[num_ctx_tokens:].reshape(
num_gens, self.mtp_num_modules + 1)
position_ids_gen = position_ids_gen[gen_batch_idx, gen_token_idx]
input_ids_gen = accepted_tokens[num_contexts:, :].flatten()

# get draft inputs
input_ids = torch.concat([input_ids_ctx, input_ids_gen], dim=0)
hidden_states = torch.concat([hidden_states_ctx, hidden_states_gen],
dim=0)
position_ids = torch.concat([position_ids_ctx, position_ids_gen], dim=0)

# change attn_metadata
attn_metadata._seq_lens[num_contexts:batch_size].fill_(1)
attn_metadata._seq_lens_cuda[num_contexts:batch_size].fill_(1)
attn_metadata.on_update()
if hasattr(attn_metadata, 'kv_lens_cuda'):
# Note that it's important to not free the seq_lens_cuda
# buffer once the graph has been captured also - this will invalidate
# the graph and force an expensive recapture.
attn_metadata.kv_lens_cuda[num_contexts:batch_size] -= (
self.mtp_num_modules + 1 - num_accepted_tokens[num_contexts:])

return {
"input_ids": input_ids,
"position_ids": position_ids,
"hidden_states": hidden_states,
"attn_metadata": attn_metadata,
"spec_metadata": spec_metadata,
}