diff --git a/docs/models/hardware_supported_models/tpu.md b/docs/models/hardware_supported_models/tpu.md
index 8d3e28c259ec..7b0a5ba6e72d 100644
--- a/docs/models/hardware_supported_models/tpu.md
+++ b/docs/models/hardware_supported_models/tpu.md
@@ -16,8 +16,8 @@
 | meta-llama/Llama-4-*                                | Llama4ForConditionalGeneration | ❌ |
 | microsoft/Phi-3-mini-128k-instruct                  | Phi3ForCausalLM                | 🟨 |
 | microsoft/phi-4                                     | Phi3ForCausalLM                | ❌ |
-| google/gemma-3-27b-it                               | TransformersForMultimodalLM    | 🟨 |
-| google/gemma-3-4b-it                                | TransformersForMultimodalLM    | ❌ |
+| google/gemma-3-27b-it                               | Gemma3ForConditionalGeneration | 🟨 |
+| google/gemma-3-4b-it                                | Gemma3ForConditionalGeneration | ❌ |
 | deepseek-ai/DeepSeek-R1                             | DeepseekV3ForCausalLM          | ❌ |
 | deepseek-ai/DeepSeek-V3                             | DeepseekV3ForCausalLM          | ❌ |
 | RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8  | LlamaForCausalLM               | ✅ |
diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md
index 001a5b96174a..076093d625f3 100644
--- a/docs/models/supported_models.md
+++ b/docs/models/supported_models.md
@@ -641,6 +641,7 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen
 | `DeepseekVLV2ForCausalLM`^ | DeepSeek-VL2 | T + I+ | `deepseek-ai/deepseek-vl2-tiny`, `deepseek-ai/deepseek-vl2-small`, `deepseek-ai/deepseek-vl2`, etc. | | ✅︎ |
 | `Ernie4_5_VLMoeForConditionalGeneration` | Ernie4.5-VL | T + I+/ V+ | `baidu/ERNIE-4.5-VL-28B-A3B-PT`, `baidu/ERNIE-4.5-VL-424B-A47B-PT` | | ✅︎ |
 | `FuyuForCausalLM` | Fuyu | T + I | `adept/fuyu-8b`, etc. | | ✅︎ |
+| `Gemma3ForConditionalGeneration` | Gemma 3 | T + I+ | `google/gemma-3-4b-it`, `google/gemma-3-27b-it`, etc. | ✅︎ | ✅︎ |
 | `Gemma3nForConditionalGeneration` | Gemma 3n | T + I + A | `google/gemma-3n-E2B-it`, `google/gemma-3n-E4B-it`, etc. | | |
 | `GLM4VForCausalLM`^ | GLM-4V | T + I | `zai-org/glm-4v-9b`, `zai-org/cogagent-9b-20241220`, etc. | ✅︎ | ✅︎ |
 | `Glm4vForConditionalGeneration` | GLM-4.1V-Thinking | T + IE+ + VE+ | `zai-org/GLM-4.1V-9B-Thinking`, etc. | ✅︎ | ✅︎ |
@@ -670,6 +671,7 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen
 | `NVLM_D_Model` | NVLM-D 1.0 | T + I+ | `nvidia/NVLM-D-72B`, etc. | | ✅︎ |
 | `Ovis` | Ovis2, Ovis1.6 | T + I+ | `AIDC-AI/Ovis2-1B`, `AIDC-AI/Ovis1.6-Llama3.2-3B`, etc. | | ✅︎ |
 | `Ovis2_5` | Ovis2.5 | T + I+ + V | `AIDC-AI/Ovis2.5-9B`, etc. | | |
+| `PaliGemmaForConditionalGeneration` | PaliGemma, PaliGemma 2 | T + IE | `google/paligemma-3b-pt-224`, `google/paligemma-3b-mix-224`, `google/paligemma2-3b-ft-docci-448`, etc. | | ✅︎ |
 | `Phi3VForCausalLM` | Phi-3-Vision, Phi-3.5-Vision | T + IE+ | `microsoft/Phi-3-vision-128k-instruct`, `microsoft/Phi-3.5-vision-instruct`, etc. | | ✅︎ |
 | `Phi4MMForCausalLM` | Phi-4-multimodal | T + I+ / T + A+ / I+ + A+ | `microsoft/Phi-4-multimodal-instruct`, etc. | ✅︎ | ✅︎ |
 | `Phi4MultimodalForCausalLM` | Phi-4-multimodal (HF Transformers) | T + I+ / T + A+ / I+ + A+ | `microsoft/Phi-4-multimodal-instruct` (with revision `refs/pr/70`), etc. | ✅︎ | ✅︎ |
@@ -694,8 +696,6 @@ Some models are supported only via the [Transformers backend](#transformers). Th
 | Architecture | Models | Inputs | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/parallelism_scaling.md) |
 |--------------|--------|--------|-------------------|-----------------------------|-----------------------------------------|
 | `Emu3ForConditionalGeneration` | Emu3 | T + I | `BAAI/Emu3-Chat-hf` | ✅︎ | ✅︎ |
-| `Gemma3ForConditionalGeneration` | Gemma 3 | T + I+ | `google/gemma-3-4b-it`, `google/gemma-3-27b-it`, etc. | ✅︎ | ✅︎ |
-| `PaliGemmaForConditionalGeneration` | PaliGemma, PaliGemma 2 | T + IE | `google/paligemma-3b-pt-224`, `google/paligemma-3b-mix-224`, `google/paligemma2-3b-ft-docci-448`, etc. | ✅︎ | ✅︎ |
 
 ^ You need to set the architecture name via `--hf-overrides` to match the one in vLLM.
     • For example, to use DeepSeek-VL2 series models:
@@ -704,7 +704,21 @@ Some models are supported only via the [Transformers backend](#transformers). Th
 + Multiple items can be inputted per text prompt for this modality.
 
 !!! warning
-    For `Gemma3ForConditionalGeneration`, `{"do_pan_and_scan": true}` is not supported in Transformers backend yet.
+    Both V0 and V1 support `Gemma3ForConditionalGeneration` for text-only inputs.
+    However, there are differences in how they handle text + image inputs:
+
+    V0 correctly implements the model's attention pattern:
+    - Uses bidirectional attention between the image tokens corresponding to the same image
+    - Uses causal attention for other tokens
+    - Implemented via (naive) PyTorch SDPA with masking tensors
+    - Note: May use significant memory for long prompts with image
+
+    V1 currently uses a simplified attention pattern:
+    - Uses causal attention for all tokens, including image tokens
+    - Generates reasonable outputs but does not match the original model's attention for text + image inputs, especially when `{"do_pan_and_scan": true}`
+    - Will be updated in the future to support the correct behavior
+
+    This limitation exists because the model's mixed attention pattern (bidirectional for images, causal otherwise) is not yet supported by vLLM's attention backends.
 
 !!! note
     `Gemma3nForConditionalGeneration` is only supported on V1 due to shared KV caching and it depends on `timm>=1.0.17` to make use of its
@@ -756,6 +770,9 @@ Some models are supported only via the [Transformers backend](#transformers). Th
     The official `openbmb/MiniCPM-V-2` doesn't work yet, so we need to use a fork (`HwwwH/MiniCPM-V-2`) for now.
     For more details, please see: 
 
+!!! warning
+    Our PaliGemma implementations have the same problem as Gemma 3 (see above) for both V0 and V1.
+
 !!! note
     For Qwen2.5-Omni and Qwen3-Omni, reading audio from video pre-processing (`--mm-processor-kwargs '{"use_audio_in_video": true}'`) is currently work in progress and not yet supported.
 
diff --git a/examples/offline_inference/vision_language.py b/examples/offline_inference/vision_language.py
index 35311a0ca7e1..6336b7f7c7ee 100644
--- a/examples/offline_inference/vision_language.py
+++ b/examples/offline_inference/vision_language.py
@@ -275,8 +275,7 @@ def run_gemma3(questions: list[str], modality: str) -> ModelRequestData:
         model=model_name,
         max_model_len=2048,
         max_num_seqs=2,
-        # TODO: Support this in transformers backend
-        # mm_processor_kwargs={"do_pan_and_scan": True},
+        mm_processor_kwargs={"do_pan_and_scan": True},
         limit_mm_per_prompt={modality: 1},
     )
 
diff --git a/tests/models/language/generation/test_gemma.py b/tests/models/language/generation/test_gemma.py
index 5108da68cb0b..246b893be315 100644
--- a/tests/models/language/generation/test_gemma.py
+++ b/tests/models/language/generation/test_gemma.py
@@ -3,7 +3,7 @@
 import numpy as np
 import pytest
 
-MODELS = ["google/gemma-2b", "google/gemma-2-2b"]
+MODELS = ["google/gemma-2b", "google/gemma-2-2b", "google/gemma-3-4b-it"]
 
 
 @pytest.mark.parametrize("model", MODELS)
@@ -14,8 +14,14 @@ def test_dummy_loader(vllm_runner, monkeypatch, model: str) -> None:
             model,
             load_format="dummy",
         ) as llm:
-            normalizers = llm.apply_model(
-                lambda model: model.model.normalizer.cpu().item()
-            )
-            config = llm.llm.llm_engine.model_config.hf_config
+            if model == "google/gemma-3-4b-it":
+                normalizers = llm.llm.collective_rpc(
+                    lambda self: self.model_runner.model.language_model.model.normalizer.cpu().item()  # noqa: E501
+                )
+                config = llm.llm.llm_engine.model_config.hf_config.text_config
+            else:
+                normalizers = llm.llm.collective_rpc(
+                    lambda self: self.model_runner.model.model.normalizer.cpu().item()
+                )
+                config = llm.llm.llm_engine.model_config.hf_config
             assert np.allclose(normalizers, config.hidden_size**0.5, rtol=2e-3)
diff --git a/tests/models/multimodal/generation/test_common.py b/tests/models/multimodal/generation/test_common.py
index 44bbc4479ca4..7118bbc5e780 100644
--- a/tests/models/multimodal/generation/test_common.py
+++ b/tests/models/multimodal/generation/test_common.py
@@ -113,6 +113,25 @@
         dtype="bfloat16" if current_platform.is_cpu() else "auto",
         marks=[pytest.mark.core_model, pytest.mark.cpu_model],
     ),
+    "paligemma": VLMTestInfo(
+        models=["google/paligemma-3b-mix-224"],
+        test_type=VLMTestType.IMAGE,
+        prompt_formatter=identity,
+        img_idx_to_prompt=lambda idx: "",
+        # Paligemma uses its own sample prompts because the default one fails
+        single_image_prompts=IMAGE_ASSETS.prompts(
+            {
+                "stop_sign": "caption es",
+                "cherry_blossom": "What is in the picture?",
+            }
+        ),
+        auto_cls=AutoModelForImageTextToText,
+        vllm_output_post_proc=model_utils.paligemma_vllm_to_hf_output,
+        dtype="bfloat16",
+        marks=[
+            pytest.mark.skip(reason="vLLM does not support PrefixLM attention mask")
+        ],
+    ),
     "qwen2_5_vl": VLMTestInfo(
         models=["Qwen/Qwen2.5-VL-3B-Instruct"],
         test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE, VLMTestType.VIDEO),
@@ -177,24 +196,14 @@
     # Gemma3 has bidirectional mask on images
     "gemma3-transformers": VLMTestInfo(
         models=["google/gemma-3-4b-it"],
-        test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE),
-        prompt_formatter=lambda img_prompt: f"user\n{img_prompt}\nmodel\n",  # noqa: E501
-        single_image_prompts=IMAGE_ASSETS.prompts(
-            {
-                "stop_sign": "What's the content in the center of the image?",  # noqa: E501
-                "cherry_blossom": "What is the season?",
-            }
-        ),
-        multi_image_prompt="Describe the two images in detail.",  # noqa: E501
-        max_model_len=8192,
+        test_type=VLMTestType.IMAGE,
+        prompt_formatter=lambda vid_prompt: f"<'user\n{vid_prompt}\nmodel\n",  # noqa: E501
+        max_model_len=4096,
         auto_cls=AutoModelForImageTextToText,
-        # TODO: Support `do_pan_and_scan` in transformers backend
-        # patch_hf_runner=model_utils.gemma3_patch_hf_runner,
         vllm_output_post_proc=model_utils.gemma3_vllm_to_hf_output,
         image_size_factors=[(0.25, 0.5, 1.0)],
         vllm_runner_kwargs={
             "model_impl": "transformers",
-            # "mm_processor_kwargs": {"do_pan_and_scan": True},
         },
         marks=[pytest.mark.core_model],
     ),
@@ -213,27 +222,6 @@
         },
         marks=[pytest.mark.core_model],
     ),
-    # PaliGemma has PrefixLM attention
-    "paligemma-transformers": VLMTestInfo(
-        models=["google/paligemma-3b-mix-224"],
-        test_type=VLMTestType.IMAGE,
-        prompt_formatter=identity,
-        img_idx_to_prompt=lambda idx: "",
-        # PaliGemma uses its own sample prompts because the default one fails
-        single_image_prompts=IMAGE_ASSETS.prompts(
-            {
-                "stop_sign": "caption es",
-                "cherry_blossom": "What is in the picture?",
-            }
-        ),
-        auto_cls=AutoModelForImageTextToText,
-        vllm_output_post_proc=model_utils.paligemma_vllm_to_hf_output,
-        image_size_factors=[(0.25, 0.5, 1.0)],
-        vllm_runner_kwargs={
-            "model_impl": "transformers",
-        },
-        marks=[pytest.mark.core_model],
-    ),
     # Pixel values from processor are not 4D or 5D arrays
     "qwen2_5_vl-transformers": VLMTestInfo(
         models=["Qwen/Qwen2.5-VL-3B-Instruct"],
@@ -360,6 +348,24 @@
         image_size_factors=[(), (0.25,), (0.25, 0.25, 0.25), (0.25, 0.2, 0.15)],
         marks=[large_gpu_mark(min_gb=32)],
     ),
