Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
f352b0d
upstream
ywang96 Sep 12, 2025
667e973
fix & add co-author
ywang96 Sep 12, 2025
8fe70a7
Update tests/models/registry.py
ywang96 Sep 12, 2025
5f6afa1
add missing processor test
Isotr0py Sep 12, 2025
9c5808f
revert str
ywang96 Sep 12, 2025
c7fe668
fix processor test hashing
Isotr0py Sep 12, 2025
aa2330f
fix frames indices
Isotr0py Sep 12, 2025
e1f8397
fix hit_rate 1.0
Isotr0py Sep 12, 2025
9c7939c
typo
Isotr0py Sep 12, 2025
5d4f6dd
fix placeholder replacement
Isotr0py Sep 12, 2025
0d88363
fix video example
Isotr0py Sep 13, 2025
574884c
Merge branch 'main' into upstream-qwen-3-vl
ywang96 Sep 13, 2025
6ec3968
fix vit backend
ywang96 Sep 13, 2025
0f80a19
fix online serving metadata
Isotr0py Sep 13, 2025
ba54870
avoid hardcode fps=1
Isotr0py Sep 14, 2025
f7c37a9
oops fps=1
Isotr0py Sep 14, 2025
a6c5d7e
Merge branch 'main' into upstream-qwen-3-vl
ywang96 Sep 15, 2025
cbf6dee
Merge branch 'main' into upstream-qwen-3-vl
Isotr0py Sep 15, 2025
bec9e7e
catch up and fix processor test
Isotr0py Sep 15, 2025
5f3cf0e
Merge branch 'main' into upstream-qwen-3-vl
ywang96 Sep 16, 2025
5027f31
fix model path
ywang96 Sep 16, 2025
d0133e2
fix qwen_vl_utils compatibility
Isotr0py Sep 16, 2025
07e4f52
fix fps
ywang96 Sep 16, 2025
78afc5b
Merge branch 'main' into upstream-qwen-3-vl
ywang96 Sep 16, 2025
44f89b2
fix
ywang96 Sep 16, 2025
c7ea6f7
cleanup
ywang96 Sep 16, 2025
10bd983
do not modify metadata
ywang96 Sep 16, 2025
a4c0d34
fix example and online fps
Isotr0py Sep 16, 2025
b1a33a5
dp vit
ywang96 Sep 16, 2025
98f7019
revert
ywang96 Sep 16, 2025
f15e09d
Merge branch 'main' into qwen3-vl-dp-vit
ywang96 Sep 17, 2025
d888781
cleanup
ywang96 Sep 17, 2025
3d3fbe5
Merge branch 'main' into qwen3-vl-dp-vit
ywang96 Sep 17, 2025
9fd7448
clarify
ywang96 Sep 17, 2025
bf353c9
Merge branch 'main' into qwen3-vl-dp-vit
ywang96 Sep 17, 2025
0fa1dcc
Merge branch 'main' into qwen3-vl-dp-vit
ywang96 Sep 17, 2025
7b2cdd9
Merge branch 'main' into qwen3-vl-dp-vit
ywang96 Sep 18, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 75 additions & 19 deletions vllm/model_executor/models/qwen3_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,20 +126,23 @@ def __init__(self,
bias: bool = False,
act_fn: Callable[[torch.Tensor], torch.Tensor] = F.silu,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
prefix: str = "",
use_data_parallel: bool = False):
super().__init__()
self.linear_fc1 = ColumnParallelLinear(in_features,
hidden_features,
bias=bias,
quant_config=quant_config,
return_bias=False,
prefix=f"{prefix}.linear_fc1")
prefix=f"{prefix}.linear_fc1",
disable_tp=use_data_parallel)
self.linear_fc2 = RowParallelLinear(hidden_features,
in_features,
bias=bias,
quant_config=quant_config,
return_bias=False,
prefix=f"{prefix}.linear_fc2")
prefix=f"{prefix}.linear_fc2",
disable_tp=use_data_parallel)
self.act_fn = act_fn

