Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/en/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ Flax), PyTorch, and/or TensorFlow.
| [Phi3](model_doc/phi3) | ✅ | ❌ | ❌ |
| [PhoBERT](model_doc/phobert) | ✅ | ✅ | ✅ |
| [Pix2Struct](model_doc/pix2struct) | ✅ | ❌ | ❌ |
| [Pixtral](model_doc/pixtral) | | ❌ | ❌ |
| [Pixtral](model_doc/pixtral) | | ❌ | ❌ |
| [PLBart](model_doc/plbart) | ✅ | ❌ | ❌ |
| [PoolFormer](model_doc/poolformer) | ✅ | ❌ | ❌ |
| [Pop2Piano](model_doc/pop2piano) | ✅ | ❌ | ❌ |
Expand Down
16 changes: 9 additions & 7 deletions docs/source/en/model_doc/pixtral.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,20 +18,22 @@ rendered properly in your Markdown viewer.

## Overview

The Pixtral model was released by the Mistral AI team on [Vllm](https://github.com/vllm-project/vllm/pull/8377), where a version of the code can be found!

The Pixtral model was released by the Mistral AI team on [vLLM](https://github.com/vllm-project/vllm/pull/8377), where a version of the code can be found!

Tips:

- Pixtral is a multimodal model, the main contribution is the 2d ROPE on the images, and support for arbitrary image size (the images are not padded together nor are they resized)
- This model follows the `Llava` familiy, meaning image embeddings are placed instead of the `[IMG]` token placeholders.
- Pixtral is a multimodal model, taking images and text as input, and producing text as output.
- This model follows the [Llava](llava) family, meaning image embeddings are placed instead of the `[IMG]` token placeholders. The model uses [`PixtralVisionModel`] for its vision encoder, and [`MistralForCausalLM`] for its language decoder.
- The main contribution is the 2d ROPE (rotary postiion embeddings) on the images, and support for arbitrary image sizes (the images are not padded together nor are they resized).
- The format for one or mulitple prompts is the following:
```
"<s>[INST][IMG]\nWhat are the things I should be cautious about when I visit this place?[/INST]"
```
Then, the processor will replace each `[IMG]` token with a number of `[IMG]` token that depends on the height and the width of the image. Each *row* of the image is separated by a `[IMG_BREAK]` token, and each image is separated by a `[IMG_END]` token.

This model was contributed by [amyeroberts](https://huggingface.co/amyeroberts) and [ArthurZ](https://huggingface.co/ArthurZ)
This model was contributed by [amyeroberts](https://huggingface.co/amyeroberts) and [ArthurZ](https://huggingface.co/ArthurZ). The original code can be found [here](https://github.com/vllm-project/vllm/pull/8377).

## Usage

Here is an example of how to run it:

Expand Down Expand Up @@ -83,9 +85,9 @@ Each image captures a different scene, from a close-up of a dog to expansive nat

[[autodoc]] PixtralVisionConfig

## PixtralModel
## PixtralVisionModel

[[autodoc]] PixtralModel
[[autodoc]] PixtralVisionModel
- forward

## PixtralImageProcessor
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2994,7 +2994,7 @@
"Pix2StructVisionModel",
]
)
_import_structure["models.pixtral"].extend(["PixtralModel", "PixtralPreTrainedModel"])
_import_structure["models.pixtral"].extend(["PixtralVisionModel", "PixtralPreTrainedModel"])
_import_structure["models.plbart"].extend(
[
"PLBartForCausalLM",
Expand Down Expand Up @@ -7486,8 +7486,8 @@
Pix2StructVisionModel,
)
from .models.pixtral import (
PixtralModel,
PixtralPreTrainedModel,
PixtralVisionModel,
)
from .models.plbart import (
PLBartForCausalLM,
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/auto/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@
("persimmon", "PersimmonModel"),
("phi", "PhiModel"),
("phi3", "Phi3Model"),
("pixtral", "PixtralModel"),
("pixtral", "PixtralVisionModel"),
("plbart", "PLBartModel"),
("poolformer", "PoolFormerModel"),
("prophetnet", "ProphetNetModel"),
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/pixtral/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
pass
else:
_import_structure["modeling_pixtral"] = [
"PixtralModel",
"PixtralVisionModel",
"PixtralPreTrainedModel",
]

Expand All @@ -53,8 +53,8 @@
pass
else:
from .modeling_pixtral import (
PixtralModel,
PixtralPreTrainedModel,
PixtralVisionModel,
)

try:
Expand Down
18 changes: 7 additions & 11 deletions src/transformers/models/pixtral/configuration_pixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@

class PixtralVisionConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`PixtralModel`]. It is used to instantiate an
Pixtral model according to the specified arguments, defining the model architecture. Instantiating a configuration
with the defaults will yield a similar configuration to that of the Pixtral-9B.
This is the configuration class to store the configuration of a [`PixtralVisionModel`]. It is used to instantiate an
Pixtral vision encoder according to the specified arguments, defining the model architecture. Instantiating a configuration
with the defaults will yield a similar configuration to the vision encoder used by Pixtral-12B.

e.g. [pixtral-hf/pixtral-9b](https://huggingface.co/pixtral-hf/pixtral-9b)

Expand Down Expand Up @@ -52,19 +52,17 @@ class PixtralVisionConfig(PretrainedConfig):
Dropout probability for the attention layers.
rope_theta (`float`, *optional*, defaults to 10000.0):
The base period of the RoPE embeddings.
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
Whether to tie the word embeddings with the input embeddings.

Example:

```python
>>> from transformers import PixtralModel, PixtralVisionConfig, CLIPVisionConfig, LlamaConfig
>>> from transformers import PixtralVisionModel, PixtralVisionConfig

>>> # Initializing a Pixtral 12B style configuration
>>> # Initializing a Pixtral-12B style configuration
>>> config = PixtralVisionConfig()

>>> # Initializing a model from the pixtral 12B style configuration
>>> model = PixtralModel(configuration)
>>> # Initializing a model (with randomly initialized weights) from the configuration
>>> model = PixtralVisionModel(configuration)

>>> # Accessing the model configuration
>>> configuration = model.config
Expand All @@ -84,7 +82,6 @@ def __init__(
hidden_act="gelu",
attention_dropout=0.0,
rope_theta=10000.0,
tie_word_embeddings=False,
**kwargs,
):
super().__init__(**kwargs)
Expand All @@ -99,5 +96,4 @@ def __init__(
self.attention_dropout = attention_dropout
self.hidden_act = hidden_act
self.rope_theta = rope_theta
self.tie_word_embeddings = tie_word_embeddings
self.head_dim = hidden_size // num_attention_heads
38 changes: 14 additions & 24 deletions src/transformers/models/pixtral/modeling_pixtral.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# coding=utf-8
# Copyright 2024 the HuggingFace Inc. team. All rights reserved.
# Copyright 2024 Mistral and the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -48,15 +48,13 @@ def position_ids_in_meshgrid(patch_embeds_list, max_width):
class PixtralRotaryEmbedding(nn.Module):
"""
The key with pixtral embedding is just that you have a frequency for each pixel positions.
If you have height x width pixels (or embedding pixels)
If you have height x width pixels (or embedding pixels), then the frequency used for ROPE
is given by indexing the pre_computed frequency on the width and height.

then the frequency used for ROPE is given by indexing the pre_computed frequency on the
width and height.
What you output is of dimension (batch, height * width, dim) with dim the embed dim.

What you output is of dimension batch, height * width, dim with dim the embed dim.

This simply means that for each image hidden states, you are going to add
a corresponding positional embedding, based on it's index in the grid.
This simply means that for each image hidden state, you are going to add
a corresponding positional embedding, based on its index in the grid.
"""

def __init__(self, config, device):
Expand Down Expand Up @@ -319,9 +317,7 @@ def forward(
r"""
Args:
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
than the model's internal embedding lookup matrix.
Embeddings which serve as input to the Transformer.
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:

Expand Down Expand Up @@ -392,17 +388,13 @@ def forward(
and behavior.

Parameters:
config ([`PixtralVisionConfig`] or [`PixtralVisionConfig`]):
Model configuration class with all the parameters of the model. Initializing with a config file does not
config ([`PixtralVisionConfig`]):
Model configuration class with all the parameters of the vision encoder. Initializing with a config file does not
load the weights associated with the model, only the configuration. Check out the
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""


@add_start_docstrings(
"The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
PIXTRAL_START_DOCSTRING,
)
class PixtralPreTrainedModel(PreTrainedModel):
config_class = PixtralVisionConfig
base_model_prefix = "model"
Expand All @@ -412,9 +404,6 @@ class PixtralPreTrainedModel(PreTrainedModel):
_supports_cache_class = True

def _init_weights(self, module):
# important: this ported version of Pixtral isn't meant for training from scratch - only
# inference and fine-tuning - so the proper init weights code has been removed - the original codebase
# https://github.com/haotian-liu/LLaVA/tree/main/pixtral should serve for that purpose
std = (
self.config.initializer_range
if hasattr(self.config, "initializer_range")
Expand All @@ -433,8 +422,9 @@ def _init_weights(self, module):

PIXTRAL_INPUTS_DOCSTRING = r"""
Args:
pixel_values: list of N_img images of variable sizes,
each of shape (C, H, W)
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See [`AutoImageProcessor.__call__`]
for details.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
Expand Down Expand Up @@ -463,10 +453,10 @@ def generate_block_attention_mask(patch_embeds_list, tensor):


@add_start_docstrings(
"""The PIXTRAL model which consists of a vision backbone and a language model.""",
"The bare Pixtral vision encoder outputting raw hidden-states without any specific head on top.",
PIXTRAL_START_DOCSTRING,
)
class PixtralModel(PixtralPreTrainedModel):
class PixtralVisionModel(PixtralPreTrainedModel):
base_model_prefix = "vision_encoder"

def __init__(self, config):
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/utils/dummy_pt_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -7102,14 +7102,14 @@ def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])


class PixtralModel(metaclass=DummyObject):
class PixtralPreTrainedModel(metaclass=DummyObject):
_backends = ["torch"]

def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])


class PixtralPreTrainedModel(metaclass=DummyObject):
class PixtralVisionModel(metaclass=DummyObject):
_backends = ["torch"]

def __init__(self, *args, **kwargs):
Expand Down
20 changes: 10 additions & 10 deletions tests/models/pixtral/test_modeling_pixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@

from transformers import (
AutoProcessor,
PixtralModel,
PixtralVisionConfig,
PixtralVisionModel,
is_torch_available,
is_vision_available,
)
Expand All @@ -46,7 +46,7 @@
from PIL import Image


class PixtralModelTester:
class PixtralVisionModelTester:
def __init__(
self,
parent,
Expand Down Expand Up @@ -107,7 +107,7 @@ def get_config(self):
)

def create_and_check_model(self, config, pixel_values):
model = PixtralModel(config=config)
model = PixtralVisionModel(config=config)
model.to(torch_device)
model.eval()
with torch.no_grad():
Expand All @@ -120,7 +120,7 @@ def create_and_check_model(self, config, pixel_values):
self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.hidden_size))

def create_and_check_model_with_projection(self, config, pixel_values):
model = PixtralModel(config=config)
model = PixtralVisionModel(config=config)
model.to(torch_device)
model.eval()
with torch.no_grad():
Expand All @@ -140,17 +140,17 @@ def prepare_config_and_inputs_for_common(self):


@require_torch
class PixtralModelModelTest(ModelTesterMixin, unittest.TestCase):
class PixtralVisionModelModelTest(ModelTesterMixin, unittest.TestCase):
"""
Model tester for `PixtralModel`.
Model tester for `PixtralVisionModel`.
"""

all_model_classes = (PixtralModel,) if is_torch_available() else ()
all_model_classes = (PixtralVisionModel,) if is_torch_available() else ()
test_pruning = False
test_head_masking = False

def setUp(self):
self.model_tester = PixtralModelTester(self)
self.model_tester = PixtralVisionModelTester(self)
self.config_tester = ConfigTester(self, config_class=PixtralVisionConfig, has_text_modality=False)

@unittest.skip("model does not support input embeds")
Expand Down Expand Up @@ -261,7 +261,7 @@ def test_determinism(self):


@require_torch
class PixtralModelIntegrationTest(unittest.TestCase):
class PixtralVisionModelIntegrationTest(unittest.TestCase):
def setUp(self):
self.processor = AutoProcessor.from_pretrained("hf-internal-testing/pixtral-12b")

Expand All @@ -273,7 +273,7 @@ def tearDown(self):
@require_bitsandbytes
def test_small_model_integration_test(self):
# Let' s make sure we test the preprocessing to replace what is used
model = PixtralModel.from_pretrained("hf-internal-testing/pixtral-12b", load_in_4bit=True)
model = PixtralVisionModel.from_pretrained("hf-internal-testing/pixtral-12b", load_in_4bit=True)

prompt = "<s>[INST][IMG]\nWhat are the things I should be cautious about when I visit this place?[/INST]"
image_file = "https://pixtral-vl.github.io/static/images/view.jpg"
Expand Down
Loading