+    "gemma3": VLMTestInfo(
+        models=["google/gemma-3-4b-it"],
+        test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE),
+        prompt_formatter=lambda img_prompt: f"user\n{img_prompt}\nmodel\n",  # noqa: E501
+        single_image_prompts=IMAGE_ASSETS.prompts(
+            {
+                "stop_sign": "What's the content in the center of the image?",  # noqa: E501
+                "cherry_blossom": "What is the season?",
+            }
+        ),
+        multi_image_prompt="Describe the two images in detail.",  # noqa: E501
+        max_model_len=4096,
+        max_num_seqs=2,
+        auto_cls=AutoModelForImageTextToText,
+        vllm_runner_kwargs={"mm_processor_kwargs": {"do_pan_and_scan": True}},
+        patch_hf_runner=model_utils.gemma3_patch_hf_runner,
+        num_logprobs=10,
+    ),
     "glm4v": VLMTestInfo(
         models=["zai-org/glm-4v-9b"],
         test_type=VLMTestType.IMAGE,
diff --git a/tests/models/multimodal/generation/vlm_utils/model_utils.py b/tests/models/multimodal/generation/vlm_utils/model_utils.py
index 8f0caed4dd4f..0685a01da58f 100644
--- a/tests/models/multimodal/generation/vlm_utils/model_utils.py
+++ b/tests/models/multimodal/generation/vlm_utils/model_utils.py
@@ -328,6 +328,16 @@ def processor(*args, **kwargs):
 
     hf_model.processor = processor
 
+    orig_generate = hf_model.model.generate
+
+    def _generate(self, *args, **kwargs):
+        # FIXME: https://github.com/huggingface/transformers/issues/38333
+        kwargs["disable_compile"] = True
+
+        return orig_generate(*args, **kwargs)
+
+    hf_model.model.generate = types.MethodType(_generate, hf_model.model)
+
     return hf_model
 
 
diff --git a/tests/models/multimodal/processing/test_common.py b/tests/models/multimodal/processing/test_common.py
index 4e693b310277..a7308244523e 100644
--- a/tests/models/multimodal/processing/test_common.py
+++ b/tests/models/multimodal/processing/test_common.py
@@ -222,6 +222,7 @@ def _to_dummy_options(modality: str, count: int) -> BaseDummyOptions:
 _ADD_SPECIAL_TOKENS_OVERRIDES = {
     "ovis": False,
     "ovis2_5": False,
+    "paligemma": False,
     "ultravox": False,
     "whisper": False,
 }
@@ -333,6 +334,7 @@ def _test_processing_correctness_one(
         "deepseek-ai/deepseek-vl2-tiny",
         "baidu/ERNIE-4.5-VL-28B-A3B-PT",
         "adept/fuyu-8b",
+        "google/gemma-3-4b-it",
         "google/gemma-3n-E2B-it",
         "zai-org/glm-4v-9b",
         "zai-org/GLM-4.1V-9B-Thinking",
@@ -369,6 +371,8 @@ def _test_processing_correctness_one(
         "AIDC-AI/Ovis1.6-Llama3.2-3B",
         "AIDC-AI/Ovis2-1B",
         "AIDC-AI/Ovis2.5-2B",
+        "google/paligemma-3b-mix-224",
+        "google/paligemma2-3b-ft-docci-448",
         "microsoft/Phi-3.5-vision-instruct",
         "microsoft/Phi-4-multimodal-instruct",
         "mistralai/Pixtral-12B-2409",
diff --git a/tests/models/multimodal/processing/test_tensor_schema.py b/tests/models/multimodal/processing/test_tensor_schema.py
index c0436e117975..024fe76e6ff7 100644
--- a/tests/models/multimodal/processing/test_tensor_schema.py
+++ b/tests/models/multimodal/processing/test_tensor_schema.py
@@ -48,6 +48,7 @@
     "Idefics3ForConditionalGeneration",
     "LlavaForConditionalGeneration",
     "MiniCPMV",
+    "PaliGemmaForConditionalGeneration",
 ]
 REPO_ID_TO_SKIP = {
     "nm-testing/pixtral-12b-FP8-dynamic": "duplicated test",
diff --git a/vllm/model_executor/models/gemma3_mm.py b/vllm/model_executor/models/gemma3_mm.py
new file mode 100644
index 000000000000..7c628fe93ce3
--- /dev/null
+++ b/vllm/model_executor/models/gemma3_mm.py
@@ -0,0 +1,710 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+import math
+from collections.abc import Iterable, Mapping, Sequence
+from typing import Annotated, Any, Literal
+
+import torch
+from torch import nn
+from transformers import BatchFeature, Gemma3Config, Gemma3Processor
+from transformers.models.gemma3.processing_gemma3 import Gemma3ProcessorKwargs
+
+import vllm.envs as envs
+from vllm.config import VllmConfig
+from vllm.config.multimodal import BaseDummyOptions
+from vllm.logger import init_logger
+from vllm.model_executor.layers.layernorm import GemmaRMSNorm
+from vllm.model_executor.models.module_mapping import MultiModelKeys
+from vllm.multimodal import MULTIMODAL_REGISTRY
+from vllm.multimodal.inputs import (
+    MultiModalDataDict,
+    MultiModalFieldConfig,
+    MultiModalKwargsItems,
+)
+from vllm.multimodal.parse import ImageProcessorItems, ImageSize, MultiModalDataItems
+from vllm.multimodal.processing import (
+    BaseMultiModalProcessor,
+    BaseProcessingInfo,
+    MultiModalPromptUpdates,
+    MultiModalPromptUpdatesApplyResult,
+    PlaceholderFeaturesInfo,
+    PromptReplacement,
+    PromptUpdate,
+    PromptUpdateDetails,
+    replace_token_matches,
+)
+from vllm.multimodal.profiling import BaseDummyInputsBuilder
+from vllm.sequence import IntermediateTensors
+from vllm.utils.tensor_schema import TensorSchema, TensorShape
+
+from .interfaces import (
+    MultiModalEmbeddings,
+    SupportsLoRA,
+    SupportsMultiModal,
+    SupportsPP,
+)
+from .siglip import SiglipVisionModel
+from .utils import (
+    AutoWeightsLoader,
+    WeightsMapper,
+    init_vllm_registered_model,
+    maybe_prefix,
+)
+
+logger = init_logger(__name__)
+
+
+class Gemma3ImagePixelInputs(TensorSchema):
+    """
+    Dimensions:
+        - p: Number of patches total (over each image over each prompt in the
+          batch)
+        - c: Number of channels (3)
+        - h: Height of each patch
+        - w: Width of each patch
+        - bn: Batch size * number of images
+    """
+
+    type: Literal["pixel_values"] = "pixel_values"
+
+    pixel_values: Annotated[torch.Tensor, TensorShape("p", 3, "h", "w")]
+
+    num_patches: Annotated[torch.Tensor, TensorShape("bn")]
+
+
+Gemma3ImageInputs = Gemma3ImagePixelInputs
+
+
+class Gemma3ProcessingInfo(BaseProcessingInfo):
+    def get_hf_config(self):
+        return self.ctx.get_hf_config(Gemma3Config)
+
+    def get_hf_processor(self, **kwargs: object):
+        return self.ctx.get_hf_processor(Gemma3Processor, **kwargs)
+
+    def get_supported_mm_limits(self) -> Mapping[str, int | None]:
+        return {"image": None}
+
+    def _resolve_image_kwargs(
+        self,
+        processor: Gemma3Processor,
+        keys: set[str],
+    ) -> dict[str, Any]:
+        image_processor = processor.image_processor
+        kwargs = processor._merge_kwargs(
+            Gemma3ProcessorKwargs,
+            tokenizer_init_kwargs=processor.tokenizer.init_kwargs,
+        )
+
+        images_kwargs = kwargs["images_kwargs"]
+
+        def _resolve_kw(key: str):
+            val = getattr(image_processor, key)
+            if val is None:
+                val = images_kwargs[key]
+
+            return val
+
+        return {k: _resolve_kw(k) for k in keys}
+
+    def get_num_crops(
+        self,
+        *,
+        image_width: int,
+        image_height: int,
+        processor: Gemma3Processor | None,
+    ) -> int:
+        if processor is None:
+            processor = self.get_hf_processor()
+
+        images_kwargs = self._resolve_image_kwargs(
+            processor,
+            {
+                "do_pan_and_scan",
+                "pan_and_scan_min_crop_size",
+                "pan_and_scan_max_num_crops",
+                "pan_and_scan_min_ratio_to_activate",
+            },
+        )
+
+        do_pan_and_scan = images_kwargs["do_pan_and_scan"]
+        pan_and_scan_min_crop_size = images_kwargs["pan_and_scan_min_crop_size"]
+        pan_and_scan_max_num_crops = images_kwargs["pan_and_scan_max_num_crops"]
+        pan_and_scan_min_ratio_to_activate = images_kwargs[
+            "pan_and_scan_min_ratio_to_activate"
+        ]
+
+        if not do_pan_and_scan:
+            return 0
+
+        if envs.VLLM_USE_V1:
+            logger.warning_once(
+                "`do_pan_and_scan=True` has suboptimal results on V1 "
+                "because of the simplified attention pattern being used."
+            )
+
+        # Based on Gemma3ImageProcessor.pan_and_scan
+        if image_width >= image_height:
+            if image_width / image_height < pan_and_scan_min_ratio_to_activate:
+                return 0
+
+            num_crops_w = min(
+                int(math.floor(image_width / pan_and_scan_min_crop_size)),
+                int(math.floor(image_width / image_height + 0.5)),
+            )
+
+            num_crops_w = max(2, num_crops_w)
+            num_crops_w = min(pan_and_scan_max_num_crops, num_crops_w)
+            num_crops_h = 1
+        else:
+            if image_height / image_width < pan_and_scan_min_ratio_to_activate:
+                return 0
+
+            num_crops_h = min(
+                int(math.floor(image_height / pan_and_scan_min_crop_size)),
+                int(math.floor(image_height / image_width + 0.5)),
+            )
+
+            num_crops_h = max(2, num_crops_h)
+            num_crops_h = min(pan_and_scan_max_num_crops, num_crops_h)
+            num_crops_w = 1
+
+        crop_size_w = int(math.ceil(image_width / num_crops_w))
+        crop_size_h = int(math.ceil(image_height / num_crops_h))
+
+        if min(crop_size_w, crop_size_h) < pan_and_scan_min_crop_size:
+            return 0
+
+        return num_crops_w * num_crops_h
+
+    def get_image_repl(
+        self,
+        *,
+        image_width: int,
+        image_height: int,
+        processor: Gemma3Processor | None,
+    ) -> PromptUpdateDetails[str]:
+        if processor is None:
+            processor = self.get_hf_processor()
+
+        boi_token = processor.boi_token
+
+        num_crops = self.get_num_crops(
+            image_width=image_width,
+            image_height=image_height,
+            processor=processor,
+        )
+
+        if num_crops == 0:
+            image_text = boi_token
+        else:
+            crops_image_tokens = " ".join(boi_token for _ in range(num_crops))
+            image_text = (
+                f"Here is the original image {boi_token} and here are some "
+                f"crops to help you see better {crops_image_tokens}"
+            )
+
+        repl_full = image_text.replace(boi_token, processor.full_image_sequence)
+
+        tokenizer = processor.tokenizer
+        vocab = tokenizer.get_vocab()
+        image_token_id = vocab[tokenizer.image_token]
+
+        return PromptUpdateDetails.select_token_id(repl_full, image_token_id)
+
+    def get_num_image_tokens(
+        self,
+        *,
+        image_width: int,
+        image_height: int,
+        processor: Gemma3Processor | None,
+    ) -> int:
+        if processor is None:
+            processor = self.get_hf_processor()
+
+        num_crops = self.get_num_crops(
+            image_width=image_width,
+            image_height=image_height,
+            processor=processor,
+        )
+        image_seq_len = processor.image_seq_length
+
+        return (num_crops + 1) * image_seq_len
+
+    def get_image_size_with_most_features(self) -> ImageSize:
+        processor = self.get_hf_processor()
+
+        images_kwargs = self._resolve_image_kwargs(
+            processor, {"pan_and_scan_max_num_crops"}
+        )
+        max_num_crops = images_kwargs["pan_and_scan_max_num_crops"]
+
+        # Result in the max possible feature size (h:w = max_num_crops:1)
+        return ImageSize(height=50 * max_num_crops, width=50)
+
+
+class Gemma3DummyInputsBuilder(BaseDummyInputsBuilder[Gemma3ProcessingInfo]):
+    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
+        num_images = mm_counts.get("image", 0)
+
+        processor = self.info.get_hf_processor()
+        image_token = processor.boi_token
+
+        return image_token * num_images
+
+    def get_dummy_mm_data(
+        self,
+        seq_len: int,
+        mm_counts: Mapping[str, int],
+        mm_options: Mapping[str, BaseDummyOptions] | None = None,
+    ) -> MultiModalDataDict:
+        num_images = mm_counts.get("image", 0)
+
+        target_width, target_height = self.info.get_image_size_with_most_features()
+
+        image_overrides = mm_options.get("image") if mm_options else None
+
+        return {
+            "image": self._get_dummy_images(
+                width=target_width,
+                height=target_height,
+                num_images=num_images,
+                overrides=image_overrides,
+            )
+        }
+
+
+class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]):
+    def _call_hf_processor(
+        self,
+        prompt: str,
+        mm_data: Mapping[str, object],
+        mm_kwargs: Mapping[str, object],
+        tok_kwargs: Mapping[str, object],
+    ) -> BatchFeature:
+        processed_outputs = super()._call_hf_processor(
+            prompt,
+            mm_data,
+            mm_kwargs,
+            tok_kwargs,
+        )
+
+        # HF processor pops the `num_crops` kwarg, which is needed by vLLM
+        if (images := mm_data.get("images")) is not None:
+            parsed_images = (
+                self._get_data_parser()
+                .parse_mm_data({"image": images})
+                .get_items("image", ImageProcessorItems)
+            )
+            image_sizes = [
+                parsed_images.get_image_size(i) for i in range(len(parsed_images))
+            ]
+            hf_processor = self.info.get_hf_processor(**mm_kwargs)
+
+            num_crops = [
+                self.info.get_num_crops(
+                    image_width=size.width,
+                    image_height=size.height,
+                    processor=hf_processor,
+                )
+                for size in image_sizes
+            ]
+            processed_outputs["num_patches"] = torch.tensor(num_crops) + 1
+
+        return processed_outputs
+
+    def _get_mm_fields_config(
+        self,
+        hf_inputs: BatchFeature,
+        hf_processor_mm_kwargs: Mapping[str, object],
+    ) -> Mapping[str, MultiModalFieldConfig]:
+        num_patches = hf_inputs.get("num_patches", torch.empty(0))
+
+        return dict(
+            pixel_values=MultiModalFieldConfig.flat_from_sizes("image", num_patches),
+            num_patches=MultiModalFieldConfig.batched("image"),
+        )
+
+    def _get_prompt_updates(
+        self,
+        mm_items: MultiModalDataItems,
+        hf_processor_mm_kwargs: Mapping[str, Any],
+        out_mm_kwargs: MultiModalKwargsItems,
+    ) -> Sequence[PromptUpdate]:
+        hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
+        image_token = hf_processor.boi_token
+
+        def get_replacement_gemma3(item_idx: int):
+            images = mm_items.get_items("image", ImageProcessorItems)
+
+            image_size = images.get_image_size(item_idx)
+            return self.info.get_image_repl(
+                image_width=image_size.width,
+                image_height=image_size.height,
+                processor=hf_processor,
+            )
+
+        return [
+            PromptReplacement(
+                modality="image",
+                target=image_token,
+                replacement=get_replacement_gemma3,
+            )
+        ]
+
+    def _apply_token_matches(
+        self,
+        prompt: list[int],
+        mm_prompt_updates: MultiModalPromptUpdates,
+    ) -> tuple[list[int], MultiModalPromptUpdatesApplyResult]:
+        token_ids, res = super()._apply_token_matches(prompt, mm_prompt_updates)
+
+        # "\n\n\n" and "\n\n\n\n" are single tokens
+        # Since our replacement can insert "\n\n" next to "\n"
+        # tokens, we have to combine them to be consistent with
+        # the output of the tokenizer
+        tokenizer = self.info.get_tokenizer()
+        vocab = tokenizer.get_vocab()
+        newline_1 = vocab["\n"]
+        newline_2 = vocab["\n\n"]
+        newline_3 = vocab["\n\n\n"]
+        newline_4 = vocab["\n\n\n\n"]
+
+        token_ids = replace_token_matches(
+            token_ids,
+            [newline_1, newline_2],
+            [newline_3],
+        )
+        token_ids = replace_token_matches(
+            token_ids,
+            [newline_2, newline_1],
+            [newline_3],
+        )
+        token_ids = replace_token_matches(
+            token_ids,
+            [newline_2, newline_2],
+            [newline_4],
+        )
+
+        return token_ids, res
+
+    def _find_mm_placeholders(
+        self,
+        new_token_ids: list[int],
+        mm_prompt_updates: MultiModalPromptUpdates,
+    ) -> Mapping[str, list[PlaceholderFeaturesInfo]]:
+        # We need to detect "\n\n" inside "\n\n\n" and "\n\n\n\n"
+        tokenizer = self.info.get_tokenizer()
+        vocab = tokenizer.get_vocab()
+        newline_1 = vocab["\n"]
+        newline_2 = vocab["\n\n"]
+        newline_3 = vocab["\n\n\n"]
+        newline_4 = vocab["\n\n\n\n"]
+
+        def get_repl_toks(tok: int) -> list[int]:
+            if tok == newline_3:
+                return [newline_1, newline_2]
+            if tok == newline_4:
+                return [newline_2, newline_2]
+
+            return [tok]
+
+        repl_token_ids = list[int]()
+        repl_orig_idxs = list[int]()
+        for orig_idx, orig_tok in enumerate(new_token_ids):
+            repl_toks = get_repl_toks(orig_tok)
+            repl_token_ids.extend(repl_toks)
+            repl_orig_idxs.extend(orig_idx for _ in range(len(repl_toks)))
+
+        repls = super()._find_mm_placeholders(repl_token_ids, mm_prompt_updates)
+
+        return {
+            modality: [
+                PlaceholderFeaturesInfo(
+                    modality=p.modality,
+                    item_idx=p.item_idx,
+                    start_idx=repl_orig_idxs[p.start_idx],
+                    tokens=p.tokens,
+                    is_embed=p.is_embed,
+                )
+                for p in placeholders
+            ]
+            for modality, placeholders in repls.items()
+        }
+
+
+class Gemma3MultiModalProjector(nn.Module):
+    def __init__(self, config: Gemma3Config):
+        super().__init__()
+
+        self.mm_input_projection_weight = nn.Parameter(
+            torch.zeros(
+                config.vision_config.hidden_size, config.text_config.hidden_size
+            )
+        )
+
+        self.mm_soft_emb_norm = GemmaRMSNorm(
+            config.vision_config.hidden_size, eps=config.vision_config.layer_norm_eps
+        )
+
+        self.patches_per_image = int(
+            config.vision_config.image_size // config.vision_config.patch_size
+        )
+        self.tokens_per_side = int(config.mm_tokens_per_image**0.5)
+        self.kernel_size = self.patches_per_image // self.tokens_per_side
+        self.avg_pool = nn.AvgPool2d(
+            kernel_size=self.kernel_size, stride=self.kernel_size
+        )
+
+    def forward(self, vision_outputs: torch.Tensor):
+        batch_size, _, seq_length = vision_outputs.shape
+
+        reshaped_vision_outputs = vision_outputs.transpose(1, 2)
+        reshaped_vision_outputs = reshaped_vision_outputs.reshape(
+            batch_size, seq_length, self.patches_per_image, self.patches_per_image
+        )
+        reshaped_vision_outputs = reshaped_vision_outputs.contiguous()
+
+        pooled_vision_outputs = self.avg_pool(reshaped_vision_outputs)
+        pooled_vision_outputs = pooled_vision_outputs.flatten(2)
+        pooled_vision_outputs = pooled_vision_outputs.transpose(1, 2)
+
+        normed_vision_outputs = self.mm_soft_emb_norm(pooled_vision_outputs)
+
+        projected_vision_outputs = torch.matmul(
+            normed_vision_outputs, self.mm_input_projection_weight
+        )
+        return projected_vision_outputs.type_as(vision_outputs)
+
+
+@MULTIMODAL_REGISTRY.register_processor(
+    Gemma3MultiModalProcessor,
+    info=Gemma3ProcessingInfo,
+    dummy_inputs=Gemma3DummyInputsBuilder,
+)
+class Gemma3ForConditionalGeneration(
+    nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA
+):
+    merge_by_field_config = True
+
+    packed_modules_mapping = {
+        "qkv_proj": [
+            "q_proj",
+            "k_proj",
+            "v_proj",
+        ],
+        "gate_up_proj": [
+            "gate_proj",
+            "up_proj",
+        ],
+    }
+
+    hf_to_vllm_mapper = WeightsMapper(
+        orig_to_new_prefix={
+            # mapping for new names in checkpoint saved after transformers v4.52
+            "model.language_model.": "language_model.model.",
+            "model.vision_tower.": "vision_tower.",
+            "model.multi_modal_projector.": "multi_modal_projector.",
+            "lm_head.": "language_model.lm_head.",
+        }
+    )
+
+    @classmethod
+    def get_placeholder_str(cls, modality: str, i: int) -> str | None:
+        if modality.startswith("image"):
+            return ""
+
+        raise ValueError("Only image modality is supported")
+
+    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
+        super().__init__()
+        config = vllm_config.model_config.hf_config
+        quant_config = vllm_config.quant_config
+        multimodal_config = vllm_config.model_config.multimodal_config
+        self.config = config
+        self.quant_config = quant_config
+        self.multimodal_config = multimodal_config
+
+        self.vision_tower = SiglipVisionModel(
+            config.vision_config,
+            quant_config,
+            prefix=maybe_prefix(prefix, "vision_tower"),
+        )
+        self.multi_modal_projector = Gemma3MultiModalProjector(config)
+
+        self.language_model = init_vllm_registered_model(
+            vllm_config=vllm_config,
+            hf_config=config.text_config,
+            prefix=maybe_prefix(prefix, "language_model"),
+            architectures=["Gemma3ForCausalLM"],
+        )
+        logit_scale = getattr(config, "logit_scale", 1.0)
+
+        if hasattr(self.language_model, "logits_processor"):
+            # The logits processor can be unset if we're using
+            # automatic conversion to pooling model.
+            self.language_model.logits_processor.scale *= logit_scale
+
+        self.make_empty_intermediate_tensors = (
+            self.language_model.make_empty_intermediate_tensors
+        )
+
+    @property
+    def dtype(self):
+        return next(self.parameters()).dtype
+
+    def _parse_and_validate_image_input(
+        self, **kwargs: object
+    ) -> Gemma3ImageInputs | None:
+        pixel_values = kwargs.pop("pixel_values", None)
+        num_patches = kwargs.pop("num_patches", None)
+        image_embeds = kwargs.pop("image_embeds", None)
+        assert image_embeds is None, "Gemma3 does not support image_embeds."
+        if pixel_values is None:
+            return None
+
+        image_size = self.config.vision_config.image_size
+
+        return Gemma3ImagePixelInputs(
+            pixel_values=pixel_values,
+            num_patches=num_patches,
+            resolve_bindings={"h": image_size, "w": image_size},
+        )
+
+    def _image_pixels_to_features(
+        self,
+        vision_tower: SiglipVisionModel,
+        pixel_values: torch.Tensor,
+    ) -> torch.Tensor:
+        return vision_tower(pixel_values)
+
+    def _process_image_input(
+        self,
+        image_input: Gemma3ImageInputs,
+    ) -> list[torch.Tensor]:
+        assert self.vision_tower is not None
+
+        pixel_values = image_input["pixel_values"]
+        num_patches = image_input["num_patches"]
+
+        image_features = self._image_pixels_to_features(
+            self.vision_tower,
+            pixel_values,
+        )
+        image_embeds = self.multi_modal_projector(image_features)
+
+        return [e.flatten(0, 1) for e in image_embeds.split(num_patches.tolist())]
+
+    def get_language_model(self) -> torch.nn.Module:
+        return self.language_model
+
+    def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings:
+        image_input = self._parse_and_validate_image_input(**kwargs)
+        if image_input is None:
+            return []
+
+        return self._process_image_input(image_input)
+
+    def forward(
+        self,
+        input_ids: torch.Tensor,
+        positions: torch.Tensor,
+        intermediate_tensors: IntermediateTensors | None = None,
+        inputs_embeds: torch.Tensor | None = None,
+        **kwargs: object,
+    ) -> IntermediateTensors:
+        if intermediate_tensors is not None:
+            inputs_embeds = None
+
+        hidden_states = self.language_model.model(
+            input_ids,
+            positions,
+            intermediate_tensors,
+            inputs_embeds=inputs_embeds,
+            **kwargs,
+        )
+
+        return hidden_states
+
+    def prepare_attn_masks(
+        self,
+        input_ids: torch.Tensor,
+        positions: torch.Tensor,
+        mask_dtype: torch.dtype,
+        **kwargs,
+    ):
+        kwargs["has_images"] = True
+        # NOTE(woosuk): Here, we distinguish the sequences by the position id 0.
+        # This is a HACK. Fix this.
+        start_indices = (positions == 0).cpu().nonzero()
+        num_seqs = len(start_indices)
+        seq_lens = []
+        for i in range(num_seqs):
+            start_idx = start_indices[i].item()
+            if i < num_seqs - 1:
+                end_idx = start_indices[i + 1].item()
+            else:
+                end_idx = len(input_ids)
+            seq_lens.append(end_idx - start_idx)
+        kwargs["seq_lens"] = seq_lens
+
+        global_attn_masks = []
+        local_attn_masks = []
+        start_idx = 0
+        for seq_len in seq_lens:
+            end_idx = start_idx + seq_len
+            input_token_ids = input_ids[start_idx:end_idx]
+            start_idx = end_idx
+            # Create a global causal mask.
+            global_attn_mask = torch.empty(
+                1,
+                1,
+                seq_len,
+                seq_len,
+                dtype=mask_dtype,
+                device=input_ids.device,
+            )
+            global_attn_mask.fill_(float("-inf"))
+            # Fill the lower triangle with 0.
+            global_attn_mask = global_attn_mask.triu(diagonal=1)
+
+            # Consider the bidirectional attention between image tokens.
+            img_mask = torch.zeros_like(global_attn_mask)
+            img_pos = input_token_ids == self.config.image_token_index
+            img_mask[:, :, :, img_pos] += 1
+            img_mask[:, :, img_pos, :] += 1
+            global_attn_mask = torch.where(img_mask == 2, 0, global_attn_mask)
+            global_attn_masks.append(global_attn_mask)
+
+            sliding_window = self.config.text_config.sliding_window
+            if sliding_window is not None:
+                # Create a local causal mask with sliding window (1024).
+                local_attn_mask = torch.ones_like(global_attn_mask)
+                local_attn_mask = torch.tril(local_attn_mask, diagonal=-sliding_window)
+                local_attn_mask = torch.where(
+                    local_attn_mask == 0, global_attn_mask, float("-inf")
+                )
+                local_attn_masks.append(local_attn_mask)
+        kwargs["global_attn_masks"] = global_attn_masks
+        kwargs["local_attn_masks"] = local_attn_masks
+        return kwargs
+
+    def compute_logits(
+        self,
+        hidden_states: torch.Tensor,
+    ) -> torch.Tensor | None:
+        return self.language_model.compute_logits(hidden_states)
+
+    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
+        loader = AutoWeightsLoader(self)
+        return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
+
+    def get_mm_mapping(self) -> MultiModelKeys:
+        """
+        Get the module prefix in multimodal models
+        """
+        return MultiModelKeys.from_string_field(
+            language_model="language_model",
+            connector="multi_modal_projector",
+            tower_model="vision_tower",
+        )
diff --git a/vllm/model_executor/models/paligemma.py b/vllm/model_executor/models/paligemma.py
new file mode 100644
index 000000000000..fb0b4b290467
--- /dev/null
+++ b/vllm/model_executor/models/paligemma.py
@@ -0,0 +1,412 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+from collections.abc import Iterable, Mapping, Sequence
+from typing import Annotated, Literal, TypeAlias
+
+import torch
+from torch import nn
+from transformers import BatchFeature, PaliGemmaConfig
+
+from vllm.config import VllmConfig
+from vllm.config.multimodal import BaseDummyOptions
+from vllm.logger import init_logger
+from vllm.multimodal import MULTIMODAL_REGISTRY
+from vllm.multimodal.inputs import (
+    MultiModalDataDict,
+    MultiModalFieldConfig,
+    MultiModalInputs,
+    MultiModalKwargsItems,
+    MultiModalUUIDDict,
+)
+from vllm.multimodal.parse import (
+    ImageEmbeddingItems,
+    ImageProcessorItems,
+    MultiModalDataItems,
+)
+from vllm.multimodal.processing import (
+    BaseMultiModalProcessor,
+    BaseProcessingInfo,
+    PromptIndexTargets,
+    PromptInsertion,
+    PromptUpdate,
+    PromptUpdateDetails,
+)
+from vllm.multimodal.profiling import BaseDummyInputsBuilder
+from vllm.sequence import IntermediateTensors
+from vllm.utils.tensor_schema import TensorSchema, TensorShape
+
+from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
+from .siglip import SiglipVisionModel
+from .utils import (
+    AutoWeightsLoader,
+    WeightsMapper,
+    flatten_bn,
+    init_vllm_registered_model,
+    maybe_prefix,
+)
+from .vision import get_vision_encoder_info
+
+logger = init_logger(__name__)
+
+
+class PaliGemmaImagePixelInputs(TensorSchema):
+    """
+    Dimensions:
+        - bn: Batch size * number of images
+        - c: Number of channels (3)
+        - h: Height
+        - w: Width
+    """
+
+    type: Literal["pixel_values"] = "pixel_values"
+    data: Annotated[torch.Tensor, TensorShape("bn", 3, "h", "w")]
+
+
+class PaliGemmaImageEmbeddingInputs(TensorSchema):
+    """
+    Dimensions:
+        - bn: Batch size * number of images
+        - ifs: Image feature size
+        - hs: Hidden size (must match language model backbone)
+    """
+
+    type: Literal["image_embeds"] = "image_embeds"
+    data: Annotated[torch.Tensor, TensorShape("bn", "ifs", "hs")]
+
+
+PaliGemmaImageInputs: TypeAlias = (
+    PaliGemmaImagePixelInputs | PaliGemmaImageEmbeddingInputs
+)
+
+
+class PaliGemmaMultiModalProjector(nn.Module):
+    def __init__(self, vision_hidden_size: int, projection_dim: int):
+        super().__init__()
+
+        self.linear = nn.Linear(vision_hidden_size, projection_dim, bias=True)
+
+    def forward(self, image_features: torch.Tensor) -> torch.Tensor:
+        hidden_states = self.linear(image_features)
+        return hidden_states
+
+
+class PaliGemmaProcessingInfo(BaseProcessingInfo):
+    def get_hf_config(self):
+        return self.ctx.get_hf_config(PaliGemmaConfig)
+
+    def get_vision_encoder_info(self):
+        return get_vision_encoder_info(self.get_hf_config())
+
+    def get_supported_mm_limits(self) -> Mapping[str, int | None]:
+        return {"image": 1}
+
+    def get_num_image_tokens(
+        self,
+        *,
+        image_width: int,
+        image_height: int,
+    ) -> int:
+        vision_encoder_info = self.get_vision_encoder_info()
+
+        return vision_encoder_info.get_num_image_tokens(
+            image_width=image_width,
+            image_height=image_height,
+        )
+
+
+class PaliGemmaDummyInputsBuilder(BaseDummyInputsBuilder[PaliGemmaProcessingInfo]):
+    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
+        return ""
+
+    def get_dummy_mm_data(
+        self,
+        seq_len: int,
+        mm_counts: Mapping[str, int],
+        mm_options: Mapping[str, BaseDummyOptions] | None = None,
+    ) -> MultiModalDataDict:
+        hf_config = self.info.get_hf_config()
+        vision_config = hf_config.vision_config
+        max_image_size = vision_config.image_size
+
+        num_images = mm_counts.get("image", 0)
+
+        image_overrides = mm_options.get("image") if mm_options else None
+
+        return {
+            "image": self._get_dummy_images(
+                width=max_image_size,
+                height=max_image_size,
+                num_images=num_images,
+                overrides=image_overrides,
+            )
+        }
+
+
+class PaliGemmaMultiModalProcessor(BaseMultiModalProcessor[PaliGemmaProcessingInfo]):
+    def _call_hf_processor(
+        self,
+        prompt: str,
+        mm_data: Mapping[str, object],
+        mm_kwargs: Mapping[str, object],
+        tok_kwargs: Mapping[str, object],
+    ) -> BatchFeature:
+        tokenizer = self.info.get_tokenizer()
+        if not mm_data:
+            prompt_ids = tokenizer.encode(prompt, add_special_tokens=False)
+            return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt")
+
+        return super()._call_hf_processor(
+            prompt=prompt,
+            mm_data=mm_data,
+            mm_kwargs=mm_kwargs,
+            tok_kwargs=tok_kwargs,
+        )
+
+    def _get_mm_fields_config(
+        self,
+        hf_inputs: BatchFeature,
+        hf_processor_mm_kwargs: Mapping[str, object],
+    ) -> Mapping[str, MultiModalFieldConfig]:
+        return dict(pixel_values=MultiModalFieldConfig.batched("image"))
+
+    def _get_prompt_updates(
+        self,
+        mm_items: MultiModalDataItems,
+        hf_processor_mm_kwargs: Mapping[str, object],
+        out_mm_kwargs: MultiModalKwargsItems,
+    ) -> Sequence[PromptUpdate]:
+        hf_config = self.info.get_hf_config()
+        image_token_id = hf_config.image_token_index
+
+        tokenizer = self.info.get_tokenizer()
+
+        bos_token_id = tokenizer.bos_token_id
+        assert isinstance(bos_token_id, int)
+
+        def get_insertion(item_idx: int):
+            images = mm_items.get_items(
+                "image", (ImageEmbeddingItems, ImageProcessorItems)
+            )
+
+            if isinstance(images, ImageEmbeddingItems):
+                num_image_tokens = images.get_feature_size(item_idx)
+            else:
+                image_size = images.get_image_size(item_idx)
+                num_image_tokens = self.info.get_num_image_tokens(
+                    image_width=image_size.width,
+                    image_height=image_size.height,
+                )
+
+            image_tokens = [image_token_id] * num_image_tokens
+
+            return PromptUpdateDetails.select_token_id(
+                image_tokens + [bos_token_id],
+                embed_token_id=image_token_id,
+            )
+
+        # Paligemma 1 and 2 have different tokenizer.add_bos_token
+        # Insert *n +  after  for Paligemma 1
+        # Insert *n +  for Paligemma 2
+        return [
+            PromptInsertion(
+                modality="image",
+                target=PromptIndexTargets.prefix(
+                    [bos_token_id] if tokenizer.add_bos_token else []
+                ),
+                insertion=get_insertion,
+            )
+        ]
+
+    def apply(
+        self,
+        prompt: str | list[int],
+        mm_data: MultiModalDataDict,
+        hf_processor_mm_kwargs: Mapping[str, object],
+        tokenization_kwargs: Mapping[str, object] | None = None,
+        mm_uuids: MultiModalUUIDDict | None = None,
+    ) -> MultiModalInputs:
+        mm_inputs = super().apply(
+            prompt,
+            mm_data,
+            hf_processor_mm_kwargs,
+            tokenization_kwargs,
+            mm_uuids=mm_uuids,
+        )
+        prompt_token_ids = mm_inputs["prompt_token_ids"]
+
+        tokenizer = self.info.get_tokenizer()
+        newline_prompt = "\n"
+        newline_token_id = tokenizer.encode(newline_prompt)[-1]  # 108
+        # Force to add newline at the end of prompt for paligemma's format
+        # This step can NOT be replacemented by current PromptUpdate methods
+        if len(prompt_token_ids) and prompt_token_ids[-1] != newline_token_id:
+            prompt_token_ids.append(newline_token_id)
+            mm_inputs["prompt_token_ids"] = prompt_token_ids
+
+        return mm_inputs
+
+
+@MULTIMODAL_REGISTRY.register_processor(
+    PaliGemmaMultiModalProcessor,
+    info=PaliGemmaProcessingInfo,
+    dummy_inputs=PaliGemmaDummyInputsBuilder,
+)
+class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
+    packed_modules_mapping = {
+        "qkv_proj": [
+            "q_proj",
+            "k_proj",
+            "v_proj",
+        ],
+        "gate_up_proj": [
+            "gate_proj",
+            "up_proj",
+        ],
+    }
+
+    hf_to_vllm_mapper = WeightsMapper(
+        orig_to_new_prefix={
+            # mapping for new names in checkpoint saved after transformers v4.52
+            "model.language_model.": "language_model.model.",
+            "model.vision_tower.": "vision_tower.",
+            "model.multi_modal_projector.": "multi_modal_projector.",
+            "lm_head.": "language_model.lm_head.",
+        }
+    )
+
+    @classmethod
+    def get_placeholder_str(cls, modality: str, i: int) -> str | None:
+        if modality.startswith("image"):
+            return None
+
+        raise ValueError("Only image modality is supported")
+
+    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
+        super().__init__()
+        config = vllm_config.model_config.hf_config
+        quant_config = vllm_config.quant_config
+        multimodal_config = vllm_config.model_config.multimodal_config
+        self.config = config
+        self.multimodal_config = multimodal_config
+
+        self.vision_tower = SiglipVisionModel(
+            config.vision_config,
+            quant_config,
+            prefix=maybe_prefix(prefix, "vision_tower"),
+        )
+        self.multi_modal_projector = PaliGemmaMultiModalProjector(
+            vision_hidden_size=config.vision_config.hidden_size,
+            projection_dim=config.vision_config.projection_dim,
+        )
+
+        self.quant_config = quant_config
+
+        if config.text_config.model_type == "gemma":
+            config.text_config.architectures = ["GemmaForCausalLM"]
+        else:
+            config.text_config.architectures = ["Gemma2ForCausalLM"]
+        self.language_model = init_vllm_registered_model(
+            vllm_config=vllm_config,
+            hf_config=config.text_config,
+            prefix=maybe_prefix(prefix, "language_model"),
+        )
+        logit_scale = getattr(config, "logit_scale", 1.0)
+        self.language_model.logits_processor.scale *= logit_scale
+
+        self.make_empty_intermediate_tensors = (
+            self.language_model.make_empty_intermediate_tensors
+        )
+
+    def _parse_and_validate_image_input(
+        self, **kwargs: object
+    ) -> PaliGemmaImageInputs | None:
+        pixel_values = kwargs.pop("pixel_values", None)
+        image_embeds = kwargs.pop("image_embeds", None)
+
+        if pixel_values is None and image_embeds is None:
+            return None
+
+        if pixel_values is not None:
+            pixel_values = flatten_bn(pixel_values, concat=True)
+
+            h = w = self.config.vision_config.image_size
+            return PaliGemmaImagePixelInputs(
+                type="pixel_values",
+                data=pixel_values,
+                resolve_bindings={"h": h, "w": w},
+            )
+
+        if image_embeds is not None:
+            image_embeds = flatten_bn(image_embeds, concat=True)
+
+            return PaliGemmaImageEmbeddingInputs(
+                type="image_embeds",
+                data=image_embeds,
+            )
+
+        raise AssertionError("This line should be unreachable.")
+
+    def _image_pixels_to_features(
+        self,
+        vision_tower: SiglipVisionModel,
+        pixel_values: torch.Tensor,
+    ) -> torch.Tensor:
+        target_dtype = vision_tower.get_input_embeddings().weight.dtype
+        image_features = vision_tower(pixel_values.to(dtype=target_dtype))
+
+        return image_features
+
+    def _process_image_input(
+        self,
+        image_input: PaliGemmaImageInputs,
+    ) -> torch.Tensor:
+        if image_input["type"] == "image_embeds":
+            return image_input["data"]
+
+        assert self.vision_tower is not None
+        pixel_values = image_input["data"]
+        image_features = self._image_pixels_to_features(
+            self.vision_tower,
+            pixel_values,
+        )
+
+        return self.multi_modal_projector(image_features)
+
+    def get_language_model(self) -> torch.nn.Module:
+        return self.language_model
+
+    def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings:
+        image_input = self._parse_and_validate_image_input(**kwargs)
+        if image_input is None:
+            return []
+        vision_embeddings = self._process_image_input(image_input)
+        # https://github.com/huggingface/transformers/blob/main/src/transformers/models/paligemma/modeling_paligemma.py#L294 # noqa
+        vision_embeddings = vision_embeddings * (self.config.hidden_size**-0.5)
+        return vision_embeddings
+
+    def forward(
+        self,
+        input_ids: torch.Tensor,
+        positions: torch.Tensor,
+        intermediate_tensors: IntermediateTensors | None = None,
+        inputs_embeds: torch.Tensor | None = None,
+        **kwargs: object,
+    ) -> IntermediateTensors:
+        if intermediate_tensors is not None:
+            inputs_embeds = None
+
+        hidden_states = self.language_model.model(
+            input_ids, positions, intermediate_tensors, inputs_embeds=inputs_embeds
+        )
+
+        return hidden_states
+
+    def compute_logits(
+        self,
+        hidden_states: torch.Tensor,
+    ) -> torch.Tensor | None:
+        return self.language_model.compute_logits(hidden_states)
+
+    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
+        loader = AutoWeightsLoader(self)
+        return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py
index da1606a7568d..2fa269f77755 100644
--- a/vllm/model_executor/models/registry.py
+++ b/vllm/model_executor/models/registry.py
@@ -264,6 +264,7 @@
         "Ernie4_5_VLMoeForConditionalGeneration",
     ),
     "FuyuForCausalLM": ("fuyu", "FuyuForCausalLM"),
