Skip to content
Merged

Add Aria #34157

Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
141 commits
Select commit Hold shift + click to select a range
f37181e
First
Oct 11, 2024
16ab157
Try to make it work
Oct 11, 2024
48828e8
Working init
Oct 14, 2024
b663c25
First working pipeline!
Oct 14, 2024
8df558c
Simplify code
Oct 14, 2024
60ad089
Fix tests
Oct 14, 2024
74642ec
Small fix
aymeric-roucher Oct 15, 2024
2c88807
Add GenerationMixin import
Oct 15, 2024
5279e43
Update doc
Oct 15, 2024
96a1fbf
Import sorting
Oct 15, 2024
d5ab4d1
Simplify by removing TokenDispatcher class
Oct 15, 2024
bf6ab44
Add small arg changes
Oct 15, 2024
f40d1cb
Simplify modular
Oct 16, 2024
ff5a37f
Simplify code a lot
Oct 17, 2024
dc29a7d
Fix tests
Oct 17, 2024
c711335
Simplify activation function
Oct 17, 2024
fc00526
Correct attention classes
Oct 18, 2024
c52d1de
Simplify processing
aymeric-roucher Oct 21, 2024
ceddfc2
Fixes
aymeric-roucher Oct 21, 2024
9dd624f
Clean size conversion
aymeric-roucher Oct 21, 2024
69578be
Style
aymeric-roucher Oct 21, 2024
994bb0a
Fix vision attention in AriaEncoderLayer
aymeric-roucher Oct 22, 2024
0188d4c
Fix tests
aymeric-roucher Oct 22, 2024
3fa73e0
Merge branch 'main' into add-aria
aymeric-roucher Oct 22, 2024
886237a
Fix tokenizer test
aymeric-roucher Oct 22, 2024
7faf143
Change sdpa
aymeric-roucher Oct 22, 2024
20babb7
Formatting
aymeric-roucher Oct 22, 2024
183db61
Fix torch.empty and cuda tests
aymeric-roucher Oct 23, 2024
f87dd8c
Try new weights init
aymeric-roucher Oct 23, 2024
3b49743
Try empty init parameters
aymeric-roucher Oct 23, 2024
6e56821
Fix initialized_range
aymeric-roucher Oct 23, 2024
4ecb46f
Should fix some tests
aymeric-roucher Oct 23, 2024
a06e425
Add num_logits_to_keep
aymeric-roucher Oct 23, 2024
09a5092
Add back sdpa fix
aymeric-roucher Oct 23, 2024
5630658
Not sure what I'm doing at that point
aymeric-roucher Oct 23, 2024
a560b26
Fix tests
aymeric-roucher Oct 23, 2024
d471b87
Test initialization tests
aymeric-roucher Oct 24, 2024
6a8c805
Test different pad token
aymeric-roucher Oct 24, 2024
ae19ca6
Streamline modular_aria format
aymeric-roucher Oct 24, 2024
0c8aa0a
Remove AriaVisionModel by just using Idefics3
aymeric-roucher Oct 25, 2024
41a4733
Final weights
aymeric-roucher Oct 25, 2024
bdd7ac0
Update weight conversion script
aymeric-roucher Oct 26, 2024
d2bf502
Remove AriaVisionModel entirely
aymeric-roucher Oct 28, 2024
c42db55
Update tests with Idefics3VisionConfig
aymeric-roucher Oct 28, 2024
cb75cc2
Make style
aymeric-roucher Oct 28, 2024
56b0a5e
Remove attention classes
aymeric-roucher Oct 28, 2024
1c9fabb
Fix phantom model in configuration_auto
aymeric-roucher Oct 28, 2024
82352b8
Amendment
aymeric-roucher Oct 28, 2024
3e91861
Modifications following Pablo's comments
aymeric-roucher Oct 29, 2024
8008228
Simplify following pablos comments
aymeric-roucher Oct 29, 2024
113d4ad
Offload image processing
aymeric-roucher Oct 29, 2024
c82fcee
Working image processing
aymeric-roucher Oct 29, 2024
c658e22
Refactor function keep_ratio_resize_and_pixel_mask
aymeric-roucher Oct 30, 2024
0467498
Simplify image preprocessing
aymeric-roucher Oct 31, 2024
55a963a
Apply modular conversion
aymeric-roucher Oct 31, 2024
7e70407
Answer comments
aymeric-roucher Nov 6, 2024
cdb9a7d
Integrate 2
aymeric-roucher Nov 6, 2024
cac130c
Protect imports
aymeric-roucher Nov 6, 2024
dab0b62
Adapt AriaProcessor args to common format
aymeric-roucher Nov 6, 2024
a5625cf
Small fix
aymeric-roucher Nov 6, 2024
45d11f9
Remove _extract_kwargs
aymeric-roucher Nov 6, 2024
8d2d75c
Harmonize modular and other files
aymeric-roucher Nov 6, 2024
55758ef
Rename variables
aymeric-roucher Nov 6, 2024
22b97bd
Rename AriaForCaualLM to AriaTextForCausalLM
aymeric-roucher Nov 6, 2024
cac3ca8
Try fixing FA2
aymeric-roucher Nov 7, 2024
3650cdf
improve sequential gemm import
aymeric-roucher Nov 7, 2024
2363b99
Formatting
aymeric-roucher Nov 7, 2024
1f13198
Renaming
aymeric-roucher Nov 7, 2024
fb51aa6
Try fixing unprotected imports
aymeric-roucher Nov 7, 2024
9a327cb
Harmonize modular with files
aymeric-roucher Nov 19, 2024
586e53b
Answer comments
aymeric-roucher Nov 22, 2024
d782f4b
Remove legacy image input merging
aymeric-roucher Nov 23, 2024
acdae0b
More simplifications following comments
aymeric-roucher Nov 23, 2024
0c56a9d
Remove TopKRouter
aymeric-roucher Nov 24, 2024
a6f75d3
Remove resize_token_embeddings
aymeric-roucher Nov 24, 2024
38f1d3a
Add data_format to image processing
aymeric-roucher Nov 24, 2024
f158836
Add vision feature layer in config
aymeric-roucher Nov 24, 2024
9451d4b
Update
aymeric-roucher Nov 24, 2024
4fe6478
Format docstrings
aymeric-roucher Nov 24, 2024
d533357
Fix docstrings
aymeric-roucher Nov 24, 2024
db97796
Merge branch 'main' into add-aria
aymeric-roucher Nov 24, 2024
cf4bd56
Working version post merge
aymeric-roucher Nov 24, 2024
0476083
Fix pretrained models
aymeric-roucher Nov 24, 2024
09390c1
Harmonize files
aymeric-roucher Nov 24, 2024
a569c6c
Hopefully fix imports
aymeric-roucher Nov 25, 2024
5276f3f
Remove dependency from processor to image processor
aymeric-roucher Nov 26, 2024
aa93d6b
Update dummy objects
aymeric-roucher Nov 26, 2024
991ddab
Clean processor
aymeric-roucher Nov 26, 2024
b31fea8
Pass generation with input embeds
aymeric-roucher Nov 27, 2024
f8be039
Style
aymeric-roucher Nov 27, 2024
d56c158
Harmonize modular
aymeric-roucher Nov 27, 2024
ce84dcf
Try fixing weight init
aymeric-roucher Nov 27, 2024
5cc3a99
Remove image token from processing
aymeric-roucher Nov 27, 2024
dab4d0f
Try fix imports
aymeric-roucher Nov 27, 2024
1e7b83e
Try fix imports 2
aymeric-roucher Nov 27, 2024
43b5f0a
Working modular
Cyrilvallez Nov 28, 2024
bdd6c4f
and style
Cyrilvallez Nov 28, 2024
3df30fd
Repair image processing
aymeric-roucher Nov 28, 2024
f9d8d69
Merge remote-tracking branch 'origin/add-aria' into add-aria
aymeric-roucher Nov 28, 2024
e08ecf0
Style
aymeric-roucher Nov 28, 2024
248aa9d
Working inference
aymeric-roucher Nov 28, 2024
1ea3d17
Fix batch token counting
aymeric-roucher Nov 29, 2024
9b13ef1
Improve docstrings
aymeric-roucher Nov 29, 2024
6d98a0e
Add image processing tests
aymeric-roucher Nov 29, 2024
265ca08
Add image processing and processing tests
aymeric-roucher Nov 29, 2024
a4d8a1f
Directly copy llava next functions
aymeric-roucher Nov 29, 2024
f30fb5b
Merge branch 'main' into add-aria
aymeric-roucher Nov 29, 2024
a4ce9e9
Remove chat template
aymeric-roucher Nov 29, 2024
63e2276
Fix docstrings
aymeric-roucher Nov 29, 2024
15f21e2
Update conversion script
aymeric-roucher Nov 29, 2024
e73febc
Update src/transformers/models/aria/convert_aria_weights_to_hf.py
aymeric-roucher Dec 3, 2024
e03a05d
Update src/transformers/models/aria/configuration_aria.py
aymeric-roucher Dec 3, 2024
56942fe
Update src/transformers/models/aria/modular_aria.py
aymeric-roucher Dec 3, 2024
acfeb4b
Update src/transformers/models/aria/modular_aria.py
aymeric-roucher Dec 3, 2024
4e6688b
Answer comments
aymeric-roucher Dec 3, 2024
a006f6a
Simplify more elements
aymeric-roucher Dec 3, 2024
d45186e
Improve projector_patch_to_query_dict max value handling
aymeric-roucher Dec 3, 2024
1d924e0
Slight simplification of input type and device modification in gemm e…
aymeric-roucher Dec 3, 2024
cf42acc
Fix import errors
aymeric-roucher Dec 3, 2024
ca30b6e
Update fa2 support
aymeric-roucher Dec 3, 2024
67e5dbb
Fix test
aymeric-roucher Dec 3, 2024
87981b0
Add cpu back
aymeric-roucher Dec 3, 2024
142e061
Improve init
aymeric-roucher Dec 3, 2024
09fe137
Fix doc checks
aymeric-roucher Dec 4, 2024
acc9968
Soft dependencies handling
aymeric-roucher Dec 4, 2024
ec55502
Fix init import order
aymeric-roucher Dec 4, 2024
0af60a4
Merge branch 'main' into add-aria
aymeric-roucher Dec 4, 2024
f529bf8
Fix experts gemm selection
aymeric-roucher Dec 4, 2024
76d116b
Add idefics3 docs
aymeric-roucher Dec 4, 2024
ae7f5d0
Fix some docstring checks
aymeric-roucher Dec 4, 2024
959702b
Fix docstrings
aymeric-roucher Dec 4, 2024
461d14d
Try fix for unused config.intermediate_size
aymeric-roucher Dec 4, 2024
a109506
Try removing unusued config args - v2
aymeric-roucher Dec 4, 2024
be0e5a9
Remove moe_intermediate_size
aymeric-roucher Dec 4, 2024
8a45000
Add sdpa support
aymeric-roucher Dec 4, 2024
09dd7d4
Try fix docstrings
aymeric-roucher Dec 4, 2024
9c3dd8a
Update the conversion script 3
aymeric-roucher Dec 4, 2024
8fd065a
Final comment answer
aymeric-roucher Dec 5, 2024
fe62f6c
Merge branch 'main' into add-aria flaky 2
aymeric-roucher Dec 5, 2024
76ee868
Fix CUDA errors 3
aymeric-roucher Dec 6, 2024
956cea2
Remove duplicate init 2
aymeric-roucher Dec 6, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/en/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -810,6 +810,8 @@
title: ALIGN
- local: model_doc/altclip
title: AltCLIP
- local: model_doc/aria
title: Aria
- local: model_doc/blip
title: BLIP
- local: model_doc/blip-2
Expand Down
3 changes: 3 additions & 0 deletions docs/source/en/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ Flax), PyTorch, and/or TensorFlow.
| [ALBERT](model_doc/albert) | ✅ | ✅ | ✅ |
| [ALIGN](model_doc/align) | ✅ | ❌ | ❌ |
| [AltCLIP](model_doc/altclip) | ✅ | ❌ | ❌ |
| [Aria](model_doc/aria) | ✅ | ❌ | ❌ |
| [AriaText](model_doc/aria_text) | ✅ | ❌ | ❌ |
| [Audio Spectrogram Transformer](model_doc/audio-spectrogram-transformer) | ✅ | ❌ | ❌ |
| [Autoformer](model_doc/autoformer) | ✅ | ❌ | ❌ |
| [Bark](model_doc/bark) | ✅ | ❌ | ❌ |
Expand Down Expand Up @@ -172,6 +174,7 @@ Flax), PyTorch, and/or TensorFlow.
| [IDEFICS](model_doc/idefics) | ✅ | ✅ | ❌ |
| [Idefics2](model_doc/idefics2) | ✅ | ❌ | ❌ |
| [Idefics3](model_doc/idefics3) | ✅ | ❌ | ❌ |
| [Idefics3VisionTransformer](model_doc/idefics3_vision) | ❌ | ❌ | ❌ |
| [ImageGPT](model_doc/imagegpt) | ✅ | ❌ | ❌ |
| [Informer](model_doc/informer) | ✅ | ❌ | ❌ |
| [InstructBLIP](model_doc/instructblip) | ✅ | ❌ | ❌ |
Expand Down
106 changes: 106 additions & 0 deletions docs/source/en/model_doc/aria.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
<!--Copyright 2024 The HuggingFace Team. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.

⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.

-->

# Aria

## Overview

The Aria model was proposed in [Aria: An Open Multimodal Native Mixture-of-Experts Model](https://huggingface.co/papers/2410.05993) by Li et al. from the Rhymes.AI team.

Aria is an open multimodal-native model with best-in-class performance across a wide range of multimodal, language, and coding tasks. It has a Mixture-of-Experts architecture, with respectively 3.9B and 3.5B activated parameters per visual token and text token.

The abstract from the paper is the following:

*Information comes in diverse modalities. Multimodal native AI models are essential to integrate real-world information and deliver comprehensive understanding. While proprietary multimodal native models exist, their lack of openness imposes obstacles for adoptions, let alone adaptations. To fill this gap, we introduce Aria, an open multimodal native model with best-in-class performance across a wide range of multimodal, language, and coding tasks. Aria is a mixture-of-expert model with 3.9B and 3.5B activated parameters per visual token and text token, respectively. It outperforms Pixtral-12B and Llama3.2-11B, and is competitive against the best proprietary models on various multimodal tasks. We pre-train Aria from scratch following a 4-stage pipeline, which progressively equips the model with strong capabilities in language understanding, multimodal understanding, long context window, and instruction following. We open-source the model weights along with a codebase that facilitates easy adoptions and adaptations of Aria in real-world applications.*

This model was contributed by [m-ric](https://huggingface.co/m-ric).
The original code can be found [here](https://github.com/rhymes-ai/Aria).

## Usage tips

Here's how to use the model for vision tasks:
```python
import requests
import torch
from PIL import Image

from transformers import AriaProcessor, AriaForConditionalGeneration

model_id_or_path = "rhymes-ai/Aria"

model = AriaForConditionalGeneration.from_pretrained(
model_id_or_path, device_map="auto"
)

processor = AriaProcessor.from_pretrained(model_id_or_path)

image = Image.open(requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw)

messages = [
{
"role": "user",
"content": [
{"type": "image"},
{"text": "what is the image?", "type": "text"},
],
}
]

text = processor.apply_chat_template(messages, add_generation_prompt=True)
inputs = processor(text=text, images=image, return_tensors="pt")
inputs.to(model.device)

output = model.generate(
**inputs,
max_new_tokens=15,
stop_strings=["<|im_end|>"],
tokenizer=processor.tokenizer,
do_sample=True,
temperature=0.9,
)
output_ids = output[0][inputs["input_ids"].shape[1]:]
response = processor.decode(output_ids, skip_special_tokens=True)
```


## AriaImageProcessor

[[autodoc]] AriaImageProcessor

## AriaProcessor

[[autodoc]] AriaProcessor

## AriaTextConfig

[[autodoc]] AriaTextConfig

## AriaConfig

[[autodoc]] AriaConfig

## AriaTextModel

[[autodoc]] AriaTextModel

## AriaTextForCausalLM

[[autodoc]] AriaTextForCausalLM

## AriaForConditionalGeneration

[[autodoc]] AriaForConditionalGeneration
- forward
7 changes: 7 additions & 0 deletions docs/source/en/model_doc/idefics3.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,13 @@ This model was contributed by [amyeroberts](https://huggingface.co/amyeroberts)

[[autodoc]] Idefics3Config

## Idefics3VisionConfig

[[autodoc]] Idefics3VisionConfig

## Idefics3VisionTransformer

[[autodoc]] Idefics3VisionTransformer

## Idefics3Model

Expand Down
2 changes: 2 additions & 0 deletions docs/source/en/perf_infer_gpu_one.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ FlashAttention-2 is experimental and may change considerably in future versions.
2. partitioning the work between GPU threads to reduce communication and shared memory reads/writes between them

FlashAttention-2 is currently supported for the following architectures:
* [Aria](https://huggingface.co/docs/transformers/model_doc/aria#transformers.AriaForConditionalGeneration)
* [Bark](https://huggingface.co/docs/transformers/model_doc/bark#transformers.BarkModel)
* [Bart](https://huggingface.co/docs/transformers/model_doc/bart#transformers.BartModel)
* [Chameleon](https://huggingface.co/docs/transformers/model_doc/chameleon#transformers.Chameleon)
Expand Down Expand Up @@ -216,6 +217,7 @@ PyTorch's [`torch.nn.functional.scaled_dot_product_attention`](https://pytorch.o

For now, Transformers supports SDPA inference and training for the following architectures:
* [Albert](https://huggingface.co/docs/transformers/model_doc/albert#transformers.AlbertModel)
* [Aria](https://huggingface.co/docs/transformers/model_doc/aria#transformers.AriaForConditionalGeneration)
* [Audio Spectrogram Transformer](https://huggingface.co/docs/transformers/model_doc/audio-spectrogram-transformer#transformers.ASTModel)
* [Bart](https://huggingface.co/docs/transformers/model_doc/bart#transformers.BartModel)
* [Bert](https://huggingface.co/docs/transformers/model_doc/bert#transformers.BertModel)
Expand Down
32 changes: 32 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,11 @@
"AltCLIPTextConfig",
"AltCLIPVisionConfig",
],
"models.aria": [
"AriaConfig",
"AriaProcessor",
"AriaTextConfig",
],
"models.audio_spectrogram_transformer": [
"ASTConfig",
"ASTFeatureExtractor",
Expand Down Expand Up @@ -1176,6 +1181,7 @@
_import_structure["image_processing_base"] = ["ImageProcessingMixin"]
_import_structure["image_processing_utils"] = ["BaseImageProcessor"]
_import_structure["image_utils"] = ["ImageFeatureExtractionMixin"]
_import_structure["models.aria"].extend(["AriaImageProcessor"])
_import_structure["models.beit"].extend(["BeitFeatureExtractor", "BeitImageProcessor"])
_import_structure["models.bit"].extend(["BitImageProcessor"])
_import_structure["models.blip"].extend(["BlipImageProcessor"])
Expand Down Expand Up @@ -1406,6 +1412,15 @@
"AltCLIPVisionModel",
]
)
_import_structure["models.aria"].extend(
[
"AriaForConditionalGeneration",
"AriaPreTrainedModel",
"AriaTextForCausalLM",
"AriaTextModel",
"AriaTextPreTrainedModel",
]
)
_import_structure["models.audio_spectrogram_transformer"].extend(
[
"ASTForAudioClassification",
Expand Down Expand Up @@ -2461,6 +2476,8 @@
"Idefics3Model",
"Idefics3PreTrainedModel",
"Idefics3Processor",
"Idefics3VisionConfig",
"Idefics3VisionTransformer",
]
)
_import_structure["models.ijepa"].extend(
Expand Down Expand Up @@ -5033,6 +5050,11 @@
AltCLIPTextConfig,
AltCLIPVisionConfig,
)
from .models.aria import (
AriaConfig,
AriaProcessor,
AriaTextConfig,
)
from .models.audio_spectrogram_transformer import (
ASTConfig,
ASTFeatureExtractor,
Expand Down Expand Up @@ -6096,6 +6118,7 @@
from .image_processing_base import ImageProcessingMixin
from .image_processing_utils import BaseImageProcessor
from .image_utils import ImageFeatureExtractionMixin
from .models.aria import AriaImageProcessor
from .models.beit import BeitFeatureExtractor, BeitImageProcessor
from .models.bit import BitImageProcessor
from .models.blip import BlipImageProcessor
Expand Down Expand Up @@ -6325,6 +6348,13 @@
AltCLIPTextModel,
AltCLIPVisionModel,
)
from .models.aria import (
AriaForConditionalGeneration,
AriaPreTrainedModel,
AriaTextForCausalLM,
AriaTextModel,
AriaTextPreTrainedModel,
)
from .models.audio_spectrogram_transformer import (
ASTForAudioClassification,
ASTModel,
Expand Down Expand Up @@ -7189,6 +7219,8 @@
Idefics3Model,
Idefics3PreTrainedModel,
Idefics3Processor,
Idefics3VisionConfig,
Idefics3VisionTransformer,
)
from .models.ijepa import (
IJepaForImageClassification,
Expand Down
1 change: 1 addition & 0 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1465,6 +1465,7 @@ def _prepare_generated_length(
elif (
model_input_name == "inputs_embeds"
and input_ids_length != inputs_tensor.shape[1]
and input_ids_length != 0
and not self.config.is_encoder_decoder
):
generation_config.max_length -= inputs_tensor.shape[1]
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
albert,
align,
altclip,
aria,
audio_spectrogram_transformer,
auto,
autoformer,
Expand Down
30 changes: 30 additions & 0 deletions src/transformers/models/aria/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING

from ...utils import _LazyModule
from ...utils.import_utils import define_import_structure


if TYPE_CHECKING:
from .configuration_aria import *
from .image_processing_aria import *
from .modeling_aria import *
from .processing_aria import *

else:
import sys

_file = globals()["__file__"]
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
Loading
Loading