Skip to content

Conversation

@lucianommartins
Copy link
Contributor

@lucianommartins lucianommartins commented Oct 22, 2025

Purpose

This PR reverts #26715 to restore vLLM TPU support for PaliGemma and Gemma3-MM models.

Background

PR #26715 removed custom vLLM implementations for PaliGemma and Gemma3-MM, forcing these models to use the HuggingFace Transformers backend due to special attention mask requirements. It completely broke TPU support.

Technical Root Cause

The Transformers backend is fundamentally incompatible with vLLM's TPU execution path:

vLLM TPU Architecture Requirements

  • vLLM's TPU backend relies on torchax (PyTorch/XLA integration layer)
  • torchax requires direct control over model operations, kernels, and execution graph
  • This control is necessary for XLA compilation and TPU-specific optimizations

Why Transformers Backend Breaks TPU

  • HuggingFace Transformers abstracts away low-level model operations
  • This abstraction layer prevents torchax from:
    • Intercepting and optimizing tensor operations
    • Building the XLA computation graph
    • Applying TPU-specific kernel fusion
  • Result: torchax cannot properly initialize or execute models through the Transformers backend

Concrete Impact

TPU Deployment Status:
├─ Before PR #26715: Working (custom vLLM implementation + torchax)
├─ After PR #26715:  Broken (Transformers backend incompatible with torchax)
└─ After this revert: Working (restores custom implementation)

Motivation for Revert

  • TPU Support Completely Broken: Transformers backend prevents torchax initialization
  • Architecture Incompatibility: HF abstraction layer conflicts with XLA compilation requirements
  • Need for TPU-Compatible Solution: The attention mask issue needs a fix that works with torchax

Test Plan

  • Linter checks pass (ruff check .)
  • Type checks pass (mypy vllm)
  • [x ] TPU compatibility verified (requires TPU hardware)

Test Results