def forward(self, x: torch.Tensor):
Expand All @@ -158,23 +161,27 @@ def __init__(
norm_layer: Optional[Callable[[int], nn.Module]] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
use_data_parallel: bool = False,
) -> None:
super().__init__()
if norm_layer is None:
norm_layer = partial(nn.LayerNorm, eps=1e-6)
self.norm1 = norm_layer(dim)
self.norm2 = norm_layer(dim)
self.attn = Qwen2_5_VisionAttention(embed_dim=dim,
num_heads=num_heads,
projection_size=dim,
quant_config=quant_config,
prefix=f"{prefix}.attn")
self.attn = Qwen2_5_VisionAttention(
embed_dim=dim,
num_heads=num_heads,
projection_size=dim,
quant_config=quant_config,
prefix=f"{prefix}.attn",
use_data_parallel=use_data_parallel)
self.mlp = Qwen3_VisionMLP(dim,
mlp_hidden_dim,
act_fn=act_fn,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.mlp")
prefix=f"{prefix}.mlp",
use_data_parallel=use_data_parallel)

def forward(
self,
Expand Down Expand Up @@ -205,6 +212,7 @@ def __init__(
use_postshuffle_norm: bool = False,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
use_data_parallel: bool = False,
) -> None:
super().__init__()
self.hidden_size = context_dim * (spatial_merge_size**2)
Expand All @@ -222,13 +230,15 @@ def __init__(
self.hidden_size,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.linear_fc1")
prefix=f"{prefix}.linear_fc1",
disable_tp=use_data_parallel)
self.act_fn = nn.GELU()
self.linear_fc2 = RowParallelLinear(self.hidden_size,
d_model,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.linear_fc2")
prefix=f"{prefix}.linear_fc2",
disable_tp=use_data_parallel)

def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.use_postshuffle_norm:
Expand All @@ -250,6 +260,7 @@ def __init__(
norm_eps: float = 1e-6,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
use_data_parallel: bool = False,
) -> None:
super().__init__()
self.hidden_size = vision_config.hidden_size
Expand All @@ -260,6 +271,12 @@ def __init__(
self.spatial_merge_unit = self.spatial_merge_size**2
self.temporal_patch_size = vision_config.temporal_patch_size
self.deepstack_visual_indexes = vision_config.deepstack_visual_indexes
self.use_data_parallel = use_data_parallel

# NOTE: This is used for creating empty tensor for all_gather for
# DP ViT. Here out_hidden_size is enlarged due to deepstack
self.out_hidden_size = (vision_config.out_hidden_size *
(1 + len(self.deepstack_visual_indexes)))

self.patch_embed = Qwen3_VisionPatchEmbed(
patch_size=self.patch_size,
Expand All @@ -283,7 +300,8 @@ def __init__(
act_fn=_ACTIVATION_REGISTRY[vision_config.hidden_act],
norm_layer=norm_layer,
quant_config=quant_config,
prefix=f"{prefix}.blocks.{layer_idx}")
prefix=f"{prefix}.blocks.{layer_idx}",
use_data_parallel=use_data_parallel)
for layer_idx in range(vision_config.depth)
])

Expand All @@ -294,6 +312,7 @@ def __init__(
spatial_merge_size=self.spatial_merge_size,
quant_config=quant_config,
prefix=f"{prefix}.merger",
use_data_parallel=use_data_parallel,
)

self.deepstack_merger_list = nn.ModuleList([
Expand All @@ -304,7 +323,8 @@ def __init__(
use_postshuffle_norm=True,
norm_layer=norm_layer,
quant_config=quant_config,
prefix=f"{prefix}.deepstack_merger_list.{layer_idx}")
prefix=f"{prefix}.deepstack_merger_list.{layer_idx}",
use_data_parallel=use_data_parallel)
for layer_idx in range(len(self.deepstack_visual_indexes))
])

Expand All @@ -325,7 +345,14 @@ def device(self) -> torch.device:

