diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index 4662176a1cc5..b0557d58d6dd 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -291,25 +291,6 @@ def forward(self, x: torch.Tensor): return x_down -def all_gather_interleave(local_tensor, hidden_size: int, tp_size: int): - """All-gather the input tensor interleavely across model parallel group.""" - import torch.distributed as dist - - gathered_tensors = [torch.zeros_like(local_tensor) for _ in range(tp_size)] - dist.all_gather( - gathered_tensors, local_tensor, group=parallel_state.get_tp_group().device_group - ) - - gathered_tensors_split = [ - torch.split(tensor, hidden_size // tp_size, -1) for tensor in gathered_tensors - ] - ordered_tensors = [ - tensor for pair in zip(*gathered_tensors_split) for tensor in pair - ] - result_tensor = torch.cat(ordered_tensors, dim=-1) - return result_tensor - - class Qwen2_5_VisionAttention(nn.Module): def __init__( self, @@ -383,21 +364,10 @@ def __init__( def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]: # [s, b, 3 * head * head_dim] seq_len, bs, _ = qkv.shape - if self.tp_size > 1: - qkv = all_gather_interleave(qkv, self.qkv.hidden_size, self.tp_size) # [s, b, 3 * head * head_dim] -> 3 * [s, b, head * head_dim] q, k, v = qkv.chunk(3, dim=2) - # 3 * [s, b, head * head_dim] - if self.tp_size > 1: - splitter = partial( - dist_utils.split_tensor_along_last_dim, num_partitions=self.tp_size - ) - q = splitter(q)[self.tp_rank] - k = splitter(k)[self.tp_rank] - v = splitter(v)[self.tp_rank] - # 3 * [s, b, head * head_dim] -> 3 * [s, b, head, head_dim] new_shape = ( seq_len, diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index bbebe7c0f928..ff04baee91d1 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -50,7 +50,7 @@ ) from vllm.config import VllmConfig from vllm.config.multimodal import BaseDummyOptions -from vllm.distributed import parallel_state, tensor_model_parallel_all_gather +from vllm.distributed import parallel_state from vllm.distributed import utils as dist_utils from vllm.logger import init_logger from vllm.model_executor.layers.activation import QuickGELU @@ -396,21 +396,10 @@ def __init__( def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]: # [s, b, 3 * head * head_dim] seq_len, bs, _ = qkv.shape - if self.tp_size > 1: - qkv = tensor_model_parallel_all_gather(qkv) # [s, b, 3 * head * head_dim] -> 3 * [s, b, head * head_dim] q, k, v = qkv.chunk(3, dim=2) - # 3 * [s, b, head * head_dim] - if self.tp_size > 1: - splitter = partial( - dist_utils.split_tensor_along_last_dim, num_partitions=self.tp_size - ) - q = splitter(q)[self.tp_rank] - k = splitter(k)[self.tp_rank] - v = splitter(v)[self.tp_rank] - # 3 * [s, b, head * head_dim] -> 3 * [s, b, head, head_dim] new_shape = ( seq_len,