From 5e926243c2073f770ec77fe527f26a4b14544a81 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Tue, 30 Jul 2024 14:45:20 -0700 Subject: [PATCH] [Serving] Fix handling of num_tokens_for_next_decode in spec decoding --- cpp/serve/engine_actions/batch_draft.cc | 4 ++++ cpp/serve/engine_actions/batch_verify.cc | 3 +++ cpp/serve/engine_actions/eagle_batch_verify.cc | 5 +++++ 3 files changed, 12 insertions(+) diff --git a/cpp/serve/engine_actions/batch_draft.cc b/cpp/serve/engine_actions/batch_draft.cc index c65fc4cc7c..1d330c4abd 100644 --- a/cpp/serve/engine_actions/batch_draft.cc +++ b/cpp/serve/engine_actions/batch_draft.cc @@ -99,6 +99,10 @@ class BatchDraftActionObj : public EngineActionObj { input_tokens.clear(); for (int i = 0; i < num_rsentries; ++i) { // The first draft proposal uses the last committed token. + if (draft_id == 0) { + ICHECK_EQ(mstates[i]->num_tokens_for_next_decode, 1); + mstates[i]->num_tokens_for_next_decode = 0; + } input_tokens.push_back(draft_id == 0 ? mstates[i]->committed_tokens.back().GetTokenId() : mstates[i]->draft_output_tokens.back().GetTokenId()); diff --git a/cpp/serve/engine_actions/batch_verify.cc b/cpp/serve/engine_actions/batch_verify.cc index 5c8adb4719..d02ae6e541 100644 --- a/cpp/serve/engine_actions/batch_verify.cc +++ b/cpp/serve/engine_actions/batch_verify.cc @@ -222,6 +222,9 @@ class BatchVerifyActionObj : public EngineActionObj { for (int i = 0; i < num_rsentries; ++i) { rsentries[i]->mstates[draft_model_id_]->RemoveAllDraftTokens(&draft_token_slots_); draft_token_workspace_manager_->FreeSlots(draft_token_slots_); + // reset num_tokens_for_next_decode to 1 + rsentries[i]->mstates[verify_model_id_]->num_tokens_for_next_decode = 0; + rsentries[i]->mstates[draft_model_id_]->num_tokens_for_next_decode = 1; } auto tend = std::chrono::high_resolution_clock::now(); diff --git a/cpp/serve/engine_actions/eagle_batch_verify.cc b/cpp/serve/engine_actions/eagle_batch_verify.cc index b08fc33f6f..977c40235c 100644 --- a/cpp/serve/engine_actions/eagle_batch_verify.cc +++ b/cpp/serve/engine_actions/eagle_batch_verify.cc @@ -320,6 +320,11 @@ class EagleBatchVerifyActionObj : public EngineActionObj { } } } + // reset num_tokens_for_next_decode + for (const RequestStateEntry& rsentry : rsentries) { + rsentry->mstates[verify_model_id_]->num_tokens_for_next_decode = 0; + rsentry->mstates[draft_model_id_]->num_tokens_for_next_decode = 0; + } auto tend = std::chrono::high_resolution_clock::now(); double elapsed_time = static_cast((tend - tstart).count()) / 1e9; estate->metrics.engine_decode_time_sum += elapsed_time;