Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions vllm/model_executor/models/qwen2_5_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -836,10 +836,8 @@ def compute_attn_mask_seqlen(
self,
cu_seqlens: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
max_seqlen, seqlens = (
torch.zeros(1, device=cu_seqlens.device),
torch.zeros(1, device=cu_seqlens.device),
)
max_seqlen = torch.zeros([], device=cu_seqlens.device)
seqlens = torch.zeros(1, device=cu_seqlens.device)
Comment on lines -839 to +840
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

max_seqlen is actually a scalar value so this updates the dummy values to match this.

if (
self.attn_backend == _Backend.FLASH_ATTN
or self.attn_backend == _Backend.ROCM_AITER_FA
Expand Down
14 changes: 8 additions & 6 deletions vllm/model_executor/models/qwen3_omni_moe_thinker.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,8 +223,8 @@ def forward(
x: torch.Tensor,
cu_seqlens: torch.Tensor,
rotary_pos_emb: torch.Tensor,
max_seqlen: int | None = None, # Only used for Flash Attention
seqlens: list[int] | None = None, # Only used for xFormers
max_seqlen: torch.Tensor, # Only used for Flash Attention
seqlens: torch.Tensor, # Only used for xFormers
) -> torch.Tensor:
x = x + self.attn(
self.norm1(x),
Expand Down Expand Up @@ -488,12 +488,13 @@ def fast_pos_embed_interpolate(self, grid_thw: list[list[int]]) -> torch.Tensor:
def compute_attn_mask_seqlen(
self,
cu_seqlens: torch.Tensor,
) -> tuple[int | None, list[int] | None]:
max_seqlen, seqlens = None, None
) -> tuple[torch.Tensor, torch.Tensor]:
max_seqlen = torch.zeros([], device=cu_seqlens.device)
seqlens = torch.zeros(1, device=cu_seqlens.device)
if self.attn_backend == _Backend.FLASH_ATTN:
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
elif self.attn_backend == _Backend.XFORMERS:
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
seqlens = cu_seqlens[1:] - cu_seqlens[:-1]
return max_seqlen, seqlens

def forward(
Expand Down Expand Up @@ -1114,6 +1115,7 @@ def get_placeholder_str(cls, modality: str, i: int) -> str | None:

def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
self.vllm_config = vllm_config # needed for torch compile forward context
thinker_config: Qwen3OmniMoeThinkerConfig = (
vllm_config.model_config.hf_config.thinker_config
)
Expand Down
13 changes: 7 additions & 6 deletions vllm/model_executor/models/qwen3_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,8 +231,8 @@ def forward(
x: torch.Tensor,
cu_seqlens: torch.Tensor,
rotary_pos_emb: torch.Tensor,
max_seqlen: int | None = None, # Only used for Flash Attention
seqlens: list[int] | None = None, # Only used for xFormers
max_seqlen: torch.Tensor, # Only used for Flash Attention
seqlens: torch.Tensor, # Only used for xFormers
) -> torch.Tensor:
x = x + self.attn(
self.norm1(x),
Expand Down Expand Up @@ -512,15 +512,16 @@ def fast_pos_embed_interpolate(self, grid_thw: list[list[int]]) -> torch.Tensor:
def compute_attn_mask_seqlen(
self,
cu_seqlens: torch.Tensor,
) -> tuple[int | None, list[int] | None]:
max_seqlen, seqlens = None, None
) -> tuple[torch.Tensor, torch.Tensor]:
max_seqlen = torch.zeros([], device=cu_seqlens.device)
seqlens = torch.zeros(1, device=cu_seqlens.device)
if (
self.attn_backend == _Backend.FLASH_ATTN
or self.attn_backend == _Backend.ROCM_AITER_FA
):
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
elif self.attn_backend == _Backend.XFORMERS:
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
seqlens = cu_seqlens[1:] - cu_seqlens[:-1]
return max_seqlen, seqlens

def forward(
Expand Down