55For most models, the prompt format should follow corresponding examples
66on HuggingFace model repository.
77"""
8+ import random
9+
810from transformers import AutoTokenizer
911
1012from vllm import LLM , SamplingParams
@@ -23,7 +25,9 @@ def run_llava(question: str, modality: str):
2325
2426 prompt = f"USER: <image>\n { question } \n ASSISTANT:"
2527
26- llm = LLM (model = "llava-hf/llava-1.5-7b-hf" , max_model_len = 4096 )
28+ llm = LLM (model = "llava-hf/llava-1.5-7b-hf" ,
29+ max_model_len = 4096 ,
30+ mm_cache_preprocessor = args .mm_cache_preprocessor )
2731 stop_token_ids = None
2832 return llm , prompt , stop_token_ids
2933
@@ -33,7 +37,9 @@ def run_llava_next(question: str, modality: str):
3337 assert modality == "image"
3438
3539 prompt = f"[INST] <image>\n { question } [/INST]"
36- llm = LLM (model = "llava-hf/llava-v1.6-mistral-7b-hf" , max_model_len = 8192 )
40+ llm = LLM (model = "llava-hf/llava-v1.6-mistral-7b-hf" ,
41+ max_model_len = 8192 ,
42+ mm_cache_preprocessor = args .mm_cache_preprocessor )
3743 stop_token_ids = None
3844 return llm , prompt , stop_token_ids
3945
@@ -44,7 +50,9 @@ def run_llava_next_video(question: str, modality: str):
4450 assert modality == "video"
4551
4652 prompt = f"USER: <video>\n { question } ASSISTANT:"
47- llm = LLM (model = "llava-hf/LLaVA-NeXT-Video-7B-hf" , max_model_len = 8192 )
53+ llm = LLM (model = "llava-hf/LLaVA-NeXT-Video-7B-hf" ,
54+ max_model_len = 8192 ,
55+ mm_cache_preprocessor = args .mm_cache_preprocessor )
4856 stop_token_ids = None
4957 return llm , prompt , stop_token_ids
5058
@@ -61,7 +69,8 @@ def run_llava_onevision(question: str, modality: str):
6169 <|im_start|>assistant\n "
6270
6371 llm = LLM (model = "llava-hf/llava-onevision-qwen2-7b-ov-hf" ,
64- max_model_len = 16384 )
72+ max_model_len = 16384 ,
73+ mm_cache_preprocessor = args .mm_cache_preprocessor )
6574 stop_token_ids = None
6675 return llm , prompt , stop_token_ids
6776
@@ -71,7 +80,10 @@ def run_fuyu(question: str, modality: str):
7180 assert modality == "image"
7281
7382 prompt = f"{ question } \n "
74- llm = LLM (model = "adept/fuyu-8b" , max_model_len = 2048 , max_num_seqs = 2 )
83+ llm = LLM (model = "adept/fuyu-8b" ,
84+ max_model_len = 2048 ,
85+ max_num_seqs = 2 ,
86+ mm_cache_preprocessor = args .mm_cache_preprocessor )
7587 stop_token_ids = None
7688 return llm , prompt , stop_token_ids
7789
@@ -107,6 +119,7 @@ def run_phi3v(question: str, modality: str):
107119 max_num_seqs = 2 ,
108120 # Note - mm_processor_kwargs can also be passed to generate/chat calls
109121 mm_processor_kwargs = {"num_crops" : 16 },
122+ mm_cache_preprocessor = args .mm_cache_preprocessor ,
110123 )
111124 stop_token_ids = None
112125 return llm , prompt , stop_token_ids
@@ -118,7 +131,8 @@ def run_paligemma(question: str, modality: str):
118131
119132 # PaliGemma has special prompt format for VQA
120133 prompt = "caption en"
121- llm = LLM (model = "google/paligemma-3b-mix-224" )
134+ llm = LLM (model = "google/paligemma-3b-mix-224" ,
135+ mm_cache_preprocessor = args .mm_cache_preprocessor )
122136 stop_token_ids = None
123137 return llm , prompt , stop_token_ids
124138
@@ -128,7 +142,9 @@ def run_chameleon(question: str, modality: str):
128142 assert modality == "image"
129143
130144 prompt = f"{ question } <image>"
131- llm = LLM (model = "facebook/chameleon-7b" , max_model_len = 4096 )
145+ llm = LLM (model = "facebook/chameleon-7b" ,
146+ max_model_len = 4096 ,
147+ mm_cache_preprocessor = args .mm_cache_preprocessor )
132148 stop_token_ids = None
133149 return llm , prompt , stop_token_ids
134150
@@ -154,6 +170,7 @@ def run_minicpmv(question: str, modality: str):
154170 max_model_len = 4096 ,
155171 max_num_seqs = 2 ,
156172 trust_remote_code = True ,
173+ mm_cache_preprocessor = args .mm_cache_preprocessor ,
157174 )
158175 # NOTE The stop_token_ids are different for various versions of MiniCPM-V
159176 # 2.0
@@ -186,6 +203,7 @@ def run_h2ovl(question: str, modality: str):
186203 model = model_name ,
187204 trust_remote_code = True ,
188205 max_model_len = 8192 ,
206+ mm_cache_preprocessor = args .mm_cache_preprocessor ,
189207 )
190208
191209 tokenizer = AutoTokenizer .from_pretrained (model_name ,
@@ -211,6 +229,7 @@ def run_internvl(question: str, modality: str):
211229 model = model_name ,
212230 trust_remote_code = True ,
213231 max_model_len = 4096 ,
232+ mm_cache_preprocessor = args .mm_cache_preprocessor ,
214233 )
215234
216235 tokenizer = AutoTokenizer .from_pretrained (model_name ,
@@ -241,6 +260,7 @@ def run_nvlm_d(question: str, modality: str):
241260 trust_remote_code = True ,
242261 max_model_len = 4096 ,
243262 tensor_parallel_size = 4 ,
263+ mm_cache_preprocessor = args .mm_cache_preprocessor ,
244264 )
245265
246266 tokenizer = AutoTokenizer .from_pretrained (model_name ,
@@ -260,7 +280,8 @@ def run_blip2(question: str, modality: str):
260280 # BLIP-2 prompt format is inaccurate on HuggingFace model repository.
261281 # See https://huggingface.co/Salesforce/blip2-opt-2.7b/discussions/15#64ff02f3f8cf9e4f5b038262 #noqa
262282 prompt = f"Question: { question } Answer:"
263- llm = LLM (model = "Salesforce/blip2-opt-2.7b" )
283+ llm = LLM (model = "Salesforce/blip2-opt-2.7b" ,
284+ mm_cache_preprocessor = args .mm_cache_preprocessor )
264285 stop_token_ids = None
265286 return llm , prompt , stop_token_ids
266287
@@ -274,6 +295,7 @@ def run_qwen_vl(question: str, modality: str):
274295 trust_remote_code = True ,
275296 max_model_len = 1024 ,
276297 max_num_seqs = 2 ,
298+ mm_cache_preprocessor = args .mm_cache_preprocessor ,
277299 )
278300
279301 prompt = f"{ question } Picture 1: <img></img>\n "
@@ -296,6 +318,7 @@ def run_qwen2_vl(question: str, modality: str):
296318 "min_pixels" : 28 * 28 ,
297319 "max_pixels" : 1280 * 28 * 28 ,
298320 },
321+ mm_cache_preprocessor = args .mm_cache_preprocessor ,
299322 )
300323
301324 prompt = ("<|im_start|>system\n You are a helpful assistant.<|im_end|>\n "
@@ -315,6 +338,7 @@ def run_pixtral_hf(question: str, modality: str):
315338 llm = LLM (
316339 model = model_name ,
317340 max_model_len = 8192 ,
341+ mm_cache_preprocessor = args .mm_cache_preprocessor ,
318342 )
319343
320344 prompt = f"<s>[INST]{ question } \n [IMG][/INST]"
@@ -338,6 +362,7 @@ def run_mllama(question: str, modality: str):
338362 max_model_len = 4096 ,
339363 max_num_seqs = 16 ,
340364 enforce_eager = True ,
365+ mm_cache_preprocessor = args .mm_cache_preprocessor ,
341366 )
342367
343368 prompt = f"<|image|><|begin_of_text|>{ question } "
@@ -355,6 +380,7 @@ def run_molmo(question, modality):
355380 model = model_name ,
356381 trust_remote_code = True ,
357382 dtype = "bfloat16" ,
383+ mm_cache_preprocessor = args .mm_cache_preprocessor ,
358384 )
359385
360386 prompt = question
@@ -371,7 +397,8 @@ def run_glm4v(question: str, modality: str):
371397 max_model_len = 2048 ,
372398 max_num_seqs = 2 ,
373399 trust_remote_code = True ,
374- enforce_eager = True )
400+ enforce_eager = True ,
401+ mm_cache_preprocessor = args .mm_cache_preprocessor )
375402 prompt = question
376403 stop_token_ids = [151329 , 151336 , 151338 ]
377404 return llm , prompt , stop_token_ids
@@ -394,6 +421,7 @@ def run_idefics3(question: str, modality: str):
394421 "longest_edge" : 3 * 364
395422 },
396423 },
424+ mm_cache_preprocessor = args .mm_cache_preprocessor ,
397425 )
398426 prompt = (
399427 f"<|begin_of_text|>User:<image>{ question } <end_of_utterance>\n Assistant:"
@@ -410,7 +438,8 @@ def run_aria(question: str, modality: str):
410438 llm = LLM (model = model_name ,
411439 tokenizer_mode = "slow" ,
412440 trust_remote_code = True ,
413- dtype = "bfloat16" )
441+ dtype = "bfloat16" ,
442+ mm_cache_preprocessor = args .mm_cache_preprocessor )
414443
415444 prompt = (f"<|im_start|>user\n <fim_prefix><|img|><fim_suffix>\n { question } "
416445 "<|im_end|>\n <|im_start|>assistant\n " )
@@ -430,6 +459,7 @@ def run_mantis(question: str, modality: str):
430459 model = "TIGER-Lab/Mantis-8B-siglip-llama3" ,
431460 max_model_len = 4096 ,
432461 hf_overrides = {"architectures" : ["MantisForConditionalGeneration" ]},
462+ mm_cache_preprocessor = args .mm_cache_preprocessor ,
433463 )
434464 stop_token_ids = [128009 ]
435465 return llm , prompt , stop_token_ids
@@ -494,6 +524,35 @@ def get_multi_modal_input(args):
494524 raise ValueError (msg )
495525
496526
527+ def apply_image_repeat (image_repeat_prob , num_prompts , data , prompt , modality ):
528+ """Repeats images with provided probability of "image_repeat_prob".
529+ Used to simulate hit/miss for the MM preprocessor cache.
530+ """
531+ assert (image_repeat_prob <= 1.0 and image_repeat_prob >= 0 )
532+ no_yes = [0 , 1 ]
533+ probs = [1.0 - image_repeat_prob , image_repeat_prob ]
534+
535+ inputs = []
536+ cur_image = data
537+ for i in range (num_prompts ):
538+ if image_repeat_prob is not None :
539+ res = random .choices (no_yes , probs )[0 ]
540+ if res == 0 :
541+ # No repeat => Modify one pixel
542+ cur_image = cur_image .copy ()
543+ new_val = (i // 256 // 256 , i // 256 , i % 256 )
544+ cur_image .putpixel ((0 , 0 ), new_val )
545+
546+ inputs .append ({
547+ "prompt" : prompt ,
548+ "multi_modal_data" : {
549+ modality : cur_image
550+ }
551+ })
552+
553+ return inputs
554+
555+
497556def main (args ):
498557 model = args .model_type
499558 if model not in model_example_map :
@@ -524,14 +583,29 @@ def main(args):
524583
525584 else :
526585 # Batch inference
527- inputs = [{
528- "prompt" : prompt ,
529- "multi_modal_data" : {
530- modality : data
531- },
532- } for _ in range (args .num_prompts )]
586+ if args .image_repeat_prob is not None :
587+ # Repeat images with specified probability of "image_repeat_prob"
588+ inputs = apply_image_repeat (args .image_repeat_prob ,
589+ args .num_prompts , data , prompt ,
590+ modality )
591+ else :
592+ # Use the same image for all prompts
593+ inputs = [{
594+ "prompt" : prompt ,
595+ "multi_modal_data" : {
596+ modality : data
597+ },
598+ } for _ in range (args .num_prompts )]
599+
600+ if args .time_generate :
601+ import time
602+ start_time = time .time ()
603+ outputs = llm .generate (inputs , sampling_params = sampling_params )
604+ elapsed_time = time .time () - start_time
605+ print ("-- generate time = {}" .format (elapsed_time ))
533606
534- outputs = llm .generate (inputs , sampling_params = sampling_params )
607+ else :
608+ outputs = llm .generate (inputs , sampling_params = sampling_params )
535609
536610 for o in outputs :
537611 generated_text = o .outputs [0 ].text
@@ -561,5 +635,23 @@ def main(args):
561635 type = int ,
562636 default = 16 ,
563637 help = 'Number of frames to extract from the video.' )
638+
639+ parser .add_argument (
640+ '--image-repeat-prob' ,
641+ type = float ,
642+ default = None ,
643+ help = 'Simulates the hit-ratio for multi-modal preprocessor cache'
644+ ' (if enabled)' )
645+
646+ parser .add_argument (
647+ '--mm-cache-preprocessor' ,
648+ action = 'store_true' ,
649+ help = 'If True, enable caching of multi-modal preprocessor/mapper.' )
650+
651+ parser .add_argument (
652+ '--time-generate' ,
653+ action = 'store_true' ,
654+ help = 'If True, then print the total generate() call time' )
655+
564656 args = parser .parse_args ()
565657 main (args )
0 commit comments