2424""" 
2525import  argparse 
2626import  asyncio 
27+ import  base64 
28+ import  io 
2729import  json 
2830import  os 
2931import  random 
3032import  time 
3133import  warnings 
3234from  dataclasses  import  dataclass 
3335from  datetime  import  datetime 
34- from  typing  import  Any , AsyncGenerator , Dict , List , Optional , Tuple 
36+ from  typing  import  Any , AsyncGenerator , Collection ,  Dict , List , Optional , Tuple 
3537
3638import  numpy  as  np 
3739from  backend_request_func  import  (ASYNC_REQUEST_FUNCS , RequestFuncInput ,
3840                                  RequestFuncOutput )
41+ from  datasets  import  load_dataset 
42+ from  PIL .Image  import  Image 
3943from  tqdm .asyncio  import  tqdm 
4044from  transformers  import  PreTrainedTokenizerBase 
4145
@@ -84,7 +88,7 @@ def sample_sharegpt_requests(
8488    num_requests : int ,
8589    tokenizer : PreTrainedTokenizerBase ,
8690    fixed_output_len : Optional [int ] =  None ,
87- ) ->  List [Tuple [str , int , int ]]:
91+ ) ->  List [Tuple [str , int , int ,  None ]]:
8892    if  fixed_output_len  is  not   None  and  fixed_output_len  <  4 :
8993        raise  ValueError ("output_len too small" )
9094    # Load the dataset. 
@@ -119,7 +123,7 @@ def sample_sharegpt_requests(
119123        if  prompt_len  >  1024  or  prompt_len  +  output_len  >  2048 :
120124            # Prune too long sequences. 
121125            continue 
122-         filtered_dataset .append ((prompt , prompt_len , output_len ))
126+         filtered_dataset .append ((prompt , prompt_len , output_len ,  None ))
123127
124128    return  filtered_dataset 
125129
@@ -131,7 +135,7 @@ def sample_sonnet_requests(
131135    output_len : int ,
132136    prefix_len : int ,
133137    tokenizer : PreTrainedTokenizerBase ,
134- ) ->  List [Tuple [str , str , int , int ]]:
138+ ) ->  List [Tuple [str , str , int , int ,  None ]]:
135139    assert  (
136140        input_len  >  prefix_len 
137141    ), "'args.sonnet-input-len' must be greater than 'args.prefix-input-len'." 
@@ -189,7 +193,65 @@ def sample_sonnet_requests(
189193            message , add_generation_prompt = True , tokenize = False )
190194        prompt_len  =  len (tokenizer (prompt_formatted ).input_ids )
191195        sampled_requests .append (
192-             (prompt , prompt_formatted , prompt_len , output_len ))
196+             (prompt , prompt_formatted , prompt_len , output_len , None ))
197+ 
198+     return  sampled_requests 
199+ 
200+ 
201+ def  sample_hf_requests (
202+     dataset_path : str ,
203+     dataset_subset : str ,
204+     dataset_split : str ,
205+     num_requests : int ,
206+     tokenizer : PreTrainedTokenizerBase ,
207+     fixed_output_len : Optional [int ] =  None ,
208+ ) ->  List [Tuple [str , str , int , Optional [Dict [str , Collection [str ]]]]]:
209+     dataset  =  load_dataset (dataset_path ,
210+                            name = dataset_subset ,
211+                            split = dataset_split ,
212+                            streaming = True )
213+     assert  "conversations"  in  dataset .features , (
214+         "HF Dataset must have 'conversations' column." )
215+     filtered_dataset  =  dataset .shuffle ().filter (
216+         lambda  x : len (x ["conversations" ]) >=  2 )
217+     sampled_requests : List [Tuple [str , int , int , Dict [str ,
218+                                                      Collection [str ]]]] =  []
219+     for  data  in  filtered_dataset :
220+         if  len (sampled_requests ) ==  num_requests :
221+             break 
222+ 
223+         # Tokenize the prompts and completions. 
224+         prompt  =  data ["conversations" ][0 ]["value" ]
225+         prompt_token_ids  =  tokenizer (prompt ).input_ids 
226+         completion  =  data ["conversations" ][1 ]["value" ]
227+         completion_token_ids  =  tokenizer (completion ).input_ids 
228+         prompt_len  =  len (prompt_token_ids )
229+         output_len  =  len (completion_token_ids 
230+                          ) if  fixed_output_len  is  None  else  fixed_output_len 
231+         if  prompt_len  <  4  or  output_len  <  4 :
232+             # Prune too short sequences. 
233+             continue 
234+         if  prompt_len  >  1024  or  prompt_len  +  output_len  >  2048 :
235+             # Prune too long sequences. 
236+             continue 
237+ 
238+         if  "image"  in  data  and  isinstance (data ["image" ], Image ):
239+             image : Image  =  data ["image" ]
240+             image  =  image .convert ("RGB" )
241+             image_data  =  io .BytesIO ()
242+             image .save (image_data , format = 'JPEG' )
243+             image_base64  =  base64 .b64encode (
244+                 image_data .getvalue ()).decode ("utf-8" )
245+             mm_content  =  {
246+                 "type" : "image_url" ,
247+                 "image_url" : {
248+                     "url" : f"data:image/jpeg;base64,{ image_base64 }  " 
249+                 },
250+             }
251+         else :
252+             mm_content  =  None 
253+ 
254+         sampled_requests .append ((prompt , prompt_len , output_len , mm_content ))
193255
194256    return  sampled_requests 
195257
@@ -223,8 +285,8 @@ def sample_random_requests(
223285                                  [(offsets [i ] +  i  +  j ) %  tokenizer .vocab_size 
224286                                   for  j  in  range (input_lens [i ])])
225287
226-         input_requests .append (
227-             ( prompt ,  int ( prefix_len   +   input_lens [ i ]),  int (output_lens [i ])))
288+         input_requests .append (( prompt ,  int ( prefix_len   +   input_lens [ i ]), 
289+                                 int (output_lens [i ]),  None ))
228290
229291    return  input_requests 
230292
@@ -343,7 +405,12 @@ async def benchmark(
343405        raise  ValueError (f"Unknown backend: { backend }  " )
344406
345407    print ("Starting initial single prompt test run..." )
346-     test_prompt , test_prompt_len , test_output_len  =  input_requests [0 ]
408+     test_prompt , test_prompt_len , test_output_len , test_mm_content  =  (
409+         input_requests [0 ])
410+     if  backend  !=  "openai-chat"  and  test_mm_content  is  not   None :
411+         # multi-modal benchmark is only available on OpenAI Chat backend. 
412+         raise  ValueError (
413+             "Multi-modal content is only supported on 'openai-chat' backend." )
347414    test_input  =  RequestFuncInput (
348415        model = model_id ,
349416        prompt = test_prompt ,
@@ -353,6 +420,7 @@ async def benchmark(
353420        logprobs = logprobs ,
354421        best_of = best_of ,
355422        use_beam_search = use_beam_search ,
423+         multi_modal_content = test_mm_content ,
356424    )
357425    test_output  =  await  request_func (request_func_input = test_input )
358426    if  not  test_output .success :
@@ -373,6 +441,7 @@ async def benchmark(
373441            logprobs = logprobs ,
374442            best_of = best_of ,
375443            use_beam_search = use_beam_search ,
444+             multi_modal_content = test_mm_content ,
376445        )
377446        profile_output  =  await  request_func (request_func_input = profile_input )
378447        if  profile_output .success :
@@ -385,7 +454,7 @@ async def benchmark(
385454    benchmark_start_time  =  time .perf_counter ()
386455    tasks : List [asyncio .Task ] =  []
387456    async  for  request  in  get_request (input_requests , request_rate ):
388-         prompt , prompt_len , output_len  =  request 
457+         prompt , prompt_len , output_len ,  mm_content  =  request 
389458        request_func_input  =  RequestFuncInput (
390459            model = model_id ,
391460            prompt = prompt ,
@@ -395,6 +464,7 @@ async def benchmark(
395464            logprobs = logprobs ,
396465            best_of = best_of ,
397466            use_beam_search = use_beam_search ,
467+             multi_modal_content = mm_content ,
398468        )
399469        tasks .append (
400470            asyncio .create_task (
@@ -575,6 +645,16 @@ def main(args: argparse.Namespace):
575645                              for  prompt , prompt_formatted , prompt_len ,
576646                              output_len  in  input_requests ]
577647
648+     elif  args .dataset_name  ==  "hf" :
649+         input_requests  =  sample_hf_requests (
650+             dataset_path = args .dataset_path ,
651+             dataset_subset = args .hf_subset ,
652+             dataset_split = args .hf_split ,
653+             num_requests = args .num_prompts ,
654+             tokenizer = tokenizer ,
655+             fixed_output_len = args .hf_output_len ,
656+         )
657+ 
578658    elif  args .dataset_name  ==  "random" :
579659        input_requests  =  sample_random_requests (
580660            prefix_len = args .random_prefix_len ,
@@ -685,13 +765,14 @@ def main(args: argparse.Namespace):
685765        "--dataset-name" ,
686766        type = str ,
687767        default = "sharegpt" ,
688-         choices = ["sharegpt" , "sonnet" , "random" ],
768+         choices = ["sharegpt" , "sonnet" , "random" ,  "hf" ],
689769        help = "Name of the dataset to benchmark on." ,
690770    )
691771    parser .add_argument ("--dataset-path" ,
692772                        type = str ,
693773                        default = None ,
694-                         help = "Path to the dataset." )
774+                         help = "Path to the sharegpt/sonnet dataset. " 
775+                         "Or the huggingface dataset ID if using HF dataset." )
695776    parser .add_argument (
696777        "--model" ,
697778        type = str ,
@@ -718,26 +799,6 @@ def main(args: argparse.Namespace):
718799        default = 1000 ,
719800        help = "Number of prompts to process." ,
720801    )
721-     parser .add_argument (
722-         "--sharegpt-output-len" ,
723-         type = int ,
724-         default = None ,
725-         help = "Output length for each request. Overrides the output length " 
726-         "from the ShareGPT dataset." )
727-     parser .add_argument (
728-         "--sonnet-input-len" ,
729-         type = int ,
730-         default = 550 ,
731-         help = 
732-         "Number of input tokens per request, used only for sonnet dataset." ,
733-     )
734-     parser .add_argument (
735-         "--sonnet-output-len" ,
736-         type = int ,
737-         default = 150 ,
738-         help = 
739-         "Number of output tokens per request, used only for sonnet dataset." ,
740-     )
741802    parser .add_argument (
742803        "--logprobs" ,
743804        type = int ,
@@ -748,42 +809,6 @@ def main(args: argparse.Namespace):
748809              "logprob is returned for each token; or (2) if beam search " 
749810              "is enabled 1 logprob per token is computed" ),
750811    )
751-     parser .add_argument (
752-         "--sonnet-prefix-len" ,
753-         type = int ,
754-         default = 200 ,
755-         help = 
756-         "Number of prefix tokens per request, used only for sonnet dataset." ,
757-     )
758-     parser .add_argument (
759-         "--random-input-len" ,
760-         type = int ,
761-         default = 1024 ,
762-         help = 
763-         "Number of input tokens per request, used only for random sampling." ,
764-     )
765-     parser .add_argument (
766-         "--random-output-len" ,
767-         type = int ,
768-         default = 128 ,
769-         help = 
770-         "Number of output tokens per request, used only for random sampling." ,
771-     )
772-     parser .add_argument (
773-         "--random-range-ratio" ,
774-         type = float ,
775-         default = 1.0 ,
776-         help = "Range of sampled ratio of input/output length, " 
777-         "used only for random sampling." ,
778-     )
779-     parser .add_argument (
780-         "--random-prefix-len" ,
781-         type = int ,
782-         default = 0 ,
783-         help = "Number of fixed prefix tokens before random " 
784-         " context. The length range of context in a random " 
785-         " request is [random-prefix-len, " 
786-         " random-prefix-len + random-prefix-len * random-range-ratio)." )
787812    parser .add_argument (
788813        "--request-rate" ,
789814        type = float ,
@@ -857,5 +882,85 @@ def main(args: argparse.Namespace):
857882        "Use \" --percentile-metrics\"  to select metrics." ,
858883    )
859884
885+     # group for dataset specific arguments 
886+     sonnet_group  =  parser .add_argument_group ("sonnet dataset options" )
887+     sonnet_group .add_argument (
888+         "--sonnet-input-len" ,
889+         type = int ,
890+         default = 550 ,
891+         help = 
892+         "Number of input tokens per request, used only for sonnet dataset." ,
893+     )
894+     sonnet_group .add_argument (
895+         "--sonnet-output-len" ,
896+         type = int ,
897+         default = 150 ,
898+         help = 
899+         "Number of output tokens per request, used only for sonnet dataset." ,
900+     )
901+     sonnet_group .add_argument (
902+         "--sonnet-prefix-len" ,
903+         type = int ,
904+         default = 200 ,
905+         help = 
906+         "Number of prefix tokens per request, used only for sonnet dataset." ,
907+     )
908+ 
909+     sharegpt_group  =  parser .add_argument_group ("sharegpt dataset options" )
910+     sharegpt_group .add_argument (
911+         "--sharegpt-output-len" ,
912+         type = int ,
913+         default = None ,
914+         help = "Output length for each request. Overrides the output length " 
915+         "from the ShareGPT dataset." )
916+ 
917+     random_group  =  parser .add_argument_group ("random dataset options" )
918+     random_group .add_argument (
919+         "--random-input-len" ,
920+         type = int ,
921+         default = 1024 ,
922+         help = 
923+         "Number of input tokens per request, used only for random sampling." ,
924+     )
925+     random_group .add_argument (
926+         "--random-output-len" ,
927+         type = int ,
928+         default = 128 ,
929+         help = 
930+         "Number of output tokens per request, used only for random sampling." ,
931+     )
932+     random_group .add_argument (
933+         "--random-range-ratio" ,
934+         type = float ,
935+         default = 1.0 ,
936+         help = "Range of sampled ratio of input/output length, " 
937+         "used only for random sampling." ,
938+     )
939+     random_group .add_argument (
940+         "--random-prefix-len" ,
941+         type = int ,
942+         default = 0 ,
943+         help = "Number of fixed prefix tokens before random " 
944+         " context. The length range of context in a random " 
945+         " request is [random-prefix-len, " 
946+         " random-prefix-len + random-prefix-len * random-range-ratio)." )
947+ 
948+     hf_group  =  parser .add_argument_group ("hf dataset options" )
949+     hf_group .add_argument ("--hf-subset" ,
950+                           type = str ,
951+                           default = None ,
952+                           help = "Subset of the HF dataset." )
953+     hf_group .add_argument ("--hf-split" ,
954+                           type = str ,
955+                           default = None ,
956+                           help = "Split of the HF dataset." )
957+     hf_group .add_argument (
958+         "--hf-output-len" ,
959+         type = int ,
960+         default = None ,
961+         help = "Output length for each request. Overrides the output lengths " 
962+         "from the sampled HF dataset." ,
963+     )
964+ 
860965    args  =  parser .parse_args ()
861966    main (args )
0 commit comments