Skip to content

Commit 368a58e

Browse files
younesbelkadafxmartypatrickvonplatenArthurZuckerLysandreJik
authored
[core ] Integrate Flash attention 2 in most used models (#25598)
* v1 * oops * working v1 * fixup * add some TODOs * fixup * padding support + try with module replacement * nit * alternative design * oops * add `use_cache` support for llama * v1 falcon * nit * a bit of refactor * nit * nits nits * add v1 padding support falcon (even though it seemed to work before) * nit * falcon works * fixup * v1 tests * nit * fix generation llama flash * update tests * fix tests + nits * fix copies * fix nit * test- padding mask * stype * add more mem efficient support * Update src/transformers/modeling_utils.py Co-authored-by: Patrick von Platen <[email protected]> * fixup * nit * fixup * remove it from config when saving * fixup * revert docstring * add more checks * use values * oops * new version * fixup * add same trick for falcon * nit * add another test * change tests * fix issues with GC and also falcon * fixup * oops * Update src/transformers/models/falcon/modeling_falcon.py Co-authored-by: Arthur <[email protected]> * add init_rope * updates * fix copies * fixup * fixup * more clarification * fixup * right padding tests * add docs * add FA in docker image * more clarifications * add some figures * add todo * rectify comment * Change to FA2 * Update docs/source/en/perf_infer_gpu_one.md Co-authored-by: Arthur <[email protected]> * split in two lines * change test name * add more tests * some clean up * remove `rearrange` deps * add more docs * revert changes on dockerfile * Revert "revert changes on dockerfile" This reverts commit 8d72a66. * revert changes on dockerfile * Apply suggestions from code review Co-authored-by: Lysandre Debut <[email protected]> * address some comments * docs * use inheritance * Update src/transformers/testing_utils.py Co-authored-by: Lysandre Debut <[email protected]> * fixup * Apply suggestions from code review Co-authored-by: Arthur <[email protected]> * Update src/transformers/modeling_utils.py * final comments * clean up * style * add cast + warning for PEFT models * fixup --------- Co-authored-by: Felix Marty <[email protected]> Co-authored-by: Patrick von Platen <[email protected]> Co-authored-by: Arthur <[email protected]> Co-authored-by: Lysandre Debut <[email protected]>
1 parent dcbfd93 commit 368a58e

File tree

14 files changed

+934
-14
lines changed

14 files changed

+934
-14
lines changed

docs/source/en/perf_infer_gpu_many.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,10 @@ Note: A multi GPU setup can use the majority of the strategies described in the
2222

2323
</Tip>
2424

25+
## Flash Attention 2
26+
27+
Flash Attention 2 integration also works in a multi-GPU setup, check out the appropriate section in the [single GPU section](./perf_infer_gpu_one#Flash-Attention-2)
28+
2529
## BetterTransformer
2630

2731
[BetterTransformer](https://huggingface.co/docs/optimum/bettertransformer/overview) converts 🤗 Transformers models to use the PyTorch-native fastpath execution, which calls optimized kernels like Flash Attention under the hood.

docs/source/en/perf_infer_gpu_one.md

Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,154 @@ rendered properly in your Markdown viewer.
1717

1818
In addition to this guide, relevant information can be found as well in [the guide for training on a single GPU](perf_train_gpu_one) and [the guide for inference on CPUs](perf_infer_cpu).
1919

20+
## Flash Attention 2
21+
22+
<Tip>
23+
24+
Note that this feature is experimental and might considerably change in future versions. For instance, the Flash Attention 2 API might migrate to `BetterTransformer` API in the near future.
25+
26+
</Tip>
27+
28+
Flash Attention 2 can considerably speed up transformer-based models' training and inference speed. Flash Attention 2 has been introduced in the [official Flash Attention repository](https://github.com/Dao-AILab/flash-attention) by Tri Dao et al. The scientific paper on Flash Attention can be found [here](https://arxiv.org/abs/2205.14135).
29+
30+
Make sure to follow the installation guide on the repository mentioned above to properly install Flash Attention 2. Once that package is installed, you can benefit from this feature.
31+
32+
We natively support Flash Attention 2 for the following models:
33+
34+
- Llama
35+
- Falcon
36+
37+
You can request to add Flash Attention 2 support for more models by opening an issue on GitHub, and even open a Pull Request to integrate the changes. The supported models can be used for inference and training, including training with padding tokens - *which is currently not supported for `BetterTransformer` API below.*
38+
39+
<Tip>
40+
41+
Flash Attention 2 can only be used when the models' dtype is `fp16` or `bf16` and runs only on NVIDIA-GPU devices. Make sure to cast your model to the appropriate dtype and load them on a supported device before using that feature.
42+
43+
</Tip>
44+
45+
### Quick usage
46+
47+
To enable Flash Attention 2 in your model, add `use_flash_attention_2` in the `from_pretrained` arguments:
48+
49+
```python
50+
import torch
51+
from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaForCausalLM
52+
53+
model_id = "tiiuae/falcon-7b"
54+
tokenizer = AutoTokenizer.from_pretrained(model_id)
55+
56+
model = AutoModelForCausalLM.from_pretrained(
57+
model_id,
58+
torch_dtype=torch.bfloat16,
59+
use_flash_attention_2=True,
60+
)
61+
```
62+
63+
And use it for generation or fine-tuning.
64+
65+
### Expected speedups
66+
67+
You can benefit from considerable speedups for fine-tuning and inference, especially for long sequences. However, since Flash Attention does not support computing attention scores with padding tokens under the hood, we must manually pad / unpad the attention scores for batched inference when the sequence contains padding tokens. This leads to a significant slowdown for batched generations with padding tokens.
68+
69+
To overcome this, one should use Flash Attention without padding tokens in the sequence for training (e.g., by packing a dataset, i.e., concatenating sequences until reaching the maximum sequence length. An example is provided [here](https://github.com/huggingface/transformers/blob/main/examples/pytorch/language-modeling/run_clm.py#L516).
70+
71+
Below is the expected speedup you can get for a simple forward pass on [tiiuae/falcon-7b](https://hf.co/tiiuae/falcon-7b) with a sequence length of 4096 and various batch sizes without padding tokens:
72+
73+
Below is the expected speedup you can get for a simple forward pass on [tiiuae/falcon-7b](https://hf.co/tiiuae/falcon-7b) with a sequence length of 4096 and various batch sizes, without padding tokens:
74+
75+
<div style="text-align: center">
76+
<img src="https://huggingface.co/datasets/ybelkada/documentation-images/resolve/main/falcon-7b-inference-large-seqlen.png">
77+
</div>
78+
79+
Below is the expected speedup you can get for a simple forward pass on [`meta-llama/Llama-7b-hf`](https://hf.co/meta-llama/Llama-7b-hf) with a sequence length of 4096 and various batch sizes, without padding tokens:
80+
81+
<div style="text-align: center">
82+
<img src="https://huggingface.co/datasets/ybelkada/documentation-images/resolve/main/llama-7b-inference-large-seqlen.png">
83+
</div>
84+
85+
For sequences with padding tokens (training with padding tokens or generating with padding tokens), we need to unpad / pad the input sequences to compute correctly the attention scores. For relatively small sequence length, on pure forward pass, this creates an overhead leading to a small speedup (below 30% of the input has been filled with padding tokens).
86+
87+
<div style="text-align: center">
88+
<img src="https://huggingface.co/datasets/ybelkada/documentation-images/resolve/main/llama-2-small-seqlen-padding.png">
89+
</div>
90+
91+
But for large sequence length you can benefit from interesting speedup for pure inference (also training)
92+
93+
Note that Flash Attention makes the attention computation more memory efficient, meaning you can train with much larger sequence lengths without facing CUDA OOM issues. It can lead up to memory reduction up to 20 for large sequence length. Check out [the official flash attention repository](https://github.com/Dao-AILab/flash-attention) for more details.
94+
95+
<div style="text-align: center">
96+
<img src="https://huggingface.co/datasets/ybelkada/documentation-images/resolve/main/llama-2-large-seqlen-padding.png">
97+
</div>
98+
99+
100+
### Advanced usage
101+
102+
You can combine this feature with many exisiting feature for model optimization. Check out few examples below:
103+
104+
### Combining Flash Attention 2 and 8-bit models
105+
106+
You can combine this feature together with 8-bit quantization:
107+
108+
```python
109+
import torch
110+
from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaForCausalLM
111+
112+
model_id = "tiiuae/falcon-7b"
113+
tokenizer = AutoTokenizer.from_pretrained(model_id)
114+
115+
model = AutoModelForCausalLM.from_pretrained(
116+
model_id,
117+
load_in_8bit=True,
118+
use_flash_attention_2=True,
119+
)
120+
```
121+
122+
### Combining Flash Attention 2 and 4-bit models
123+
124+
You can combine this feature together with 4-bit quantization:
125+
126+
```python
127+
import torch
128+
from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaForCausalLM
129+
130+
model_id = "tiiuae/falcon-7b"
131+
tokenizer = AutoTokenizer.from_pretrained(model_id)
132+
133+
model = AutoModelForCausalLM.from_pretrained(
134+
model_id,
135+
load_in_4bit=True,
136+
use_flash_attention_2=True,
137+
)
138+
```
139+
140+
### Combining Flash Attention 2 and PEFT
141+
142+
You can combine this feature together with PEFT for training adapters using Flash Attention 2 under the hood:
143+
144+
```python
145+
import torch
146+
from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaForCausalLM
147+
from peft import LoraConfig
148+
149+
model_id = "tiiuae/falcon-7b"
150+
tokenizer = AutoTokenizer.from_pretrained(model_id)
151+
152+
model = AutoModelForCausalLM.from_pretrained(
153+
model_id,
154+
load_in_4bit=True,
155+
use_flash_attention_2=True,
156+
)
157+
158+
lora_config = LoraConfig(
159+
r=8,
160+
task_type="CAUSAL_LM"
161+
)
162+
163+
model.add_adapter(lora_config)
164+
165+
... # train your model
166+
```
167+
20168
## BetterTransformer
21169

22170
[BetterTransformer](https://huggingface.co/docs/optimum/bettertransformer/overview) converts 🤗 Transformers models to use the PyTorch-native fastpath execution, which calls optimized kernels like Flash Attention under the hood.

docs/source/en/perf_train_gpu_one.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,10 @@ For additional information on tf32 vs other precisions, please refer to the foll
228228
[RTX-3090](https://github.com/huggingface/transformers/issues/14608#issuecomment-1004390803) and
229229
[A100](https://github.com/huggingface/transformers/issues/15026#issuecomment-1004543189).
230230

231+
## Flash Attention 2
232+
233+
You can speedup the training throughput by using Flash Attention 2 integration in transformers. Check out the appropriate section in the [single GPU section](./perf_infer_gpu_one#Flash-Attention-2) to learn more about how to load a model with Flash Attention 2 modules.
234+
231235
## Optimizer choice
232236

233237
The most common optimizer used to train transformer models is Adam or AdamW (Adam with weight decay). Adam achieves

src/transformers/configuration_utils.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -855,6 +855,9 @@ def to_diff_dict(self) -> Dict[str, Any]:
855855

856856
self.dict_torch_dtype_to_str(serializable_config_dict)
857857

858+
if "_flash_attn_2_enabled" in serializable_config_dict:
859+
del serializable_config_dict["_flash_attn_2_enabled"]
860+
858861
return serializable_config_dict
859862

860863
def to_dict(self) -> Dict[str, Any]:
@@ -871,6 +874,8 @@ def to_dict(self) -> Dict[str, Any]:
871874
del output["_auto_class"]
872875
if "_commit_hash" in output:
873876
del output["_commit_hash"]
877+
if "_flash_attn_2_enabled" in output:
878+
del output["_flash_attn_2_enabled"]
874879

875880
# Transformers version when serializing the model
876881
output["transformers_version"] = __version__

src/transformers/modeling_utils.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@
7070
is_accelerate_available,
7171
is_auto_gptq_available,
7272
is_bitsandbytes_available,
73+
is_flash_attn_available,
7374
is_offline_mode,
7475
is_optimum_available,
7576
is_peft_available,
@@ -1116,6 +1117,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
11161117
is_parallelizable = False
11171118
supports_gradient_checkpointing = False
11181119

1120+
# Flash Attention 2 support
1121+
_supports_flash_attn_2 = False
1122+
11191123
@property
11201124
def dummy_inputs(self) -> Dict[str, torch.Tensor]:
11211125
"""
@@ -1239,6 +1243,84 @@ def can_generate(cls) -> bool:
12391243
return False
12401244
return True
12411245

1246+
@classmethod
1247+
def _check_and_enable_flash_attn_2(
1248+
cls, config, torch_dtype: Optional[torch.dtype] = None, device_map: Optional[Union[str, Dict[str, int]]] = None
1249+
) -> PretrainedConfig:
1250+
"""
1251+
If you don't know about Flash Attention, check out the official repository of flash attention:
1252+
https://github.com/Dao-AILab/flash-attention
1253+
1254+
For using Flash Attention 1.0 you can do it directly via the `BetterTransformer` API, have a look at this
1255+
specific section of the documentation to learn more about it:
1256+
https://huggingface.co/docs/transformers/main/en/perf_infer_gpu_one#decoder-models
1257+
1258+
The method checks if the current setup is compatible with Flash Attention as it requires the model to be in
1259+
half precision and not ran on CPU.
1260+
1261+
If all checks pass, the method will create an attribute in the config `_flash_attn_2_enabled` so that the model
1262+
can initialize the correct attention module
1263+
"""
1264+
if not cls._supports_flash_attn_2:
1265+
raise ValueError(
1266+
"The current architecture does not support Flash Attention 2.0. Please open an issue on GitHub to "
1267+
"request support for this architecture: https://github.com/huggingface/transformers/issues/new"
1268+
)
1269+
1270+
if not is_flash_attn_available():
1271+
raise ImportError(
1272+
"Flash Attention 2.0 is not available. Please refer to the documentation of https://github.com/Dao-AILab/flash-attention for"
1273+
" installing it."
1274+
)
1275+
else:
1276+
flash_attention_version = version.parse(importlib.metadata.version("flash_attn"))
1277+
is_flash_greater_than_2 = flash_attention_version > version.parse("2.0.0")
1278+
if not is_flash_greater_than_2:
1279+
raise ValueError(
1280+
f"You need flash_attn package version to be greater than 2.0. Make sure to have that version installed - detected version {flash_attention_version}"
1281+
)
1282+
1283+
_is_bettertransformer = getattr(cls, "use_bettertransformer", False)
1284+
1285+
if _is_bettertransformer:
1286+
raise ValueError(
1287+
"Flash Attention 2 and BetterTransformer API are not compatible. Please make sure to disable BetterTransformers by doing model.reverse_bettertransformer()"
1288+
)
1289+
1290+
if torch_dtype is None:
1291+
logger.warning(
1292+
"You are attempting to use Flash Attention 2.0 without specifying a torch dtype. This might lead to unexpected behaviour"
1293+
)
1294+
elif torch_dtype is not None and torch_dtype not in [torch.float16, torch.bfloat16]:
1295+
raise ValueError(
1296+
f"Flash Attention 2.0 only supports torch.float16 and torch.bfloat16 dtypes. You passed {torch_dtype}, this might lead to"
1297+
" unexpected behaviour."
1298+
)
1299+
1300+
if device_map is None:
1301+
if torch.cuda.is_available():
1302+
logger.warning(
1303+
"You are attempting to use Flash Attention 2.0 with a model initialized on CPU. Make sure to move the model to GPU"
1304+
" after initializing it on CPU with `model.to('cuda')`."
1305+
)
1306+
else:
1307+
raise ValueError(
1308+
"You are attempting to use Flash Attention 2.0 with a model initialized on CPU and with no GPU available. "
1309+
"This is not supported yet. Please make sure to have access to a GPU and either initialise the model on a GPU by passing a device_map "
1310+
"or initialising the model on CPU and then moving it to GPU."
1311+
)
1312+
elif (
1313+
device_map is not None
1314+
and isinstance(device_map, dict)
1315+
and ("cpu" in device_map.values() or "disk" in device_map.values())
1316+
):
1317+
raise ValueError(
1318+
"You are attempting to use Flash Attention 2.0 with a model dispatched on CPU or disk. This is not supported. Please make sure to "
1319+
"initialise the model on a GPU by passing a device_map that contains only GPU devices as keys."
1320+
)
1321+
config._flash_attn_2_enabled = True
1322+
return config
1323+
12421324
def enable_input_require_grads(self):
12431325
"""
12441326
Enables the gradients for the input embeddings. This is useful for fine-tuning adapter weights while keeping
@@ -2374,6 +2456,7 @@ def from_pretrained(
23742456
variant = kwargs.pop("variant", None)
23752457
_adapter_model_path = kwargs.pop("_adapter_model_path", None)
23762458
adapter_name = kwargs.pop("adapter_name", "default")
2459+
use_flash_attention_2 = kwargs.pop("use_flash_attention_2", False)
23772460

23782461
if is_fsdp_enabled():
23792462
low_cpu_mem_usage = True
@@ -2977,6 +3060,9 @@ def from_pretrained(
29773060
elif load_in_8bit or load_in_4bit or low_cpu_mem_usage:
29783061
init_contexts.append(init_empty_weights())
29793062

3063+
if use_flash_attention_2:
3064+
config = cls._check_and_enable_flash_attn_2(config, torch_dtype=torch_dtype, device_map=device_map)
3065+
29803066
with ContextManagers(init_contexts):
29813067
model = cls(config, *model_args, **model_kwargs)
29823068

src/transformers/models/deprecated/open_llama/modeling_open_llama.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -364,7 +364,6 @@ def __init__(self, config: OpenLlamaConfig):
364364
self.input_layernorm = OpenLlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
365365
self.post_attention_layernorm = OpenLlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
366366

367-
# Copied from transformers.models.llama.modeling_llama.LlamaDecoderLayer.forward
368367
def forward(
369368
self,
370369
hidden_states: torch.Tensor,

0 commit comments

Comments
 (0)