Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 119 additions & 14 deletions vllm/model_executor/models/molmo.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
from vllm.attention.layer import MultiHeadAttention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.distributed import (
get_pp_group,
get_tensor_model_parallel_rank,
Expand Down Expand Up @@ -80,6 +79,7 @@
make_empty_intermediate_tensors_factory,
make_layers,
maybe_prefix,
merge_multimodal_embeddings,
)

# TODO: hard-coded for now. Consider making it configurable.
Expand Down Expand Up @@ -119,6 +119,13 @@ class MolmoImageInputs(TensorSchema):
TensorShape("bn", "nc", "tp", dynamic_dims={"nc"}),
]
# A boolean mask indicating which image features correspond to patch tokens.

image_input_idx: Annotated[
Optional[Union[torch.Tensor, list[torch.Tensor]]],
TensorShape("bn", "nc", "tp", dynamic_dims={"nc"}),
]
# Index mapping for patch reordering to maintain spatial order

num_crops: Annotated[torch.Tensor, TensorShape("bn")]


Expand Down Expand Up @@ -842,7 +849,10 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
["hidden_states", "residual"], config.hidden_size
)

def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
def get_input_embeddings(
self,
input_ids: torch.Tensor,
) -> torch.Tensor:
return self.embed_tokens(input_ids)

