diff --git a/vllm/model_executor/models/molmo.py b/vllm/model_executor/models/molmo.py index 734841d0dc98..0aa35c543319 100644 --- a/vllm/model_executor/models/molmo.py +++ b/vllm/model_executor/models/molmo.py @@ -21,7 +21,6 @@ from vllm.attention.layer import MultiHeadAttention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig -from vllm.config.multimodal import BaseDummyOptions from vllm.distributed import ( get_pp_group, get_tensor_model_parallel_rank, @@ -80,6 +79,7 @@ make_empty_intermediate_tensors_factory, make_layers, maybe_prefix, + merge_multimodal_embeddings, ) # TODO: hard-coded for now. Consider making it configurable. @@ -119,6 +119,13 @@ class MolmoImageInputs(TensorSchema): TensorShape("bn", "nc", "tp", dynamic_dims={"nc"}), ] # A boolean mask indicating which image features correspond to patch tokens. + + image_input_idx: Annotated[ + Optional[Union[torch.Tensor, list[torch.Tensor]]], + TensorShape("bn", "nc", "tp", dynamic_dims={"nc"}), + ] + # Index mapping for patch reordering to maintain spatial order + num_crops: Annotated[torch.Tensor, TensorShape("bn")] @@ -842,7 +849,10 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): ["hidden_states", "residual"], config.hidden_size ) - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + def get_input_embeddings( + self, + input_ids: torch.Tensor, + ) -> torch.Tensor: return self.embed_tokens(input_ids) def forward( @@ -1178,6 +1188,8 @@ def __call__( assert num_crops.sum() == len(feat_is_patch) outputs["feat_is_patch"] = feat_is_patch + # Keep the original index mapping + outputs["image_input_idx"] = image_input_idx outputs["num_crops"] = num_crops outputs["img_patch_id"] = self.image_patch_id @@ -1249,19 +1261,13 @@ def get_dummy_mm_data( self, seq_len: int, mm_counts: Mapping[str, int], - mm_options: Optional[Mapping[str, BaseDummyOptions]] = None, ) -> MultiModalDataDict: target_width, target_height = self.info.get_image_size_with_most_features() num_images = mm_counts.get("image", 0) - image_overrides = mm_options.get("image") if mm_options else None - return { "image": self._get_dummy_images( - width=target_width, - height=target_height, - num_images=num_images, - overrides=image_overrides, + width=target_width, height=target_height, num_images=num_images ) } @@ -1300,6 +1306,7 @@ def _get_mm_fields_config( images=MultiModalFieldConfig.flat_from_sizes("image", num_crops), image_masks=MultiModalFieldConfig.flat_from_sizes("image", num_crops), feat_is_patch=MultiModalFieldConfig.flat_from_sizes("image", num_crops), + image_input_idx=MultiModalFieldConfig.flat_from_sizes("image", num_crops), num_crops=MultiModalFieldConfig.batched("image"), img_patch_id=MultiModalFieldConfig.shared("image", num_images), ) @@ -1445,6 +1452,7 @@ def _parse_and_validate_image_input( images = kwargs.pop("images", None) image_masks = kwargs.pop("image_masks", None) feat_is_patch = kwargs.pop("feat_is_patch", None) + image_input_idx = kwargs.pop("image_input_idx", None) num_crops = kwargs.pop("num_crops", None) if images is None: @@ -1467,6 +1475,7 @@ def _parse_and_validate_image_input( images=images, image_masks=image_masks, feat_is_patch=feat_is_patch, + image_input_idx=image_input_idx, num_crops=num_crops, ) @@ -1477,6 +1486,7 @@ def _process_image_input( images = image_input["images"] image_masks = image_input["image_masks"] feat_is_patch = image_input["feat_is_patch"] + image_input_idx = image_input["image_input_idx"] num_crops = image_input["num_crops"] # Call the vision backbone on the whole batch at once @@ -1493,14 +1503,85 @@ def _process_image_input( ), ).squeeze(0) - # Only the features corresponding to patch tokens are relevant - return [ - feats[f_is_patch] - for feats, f_is_patch in zip( + # Check if we have image_input_idx for spatial reordering + if image_input_idx is not None: + # Split features by image and implement patch reordering + image_input_idx_flat = flatten_bn(image_input_idx, concat=True) + result_embeddings = [] + + for feats, f_is_patch, img_input_idx in zip( image_features_flat.split(num_crops.tolist()), feat_is_patch_flat.split(num_crops.tolist()), + image_input_idx_flat.split(num_crops.tolist()), + ): + # Apply spatial reordering based on image_input_idx + reordered_features = self._reorder_patches_by_spatial_position( + feats, f_is_patch, img_input_idx + ) + result_embeddings.append(reordered_features) + + return result_embeddings + else: + # Fallback to old behavior when image_input_idx is not available + # Only the features corresponding to patch tokens are relevant + return [ + feats[f_is_patch] + for feats, f_is_patch in zip( + image_features_flat.split(num_crops.tolist()), + feat_is_patch_flat.split(num_crops.tolist()), + ) + ] + + def _reorder_patches_by_spatial_position( + self, + image_features: torch.Tensor, + feat_is_patch: torch.Tensor, + image_input_idx: torch.Tensor, + ) -> torch.Tensor: + """ + Reorder patches from crop order to spatial order using image_input_idx. + + Args: + image_features: (num_patches, feature_dim) - features in crop order + feat_is_patch: (num_patches,) - mask for valid patches + image_input_idx: (num_patches,) - target positions in spatial order + + Returns: + torch.Tensor: (num_valid_patches, feature_dim) - spatially ordered features + """ + # Filter valid patches and their indices + valid_mask = feat_is_patch & (image_input_idx >= 0) + + if not valid_mask.any(): + return torch.empty( + 0, + image_features.shape[-1], + dtype=image_features.dtype, + device=image_features.device, ) - ] + + valid_features = image_features[valid_mask] + valid_indices = image_input_idx[valid_mask] + + # Create output tensor with spatial ordering + max_idx = valid_indices.max().item() + output_features = torch.zeros( + max_idx + 1, + image_features.shape[-1], + dtype=image_features.dtype, + device=image_features.device, + ) + + # Place features at their spatial positions + output_features[valid_indices] = valid_features + + # Return only the positions that have features + occupied_positions = torch.zeros( + max_idx + 1, dtype=torch.bool, device=image_features.device + ) + occupied_positions[valid_indices] = True + + return output_features[occupied_positions] def get_language_model(self) -> torch.nn.Module: return self.model @@ -1512,6 +1593,23 @@ def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: return self._process_image_input(image_input) + def get_input_embeddings( + self, + input_ids: torch.Tensor, + multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + ) -> torch.Tensor: + inputs_embeds = self.model.get_input_embeddings(input_ids) + if multimodal_embeddings is not None and len(multimodal_embeddings) != 0: + assert self.img_patch_id is not None + + inputs_embeds = merge_multimodal_embeddings( + input_ids, + inputs_embeds, + multimodal_embeddings, + self.img_patch_id, + ) + return inputs_embeds + def forward( self, input_ids: torch.LongTensor, @@ -1523,6 +1621,13 @@ def forward( if intermediate_tensors is not None: inputs_embeds = None + # NOTE: In v1, inputs_embeds is always generated at model runner, this + # condition is for v0 compatibility. + elif inputs_embeds is None: + vision_embeddings = self.get_multimodal_embeddings(**kwargs) + inputs_embeds = self.get_input_embeddings(input_ids, vision_embeddings) + input_ids = None + hidden_states = self.model( input_ids, positions, intermediate_tensors, inputs_embeds=inputs_embeds )