33
33
import torch .nn .functional as F
34
34
from transformers import BatchFeature
35
35
from transformers .models .qwen2_vl import Qwen2VLImageProcessorFast
36
- from transformers .models .qwen2_vl .image_processing_qwen2_vl import smart_resize
36
+ from transformers .models .qwen2_vl .image_processing_qwen2_vl import (
37
+ smart_resize as image_smart_resize )
37
38
from transformers .models .qwen3_vl import (Qwen3VLProcessor ,
38
39
Qwen3VLVideoProcessor )
39
40
from transformers .models .qwen3_vl .configuration_qwen3_vl import (
40
41
Qwen3VLConfig , Qwen3VLVisionConfig )
42
+ from transformers .models .qwen3_vl .video_processing_qwen3_vl import (
43
+ smart_resize as video_smart_resize )
41
44
from transformers .video_utils import VideoMetadata
42
45
43
46
from vllm .attention .layer import check_upstream_fa_availability
84
87
85
88
logger = init_logger (__name__ )
86
89
90
+ # Official recommended max pixels is 24576 * 32 * 32
91
+ _MAX_FRAMES_PER_VIDEO = 24576
92
+
87
93
88
94
class Qwen3_VisionPatchEmbed (nn .Module ):
89
95
@@ -592,24 +598,39 @@ def _get_vision_info(
592
598
image_height : int ,
593
599
num_frames : int = 2 ,
594
600
do_resize : bool = True ,
595
- image_processor : Optional [Qwen2VLImageProcessorFast ],
601
+ image_processor : Optional [Union [Qwen2VLImageProcessorFast ,
602
+ Qwen3VLVideoProcessor ]],
596
603
) -> tuple [ImageSize , int ]:
597
- if image_processor is None :
604
+ if image_processor is None and num_frames > 1 :
605
+ image_processor = self .get_video_processor ()
606
+ elif image_processor is None :
598
607
image_processor = self .get_image_processor ()
599
608
609
+ is_video = isinstance (image_processor , Qwen3VLVideoProcessor )
610
+
600
611
hf_config = self .get_hf_config ()
601
612
vision_config = hf_config .vision_config
602
613
patch_size = vision_config .patch_size
603
614
merge_size = vision_config .spatial_merge_size
604
615
temporal_patch_size = vision_config .temporal_patch_size
605
616
606
617
if do_resize :
618
+ if is_video :
619
+ smart_resize = video_smart_resize
620
+ extra_kwargs = {
621
+ "num_frames" : num_frames ,
622
+ "temporal_factor" : temporal_patch_size
623
+ }
624
+ else :
625
+ smart_resize = image_smart_resize
626
+ extra_kwargs = {}
607
627
resized_height , resized_width = smart_resize (
608
628
height = image_height ,
609
629
width = image_width ,
610
630
factor = patch_size * merge_size ,
611
631
min_pixels = image_processor .size ["shortest_edge" ],
612
632
max_pixels = image_processor .size ["longest_edge" ],
633
+ ** extra_kwargs ,
613
634
)
614
635
preprocessed_size = ImageSize (width = resized_width ,
615
636
height = resized_height )
@@ -628,6 +649,39 @@ def _get_vision_info(
628
649
629
650
return preprocessed_size , num_vision_tokens
630
651
652
+ def _get_max_video_frames (self ,
653
+ max_tokens : int ,
654
+ start_num_frames : int = 2 ) -> int :
655
+ return super ()._get_max_video_frames (max_tokens ,
656
+ start_num_frames = start_num_frames )
657
+
658
+ def get_num_frames_with_most_features (
659
+ self ,
660
+ seq_len : int ,
661
+ mm_counts : Mapping [str , int ],
662
+ ) -> int :
663
+ return super ().get_num_frames_with_most_features (
664
+ seq_len , mm_counts , max_frames_per_video = _MAX_FRAMES_PER_VIDEO )
665
+
666
+ def get_max_video_tokens (
667
+ self ,
668
+ seq_len : int ,
669
+ mm_counts : Mapping [str , int ],
670
+ ) -> int :
671
+ target_width , target_height = self .get_image_size_with_most_features ()
672
+ video_soft_tokens = self .get_num_video_tokens (
673
+ image_width = target_width ,
674
+ image_height = target_height ,
675
+ num_frames = self .get_num_frames_with_most_features (
676
+ seq_len , mm_counts ),
677
+ image_processor = None ,
678
+ )
679
+
680
+ # NOTE: By default in Qwen3-VL, one video token is converted to
681
+ # "<{timestamp} seconds>" (on average 9.5 tokens) + vision_start_token + video_token + vision_end_token # noqa: E501
682
+ formatted_video_soft_tokens = video_soft_tokens * 12.5
683
+ return int (formatted_video_soft_tokens )
684
+
631
685
def _calculate_timestamps (self , indices : list [int ] | torch .Tensor ,
632
686
video_fps : float , merge_size : int ):
633
687
if not isinstance (indices , list ):
@@ -697,15 +751,21 @@ def get_dummy_mm_data(
697
751
self .info .get_image_size_with_most_features ())
698
752
target_num_frames = self .info .get_num_frames_with_most_features (
699
753
seq_len , mm_counts )
754
+ target_video_size , _ = self .info ._get_vision_info (
755
+ image_width = target_width ,
756
+ image_height = target_height ,
757
+ num_frames = target_num_frames ,
758
+ image_processor = self .info .get_video_processor (),
759
+ )
700
760
return {
701
761
"image" :
702
762
self ._get_dummy_images (width = target_width ,
703
763
height = target_height ,
704
764
num_images = num_images ),
705
765
"video" :
706
766
self ._get_dummy_videos (
707
- width = target_width ,
708
- height = target_height ,
767
+ width = target_video_size . width ,
768
+ height = target_video_size . height ,
709
769
num_frames = target_num_frames ,
710
770
num_videos = num_videos ,
711
771
),
0 commit comments