Skip to content

Commit 27d7638

Browse files
DarkLight1337NickLuccheywang96
authored
[Bugfix] Merge MM embeddings by index instead of token IDs (#16229)
Signed-off-by: DarkLight1337 <[email protected]> Signed-off-by: NickLucche <[email protected]> Signed-off-by: Roger Wang <[email protected]> Co-authored-by: NickLucche <[email protected]> Co-authored-by: Roger Wang <[email protected]>
1 parent 1761739 commit 27d7638

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

80 files changed

+965
-1138
lines changed

docs/contributing/model/multimodal.md

Lines changed: 5 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -66,35 +66,12 @@ Further update the model as follows:
6666
!!! important
6767
The returned `multimodal_embeddings` must be either a **3D [torch.Tensor][]** of shape `(num_items, feature_size, hidden_size)`, or a **list / tuple of 2D [torch.Tensor][]'s** of shape `(feature_size, hidden_size)`, so that `multimodal_embeddings[i]` retrieves the embeddings generated from the `i`-th multimodal data item (e.g, image) of the request.
6868

69-
- Implement [get_input_embeddings][vllm.model_executor.models.interfaces.SupportsMultiModal.get_input_embeddings] to merge `multimodal_embeddings` with text embeddings from the `input_ids`. If input processing for the model is implemented correctly (see sections below), then you can leverage the utility function we provide to easily merge the embeddings.
70-
71-
??? code
72-
73-
```python
74-
from .utils import merge_multimodal_embeddings
75-
76-
class YourModelForImage2Seq(nn.Module):
77-
...
78-
79-
def get_input_embeddings(
80-
self,
81-
input_ids: torch.Tensor,
82-
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
83-
) -> torch.Tensor:
84-
85-
# `get_input_embeddings` should already be implemented for the language
86-
# model as one of the requirements of basic vLLM model implementation.
87-
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
88-
89-
if multimodal_embeddings is not None:
90-
inputs_embeds = merge_multimodal_embeddings(
91-
input_ids=input_ids,
92-
inputs_embeds=inputs_embeds,
93-
multimodal_embeddings=multimodal_embeddings,
94-
placeholder_token_id=self.config.image_token_index)
69+
!!! note
70+
By default, vLLM merges the multimodal embeddings into text embeddings depending on the information of their locations defined in
71+
[PlaceholderRange][vllm.multimodal.inputs.PlaceholderRange] from input processing.
72+
This logic can be found at [get_input_embeddings][vllm.model_executor.models.interfaces.SupportsMultiModal.get_input_embeddings].
9573

96-
return inputs_embeds
97-
```
74+
You may override this method if additional logic is required for your model when merging embeddings.
9875

9976
- Implement [get_language_model][vllm.model_executor.models.interfaces.SupportsMultiModal.get_language_model] getter to provide stable access to the underlying language model.
10077

vllm/config/model.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -509,9 +509,14 @@ def _task_to_convert(task: TaskOption) -> ConvertType:
509509
else: # task == "auto"
510510
pass
511511
else:
512+
debug_info = {
513+
"architectures": architectures,
514+
"is_generative_model": is_generative_model,
515+
"is_pooling_model": is_pooling_model,
516+
}
512517
raise AssertionError("The model should be a generative or "
513518
"pooling model when task is set to "
514-
f"{self.task!r}.")
519+
f"{self.task!r}. Found: {debug_info}")
515520

516521
self.runner = runner
517522
self.convert = convert

vllm/model_executor/models/aria.py

Lines changed: 6 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,7 @@
3838
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsQuant
3939
from .llama import LlamaDecoderLayer, LlamaMLP, LlamaModel
4040
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
41-
is_pp_missing_parameter, maybe_prefix,
42-
merge_multimodal_embeddings)
41+
is_pp_missing_parameter, maybe_prefix)
4342

4443

