diff --git a/docs/models/hardware_supported_models/tpu.md b/docs/models/hardware_supported_models/tpu.md index 8d3e28c259ec..7b0a5ba6e72d 100644 --- a/docs/models/hardware_supported_models/tpu.md +++ b/docs/models/hardware_supported_models/tpu.md @@ -16,8 +16,8 @@ | meta-llama/Llama-4-* | Llama4ForConditionalGeneration | ❌ | | microsoft/Phi-3-mini-128k-instruct | Phi3ForCausalLM | 🟨 | | microsoft/phi-4 | Phi3ForCausalLM | ❌ | -| google/gemma-3-27b-it | TransformersForMultimodalLM | 🟨 | -| google/gemma-3-4b-it | TransformersForMultimodalLM | ❌ | +| google/gemma-3-27b-it | Gemma3ForConditionalGeneration | 🟨 | +| google/gemma-3-4b-it | Gemma3ForConditionalGeneration | ❌ | | deepseek-ai/DeepSeek-R1 | DeepseekV3ForCausalLM | ❌ | | deepseek-ai/DeepSeek-V3 | DeepseekV3ForCausalLM | ❌ | | RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8 | LlamaForCausalLM | ✅ | diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index 001a5b96174a..076093d625f3 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -641,6 +641,7 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen | `DeepseekVLV2ForCausalLM`^ | DeepSeek-VL2 | T + I+ | `deepseek-ai/deepseek-vl2-tiny`, `deepseek-ai/deepseek-vl2-small`, `deepseek-ai/deepseek-vl2`, etc. | | ✅︎ | | `Ernie4_5_VLMoeForConditionalGeneration` | Ernie4.5-VL | T + I+/ V+ | `baidu/ERNIE-4.5-VL-28B-A3B-PT`, `baidu/ERNIE-4.5-VL-424B-A47B-PT` | | ✅︎ | | `FuyuForCausalLM` | Fuyu | T + I | `adept/fuyu-8b`, etc. | | ✅︎ | +| `Gemma3ForConditionalGeneration` | Gemma 3 | T + I+ | `google/gemma-3-4b-it`, `google/gemma-3-27b-it`, etc. | ✅︎ | ✅︎ | | `Gemma3nForConditionalGeneration` | Gemma 3n | T + I + A | `google/gemma-3n-E2B-it`, `google/gemma-3n-E4B-it`, etc. | | | | `GLM4VForCausalLM`^ | GLM-4V | T + I | `zai-org/glm-4v-9b`, `zai-org/cogagent-9b-20241220`, etc. | ✅︎ | ✅︎ | | `Glm4vForConditionalGeneration` | GLM-4.1V-Thinking | T + IE+ + VE+ | `zai-org/GLM-4.1V-9B-Thinking`, etc. | ✅︎ | ✅︎ | @@ -670,6 +671,7 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen | `NVLM_D_Model` | NVLM-D 1.0 | T + I+ | `nvidia/NVLM-D-72B`, etc. | | ✅︎ | | `Ovis` | Ovis2, Ovis1.6 | T + I+ | `AIDC-AI/Ovis2-1B`, `AIDC-AI/Ovis1.6-Llama3.2-3B`, etc. | | ✅︎ | | `Ovis2_5` | Ovis2.5 | T + I+ + V | `AIDC-AI/Ovis2.5-9B`, etc. | | | +| `PaliGemmaForConditionalGeneration` | PaliGemma, PaliGemma 2 | T + IE | `google/paligemma-3b-pt-224`, `google/paligemma-3b-mix-224`, `google/paligemma2-3b-ft-docci-448`, etc. | | ✅︎ | | `Phi3VForCausalLM` | Phi-3-Vision, Phi-3.5-Vision | T + IE+ | `microsoft/Phi-3-vision-128k-instruct`, `microsoft/Phi-3.5-vision-instruct`, etc. | | ✅︎ | | `Phi4MMForCausalLM` | Phi-4-multimodal | T + I+ / T + A+ / I+ + A+ | `microsoft/Phi-4-multimodal-instruct`, etc. | ✅︎ | ✅︎ | | `Phi4MultimodalForCausalLM` | Phi-4-multimodal (HF Transformers) | T + I+ / T + A+ / I+ + A+ | `microsoft/Phi-4-multimodal-instruct` (with revision `refs/pr/70`), etc. | ✅︎ | ✅︎ | @@ -694,8 +696,6 @@ Some models are supported only via the [Transformers backend](#transformers). Th | Architecture | Models | Inputs | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/parallelism_scaling.md) | |--------------|--------|--------|-------------------|-----------------------------|-----------------------------------------| | `Emu3ForConditionalGeneration` | Emu3 | T + I | `BAAI/Emu3-Chat-hf` | ✅︎ | ✅︎ | -| `Gemma3ForConditionalGeneration` | Gemma 3 | T + I+ | `google/gemma-3-4b-it`, `google/gemma-3-27b-it`, etc. | ✅︎ | ✅︎ | -| `PaliGemmaForConditionalGeneration` | PaliGemma, PaliGemma 2 | T + IE | `google/paligemma-3b-pt-224`, `google/paligemma-3b-mix-224`, `google/paligemma2-3b-ft-docci-448`, etc. | ✅︎ | ✅︎ | ^ You need to set the architecture name via `--hf-overrides` to match the one in vLLM.     • For example, to use DeepSeek-VL2 series models: @@ -704,7 +704,21 @@ Some models are supported only via the [Transformers backend](#transformers). Th + Multiple items can be inputted per text prompt for this modality. !!! warning - For `Gemma3ForConditionalGeneration`, `{"do_pan_and_scan": true}` is not supported in Transformers backend yet. + Both V0 and V1 support `Gemma3ForConditionalGeneration` for text-only inputs. + However, there are differences in how they handle text + image inputs: + + V0 correctly implements the model's attention pattern: + - Uses bidirectional attention between the image tokens corresponding to the same image + - Uses causal attention for other tokens + - Implemented via (naive) PyTorch SDPA with masking tensors + - Note: May use significant memory for long prompts with image + + V1 currently uses a simplified attention pattern: + - Uses causal attention for all tokens, including image tokens + - Generates reasonable outputs but does not match the original model's attention for text + image inputs, especially when `{"do_pan_and_scan": true}` + - Will be updated in the future to support the correct behavior + + This limitation exists because the model's mixed attention pattern (bidirectional for images, causal otherwise) is not yet supported by vLLM's attention backends. !!! note `Gemma3nForConditionalGeneration` is only supported on V1 due to shared KV caching and it depends on `timm>=1.0.17` to make use of its @@ -756,6 +770,9 @@ Some models are supported only via the [Transformers backend](#transformers). Th The official `openbmb/MiniCPM-V-2` doesn't work yet, so we need to use a fork (`HwwwH/MiniCPM-V-2`) for now. For more details, please see: +!!! warning + Our PaliGemma implementations have the same problem as Gemma 3 (see above) for both V0 and V1. + !!! note For Qwen2.5-Omni and Qwen3-Omni, reading audio from video pre-processing (`--mm-processor-kwargs '{"use_audio_in_video": true}'`) is currently work in progress and not yet supported. diff --git a/examples/offline_inference/vision_language.py b/examples/offline_inference/vision_language.py index 35311a0ca7e1..6336b7f7c7ee 100644 --- a/examples/offline_inference/vision_language.py +++ b/examples/offline_inference/vision_language.py @@ -275,8 +275,7 @@ def run_gemma3(questions: list[str], modality: str) -> ModelRequestData: model=model_name, max_model_len=2048, max_num_seqs=2, - # TODO: Support this in transformers backend - # mm_processor_kwargs={"do_pan_and_scan": True}, + mm_processor_kwargs={"do_pan_and_scan": True}, limit_mm_per_prompt={modality: 1}, ) diff --git a/tests/models/language/generation/test_gemma.py b/tests/models/language/generation/test_gemma.py index 5108da68cb0b..246b893be315 100644 --- a/tests/models/language/generation/test_gemma.py +++ b/tests/models/language/generation/test_gemma.py @@ -3,7 +3,7 @@ import numpy as np import pytest -MODELS = ["google/gemma-2b", "google/gemma-2-2b"] +MODELS = ["google/gemma-2b", "google/gemma-2-2b", "google/gemma-3-4b-it"] @pytest.mark.parametrize("model", MODELS) @@ -14,8 +14,14 @@ def test_dummy_loader(vllm_runner, monkeypatch, model: str) -> None: model, load_format="dummy", ) as llm: - normalizers = llm.apply_model( - lambda model: model.model.normalizer.cpu().item() - ) - config = llm.llm.llm_engine.model_config.hf_config + if model == "google/gemma-3-4b-it": + normalizers = llm.llm.collective_rpc( + lambda self: self.model_runner.model.language_model.model.normalizer.cpu().item() # noqa: E501 + ) + config = llm.llm.llm_engine.model_config.hf_config.text_config + else: + normalizers = llm.llm.collective_rpc( + lambda self: self.model_runner.model.model.normalizer.cpu().item() + ) + config = llm.llm.llm_engine.model_config.hf_config assert np.allclose(normalizers, config.hidden_size**0.5, rtol=2e-3) diff --git a/tests/models/multimodal/generation/test_common.py b/tests/models/multimodal/generation/test_common.py index 44bbc4479ca4..7118bbc5e780 100644 --- a/tests/models/multimodal/generation/test_common.py +++ b/tests/models/multimodal/generation/test_common.py @@ -113,6 +113,25 @@ dtype="bfloat16" if current_platform.is_cpu() else "auto", marks=[pytest.mark.core_model, pytest.mark.cpu_model], ), + "paligemma": VLMTestInfo( + models=["google/paligemma-3b-mix-224"], + test_type=VLMTestType.IMAGE, + prompt_formatter=identity, + img_idx_to_prompt=lambda idx: "", + # Paligemma uses its own sample prompts because the default one fails + single_image_prompts=IMAGE_ASSETS.prompts( + { + "stop_sign": "caption es", + "cherry_blossom": "What is in the picture?", + } + ), + auto_cls=AutoModelForImageTextToText, + vllm_output_post_proc=model_utils.paligemma_vllm_to_hf_output, + dtype="bfloat16", + marks=[ + pytest.mark.skip(reason="vLLM does not support PrefixLM attention mask") + ], + ), "qwen2_5_vl": VLMTestInfo( models=["Qwen/Qwen2.5-VL-3B-Instruct"], test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE, VLMTestType.VIDEO), @@ -177,24 +196,14 @@ # Gemma3 has bidirectional mask on images "gemma3-transformers": VLMTestInfo( models=["google/gemma-3-4b-it"], - test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE), - prompt_formatter=lambda img_prompt: f"user\n{img_prompt}\nmodel\n", # noqa: E501 - single_image_prompts=IMAGE_ASSETS.prompts( - { - "stop_sign": "What's the content in the center of the image?", # noqa: E501 - "cherry_blossom": "What is the season?", - } - ), - multi_image_prompt="Describe the two images in detail.", # noqa: E501 - max_model_len=8192, + test_type=VLMTestType.IMAGE, + prompt_formatter=lambda vid_prompt: f"<'user\n{vid_prompt}\nmodel\n", # noqa: E501 + max_model_len=4096, auto_cls=AutoModelForImageTextToText, - # TODO: Support `do_pan_and_scan` in transformers backend - # patch_hf_runner=model_utils.gemma3_patch_hf_runner, vllm_output_post_proc=model_utils.gemma3_vllm_to_hf_output, image_size_factors=[(0.25, 0.5, 1.0)], vllm_runner_kwargs={ "model_impl": "transformers", - # "mm_processor_kwargs": {"do_pan_and_scan": True}, }, marks=[pytest.mark.core_model], ), @@ -213,27 +222,6 @@ }, marks=[pytest.mark.core_model], ), - # PaliGemma has PrefixLM attention - "paligemma-transformers": VLMTestInfo( - models=["google/paligemma-3b-mix-224"], - test_type=VLMTestType.IMAGE, - prompt_formatter=identity, - img_idx_to_prompt=lambda idx: "", - # PaliGemma uses its own sample prompts because the default one fails - single_image_prompts=IMAGE_ASSETS.prompts( - { - "stop_sign": "caption es", - "cherry_blossom": "What is in the picture?", - } - ), - auto_cls=AutoModelForImageTextToText, - vllm_output_post_proc=model_utils.paligemma_vllm_to_hf_output, - image_size_factors=[(0.25, 0.5, 1.0)], - vllm_runner_kwargs={ - "model_impl": "transformers", - }, - marks=[pytest.mark.core_model], - ), # Pixel values from processor are not 4D or 5D arrays "qwen2_5_vl-transformers": VLMTestInfo( models=["Qwen/Qwen2.5-VL-3B-Instruct"], @@ -360,6 +348,24 @@ image_size_factors=[(), (0.25,), (0.25, 0.25, 0.25), (0.25, 0.2, 0.15)], marks=[large_gpu_mark(min_gb=32)], ), + "gemma3": VLMTestInfo( + models=["google/gemma-3-4b-it"], + test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE), + prompt_formatter=lambda img_prompt: f"user\n{img_prompt}\nmodel\n", # noqa: E501 + single_image_prompts=IMAGE_ASSETS.prompts( + { + "stop_sign": "What's the content in the center of the image?", # noqa: E501 + "cherry_blossom": "What is the season?", + } + ), + multi_image_prompt="Describe the two images in detail.", # noqa: E501 + max_model_len=4096, + max_num_seqs=2, + auto_cls=AutoModelForImageTextToText, + vllm_runner_kwargs={"mm_processor_kwargs": {"do_pan_and_scan": True}}, + patch_hf_runner=model_utils.gemma3_patch_hf_runner, + num_logprobs=10, + ), "glm4v": VLMTestInfo( models=["zai-org/glm-4v-9b"], test_type=VLMTestType.IMAGE, diff --git a/tests/models/multimodal/generation/vlm_utils/model_utils.py b/tests/models/multimodal/generation/vlm_utils/model_utils.py index 8f0caed4dd4f..0685a01da58f 100644 --- a/tests/models/multimodal/generation/vlm_utils/model_utils.py +++ b/tests/models/multimodal/generation/vlm_utils/model_utils.py @@ -328,6 +328,16 @@ def processor(*args, **kwargs): hf_model.processor = processor + orig_generate = hf_model.model.generate + + def _generate(self, *args, **kwargs): + # FIXME: https://github.com/huggingface/transformers/issues/38333 + kwargs["disable_compile"] = True + + return orig_generate(*args, **kwargs) + + hf_model.model.generate = types.MethodType(_generate, hf_model.model) + return hf_model diff --git a/tests/models/multimodal/processing/test_common.py b/tests/models/multimodal/processing/test_common.py index 4e693b310277..a7308244523e 100644 --- a/tests/models/multimodal/processing/test_common.py +++ b/tests/models/multimodal/processing/test_common.py @@ -222,6 +222,7 @@ def _to_dummy_options(modality: str, count: int) -> BaseDummyOptions: _ADD_SPECIAL_TOKENS_OVERRIDES = { "ovis": False, "ovis2_5": False, + "paligemma": False, "ultravox": False, "whisper": False, } @@ -333,6 +334,7 @@ def _test_processing_correctness_one( "deepseek-ai/deepseek-vl2-tiny", "baidu/ERNIE-4.5-VL-28B-A3B-PT", "adept/fuyu-8b", + "google/gemma-3-4b-it", "google/gemma-3n-E2B-it", "zai-org/glm-4v-9b", "zai-org/GLM-4.1V-9B-Thinking", @@ -369,6 +371,8 @@ def _test_processing_correctness_one( "AIDC-AI/Ovis1.6-Llama3.2-3B", "AIDC-AI/Ovis2-1B", "AIDC-AI/Ovis2.5-2B", + "google/paligemma-3b-mix-224", + "google/paligemma2-3b-ft-docci-448", "microsoft/Phi-3.5-vision-instruct", "microsoft/Phi-4-multimodal-instruct", "mistralai/Pixtral-12B-2409", diff --git a/tests/models/multimodal/processing/test_tensor_schema.py b/tests/models/multimodal/processing/test_tensor_schema.py index c0436e117975..024fe76e6ff7 100644 --- a/tests/models/multimodal/processing/test_tensor_schema.py +++ b/tests/models/multimodal/processing/test_tensor_schema.py @@ -48,6 +48,7 @@ "Idefics3ForConditionalGeneration", "LlavaForConditionalGeneration", "MiniCPMV", + "PaliGemmaForConditionalGeneration", ] REPO_ID_TO_SKIP = { "nm-testing/pixtral-12b-FP8-dynamic": "duplicated test", diff --git a/vllm/model_executor/models/gemma3_mm.py b/vllm/model_executor/models/gemma3_mm.py new file mode 100644 index 000000000000..7c628fe93ce3 --- /dev/null +++ b/vllm/model_executor/models/gemma3_mm.py @@ -0,0 +1,710 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import math +from collections.abc import Iterable, Mapping, Sequence +from typing import Annotated, Any, Literal + +import torch +from torch import nn +from transformers import BatchFeature, Gemma3Config, Gemma3Processor +from transformers.models.gemma3.processing_gemma3 import Gemma3ProcessorKwargs + +import vllm.envs as envs +from vllm.config import VllmConfig +from vllm.config.multimodal import BaseDummyOptions +from vllm.logger import init_logger +from vllm.model_executor.layers.layernorm import GemmaRMSNorm +from vllm.model_executor.models.module_mapping import MultiModelKeys +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import ( + MultiModalDataDict, + MultiModalFieldConfig, + MultiModalKwargsItems, +) +from vllm.multimodal.parse import ImageProcessorItems, ImageSize, MultiModalDataItems +from vllm.multimodal.processing import ( + BaseMultiModalProcessor, + BaseProcessingInfo, + MultiModalPromptUpdates, + MultiModalPromptUpdatesApplyResult, + PlaceholderFeaturesInfo, + PromptReplacement, + PromptUpdate, + PromptUpdateDetails, + replace_token_matches, +) +from vllm.multimodal.profiling import BaseDummyInputsBuilder +from vllm.sequence import IntermediateTensors +from vllm.utils.tensor_schema import TensorSchema, TensorShape + +from .interfaces import ( + MultiModalEmbeddings, + SupportsLoRA, + SupportsMultiModal, + SupportsPP, +) +from .siglip import SiglipVisionModel +from .utils import ( + AutoWeightsLoader, + WeightsMapper, + init_vllm_registered_model, + maybe_prefix, +) + +logger = init_logger(__name__) + + +class Gemma3ImagePixelInputs(TensorSchema): + """ + Dimensions: + - p: Number of patches total (over each image over each prompt in the + batch) + - c: Number of channels (3) + - h: Height of each patch + - w: Width of each patch + - bn: Batch size * number of images + """ + + type: Literal["pixel_values"] = "pixel_values" + + pixel_values: Annotated[torch.Tensor, TensorShape("p", 3, "h", "w")] + + num_patches: Annotated[torch.Tensor, TensorShape("bn")] + + +Gemma3ImageInputs = Gemma3ImagePixelInputs + + +class Gemma3ProcessingInfo(BaseProcessingInfo): + def get_hf_config(self): + return self.ctx.get_hf_config(Gemma3Config) + + def get_hf_processor(self, **kwargs: object): + return self.ctx.get_hf_processor(Gemma3Processor, **kwargs) + + def get_supported_mm_limits(self) -> Mapping[str, int | None]: + return {"image": None} + + def _resolve_image_kwargs( + self, + processor: Gemma3Processor, + keys: set[str], + ) -> dict[str, Any]: + image_processor = processor.image_processor + kwargs = processor._merge_kwargs( + Gemma3ProcessorKwargs, + tokenizer_init_kwargs=processor.tokenizer.init_kwargs, + ) + + images_kwargs = kwargs["images_kwargs"] + + def _resolve_kw(key: str): + val = getattr(image_processor, key) + if val is None: + val = images_kwargs[key] + + return val + + return {k: _resolve_kw(k) for k in keys} + + def get_num_crops( + self, + *, + image_width: int, + image_height: int, + processor: Gemma3Processor | None, + ) -> int: + if processor is None: + processor = self.get_hf_processor() + + images_kwargs = self._resolve_image_kwargs( + processor, + { + "do_pan_and_scan", + "pan_and_scan_min_crop_size", + "pan_and_scan_max_num_crops", + "pan_and_scan_min_ratio_to_activate", + }, + ) + + do_pan_and_scan = images_kwargs["do_pan_and_scan"] + pan_and_scan_min_crop_size = images_kwargs["pan_and_scan_min_crop_size"] + pan_and_scan_max_num_crops = images_kwargs["pan_and_scan_max_num_crops"] + pan_and_scan_min_ratio_to_activate = images_kwargs[ + "pan_and_scan_min_ratio_to_activate" + ] + + if not do_pan_and_scan: + return 0 + + if envs.VLLM_USE_V1: + logger.warning_once( + "`do_pan_and_scan=True` has suboptimal results on V1 " + "because of the simplified attention pattern being used." + ) + + # Based on Gemma3ImageProcessor.pan_and_scan + if image_width >= image_height: + if image_width / image_height < pan_and_scan_min_ratio_to_activate: + return 0 + + num_crops_w = min( + int(math.floor(image_width / pan_and_scan_min_crop_size)), + int(math.floor(image_width / image_height + 0.5)), + ) + + num_crops_w = max(2, num_crops_w) + num_crops_w = min(pan_and_scan_max_num_crops, num_crops_w) + num_crops_h = 1 + else: + if image_height / image_width < pan_and_scan_min_ratio_to_activate: + return 0 + + num_crops_h = min( + int(math.floor(image_height / pan_and_scan_min_crop_size)), + int(math.floor(image_height / image_width + 0.5)), + ) + + num_crops_h = max(2, num_crops_h) + num_crops_h = min(pan_and_scan_max_num_crops, num_crops_h) + num_crops_w = 1 + + crop_size_w = int(math.ceil(image_width / num_crops_w)) + crop_size_h = int(math.ceil(image_height / num_crops_h)) + + if min(crop_size_w, crop_size_h) < pan_and_scan_min_crop_size: + return 0 + + return num_crops_w * num_crops_h + + def get_image_repl( + self, + *, + image_width: int, + image_height: int, + processor: Gemma3Processor | None, + ) -> PromptUpdateDetails[str]: + if processor is None: + processor = self.get_hf_processor() + + boi_token = processor.boi_token + + num_crops = self.get_num_crops( + image_width=image_width, + image_height=image_height, + processor=processor, + ) + + if num_crops == 0: + image_text = boi_token + else: + crops_image_tokens = " ".join(boi_token for _ in range(num_crops)) + image_text = ( + f"Here is the original image {boi_token} and here are some " + f"crops to help you see better {crops_image_tokens}" + ) + + repl_full = image_text.replace(boi_token, processor.full_image_sequence) + + tokenizer = processor.tokenizer + vocab = tokenizer.get_vocab() + image_token_id = vocab[tokenizer.image_token] + + return PromptUpdateDetails.select_token_id(repl_full, image_token_id) + + def get_num_image_tokens( + self, + *, + image_width: int, + image_height: int, + processor: Gemma3Processor | None, + ) -> int: + if processor is None: + processor = self.get_hf_processor() + + num_crops = self.get_num_crops( + image_width=image_width, + image_height=image_height, + processor=processor, + ) + image_seq_len = processor.image_seq_length + + return (num_crops + 1) * image_seq_len + + def get_image_size_with_most_features(self) -> ImageSize: + processor = self.get_hf_processor() + + images_kwargs = self._resolve_image_kwargs( + processor, {"pan_and_scan_max_num_crops"} + ) + max_num_crops = images_kwargs["pan_and_scan_max_num_crops"] + + # Result in the max possible feature size (h:w = max_num_crops:1) + return ImageSize(height=50 * max_num_crops, width=50) + + +class Gemma3DummyInputsBuilder(BaseDummyInputsBuilder[Gemma3ProcessingInfo]): + def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: + num_images = mm_counts.get("image", 0) + + processor = self.info.get_hf_processor() + image_token = processor.boi_token + + return image_token * num_images + + def get_dummy_mm_data( + self, + seq_len: int, + mm_counts: Mapping[str, int], + mm_options: Mapping[str, BaseDummyOptions] | None = None, + ) -> MultiModalDataDict: + num_images = mm_counts.get("image", 0) + + target_width, target_height = self.info.get_image_size_with_most_features() + + image_overrides = mm_options.get("image") if mm_options else None + + return { + "image": self._get_dummy_images( + width=target_width, + height=target_height, + num_images=num_images, + overrides=image_overrides, + ) + } + + +class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]): + def _call_hf_processor( + self, + prompt: str, + mm_data: Mapping[str, object], + mm_kwargs: Mapping[str, object], + tok_kwargs: Mapping[str, object], + ) -> BatchFeature: + processed_outputs = super()._call_hf_processor( + prompt, + mm_data, + mm_kwargs, + tok_kwargs, + ) + + # HF processor pops the `num_crops` kwarg, which is needed by vLLM + if (images := mm_data.get("images")) is not None: + parsed_images = ( + self._get_data_parser() + .parse_mm_data({"image": images}) + .get_items("image", ImageProcessorItems) + ) + image_sizes = [ + parsed_images.get_image_size(i) for i in range(len(parsed_images)) + ] + hf_processor = self.info.get_hf_processor(**mm_kwargs) + + num_crops = [ + self.info.get_num_crops( + image_width=size.width, + image_height=size.height, + processor=hf_processor, + ) + for size in image_sizes + ] + processed_outputs["num_patches"] = torch.tensor(num_crops) + 1 + + return processed_outputs + + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + num_patches = hf_inputs.get("num_patches", torch.empty(0)) + + return dict( + pixel_values=MultiModalFieldConfig.flat_from_sizes("image", num_patches), + num_patches=MultiModalFieldConfig.batched("image"), + ) + + def _get_prompt_updates( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, Any], + out_mm_kwargs: MultiModalKwargsItems, + ) -> Sequence[PromptUpdate]: + hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) + image_token = hf_processor.boi_token + + def get_replacement_gemma3(item_idx: int): + images = mm_items.get_items("image", ImageProcessorItems) + + image_size = images.get_image_size(item_idx) + return self.info.get_image_repl( + image_width=image_size.width, + image_height=image_size.height, + processor=hf_processor, + ) + + return [ + PromptReplacement( + modality="image", + target=image_token, + replacement=get_replacement_gemma3, + ) + ] + + def _apply_token_matches( + self, + prompt: list[int], + mm_prompt_updates: MultiModalPromptUpdates, + ) -> tuple[list[int], MultiModalPromptUpdatesApplyResult]: + token_ids, res = super()._apply_token_matches(prompt, mm_prompt_updates) + + # "\n\n\n" and "\n\n\n\n" are single tokens + # Since our replacement can insert "\n\n" next to "\n" + # tokens, we have to combine them to be consistent with + # the output of the tokenizer + tokenizer = self.info.get_tokenizer() + vocab = tokenizer.get_vocab() + newline_1 = vocab["\n"] + newline_2 = vocab["\n\n"] + newline_3 = vocab["\n\n\n"] + newline_4 = vocab["\n\n\n\n"] + + token_ids = replace_token_matches( + token_ids, + [newline_1, newline_2], + [newline_3], + ) + token_ids = replace_token_matches( + token_ids, + [newline_2, newline_1], + [newline_3], + ) + token_ids = replace_token_matches( + token_ids, + [newline_2, newline_2], + [newline_4], + ) + + return token_ids, res + + def _find_mm_placeholders( + self, + new_token_ids: list[int], + mm_prompt_updates: MultiModalPromptUpdates, + ) -> Mapping[str, list[PlaceholderFeaturesInfo]]: + # We need to detect "\n\n" inside "\n\n\n" and "\n\n\n\n" + tokenizer = self.info.get_tokenizer() + vocab = tokenizer.get_vocab() + newline_1 = vocab["\n"] + newline_2 = vocab["\n\n"] + newline_3 = vocab["\n\n\n"] + newline_4 = vocab["\n\n\n\n"] + + def get_repl_toks(tok: int) -> list[int]: + if tok == newline_3: + return [newline_1, newline_2] + if tok == newline_4: + return [newline_2, newline_2] + + return [tok] + + repl_token_ids = list[int]() + repl_orig_idxs = list[int]() + for orig_idx, orig_tok in enumerate(new_token_ids): + repl_toks = get_repl_toks(orig_tok) + repl_token_ids.extend(repl_toks) + repl_orig_idxs.extend(orig_idx for _ in range(len(repl_toks))) + + repls = super()._find_mm_placeholders(repl_token_ids, mm_prompt_updates) + + return { + modality: [ + PlaceholderFeaturesInfo( + modality=p.modality, + item_idx=p.item_idx, + start_idx=repl_orig_idxs[p.start_idx], + tokens=p.tokens, + is_embed=p.is_embed, + ) + for p in placeholders + ] + for modality, placeholders in repls.items() + } + + +class Gemma3MultiModalProjector(nn.Module): + def __init__(self, config: Gemma3Config): + super().__init__() + + self.mm_input_projection_weight = nn.Parameter( + torch.zeros( + config.vision_config.hidden_size, config.text_config.hidden_size + ) + ) + + self.mm_soft_emb_norm = GemmaRMSNorm( + config.vision_config.hidden_size, eps=config.vision_config.layer_norm_eps + ) + + self.patches_per_image = int( + config.vision_config.image_size // config.vision_config.patch_size + ) + self.tokens_per_side = int(config.mm_tokens_per_image**0.5) + self.kernel_size = self.patches_per_image // self.tokens_per_side + self.avg_pool = nn.AvgPool2d( + kernel_size=self.kernel_size, stride=self.kernel_size + ) + + def forward(self, vision_outputs: torch.Tensor): + batch_size, _, seq_length = vision_outputs.shape + + reshaped_vision_outputs = vision_outputs.transpose(1, 2) + reshaped_vision_outputs = reshaped_vision_outputs.reshape( + batch_size, seq_length, self.patches_per_image, self.patches_per_image + ) + reshaped_vision_outputs = reshaped_vision_outputs.contiguous() + + pooled_vision_outputs = self.avg_pool(reshaped_vision_outputs) + pooled_vision_outputs = pooled_vision_outputs.flatten(2) + pooled_vision_outputs = pooled_vision_outputs.transpose(1, 2) + + normed_vision_outputs = self.mm_soft_emb_norm(pooled_vision_outputs) + + projected_vision_outputs = torch.matmul( + normed_vision_outputs, self.mm_input_projection_weight + ) + return projected_vision_outputs.type_as(vision_outputs) + + +@MULTIMODAL_REGISTRY.register_processor( + Gemma3MultiModalProcessor, + info=Gemma3ProcessingInfo, + dummy_inputs=Gemma3DummyInputsBuilder, +) +class Gemma3ForConditionalGeneration( + nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA +): + merge_by_field_config = True + + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_prefix={ + # mapping for new names in checkpoint saved after transformers v4.52 + "model.language_model.": "language_model.model.", + "model.vision_tower.": "vision_tower.", + "model.multi_modal_projector.": "multi_modal_projector.", + "lm_head.": "language_model.lm_head.", + } + ) + + @classmethod + def get_placeholder_str(cls, modality: str, i: int) -> str | None: + if modality.startswith("image"): + return "" + + raise ValueError("Only image modality is supported") + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + multimodal_config = vllm_config.model_config.multimodal_config + self.config = config + self.quant_config = quant_config + self.multimodal_config = multimodal_config + + self.vision_tower = SiglipVisionModel( + config.vision_config, + quant_config, + prefix=maybe_prefix(prefix, "vision_tower"), + ) + self.multi_modal_projector = Gemma3MultiModalProjector(config) + + self.language_model = init_vllm_registered_model( + vllm_config=vllm_config, + hf_config=config.text_config, + prefix=maybe_prefix(prefix, "language_model"), + architectures=["Gemma3ForCausalLM"], + ) + logit_scale = getattr(config, "logit_scale", 1.0) + + if hasattr(self.language_model, "logits_processor"): + # The logits processor can be unset if we're using + # automatic conversion to pooling model. + self.language_model.logits_processor.scale *= logit_scale + + self.make_empty_intermediate_tensors = ( + self.language_model.make_empty_intermediate_tensors + ) + + @property + def dtype(self): + return next(self.parameters()).dtype + + def _parse_and_validate_image_input( + self, **kwargs: object + ) -> Gemma3ImageInputs | None: + pixel_values = kwargs.pop("pixel_values", None) + num_patches = kwargs.pop("num_patches", None) + image_embeds = kwargs.pop("image_embeds", None) + assert image_embeds is None, "Gemma3 does not support image_embeds." + if pixel_values is None: + return None + + image_size = self.config.vision_config.image_size + + return Gemma3ImagePixelInputs( + pixel_values=pixel_values, + num_patches=num_patches, + resolve_bindings={"h": image_size, "w": image_size}, + ) + + def _image_pixels_to_features( + self, + vision_tower: SiglipVisionModel, + pixel_values: torch.Tensor, + ) -> torch.Tensor: + return vision_tower(pixel_values) + + def _process_image_input( + self, + image_input: Gemma3ImageInputs, + ) -> list[torch.Tensor]: + assert self.vision_tower is not None + + pixel_values = image_input["pixel_values"] + num_patches = image_input["num_patches"] + + image_features = self._image_pixels_to_features( + self.vision_tower, + pixel_values, + ) + image_embeds = self.multi_modal_projector(image_features) + + return [e.flatten(0, 1) for e in image_embeds.split(num_patches.tolist())] + + def get_language_model(self) -> torch.nn.Module: + return self.language_model + + def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: + image_input = self._parse_and_validate_image_input(**kwargs) + if image_input is None: + return [] + + return self._process_image_input(image_input) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + **kwargs: object, + ) -> IntermediateTensors: + if intermediate_tensors is not None: + inputs_embeds = None + + hidden_states = self.language_model.model( + input_ids, + positions, + intermediate_tensors, + inputs_embeds=inputs_embeds, + **kwargs, + ) + + return hidden_states + + def prepare_attn_masks( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + mask_dtype: torch.dtype, + **kwargs, + ): + kwargs["has_images"] = True + # NOTE(woosuk): Here, we distinguish the sequences by the position id 0. + # This is a HACK. Fix this. + start_indices = (positions == 0).cpu().nonzero() + num_seqs = len(start_indices) + seq_lens = [] + for i in range(num_seqs): + start_idx = start_indices[i].item() + if i < num_seqs - 1: + end_idx = start_indices[i + 1].item() + else: + end_idx = len(input_ids) + seq_lens.append(end_idx - start_idx) + kwargs["seq_lens"] = seq_lens + + global_attn_masks = [] + local_attn_masks = [] + start_idx = 0 + for seq_len in seq_lens: + end_idx = start_idx + seq_len + input_token_ids = input_ids[start_idx:end_idx] + start_idx = end_idx + # Create a global causal mask. + global_attn_mask = torch.empty( + 1, + 1, + seq_len, + seq_len, + dtype=mask_dtype, + device=input_ids.device, + ) + global_attn_mask.fill_(float("-inf")) + # Fill the lower triangle with 0. + global_attn_mask = global_attn_mask.triu(diagonal=1) + + # Consider the bidirectional attention between image tokens. + img_mask = torch.zeros_like(global_attn_mask) + img_pos = input_token_ids == self.config.image_token_index + img_mask[:, :, :, img_pos] += 1 + img_mask[:, :, img_pos, :] += 1 + global_attn_mask = torch.where(img_mask == 2, 0, global_attn_mask) + global_attn_masks.append(global_attn_mask) + + sliding_window = self.config.text_config.sliding_window + if sliding_window is not None: + # Create a local causal mask with sliding window (1024). + local_attn_mask = torch.ones_like(global_attn_mask) + local_attn_mask = torch.tril(local_attn_mask, diagonal=-sliding_window) + local_attn_mask = torch.where( + local_attn_mask == 0, global_attn_mask, float("-inf") + ) + local_attn_masks.append(local_attn_mask) + kwargs["global_attn_masks"] = global_attn_masks + kwargs["local_attn_masks"] = local_attn_masks + return kwargs + + def compute_logits( + self, + hidden_states: torch.Tensor, + ) -> torch.Tensor | None: + return self.language_model.compute_logits(hidden_states) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader(self) + return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) + + def get_mm_mapping(self) -> MultiModelKeys: + """ + Get the module prefix in multimodal models + """ + return MultiModelKeys.from_string_field( + language_model="language_model", + connector="multi_modal_projector", + tower_model="vision_tower", + ) diff --git a/vllm/model_executor/models/paligemma.py b/vllm/model_executor/models/paligemma.py new file mode 100644 index 000000000000..fb0b4b290467 --- /dev/null +++ b/vllm/model_executor/models/paligemma.py @@ -0,0 +1,412 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from collections.abc import Iterable, Mapping, Sequence +from typing import Annotated, Literal, TypeAlias + +import torch +from torch import nn +from transformers import BatchFeature, PaliGemmaConfig + +from vllm.config import VllmConfig +from vllm.config.multimodal import BaseDummyOptions +from vllm.logger import init_logger +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import ( + MultiModalDataDict, + MultiModalFieldConfig, + MultiModalInputs, + MultiModalKwargsItems, + MultiModalUUIDDict, +) +from vllm.multimodal.parse import ( + ImageEmbeddingItems, + ImageProcessorItems, + MultiModalDataItems, +) +from vllm.multimodal.processing import ( + BaseMultiModalProcessor, + BaseProcessingInfo, + PromptIndexTargets, + PromptInsertion, + PromptUpdate, + PromptUpdateDetails, +) +from vllm.multimodal.profiling import BaseDummyInputsBuilder +from vllm.sequence import IntermediateTensors +from vllm.utils.tensor_schema import TensorSchema, TensorShape + +from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP +from .siglip import SiglipVisionModel +from .utils import ( + AutoWeightsLoader, + WeightsMapper, + flatten_bn, + init_vllm_registered_model, + maybe_prefix, +) +from .vision import get_vision_encoder_info + +logger = init_logger(__name__) + + +class PaliGemmaImagePixelInputs(TensorSchema): + """ + Dimensions: + - bn: Batch size * number of images + - c: Number of channels (3) + - h: Height + - w: Width + """ + + type: Literal["pixel_values"] = "pixel_values" + data: Annotated[torch.Tensor, TensorShape("bn", 3, "h", "w")] + + +class PaliGemmaImageEmbeddingInputs(TensorSchema): + """ + Dimensions: + - bn: Batch size * number of images + - ifs: Image feature size + - hs: Hidden size (must match language model backbone) + """ + + type: Literal["image_embeds"] = "image_embeds" + data: Annotated[torch.Tensor, TensorShape("bn", "ifs", "hs")] + + +PaliGemmaImageInputs: TypeAlias = ( + PaliGemmaImagePixelInputs | PaliGemmaImageEmbeddingInputs +) + + +class PaliGemmaMultiModalProjector(nn.Module): + def __init__(self, vision_hidden_size: int, projection_dim: int): + super().__init__() + + self.linear = nn.Linear(vision_hidden_size, projection_dim, bias=True) + + def forward(self, image_features: torch.Tensor) -> torch.Tensor: + hidden_states = self.linear(image_features) + return hidden_states + + +class PaliGemmaProcessingInfo(BaseProcessingInfo): + def get_hf_config(self): + return self.ctx.get_hf_config(PaliGemmaConfig) + + def get_vision_encoder_info(self): + return get_vision_encoder_info(self.get_hf_config()) + + def get_supported_mm_limits(self) -> Mapping[str, int | None]: + return {"image": 1} + + def get_num_image_tokens( + self, + *, + image_width: int, + image_height: int, + ) -> int: + vision_encoder_info = self.get_vision_encoder_info() + + return vision_encoder_info.get_num_image_tokens( + image_width=image_width, + image_height=image_height, + ) + + +class PaliGemmaDummyInputsBuilder(BaseDummyInputsBuilder[PaliGemmaProcessingInfo]): + def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: + return "" + + def get_dummy_mm_data( + self, + seq_len: int, + mm_counts: Mapping[str, int], + mm_options: Mapping[str, BaseDummyOptions] | None = None, + ) -> MultiModalDataDict: + hf_config = self.info.get_hf_config() + vision_config = hf_config.vision_config + max_image_size = vision_config.image_size + + num_images = mm_counts.get("image", 0) + + image_overrides = mm_options.get("image") if mm_options else None + + return { + "image": self._get_dummy_images( + width=max_image_size, + height=max_image_size, + num_images=num_images, + overrides=image_overrides, + ) + } + + +class PaliGemmaMultiModalProcessor(BaseMultiModalProcessor[PaliGemmaProcessingInfo]): + def _call_hf_processor( + self, + prompt: str, + mm_data: Mapping[str, object], + mm_kwargs: Mapping[str, object], + tok_kwargs: Mapping[str, object], + ) -> BatchFeature: + tokenizer = self.info.get_tokenizer() + if not mm_data: + prompt_ids = tokenizer.encode(prompt, add_special_tokens=False) + return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt") + + return super()._call_hf_processor( + prompt=prompt, + mm_data=mm_data, + mm_kwargs=mm_kwargs, + tok_kwargs=tok_kwargs, + ) + + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + return dict(pixel_values=MultiModalFieldConfig.batched("image")) + + def _get_prompt_updates( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + out_mm_kwargs: MultiModalKwargsItems, + ) -> Sequence[PromptUpdate]: + hf_config = self.info.get_hf_config() + image_token_id = hf_config.image_token_index + + tokenizer = self.info.get_tokenizer() + + bos_token_id = tokenizer.bos_token_id + assert isinstance(bos_token_id, int) + + def get_insertion(item_idx: int): + images = mm_items.get_items( + "image", (ImageEmbeddingItems, ImageProcessorItems) + ) + + if isinstance(images, ImageEmbeddingItems): + num_image_tokens = images.get_feature_size(item_idx) + else: + image_size = images.get_image_size(item_idx) + num_image_tokens = self.info.get_num_image_tokens( + image_width=image_size.width, + image_height=image_size.height, + ) + + image_tokens = [image_token_id] * num_image_tokens + + return PromptUpdateDetails.select_token_id( + image_tokens + [bos_token_id], + embed_token_id=image_token_id, + ) + + # Paligemma 1 and 2 have different tokenizer.add_bos_token + # Insert *n + after for Paligemma 1 + # Insert *n + for Paligemma 2 + return [ + PromptInsertion( + modality="image", + target=PromptIndexTargets.prefix( + [bos_token_id] if tokenizer.add_bos_token else [] + ), + insertion=get_insertion, + ) + ] + + def apply( + self, + prompt: str | list[int], + mm_data: MultiModalDataDict, + hf_processor_mm_kwargs: Mapping[str, object], + tokenization_kwargs: Mapping[str, object] | None = None, + mm_uuids: MultiModalUUIDDict | None = None, + ) -> MultiModalInputs: + mm_inputs = super().apply( + prompt, + mm_data, + hf_processor_mm_kwargs, + tokenization_kwargs, + mm_uuids=mm_uuids, + ) + prompt_token_ids = mm_inputs["prompt_token_ids"] + + tokenizer = self.info.get_tokenizer() + newline_prompt = "\n" + newline_token_id = tokenizer.encode(newline_prompt)[-1] # 108 + # Force to add newline at the end of prompt for paligemma's format + # This step can NOT be replacemented by current PromptUpdate methods + if len(prompt_token_ids) and prompt_token_ids[-1] != newline_token_id: + prompt_token_ids.append(newline_token_id) + mm_inputs["prompt_token_ids"] = prompt_token_ids + + return mm_inputs + + +@MULTIMODAL_REGISTRY.register_processor( + PaliGemmaMultiModalProcessor, + info=PaliGemmaProcessingInfo, + dummy_inputs=PaliGemmaDummyInputsBuilder, +) +class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_prefix={ + # mapping for new names in checkpoint saved after transformers v4.52 + "model.language_model.": "language_model.model.", + "model.vision_tower.": "vision_tower.", + "model.multi_modal_projector.": "multi_modal_projector.", + "lm_head.": "language_model.lm_head.", + } + ) + + @classmethod + def get_placeholder_str(cls, modality: str, i: int) -> str | None: + if modality.startswith("image"): + return None + + raise ValueError("Only image modality is supported") + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + multimodal_config = vllm_config.model_config.multimodal_config + self.config = config + self.multimodal_config = multimodal_config + + self.vision_tower = SiglipVisionModel( + config.vision_config, + quant_config, + prefix=maybe_prefix(prefix, "vision_tower"), + ) + self.multi_modal_projector = PaliGemmaMultiModalProjector( + vision_hidden_size=config.vision_config.hidden_size, + projection_dim=config.vision_config.projection_dim, + ) + + self.quant_config = quant_config + + if config.text_config.model_type == "gemma": + config.text_config.architectures = ["GemmaForCausalLM"] + else: + config.text_config.architectures = ["Gemma2ForCausalLM"] + self.language_model = init_vllm_registered_model( + vllm_config=vllm_config, + hf_config=config.text_config, + prefix=maybe_prefix(prefix, "language_model"), + ) + logit_scale = getattr(config, "logit_scale", 1.0) + self.language_model.logits_processor.scale *= logit_scale + + self.make_empty_intermediate_tensors = ( + self.language_model.make_empty_intermediate_tensors + ) + + def _parse_and_validate_image_input( + self, **kwargs: object + ) -> PaliGemmaImageInputs | None: + pixel_values = kwargs.pop("pixel_values", None) + image_embeds = kwargs.pop("image_embeds", None) + + if pixel_values is None and image_embeds is None: + return None + + if pixel_values is not None: + pixel_values = flatten_bn(pixel_values, concat=True) + + h = w = self.config.vision_config.image_size + return PaliGemmaImagePixelInputs( + type="pixel_values", + data=pixel_values, + resolve_bindings={"h": h, "w": w}, + ) + + if image_embeds is not None: + image_embeds = flatten_bn(image_embeds, concat=True) + + return PaliGemmaImageEmbeddingInputs( + type="image_embeds", + data=image_embeds, + ) + + raise AssertionError("This line should be unreachable.") + + def _image_pixels_to_features( + self, + vision_tower: SiglipVisionModel, + pixel_values: torch.Tensor, + ) -> torch.Tensor: + target_dtype = vision_tower.get_input_embeddings().weight.dtype + image_features = vision_tower(pixel_values.to(dtype=target_dtype)) + + return image_features + + def _process_image_input( + self, + image_input: PaliGemmaImageInputs, + ) -> torch.Tensor: + if image_input["type"] == "image_embeds": + return image_input["data"] + + assert self.vision_tower is not None + pixel_values = image_input["data"] + image_features = self._image_pixels_to_features( + self.vision_tower, + pixel_values, + ) + + return self.multi_modal_projector(image_features) + + def get_language_model(self) -> torch.nn.Module: + return self.language_model + + def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: + image_input = self._parse_and_validate_image_input(**kwargs) + if image_input is None: + return [] + vision_embeddings = self._process_image_input(image_input) + # https://github.com/huggingface/transformers/blob/main/src/transformers/models/paligemma/modeling_paligemma.py#L294 # noqa + vision_embeddings = vision_embeddings * (self.config.hidden_size**-0.5) + return vision_embeddings + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + **kwargs: object, + ) -> IntermediateTensors: + if intermediate_tensors is not None: + inputs_embeds = None + + hidden_states = self.language_model.model( + input_ids, positions, intermediate_tensors, inputs_embeds=inputs_embeds + ) + + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + ) -> torch.Tensor | None: + return self.language_model.compute_logits(hidden_states) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader(self) + return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index da1606a7568d..2fa269f77755 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -264,6 +264,7 @@ "Ernie4_5_VLMoeForConditionalGeneration", ), "FuyuForCausalLM": ("fuyu", "FuyuForCausalLM"), + "Gemma3ForConditionalGeneration": ("gemma3_mm", "Gemma3ForConditionalGeneration"), # noqa: E501 "Gemma3nForConditionalGeneration": ( "gemma3n_mm", "Gemma3nForConditionalGeneration", @@ -333,6 +334,10 @@ "NVLM_D": ("nvlm_d", "NVLM_D_Model"), "Ovis": ("ovis", "Ovis"), "Ovis2_5": ("ovis2_5", "Ovis2_5"), + "PaliGemmaForConditionalGeneration": ( + "paligemma", + "PaliGemmaForConditionalGeneration", + ), "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"), "Phi4MMForCausalLM": ("phi4mm", "Phi4MMForCausalLM"), "Phi4MultimodalForCausalLM": ("phi4_multimodal", "Phi4MultimodalForCausalLM"), # noqa: E501 @@ -405,14 +410,6 @@ "transformers", "TransformersMultiModalForCausalLM", ), - "Gemma3ForConditionalGeneration": ( - "transformers", - "TransformersMultiModalForCausalLM", - ), - "PaliGemmaForConditionalGeneration": ( - "transformers", - "TransformersMultiModalForCausalLM", - ), } _TRANSFORMERS_BACKEND_MODELS = { diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index 9788bfeca109..4d6aa10500d3 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -59,6 +59,9 @@ "Qwen2ForCausalLM": _ROCM_SWA_REASON, "MistralForCausalLM": _ROCM_SWA_REASON, "MixtralForCausalLM": _ROCM_SWA_REASON, + "PaliGemmaForConditionalGeneration": ( + "ROCm flash attention does not yet fully support 32-bit precision on PaliGemma" + ), "Phi3VForCausalLM": ( "ROCm Triton flash attention may run into compilation errors due to " "excessive use of shared memory. If this happens, disable Triton FA "