Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
399 commits
Select commit Hold shift + click to select a range
8ba624d
current changes
ArthurZucker Sep 9, 2024
2dd4828
nit
zucchini-nlp Sep 9, 2024
29d9f0d
Add cross_attenttion_mask to processor
Sep 9, 2024
829ee71
multi-image fixed
zucchini-nlp Sep 9, 2024
8d27b9e
Add cross_attenttion_mask to processor
Sep 9, 2024
01dd827
cross attn works in all cases
zucchini-nlp Sep 9, 2024
547ffe2
WIP refactoring function for image processor
Sep 8, 2024
a814074
WIP refactoring image processor functions
Sep 8, 2024
75a8608
Refactor preprocess to use global loops instead of list nested list c…
Sep 8, 2024
e97e7c6
Docstrings
Sep 8, 2024
9a26b55
Add channels unification
Sep 9, 2024
aa9b752
fix dtype issues
zucchini-nlp Sep 9, 2024
ac1c665
Update docsrings and format
Sep 9, 2024
720b5e2
Consistent max_image_tiles
Sep 9, 2024
5d5183c
current script
ArthurZucker Sep 9, 2024
522c87a
updates
ArthurZucker Sep 9, 2024
b3bc43a
Add convert to rgb
Sep 9, 2024
f36e877
Add image processor tests
Sep 9, 2024
90e962a
updates!
ArthurZucker Sep 9, 2024
6436a0c
update
ArthurZucker Sep 9, 2024
2dbb5f7
god damn it I am dumb sometimes
ArthurZucker Sep 9, 2024
472ade3
Precompute aspect ratios
Sep 9, 2024
0c76794
now this works, full match
ArthurZucker Sep 10, 2024
185b431
fix :wink:
ArthurZucker Sep 10, 2024
0951023
nits
ArthurZucker Sep 10, 2024
17fd76a
style
ArthurZucker Sep 10, 2024
74b2911
Merge remote-tracking branch 'origin/refactor-mlamma' into merge-ever…
zucchini-nlp Sep 10, 2024
b673b11
Merge remote-tracking branch 'origin/cache-generation' into merge-eve…
zucchini-nlp Sep 10, 2024
6086d49
fix model and conversion
ArthurZucker Sep 10, 2024
93112a8
nit
ArthurZucker Sep 10, 2024
8157032
nit
ArthurZucker Sep 10, 2024
674abef
kinda works
zucchini-nlp Sep 10, 2024
9423a4e
hack for sdpa non-contiguous bias
zucchini-nlp Sep 10, 2024
bd051f7
nits here and there
ArthurZucker Sep 10, 2024
1a36a3b
latest c hanges
ArthurZucker Sep 10, 2024
3e79305
Merge branch 'merge-everything' of github.com:huggingface/new-model-a…
ArthurZucker Sep 10, 2024
3bfb77a
merge?
ArthurZucker Sep 10, 2024
3204a82
run forward
ArthurZucker Sep 10, 2024
41f4646
Add aspect_ratio_mask
qubvel Sep 10, 2024
5d18eb7
vision attention mask
ArthurZucker Sep 10, 2024
c9844f3
Merge branch 'add_mllama' of github.com:huggingface/new-model-additio…
ArthurZucker Sep 10, 2024
e5bfa68
update script and config variable names
ArthurZucker Sep 11, 2024
6aede44
nit
ArthurZucker Sep 11, 2024
52ffa45
nits
ArthurZucker Sep 11, 2024
8766649
be able to load
ArthurZucker Sep 11, 2024
ab53962
style
ArthurZucker Sep 11, 2024
db348d4
nits
ArthurZucker Sep 11, 2024
0237806
there
ArthurZucker Sep 11, 2024
a5208de
nits
ArthurZucker Sep 11, 2024
1f710d6
make forward run
ArthurZucker Sep 11, 2024
6e2a6ea
small update
zucchini-nlp Sep 11, 2024
d2f64aa
enable generation multi-turn
zucchini-nlp Sep 11, 2024
301afe7
nit
zucchini-nlp Sep 11, 2024
f7d088e
nit
zucchini-nlp Sep 11, 2024
f9a8f33
Merge branch 'merge-everything' of github.com:huggingface/new-model-a…
ArthurZucker Sep 11, 2024
dfc9361
Clean up a bit for errors and typos
Sep 11, 2024
a8a9229
A bit more constant fixes
Sep 11, 2024
d8e4402
90B keys and shapes match
Sep 11, 2024
7fe06fe
Fix for 11B model
Sep 11, 2024
44d9cc0
Fixup, remove debug part
Sep 12, 2024
1ad83cd
Docs
Sep 12, 2024
9483220
Make max_aspect_ratio_id to be minimal
Sep 12, 2024
6b76850
Update image processing code to match new implementation
Sep 12, 2024
734a80f
Adjust conversion for final checkpoint state
Sep 12, 2024
447b152
Change dim in repeat_interleave (accordig to meta code)
Sep 13, 2024
84c544e
tmp fix for num_tiles
Sep 13, 2024
e69ef74
Fix for conversion (gate<->up, q/k_proj rope permute)
Sep 13, 2024
9509d9c
nits
zucchini-nlp Sep 13, 2024
ee86085
codestyle
zucchini-nlp Sep 13, 2024
6e90e91
Vision encoder fixes
Sep 13, 2024
c9601e3
pass cross attn mask further
zucchini-nlp Sep 13, 2024
b04a9e2
Refactor aspect ratio mask
Sep 13, 2024
261d544
Disable text-only generation
Sep 13, 2024
b76d62e
Fix cross attention layers order, remove q/k norm rotation for cross …
Sep 15, 2024
205413f
Refactor gated position embeddings
Sep 15, 2024
b4e56e9
fix bugs but needs test with new weights
zucchini-nlp Sep 15, 2024
1e640fd
rope scaling should be llama3
zucchini-nlp Sep 16, 2024
2539204
Fix rope scaling name
Sep 16, 2024
0036813
Remove debug for linear layer
Sep 16, 2024
07da55c
fix copies
zucchini-nlp Sep 13, 2024
db1f442
Make mask prepare private func
Sep 16, 2024
c041a66
Remove linear patch embed
qubvel Sep 16, 2024
66e1333
Make precomputed embeddings as nn.Embedding module
qubvel Sep 16, 2024
9cad38e
MllamaPrecomputedAspectRatioEmbedding with config init
qubvel Sep 16, 2024
d2fe072
Remove unused self.output_dim
qubvel Sep 16, 2024
17c138d
nit, intermediate layers
qubvel Sep 16, 2024
4d83beb
Rename ln and pos_embed
qubvel Sep 16, 2024
aa6c083
vision_chunk_size -> image_size
qubvel Sep 16, 2024
bb076e7
return_intermediate -> intermediate_layers_indices
qubvel Sep 16, 2024
fc5775d
vision_input_dim -> hidden_size
qubvel Sep 16, 2024
f053357
Fix copied from statements
qubvel Sep 16, 2024
f51ccec
fix most tests
zucchini-nlp Sep 16, 2024
b896288
Fix more copied from
qubvel Sep 16, 2024
264a98c
layer_id->layer_idx
qubvel Sep 16, 2024
da55d6b
Comment
qubvel Sep 16, 2024
154f3c9
Fix tests for processor
qubvel Sep 16, 2024
76b1b8d
Copied from for _prepare_4d_causal_attention_mask_with_cache_position
qubvel Sep 16, 2024
e173d73
Style fix
qubvel Sep 16, 2024
315242f
Add MllamaForCausalLM
qubvel Sep 16, 2024
70f78ff
Merge branch 'refactor-mlamma' into refactor-configuration
qubvel Sep 16, 2024
4fbb348
WIP fixing tests
qubvel Sep 16, 2024
eb411aa
Remove duplicated layers
qubvel Sep 16, 2024
78ddd94
Remove dummy file
qubvel Sep 16, 2024
7770b7f
Fix style
qubvel Sep 16, 2024
037d834
Fix consistency
qubvel Sep 16, 2024
69bf639
Fix some TODOs
qubvel Sep 16, 2024
567fe6a
fix language_model instantiation, add docstring
qubvel Sep 16, 2024
9decb19
Move docstring, remove todos for precomputed embeds (we cannot init t…
qubvel Sep 16, 2024
3e677c9
Add initial docstrings
qubvel Sep 16, 2024
41a4d8a
Fix
qubvel Sep 16, 2024
0350540
Merge pull request #10 from huggingface/refactor-configuration
qubvel Sep 16, 2024
ebb4abc
fix some tests
zucchini-nlp Sep 16, 2024
e502eb2
lets skip these
zucchini-nlp Sep 16, 2024
f371a44
nits, remove print, style
qubvel Sep 16, 2024
68d3e1b
Add one more copied from
qubvel Sep 17, 2024
34d04ce
Improve test message
qubvel Sep 17, 2024
cb762ec
Make validate func private
qubvel Sep 17, 2024
4b3a8dd
Fix dummy objects
qubvel Sep 17, 2024
642ed76
Refactor `data_format` a bit + add comment
qubvel Sep 17, 2024
39d7c1a
typos/nits
qubvel Sep 17, 2024
1adadd6
fix dummy objects and imports
qubvel Sep 17, 2024
b445b5b
Merge branch 'refactor-mlamma' of github.com:huggingface/new-model-ad…
qubvel Sep 17, 2024
34c28d9
Add chat template config json
qubvel Sep 17, 2024
34ebc93
remove num_kv_heads from vision attention
qubvel Sep 17, 2024
3b7c06a
fix
ydshieh Sep 17, 2024
e27b3b2
move some commits and add more tests
zucchini-nlp Sep 18, 2024
4425062
fix test
zucchini-nlp Sep 18, 2024
0d12f59
Remove `update_key_name` from modeling utils
qubvel Sep 18, 2024
fe374f2
remove num-kv-heads again
zucchini-nlp Sep 18, 2024
0c7609a
some prelimiary docs
zucchini-nlp Sep 18, 2024
d4b547b
Update chat template + tests
qubvel Sep 18, 2024
82f5745
Merge branch 'refactor-mlamma' of github.com:huggingface/new-model-ad…
qubvel Sep 18, 2024
1c515cc
nit, conversion script max_num_tiles from params
qubvel Sep 18, 2024
a968e5e
Fix warning for text-only generation
qubvel Sep 18, 2024
c1a2310
Update conversion script for instruct models
qubvel Sep 19, 2024
2a49b97
Update chat template in converstion + test
qubvel Sep 19, 2024
ca94ea0
add tests for CausalLM model
zucchini-nlp Sep 19, 2024
5d98a72
model_max_length, avoid null chat_template
pcuenca Sep 19, 2024
717d579
Merge pull request #13 from huggingface/mllama-converter-updates
pcuenca Sep 19, 2024
2f3c084
Refactor conversion script
qubvel Sep 20, 2024
70729df
Fix forward
qubvel Sep 20, 2024
31724eb
Fix integration tests
qubvel Sep 20, 2024
380c020
Refactor vision config + docs
qubvel Sep 20, 2024
a8253ca
Fix default
qubvel Sep 20, 2024
d03f20e
Refactor text config
qubvel Sep 20, 2024
40f58f9
Doc fixes
qubvel Sep 20, 2024
4c4509d
Remove unused args, fix docs example
qubvel Sep 20, 2024
19f67e2
Squashed commit of the following:
qubvel Sep 20, 2024
241cfc4
Fix num_channels
qubvel Sep 20, 2024
bc4ea49
Add mllama text and mllama vision models
qubvel Sep 20, 2024
1cc2417
Fixing repo consistency
qubvel Sep 20, 2024
e1ffc7e
Style fix
qubvel Sep 20, 2024
6ce09f8
Fixing repo consistency
qubvel Sep 20, 2024
6a75a0c
Fixing unused config params
qubvel Sep 20, 2024
3402477
Fix failed tests after refactoring
qubvel Sep 20, 2024
4c70fa4
hidden_activation -> hidden_act for text mlp
qubvel Sep 23, 2024
c9b7938
Remove from_pretrained from sub-configs
qubvel Sep 23, 2024
e6da565
Apply suggestions from code review
qubvel Sep 23, 2024
cf870d1
Merge branch 'refactor-mlamma' of github.com:huggingface/new-model-ad…
qubvel Sep 23, 2024
708f068
Update src/transformers/models/mllama/convert_mllama_weights_to_hf.py
qubvel Sep 23, 2024
1782fa6
Reuse lambda in conversion script
qubvel Sep 23, 2024
4648cb7
Remove run.py
qubvel Sep 23, 2024
4c9d2a0
Update docs/source/en/model_doc/mllama.md
qubvel Sep 23, 2024
0a17136
Update src/transformers/models/mllama/processing_mllama.py
qubvel Sep 23, 2024
99da538
Remove unused LlamaTokenizerFast
qubvel Sep 23, 2024
ecf0120
Fix logging
qubvel Sep 23, 2024
5c5582a
Refactor gating
qubvel Sep 23, 2024
a73f415
Remove cycle for collecting intermediate states
qubvel Sep 23, 2024
f19c674
Refactor text-only check, add integration test for text-only
qubvel Sep 23, 2024
8e2faf7
Revert from pretrained to configs
qubvel Sep 23, 2024
242af3b
Fix example
qubvel Sep 23, 2024
b473efe
Add auto `bos_token` adding in processor
qubvel Sep 23, 2024
933a48f
Fix tips
qubvel Sep 23, 2024
fd4d1b1
Update src/transformers/models/auto/tokenization_auto.py
qubvel Sep 23, 2024
dc892be
Enable supports_gradient_checkpointing model flag
qubvel Sep 23, 2024
8a0a614
Merge branch 'refactor-mlamma' of github.com:huggingface/new-model-ad…
qubvel Sep 23, 2024
45b2e4c
add eager/sdpa options
zucchini-nlp Sep 23, 2024
e754702
don't skip attn tests and bring back GC skips (did i really remove th…
zucchini-nlp Sep 23, 2024
9ec36f5
Fix signature, but get error with None gradient
qubvel Sep 23, 2024
6647660
Fix output attention tests
qubvel Sep 23, 2024
7997676
Disable GC back
qubvel Sep 23, 2024
5bc72b2
Change no split modules
qubvel Sep 23, 2024
46486b1
Fix dropout
qubvel Sep 23, 2024
ea9b1bd
Style
qubvel Sep 23, 2024
febaf6d
Add Mllama to sdpa list
qubvel Sep 23, 2024
b6db674
Add post init for vision model
qubvel Sep 23, 2024
01d1394
Refine config for MllamaForCausalLMModelTest and skipped tests for Ca…
qubvel Sep 23, 2024
444c9d5
if skipped, say it, don't pass
ArthurZucker Sep 23, 2024
fd22057
Clean vision tester config
qubvel Sep 23, 2024
55146a4
Doc for args
qubvel Sep 23, 2024
781f0c1
Merge pull request #14 from huggingface/fix-gradient-checkpointing
qubvel Sep 23, 2024
c53420d
Update tests/models/mllama/test_modeling_mllama.py
qubvel Sep 23, 2024
d657a4e
Add cross_attention_mask to test
qubvel Sep 23, 2024
abc07b3
typehint
qubvel Sep 23, 2024
15f5fea
Remove todo
qubvel Sep 23, 2024
d175518
Enable gradient checkpointing
qubvel Sep 23, 2024
6be06fd
Docstring
qubvel Sep 23, 2024
94585bd
Style
qubvel Sep 23, 2024
6b7baca
Fixing and skipping some tests for new cache
qubvel Sep 23, 2024
3dd8cbe
Mark flaky test
qubvel Sep 24, 2024
8e98212
Skip `test_sdpa_can_compile_dynamic` test
qubvel Sep 24, 2024
1f22643
Fixing some offload tests
qubvel Sep 24, 2024
718e4e0
Add direct GenerationMixin inheritance
qubvel Sep 24, 2024
2a09ec4
Remove unused code
qubvel Sep 24, 2024
7558cbf
Add initializer_range to vision config
qubvel Sep 24, 2024
a7c7569
Merge branch 'refactor-mlamma' of github.com:huggingface/new-model-ad…
ArthurZucker Sep 24, 2024
fd3519c
update the test to make sure we show if split
ArthurZucker Sep 24, 2024
6058246
fix gc?
ArthurZucker Sep 24, 2024
3b1d617
Fix repo consistency
qubvel Sep 24, 2024
f5e1582
Undo modeling utils debug changes
qubvel Sep 24, 2024
f310328
Fix link
qubvel Sep 24, 2024
3034407
mllama -> Mllama
qubvel Sep 24, 2024
fe23b4c
[mllama] -> [Mllama]
qubvel Sep 24, 2024
4edb097
Enable compile test for CausalLM model (text-only)
qubvel Sep 24, 2024
e83f89c
Fix TextModel prefix
qubvel Sep 24, 2024
848df51
Update doc
qubvel Sep 24, 2024
3c2b37e
Docs for forward, type hints, and vision model prefix
qubvel Sep 24, 2024
aa4323a
make sure to reset
ArthurZucker Sep 25, 2024
f35d128
Merge branch 'refactor-mlamma' of github.com:huggingface/new-model-ad…
ArthurZucker Sep 25, 2024
2fcfa0c
fix init
ArthurZucker Sep 25, 2024
bc8caea
small script refactor and styling
ArthurZucker Sep 25, 2024
680d809
nit
ArthurZucker Sep 25, 2024
e9e197e
updates!
ArthurZucker Sep 25, 2024
0842a60
some nits
ArthurZucker Sep 25, 2024
6c58488
Interpolate embeddings for 560 size and update integration tests
qubvel Sep 25, 2024
1d3a266
nit
ArthurZucker Sep 25, 2024
c6788dc
does not suppor static cache!
ArthurZucker Sep 25, 2024
f40ce28
update
ArthurZucker Sep 25, 2024
554ea46
Merge pull request #15 from huggingface/embed-interpolation
qubvel Sep 25, 2024
7c4747e
Merge branch 'main' of github.com:huggingface/new-model-addition-mlla…
ArthurZucker Sep 25, 2024
af23a1a
fix
ArthurZucker Sep 25, 2024
3fb241b
nit2
ArthurZucker Sep 25, 2024
d685124
this?
ArthurZucker Sep 25, 2024
fba6b53
Fix conversion
qubvel Sep 25, 2024
4aedcea
Style
qubvel Sep 25, 2024
928bcf6
4x memory improvement with image cache AFAIK
ArthurZucker Sep 25, 2024
ef5f5a4
Merge branch 'refactor-mlamma' of github.com:huggingface/new-model-ad…
ArthurZucker Sep 25, 2024
19d7e89
Token decorator for tests
qubvel Sep 25, 2024
63aff72
Skip failing tests
qubvel Sep 25, 2024
1d1dc6c
update processor errors
ArthurZucker Sep 25, 2024
ed90bda
fix split issues
ArthurZucker Sep 25, 2024
dafa298
Merge branch 'refactor-mlamma' of github.com:huggingface/new-model-ad…
ArthurZucker Sep 25, 2024
b0eff7e
style
ArthurZucker Sep 25, 2024
af0b3eb
weird
ArthurZucker Sep 25, 2024
db626e2
style
ArthurZucker Sep 25, 2024
4999784
fix failing tests
ArthurZucker Sep 25, 2024
c798599
update
ArthurZucker Sep 25, 2024
3d8405d
nit fixing the whisper tests
ArthurZucker Sep 25, 2024
a17100c
fix path
ArthurZucker Sep 25, 2024
0a991ce
update
ArthurZucker Sep 25, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/en/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -860,6 +860,8 @@
title: MatCha
- local: model_doc/mgp-str
title: MGP-STR
- local: model_doc/mllama
title: mllama
- local: model_doc/nougat
title: Nougat
- local: model_doc/oneformer
Expand Down
1 change: 1 addition & 0 deletions docs/source/en/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,7 @@ Flax), PyTorch, and/or TensorFlow.
| [Mimi](model_doc/mimi) | ✅ | ❌ | ❌ |
| [Mistral](model_doc/mistral) | ✅ | ✅ | ✅ |
| [Mixtral](model_doc/mixtral) | ✅ | ❌ | ❌ |
| [Mllama](model_doc/mllama) | ✅ | ❌ | ❌ |
| [mLUKE](model_doc/mluke) | ✅ | ❌ | ❌ |
| [MMS](model_doc/mms) | ✅ | ✅ | ✅ |
| [MobileBERT](model_doc/mobilebert) | ✅ | ✅ | ❌ |
Expand Down
124 changes: 124 additions & 0 deletions docs/source/en/model_doc/mllama.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
<!--Copyright 2024 The HuggingFace Team. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.

⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.

-->

# Mllama

## Overview

The Llama 3.2-Vision collection of multimodal large language models (LLMs) is a collection of pretrained and instruction-tuned image reasoning generative models in 11B and 90B sizes (text \+ images in / text out). The Llama 3.2-Vision instruction-tuned models are optimized for visual recognition, image reasoning, captioning, and answering general questions about an image.

**Model Architecture:** Llama 3.2-Vision is built on top of Llama 3.1 text-only model, which is an auto-regressive language model that uses an optimized transformer architecture. The tuned versions use supervised fine-tuning (SFT) and reinforcement learning with human feedback (RLHF) to align with human preferences for helpfulness and safety. To support image recognition tasks, the Llama 3.2-Vision model uses a separately trained vision adapter that integrates with the pre-trained Llama 3.1 language model. The adapter consists of a series of cross-attention layers that feed image encoder representations into the core LLM.

## Usage Tips

- For image+text and text inputs use `MllamaForConditionalGeneration`.
- For text-only inputs use `MllamaForCausalLM` for generation to avoid loading vision tower.
- Each sample can contain multiple images, and the number of images can vary between samples. The processor will pad the inputs to the maximum number of images across samples and to a maximum number of tiles within each image.
- The text passed to the processor should have the `"<|image|>"` tokens where the images should be inserted.
- The processor has its own `apply_chat_template` method to convert chat messages to text that can then be passed as text to the processor.

