Skip to content

[Performance]: Optimize beam search code #20316

@zhanggzh

Description

@zhanggzh

Proposal to improve performance

Dear vLLM community,

I've been analyzing the beam search implementation and have identified potential areas for performance improvement.

Problem Description

The beam search algorithm currently implements candidate expansion by concatenating all candidate tokens with existing sequences. This operation executes beam_width × 2 × beam_width times per decoding step.
Through time consumption statistics, it was found that the concatenation operation is highly time-consuming.

Image

Technical Analysis

Key Code Locations

  • Primary logic: vllm/engine/protocol.py
    The following code uses a double loop to concatenate all candidate tokens with the original token as a BeamSearchSequence, then selects the top-k candidates via the sort method.
    Image

However, by first selecting the top-k candidate tokens with the highest cumulative probabilities and then performing sorting, we can avoid the time-consuming concatenation operations, thereby improving inference speed. Here’s my revised code:

            new_beams = []
            #Store all new tokens generated by beam
            all_beams_token_id = []
            #Store the cumulative probability of all tokens
            #generated by beam search
            all_beams_logprob = []
            #Iterate through all beam inference results
            for i, result in enumerate(output):
                current_beam = all_beams[i]
                if result.outputs[0].logprobs is not None:
                    logprobs = result.outputs[0].logprobs[0]
                    all_beams_token_id.extend(list(logprobs.keys()))
                    all_beams_logprob.extend([
                        current_beam.cum_logprob + obj.logprob
                        for obj in logprobs.values()
                    ])

            ##Handle the token for the end of sentence (EOS)
            all_beams_token_id = np.array(all_beams_token_id)
            all_beams_logprob = np.array(all_beams_logprob)

            if not ignore_eos:
                #Get the index position of eos token in all generated results
                eos_idx = np.where(
                    all_beams_token_id == tokenizer.eos_token_id)[0]
                for idx in eos_idx:
                    current_beam = all_beams[idx // logprobs_num]
                    result = output[idx // logprobs_num]
                    completed.append(
                        BeamSearchSequence(tokens=current_beam.tokens +
                                           [tokenizer.eos_token_id]
                                           if include_stop_str_in_output else
                                           current_beam.tokens,
                                           logprobs=current_beam.logprobs +
                                           [result.outputs[0].logprobs[0]],
                                           cum_logprob=float(
                                               all_beams_logprob[idx]),
                                           finish_reason="stop",
                                           stop_reason=tokenizer.eos_token_id))
                #After processing, set the log probability of the eos condition
                #to negative infinity.
                all_beams_logprob[eos_idx] = -np.inf

            #Processing non-EOS tokens
            #Get indices of the top beam_width probabilities
            topn_idx = np.argpartition(np.negative(all_beams_logprob),
                                       beam_width)[:beam_width]

            for idx in topn_idx:
                current_beam = all_beams[idx // logprobs_num]
                result = output[idx // logprobs_num]
                token_id = int(all_beams_token_id[idx])
                new_beams.append(
                    BeamSearchSequence(
                        tokens=current_beam.tokens + [token_id],
                        logprobs=current_beam.logprobs +
                        [result.outputs[0].logprobs[0]],
                        lora_request=current_beam.lora_request,
                        cum_logprob=float(all_beams_logprob[idx]),
                        multi_modal_data=current_beam.multi_modal_data,
                        mm_processor_kwargs=current_beam.mm_processor_kwargs))

            all_beams =  new_beams

Related PR:#19347

Performance Profiling

Through the optimizations above, the processing time can be reduced by nearly 40%.
Image

I appreciate your review and hope to see these improvements merged into the main branch soon.

Report of performance regression

No response

Misc discussion on performance

No response

Your current environment (if you think it is necessary)

The output of `python collect_env.py`

Before submitting a new issue...

  • Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.

Metadata

Metadata

Assignees

No one assigned

    Labels

    performancePerformance-related issuesunstaleRecieved activity after being labelled stale

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions