Skip to content

Conversation

@jp1924
Copy link
Contributor

@jp1924 jp1924 commented Oct 27, 2024

What does this PR do?

Image features and image tokens do not match: tokens: 729, features 728
  File "/usr/local/lib/python3.10/dist-packages/transformers/models/llava/modeling_llava.py", line 524, in forward
    raise ValueError(
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/root/load_data.py", line 29, in <module>
    model(**outputs)
  File "/usr/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/usr/lib/python3.10/runpy.py", line 196, in _run_module_as_main (Current frame)
    return _run_code(code, main_globals, None,
ValueError: Image features and image tokens do not match: tokens: 729, features 728

When using a vision encoder that doesn't insert CLS tokens (like SigLIP) with Llava or Llava-next models,
an img_size mismatch error occurs.

height, width = get_image_size(to_numpy_array(pixel_values[0]))
num_image_tokens = (height // self.patch_size) * (width // self.patch_size) + 1

Unlike ViT, SigLIP doesn't add a cls token in the vision embedding layer.
However, in the Llava processor, +1 is hardcoded.
As a result, an error occurs even when vision_feature_select_strategy is set to full.

Therefore, I propose replacing the hardcoded part with a flag named vision_feature_use_cls.

bug reproduction code

import requests
from PIL import Image

from transformers import (
    AddedToken,
    AutoConfig,
    AutoImageProcessor,
    AutoTokenizer,
    LlavaConfig,
    LlavaForConditionalGeneration,
    LlavaNextConfig,
    LlavaNextForConditionalGeneration,
    LlavaNextImageProcessor,
    LlavaNextProcessor,
    LlavaProcessor,
)


device = "cpu"

IMG_TOKEN = "<|image|>"
language_name, vision_name = "google/gemma-2-9b", "google/siglip-so400m-patch14-384"
language_config = AutoConfig.from_pretrained(language_name)
vision_config = AutoConfig.from_pretrained(vision_name).vision_config

vision_config.num_hidden_layers, language_config.num_hidden_layers = 2, 2

image_processor = AutoImageProcessor.from_pretrained(vision_name)
tokenizer = AutoTokenizer.from_pretrained(language_name)
tokenizer.add_tokens(AddedToken(IMG_TOKEN, special=True, normalized=False), special_tokens=True)
language_config.vocab_size = len(tokenizer.get_vocab())
image_token_index = tokenizer.convert_tokens_to_ids(IMG_TOKEN)

prompts = [
    f"USER: {IMG_TOKEN}\nWhat are the things I should be cautious about when I visit this place? What should I bring with me? ASSISTANT:",
    f"USER: {IMG_TOKEN}\nWhat is this? ASSISTANT:",
]
image1 = Image.open(requests.get("https://llava-vl.github.io/static/images/view.jpg", stream=True).raw)
image2 = Image.open(requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw)

# LLaVA
vision_feature_select_strategy = "full"
processor = LlavaProcessor(
    tokenizer=tokenizer,
    image_processor=image_processor,
    patch_size=vision_config.patch_size,
    vision_feature_select_strategy=vision_feature_select_strategy,
    image_seq_length=vision_config.image_size,
    image_token_index=image_token_index,
    image_token=IMG_TOKEN,
)
config = LlavaConfig(
    vision_config=vision_config,
    text_config=language_config,
    image_seq_length=vision_config.image_size,
    image_token_index=image_token_index,
    vision_feature_select_strategy=vision_feature_select_strategy,
    loss_type="ForCausalLM",
    _attn_implementation="eager",
)
model = LlavaForConditionalGeneration(config)
model = model.train().to(device)

inputs = processor(images=[image1, image2], text=prompts, return_tensors="pt", padding=True)
inputs["labels"] = inputs["input_ids"]
inputs = inputs.to(device)

outputs = model(**inputs)
outputs


# LLaVA-NEXT
vision_feature_select_strategy = "full"
image_processor = LlavaNextImageProcessor.from_pretrained(
    vision_name,
    image_grid_pinpoints=[[768, 768], [384, 768], [384, 1152], [768, 384], [1152, 384]],
    crop_size={"height": config.vision_config.image_size, "width": config.vision_config.image_size},
)

processor = LlavaNextProcessor(
    image_processor=image_processor,
    tokenizer=tokenizer,
    vision_feature_select_strategy=vision_feature_select_strategy,
    image_seq_length=vision_config.image_size,
)
config = LlavaNextConfig(
    vision_config=vision_config,
    text_config=language_config,
    image_seq_length=vision_config.image_size,
    image_token_index=image_token_index,
    vision_feature_select_strategy=vision_feature_select_strategy,
    loss_type="ForCausalLM",
    _attn_implementation="eager",
)
model = LlavaNextForConditionalGeneration(config)

inputs = processor(images=[image1, image2], text=prompts, return_tensors="pt", padding=True)
inputs["labels"] = inputs["input_ids"]
inputs = inputs.to(device)

outputs = model(**inputs)
outputs
  • transformers version: 4.46.0
  • Platform: Linux-5.15.0-124-generic-x86_64-with-glibc2.35
  • Python version: 3.10.12
  • Huggingface_hub version: 0.26.0
  • Safetensors version: 0.4.5
  • Accelerate version: 1.0.1
  • Accelerate config: not found
  • PyTorch version (GPU?): 2.4.0+cu121 (True)
  • Tensorflow version (GPU?): 2.15.1 (True)
  • Flax version (CPU?/GPU?/TPU?): 0.7.0 (cpu)
  • Jax version: 0.4.13
  • JaxLib version: 0.4.13
  • Using distributed or parallel set-up in script?:
  • Using GPU in script?:
  • GPU type: NVIDIA A100 80GB PCIe

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

@amyeroberts, @qubvel, @zucchini-nlp

@jp1924
Copy link
Contributor Author

jp1924 commented Oct 27, 2024

This is my personal opinion, but I don't think my implementation is particularly good.

In transformers vision models, there are models like ViT where CLS tokens are hardcoded to be inserted, and it's somewhat difficult to determine whether a model includes CLS tokens or not.

The bug occurred because of this issue, and I think using a flag-based approach isn't a fundamental solution.

Instead, I believe the best approach would be to add a configuration value in the vision encoder's config to distinguish between models with and without CLS tokens. Then, we could write code that adds the additional value only when the vision encoder includes CLS tokens, and doesn't add it when it doesn't. This would be a more robust solution.

@zucchini-nlp
Copy link
Member

Hey! Thanks for reporting, it is a known but and we discovered it later after releasing new processing logic. The fix will be in #33424 in a few weeks :)

@jp1924
Copy link
Contributor Author

jp1924 commented Oct 27, 2024

@zucchini-nlp

Oh! Thanks for the quick answer!

But I have to run llava with siglip, which is annoying because of this error.
If I want to use this temporarily, can I just grab a branch of that PR and run

@jp1924
Copy link
Contributor Author

jp1924 commented Oct 27, 2024

Oh, it's not fixed yet, I see. I'll take care of it until it's fixed. Thanks!

@zucchini-nlp
Copy link
Member

Yes, feel free to install from that PR in the meanwhile. It took a bit longer to merge as we had to discuss the long-term solution which would work in most cases for all VLMs

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants