Skip to content

Conversation

@younesbelkada
Copy link
Contributor

What does this PR do?

Fixes: #26451

Currently performing bf16 fine-tuning with FA-2 leads to hidden states silently being casted in float16
As it is challenging to retrieve the original dtype of the model in case the model is quantized, I propose to store that dtype in a private attribute to be able to retrieve it conveniently without having to perform any sort of hack that gets the correct dtype if the model is quantized

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A small nit!

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay! LGTM but the issue is that the dtype could/might change when we do model.to(device) meaning this only fixes inference after init if torch_dtype is specified

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changes to the modeling utils are a bit too specific (specifically the to) to flash attention. Given that torch_dtype can be accessed in the XXXFlashAttention as self.config makes more sense to have this in the attention (if possible?).

It's good if we want this, but there might be a solution to just change the attention

else:
target_dtype = kwargs["dtype"]

if target_dtype is not None and target_dtype == torch.float32:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice

@younesbelkada
Copy link
Contributor Author

younesbelkada commented Oct 3, 2023

I agree we should make all hacks / changes with respect to FA2 modules inside them. However this might introduce multiple patches and other hacks for quantized modules. I think for now this approach is fine, but I agree we should go for a better one, as this would unblock some users for the next release, I left it as a TODO!

" float16."
f"The input hidden states seems to be silently casted in float32, this might be related to"
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
f" {attention_dtype}. Make sure to pass the desired dtype when calling `from_pretrained` with `torch_dtype=your_dtype`"

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems to me, that this line is 124 long and should be split into 2, previous lines could be left without 'f' at the beginning, while there are no values to be formatted there

@younesbelkada
Copy link
Contributor Author

cc @hiyouga are you able to Fine-tune in bf16 with this branch?

Copy link
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, cool to relax the cast to fp16.

Wouldn't it be cleaner to put the changes relative to to in a FlashAttention-specific mixin so as to not require changes to modeling_utils.py? You wouldn't need to add an additional private property _flash_attn_2_attention_dtype and wouldn't need to edit the general to method either (just the Flash-attention-specific to method.

You may want to update the _apply method instead of to however, I think with to you're not seeing calls like model.float() which will convert your entire model to float32.

And probably an edge case but with a class-level to override like this you're not keeping the following in check:

model = MistralForCausalLM.from_pretrained("hf-internal-testing/tiny-random-MistralModel", use_flash_attention_2=True, torch_dtype=torch.float16)
model.model.layers.to(torch.float32)

Comment on lines +2192 to +2194
raise ValueError(
"You cannot cast a model that has been loaded with Flash Attention 2 in `float32`"
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe mention how to go around this?

Comment on lines +2178 to +2179
# TODO: @younesbelkada find a better way to do this directly in `xxxFlashAttention` modules
# currently it is not possible to retrieve the original dtype for quantized models.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be investigated now?

@hiyouga
Copy link
Contributor

hiyouga commented Oct 11, 2023

I also thought that we could retrieve the data type from self.config.torch_dtype of LlamaFlashAttention2.

PS. Although it may fail in such an edge case as Lysandre said, we usually have a consistent data type in training.

# in fp32. (LlamaRMSNorm handles it correctly)
input_dtype = query_states.dtype
if input_dtype == torch.float32:
attention_dtype = self.config._flash_attn_2_attention_dtype
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not all attention modules have a config variable. But guess that's ok we can just pass it forward.

@younesbelkada
Copy link
Contributor Author

Closing this PR in favor of #26846

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

The hidden states in LlamaFlashAttention2 are cast in fp16 unexpectedly

7 participants