## Usage Example

#### Instruct model
```python
import requests
import torch
from PIL import Image
from transformers import MllamaForConditionalGeneration, AutoProcessor

model_id = "meta-llama/Llama-3.2-11B-Vision-Instruct"
model = MllamaForConditionalGeneration.from_pretrained(model_id, device_map="auto", torch_dtype=torch.bfloat16)
processor = AutoProcessor.from_pretrained(model_id)

messages = [
[
{
"role": "user",
"content": [
{"type": "image"},
{"type": "text", "text": "What does the image show?"}
]
}
],
]
text = processor.apply_chat_template(messages, add_generation_prompt=True)

url = "https://llava-vl.github.io/static/images/view.jpg"
image = Image.open(requests.get(url, stream=True).raw)

inputs = processor(text=text, images=image, return_tensors="pt").to(model.device)
output = model.generate(**inputs, max_new_tokens=25)
print(processor.decode(output[0]))
```

#### Base model
```python
import requests
import torch
from PIL import Image
from transformers import MllamaForConditionalGeneration, AutoProcessor

model_id = "meta-llama/Llama-3.2-11B-Vision"
model = MllamaForConditionalGeneration.from_pretrained(model_id, device_map="auto", torch_dtype=torch.bfloat16)
processor = AutoProcessor.from_pretrained(model_id)

prompt = "<|image|>If I had to write a haiku for this one"
url = "https://llava-vl.github.io/static/images/view.jpg"
raw_image = Image.open(requests.get(url, stream=True).raw)

inputs = processor(text=prompt, images=raw_image, return_tensors="pt").to(model.device)
output = model.generate(**inputs, do_sample=False, max_new_tokens=25)
print(processor.decode(output[0], skip_special_tokens=True))
```