+    "Gemma3ForConditionalGeneration": ("gemma3_mm", "Gemma3ForConditionalGeneration"),  # noqa: E501
     "Gemma3nForConditionalGeneration": (
         "gemma3n_mm",
         "Gemma3nForConditionalGeneration",
@@ -333,6 +334,10 @@
     "NVLM_D": ("nvlm_d", "NVLM_D_Model"),
     "Ovis": ("ovis", "Ovis"),
     "Ovis2_5": ("ovis2_5", "Ovis2_5"),
+    "PaliGemmaForConditionalGeneration": (
+        "paligemma",
+        "PaliGemmaForConditionalGeneration",
+    ),
     "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
     "Phi4MMForCausalLM": ("phi4mm", "Phi4MMForCausalLM"),
     "Phi4MultimodalForCausalLM": ("phi4_multimodal", "Phi4MultimodalForCausalLM"),  # noqa: E501
@@ -405,14 +410,6 @@
         "transformers",
         "TransformersMultiModalForCausalLM",
     ),
-    "Gemma3ForConditionalGeneration": (
-        "transformers",
-        "TransformersMultiModalForCausalLM",
-    ),
-    "PaliGemmaForConditionalGeneration": (
-        "transformers",
-        "TransformersMultiModalForCausalLM",
-    ),
 }
 
 _TRANSFORMERS_BACKEND_MODELS = {
diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py
index 9788bfeca109..4d6aa10500d3 100644
--- a/vllm/platforms/rocm.py
+++ b/vllm/platforms/rocm.py
@@ -59,6 +59,9 @@
     "Qwen2ForCausalLM": _ROCM_SWA_REASON,
     "MistralForCausalLM": _ROCM_SWA_REASON,
     "MixtralForCausalLM": _ROCM_SWA_REASON,
+    "PaliGemmaForConditionalGeneration": (
+        "ROCm flash attention does not yet fully support 32-bit precision on PaliGemma"
+    ),
     "Phi3VForCausalLM": (
         "ROCm Triton flash attention may run into compilation errors due to "
         "excessive use of shared memory. If this happens, disable Triton FA "