1313from  vllm  import  LLM , SamplingParams 
1414from  vllm .engine .arg_utils  import  EngineArgs 
1515from  vllm .inputs  import  PromptType 
16+ from  vllm .sampling_params  import  BeamSearchParams 
1617from  vllm .utils  import  FlexibleArgumentParser 
1718
1819
@@ -40,6 +41,20 @@ def main(args: argparse.Namespace):
4041        "prompt_token_ids" : batch 
4142    } for  batch  in  dummy_prompt_token_ids .tolist ()]
4243
44+     def  llm_generate ():
45+         if  not  args .use_beam_search :
46+             llm .generate (dummy_prompts ,
47+                          sampling_params = sampling_params ,
48+                          use_tqdm = False )
49+         else :
50+             llm .beam_search (
51+                 dummy_prompts ,
52+                 BeamSearchParams (
53+                     beam_width = args .n ,
54+                     max_tokens = args .output_len ,
55+                     ignore_eos = True ,
56+                 ))
57+ 
4358    def  run_to_completion (profile_dir : Optional [str ] =  None ):
4459        if  profile_dir :
4560            with  torch .profiler .profile (
@@ -49,15 +64,11 @@ def run_to_completion(profile_dir: Optional[str] = None):
4964                    ],
5065                    on_trace_ready = torch .profiler .tensorboard_trace_handler (
5166                        str (profile_dir ))) as  p :
52-                 llm .generate (dummy_prompts ,
53-                              sampling_params = sampling_params ,
54-                              use_tqdm = False )
67+                 llm_generate ()
5568            print (p .key_averages ().table (sort_by = "self_cuda_time_total" ))
5669        else :
5770            start_time  =  time .perf_counter ()
58-             llm .generate (dummy_prompts ,
59-                          sampling_params = sampling_params ,
60-                          use_tqdm = False )
71+             llm_generate ()
6172            end_time  =  time .perf_counter ()
6273            latency  =  end_time  -  start_time 
6374            return  latency 
0 commit comments