4343_TEST_PROMPTS = [os .path .join (_TEST_DIR , "prompts" , "example.txt" )]
4444_LONG_PROMPTS = [os .path .join (_TEST_DIR , "prompts" , "summary.txt" )]
4545
46- PromptImageInput = Union [List [Image .Image ], List [List [Image .Image ]]]
47- PromptAudioInput = Union [List [Tuple [np .ndarray , int ]],
48- List [List [Tuple [np .ndarray , int ]]]]
49- PromptVideoInput = Union [List [np .ndarray ], List [List [np .ndarray ]]]
46+ _M = TypeVar ("_M" )
47+ _PromptMultiModalInput = Union [List [_M ], List [List [_M ]]]
48+
49+ PromptImageInput = _PromptMultiModalInput [Image .Image ]
50+ PromptAudioInput = _PromptMultiModalInput [Tuple [np .ndarray , int ]]
51+ PromptVideoInput = _PromptMultiModalInput [np .ndarray ]
5052
5153
5254def _read_prompts (filename : str ) -> List [str ]:
@@ -318,12 +320,12 @@ def get_inputs(
318320 "text" : prompt ,
319321 "return_tensors" : "pt" ,
320322 }
321- if images is not None and images [i ] is not None :
322- processor_kwargs ["images" ] = images [ i ]
323- if videos is not None and videos [i ] is not None :
324- processor_kwargs ["videos" ] = videos [ i ]
325- if audios is not None and audios [i ] is not None :
326- audio , sr = audios [ i ]
323+ if images is not None and ( image := images [i ]) is not None :
324+ processor_kwargs ["images" ] = image
325+ if videos is not None and ( video := videos [i ]) is not None :
326+ processor_kwargs ["videos" ] = video
327+ if audios is not None and ( audio_tuple := audios [i ]) is not None :
328+ audio , sr = audio_tuple
327329 processor_kwargs ["audio" ] = audio
328330 processor_kwargs ["sampling_rate" ] = sr
329331
@@ -338,7 +340,7 @@ def generate(
338340 self ,
339341 prompts : List [str ],
340342 images : Optional [PromptImageInput ] = None ,
341- videos : Optional [List [ np . ndarray ] ] = None ,
343+ videos : Optional [PromptVideoInput ] = None ,
342344 audios : Optional [PromptAudioInput ] = None ,
343345 ** kwargs : Any ,
344346 ) -> List [Tuple [List [List [int ]], List [str ]]]:
@@ -368,7 +370,7 @@ def generate_greedy(
368370 prompts : List [str ],
369371 max_tokens : int ,
370372 images : Optional [PromptImageInput ] = None ,
371- videos : Optional [List [ np . ndarray ] ] = None ,
373+ videos : Optional [PromptVideoInput ] = None ,
372374 audios : Optional [PromptAudioInput ] = None ,
373375 ** kwargs : Any ,
374376 ) -> List [Tuple [List [int ], str ]]:
@@ -409,7 +411,7 @@ def generate_greedy_logprobs(
409411 prompts : List [str ],
410412 max_tokens : int ,
411413 images : Optional [PromptImageInput ] = None ,
412- videos : Optional [List [ np . ndarray ] ] = None ,
414+ videos : Optional [PromptVideoInput ] = None ,
413415 audios : Optional [PromptAudioInput ] = None ,
414416 ** kwargs : Any ,
415417 ) -> List [List [torch .Tensor ]]:
@@ -488,7 +490,7 @@ def generate_greedy_logprobs_limit(
488490 num_logprobs : int ,
489491 images : Optional [PromptImageInput ] = None ,
490492 audios : Optional [PromptAudioInput ] = None ,
491- videos : Optional [List [ np . ndarray ] ] = None ,
493+ videos : Optional [PromptVideoInput ] = None ,
492494 ** kwargs : Any ,
493495 ) -> List [TokensTextLogprobs ]:
494496 all_inputs = self .get_inputs (prompts ,
@@ -657,15 +659,18 @@ def get_inputs(
657659 inputs = [TextPrompt (prompt = prompt ) for prompt in prompts ]
658660 if images is not None :
659661 for i , image in enumerate (images ):
660- inputs [i ]["multi_modal_data" ] = {"image" : image }
662+ if image is not None :
663+ inputs [i ]["multi_modal_data" ] = {"image" : image }
661664
662665 if videos is not None :
663666 for i , video in enumerate (videos ):
664- inputs [i ]["multi_modal_data" ] = {"video" : video }
667+ if video is not None :
668+ inputs [i ]["multi_modal_data" ] = {"video" : video }
665669
666670 if audios is not None :
667671 for i , audio in enumerate (audios ):
668- inputs [i ]["multi_modal_data" ] = {"audio" : audio }
672+ if audio is not None :
673+ inputs [i ]["multi_modal_data" ] = {"audio" : audio }
669674
670675 return inputs
671676
@@ -837,13 +842,20 @@ def generate_beam_search(
837842 returned_outputs .append ((token_ids , texts ))
838843 return returned_outputs
839844
840- def encode (self , prompts : List [str ]) -> List [List [float ]]:
841- req_outputs = self .model .encode (prompts )
842- outputs = []
843- for req_output in req_outputs :
844- embedding = req_output .outputs .embedding
845- outputs .append (embedding )
846- return outputs
845+ def encode (
846+ self ,
847+ prompts : List [str ],
848+ images : Optional [PromptImageInput ] = None ,
849+ videos : Optional [PromptVideoInput ] = None ,
850+ audios : Optional [PromptAudioInput ] = None ,
851+ ) -> List [List [float ]]:
852+ inputs = self .get_inputs (prompts ,
853+ images = images ,
854+ videos = videos ,
855+ audios = audios )
856+
857+ req_outputs = self .model .encode (inputs )
858+ return [req_output .outputs .embedding for req_output in req_outputs ]
847859
848860 def __enter__ (self ):
849861 return self
0 commit comments