diff --git a/vllm/model_executor/models/qwen3_vl.py b/vllm/model_executor/models/qwen3_vl.py index 22948aee4936..2c36dfbce7f6 100644 --- a/vllm/model_executor/models/qwen3_vl.py +++ b/vllm/model_executor/models/qwen3_vl.py @@ -126,20 +126,23 @@ def __init__(self, bias: bool = False, act_fn: Callable[[torch.Tensor], torch.Tensor] = F.silu, quant_config: Optional[QuantizationConfig] = None, - prefix: str = ""): + prefix: str = "", + use_data_parallel: bool = False): super().__init__() self.linear_fc1 = ColumnParallelLinear(in_features, hidden_features, bias=bias, quant_config=quant_config, return_bias=False, - prefix=f"{prefix}.linear_fc1") + prefix=f"{prefix}.linear_fc1", + disable_tp=use_data_parallel) self.linear_fc2 = RowParallelLinear(hidden_features, in_features, bias=bias, quant_config=quant_config, return_bias=False, - prefix=f"{prefix}.linear_fc2") + prefix=f"{prefix}.linear_fc2", + disable_tp=use_data_parallel) self.act_fn = act_fn def forward(self, x: torch.Tensor): @@ -158,23 +161,27 @@ def __init__( norm_layer: Optional[Callable[[int], nn.Module]] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", + use_data_parallel: bool = False, ) -> None: super().__init__() if norm_layer is None: norm_layer = partial(nn.LayerNorm, eps=1e-6) self.norm1 = norm_layer(dim) self.norm2 = norm_layer(dim) - self.attn = Qwen2_5_VisionAttention(embed_dim=dim, - num_heads=num_heads, - projection_size=dim, - quant_config=quant_config, - prefix=f"{prefix}.attn") + self.attn = Qwen2_5_VisionAttention( + embed_dim=dim, + num_heads=num_heads, + projection_size=dim, + quant_config=quant_config, + prefix=f"{prefix}.attn", + use_data_parallel=use_data_parallel) self.mlp = Qwen3_VisionMLP(dim, mlp_hidden_dim, act_fn=act_fn, bias=True, quant_config=quant_config, - prefix=f"{prefix}.mlp") + prefix=f"{prefix}.mlp", + use_data_parallel=use_data_parallel) def forward( self, @@ -205,6 +212,7 @@ def __init__( use_postshuffle_norm: bool = False, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", + use_data_parallel: bool = False, ) -> None: super().__init__() self.hidden_size = context_dim * (spatial_merge_size**2) @@ -222,13 +230,15 @@ def __init__( self.hidden_size, bias=True, quant_config=quant_config, - prefix=f"{prefix}.linear_fc1") + prefix=f"{prefix}.linear_fc1", + disable_tp=use_data_parallel) self.act_fn = nn.GELU() self.linear_fc2 = RowParallelLinear(self.hidden_size, d_model, bias=True, quant_config=quant_config, - prefix=f"{prefix}.linear_fc2") + prefix=f"{prefix}.linear_fc2", + disable_tp=use_data_parallel) def forward(self, x: torch.Tensor) -> torch.Tensor: if self.use_postshuffle_norm: @@ -250,6 +260,7 @@ def __init__( norm_eps: float = 1e-6, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", + use_data_parallel: bool = False, ) -> None: super().__init__() self.hidden_size = vision_config.hidden_size @@ -260,6 +271,12 @@ def __init__( self.spatial_merge_unit = self.spatial_merge_size**2 self.temporal_patch_size = vision_config.temporal_patch_size self.deepstack_visual_indexes = vision_config.deepstack_visual_indexes + self.use_data_parallel = use_data_parallel + + # NOTE: This is used for creating empty tensor for all_gather for + # DP ViT. Here out_hidden_size is enlarged due to deepstack + self.out_hidden_size = (vision_config.out_hidden_size * + (1 + len(self.deepstack_visual_indexes))) self.patch_embed = Qwen3_VisionPatchEmbed( patch_size=self.patch_size, @@ -283,7 +300,8 @@ def __init__( act_fn=_ACTIVATION_REGISTRY[vision_config.hidden_act], norm_layer=norm_layer, quant_config=quant_config, - prefix=f"{prefix}.blocks.{layer_idx}") + prefix=f"{prefix}.blocks.{layer_idx}", + use_data_parallel=use_data_parallel) for layer_idx in range(vision_config.depth) ]) @@ -294,6 +312,7 @@ def __init__( spatial_merge_size=self.spatial_merge_size, quant_config=quant_config, prefix=f"{prefix}.merger", + use_data_parallel=use_data_parallel, ) self.deepstack_merger_list = nn.ModuleList([ @@ -304,7 +323,8 @@ def __init__( use_postshuffle_norm=True, norm_layer=norm_layer, quant_config=quant_config, - prefix=f"{prefix}.deepstack_merger_list.{layer_idx}") + prefix=f"{prefix}.deepstack_merger_list.{layer_idx}", + use_data_parallel=use_data_parallel) for layer_idx in range(len(self.deepstack_visual_indexes)) ]) @@ -325,7 +345,14 @@ def device(self) -> torch.device: def rot_pos_emb(self, grid_thw): pos_ids = [] - for t, h, w in grid_thw: + # Support both Tensor and list inputs for DP path + if isinstance(grid_thw, list): + grid_list = grid_thw + max_grid_size = max(max(h, w) for _, h, w in grid_list) + else: + grid_list = grid_thw.tolist() + max_grid_size = int(grid_thw[:, 1:].max().item()) + for t, h, w in grid_list: hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) hpos_ids = hpos_ids.reshape( h // self.spatial_merge_size, @@ -348,7 +375,6 @@ def rot_pos_emb(self, grid_thw): pos_ids.append( torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) pos_ids = torch.cat(pos_ids, dim=0) - max_grid_size = grid_thw[:, 1:].max() rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size) rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) return rotary_pos_emb @@ -453,10 +479,18 @@ def forward( hidden_states = hidden_states + pos_embeds rotary_pos_emb = self.rot_pos_emb(grid_thw) + if isinstance(grid_thw, list): + grid_thw_tensor = torch.tensor(grid_thw, + device=hidden_states.device, + dtype=torch.int32) + else: + grid_thw_tensor = grid_thw + cu_seqlens = torch.repeat_interleave( - grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum( + grid_thw_tensor[:, 1] * grid_thw_tensor[:, 2], + grid_thw_tensor[:, 0]).cumsum( dim=0, - dtype=grid_thw.dtype + dtype=grid_thw_tensor.dtype if torch.jit.is_tracing() else torch.int32, ) cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) @@ -984,6 +1018,9 @@ class Qwen3VLForConditionalGeneration(nn.Module, SupportsMultiModal, "up_proj", ], } + + supports_encoder_tp_data = True + # To ensure correct weight loading and mapping. hf_to_vllm_mapper = WeightsMapper( orig_to_new_prefix={ @@ -1009,12 +1046,14 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "model"): self.config = config self.multimodal_config = multimodal_config + self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data" self.visual = Qwen3_VisionTransformer( config.vision_config, norm_eps=getattr(config, "rms_norm_eps", 1e-6), quant_config=self._maybe_ignore_quant_config(quant_config), prefix=maybe_prefix(prefix, "visual"), + use_data_parallel=self.use_data_parallel, ) self.language_model = Qwen3LLMForCausalLM(vllm_config=vllm_config, @@ -1177,7 +1216,15 @@ def _process_image_input( image_embeds = image_input["image_embeds"].type(self.visual.dtype) else: pixel_values = image_input["pixel_values"].type(self.visual.dtype) - image_embeds = self.visual(pixel_values, grid_thw=grid_thw) + if self.use_data_parallel: + from vllm.multimodal.utils import ( + run_dp_sharded_mrope_vision_model) + return run_dp_sharded_mrope_vision_model(self.visual, + pixel_values, + grid_thw_list, + rope_type="rope_3d") + else: + image_embeds = self.visual(pixel_values, grid_thw=grid_thw) # Split concatenated embeddings for each image item. # Using prod on grid_thw_list instead of grid_thw.prod avoids CUDA sync @@ -1199,7 +1246,16 @@ def _process_video_input( else: pixel_values_videos = video_input["pixel_values_videos"].type( self.visual.dtype) - video_embeds = self.visual(pixel_values_videos, grid_thw=grid_thw) + if self.use_data_parallel: + from vllm.multimodal.utils import ( + run_dp_sharded_mrope_vision_model) + return run_dp_sharded_mrope_vision_model(self.visual, + pixel_values_videos, + grid_thw_list, + rope_type="rope_3d") + else: + video_embeds = self.visual(pixel_values_videos, + grid_thw=grid_thw) # Split concatenated embeddings for each video item. # Using prod on grid_thw_list instead of grid_thw.prod avoids CUDA sync diff --git a/vllm/model_executor/models/qwen3_vl_moe.py b/vllm/model_executor/models/qwen3_vl_moe.py index a800e94ab1e5..d25bc71dcb59 100644 --- a/vllm/model_executor/models/qwen3_vl_moe.py +++ b/vllm/model_executor/models/qwen3_vl_moe.py @@ -315,12 +315,14 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.config = config self.multimodal_config = multimodal_config + self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data" self.visual = Qwen3_VisionTransformer( config.vision_config, norm_eps=getattr(config, "rms_norm_eps", 1e-6), quant_config=self._maybe_ignore_quant_config(quant_config), prefix=maybe_prefix(prefix, "visual"), + use_data_parallel=self.use_data_parallel, ) self.language_model = Qwen3MoeLLMForCausalLM(vllm_config=vllm_config,