def forward(
Expand Down Expand Up @@ -1178,6 +1188,8 @@ def __call__(
assert num_crops.sum() == len(feat_is_patch)

outputs["feat_is_patch"] = feat_is_patch
# Keep the original index mapping
outputs["image_input_idx"] = image_input_idx
outputs["num_crops"] = num_crops
outputs["img_patch_id"] = self.image_patch_id

Expand Down Expand Up @@ -1249,19 +1261,13 @@ def get_dummy_mm_data(
self,
seq_len: int,
mm_counts: Mapping[str, int],
mm_options: Optional[Mapping[str, BaseDummyOptions]] = None,
) -> MultiModalDataDict:
target_width, target_height = self.info.get_image_size_with_most_features()
num_images = mm_counts.get("image", 0)

image_overrides = mm_options.get("image") if mm_options else None

return {
"image": self._get_dummy_images(
width=target_width,
height=target_height,
num_images=num_images,
overrides=image_overrides,
width=target_width, height=target_height, num_images=num_images
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do you want to change this part? I think it should be irrelevant.

)
}

Expand Down Expand Up @@ -1300,6 +1306,7 @@ def _get_mm_fields_config(
images=MultiModalFieldConfig.flat_from_sizes("image", num_crops),
image_masks=MultiModalFieldConfig.flat_from_sizes("image", num_crops),
feat_is_patch=MultiModalFieldConfig.flat_from_sizes("image", num_crops),
image_input_idx=MultiModalFieldConfig.flat_from_sizes("image", num_crops),
num_crops=MultiModalFieldConfig.batched("image"),
img_patch_id=MultiModalFieldConfig.shared("image", num_images),
)
Expand Down Expand Up @@ -1445,6 +1452,7 @@ def _parse_and_validate_image_input(
images = kwargs.pop("images", None)
image_masks = kwargs.pop("image_masks", None)
feat_is_patch = kwargs.pop("feat_is_patch", None)
image_input_idx = kwargs.pop("image_input_idx", None)
num_crops = kwargs.pop("num_crops", None)

if images is None:
Expand All @@ -1467,6 +1475,7 @@ def _parse_and_validate_image_input(
images=images,
image_masks=image_masks,
feat_is_patch=feat_is_patch,
image_input_idx=image_input_idx,
num_crops=num_crops,
)

Expand All @@ -1477,6 +1486,7 @@ def _process_image_input(
images = image_input["images"]
image_masks = image_input["image_masks"]
feat_is_patch = image_input["feat_is_patch"]
image_input_idx = image_input["image_input_idx"]
num_crops = image_input["num_crops"]

# Call the vision backbone on the whole batch at once
Expand All @@ -1493,14 +1503,85 @@ def _process_image_input(
),
).squeeze(0)

# Only the features corresponding to patch tokens are relevant
return [
feats[f_is_patch]
for feats, f_is_patch in zip(
# Check if we have image_input_idx for spatial reordering
if image_input_idx is not None:
# Split features by image and implement patch reordering
image_input_idx_flat = flatten_bn(image_input_idx, concat=True)
result_embeddings = []

for feats, f_is_patch, img_input_idx in zip(
image_features_flat.split(num_crops.tolist()),
feat_is_patch_flat.split(num_crops.tolist()),
image_input_idx_flat.split(num_crops.tolist()),
):
# Apply spatial reordering based on image_input_idx
reordered_features = self._reorder_patches_by_spatial_position(
feats, f_is_patch, img_input_idx
)
result_embeddings.append(reordered_features)

return result_embeddings
else:
# Fallback to old behavior when image_input_idx is not available
# Only the features corresponding to patch tokens are relevant
return [
feats[f_is_patch]
for feats, f_is_patch in zip(
image_features_flat.split(num_crops.tolist()),
feat_is_patch_flat.split(num_crops.tolist()),
)
]

def _reorder_patches_by_spatial_position(
self,
image_features: torch.Tensor,
feat_is_patch: torch.Tensor,
image_input_idx: torch.Tensor,
) -> torch.Tensor:
"""
Reorder patches from crop order to spatial order using image_input_idx.

Args:
image_features: (num_patches, feature_dim) - features in crop order
feat_is_patch: (num_patches,) - mask for valid patches
image_input_idx: (num_patches,) - target positions in spatial order

Returns:
torch.Tensor: (num_valid_patches, feature_dim) - spatially ordered features
"""
# Filter valid patches and their indices
valid_mask = feat_is_patch & (image_input_idx >= 0)

if not valid_mask.any():
return torch.empty(
0,
image_features.shape[-1],
dtype=image_features.dtype,
device=image_features.device,
)
]

valid_features = image_features[valid_mask]
valid_indices = image_input_idx[valid_mask]

# Create output tensor with spatial ordering
max_idx = valid_indices.max().item()
output_features = torch.zeros(
max_idx + 1,
image_features.shape[-1],
dtype=image_features.dtype,
device=image_features.device,
)

# Place features at their spatial positions
output_features[valid_indices] = valid_features

# Return only the positions that have features
occupied_positions = torch.zeros(
max_idx + 1, dtype=torch.bool, device=image_features.device
)
occupied_positions[valid_indices] = True

return output_features[occupied_positions]
Comment on lines +1566 to +1584
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The current implementation for reordering patches is a bit complex and potentially inefficient. It creates two large intermediate tensors (output_features and occupied_positions) with a size determined by max_idx, which could be large and consume unnecessary memory. A more direct and memory-efficient approach is to sort the valid features directly using their spatial indices. This improves both readability and performance.

        # Spatially reorder the valid features based on their indices.
        # Sorting is more direct and memory-efficient than creating large
        # intermediate tensors for scattering and then gathering.
        _, sort_indices = torch.sort(valid_indices)
        return valid_features[sort_indices]


def get_language_model(self) -> torch.nn.Module:
return self.model
Expand All @@ -1512,6 +1593,23 @@ def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings:

return self._process_image_input(image_input)

def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor:
inputs_embeds = self.model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None and len(multimodal_embeddings) != 0:
assert self.img_patch_id is not None

inputs_embeds = merge_multimodal_embeddings(
input_ids,
inputs_embeds,
multimodal_embeddings,
self.img_patch_id,
)
return inputs_embeds

def forward(
self,
input_ids: torch.LongTensor,
Expand All @@ -1523,6 +1621,13 @@ def forward(
if intermediate_tensors is not None:
inputs_embeds = None

# NOTE: In v1, inputs_embeds is always generated at model runner, this
# condition is for v0 compatibility.
elif inputs_embeds is None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe we don't need to maintain the code for v0 now.

vision_embeddings = self.get_multimodal_embeddings(**kwargs)
inputs_embeds = self.get_input_embeddings(input_ids, vision_embeddings)
input_ids = None

hidden_states = self.model(
input_ids, positions, intermediate_tensors, inputs_embeds=inputs_embeds
)
Expand Down