4544
class AriaImagePixelInputs(TensorSchema):
@@ -605,19 +604,6 @@ def get_multimodal_embeddings(self,
605604
multimodal_embeddings = self._process_image_input(image_input)
606605
return multimodal_embeddings
607606

608-
def get_input_embeddings(
609-
self,
610-
input_ids: torch.Tensor,
611-
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
612-
) -> torch.Tensor:
613-
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
614-
if multimodal_embeddings is not None \
615-
and len(multimodal_embeddings) != 0:
616-
inputs_embeds = merge_multimodal_embeddings(
617-
input_ids, inputs_embeds, multimodal_embeddings,
618-
self.config.image_token_index)
619-
return inputs_embeds
620-
621607
def forward(
622608
self,
623609
input_ids: torch.Tensor,
@@ -628,10 +614,11 @@ def forward(
628614
) -> Union[torch.Tensor, IntermediateTensors]:
629615
if inputs_embeds is None:
630616
multimodal_embeddings = self.get_multimodal_embeddings(**kwargs)
631-
# always pass the input via `inputs_embeds`
632-
# to make sure the computation graph is consistent
633-
inputs_embeds = self.get_input_embeddings(input_ids,
634-
multimodal_embeddings)
617+
inputs_embeds = self.get_input_embeddings(
618+
input_ids,
619+
multimodal_embeddings,
620+
is_multimodal=input_ids == self.config.image_token_index,
621+
)
635622
input_ids = None
636623

637624
hidden_states = self.language_model(

vllm/model_executor/models/aya_vision.py

Lines changed: 6 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,7 @@
3333
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
3434
from .siglip import SiglipVisionModel
3535
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
36-
init_vllm_registered_model, maybe_prefix,
37-
merge_multimodal_embeddings)
36+
init_vllm_registered_model, maybe_prefix)
3837

3938

4039
class AyaVisionImagePixelInputs(TensorSchema):
@@ -417,23 +416,6 @@ def get_multimodal_embeddings(self,
417416

418417
return self._process_image_input(image_input, **kwargs)
419418

420-
def get_input_embeddings(
421-
self,
422-
input_ids: torch.Tensor,
423-
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
424-
) -> torch.Tensor:
425-
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
426-
if multimodal_embeddings is not None \
427-
and len(multimodal_embeddings) != 0:
428-
inputs_embeds = merge_multimodal_embeddings(
429-
input_ids=input_ids,
430-
inputs_embeds=inputs_embeds,
431-
multimodal_embeddings=multimodal_embeddings,
432-
placeholder_token_id=self.config.image_token_index,
433-
)
434-
435-
return inputs_embeds
436-
437419
def forward(
438420
self,
439421
input_ids: torch.Tensor,
@@ -449,8 +431,11 @@ def forward(
449431
# condition is for v0 compatibility.
450432
elif inputs_embeds is None:
451433
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
452-
inputs_embeds = self.get_input_embeddings(input_ids,
453-
vision_embeddings)
434+
inputs_embeds = self.get_input_embeddings(
435+
input_ids,
436+
vision_embeddings,
437+
is_multimodal=input_ids == self.config.image_token_index,
438+
)
454439
input_ids = None
455440

456441
hidden_states = self.language_model.model(

vllm/model_executor/models/bert.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -348,6 +348,9 @@ def __init__(
348348
self.encoder = BertEncoder(vllm_config=vllm_config,
349349
prefix=f"{prefix}.encoder")
350350

351+
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
352+
return self.embeddings(input_ids)
353+
351354
def forward(
352355
self,
353356
input_ids: torch.Tensor,
@@ -457,6 +460,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
457460
prefix=maybe_prefix(prefix, "model"))
458461
self.pooler = self._build_pooler(pooler_config)
459462

463+
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
464+
return self.model.get_input_embeddings(input_ids)
465+
460466
def forward(
461467
self,
462468
input_ids: torch.Tensor,
@@ -588,6 +594,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
588594
),
589595
})
590596

597+
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
598+
return self.bert.get_input_embeddings(input_ids)
599+
591600
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
592601
loader = AutoWeightsLoader(self)
593602
loaded_params = loader.load_weights(weights)
@@ -637,6 +646,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
637646
Pooler.for_encode(pooler_config),
638647
})
639648

649+
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
650+
return self.bert.get_input_embeddings(input_ids)
651+
640652
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
641653
loader = AutoWeightsLoader(self)
642654
loaded_params = loader.load_weights(weights)

vllm/model_executor/models/bert_with_rope.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -426,6 +426,9 @@ def __init__(self,
426426
prefix=f"{prefix}.encoder")
427427
self.pooler = BertPooler(self.config) if add_pooling_layer else None
428428

429+
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
430+
return self.embeddings(input_ids)
431+
429432
def forward(
430433
self,
431434
input_ids: torch.Tensor,
@@ -673,6 +676,9 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
673676
loaded_params = loader.load_weights(weights)
674677
return loaded_params
675678

679+
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
680+
return self.new.get_input_embeddings(input_ids)
681+
676682
def forward(
677683
self,
678684
input_ids: Optional[torch.Tensor],

vllm/model_executor/models/blip2.py

Lines changed: 6 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from .interfaces import (MultiModalEmbeddings, SupportsMultiModal, SupportsPP,
2828
SupportsQuant)
2929
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
30-
maybe_prefix, merge_multimodal_embeddings)
30+
maybe_prefix)
3131

3232
# We use this internally as placeholders since there is no image token
3333
# defined on the HuggingFace repo
@@ -631,19 +631,6 @@ def get_multimodal_embeddings(self,
631631
vision_embeddings = self._process_image_input(image_input)
632632
return vision_embeddings
633633

634-
def get_input_embeddings(
635-
self,
636-
input_ids: torch.Tensor,
637-
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
638-
) -> torch.Tensor:
639-
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
640-
if multimodal_embeddings is not None \
641-
and len(multimodal_embeddings) != 0:
642-
inputs_embeds = merge_multimodal_embeddings(
643-
input_ids, inputs_embeds, multimodal_embeddings,
644-
_IMAGE_TOKEN_ID)
645-
return inputs_embeds
646-
647634
def forward(
648635
self,
649636
input_ids: torch.Tensor,
@@ -689,8 +676,11 @@ def forward(
689676
# condition is for v0 compatibility.
690677
elif inputs_embeds is None:
691678
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
692-
inputs_embeds = self.get_input_embeddings(input_ids,
693-
vision_embeddings)
679+
inputs_embeds = self.get_input_embeddings(
680+
input_ids,
681+
vision_embeddings,
682+
is_multimodal=input_ids == _IMAGE_TOKEN_ID,
683+
)
694684
input_ids = None
695685

696686
hidden_states = self.language_model.model(input_ids,

vllm/model_executor/models/chameleon.py

Lines changed: 7 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444
SupportsQuant)
4545
from .utils import (flatten_bn, is_pp_missing_parameter,
4646
make_empty_intermediate_tensors_factory, make_layers,
47-
maybe_prefix, merge_multimodal_embeddings)
47+
maybe_prefix)
4848

4949
logger = init_logger(__name__)
5050

@@ -1002,20 +1002,6 @@ def get_multimodal_embeddings(self,
10021002
vision_embeddings = self.model.get_input_embeddings(image_tokens)
10031003
return vision_embeddings
10041004

1005-
def get_input_embeddings(
1006-
self,
1007-
input_ids: torch.Tensor,
1008-
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
1009-
) -> torch.Tensor:
1010-
1011-
inputs_embeds = self.model.get_input_embeddings(input_ids)
1012-
if multimodal_embeddings is not None \
1013-
and len(multimodal_embeddings) != 0:
1014-
inputs_embeds = merge_multimodal_embeddings(
1015-
input_ids, inputs_embeds, multimodal_embeddings,
1016-
self.model.vocabulary_mapping.image_token_id)
1017-
return inputs_embeds
1018-
10191005
def forward(
10201006
self,
10211007
input_ids: torch.Tensor,
@@ -1032,8 +1018,12 @@ def forward(
10321018
# condition is for v0 compatibility.
10331019
elif inputs_embeds is None:
10341020
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
1035-
inputs_embeds = self.get_input_embeddings(input_ids,
1036-
vision_embeddings)
1021+
image_token_id = self.model.vocabulary_mapping.image_token_id
1022+
inputs_embeds = self.get_input_embeddings(
1023+
input_ids,
1024+
vision_embeddings,
1025+
is_multimodal=input_ids == image_token_id,
1026+
)
10371027
input_ids = None
10381028

10391029
hidden_states = self.model(input_ids,

vllm/model_executor/models/chatglm.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -433,6 +433,9 @@ def __init__(
433433
self.make_empty_intermediate_tensors = (
434434
self.transformer.make_empty_intermediate_tensors)
435435

436+
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
437+
return self.transformer.get_input_embeddings(input_ids)
438+
436439
def compute_logits(
437440
self,
438441
hidden_states: torch.Tensor,

vllm/model_executor/models/cohere2_vision.py

Lines changed: 6 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,7 @@
3737
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
3838
from .siglip import SiglipVisionModel
3939
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
40-
init_vllm_registered_model, maybe_prefix,
41-
merge_multimodal_embeddings)
40+
init_vllm_registered_model, maybe_prefix)
4241

4342

4443
class Cohere2VisionImagePixelInputs(TensorSchema):
@@ -430,23 +429,6 @@ def get_multimodal_embeddings(self,
430429

431430
return self._process_image_input(image_input, **kwargs)
432431

433-
def get_input_embeddings(
434-
self,
435-
input_ids: torch.Tensor,
436-
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
437-
) -> torch.Tensor:
438-
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
439-
if multimodal_embeddings is not None \
440-
and len(multimodal_embeddings) != 0:
441-
inputs_embeds = merge_multimodal_embeddings(
442-
input_ids=input_ids,
443-
inputs_embeds=inputs_embeds,
444-
multimodal_embeddings=multimodal_embeddings,
445-
placeholder_token_id=self.config.image_token_id,
446-
)
447-
448-
return inputs_embeds
449-
450432
def forward(
451433
self,
452434
input_ids: torch.Tensor,
@@ -462,8 +444,11 @@ def forward(
462444
# condition is for v0 compatibility.
463445
elif inputs_embeds is None:
464446
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
465-
inputs_embeds = self.get_input_embeddings(input_ids,
466-
vision_embeddings)
447+
inputs_embeds = self.get_input_embeddings(
448+
input_ids,
449+
vision_embeddings,
450+
is_multimodal=input_ids == self.config.image_token_id,
451+
)
467452
input_ids = None
468453

469454
hidden_states = self.language_model.model(

0 commit comments

Comments
 (0)