## MllamaConfig

[[autodoc]] MllamaConfig

## MllamaProcessor

[[autodoc]] MllamaProcessor


## MllamaImageProcessor

[[autodoc]] MllamaImageProcessor

## MllamaForConditionalGeneration

[[autodoc]] MllamaForConditionalGeneration
- forward

## MllamaForCausalLM

[[autodoc]] MllamaForCausalLM
- forward

## MllamaTextModel

[[autodoc]] MllamaTextModel
- forward

## MllamaForCausalLM

[[autodoc]] MllamaForCausalLM
- forward

## MllamaVisionModel

[[autodoc]] MllamaVisionModel
- forward
1 change: 1 addition & 0 deletions docs/source/en/perf_infer_gpu_one.md
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,7 @@ For now, Transformers supports SDPA inference and training for the following arc
* [LLaVA-Onevision](https://huggingface.co/docs/transformers/model_doc/llava_onevision)
* [Mimi](https://huggingface.co/docs/transformers/model_doc/mimi)
* [Mistral](https://huggingface.co/docs/transformers/model_doc/mistral#transformers.MistralModel)
* [Mllama](https://huggingface.co/docs/transformers/model_doc/mllama#transformers.MllamaForConditionalGeneration)
* [Mixtral](https://huggingface.co/docs/transformers/model_doc/mixtral#transformers.MixtralModel)
* [Musicgen](https://huggingface.co/docs/transformers/model_doc/musicgen#transformers.MusicgenModel)
* [MusicGen Melody](https://huggingface.co/docs/transformers/model_doc/musicgen_melody#transformers.MusicgenMelodyModel)
Expand Down
28 changes: 28 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -577,6 +577,10 @@
"models.mimi": ["MimiConfig"],
"models.mistral": ["MistralConfig"],
"models.mixtral": ["MixtralConfig"],
"models.mllama": [
"MllamaConfig",
"MllamaProcessor",
],
"models.mluke": [],
"models.mobilebert": [
"MobileBertConfig",
Expand Down Expand Up @@ -1195,6 +1199,7 @@
)
_import_structure["models.mask2former"].append("Mask2FormerImageProcessor")
_import_structure["models.maskformer"].extend(["MaskFormerFeatureExtractor", "MaskFormerImageProcessor"])
_import_structure["models.mllama"].extend(["MllamaImageProcessor"])
_import_structure["models.mobilenet_v1"].extend(["MobileNetV1FeatureExtractor", "MobileNetV1ImageProcessor"])
_import_structure["models.mobilenet_v2"].extend(["MobileNetV2FeatureExtractor", "MobileNetV2ImageProcessor"])
_import_structure["models.mobilevit"].extend(["MobileViTFeatureExtractor", "MobileViTImageProcessor"])
Expand Down Expand Up @@ -2700,6 +2705,16 @@
"MixtralPreTrainedModel",
]
)
_import_structure["models.mllama"].extend(
[
"MllamaForCausalLM",
"MllamaForConditionalGeneration",
"MllamaPreTrainedModel",
"MllamaProcessor",
"MllamaTextModel",
"MllamaVisionModel",
]
)
_import_structure["models.mobilebert"].extend(
[
"MobileBertForMaskedLM",
Expand Down Expand Up @@ -5367,6 +5382,10 @@
)
from .models.mistral import MistralConfig
from .models.mixtral import MixtralConfig
from .models.mllama import (
MllamaConfig,
MllamaProcessor,
)
from .models.mobilebert import (
MobileBertConfig,
MobileBertTokenizer,
Expand Down Expand Up @@ -6023,6 +6042,7 @@
MaskFormerFeatureExtractor,
MaskFormerImageProcessor,
)
from .models.mllama import MllamaImageProcessor
from .models.mobilenet_v1 import (
MobileNetV1FeatureExtractor,
MobileNetV1ImageProcessor,
Expand Down Expand Up @@ -7256,6 +7276,14 @@
MixtralModel,
MixtralPreTrainedModel,
)
from .models.mllama import (
MllamaForCausalLM,
MllamaForConditionalGeneration,
MllamaPreTrainedModel,
MllamaProcessor,
MllamaTextModel,
MllamaVisionModel,
)
from .models.mobilebert import (
MobileBertForMaskedLM,
MobileBertForMultipleChoice,
Expand Down
81 changes: 53 additions & 28 deletions src/transformers/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,10 +80,12 @@ def get_usable_length(self, new_seq_length: int, layer_idx: Optional[int] = 0) -
def reorder_cache(self, beam_idx: torch.LongTensor):
"""Reorders the cache for beam search, given the selected beam indices."""
for layer_idx in range(len(self.key_cache)):
device = self.key_cache[layer_idx].device
self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device))
device = self.value_cache[layer_idx].device
self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device))
if self.key_cache[layer_idx] != []:
device = self.key_cache[layer_idx].device
self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device))
if self.value_cache[layer_idx] != []:
device = self.value_cache[layer_idx].device
self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device))

