Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 70 additions & 33 deletions vllm/entrypoints/openai/serving_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from http import HTTPStatus
from typing import Any, ClassVar, Generic, TypeAlias, TypeVar

import numpy as np
import torch
from fastapi import Request
from pydantic import BaseModel, ConfigDict, Field, TypeAdapter
Expand Down Expand Up @@ -389,8 +390,9 @@ async def beam_search(

sort_beams_key = create_sort_beams_key_function(eos_token_id, length_penalty)

logprobs_num = 2 * beam_width
beam_search_params = SamplingParams(
logprobs=2 * beam_width,
logprobs=logprobs_num,
max_tokens=1,
temperature=temperature,
)
Expand Down Expand Up @@ -443,40 +445,75 @@ async def beam_search(
output = [x[0] for x in await asyncio.gather(*tasks)]

new_beams = []
for i, current_beam in enumerate(all_beams):
result = output[i]

# Store all new tokens generated by beam
all_beams_token_id = []
# Store the cumulative probability of all tokens
# generated by beam search
all_beams_logprob = []
# Iterate through all beam inference results
for i, result in enumerate(output):
current_beam = all_beams[i]
if result.outputs[0].logprobs is not None:
logprobs = result.outputs[0].logprobs[0]
for token_id, logprob_obj in logprobs.items():
if token_id == eos_token_id and not ignore_eos:
completed.append(
BeamSearchSequence(
tokens=current_beam.tokens + [token_id]
if include_stop_str_in_output
else current_beam.tokens,
logprobs=current_beam.logprobs + [logprobs],
cum_logprob=current_beam.cum_logprob
+ logprob_obj.logprob,
finish_reason="stop",
stop_reason=eos_token_id,
)
)
else:
new_beams.append(
BeamSearchSequence(
tokens=current_beam.tokens + [token_id],
logprobs=current_beam.logprobs + [logprobs],
lora_request=current_beam.lora_request,
cum_logprob=current_beam.cum_logprob
+ logprob_obj.logprob,
multi_modal_data=current_beam.multi_modal_data,
mm_processor_kwargs=current_beam.mm_processor_kwargs,
)
)

sorted_beams = sorted(new_beams, key=sort_beams_key, reverse=True)
all_beams = sorted_beams[:beam_width]
all_beams_token_id.extend(list(logprobs.keys()))
all_beams_logprob.extend(
[
current_beam.cum_logprob + obj.logprob
for obj in logprobs.values()
]
)

# Handle the token for the end of sentence (EOS)
all_beams_token_id = np.array(all_beams_token_id)
all_beams_logprob = np.array(all_beams_logprob)

if not ignore_eos:
# Get the index position of eos token in all generated results
eos_idx = np.where(all_beams_token_id == eos_token_id)[0]
for idx in eos_idx:
current_beam = all_beams[idx // logprobs_num]
result = output[idx // logprobs_num]
assert result.outputs[0].logprobs is not None
logprobs_entry = result.outputs[0].logprobs[0]
completed.append(
BeamSearchSequence(
tokens=current_beam.tokens + [eos_token_id]
if include_stop_str_in_output
else current_beam.tokens,
logprobs=current_beam.logprobs + [logprobs_entry],
cum_logprob=float(all_beams_logprob[idx]),
finish_reason="stop",
stop_reason=eos_token_id,
)
)
# After processing, set the log probability of the eos condition
# to negative infinity.
all_beams_logprob[eos_idx] = -np.inf

# Processing non-EOS tokens
# Get indices of the top beam_width probabilities
topn_idx = np.argpartition(np.negative(all_beams_logprob), beam_width)[
:beam_width
]

for idx in topn_idx:
current_beam = all_beams[idx // logprobs_num]
result = output[idx // logprobs_num]
token_id = int(all_beams_token_id[idx])
assert result.outputs[0].logprobs is not None
logprobs_entry = result.outputs[0].logprobs[0]
new_beams.append(
BeamSearchSequence(
tokens=current_beam.tokens + [token_id],
logprobs=current_beam.logprobs + [logprobs_entry],
lora_request=current_beam.lora_request,
cum_logprob=float(all_beams_logprob[idx]),
multi_modal_data=current_beam.multi_modal_data,
mm_processor_kwargs=current_beam.mm_processor_kwargs,
)
)

all_beams = new_beams

completed.extend(all_beams)
sorted_completed = sorted(completed, key=sort_beams_key, reverse=True)
Expand Down
Loading