2323import  typing 
2424import  warnings 
2525from  pathlib  import  Path 
26- from  typing  import  Any , Callable , Optional , TypedDict , Union 
26+ from  typing  import  Any , Callable , Dict ,  List ,  Optional , TypedDict , Union 
2727
2828import  numpy  as  np 
2929import  typing_extensions 
@@ -386,14 +386,10 @@ class TokenizerChatTemplateKwargs(TypedDict, total=False):
386386    return_assistant_tokens_mask : Optional [bool ] =  False 
387387
388388
389- class  ProcessorChatTemplateKwargs ( TokenizerChatTemplateKwargs , total = False ):
389+ class  ChatTemplateLoadKwargs ( TypedDict , total = False ):
390390    """ 
391-     Keyword arguments for  processor chat templates. 
391+     Keyword arguments used to load multimodal data in  processor chat templates. 
392392
393-     tokenize (`bool`, *optional*, defaults to `False`): 
394-         Whether to tokenize the output or not. 
395-     return_dict (`bool`, defaults to `False`): 
396-         Whether to return a dictionary with named outputs. Has no effect if tokenize is `False`. 
397393    num_frames (`int`, *optional*): 
398394        Number of frames to sample uniformly. If not passed, the whole video is loaded. 
399395    video_load_backend (`str`, *optional*, defaults to `"pyav"`): 
@@ -415,13 +411,26 @@ def sample_indices_fn(num_frames, fps, metadata, **kwargs):
415411                return np.linspace(start_idx, end_idx, num_frames, dtype=int) 
416412    """ 
417413
418-     tokenize : Optional [bool ] =  False 
419-     return_dict : Optional [bool ] =  False 
420414    num_frames : Optional [int ] =  None 
421415    video_load_backend : Optional [str ] =  "pyav" 
422416    video_fps : Optional [int ] =  None 
423417    sampling_rate : Optional [int ] =  16_000 
424418    sample_indices_fn : Optional [Callable ] =  None 
419+     load_audio_from_video : Optional [bool ] =  False 
420+ 
421+ 
422+ class  ProcessorChatTemplateKwargs (ChatTemplateLoadKwargs , TokenizerChatTemplateKwargs , total = False ):
423+     """ 
424+     Keyword arguments for processor's `apply_chat_template`. 
425+ 
426+     tokenize (`bool`, *optional*, defaults to `False`): 
427+         Whether to tokenize the output or not. 
428+     return_dict (`bool`, defaults to `False`): 
429+         Whether to return a dictionary with named outputs. Has no effect if tokenize is `False`. 
430+     """ 
431+ 
432+     tokenize : Optional [bool ] =  False 
433+     return_dict : Optional [bool ] =  False 
425434
426435
427436class  AllKwargsForChatTemplate (
@@ -1236,11 +1245,11 @@ def __call__(
12361245
12371246    def  _process_messages_for_chat_template (
12381247        self ,
1239-         conversation : list [ list [ dict [str , str ]]],
1240-         batch_images : list [ImageInput ],
1241-         batch_videos : list [VideoInput ],
1242-         batch_video_metadata : list [ list [ dict [str , any ]]],
1243-         ** chat_template_kwargs : Unpack [AllKwargsForChatTemplate ],
1248+         conversation : List [ List [ Dict [str , str ]]],
1249+         batch_images : List [ImageInput ],
1250+         batch_videos : List [VideoInput ],
1251+         batch_video_metadata : List [ List [ Dict [str , any ]]],
1252+         ** mm_load_kwargs : Unpack [ChatTemplateLoadKwargs ],
12441253    ):
12451254        """ 
12461255        Used within `apply_chat_template` when a model has a special way to process conversation history. For example, 
@@ -1311,18 +1320,18 @@ def apply_chat_template(
13111320                )
13121321
13131322        # Fill two sets of kwargs that should be used by tokenizer's `apply_chat_template` 
1314-         # and for multimodal chat template  
1323+         # and for multimodal data loading. Everything else will be used in `__call__`  
13151324        tokenizer_template_kwargs  =  {}
13161325        for  tokenizer_key  in  TokenizerChatTemplateKwargs .__annotations__ .keys ():
1317-             tokenizer_value  =  getattr (TokenizerChatTemplateKwargs , tokenizer_key , None )
1318-             value  =  kwargs .pop (tokenizer_key , tokenizer_value )
1326+             default_value  =  getattr (TokenizerChatTemplateKwargs , tokenizer_key , None )
1327+             value  =  kwargs .pop (tokenizer_key , default_value )
13191328            tokenizer_template_kwargs [tokenizer_key ] =  value 
13201329
1321-         chat_template_kwargs  =  {}
1322-         for  key  in  ProcessorChatTemplateKwargs .__annotations__ .keys ():
1323-             processor_value  =  getattr (ProcessorChatTemplateKwargs ,  key , None )
1324-             value  =  kwargs .pop (key ,  processor_value )
1325-             chat_template_kwargs [ key ] =  value 
1330+         mm_load_kwargs  =  {}
1331+         for  mm_load_key  in  ChatTemplateLoadKwargs .__annotations__ .keys ():
1332+             default_value  =  getattr (ChatTemplateLoadKwargs ,  mm_load_key , None )
1333+             value  =  kwargs .pop (mm_load_key ,  default_value )
1334+             mm_load_kwargs [ mm_load_key ] =  value 
13261335
13271336        if  isinstance (conversation , (list , tuple )) and  (
13281337            isinstance (conversation [0 ], (list , tuple )) or  hasattr (conversation [0 ], "content" )
@@ -1333,13 +1342,8 @@ def apply_chat_template(
13331342            is_batched  =  False 
13341343            conversations  =  [conversation ]
13351344
1336-         num_frames  =  chat_template_kwargs .get ("num_frames" )
1337-         video_fps  =  chat_template_kwargs .get ("video_fps" )
1338-         video_load_backend  =  chat_template_kwargs .get ("video_load_backend" )
1339-         tokenize  =  chat_template_kwargs .get ("tokenize" )
1340-         return_dict  =  chat_template_kwargs .get ("return_dict" )
1341-         sample_indices_fn  =  chat_template_kwargs .get ("sample_indices_fn" )
1342-         sampling_rate  =  chat_template_kwargs .pop ("sampling_rate" )
1345+         tokenize  =  kwargs .pop ("tokenize" , False )
1346+         return_dict  =  kwargs .pop ("return_dict" , False )
13431347
13441348        if  tokenize :
13451349            batch_images , batch_videos  =  [], []
@@ -1369,31 +1373,37 @@ def apply_chat_template(
13691373                        if  key  in  vision_info  and  vision_info ["type" ] ==  "video" 
13701374                    ]
13711375
1372-                     # Audio models do not accept nested list of audios (yet!) 
1373-                     for  fname  in  audio_fnames :
1374-                         batch_audios .append (load_audio (fname , sampling_rate = sampling_rate ))
13751376                    for  fname  in  image_fnames :
13761377                        images .append (load_image (fname ))
1377-                     for  fname  in  video_fnames :
1378-                         if  isinstance (fname , (list , tuple )) and  isinstance (fname [0 ], str ):
1379-                             video  =  [np .array (load_image (image_fname )).T  for  image_fname  in  fname ]
1380-                             # create a 4D video because `load_video` always returns a 4D array 
1381-                             video  =  np .stack (video )
1382-                             metadata  =  None 
1383-                             logger .warning (
1384-                                 "When loading the video from list of images, we cannot infer metadata such as `fps` or `duration`. " 
1385-                                 "If you model applies special processing based on metadata, please load the whole video and let the model sample frames." 
1386-                             )
1387-                         else :
1388-                             video , metadata  =  load_video (
1389-                                 fname ,
1390-                                 num_frames = num_frames ,
1391-                                 fps = video_fps ,
1392-                                 backend = video_load_backend ,
1393-                                 sample_indices_fn = sample_indices_fn ,
1394-                             )
1395-                         videos .append (video )
1396-                         video_metadata .append (metadata )
1378+ 
1379+                     # Audio models do not accept nested list of audios (yet!) so we construct a flat input audio list 
1380+                     if  not  mm_load_kwargs ["load_audio_from_video" ]:
1381+                         for  fname  in  audio_fnames :
1382+                             batch_audios .append (load_audio (fname , sampling_rate = mm_load_kwargs ["sampling_rate" ]))
1383+                     else :
1384+                         for  fname  in  video_fnames :
1385+                             if  isinstance (fname , (list , tuple )) and  isinstance (fname [0 ], str ):
1386+                                 video  =  [np .array (load_image (image_fname )).T  for  image_fname  in  fname ]
1387+                                 # create a 4D video because `load_video` always returns a 4D array 
1388+                                 video  =  np .stack (video )
1389+                                 metadata  =  None 
1390+                                 audios  =  None 
1391+                                 logger .warning (
1392+                                     "When loading the video from list of images, we cannot infer metadata such as `fps` or `duration`. " 
1393+                                     "If your model uses this metadata during processing, please load the whole video and let the model sample frames instead." 
1394+                                 )
1395+                             else :
1396+                                 video , metadata  =  load_video (
1397+                                     fname ,
1398+                                     num_frames = mm_load_kwargs ["num_frames" ],
1399+                                     fps = mm_load_kwargs ["video_fps" ],
1400+                                     backend = mm_load_kwargs ["video_load_backend" ],
1401+                                     sample_indices_fn = mm_load_kwargs ["sample_indices_fn" ],
1402+                                 )
1403+                                 audios  =  load_audio (fname , sampling_rate = mm_load_kwargs ["sampling_rate" ])
1404+                             batch_audios .append (audios )
1405+                             videos .append (video )
1406+                             video_metadata .append (metadata )
13971407
13981408                # Currently all processors can accept nested list of batches, but not flat list of visuals 
13991409                # So we'll make a batched list of images and let the processor handle it 
@@ -1409,7 +1419,7 @@ def apply_chat_template(
14091419                batch_images = batch_images ,
14101420                batch_videos = batch_videos ,
14111421                batch_video_metadata = batch_video_metadata ,
1412-                 ** chat_template_kwargs ,
1422+                 ** mm_load_kwargs ,
14131423            )
14141424
14151425        prompt  =  self .tokenizer .apply_chat_template (
@@ -1438,7 +1448,7 @@ def apply_chat_template(
14381448                text = prompt ,
14391449                images = batch_images  if  batch_images  else  None ,
14401450                videos = batch_videos  if  batch_videos  else  None ,
1441-                 audios = batch_audios  if  batch_audios  else  None ,
1451+                 audio = batch_audios  if  batch_audios  else  None ,
14421452                ** kwargs ,
14431453            )
14441454            if  return_dict :
0 commit comments