@property
def seen_tokens(self):
Expand Down Expand Up @@ -358,10 +360,14 @@ class DynamicCache(Cache):
```
"""

def __init__(self) -> None:
def __init__(self, num_hidden_layers: Optional[int] = None) -> None:
super().__init__()
self.key_cache: List[torch.Tensor] = []
self.value_cache: List[torch.Tensor] = []
if num_hidden_layers is None:
self.key_cache: List[torch.Tensor] = []
self.value_cache: List[torch.Tensor] = []
else:
self.key_cache: List[torch.Tensor] = [[] for _ in range(num_hidden_layers)]
self.value_cache: List[torch.Tensor] = [[] for _ in range(num_hidden_layers)]
self._seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen

def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]:
Expand Down Expand Up @@ -420,6 +426,11 @@ def update(
if len(self.key_cache) <= layer_idx:
self.key_cache.append(key_states)
self.value_cache.append(value_states)
# content on layer cache can be a tensor and checking not tensor causes errors
# so we explicitly check for the empty list
elif self.key_cache[layer_idx] == []:
self.key_cache[layer_idx] = key_states
self.value_cache[layer_idx] = value_states
else:
self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
Expand All @@ -429,7 +440,7 @@ def update(
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
"""Returns the sequence length of the cached states. A layer index can be optionally passed."""
# TODO: deprecate this function in favor of `cache_position`
if len(self.key_cache) <= layer_idx:
if len(self.key_cache) <= layer_idx or (len(self.key_cache) > layer_idx and self.key_cache[layer_idx] == []):
return 0
return self.key_cache[layer_idx].shape[-2]

Expand All @@ -446,10 +457,12 @@ def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]:
return legacy_cache

@classmethod
def from_legacy_cache(cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None) -> "DynamicCache":
def from_legacy_cache(
cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, num_hidden_layers: int = None
) -> "DynamicCache":
"""Converts a cache in the legacy cache format into an equivalent `DynamicCache`. Used for
backward compatibility."""
cache = cls()
cache = cls(num_hidden_layers)
if past_key_values is not None:
for layer_idx in range(len(past_key_values)):
key_states, value_states = past_key_values[layer_idx]
Expand All @@ -468,30 +481,34 @@ def crop(self, max_length: int):

self._seen_tokens = max_length
for idx in range(len(self.key_cache)):
self.key_cache[idx] = self.key_cache[idx][..., :max_length, :]
self.value_cache[idx] = self.value_cache[idx][..., :max_length, :]
if self.key_cache[idx] != []:
self.key_cache[idx] = self.key_cache[idx][..., :max_length, :]
self.value_cache[idx] = self.value_cache[idx][..., :max_length, :]

def batch_split(self, full_batch_size: int, split_size: int) -> List["DynamicCache"]:
def batch_split(self, full_batch_size: int, split_size: int, num_hidden_layers: int) -> List["DynamicCache"]:
"""Split the current instance into a list of `DynamicCache` by the batch size. This will be used by
`_split_model_inputs()` in `generation.utils`"""
out = []
for i in range(0, full_batch_size, split_size):
current_split = DynamicCache()
current_split = DynamicCache(num_hidden_layers)
current_split._seen_tokens = self._seen_tokens
current_split.key_cache = [tensor[i : i + split_size] for tensor in self.key_cache]
current_split.value_cache = [tensor[i : i + split_size] for tensor in self.value_cache]
out.append(current_split)
return out

@classmethod
def from_batch_splits(cls, splits: List["DynamicCache"]) -> "DynamicCache":
def from_batch_splits(cls, splits: List["DynamicCache"], num_hidden_layers: int) -> "DynamicCache":
"""This is the opposite of the above `batch_split()` method. This will be used by `stack_model_outputs` in
`generation.utils`"""
cache = cls()
cache = cls(num_hidden_layers)
for idx in range(len(splits[0])):
layer_keys = torch.cat([current.key_cache[idx] for current in splits], dim=0)
layer_values = torch.cat([current.value_cache[idx] for current in splits], dim=0)
cache.update(layer_keys, layer_values, idx)
key_cache = [current.key_cache[idx] for current in splits if current.key_cache[idx] != []]
value_cache = [current.key_cache[idx] for current in splits if current.key_cache[idx] != []]
if key_cache != []:
layer_keys = torch.cat(key_cache, dim=0)
layer_values = torch.cat(value_cache, dim=0)
cache.update(layer_keys, layer_values, idx)
return cache

def batch_repeat_interleave(self, repeats: int):
Expand Down Expand Up @@ -1391,10 +1408,13 @@ def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]:

@classmethod
def from_legacy_cache(
cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, num_hidden_layers: int = None
) -> "EncoderDecoderCache":
"""Converts a cache in the legacy cache format into an equivalent `EncoderDecoderCache`."""
cache = cls(self_attention_cache=DynamicCache(), cross_attention_cache=DynamicCache())
cache = cls(
self_attention_cache=DynamicCache(num_hidden_layers),
cross_attention_cache=DynamicCache(num_hidden_layers),
)
if past_key_values is not None:
for layer_idx in range(len(past_key_values)):
key_states, value_states = past_key_values[layer_idx][:2]
Expand All @@ -1407,7 +1427,10 @@ def from_legacy_cache(

def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
"""Returns the sequence length of the cached states. A layer index can be optionally passed."""
if len(self.self_attention_cache.key_cache) <= layer_idx:
# check if empty list because in case of static cache it will be a tensors and we can't check `if not torch.Tensor`
if self.self_attention_cache.key_cache == []:
return 0
if len(self.self_attention_cache.key_cache) > 1 and self.self_attention_cache.key_cache[layer_idx] == []:
return 0
return (self.self_attention_cache.key_cache[layer_idx][0, 0].any(dim=-1)).sum()

Expand Down Expand Up @@ -1448,24 +1471,26 @@ def crop(self, maximum_length: int):
self.check_dynamic_cache(self.crop.__name__)
self.self_attention_cache.crop(maximum_length)

def batch_split(self, full_batch_size: int, split_size: int) -> "List[EncoderDecoderCache]":
def batch_split(
self, full_batch_size: int, split_size: int, num_hidden_layers: int
) -> "List[EncoderDecoderCache]":
"""Split the current instance into a list of `DynamicCache` by the batch size. This will be used by
`_split_model_inputs()` in `generation.utils`"""
self.check_dynamic_cache(self.batch_split.__name__)
self_attention_cache = self.self_attention_cache.batch_split(full_batch_size, split_size)
cross_attention_cache = self.cross_attention_cache.batch_split(full_batch_size, split_size)
self_attention_cache = self.self_attention_cache.batch_split(full_batch_size, split_size, num_hidden_layers)
cross_attention_cache = self.cross_attention_cache.batch_split(full_batch_size, split_size, num_hidden_layers)

out = []
for self_attn, cross_attn in zip(self_attention_cache, cross_attention_cache):
out.append(EncoderDecoderCache(self_attn, cross_attn))
return out

@classmethod
def from_batch_splits(cls, splits: List["EncoderDecoderCache"]) -> "EncoderDecoderCache":
def from_batch_splits(cls, splits: List["EncoderDecoderCache"], num_hidden_layers: int) -> "EncoderDecoderCache":
"""This is the opposite of the above `batch_split()` method. This will be used by `stack_model_outputs` in
`generation.utils`"""
self_attention_cache = DynamicCache()
cross_attention_cache = DynamicCache()
self_attention_cache = DynamicCache(num_hidden_layers)
cross_attention_cache = DynamicCache(num_hidden_layers)
for idx in range(len(splits[0])):
layer_keys = torch.cat([current.self_attention_cache.key_cache[idx] for current in splits], dim=0)
layer_values = torch.cat([current.self_attention_cache.value_cache[idx] for current in splits], dim=0)
Expand Down
Loading
Loading