@@ -114,16 +114,12 @@ class EagleBatchVerifyActionObj : public EngineActionObj {
114114 RECORD_EVENT (trace_recorder_, request_ids, " start verify" );
115115 ObjectRef hidden_states = models_[verify_model_id_]->BatchVerifyToLastHidden (
116116 embeddings, request_internal_ids, verify_lengths);
117- NDArray logits =
118- models_[verify_model_id_]->GetLogits (hidden_states, 1 , cum_verify_lengths[num_rsentries]);
117+ NDArray logits = models_[verify_model_id_]->GetLogits (hidden_states);
119118 RECORD_EVENT (trace_recorder_, request_ids, " finish verify" );
120- ICHECK_EQ (logits->ndim , 3 );
121- ICHECK_EQ (logits->shape [0 ], 1 );
122- ICHECK_EQ (logits->shape [1 ], cum_verify_lengths[num_rsentries]);
119+ ICHECK_EQ (logits->ndim , 2 );
120+ ICHECK_EQ (logits->shape [0 ], cum_verify_lengths.back ());
123121
124122 // - Update logits.
125- logits =
126- logits.CreateView ({cum_verify_lengths[num_rsentries], logits->shape [2 ]}, logits->dtype );
127123 logit_processor_->InplaceUpdateLogits (logits, generation_cfg, verify_request_mstates,
128124 request_ids, &cum_verify_lengths, &draft_output_tokens);
129125
@@ -273,19 +269,16 @@ class EagleBatchVerifyActionObj : public EngineActionObj {
273269 fused_embedding_hidden_states, request_internal_ids);
274270
275271 if (models_[draft_model_id_]->CanGetLogits ()) {
276- logits = models_[draft_model_id_]->GetLogits (hidden_states, /* batch_size*/ num_rsentries,
277- /* seq_len*/ 1 );
272+ logits = models_[draft_model_id_]->GetLogits (hidden_states);
278273 } else {
279274 // - Use base model's head.
280- logits = models_[0 ]->GetLogits (hidden_states, /* batch_size */ num_rsentries, /* seq_len */ 1 );
275+ logits = models_[0 ]->GetLogits (hidden_states);
281276 }
282277 RECORD_EVENT (trace_recorder_, request_ids, " finish proposal decode" );
283- ICHECK_EQ (logits->ndim , 3 );
278+ ICHECK_EQ (logits->ndim , 2 );
284279 ICHECK_EQ (logits->shape [0 ], num_rsentries);
285- ICHECK_EQ (logits->shape [1 ], 1 );
286280
287281 // - Update logits.
288- logits = logits.CreateView ({num_rsentries, logits->shape [2 ]}, logits->dtype );
289282 logit_processor_->InplaceUpdateLogits (logits, generation_cfg, mstates, request_ids);
290283
291284 // - Compute probability distributions.
0 commit comments