diff --git a/docs/configuration/optimization.md b/docs/configuration/optimization.md index db9dfb313fb8..1a818fbf71d0 100644 --- a/docs/configuration/optimization.md +++ b/docs/configuration/optimization.md @@ -171,6 +171,7 @@ The availablilty of batch-level DP is based on model implementation. Currently, the following models support `mm_encoder_tp_mode="data"`: - Llama4 () +- MiniCPM-V-4 () - Qwen2.5-VL () - Step3 () diff --git a/vllm/model_executor/models/idefics2_vision_model.py b/vllm/model_executor/models/idefics2_vision_model.py index 9e27200fb1c8..88b2a295905b 100644 --- a/vllm/model_executor/models/idefics2_vision_model.py +++ b/vllm/model_executor/models/idefics2_vision_model.py @@ -27,13 +27,15 @@ Idefics2Config, Idefics2VisionConfig) from vllm.attention.layer import MultiHeadAttention -from vllm.distributed import divide, get_tensor_model_parallel_world_size +from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, QKVParallelLinear, + ReplicatedLinear, RowParallelLinear) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.multimodal.utils import run_dp_sharded_vision_model class Idefics2VisionEmbeddings(nn.Module): @@ -118,6 +120,7 @@ def __init__( config: Idefics2VisionConfig, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", + use_data_parallel: bool = False, ) -> None: super().__init__() self.config = config @@ -130,22 +133,43 @@ def __init__( f" {self.num_heads}).") self.scale = self.head_dim**-0.5 self.dropout = config.attention_dropout - self.qkv_proj = QKVParallelLinear( - self.embed_dim, - self.head_dim, - self.num_heads, - quant_config=quant_config, - prefix=f"{prefix}.qkv_proj", - ) - self.out_proj = RowParallelLinear( - self.embed_dim, - self.embed_dim, - bias=True, - quant_config=quant_config, - prefix=f"{prefix}.out_proj", - ) - self.tp_size = get_tensor_model_parallel_world_size() - self.num_heads_per_partition = divide(self.num_heads, self.tp_size) + + tp_size = (1 if use_data_parallel else + get_tensor_model_parallel_world_size()) + assert self.num_heads % tp_size == 0 + self.num_heads_per_partition = self.num_heads // tp_size + + if use_data_parallel: + self.q_size = self.num_heads * self.head_dim + self.qkv_proj = ReplicatedLinear( + self.embed_dim, + 3 * self.q_size, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) + self.out_proj = ReplicatedLinear( + self.embed_dim, + self.embed_dim, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.out_proj", + ) + else: + self.qkv_proj = QKVParallelLinear( + self.embed_dim, + self.head_dim, + self.num_heads, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) + self.out_proj = RowParallelLinear( + self.embed_dim, + self.embed_dim, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.out_proj", + ) self.attn = MultiHeadAttention(self.num_heads_per_partition, self.head_dim, self.scale) @@ -169,18 +193,23 @@ def __init__( config: Idefics2VisionConfig, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", + use_data_parallel: bool = False, ) -> None: super().__init__() self.config = config self.activation_fn = get_act_fn(config.hidden_act) - self.fc1 = ColumnParallelLinear( + cls_fc1 = (ReplicatedLinear + if use_data_parallel else ColumnParallelLinear) + self.fc1 = cls_fc1( config.hidden_size, config.intermediate_size, bias=True, quant_config=quant_config, prefix=f"{prefix}.fc1", ) - self.fc2 = RowParallelLinear( + cls_fc2 = (ReplicatedLinear + if use_data_parallel else RowParallelLinear) + self.fc2 = cls_fc2( config.intermediate_size, config.hidden_size, bias=True, @@ -202,17 +231,21 @@ def __init__( config: Idefics2Config, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", + use_data_parallel: bool = False, ) -> None: super().__init__() self.embed_dim = config.hidden_size - self.self_attn = Idefics2VisionAttention(config, - quant_config=quant_config, - prefix=f"{prefix}.self_attn") + self.self_attn = Idefics2VisionAttention( + config, + quant_config=quant_config, + prefix=f"{prefix}.self_attn", + use_data_parallel=use_data_parallel) self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) self.mlp = Idefics2VisionMLP(config, quant_config=quant_config, - prefix=f"{prefix}.mlp") + prefix=f"{prefix}.mlp", + use_data_parallel=use_data_parallel) self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) @@ -254,6 +287,7 @@ def __init__( *, num_hidden_layers_override: Optional[int] = None, prefix: str = "", + use_data_parallel: bool = False, ) -> None: super().__init__() @@ -267,7 +301,8 @@ def __init__( self.layers = nn.ModuleList([ Idefics2EncoderLayer(config, quant_config=quant_config, - prefix=f"{prefix}.layers.{layer_idx}") + prefix=f"{prefix}.layers.{layer_idx}", + use_data_parallel=use_data_parallel) for layer_idx in range(num_hidden_layers) ]) @@ -301,17 +336,20 @@ def __init__( num_hidden_layers_override: Optional[int] = None, require_post_norm: bool = True, prefix: str = "", + use_data_parallel: bool = False, ) -> None: super().__init__() embed_dim = config.hidden_size self.config = config + self.use_data_parallel = use_data_parallel self.embeddings = Idefics2VisionEmbeddings(config) self.encoder = Idefics2Encoder( config, quant_config=quant_config, num_hidden_layers_override=num_hidden_layers_override, - prefix=f"{prefix}.encoder") + prefix=f"{prefix}.encoder", + use_data_parallel=use_data_parallel) num_hidden_layers = config.num_hidden_layers if len(self.encoder.layers) > config.num_hidden_layers: @@ -340,10 +378,38 @@ def forward( patch_attention_mask=patch_attention_mask, tgt_sizes=tgt_sizes, ) - encoder_outputs = self.encoder(hidden_states) + if self.use_data_parallel: + encoder_outputs = run_dp_sharded_vision_model( + hidden_states, self.encoder) + else: + encoder_outputs = self.encoder(hidden_states) last_hidden_state = self.post_layernorm(encoder_outputs) return last_hidden_state + def _consolidate_qkv_weights( + self, weights: Iterable[tuple[str, torch.Tensor]] + ) -> Iterable[tuple[str, torch.Tensor]]: + qkv_idx_mappings = { + ".self_attn.q_proj": 0, + ".self_attn.k_proj": 1, + ".self_attn.v_proj": 2, + } + qkv_weights = {} + for name, loaded_weight in weights: + for weight_name, idx in qkv_idx_mappings.items(): + if weight_name not in name: + continue + new_name = name.replace(weight_name, ".self_attn.qkv_proj") + if new_name not in qkv_weights: + qkv_weights[new_name] = [None] * 3 + qkv_weights[new_name][idx] = loaded_weight + break + else: + yield name, loaded_weight + for key, weight in qkv_weights.items(): + qkv_weight = torch.cat(weight, dim=0) + yield key, qkv_weight + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ @@ -356,6 +422,9 @@ def load_weights(self, weights: Iterable[tuple[str, loaded_params: set[str] = set() layer_count = len(self.encoder.layers) + if self.use_data_parallel: + weights = self._consolidate_qkv_weights(weights) + for name, loaded_weight in weights: # skip pooling header if name.startswith("head."): @@ -373,7 +442,7 @@ def load_weights(self, weights: Iterable[tuple[str, continue for param_name, weight_name, shard_id in stacked_params_mapping: - if weight_name not in name: + if weight_name not in name or self.use_data_parallel: continue name = name.replace(weight_name, param_name) param = params_dict[name] diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py index 48ce1b9d38e2..a2a71bdd12b3 100644 --- a/vllm/model_executor/models/minicpmv.py +++ b/vllm/model_executor/models/minicpmv.py @@ -778,6 +778,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): # and config class self.config = config self.multimodal_config = multimodal_config + self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data" self.version = get_version_by_config(self.config) self.llm = self.init_llm(vllm_config=vllm_config, @@ -1325,9 +1326,11 @@ def init_vision_module( prefix: str = "", ) -> nn.Module: quant_config = self._maybe_ignore_quant_config(quant_config) - model = Idefics2VisionTransformer(config.vision_config, - quant_config=quant_config, - prefix=prefix) + model = Idefics2VisionTransformer( + config.vision_config, + quant_config=quant_config, + prefix=prefix, + use_data_parallel=self.use_data_parallel) if self.config.drop_vision_last_layer: model.encoder.layers = model.encoder.layers[:-1] return model diff --git a/vllm/multimodal/utils.py b/vllm/multimodal/utils.py index 58c71d865dc7..834b2189e4be 100644 --- a/vllm/multimodal/utils.py +++ b/vllm/multimodal/utils.py @@ -461,6 +461,8 @@ def run_dp_sharded_vision_model(image_input: torch.Tensor, num_chunks_per_rank, ...] vision_embeddings = vision_model(image_input_per_rank) + # Ensure tensor is contiguous before all_gather + vision_embeddings = vision_embeddings.contiguous() vision_embeddings = tensor_model_parallel_all_gather(vision_embeddings, dim=0) vision_embeddings = vision_embeddings[:num_chunks, ...]