File tree Expand file tree Collapse file tree 2 files changed +7
-7
lines changed
tensorrt_llm/_torch/speculative Expand file tree Collapse file tree 2 files changed +7
-7
lines changed Original file line number Diff line number Diff line change @@ -109,8 +109,8 @@ std::tuple<th::Tensor, th::Tensor> mtp_sampling_and_accepted_draft_tokens_op(th:
109109 TLLM_CHECK (draftTokensSizes[0 ] == (numGenerationRequest * numMTPModules));
110110
111111 auto stream = at::cuda::getCurrentCUDAStream (logits.get_device ());
112- auto acceptedTokens = torch::empty (
113- {batchSize, numMTPModules + 1 }, at::TensorOptions ().dtype (torch::kInt32 ).device (logits.device ()));
112+ auto acceptedTokens
113+ = torch::ones ( {batchSize, numMTPModules + 1 }, at::TensorOptions ().dtype (torch::kInt32 ).device (logits.device ()));
114114 auto numAcceptedTokens = torch::ones ({batchSize}, at::TensorOptions ().dtype (torch::kInt32 ).device (logits.device ()));
115115
116116 // Fill params
Original file line number Diff line number Diff line change @@ -745,12 +745,13 @@ def sample_and_accept_draft_tokens(
745745 logits = logits .unsqueeze (0 )
746746
747747 # The return buffer
748- accepted_tokens = torch .empty ((batch_size , (mtp_num_modules + 1 )),
749- dtype = torch .int ,
750- device = logits .device )
751- num_accepted_tokens = torch .ones (batch_size ,
748+ if self .spec_config .use_relaxed_acceptance_for_thinking or not self .is_thop :
749+ accepted_tokens = torch .ones ((batch_size , (mtp_num_modules + 1 )),
752750 dtype = torch .int ,
753751 device = logits .device )
752+ num_accepted_tokens = torch .ones (batch_size ,
753+ dtype = torch .int ,
754+ device = logits .device )
754755 if self .spec_config .use_relaxed_acceptance_for_thinking :
755756 mtp_relaxed_delta_pool = spec_metadata .mtp_hidden_states_manager .mtp_relaxed_delta_pool
756757
@@ -1068,7 +1069,6 @@ class MTPEagleWorker(MTPWorker):
10681069 def __init__ (self , spec_config : MTPConfig ):
10691070 super ().__init__ (spec_config )
10701071 self .mtp_num_modules = spec_config .num_nextn_predict_layers
1071- self .is_thop = False
10721072
10731073 def forward (
10741074 self ,
You can’t perform that action at this time.
0 commit comments