Skip to content

Commit ae9a344

Browse files
[Mistral] Add Flash Attention-2 support for mistral (#26464)
* add FA-2 support for mistral * fixup * add sliding windows * fixing few nits * v1 slicing cache - logits do not match * add comment * fix bugs * more mem efficient * add warning once * add warning once * oops * fixup * more comments * copy * add safety checker * fixup * Update src/transformers/models/mistral/modeling_mistral.py Co-authored-by: Arthur <[email protected]> * copied from * up * raise when padding side is right * fixup * add doc + few minor changes * fixup --------- Co-authored-by: Arthur <[email protected]>
1 parent 1a2e966 commit ae9a344

File tree

5 files changed

+435
-5
lines changed

5 files changed

+435
-5
lines changed

docs/source/en/model_doc/mistral.md

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,51 @@ tokenizer = LlamaTokenizer.from_pretrained("/output/path")
8282
model = MistralForCausalLM.from_pretrained("/output/path")
8383
```
8484

85+
## Combining Mistral and Flash Attention 2
86+
87+
First, make sure to install the latest version of Flash Attention 2 to include the sliding window attention feature.
88+
89+
```bash
90+
pip install -U flash-attn --no-build-isolation
91+
```
92+
93+
Make also sure that you have a hardware that is compatible with Flash-Attention 2. Read more about it in the official documentation of [`flash-attn`](https://github.com/Dao-AILab/flash-attention) repository. Make also sure to load your model in half-precision (e.g. `torch.float16`)
94+
95+
To load and run a model using Flash Attention 2, refer to the snippet below:
96+
97+
```python
98+
>>> import torch
99+
>>> from transformers import AutoModelForCausalLM, AutoTokenizer
100+
>>> device = "cuda" # the device to load the model onto
101+
102+
>>> model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1", torch_dtype=torch.float16, use_flash_attention_2=True)
103+
>>> tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1")
104+
105+
>>> prompt = "My favourite condiment is"
106+
107+
>>> model_inputs = tokenizer([prompt], return_tensors="pt").to(device)
108+
>>> model.to(device)
109+
110+
>>> generated_ids = model.generate(**model_inputs, max_new_tokens=100, do_sample=True)
111+
>>> tokenizer.batch_decode(generated_ids)[0]
112+
"The expected outupt"
113+
```
114+
115+
### Expected speedups
116+
117+
Below is a expected speedup diagram that compares pure inference time between the native implementation in transformers using `mistralai/Mistral-7B-v0.1` checkpoint and the Flash Attention 2 version of the model.
118+
119+
<div style="text-align: center">
120+
<img src="https://huggingface.co/datasets/ybelkada/documentation-images/resolve/main/mistral-7b-inference-large-seqlen.png">
121+
</div>
122+
123+
### Sliding window Attention
124+
125+
The current implementation supports the sliding window attention mechanism and memory efficient cache management.
126+
To enable sliding window attention, just make sure to have a `flash-attn` version that is compatible with sliding window attention (`>=2.3.0`).
127+
128+
The Flash Attention-2 model uses also a more memory efficient cache slicing mechanism - as recommended per the official implementation of Mistral model that use rolling cache mechanism we keep the cache size fixed (`self.config.sliding_window`), support batched generation only for `padding_side="left"` and use the absolute position of the current token to compute the positional embedding.
129+
85130
## The Mistral Team
86131

87132
Albert Jiang, Alexandre Sablayrolles, Arthur Mensch, Chris Bamford, Devendra Singh Chaplot, Diego de las Casas, Florian Bressand, Gianna Lengyel, Guillaume Lample, Lélio Renard Lavaud, Lucile Saulnier, Marie-Anne Lachaux, Pierre Stock, Teven Le Scao, Thibaut Lavril, Thomas Wang, Timothée Lacroix, William El Sayed.

docs/source/en/perf_infer_gpu_one.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ Make sure to follow the installation guide on the repository mentioned above to
3232
We natively support Flash Attention 2 for the following models:
3333

3434
- Llama
35+
- Mistral
3536
- Falcon
3637

3738
You can request to add Flash Attention 2 support for more models by opening an issue on GitHub, and even open a Pull Request to integrate the changes. The supported models can be used for inference and training, including training with padding tokens - *which is currently not supported for `BetterTransformer` API below.*

0 commit comments

Comments
 (0)