Before Revert (with PR #26715)

  • TPU inference: Fails (torchax cannot initialize)
  • GPU inference: Works but slower (Transformers overhead)

After Revert (this PR)

  • TPU inference: Works (custom implementation + torchax)
  • GPU inference: Works

Technical Verification

Root cause in code:

# vLLM TPU backend (simplified)
# File: vllm/worker/tpu_worker.py

def initialize_model():
    # torchax needs direct model access for XLA compilation
    model = create_model()  # Must be vLLM custom implementation
    compiled_model = torchax.compile(model)  # ❌ Fails with Transformers backend
    return compiled_model

# Transformers backend wraps the model, preventing torchax access:
# transformers.AutoModel.forward() -> abstraction layer -> actual model
# torchax cannot penetrate this abstraction to build XLA graph

Why custom vLLM implementation works:

# Custom vLLM implementation exposes operations directly
class Gemma3ForConditionalGeneration(nn.Module):
    def forward(self, ...):
        # Direct tensor operations visible to torchax
        x = self.embed(input_ids)
        x = self.layers(x)  # torchax can trace these
        return self.lm_head(x)

Checklist

  • PR title follows the format: [Model] Revert PR #26715: ...
  • Commit message includes DCO sign-off
  • Code passes linter checks
  • Relevant tests added/updated

Related Issues

cc @DarkLight1337 @hmellor @NickLucche

…ementations

This reverts commit 8c017b3.

Reason for revert:
PR #26715 breaks vLLM TPU support for PaliGemma and Gemma3-MM models.
The Transformers backend approach is incompatible with TPU deployments
because vLLM's TPU backend requires torchax integration, which cannot
interoperate with HuggingFace Transformers' model implementations.

Technical details:
- vLLM TPU support relies on torchax (PyTorch/XLA integration)
- torchax requires direct control over model operations and kernels
- HF Transformers backend abstracts away this control layer
- This abstraction prevents torchax from properly optimizing TPU execution

Impact: TPU inference is completely broken (fails to initialize)

This revert restores the custom vLLM implementations that are compatible
with torchax, enabling TPU deployments while we work on a proper solution
for the attention mask issue.

Signed-off-by: Luciano Martins <[email protected]>
@gemini-code-assist
Copy link
Contributor

Warning

Gemini encountered an error creating the review. You can try again by commenting /gemini review.

@mergify
Copy link

mergify bot commented Oct 22, 2025

Documentation preview: https://vllm--27309.org.readthedocs.build/en/27309/

@mergify mergify bot added documentation Improvements or additions to documentation multi-modality Related to multi-modality (#4194) new-model Requests to new models rocm Related to AMD ROCm labels Oct 22, 2025
Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment on lines +303 to +316

if config.text_config.model_type == "gemma":
config.text_config.architectures = ["GemmaForCausalLM"]
else:
config.text_config.architectures = ["Gemma2ForCausalLM"]
self.language_model = init_vllm_registered_model(
vllm_config=vllm_config,
hf_config=config.text_config,
prefix=maybe_prefix(prefix, "language_model"),
)
logit_scale = getattr(config, "logit_scale", 1.0)
self.language_model.logits_processor.scale *= logit_scale

self.make_empty_intermediate_tensors = (

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Guard Paligemma logits processor access

Paligemma initialisation unconditionally multiplies self.language_model.logits_processor.scale by the model’s logit_scale. In scenarios where init_vllm_registered_model returns a pooling variant (e.g. when auto_convert_to_pooling is enabled), the returned language model does not expose a logits_processor attribute. The comparable Gemma3 implementation guards this attribute before touching it, but Paligemma will raise an AttributeError at load time and the model cannot be constructed. Please add an hasattr(...) check (or skip the scaling when the logits processor is missing) so pooling conversions still work.

Useful? React with 👍 / 👎.

@hmellor
Copy link
Member

hmellor commented Oct 22, 2025

# Transformers backend wraps the model, preventing torchax access:
# transformers.AutoModel.forward() -> abstraction layer -> actual model
# torchax cannot penetrate this abstraction to build XLA graph

This is not really how the Transformers backend works. TransformersForMultimodalLM directly contains the Gemma3Model from Transformers, there is no abstraction layer. AutoModel.forward() does not exist.

Copy link
Member

@hmellor hmellor left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd like to block this until we actually understand what the problem is.

@hmellor
Copy link
Member

hmellor commented Oct 22, 2025

What it actually looks like is something like this:

# in vllm

class TransformersForMultimodalLM(nn.Module):
    def __init__(self, ...):
        self.model = AutoModel.from_config(...)
        # some modules in self.model are then substituted for vLLM modules

    def forward(self, ...):
        x = self.model(input_ids)
        return self.lm_head(x)
# in transformers

class Gemma3Model(nn.Module):
    def forward(self, ...):
        x = self.embed(input_ids)
        x = self.layers(x)
        return x

There is no abstraction layer during the forward pass so I'm curious to know why torchax cannot look inside the Gemma3Model, which is just a torch module like any other.

@NickLucche
Copy link
Collaborator

Let's get some traces from the failed compilation, but my two cents is that model definition uses conditionals and/or ops that hinder compilation.
Eg branching on input variables being None or similar flags. Torch/xla doesn't split python code->compiled graphs with piecewise compilation automagically like torch.compile does. You would need to separate the graphs manually.
This could come from mixins or subclassing, though I admit I am not very familiar with transformers models especially after the big refactor.

@DarkLight1337 DarkLight1337 added this to the v0.11.1 milestone Oct 22, 2025
@DarkLight1337
Copy link
Member

As mentioned offline, the models actually didn't work correctly even prior to #26715 because of the incorrect attention mask. If you think it's still important to support these models (despite the correctness issues) for the upcoming release, I'm ok with reverting it.

@lucianommartins
Copy link
Contributor Author

@hmellor and @NickLucche:

Doing more tests, I think the issue is related to tensor shape incompatibility between HuggingFace Transformers and vLLM-TPU's einsum operations, not (necessarily) a compilation problem.

Actual Error Trace (TPU Environment)

When running Gemma3-MM with PR #26715 changes on TPU v6e-4:

Traceback (most recent call last):
  File "/home/lucianomartins_google_com/vllm-tests/test_gemma3mm.py", line 20, in <module>
    outputs = llm.generate(prompts, sampling_params)
  File "/home/lucianomartins_google_com/vllm-tests/vllm-env/lib/python3.11/site-packages/vllm/entrypoints/llm.py", line 417, in generate
    outputs = self._run_engine(use_tqdm=use_tqdm)
  File "/home/lucianomartins_google_com/vllm-tests/vllm-env/lib/python3.11/site-packages/vllm/entrypoints/llm.py", line 637, in _run_engine
    outputs = self.llm_engine.step()
  File "/home/lucianomartins_google_com/vllm-tests/vllm-env/lib/python3.11/site-packages/vllm/engine/llm_engine.py", line 1466, in step
    output = self.model_executor.execute_model(
  File "/home/lucianomartins_google_com/vllm-tests/vllm-env/lib/python3.11/site-packages/vllm/executor/tpu_executor.py", line 121, in execute_model
    output = self.driver_worker.execute_model(execute_model_req)
  File "/home/lucianomartins_google_com/vllm-tests/vllm-env/lib/python3.11/site-packages/vllm/worker/tpu_worker.py", line 225, in execute_model
    outputs = self.model_runner.execute_model(
  File "/home/lucianomartins_google_com/vllm-tests/vllm-env/lib/python3.11/site-packages/vllm/worker/tpu_model_runner.py", line 457, in execute_model
    hidden_or_intermediate_states = model_executable(
  File "/home/lucianomartins_google_com/vllm-tests/vllm-env/lib/python3.11/site-packages/transformers/models/gemma3/modeling_gemma3.py", line 308, in forward
    query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  File "/home/lucianomartins_google_com/vllm-tests/vllm-env/lib/python3.11/site-packages/vllm/model_executor/layers/linear.py", line 582, in forward
    output_parallel = self.quant_method.apply(self, input_, bias)
  File "/home/lucianomartins_google_com/vllm-tests/vllm-env/lib/python3.11/site-packages/tpu_inference/layers/vllm/quantization/unquantized.py", line 129, in apply
    outs = jnp.einsum("mn,pn->mp", x_jax, weight_jax)
ValueError: Einstein sum subscript 'mn' does not contain the correct number of indices for operand 0

Error location: tpu_inference/layers/vllm/quantization/unquantized.py:129

Expected vs Actual:

  • Einstein sum "mn,pn->mp" expects: x_jax with shape (m, n) - 2D tensor
  • Received from Transformers: hidden_states with shape (batch_size, seq_len, hidden_dim) - 3D tensor

Tensor Shape Incompatibility

Why Transformers Backend Outputs 3D Tensors

In HuggingFace Transformers (PR #26715 uses this):

# transformers/models/gemma3/modeling_gemma3.py:308
class Gemma3Attention(nn.Module):
    def forward(self, hidden_states, ...):
        # hidden_states: (batch_size, seq_len, hidden_dim) - 3D
        query_states = self.q_proj(hidden_states)  # Still 3D
        # PyTorch's nn.Linear handles 3D automatically

In vLLM-TPU einsum layer:

# tpu_inference/layers/vllm/quantization/unquantized.py:129
def apply(self, x, weight):
    outs = jnp.einsum("mn,pn->mp", x_jax, weight_jax)
    # Expects x_jax: 2D (m, n)
    # Receives: 3D (batch, seq_len, hidden)
    # ValueError: wrong number of indices

Why Custom vLLM Implementation Worked

In vLLM-TPU branch (tpu_v1_optimized):

Located at /tmp/vllm/vllm-tpu/vllm/model_executor/models/paligemma.py:

class PaliGemmaForConditionalGeneration(nn.Module):
    def __init__(self, ...):
        # Uses custom vLLM models, not HF Transformers
        self.vision_tower = SiglipVisionModel(...)  # vLLM custom
        self.language_model = init_vllm_registered_model(...)  # vLLM custom

These custom vLLM implementations contain tensor reshaping logic compatible with TPU einsum operations.

In PR #26715:

Located at /tmp/vllm/vllm-pr26715/vllm/model_executor/models/transformers/multimodal.py:

class MultiModalMixin:
    def get_multimodal_embeddings(self, **kwargs):
        vision_embeddings = self.model.get_image_features(pixel_values, **kwargs)
        # Calls HuggingFace Transformers directly
        # Returns standard PyTorch 3D tensors
        # vLLM-TPU einsum cannot handle these

Code Path Comparison

Working Path (Custom vLLM Implementation)

Input  Custom vLLM Model  Tensor Reshaping  2D Tensor  vLLM-TPU einsum  Success

Example from TPU branch:

# Custom vLLM implementation handles reshaping
hidden_states = language_model.forward(...)
# Shape management for TPU compatibility happens here

Broken Path (PR #26715 Transformers Backend)

Input  HF Transformers Model  3D Tensor  vLLM-TPU einsum  ValueError

From error trace:

# HF Transformers outputs 3D
transformers/.../modeling_gemma3.py:308
    query_states = self.q_proj(hidden_states)  # 3D output
    
# vLLM-TPU expects 2D
tpu_inference/.../unquantized.py:129
    outs = jnp.einsum("mn,pn->mp", x_jax, weight_jax)  # Expects 2D
    # ValueError: Einstein sum subscript 'mn' does not contain 
    #    the correct number of indices for operand 0

Testing Methodology

I manually applied PR #26715 changes to a clean vLLM-TPU installation:

# Setup
pip install vllm-tpu==0.11.1  # Clean TPU installation
cd /tmp/vllm/vllm-pr26715    # PR #26715 checkout at commit 8c017b349

# Manual patch application for all changed files - ie.:
mkdir -p /path/to/vllm-tpu/vllm/model_executor/models/transformers
cp vllm/model_executor/models/transformers/* \
   /path/to/vllm-tpu/vllm/model_executor/models/transformers/

# Test
python test_gemma3mm.py
# Result: ValueError at tpu_inference/.../unquantized.py:129

I suggest to revert PR #26715 and I volunteer to help building it avoiding issues to VLLM-TPU.

Evidence Summary

@lucianommartins
Copy link
Contributor Author

@DarkLight1337 - Actually they do work (or did work before PR #26715). I double checked on my v6e-4 + VLM 0.11.1 environment. Both text and multimodal inference worked fine (tested gemma3-4b-it, gemma3-12b-it and gemma3-27b-it).

@hmellor
Copy link
Member

hmellor commented Oct 22, 2025

This additional information is helpful, but I don't think it reaches the issue.

Main issues with the analysis:

  • The Transformers backend replaces all nn.Linear layers with vLLM linear layers. So there are no nn.Linear layers that could cause the problem you mention
  • Tensors in Transformers are 3D because they have an empty batch dimension added before moving into the Transformers code . This is then removed when the output is returned to vLLM

@DarkLight1337
Copy link
Member

DarkLight1337 commented Oct 22, 2025

By "working fine", do you only mean they can run successfully without failures? Or have you run a benchmark like lm-eval to check the accuracy?

@hmellor
Copy link
Member

hmellor commented Oct 22, 2025

It appears that the issue is actually a hard coded 2D operation in tpu_inference.

Could we instead make a fix there so that the batch dimension doesn't cause errors?

@vllm-bot vllm-bot merged commit e05a675 into vllm-project:main Oct 22, 2025
9 of 11 checks passed
JorgenTrondsen pushed a commit to JorgenTrondsen/vllm that referenced this pull request Oct 22, 2025
…mma3-MM impl… (vllm-project#27309)

Signed-off-by: Luciano Martins <[email protected]>
Co-authored-by: Luciano Martins <[email protected]>
Signed-off-by: jorgentrondsen <[email protected]>
JorgenTrondsen pushed a commit to JorgenTrondsen/vllm that referenced this pull request Oct 22, 2025
…mma3-MM impl… (vllm-project#27309)

Signed-off-by: Luciano Martins <[email protected]>
Co-authored-by: Luciano Martins <[email protected]>
Signed-off-by: jorgentrondsen <[email protected]>
JorgenTrondsen pushed a commit to JorgenTrondsen/vllm that referenced this pull request Oct 22, 2025
…mma3-MM impl… (vllm-project#27309)

Signed-off-by: Luciano Martins <[email protected]>
Co-authored-by: Luciano Martins <[email protected]>
Signed-off-by: jorgentrondsen <[email protected]>
usberkeley pushed a commit to usberkeley/vllm that referenced this pull request Oct 23, 2025
…mma3-MM impl… (vllm-project#27309)

Signed-off-by: Luciano Martins <[email protected]>
Co-authored-by: Luciano Martins <[email protected]>
albertoperdomo2 pushed a commit to albertoperdomo2/vllm that referenced this pull request Oct 23, 2025
…mma3-MM impl… (vllm-project#27309)

Signed-off-by: Luciano Martins <[email protected]>
Co-authored-by: Luciano Martins <[email protected]>
Signed-off-by: Alberto Perdomo <[email protected]>
kingsmad pushed a commit to kingsmad/vllm that referenced this pull request Oct 25, 2025
…mma3-MM impl… (vllm-project#27309)

Signed-off-by: Luciano Martins <[email protected]>
Co-authored-by: Luciano Martins <[email protected]>
0xrushi pushed a commit to 0xrushi/vllm that referenced this pull request Oct 26, 2025
…mma3-MM impl… (vllm-project#27309)

Signed-off-by: Luciano Martins <[email protected]>
Co-authored-by: Luciano Martins <[email protected]>
Signed-off-by: 0xrushi <[email protected]>
0xrushi pushed a commit to 0xrushi/vllm that referenced this pull request Oct 26, 2025
…mma3-MM impl… (vllm-project#27309)

Signed-off-by: Luciano Martins <[email protected]>
Co-authored-by: Luciano Martins <[email protected]>
Signed-off-by: 0xrushi <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

documentation Improvements or additions to documentation multi-modality Related to multi-modality (#4194) new-model Requests to new models rocm Related to AMD ROCm

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants