You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
[Mistral] Add Flash Attention-2 support for mistral (#26464)
* add FA-2 support for mistral
* fixup
* add sliding windows
* fixing few nits
* v1 slicing cache - logits do not match
* add comment
* fix bugs
* more mem efficient
* add warning once
* add warning once
* oops
* fixup
* more comments
* copy
* add safety checker
* fixup
* Update src/transformers/models/mistral/modeling_mistral.py
Co-authored-by: Arthur <[email protected]>
* copied from
* up
* raise when padding side is right
* fixup
* add doc + few minor changes
* fixup
---------
Co-authored-by: Arthur <[email protected]>
model = MistralForCausalLM.from_pretrained("/output/path")
83
83
```
84
84
85
+
## Combining Mistral and Flash Attention 2
86
+
87
+
First, make sure to install the latest version of Flash Attention 2 to include the sliding window attention feature.
88
+
89
+
```bash
90
+
pip install -U flash-attn --no-build-isolation
91
+
```
92
+
93
+
Make also sure that you have a hardware that is compatible with Flash-Attention 2. Read more about it in the official documentation of [`flash-attn`](https://github.com/Dao-AILab/flash-attention) repository. Make also sure to load your model in half-precision (e.g. `torch.float16`)
94
+
95
+
To load and run a model using Flash Attention 2, refer to the snippet below:
Below is a expected speedup diagram that compares pure inference time between the native implementation in transformers using `mistralai/Mistral-7B-v0.1` checkpoint and the Flash Attention 2 version of the model.
The current implementation supports the sliding window attention mechanism and memory efficient cache management.
126
+
To enable sliding window attention, just make sure to have a `flash-attn` version that is compatible with sliding window attention (`>=2.3.0`).
127
+
128
+
The Flash Attention-2 model uses also a more memory efficient cache slicing mechanism - as recommended per the official implementation of Mistral model that use rolling cache mechanism we keep the cache size fixed (`self.config.sliding_window`), support batched generation only for `padding_side="left"` and use the absolute position of the current token to compute the positional embedding.
129
+
85
130
## The Mistral Team
86
131
87
132
Albert Jiang, Alexandre Sablayrolles, Arthur Mensch, Chris Bamford, Devendra Singh Chaplot, Diego de las Casas, Florian Bressand, Gianna Lengyel, Guillaume Lample, Lélio Renard Lavaud, Lucile Saulnier, Marie-Anne Lachaux, Pierre Stock, Teven Le Scao, Thibaut Lavril, Thomas Wang, Timothée Lacroix, William El Sayed.
Copy file name to clipboardExpand all lines: docs/source/en/perf_infer_gpu_one.md
+1Lines changed: 1 addition & 0 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -32,6 +32,7 @@ Make sure to follow the installation guide on the repository mentioned above to
32
32
We natively support Flash Attention 2 for the following models:
33
33
34
34
- Llama
35
+
- Mistral
35
36
- Falcon
36
37
37
38
You can request to add Flash Attention 2 support for more models by opening an issue on GitHub, and even open a Pull Request to integrate the changes. The supported models can be used for inference and training, including training with padding tokens - *which is currently not supported for `BetterTransformer` API below.*
0 commit comments