You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: docs/source/en/perf_infer_gpu_many.md
+4Lines changed: 4 additions & 0 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -22,6 +22,10 @@ Note: A multi GPU setup can use the majority of the strategies described in the
22
22
23
23
</Tip>
24
24
25
+
## Flash Attention 2
26
+
27
+
Flash Attention 2 integration also works in a multi-GPU setup, check out the appropriate section in the [single GPU section](./perf_infer_gpu_one#Flash-Attention-2)
28
+
25
29
## BetterTransformer
26
30
27
31
[BetterTransformer](https://huggingface.co/docs/optimum/bettertransformer/overview) converts 🤗 Transformers models to use the PyTorch-native fastpath execution, which calls optimized kernels like Flash Attention under the hood.
Copy file name to clipboardExpand all lines: docs/source/en/perf_infer_gpu_one.md
+148Lines changed: 148 additions & 0 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -17,6 +17,154 @@ rendered properly in your Markdown viewer.
17
17
18
18
In addition to this guide, relevant information can be found as well in [the guide for training on a single GPU](perf_train_gpu_one) and [the guide for inference on CPUs](perf_infer_cpu).
19
19
20
+
## Flash Attention 2
21
+
22
+
<Tip>
23
+
24
+
Note that this feature is experimental and might considerably change in future versions. For instance, the Flash Attention 2 API might migrate to `BetterTransformer` API in the near future.
25
+
26
+
</Tip>
27
+
28
+
Flash Attention 2 can considerably speed up transformer-based models' training and inference speed. Flash Attention 2 has been introduced in the [official Flash Attention repository](https://github.com/Dao-AILab/flash-attention) by Tri Dao et al. The scientific paper on Flash Attention can be found [here](https://arxiv.org/abs/2205.14135).
29
+
30
+
Make sure to follow the installation guide on the repository mentioned above to properly install Flash Attention 2. Once that package is installed, you can benefit from this feature.
31
+
32
+
We natively support Flash Attention 2 for the following models:
33
+
34
+
- Llama
35
+
- Falcon
36
+
37
+
You can request to add Flash Attention 2 support for more models by opening an issue on GitHub, and even open a Pull Request to integrate the changes. The supported models can be used for inference and training, including training with padding tokens - *which is currently not supported for `BetterTransformer` API below.*
38
+
39
+
<Tip>
40
+
41
+
Flash Attention 2 can only be used when the models' dtype is `fp16` or `bf16` and runs only on NVIDIA-GPU devices. Make sure to cast your model to the appropriate dtype and load them on a supported device before using that feature.
42
+
43
+
</Tip>
44
+
45
+
### Quick usage
46
+
47
+
To enable Flash Attention 2 in your model, add `use_flash_attention_2` in the `from_pretrained` arguments:
48
+
49
+
```python
50
+
import torch
51
+
from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaForCausalLM
You can benefit from considerable speedups for fine-tuning and inference, especially for long sequences. However, since Flash Attention does not support computing attention scores with padding tokens under the hood, we must manually pad / unpad the attention scores for batched inference when the sequence contains padding tokens. This leads to a significant slowdown for batched generations with padding tokens.
68
+
69
+
To overcome this, one should use Flash Attention without padding tokens in the sequence for training (e.g., by packing a dataset, i.e., concatenating sequences until reaching the maximum sequence length. An example is provided [here](https://github.com/huggingface/transformers/blob/main/examples/pytorch/language-modeling/run_clm.py#L516).
70
+
71
+
Below is the expected speedup you can get for a simple forward pass on [tiiuae/falcon-7b](https://hf.co/tiiuae/falcon-7b) with a sequence length of 4096 and various batch sizes without padding tokens:
72
+
73
+
Below is the expected speedup you can get for a simple forward pass on [tiiuae/falcon-7b](https://hf.co/tiiuae/falcon-7b) with a sequence length of 4096 and various batch sizes, without padding tokens:
Below is the expected speedup you can get for a simple forward pass on [`meta-llama/Llama-7b-hf`](https://hf.co/meta-llama/Llama-7b-hf) with a sequence length of 4096 and various batch sizes, without padding tokens:
For sequences with padding tokens (training with padding tokens or generating with padding tokens), we need to unpad / pad the input sequences to compute correctly the attention scores. For relatively small sequence length, on pure forward pass, this creates an overhead leading to a small speedup (below 30% of the input has been filled with padding tokens).
But for large sequence length you can benefit from interesting speedup for pure inference (also training)
92
+
93
+
Note that Flash Attention makes the attention computation more memory efficient, meaning you can train with much larger sequence lengths without facing CUDA OOM issues. It can lead up to memory reduction up to 20 for large sequence length. Check out [the official flash attention repository](https://github.com/Dao-AILab/flash-attention) for more details.
[BetterTransformer](https://huggingface.co/docs/optimum/bettertransformer/overview) converts 🤗 Transformers models to use the PyTorch-native fastpath execution, which calls optimized kernels like Flash Attention under the hood.
You can speedup the training throughput by using Flash Attention 2 integration in transformers. Check out the appropriate section in the [single GPU section](./perf_infer_gpu_one#Flash-Attention-2) to learn more about how to load a model with Flash Attention 2 modules.
234
+
231
235
## Optimizer choice
232
236
233
237
The most common optimizer used to train transformer models is Adam or AdamW (Adam with weight decay). Adam achieves
"Flash Attention 2 and BetterTransformer API are not compatible. Please make sure to disable BetterTransformers by doing model.reverse_bettertransformer()"
1288
+
)
1289
+
1290
+
iftorch_dtypeisNone:
1291
+
logger.warning(
1292
+
"You are attempting to use Flash Attention 2.0 without specifying a torch dtype. This might lead to unexpected behaviour"
0 commit comments