Skip to content

Commit ee6bdf5

Browse files
younesbelkadaArthurZucker
authored andcommitted
[FA-2] Final fix for FA2 dtype (huggingface#26846)
* final fix for FA2 dtype * try * oops * Update src/transformers/models/falcon/modeling_falcon.py Co-authored-by: Arthur <[email protected]> * apply fix everywhere --------- Co-authored-by: Arthur <[email protected]>
1 parent 8c7dcdf commit ee6bdf5

File tree

4 files changed

+69
-19
lines changed

4 files changed

+69
-19
lines changed

src/transformers/models/falcon/modeling_falcon.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -613,15 +613,18 @@ def forward(
613613
# cast them back in float16 just to be sure everything works as expected.
614614
input_dtype = query_layer.dtype
615615
if input_dtype == torch.float32:
616+
# Handle the case where the model is quantized
617+
target_dtype = getattr(self.config, "_pre_quantization_dtype", self.query_key_value.weight.dtype)
618+
616619
logger.warning_once(
617-
"The input hidden states seems to be silently casted in float32, this might be related to"
618-
" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
619-
" float16."
620+
f"The input hidden states seems to be silently casted in float32, this might be related to"
621+
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
622+
f" {target_dtype}."
620623
)
621624

622-
query_layer = query_layer.to(torch.float16)
623-
key_layer = key_layer.to(torch.float16)
624-
value_layer = value_layer.to(torch.float16)
625+
query_layer = query_layer.to(target_dtype)
626+
key_layer = key_layer.to(target_dtype)
627+
value_layer = value_layer.to(target_dtype)
625628

626629
attn_output = self._flash_attention_forward(
627630
query_layer, key_layer, value_layer, padding_mask, query_length, dropout=attn_dropout

src/transformers/models/llama/modeling_llama.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -469,20 +469,24 @@ def forward(
469469

470470
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
471471
# therefore the input hidden states gets silently casted in float32. Hence, we need
472-
# cast them back in float16 just to be sure everything works as expected.
472+
# cast them back in the correct dtype just to be sure everything works as expected.
473473
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
474474
# in fp32. (LlamaRMSNorm handles it correctly)
475+
475476
input_dtype = query_states.dtype
476477
if input_dtype == torch.float32:
478+
# Handle the case where the model is quantized
479+
target_dtype = getattr(self.config, "_pre_quantization_dtype", self.q_proj.weight.dtype)
480+
477481
logger.warning_once(
478-
"The input hidden states seems to be silently casted in float32, this might be related to"
479-
" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
480-
" float16."
482+
f"The input hidden states seems to be silently casted in float32, this might be related to"
483+
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
484+
f" {target_dtype}."
481485
)
482486

483-
query_states = query_states.to(torch.float16)
484-
key_states = key_states.to(torch.float16)
485-
value_states = value_states.to(torch.float16)
487+
query_states = query_states.to(target_dtype)
488+
key_states = key_states.to(target_dtype)
489+
value_states = value_states.to(target_dtype)
486490

487491
attn_output = self._flash_attention_forward(
488492
query_states, key_states, value_states, padding_mask, q_len, dropout=dropout_rate

src/transformers/models/mistral/modeling_mistral.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -408,15 +408,18 @@ def forward(
408408
# cast them back in float16 just to be sure everything works as expected.
409409
input_dtype = query_states.dtype
410410
if input_dtype == torch.float32:
411+
# Handle the case where the model is quantized
412+
target_dtype = getattr(self.config, "_pre_quantization_dtype", self.q_proj.weight.dtype)
413+
411414
logger.warning_once(
412-
"The input hidden states seems to be silently casted in float32, this might be related to"
413-
" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
414-
" float16."
415+
f"The input hidden states seems to be silently casted in float32, this might be related to"
416+
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
417+
f" {target_dtype}."
415418
)
416419

417-
query_states = query_states.to(torch.float16)
418-
key_states = key_states.to(torch.float16)
419-
value_states = value_states.to(torch.float16)
420+
query_states = query_states.to(target_dtype)
421+
key_states = key_states.to(target_dtype)
422+
value_states = value_states.to(target_dtype)
420423

421424
# Reashape to the expected shape for Flash Attention
422425
query_states = query_states.transpose(1, 2)

tests/test_modeling_common.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@
6464
is_pt_flax_cross_test,
6565
is_pt_tf_cross_test,
6666
require_accelerate,
67+
require_bitsandbytes,
6768
require_flash_attn,
6869
require_safetensors,
6970
require_torch,
@@ -2959,6 +2960,45 @@ def test_flash_attn_2_generate_use_cache(self):
29592960
dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=30, do_sample=False
29602961
)
29612962

2963+
@require_flash_attn
2964+
@require_torch_gpu
2965+
@require_bitsandbytes
2966+
@mark.flash_attn_test
2967+
@slow
2968+
def test_flash_attn_2_fp32_ln(self):
2969+
import torch
2970+
2971+
for model_class in self.all_generative_model_classes:
2972+
if not model_class._supports_flash_attn_2:
2973+
return
2974+
2975+
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
2976+
model = model_class(config)
2977+
2978+
with tempfile.TemporaryDirectory() as tmpdirname:
2979+
model.save_pretrained(tmpdirname)
2980+
2981+
dummy_input = torch.LongTensor([[0, 2, 3, 4], [0, 2, 3, 4]]).to(torch_device)
2982+
dummy_attention_mask = torch.LongTensor([[1, 1, 1, 1], [0, 1, 1, 1]]).to(torch_device)
2983+
2984+
model = model_class.from_pretrained(
2985+
tmpdirname,
2986+
torch_dtype=torch.float16,
2987+
use_flash_attention_2=True,
2988+
low_cpu_mem_usage=True,
2989+
load_in_4bit=True,
2990+
)
2991+
2992+
for _, param in model.named_parameters():
2993+
# upcast only layer norms
2994+
if (param.dtype == torch.float16) or (param.dtype == torch.bfloat16):
2995+
param.data = param.data.to(torch.float32)
2996+
2997+
_ = model(input_ids=dummy_input)
2998+
2999+
# with attention mask
3000+
_ = model(input_ids=dummy_input, attention_mask=dummy_attention_mask)
3001+
29623002

29633003
global_rng = random.Random()
29643004

0 commit comments

Comments
 (0)