def rot_pos_emb(self, grid_thw):
pos_ids = []
for t, h, w in grid_thw:
# Support both Tensor and list inputs for DP path
if isinstance(grid_thw, list):
grid_list = grid_thw
max_grid_size = max(max(h, w) for _, h, w in grid_list)
else:
grid_list = grid_thw.tolist()
max_grid_size = int(grid_thw[:, 1:].max().item())
for t, h, w in grid_list:
hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
hpos_ids = hpos_ids.reshape(
h // self.spatial_merge_size,
Expand All @@ -348,7 +375,6 @@ def rot_pos_emb(self, grid_thw):
pos_ids.append(
torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
pos_ids = torch.cat(pos_ids, dim=0)
max_grid_size = grid_thw[:, 1:].max()
rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
return rotary_pos_emb
Expand Down Expand Up @@ -453,10 +479,18 @@ def forward(
hidden_states = hidden_states + pos_embeds
rotary_pos_emb = self.rot_pos_emb(grid_thw)

if isinstance(grid_thw, list):
grid_thw_tensor = torch.tensor(grid_thw,
device=hidden_states.device,
dtype=torch.int32)
else:
grid_thw_tensor = grid_thw

cu_seqlens = torch.repeat_interleave(
grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum(
grid_thw_tensor[:, 1] * grid_thw_tensor[:, 2],
grid_thw_tensor[:, 0]).cumsum(
dim=0,
dtype=grid_thw.dtype
dtype=grid_thw_tensor.dtype
if torch.jit.is_tracing() else torch.int32,
)
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
Expand Down Expand Up @@ -984,6 +1018,9 @@ class Qwen3VLForConditionalGeneration(nn.Module, SupportsMultiModal,
"up_proj",
],
}

supports_encoder_tp_data = True

# To ensure correct weight loading and mapping.
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={
Expand All @@ -1009,12 +1046,14 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "model"):

self.config = config
self.multimodal_config = multimodal_config
self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"

self.visual = Qwen3_VisionTransformer(
config.vision_config,
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
quant_config=self._maybe_ignore_quant_config(quant_config),
prefix=maybe_prefix(prefix, "visual"),
use_data_parallel=self.use_data_parallel,
)

self.language_model = Qwen3LLMForCausalLM(vllm_config=vllm_config,
Expand Down Expand Up @@ -1177,7 +1216,15 @@ def _process_image_input(
image_embeds = image_input["image_embeds"].type(self.visual.dtype)
else:
pixel_values = image_input["pixel_values"].type(self.visual.dtype)
image_embeds = self.visual(pixel_values, grid_thw=grid_thw)
if self.use_data_parallel:
from vllm.multimodal.utils import (
run_dp_sharded_mrope_vision_model)
return run_dp_sharded_mrope_vision_model(self.visual,
pixel_values,
grid_thw_list,
rope_type="rope_3d")
else:
image_embeds = self.visual(pixel_values, grid_thw=grid_thw)

# Split concatenated embeddings for each image item.
# Using prod on grid_thw_list instead of grid_thw.prod avoids CUDA sync
Expand All @@ -1199,7 +1246,16 @@ def _process_video_input(
else:
pixel_values_videos = video_input["pixel_values_videos"].type(
self.visual.dtype)
video_embeds = self.visual(pixel_values_videos, grid_thw=grid_thw)
if self.use_data_parallel:
from vllm.multimodal.utils import (
run_dp_sharded_mrope_vision_model)
return run_dp_sharded_mrope_vision_model(self.visual,
pixel_values_videos,
grid_thw_list,
rope_type="rope_3d")
else:
video_embeds = self.visual(pixel_values_videos,
grid_thw=grid_thw)

# Split concatenated embeddings for each video item.
# Using prod on grid_thw_list instead of grid_thw.prod avoids CUDA sync
Expand Down
2 changes: 2 additions & 0 deletions vllm/model_executor/models/qwen3_vl_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,12 +315,14 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):

self.config = config
self.multimodal_config = multimodal_config
self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"

self.visual = Qwen3_VisionTransformer(
config.vision_config,
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
quant_config=self._maybe_ignore_quant_config(quant_config),
prefix=maybe_prefix(prefix, "visual"),
use_data_parallel=self.use_data_parallel,
)

self.language_model = Qwen3MoeLLMForCausalLM(vllm_config=vllm_config,
Expand Down