diff --git a/cpp/serve/engine_actions/batch_prefill_base.cc b/cpp/serve/engine_actions/batch_prefill_base.cc index 2a23f0f6b3..61b52539de 100644 --- a/cpp/serve/engine_actions/batch_prefill_base.cc +++ b/cpp/serve/engine_actions/batch_prefill_base.cc @@ -62,15 +62,12 @@ BatchPrefillBaseActionObj::GetRequestStateEntriesToPrefill(EngineState estate) { // Then we make a reduction to return the maximum common inputs. for (int i = 0; i < static_cast(models_.size()); ++i) { std::vector prefill_inputs; - // - Try to prefill pending requests, in addition to reserved decode requests. + // - Try to prefill pending requests. int total_input_length = 0; - int total_required_pages = num_decode_inputs; - // Reserve decode requests first. for (const RequestStateEntry& rsentry : *running_rsentries) { - prefill_inputs.push_back( - {rsentry, rsentry->mstates[i]->num_tokens_for_next_decode, 0, /*is_decode=*/true}); total_input_length += rsentry->mstates[i]->num_tokens_for_next_decode; } + int total_required_pages = num_decode_inputs; int num_available_pages; int num_running_rsentries = num_decode_inputs; int current_total_seq_len; @@ -211,14 +208,20 @@ BatchPrefillBaseActionObj::GetRequestStateEntriesToPrefill(EngineState estate) { std::min(num_prefill_inputs, static_cast(prefill_inputs_for_all_models[i].size())); } - // If all inputs are decode inputs, since no prefill inputs can be added, skip prefill action - if (num_prefill_inputs == num_decode_inputs) { + if (num_prefill_inputs == 0) { return {}; } - std::vector prefill_inputs( - prefill_inputs_for_all_models[0].begin(), - prefill_inputs_for_all_models[0].begin() + num_prefill_inputs); + // Add the decode requests to the prefill inputs. + std::vector prefill_inputs; + prefill_inputs.reserve(num_decode_inputs + num_prefill_inputs); + for (const RequestStateEntry& rsentry : *running_rsentries) { + prefill_inputs.push_back( + {rsentry, rsentry->mstates[0]->num_tokens_for_next_decode, 0, /*is_decode=*/true}); + } + prefill_inputs.insert(prefill_inputs.end(), prefill_inputs_for_all_models[0].begin(), + prefill_inputs_for_all_models[0].begin() + num_prefill_inputs); + num_prefill_inputs += num_decode_inputs; { NVTXScopedRange nvtx_scope("reduction"); for (int i = 1; i < static_cast(prefill_inputs_for_all_models.size()); ++i) {