Skip to content

Commit d6276c0

Browse files
committed
fix accepted_tokens tensor.
Signed-off-by: Fanrong Li <[email protected]>
1 parent fb38299 commit d6276c0

File tree

2 files changed

+7
-7
lines changed

2 files changed

+7
-7
lines changed

cpp/tensorrt_llm/thop/mtpOp.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -109,8 +109,8 @@ std::tuple<th::Tensor, th::Tensor> mtp_sampling_and_accepted_draft_tokens_op(th:
109109
TLLM_CHECK(draftTokensSizes[0] == (numGenerationRequest * numMTPModules));
110110

111111
auto stream = at::cuda::getCurrentCUDAStream(logits.get_device());
112-
auto acceptedTokens = torch::empty(
113-
{batchSize, numMTPModules + 1}, at::TensorOptions().dtype(torch::kInt32).device(logits.device()));
112+
auto acceptedTokens
113+
= torch::ones({batchSize, numMTPModules + 1}, at::TensorOptions().dtype(torch::kInt32).device(logits.device()));
114114
auto numAcceptedTokens = torch::ones({batchSize}, at::TensorOptions().dtype(torch::kInt32).device(logits.device()));
115115

116116
// Fill params

tensorrt_llm/_torch/speculative/mtp.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -745,12 +745,13 @@ def sample_and_accept_draft_tokens(
745745
logits = logits.unsqueeze(0)
746746

747747
# The return buffer
748-
accepted_tokens = torch.empty((batch_size, (mtp_num_modules + 1)),
749-
dtype=torch.int,
750-
device=logits.device)
751-
num_accepted_tokens = torch.ones(batch_size,
748+
if self.spec_config.use_relaxed_acceptance_for_thinking or not self.is_thop:
749+
accepted_tokens = torch.ones((batch_size, (mtp_num_modules + 1)),
752750
dtype=torch.int,
753751
device=logits.device)
752+
num_accepted_tokens = torch.ones(batch_size,
753+
dtype=torch.int,
754+
device=logits.device)
754755
if self.spec_config.use_relaxed_acceptance_for_thinking:
755756
mtp_relaxed_delta_pool = spec_metadata.mtp_hidden_states_manager.mtp_relaxed_delta_pool
756757

@@ -1068,7 +1069,6 @@ class MTPEagleWorker(MTPWorker):
10681069
def __init__(self, spec_config: MTPConfig):
10691070
super().__init__(spec_config)
10701071
self.mtp_num_modules = spec_config.num_nextn_predict_layers
1071-
self.is_thop = False
10721072

10731073
def forward(
10741074
self,

0 commit comments

Comments
 (0)