Skip to content

Commit 4f01dbb

Browse files
committed
[chat templates} support loading audio from video (huggingface#36955)
* add audio from video * typos * delete print * comments
1 parent a6762d5 commit 4f01dbb

File tree

2 files changed

+129
-58
lines changed

2 files changed

+129
-58
lines changed

src/transformers/processing_utils.py

Lines changed: 64 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
import typing
2424
import warnings
2525
from pathlib import Path
26-
from typing import Any, Callable, Optional, TypedDict, Union
26+
from typing import Any, Callable, Dict, List, Optional, TypedDict, Union
2727

2828
import numpy as np
2929
import typing_extensions
@@ -386,14 +386,10 @@ class TokenizerChatTemplateKwargs(TypedDict, total=False):
386386
return_assistant_tokens_mask: Optional[bool] = False
387387

388388

389-
class ProcessorChatTemplateKwargs(TokenizerChatTemplateKwargs, total=False):
389+
class ChatTemplateLoadKwargs(TypedDict, total=False):
390390
"""
391-
Keyword arguments for processor chat templates.
391+
Keyword arguments used to load multimodal data in processor chat templates.
392392
393-
tokenize (`bool`, *optional*, defaults to `False`):
394-
Whether to tokenize the output or not.
395-
return_dict (`bool`, defaults to `False`):
396-
Whether to return a dictionary with named outputs. Has no effect if tokenize is `False`.
397393
num_frames (`int`, *optional*):
398394
Number of frames to sample uniformly. If not passed, the whole video is loaded.
399395
video_load_backend (`str`, *optional*, defaults to `"pyav"`):
@@ -415,13 +411,26 @@ def sample_indices_fn(num_frames, fps, metadata, **kwargs):
415411
return np.linspace(start_idx, end_idx, num_frames, dtype=int)
416412
"""
417413

418-
tokenize: Optional[bool] = False
419-
return_dict: Optional[bool] = False
420414
num_frames: Optional[int] = None
421415
video_load_backend: Optional[str] = "pyav"
422416
video_fps: Optional[int] = None
423417
sampling_rate: Optional[int] = 16_000
424418
sample_indices_fn: Optional[Callable] = None
419+
load_audio_from_video: Optional[bool] = False
420+
421+
422+
class ProcessorChatTemplateKwargs(ChatTemplateLoadKwargs, TokenizerChatTemplateKwargs, total=False):
423+
"""
424+
Keyword arguments for processor's `apply_chat_template`.
425+
426+
tokenize (`bool`, *optional*, defaults to `False`):
427+
Whether to tokenize the output or not.
428+
return_dict (`bool`, defaults to `False`):
429+
Whether to return a dictionary with named outputs. Has no effect if tokenize is `False`.
430+
"""
431+
432+
tokenize: Optional[bool] = False
433+
return_dict: Optional[bool] = False
425434

426435

427436
class AllKwargsForChatTemplate(
@@ -1236,11 +1245,11 @@ def __call__(
12361245

12371246
def _process_messages_for_chat_template(
12381247
self,
1239-
conversation: list[list[dict[str, str]]],
1240-
batch_images: list[ImageInput],
1241-
batch_videos: list[VideoInput],
1242-
batch_video_metadata: list[list[dict[str, any]]],
1243-
**chat_template_kwargs: Unpack[AllKwargsForChatTemplate],
1248+
conversation: List[List[Dict[str, str]]],
1249+
batch_images: List[ImageInput],
1250+
batch_videos: List[VideoInput],
1251+
batch_video_metadata: List[List[Dict[str, any]]],
1252+
**mm_load_kwargs: Unpack[ChatTemplateLoadKwargs],
12441253
):
12451254
"""
12461255
Used within `apply_chat_template` when a model has a special way to process conversation history. For example,
@@ -1311,18 +1320,18 @@ def apply_chat_template(
13111320
)
13121321

13131322
# Fill two sets of kwargs that should be used by tokenizer's `apply_chat_template`
1314-
# and for multimodal chat template
1323+
# and for multimodal data loading. Everything else will be used in `__call__`
13151324
tokenizer_template_kwargs = {}
13161325
for tokenizer_key in TokenizerChatTemplateKwargs.__annotations__.keys():
1317-
tokenizer_value = getattr(TokenizerChatTemplateKwargs, tokenizer_key, None)
1318-
value = kwargs.pop(tokenizer_key, tokenizer_value)
1326+
default_value = getattr(TokenizerChatTemplateKwargs, tokenizer_key, None)
1327+
value = kwargs.pop(tokenizer_key, default_value)
13191328
tokenizer_template_kwargs[tokenizer_key] = value
13201329

1321-
chat_template_kwargs = {}
1322-
for key in ProcessorChatTemplateKwargs.__annotations__.keys():
1323-
processor_value = getattr(ProcessorChatTemplateKwargs, key, None)
1324-
value = kwargs.pop(key, processor_value)
1325-
chat_template_kwargs[key] = value
1330+
mm_load_kwargs = {}
1331+
for mm_load_key in ChatTemplateLoadKwargs.__annotations__.keys():
1332+
default_value = getattr(ChatTemplateLoadKwargs, mm_load_key, None)
1333+
value = kwargs.pop(mm_load_key, default_value)
1334+
mm_load_kwargs[mm_load_key] = value
13261335

13271336
if isinstance(conversation, (list, tuple)) and (
13281337
isinstance(conversation[0], (list, tuple)) or hasattr(conversation[0], "content")
@@ -1333,13 +1342,8 @@ def apply_chat_template(
13331342
is_batched = False
13341343
conversations = [conversation]
13351344

1336-
num_frames = chat_template_kwargs.get("num_frames")
1337-
video_fps = chat_template_kwargs.get("video_fps")
1338-
video_load_backend = chat_template_kwargs.get("video_load_backend")
1339-
tokenize = chat_template_kwargs.get("tokenize")
1340-
return_dict = chat_template_kwargs.get("return_dict")
1341-
sample_indices_fn = chat_template_kwargs.get("sample_indices_fn")
1342-
sampling_rate = chat_template_kwargs.pop("sampling_rate")
1345+
tokenize = kwargs.pop("tokenize", False)
1346+
return_dict = kwargs.pop("return_dict", False)
13431347

13441348
if tokenize:
13451349
batch_images, batch_videos = [], []
@@ -1369,31 +1373,37 @@ def apply_chat_template(
13691373
if key in vision_info and vision_info["type"] == "video"
13701374
]
13711375

1372-
# Audio models do not accept nested list of audios (yet!)
1373-
for fname in audio_fnames:
1374-
batch_audios.append(load_audio(fname, sampling_rate=sampling_rate))
13751376
for fname in image_fnames:
13761377
images.append(load_image(fname))
1377-
for fname in video_fnames:
1378-
if isinstance(fname, (list, tuple)) and isinstance(fname[0], str):
1379-
video = [np.array(load_image(image_fname)).T for image_fname in fname]
1380-
# create a 4D video because `load_video` always returns a 4D array
1381-
video = np.stack(video)
1382-
metadata = None
1383-
logger.warning(
1384-
"When loading the video from list of images, we cannot infer metadata such as `fps` or `duration`. "
1385-
"If you model applies special processing based on metadata, please load the whole video and let the model sample frames."
1386-
)
1387-
else:
1388-
video, metadata = load_video(
1389-
fname,
1390-
num_frames=num_frames,
1391-
fps=video_fps,
1392-
backend=video_load_backend,
1393-
sample_indices_fn=sample_indices_fn,
1394-
)
1395-
videos.append(video)
1396-
video_metadata.append(metadata)
1378+
1379+
# Audio models do not accept nested list of audios (yet!) so we construct a flat input audio list
1380+
if not mm_load_kwargs["load_audio_from_video"]:
1381+
for fname in audio_fnames:
1382+
batch_audios.append(load_audio(fname, sampling_rate=mm_load_kwargs["sampling_rate"]))
1383+
else:
1384+
for fname in video_fnames:
1385+
if isinstance(fname, (list, tuple)) and isinstance(fname[0], str):
1386+
video = [np.array(load_image(image_fname)).T for image_fname in fname]
1387+
# create a 4D video because `load_video` always returns a 4D array
1388+
video = np.stack(video)
1389+
metadata = None
1390+
audios = None
1391+
logger.warning(
1392+
"When loading the video from list of images, we cannot infer metadata such as `fps` or `duration`. "
1393+
"If your model uses this metadata during processing, please load the whole video and let the model sample frames instead."
1394+
)
1395+
else:
1396+
video, metadata = load_video(
1397+
fname,
1398+
num_frames=mm_load_kwargs["num_frames"],
1399+
fps=mm_load_kwargs["video_fps"],
1400+
backend=mm_load_kwargs["video_load_backend"],
1401+
sample_indices_fn=mm_load_kwargs["sample_indices_fn"],
1402+
)
1403+
audios = load_audio(fname, sampling_rate=mm_load_kwargs["sampling_rate"])
1404+
batch_audios.append(audios)
1405+
videos.append(video)
1406+
video_metadata.append(metadata)
13971407

13981408
# Currently all processors can accept nested list of batches, but not flat list of visuals
13991409
# So we'll make a batched list of images and let the processor handle it
@@ -1409,7 +1419,7 @@ def apply_chat_template(
14091419
batch_images=batch_images,
14101420
batch_videos=batch_videos,
14111421
batch_video_metadata=batch_video_metadata,
1412-
**chat_template_kwargs,
1422+
**mm_load_kwargs,
14131423
)
14141424

14151425
prompt = self.tokenizer.apply_chat_template(
@@ -1438,7 +1448,7 @@ def apply_chat_template(
14381448
text=prompt,
14391449
images=batch_images if batch_images else None,
14401450
videos=batch_videos if batch_videos else None,
1441-
audios=batch_audios if batch_audios else None,
1451+
audio=batch_audios if batch_audios else None,
14421452
**kwargs,
14431453
)
14441454
if return_dict:

tests/test_processing_common.py

Lines changed: 65 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1097,10 +1097,7 @@ def test_chat_template_video_custom_sampling(self):
10971097
{
10981098
"role": "user",
10991099
"content": [
1100-
{
1101-
"type": "video",
1102-
"path": video_file_path,
1103-
},
1100+
{"type": "video", "path": video_file_path},
11041101
{"type": "text", "text": "What is shown in this video?"},
11051102
],
11061103
},
@@ -1189,6 +1186,70 @@ def _process_messages_for_chat_template(
11891186
self.assertEqual(len(out_dict_with_video[self.videos_input_name]), 1)
11901187
self.assertEqual(len(out_dict_with_video[self.videos_input_name][0]), 243)
11911188

1189+
@require_librosa
1190+
@require_av
1191+
def test_audio_chat_template_from_video(self):
1192+
processor = self.get_processor()
1193+
if processor.chat_template is None:
1194+
self.skipTest("Processor has no chat template")
1195+
1196+
signature = inspect.signature(processor.__call__)
1197+
if "videos" not in {*signature.parameters.keys()} or (
1198+
signature.parameters.get("videos") is not None
1199+
and signature.parameters["videos"].annotation == inspect._empty
1200+
):
1201+
self.skipTest(f"{self.processor_class} does not suport video inputs")
1202+
1203+
if "feature_extractor" not in self.processor_class.attributes:
1204+
self.skipTest(f"feature_extractor attribute not present in {self.processor_class}")
1205+
1206+
video_file_path = hf_hub_download(
1207+
repo_id="raushan-testing-hf/videos-test", filename="sample_demo_1.mp4", repo_type="dataset"
1208+
)
1209+
messages = [
1210+
{
1211+
"role": "user",
1212+
"content": [
1213+
{"type": "video", "path": video_file_path},
1214+
{"type": "text", "text": "Which of these animals is making the sound?"},
1215+
],
1216+
},
1217+
{
1218+
"role": "assistant",
1219+
"content": [{"type": "text", "text": "It is a cow."}],
1220+
},
1221+
{
1222+
"role": "user",
1223+
"content": [
1224+
{
1225+
"type": "audio",
1226+
"url": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2-Audio/audio/glass-breaking-151256.mp3",
1227+
},
1228+
{"type": "text", "text": "Is it the same sound?"},
1229+
],
1230+
},
1231+
]
1232+
1233+
formatted_prompt = processor.apply_chat_template([messages], add_generation_prompt=True, tokenize=False)
1234+
self.assertEqual(len(formatted_prompt), 1) # batch size=1
1235+
1236+
out_dict = processor.apply_chat_template(
1237+
messages,
1238+
add_generation_prompt=True,
1239+
tokenize=True,
1240+
return_dict=True,
1241+
return_tensors="np",
1242+
load_audio_from_video=True,
1243+
)
1244+
self.assertTrue(self.audio_input_name in out_dict)
1245+
self.assertTrue(self.video_input_name in out_dict)
1246+
1247+
# should always have input_ids and attention_mask
1248+
self.assertEqual(len(out_dict["input_ids"]), 1) # batch-size=1
1249+
self.assertEqual(len(out_dict["attention_mask"]), 1) # batch-size=1
1250+
self.assertEqual(len(out_dict[self.audio_input_name]), 2) # 2 audios in the conversation
1251+
self.assertEqual(len(out_dict[self.video_input_name]), 1) # 1 video in the conversation
1252+
11921253
@require_librosa
11931254
def test_audio_chat_template_single(self):
11941255
processor = self.get_processor()

0 commit comments

Comments
 (0)