-
-
Notifications
You must be signed in to change notification settings - Fork 11.5k
[Bugfix][Multi Modal] Fix incorrect output in Molmo #26518
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -21,7 +21,6 @@ | |
| from vllm.attention.layer import MultiHeadAttention | ||
| from vllm.compilation.decorators import support_torch_compile | ||
| from vllm.config import CacheConfig, VllmConfig | ||
| from vllm.config.multimodal import BaseDummyOptions | ||
| from vllm.distributed import ( | ||
| get_pp_group, | ||
| get_tensor_model_parallel_rank, | ||
|
|
@@ -80,6 +79,7 @@ | |
| make_empty_intermediate_tensors_factory, | ||
| make_layers, | ||
| maybe_prefix, | ||
| merge_multimodal_embeddings, | ||
| ) | ||
|
|
||
| # TODO: hard-coded for now. Consider making it configurable. | ||
|
|
@@ -119,6 +119,13 @@ class MolmoImageInputs(TensorSchema): | |
| TensorShape("bn", "nc", "tp", dynamic_dims={"nc"}), | ||
| ] | ||
| # A boolean mask indicating which image features correspond to patch tokens. | ||
|
|
||
| image_input_idx: Annotated[ | ||
| Optional[Union[torch.Tensor, list[torch.Tensor]]], | ||
| TensorShape("bn", "nc", "tp", dynamic_dims={"nc"}), | ||
| ] | ||
| # Index mapping for patch reordering to maintain spatial order | ||
|
|
||
| num_crops: Annotated[torch.Tensor, TensorShape("bn")] | ||
|
|
||
|
|
||
|
|
@@ -842,7 +849,10 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): | |
| ["hidden_states", "residual"], config.hidden_size | ||
| ) | ||
|
|
||
| def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: | ||
| def get_input_embeddings( | ||
| self, | ||
| input_ids: torch.Tensor, | ||
| ) -> torch.Tensor: | ||
| return self.embed_tokens(input_ids) | ||
|
|
||
| def forward( | ||
|
|
@@ -1178,6 +1188,8 @@ def __call__( | |
| assert num_crops.sum() == len(feat_is_patch) | ||
|
|
||
| outputs["feat_is_patch"] = feat_is_patch | ||
| # Keep the original index mapping | ||
| outputs["image_input_idx"] = image_input_idx | ||
| outputs["num_crops"] = num_crops | ||
| outputs["img_patch_id"] = self.image_patch_id | ||
|
|
||
|
|
@@ -1249,19 +1261,13 @@ def get_dummy_mm_data( | |
| self, | ||
| seq_len: int, | ||
| mm_counts: Mapping[str, int], | ||
| mm_options: Optional[Mapping[str, BaseDummyOptions]] = None, | ||
| ) -> MultiModalDataDict: | ||
| target_width, target_height = self.info.get_image_size_with_most_features() | ||
| num_images = mm_counts.get("image", 0) | ||
|
|
||
| image_overrides = mm_options.get("image") if mm_options else None | ||
|
|
||
| return { | ||
| "image": self._get_dummy_images( | ||
| width=target_width, | ||
| height=target_height, | ||
| num_images=num_images, | ||
| overrides=image_overrides, | ||
| width=target_width, height=target_height, num_images=num_images | ||
| ) | ||
| } | ||
|
|
||
|
|
@@ -1300,6 +1306,7 @@ def _get_mm_fields_config( | |
| images=MultiModalFieldConfig.flat_from_sizes("image", num_crops), | ||
| image_masks=MultiModalFieldConfig.flat_from_sizes("image", num_crops), | ||
| feat_is_patch=MultiModalFieldConfig.flat_from_sizes("image", num_crops), | ||
| image_input_idx=MultiModalFieldConfig.flat_from_sizes("image", num_crops), | ||
| num_crops=MultiModalFieldConfig.batched("image"), | ||
| img_patch_id=MultiModalFieldConfig.shared("image", num_images), | ||
| ) | ||
|
|
@@ -1445,6 +1452,7 @@ def _parse_and_validate_image_input( | |
| images = kwargs.pop("images", None) | ||
| image_masks = kwargs.pop("image_masks", None) | ||
| feat_is_patch = kwargs.pop("feat_is_patch", None) | ||
| image_input_idx = kwargs.pop("image_input_idx", None) | ||
| num_crops = kwargs.pop("num_crops", None) | ||
|
|
||
| if images is None: | ||
|
|
@@ -1467,6 +1475,7 @@ def _parse_and_validate_image_input( | |
| images=images, | ||
| image_masks=image_masks, | ||
| feat_is_patch=feat_is_patch, | ||
| image_input_idx=image_input_idx, | ||
| num_crops=num_crops, | ||
| ) | ||
|
|
||
|
|
@@ -1477,6 +1486,7 @@ def _process_image_input( | |
| images = image_input["images"] | ||
| image_masks = image_input["image_masks"] | ||
| feat_is_patch = image_input["feat_is_patch"] | ||
| image_input_idx = image_input["image_input_idx"] | ||
| num_crops = image_input["num_crops"] | ||
|
|
||
| # Call the vision backbone on the whole batch at once | ||
|
|
@@ -1493,14 +1503,85 @@ def _process_image_input( | |
| ), | ||
| ).squeeze(0) | ||
|
|
||
| # Only the features corresponding to patch tokens are relevant | ||
| return [ | ||
| feats[f_is_patch] | ||
| for feats, f_is_patch in zip( | ||
| # Check if we have image_input_idx for spatial reordering | ||
| if image_input_idx is not None: | ||
| # Split features by image and implement patch reordering | ||
| image_input_idx_flat = flatten_bn(image_input_idx, concat=True) | ||
| result_embeddings = [] | ||
|
|
||
| for feats, f_is_patch, img_input_idx in zip( | ||
| image_features_flat.split(num_crops.tolist()), | ||
| feat_is_patch_flat.split(num_crops.tolist()), | ||
| image_input_idx_flat.split(num_crops.tolist()), | ||
| ): | ||
| # Apply spatial reordering based on image_input_idx | ||
| reordered_features = self._reorder_patches_by_spatial_position( | ||
| feats, f_is_patch, img_input_idx | ||
| ) | ||
| result_embeddings.append(reordered_features) | ||
|
|
||
| return result_embeddings | ||
| else: | ||
| # Fallback to old behavior when image_input_idx is not available | ||
| # Only the features corresponding to patch tokens are relevant | ||
| return [ | ||
| feats[f_is_patch] | ||
| for feats, f_is_patch in zip( | ||
| image_features_flat.split(num_crops.tolist()), | ||
| feat_is_patch_flat.split(num_crops.tolist()), | ||
| ) | ||
| ] | ||
|
|
||
| def _reorder_patches_by_spatial_position( | ||
| self, | ||
| image_features: torch.Tensor, | ||
| feat_is_patch: torch.Tensor, | ||
| image_input_idx: torch.Tensor, | ||
| ) -> torch.Tensor: | ||
| """ | ||
| Reorder patches from crop order to spatial order using image_input_idx. | ||
|
|
||
| Args: | ||
| image_features: (num_patches, feature_dim) - features in crop order | ||
| feat_is_patch: (num_patches,) - mask for valid patches | ||
| image_input_idx: (num_patches,) - target positions in spatial order | ||
|
|
||
| Returns: | ||
| torch.Tensor: (num_valid_patches, feature_dim) - spatially ordered features | ||
| """ | ||
| # Filter valid patches and their indices | ||
| valid_mask = feat_is_patch & (image_input_idx >= 0) | ||
|
|
||
| if not valid_mask.any(): | ||
| return torch.empty( | ||
| 0, | ||
| image_features.shape[-1], | ||
| dtype=image_features.dtype, | ||
| device=image_features.device, | ||
| ) | ||
| ] | ||
|
|
||
| valid_features = image_features[valid_mask] | ||
| valid_indices = image_input_idx[valid_mask] | ||
|
|
||
| # Create output tensor with spatial ordering | ||
| max_idx = valid_indices.max().item() | ||
| output_features = torch.zeros( | ||
| max_idx + 1, | ||
| image_features.shape[-1], | ||
| dtype=image_features.dtype, | ||
| device=image_features.device, | ||
| ) | ||
|
|
||
| # Place features at their spatial positions | ||
| output_features[valid_indices] = valid_features | ||
|
|
||
| # Return only the positions that have features | ||
| occupied_positions = torch.zeros( | ||
| max_idx + 1, dtype=torch.bool, device=image_features.device | ||
| ) | ||
| occupied_positions[valid_indices] = True | ||
|
|
||
| return output_features[occupied_positions] | ||
|
Comment on lines
+1566
to
+1584
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The current implementation for reordering patches is a bit complex and potentially inefficient. It creates two large intermediate tensors ( # Spatially reorder the valid features based on their indices.
# Sorting is more direct and memory-efficient than creating large
# intermediate tensors for scattering and then gathering.
_, sort_indices = torch.sort(valid_indices)
return valid_features[sort_indices] |
||
|
|
||
| def get_language_model(self) -> torch.nn.Module: | ||
| return self.model | ||
|
|
@@ -1512,6 +1593,23 @@ def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: | |
|
|
||
| return self._process_image_input(image_input) | ||
|
|
||
| def get_input_embeddings( | ||
| self, | ||
| input_ids: torch.Tensor, | ||
| multimodal_embeddings: Optional[MultiModalEmbeddings] = None, | ||
| ) -> torch.Tensor: | ||
| inputs_embeds = self.model.get_input_embeddings(input_ids) | ||
| if multimodal_embeddings is not None and len(multimodal_embeddings) != 0: | ||
| assert self.img_patch_id is not None | ||
|
|
||
| inputs_embeds = merge_multimodal_embeddings( | ||
| input_ids, | ||
| inputs_embeds, | ||
| multimodal_embeddings, | ||
| self.img_patch_id, | ||
| ) | ||
| return inputs_embeds | ||
|
|
||
| def forward( | ||
| self, | ||
| input_ids: torch.LongTensor, | ||
|
|
@@ -1523,6 +1621,13 @@ def forward( | |
| if intermediate_tensors is not None: | ||
| inputs_embeds = None | ||
|
|
||
| # NOTE: In v1, inputs_embeds is always generated at model runner, this | ||
| # condition is for v0 compatibility. | ||
| elif inputs_embeds is None: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I believe we don't need to maintain the code for v0 now. |
||
| vision_embeddings = self.get_multimodal_embeddings(**kwargs) | ||
| inputs_embeds = self.get_input_embeddings(input_ids, vision_embeddings) | ||
| input_ids = None | ||
|
|
||
| hidden_states = self.model( | ||
| input_ids, positions, intermediate_tensors, inputs_embeds=inputs_embeds | ||
| ) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do you want to change this part? I think it should be irrelevant.