Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 13 additions & 10 deletions cpp/serve/engine_actions/batch_prefill_base.cc
Original file line number Diff line number Diff line change
Expand Up @@ -62,15 +62,12 @@ BatchPrefillBaseActionObj::GetRequestStateEntriesToPrefill(EngineState estate) {
// Then we make a reduction to return the maximum common inputs.
for (int i = 0; i < static_cast<int>(models_.size()); ++i) {
std::vector<PrefillInput> prefill_inputs;
// - Try to prefill pending requests, in addition to reserved decode requests.
// - Try to prefill pending requests.
int total_input_length = 0;
int total_required_pages = num_decode_inputs;
// Reserve decode requests first.
for (const RequestStateEntry& rsentry : *running_rsentries) {
prefill_inputs.push_back(
{rsentry, rsentry->mstates[i]->num_tokens_for_next_decode, 0, /*is_decode=*/true});
total_input_length += rsentry->mstates[i]->num_tokens_for_next_decode;
}
int total_required_pages = num_decode_inputs;
int num_available_pages;
int num_running_rsentries = num_decode_inputs;
int current_total_seq_len;
Expand Down Expand Up @@ -211,14 +208,20 @@ BatchPrefillBaseActionObj::GetRequestStateEntriesToPrefill(EngineState estate) {
std::min(num_prefill_inputs, static_cast<int>(prefill_inputs_for_all_models[i].size()));
}

// If all inputs are decode inputs, since no prefill inputs can be added, skip prefill action
if (num_prefill_inputs == num_decode_inputs) {
if (num_prefill_inputs == 0) {
return {};
}

std::vector<PrefillInput> prefill_inputs(
prefill_inputs_for_all_models[0].begin(),
prefill_inputs_for_all_models[0].begin() + num_prefill_inputs);
// Add the decode requests to the prefill inputs.
std::vector<PrefillInput> prefill_inputs;
prefill_inputs.reserve(num_decode_inputs + num_prefill_inputs);
for (const RequestStateEntry& rsentry : *running_rsentries) {
prefill_inputs.push_back(
{rsentry, rsentry->mstates[0]->num_tokens_for_next_decode, 0, /*is_decode=*/true});
}
prefill_inputs.insert(prefill_inputs.end(), prefill_inputs_for_all_models[0].begin(),
prefill_inputs_for_all_models[0].begin() + num_prefill_inputs);
num_prefill_inputs += num_decode_inputs;
{
NVTXScopedRange nvtx_scope("reduction");
for (int i = 1; i < static_cast<int>(prefill_inputs_for_all_models.size()); ++i) {
Expand Down