Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/models/supported_models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -450,7 +450,7 @@ Text Generation
- Idefics3
- T + I
- :code:`HuggingFaceM4/Idefics3-8B-Llama3` etc.
-
- ✅︎
-
* - :code:`InternVLChatModel`
- InternVL2
Expand Down
55 changes: 46 additions & 9 deletions vllm/model_executor/models/idefics3.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
from vllm.multimodal.image import cached_get_image_processor
Expand All @@ -44,7 +45,7 @@
from .idefics2_vision_model import (
Idefics2VisionTransformer as Idefics3VisionTransformer)
# yapf: enable
from .interfaces import SupportsMultiModal
from .interfaces import SupportsLoRA, SupportsMultiModal
from .llama import LlamaModel
from .utils import (AutoWeightsLoader, flatten_bn, maybe_prefix,
merge_multimodal_embeddings)
Expand All @@ -58,8 +59,6 @@ class Idefics3ImagePixelInputs(TypedDict):
"""
Shape: `(batch_size * num_images, num_channels, height, width)`
"""
rows: List[int]
cols: List[int]
pixel_attention_mask: Optional[torch.BoolTensor]


Expand Down Expand Up @@ -356,8 +355,15 @@ def dummy_data_for_idefics3(
image_seq_len = processor.image_seq_len
max_llm_image_tokens = max_num_image_patches * image_seq_len * num_images

if seq_len - max_llm_image_tokens < 0:
raise RuntimeError(
f"Idefics3 cannot process {num_images} images in a prompt, "
"please increase max_model_len or reduce image limit by "
"--limit-mm-per-prompt.")

seq_data = SequenceData.from_prompt_token_counts(
(hf_config.image_token_id, max_llm_image_tokens), (0, seq_len))
(hf_config.image_token_id, max_llm_image_tokens),
(0, seq_len - max_llm_image_tokens))

width = height = hf_config.vision_config.image_size
image = Image.new("RGB", (width, height), color=0)
Expand Down Expand Up @@ -463,8 +469,6 @@ def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[ImageInputs]:
pixel_values = kwargs.pop("pixel_values", None)
image_embeds = kwargs.pop("image_embeds", None)
rows = kwargs.pop("rows", None)
cols = kwargs.pop("cols", None)
pixel_attention_mask = kwargs.pop("pixel_attention_mask", None)

if pixel_values is None and image_embeds is None:
Expand All @@ -489,8 +493,6 @@ def _parse_and_validate_image_input(
data=self._validate_pixel_values(
flatten_bn(pixel_values,
concat=True)),
rows=rows,
cols=cols,
pixel_attention_mask=flatten_bn(
pixel_attention_mask,
concat=True))
Expand Down Expand Up @@ -610,7 +612,33 @@ def forward(
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_idefics3_image_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_idefics3)
@INPUT_REGISTRY.register_input_processor(input_processor_for_idefics3)
class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal):
class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsLoRA):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
}
# LoRA specific attributes
supported_lora_modules = [
# vision_model
"fc1",
"fc2",
"out_proj",
# text_model
"qkv_proj", # same name with vision encoder
"o_proj",
"gate_up_proj",
"down_proj",
]
embedding_modules = {}
embedding_padding_modules = []

def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
Expand Down Expand Up @@ -672,3 +700,12 @@ def sample(
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
loader = AutoWeightsLoader(self)
loader.load_weights(weights)

def get_mm_mapping(self) -> MultiModelKeys:
"""
Get the module prefix in multimodal models
"""
return MultiModelKeys.from_string_field(
language_model="model.text_model",
connector="model.connector",
tower_model="model.vision_model")