@@ -65,8 +65,10 @@ class EagleBatchVerifyActionObj : public EngineActionObj {
6565 Array<GenerationConfig> generation_cfg;
6666 std::vector<RandomGenerator*> rngs;
6767 std::vector<std::vector<SampleResult>> draft_output_tokens;
68+ std::vector<int64_t > token_tree_parent_ptr;
6869 request_internal_ids.reserve (num_rsentries);
6970 all_tokens_to_verify.reserve (total_draft_length);
71+ token_tree_parent_ptr.reserve (total_draft_length);
7072 verify_request_mstates.reserve (num_rsentries);
7173 rngs.reserve (num_rsentries);
7274 generation_cfg.reserve (num_rsentries);
@@ -83,9 +85,12 @@ class EagleBatchVerifyActionObj : public EngineActionObj {
8385 // the last committed token + all the draft tokens but the last one.
8486 all_tokens_to_verify.push_back (draft_mstate->committed_tokens .back ().GetTokenId ());
8587 draft_token_slots_.push_back (0 ); // placeholder for the last committed token
88+ token_tree_parent_ptr.push_back (-1 );
89+
8690 for (int j = 0 ; j < static_cast <int >(draft_mstate->draft_output_tokens .size ()); ++j) {
8791 all_tokens_to_verify.push_back (draft_mstate->draft_output_tokens [j].GetTokenId ());
8892 draft_token_slots_.push_back (draft_mstate->draft_token_slots [j]);
93+ token_tree_parent_ptr.push_back (draft_mstate->draft_token_parent_idx [j] + 1 );
8994 }
9095 verify_request_mstates.push_back (verify_mstate);
9196 generation_cfg.push_back (rsentries[i]->request ->generation_cfg );
@@ -111,16 +116,6 @@ class EagleBatchVerifyActionObj : public EngineActionObj {
111116 {IntTuple{all_tokens_to_verify.begin (), all_tokens_to_verify.end ()}});
112117 RECORD_EVENT (trace_recorder_, request_ids, " finish verify embedding" );
113118
114- // Construct the token tree. Right now only chains are supported.
115- std::vector<int64_t > token_tree_parent_ptr;
116- token_tree_parent_ptr.reserve (cum_verify_lengths.back ());
117- for (int i = 0 ; i < num_rsentries; ++i) {
118- for (int pos = 0 ; pos < verify_lengths[i]; ++pos) {
119- token_tree_parent_ptr.push_back (pos - 1 );
120- }
121- }
122- ICHECK_EQ (token_tree_parent_ptr.size (), cum_verify_lengths.back ());
123-
124119 RECORD_EVENT (trace_recorder_, request_ids, " start verify" );
125120 ObjectRef hidden_states = models_[verify_model_id_]->BatchVerifyToLastHidden (
126121 embeddings, request_internal_ids, verify_lengths, token_tree_parent_ptr);
@@ -143,7 +138,7 @@ class EagleBatchVerifyActionObj : public EngineActionObj {
143138 std::vector<std::vector<SampleResult>> sample_results_arr =
144139 sampler_->BatchVerifyDraftTokensWithProbAfterTopP (
145140 renormalized_probs, request_ids, cum_verify_lengths, generation_cfg, rngs,
146- draft_output_tokens, draft_probs_on_device);
141+ draft_output_tokens, token_tree_parent_ptr, draft_probs_on_device);
147142 ICHECK_EQ (sample_results_arr.size (), num_rsentries);
148143
149144 // We collect the requests whose drafts are fully accepted.
@@ -398,7 +393,8 @@ class EagleBatchVerifyActionObj : public EngineActionObj {
398393 &model_workspaces_[0 ].draft_hidden_states_storage );
399394 }
400395 for (int i = 0 ; i < static_cast <int >(mstates.size ()); ++i) {
401- mstates[i]->AddDraftToken (sample_results[i], draft_token_slots_[i]);
396+ int64_t parent_idx = static_cast <int64_t >(mstates[i]->draft_output_tokens .size ()) - 1 ;
397+ mstates[i]->AddDraftToken (sample_results[i], draft_token_slots_[i], parent_idx);
402398 }
403399 }
404400 /* !
0 commit comments