33
33
import torch .nn .functional as F
34
34
from transformers import BatchFeature
35
35
from transformers .models .qwen2_vl import Qwen2VLImageProcessorFast
36
- from transformers .models .qwen2_vl .image_processing_qwen2_vl import smart_resize
36
+ from transformers .models .qwen2_vl .image_processing_qwen2_vl import (
37
+ smart_resize as image_smart_resize )
37
38
from transformers .models .qwen3_vl import (Qwen3VLProcessor ,
38
39
Qwen3VLVideoProcessor )
39
40
from transformers .models .qwen3_vl .configuration_qwen3_vl import (
40
41
Qwen3VLConfig , Qwen3VLVisionConfig )
42
+ from transformers .models .qwen3_vl .video_processing_qwen3_vl import (
43
+ smart_resize as video_smart_resize )
41
44
from transformers .video_utils import VideoMetadata
42
45
43
46
from vllm .attention .layer import check_upstream_fa_availability
85
88
86
89
logger = init_logger (__name__ )
87
90
91
+ # Official recommended max pixels is 24576 * 32 * 32
92
+ _MAX_FRAMES_PER_VIDEO = 24576
93
+
88
94
89
95
class Qwen3_VisionPatchEmbed (nn .Module ):
90
96
@@ -593,24 +599,39 @@ def _get_vision_info(
593
599
image_height : int ,
594
600
num_frames : int = 2 ,
595
601
do_resize : bool = True ,
596
- image_processor : Optional [Qwen2VLImageProcessorFast ],
602
+ image_processor : Optional [Union [Qwen2VLImageProcessorFast ,
603
+ Qwen3VLVideoProcessor ]],
597
604
) -> tuple [ImageSize , int ]:
598
- if image_processor is None :
605
+ if image_processor is None and num_frames > 1 :
606
+ image_processor = self .get_video_processor ()
607
+ elif image_processor is None :
599
608
image_processor = self .get_image_processor ()
600
609
610
+ is_video = isinstance (image_processor , Qwen3VLVideoProcessor )
611
+
601
612
hf_config = self .get_hf_config ()
602
613
vision_config = hf_config .vision_config
603
614
patch_size = vision_config .patch_size
604
615
merge_size = vision_config .spatial_merge_size
605
616
temporal_patch_size = vision_config .temporal_patch_size
606
617
607
618
if do_resize :
619
+ if is_video :
620
+ smart_resize = video_smart_resize
621
+ extra_kwargs = {
622
+ "num_frames" : num_frames ,
623
+ "temporal_factor" : temporal_patch_size
624
+ }
625
+ else :
626
+ smart_resize = image_smart_resize
627
+ extra_kwargs = {}
608
628
resized_height , resized_width = smart_resize (
609
629
height = image_height ,
610
630
width = image_width ,
611
631
factor = patch_size * merge_size ,
612
632
min_pixels = image_processor .size ["shortest_edge" ],
613
633
max_pixels = image_processor .size ["longest_edge" ],
634
+ ** extra_kwargs ,
614
635
)
615
636
preprocessed_size = ImageSize (width = resized_width ,
616
637
height = resized_height )
@@ -629,6 +650,39 @@ def _get_vision_info(
629
650
630
651
return preprocessed_size , num_vision_tokens
631
652
653
+ def _get_max_video_frames (self ,
654
+ max_tokens : int ,
655
+ start_num_frames : int = 2 ) -> int :
656
+ return super ()._get_max_video_frames (max_tokens ,
657
+ start_num_frames = start_num_frames )
658
+
659
+ def get_num_frames_with_most_features (
660
+ self ,
661
+ seq_len : int ,
662
+ mm_counts : Mapping [str , int ],
663
+ ) -> int :
664
+ return super ().get_num_frames_with_most_features (
665
+ seq_len , mm_counts , max_frames_per_video = _MAX_FRAMES_PER_VIDEO )
666
+
667
+ def get_max_video_tokens (
668
+ self ,
669
+ seq_len : int ,
670
+ mm_counts : Mapping [str , int ],
671
+ ) -> int :
672
+ target_width , target_height = self .get_image_size_with_most_features ()
673
+ video_soft_tokens = self .get_num_video_tokens (
674
+ image_width = target_width ,
675
+ image_height = target_height ,
676
+ num_frames = self .get_num_frames_with_most_features (
677
+ seq_len , mm_counts ),
678
+ image_processor = None ,
679
+ )
680
+
681
+ # NOTE: By default in Qwen3-VL, one video token is converted to
682
+ # "<{timestamp} seconds>" (on average 9.5 tokens) + vision_start_token + video_token + vision_end_token # noqa: E501
683
+ formatted_video_soft_tokens = video_soft_tokens * 12.5
684
+ return int (formatted_video_soft_tokens )
685
+
632
686
def _calculate_timestamps (self , indices : list [int ] | torch .Tensor ,
633
687
video_fps : float , merge_size : int ):
634
688
if not isinstance (indices , list ):
@@ -698,15 +752,21 @@ def get_dummy_mm_data(
698
752
self .info .get_image_size_with_most_features ())
699
753
target_num_frames = self .info .get_num_frames_with_most_features (
700
754
seq_len , mm_counts )
755
+ target_video_size , _ = self .info ._get_vision_info (
756
+ image_width = target_width ,
757
+ image_height = target_height ,
758
+ num_frames = target_num_frames ,
759
+ image_processor = self .info .get_video_processor (),
760
+ )
701
761
return {
702
762
"image" :
703
763
self ._get_dummy_images (width = target_width ,
704
764
height = target_height ,
705
765
num_images = num_images ),
706
766
"video" :
707
767
self ._get_dummy_videos (
708
- width = target_width ,
709
- height = target_height ,
768
+ width = target_video_size . width ,
769
+ height = target_video_size . height ,
710
770
num_frames = target_num_frames ,
711
771
num_videos = num_videos ,
712
772
),
0 commit comments