Skip to content

[RFC]: Reimplement and separate beam search on top of vLLM core #8306

@youkaichao

Description

@youkaichao

Motivation.

A rework of #6226

After discussing further with the community, we find that the common use case for beam search is:

  1. throughput oriented
  2. mainly offline batch inference
  3. use one beam search parameter for all the prompts in the batch

After discussing with many contributors, we find:

because beam search is a search algorithm, it conflicts with all the rest sampling algorithm. As a result, many features in vllm already directly assert beam search is not used, e.g.

assert len(input_seq_group_metadata.seq_data) == 1, (
"Beam search "
"not supported in speculative decoding")

assert len(seqs) == 1, (
"Beam search not supported in multi-step decoding.")
seq = seqs[0]

keeping beam-search as-is in the codebase, will not benefit current beam search user, as no optimization will target at better beam search performance. What's worse, very few developers understand beam search. Keeping beam-search as-is will not only increase the bugs for beam search as the codebase evolves, but also increase the maintenance cost of all contributors.

in search of a win-win solution, on behalf of the vllm team, I propose to separate and reimplement beam search on top of the vllm core code.

to be specific, we can:

  1. remove beam search logic from the scheduler
  2. add an LLM.beam_search interface, that calls the engine to generate 1 tokens with logprobs every step, and maintain beam-search logic only in the LLM.beam_search function.
  3. add a beam search emulator over commonly used openai api server, which internally calls the generation endpoint to generate one step with logprobs, and maintain beam-search logic only in the emulator.

From the initial discussion, one concern is the efficiency of such implementation, as the request will come and go again and again from the vllm core's perspective. It should be solvable in two-folds:

  1. turning on prefix caching can reuse computation from the last step so that we don't need to recompute the kv cache of prompt again and again.
  2. after separating beam search and the vllm core, they can be optimized individually. The simplified code will be much easier to optimize.

vLLM is a community project, and we'd like to not only seek opinions from beam-search users, but also seek contributions from beam-search users. Your help is truly needed to shape the future of beam-search support in vLLM.

Proposed Change.

summary of the change: implement beam-search on top of vllm core and add wrappers for users. remove beam-search from the vllm core (scheduler).

Feedback Period.

1 week, from 9/9 to 9/15 (both inclusive)

CC List.

@hrsmanian @zhouyuan @lanking520 @nightflight-dk @HeegonJin @SemMulder @darabos @DhruvaBansal00 @tmostak @physicsrob @YooSungHyun @denadai2 @sjmielke @Reichenbachian @AaronFriel @hinnefe2 @mflaxman10
@WoosukKwon @zhuohan123 @simon-mo

Any Other Things.

No response

Before submitting a new issue...

  • Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions