Skip to content

Commit edfcef6

Browse files
lfr-0531dominicshanshan
authored andcommitted
Fix: fix the deterministic issue in the MTP Eagle path (NVIDIA#5285)
Signed-off-by: Fanrong Li <[email protected]>
1 parent a533c21 commit edfcef6

File tree

3 files changed

+107
-100
lines changed

3 files changed

+107
-100
lines changed

cpp/tensorrt_llm/thop/mtpOp.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -109,8 +109,8 @@ std::tuple<th::Tensor, th::Tensor> mtp_sampling_and_accepted_draft_tokens_op(th:
109109
TLLM_CHECK(draftTokensSizes[0] == (numGenerationRequest * numMTPModules));
110110

111111
auto stream = at::cuda::getCurrentCUDAStream(logits.get_device());
112-
auto acceptedTokens = torch::empty(
113-
{batchSize, numMTPModules + 1}, at::TensorOptions().dtype(torch::kInt32).device(logits.device()));
112+
auto acceptedTokens
113+
= torch::ones({batchSize, numMTPModules + 1}, at::TensorOptions().dtype(torch::kInt32).device(logits.device()));
114114
auto numAcceptedTokens = torch::ones({batchSize}, at::TensorOptions().dtype(torch::kInt32).device(logits.device()));
115115

116116
// Fill params

tensorrt_llm/_torch/models/modeling_deepseekv3.py

Lines changed: 19 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -135,19 +135,23 @@ def __init__(self, model_config: ModelConfig[PretrainedConfig]):
135135
eps=config.rms_norm_eps,
136136
dtype=config.torch_dtype)
137137

138-
def forward(self, hidden_states: torch.Tensor, lm_head: Linear,
139-
attn_metadata: AttentionMetadata) -> torch.Tensor:
140-
if attn_metadata is not None:
141-
last_tokens = torch.cumsum(
142-
attn_metadata.seq_lens_cuda,
143-
dim=0,
144-
dtype=torch.long,
145-
) - 1
146-
last_token_hidden_states = hidden_states[last_tokens]
147-
else:
148-
last_token_hidden_states = hidden_states[-1].unsqueeze(0)
138+
def forward(self,
139+
hidden_states: torch.Tensor,
140+
lm_head: Linear,
141+
attn_metadata: AttentionMetadata,
142+
return_context_logits: bool = False) -> torch.Tensor:
143+
if not return_context_logits:
144+
if attn_metadata is not None:
145+
last_tokens = torch.cumsum(
146+
attn_metadata.seq_lens_cuda,
147+
dim=0,
148+
dtype=torch.long,
149+
) - 1
150+
hidden_states = hidden_states[last_tokens]
151+
else:
152+
hidden_states = hidden_states[-1].unsqueeze(0)
149153

150-
logits = lm_head(last_token_hidden_states)
154+
logits = lm_head(hidden_states)
151155
return logits
152156

153157

@@ -931,10 +935,9 @@ def forward(
931935
input_ids: torch.IntTensor,
932936
position_ids: torch.IntTensor,
933937
hidden_states: torch.Tensor,
934-
lm_head: Linear,
935938
embed_tokens: Embedding,
936939
attn_metadata: AttentionMetadata,
937-
spec_metadata: MTPSpecMetadata,
940+
all_rank_num_tokens: Optional[List[int]] = None,
938941
**kwargs,
939942
) -> Tuple[torch.Tensor, torch.Tensor]:
940943

@@ -975,7 +978,7 @@ def forward(
975978
# MoE
976979
hidden_states = self.mlp(
977980
hidden_states,
978-
all_rank_num_tokens=spec_metadata.all_rank_num_tokens,
981+
all_rank_num_tokens=all_rank_num_tokens,
979982
final_all_reduce_params=AllReduceParams(
980983
enable_allreduce=not (self.fusion_config.POST_MOE_FUSION
981984
or self.mapping.tp_size == 1)),
@@ -994,9 +997,7 @@ def forward(
994997
else:
995998
hidden_states, _ = self.shared_head.norm(hidden_states, residual)
996999

997-
logits = self.shared_head(hidden_states, lm_head, attn_metadata).float()
998-
999-
return hidden_states, logits
1000+
return hidden_states
10001001

10011002

10021003
class DeepseekV3Model(DecoderModel):

tensorrt_llm/_torch/speculative/mtp.py

Lines changed: 86 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,13 @@ class MTPSpecMetadata(SpecMetadata):
140140
slot_ids: Optional[torch.Tensor] = None
141141
# The index of the batche inputs
142142
batch_indices_cuda: Optional[torch.Tensor] = None
143+
# The number of sequences for speculative model/layer of different rank
144+
_all_rank_num_seqs: Optional[List[int]] = None
145+
# This is used for attention dp in the MTP Eagle worker. The numbers of input
146+
# tokens varies between the 1st draft forward and subsequent ones. To support
147+
# CUDA graph, we use this tensor to store the number of input tokens for the
148+
# subsequence draft forward.
149+
subseq_all_rank_num_tokens: Optional[List[int]] = None
143150

144151
def __post_init__(self) -> None:
145152
if self.mtp_hidden_states_manager is not None:
@@ -166,6 +173,16 @@ def __post_init__(self) -> None:
166173
device='cuda',
167174
)
168175

176+
@property
177+
def all_rank_num_seqs(self):
178+
return self._all_rank_num_seqs
179+
180+
@all_rank_num_seqs.setter
181+
def all_rank_num_seqs(self, value: List[int]):
182+
self._all_rank_num_seqs = value
183+
if self.spec_dec_mode.is_mtp_eagle():
184+
self.subseq_all_rank_num_tokens = value
185+
169186
def prepare(self):
170187
assert self.request_ids is not None
171188
num_seqs = len(self.request_ids)
@@ -176,10 +193,11 @@ def prepare(self):
176193
pin_memory=True)
177194
self.batch_indices_cuda[:num_seqs].copy_(batch_indices,
178195
non_blocking=True)
179-
# MTP module need different number of input tokens in generation phase
180-
if self.spec_dec_mode.is_mtp_eagle():
181-
self.num_tokens -= (self.num_generations) * self.mtp_num_modules
182-
else:
196+
# MTP vanilla worker uses total max_draft_tokens input tokens in generation phase,
197+
# while MTP Eagle worker uses (max_draft_tokens + 1) input tokens in the 1st draft
198+
# forward and only one input token in the following draft forward.
199+
# This num_tokens is used to set the all_rank_num_tokens for attention dp.
200+
if not self.spec_dec_mode.is_mtp_eagle():
183201
self.num_tokens -= self.num_generations
184202

185203
if self.mtp_hidden_states_manager is not None: # MTP vanilla or use relaxed acceptance
@@ -375,9 +393,9 @@ def forward(
375393
num_accepted_tokens=num_accepted_tokens,
376394
spec_metadata=spec_metadata,
377395
attn_metadata=attn_metadata)
378-
hidden_states, logits = mtp_layer(lm_head=lm_head,
379-
embed_tokens=embed_tokens,
380-
**draft_inputs)
396+
hidden_states = mtp_layer(embed_tokens=embed_tokens, **draft_inputs)
397+
logits = mtp_layer.shared_head(hidden_states, lm_head,
398+
attn_metadata).float()
381399
previous_layer_draft_tokens = self.draft_sampler(logits)
382400
next_draft_tokens.append(previous_layer_draft_tokens)
383401

@@ -727,12 +745,13 @@ def sample_and_accept_draft_tokens(
727745
logits = logits.unsqueeze(0)
728746

729747
# The return buffer
730-
accepted_tokens = torch.empty((batch_size, (mtp_num_modules + 1)),
731-
dtype=torch.int,
732-
device=logits.device)
733-
num_accepted_tokens = torch.ones(batch_size,
748+
if self.spec_config.use_relaxed_acceptance_for_thinking or not self.is_thop:
749+
accepted_tokens = torch.ones((batch_size, (mtp_num_modules + 1)),
734750
dtype=torch.int,
735751
device=logits.device)
752+
num_accepted_tokens = torch.ones(batch_size,
753+
dtype=torch.int,
754+
device=logits.device)
736755
if self.spec_config.use_relaxed_acceptance_for_thinking:
737756
mtp_relaxed_delta_pool = spec_metadata.mtp_hidden_states_manager.mtp_relaxed_delta_pool
738757

@@ -1021,7 +1040,6 @@ def prepare_drafter_inputs(
10211040
"position_ids": position_ids,
10221041
"hidden_states": return_hidden_states,
10231042
"attn_metadata": attn_metadata,
1024-
"spec_metadata": spec_metadata,
10251043
}
10261044

10271045
def draft_sampler(
@@ -1066,6 +1084,7 @@ def forward(
10661084
):
10671085
batch_size = attn_metadata.num_seqs
10681086
num_contexts = attn_metadata.num_contexts
1087+
num_gens = batch_size - num_contexts
10691088

10701089
# Sample and verify draft tokens
10711090
raw_logits = logits
@@ -1079,58 +1098,79 @@ def forward(
10791098

10801099
# Prepare inputs for the 1st MTP layer
10811100
position_ids = position_ids.squeeze(0)
1082-
inputs = self.prepare_drafter_inputs(
1083-
input_ids=input_ids,
1084-
position_ids=position_ids,
1085-
hidden_states=hidden_states,
1086-
accepted_tokens=accepted_tokens,
1087-
num_accepted_tokens=num_accepted_tokens,
1088-
attn_metadata=attn_metadata,
1089-
spec_metadata=spec_metadata)
1101+
last_tokens_idx = torch.cumsum(
1102+
attn_metadata.seq_lens_cuda, dim=0, dtype=torch.long) - 1
1103+
inputs = self.prepare_drafter_inputs(input_ids=input_ids,
1104+
position_ids=position_ids,
1105+
last_tokens_idx=last_tokens_idx,
1106+
hidden_states=hidden_states,
1107+
accepted_tokens=accepted_tokens,
1108+
attn_metadata=attn_metadata,
1109+
spec_metadata=spec_metadata)
10901110

10911111
# Predict draft tokens
10921112
next_draft_tokens = []
10931113
for i in range(self.mtp_num_modules):
1094-
hidden_states, logits = mtp_layers[0](lm_head=lm_head,
1095-
embed_tokens=embed_tokens,
1096-
**inputs)
1114+
if i == 0:
1115+
hidden_states = mtp_layers[0](
1116+
embed_tokens=embed_tokens,
1117+
all_rank_num_tokens=spec_metadata.all_rank_num_tokens,
1118+
**inputs)
1119+
start_ids_gen = (spec_metadata.batch_indices_cuda[:num_gens] *
1120+
(self.mtp_num_modules + 1)).long()
1121+
gather_ids_gen = (start_ids_gen +
1122+
num_accepted_tokens[num_contexts:] - 1 +
1123+
attn_metadata.num_ctx_tokens)
1124+
gather_ids = torch.concat(
1125+
[last_tokens_idx[:num_contexts], gather_ids_gen], dim=0)
1126+
else:
1127+
hidden_states = mtp_layers[0](embed_tokens=embed_tokens,
1128+
all_rank_num_tokens=spec_metadata.
1129+
subseq_all_rank_num_tokens,
1130+
**inputs)
1131+
# All of the seq_len are 1, use batch_indices_cuda as gather_ids
1132+
gather_ids = spec_metadata.batch_indices_cuda[:batch_size]
1133+
logits = mtp_layers[0].shared_head(hidden_states[gather_ids],
1134+
lm_head, attn_metadata, True)
10971135
new_draft_token = self.draft_sampler(logits)
10981136
next_draft_tokens.append(new_draft_token)
10991137
# update inputs
1100-
last_tokens = torch.cumsum(
1101-
attn_metadata.seq_lens_cuda,
1102-
dim=0,
1103-
dtype=torch.long,
1104-
) - 1
1105-
position_ids = inputs["position_ids"][last_tokens] + 1
1106-
hidden_states = hidden_states[last_tokens]
1107-
attn_metadata._seq_lens[:attn_metadata.num_contexts].fill_(1)
1108-
attn_metadata._seq_lens_cuda[:attn_metadata.num_contexts].fill_(1)
1109-
attn_metadata.on_update()
1110-
# cannot run generation if their is no kv cache
1111-
if inputs["attn_metadata"].kv_cache_manager is not None:
1112-
attn_metadata.host_request_types[:attn_metadata.
1113-
num_contexts].fill_(1)
1114-
attn_metadata.num_contexts = 0
1115-
if i == 0 and num_contexts > 0 and attn_metadata.enable_flash_mla:
1138+
hidden_states = hidden_states[gather_ids]
1139+
position_ids = inputs["position_ids"][gather_ids] + 1
1140+
# update attn_metadata
1141+
if i == 0:
1142+
attn_metadata._seq_lens[:batch_size].fill_(1)
1143+
attn_metadata._seq_lens_cuda[:batch_size].fill_(1)
1144+
attn_metadata.on_update()
1145+
# cannot run generation if their is no kv cache
1146+
has_kv_cache = inputs[
1147+
"attn_metadata"].kv_cache_manager is not None
1148+
if has_kv_cache:
1149+
attn_metadata.host_request_types[:attn_metadata.
1150+
num_contexts].fill_(1)
1151+
attn_metadata.num_contexts = 0
1152+
# update kv_lens_cuda
1153+
if hasattr(attn_metadata, 'kv_lens_cuda'):
1154+
attn_metadata.kv_lens_cuda[num_contexts:batch_size] -= (
1155+
self.mtp_num_modules -
1156+
num_accepted_tokens[num_contexts:])
1157+
attn_metadata.kv_lens_cuda[:num_contexts] += 1
1158+
# update metadata for flash mla
1159+
if has_kv_cache and num_contexts > 0 and attn_metadata.enable_flash_mla:
11161160
reorder_block_ids_per_seq = torch.cat([
11171161
attn_metadata.
11181162
kv_block_ids_per_seq[num_contexts:batch_size],
11191163
attn_metadata.kv_block_ids_per_seq[:num_contexts]
11201164
])
11211165
attn_metadata.block_ids_per_seq[:batch_size, :].copy_(
11221166
reorder_block_ids_per_seq, non_blocking=True)
1123-
if hasattr(attn_metadata, 'kv_lens_cuda'):
1167+
elif hasattr(attn_metadata, 'kv_lens_cuda'):
11241168
attn_metadata.kv_lens_cuda[:batch_size] += 1
1125-
# support attention dp
1126-
if spec_metadata.all_rank_num_tokens is not None:
1127-
spec_metadata.all_rank_num_tokens = spec_metadata.all_rank_num_seqs
11281169
inputs = {
11291170
"input_ids": new_draft_token,
11301171
"position_ids": position_ids,
11311172
"hidden_states": hidden_states,
11321173
"attn_metadata": attn_metadata,
1133-
"spec_metadata": spec_metadata,
11341174
}
11351175
next_draft_tokens = torch.stack(next_draft_tokens, dim=1)
11361176

@@ -1159,66 +1199,32 @@ def prepare_drafter_inputs(
11591199
self,
11601200
input_ids: torch.IntTensor,
11611201
position_ids: torch.IntTensor,
1202+
last_tokens_idx: torch.LongTensor,
11621203
hidden_states: torch.Tensor,
11631204
accepted_tokens: torch.Tensor,
1164-
num_accepted_tokens: torch.Tensor,
11651205
attn_metadata: AttentionMetadata,
11661206
spec_metadata: MTPSpecMetadata,
11671207
):
1168-
batch_size = attn_metadata.num_seqs
11691208
num_contexts = attn_metadata.num_contexts
1170-
num_gens = batch_size - num_contexts
1171-
num_ctx_tokens = attn_metadata.num_ctx_tokens
1172-
hidden_size = hidden_states.shape[1]
1173-
last_tokens_idx = torch.cumsum(
1174-
attn_metadata.seq_lens_cuda, dim=0, dtype=torch.long) - 1
11751209

11761210
# context
1177-
hidden_states_ctx = hidden_states[:attn_metadata.num_ctx_tokens, :]
11781211
input_ctx_ids = input_ids[:attn_metadata.num_ctx_tokens]
11791212
input_ids_ctx = torch.empty_like(input_ctx_ids,
11801213
dtype=torch.int32,
11811214
device="cuda")
11821215
input_ids_ctx[:-1].copy_(input_ctx_ids[1:])
11831216
input_ids_ctx[
11841217
last_tokens_idx[:num_contexts]] = accepted_tokens[:num_contexts, 0]
1185-
position_ids_ctx = position_ids[:num_ctx_tokens]
11861218

11871219
# generation
1188-
gen_batch_idx = spec_metadata.batch_indices_cuda[:num_gens]
1189-
gen_token_idx = num_accepted_tokens[num_contexts:] - 1
1190-
hidden_states_gen = hidden_states[attn_metadata.num_ctx_tokens:, :]
1191-
hidden_states_gen = hidden_states_gen.reshape(num_gens,
1192-
self.mtp_num_modules + 1,
1193-
hidden_size)
1194-
hidden_states_gen = hidden_states_gen[gen_batch_idx, gen_token_idx, :]
1195-
accepted_tokens_gen = accepted_tokens[num_contexts:, :]
1196-
input_ids_gen = accepted_tokens_gen[gen_batch_idx, gen_token_idx]
1197-
position_ids_gen = position_ids[num_ctx_tokens:].reshape(
1198-
num_gens, self.mtp_num_modules + 1)
1199-
position_ids_gen = position_ids_gen[gen_batch_idx, gen_token_idx]
1220+
input_ids_gen = accepted_tokens[num_contexts:, :].flatten()
12001221

12011222
# get draft inputs
12021223
input_ids = torch.concat([input_ids_ctx, input_ids_gen], dim=0)
1203-
hidden_states = torch.concat([hidden_states_ctx, hidden_states_gen],
1204-
dim=0)
1205-
position_ids = torch.concat([position_ids_ctx, position_ids_gen], dim=0)
1206-
1207-
# change attn_metadata
1208-
attn_metadata._seq_lens[num_contexts:batch_size].fill_(1)
1209-
attn_metadata._seq_lens_cuda[num_contexts:batch_size].fill_(1)
1210-
attn_metadata.on_update()
1211-
if hasattr(attn_metadata, 'kv_lens_cuda'):
1212-
# Note that it's important to not free the seq_lens_cuda
1213-
# buffer once the graph has been captured also - this will invalidate
1214-
# the graph and force an expensive recapture.
1215-
attn_metadata.kv_lens_cuda[num_contexts:batch_size] -= (
1216-
self.mtp_num_modules + 1 - num_accepted_tokens[num_contexts:])
12171224

12181225
return {
12191226
"input_ids": input_ids,
12201227
"position_ids": position_ids,
12211228
"hidden_states": hidden_states,
12221229
"attn_metadata": attn_metadata,
1223-
"spec_metadata": spec_metadata,
12241230
}

0 commit comments

Comments
 (0)