-
Notifications
You must be signed in to change notification settings - Fork 31.1k
[FA2] Cast to correct dtype
#26560
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[FA2] Cast to correct dtype
#26560
Conversation
ArthurZucker
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A small nit!
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. |
ArthurZucker
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay! LGTM but the issue is that the dtype could/might change when we do model.to(device) meaning this only fixes inference after init if torch_dtype is specified
Co-authored-by: Arthur <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changes to the modeling utils are a bit too specific (specifically the to) to flash attention. Given that torch_dtype can be accessed in the XXXFlashAttention as self.config makes more sense to have this in the attention (if possible?).
It's good if we want this, but there might be a solution to just change the attention
| else: | ||
| target_dtype = kwargs["dtype"] | ||
|
|
||
| if target_dtype is not None and target_dtype == torch.float32: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice
|
I agree we should make all hacks / changes with respect to FA2 modules inside them. However this might introduce multiple patches and other hacks for quantized modules. I think for now this approach is fine, but I agree we should go for a better one, as this would unblock some users for the next release, I left it as a TODO! |
| " float16." | ||
| f"The input hidden states seems to be silently casted in float32, this might be related to" | ||
| f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" | ||
| f" {attention_dtype}. Make sure to pass the desired dtype when calling `from_pretrained` with `torch_dtype=your_dtype`" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems to me, that this line is 124 long and should be split into 2, previous lines could be left without 'f' at the beginning, while there are no values to be formatted there
|
cc @hiyouga are you able to Fine-tune in bf16 with this branch? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, cool to relax the cast to fp16.
Wouldn't it be cleaner to put the changes relative to to in a FlashAttention-specific mixin so as to not require changes to modeling_utils.py? You wouldn't need to add an additional private property _flash_attn_2_attention_dtype and wouldn't need to edit the general to method either (just the Flash-attention-specific to method.
You may want to update the _apply method instead of to however, I think with to you're not seeing calls like model.float() which will convert your entire model to float32.
And probably an edge case but with a class-level to override like this you're not keeping the following in check:
model = MistralForCausalLM.from_pretrained("hf-internal-testing/tiny-random-MistralModel", use_flash_attention_2=True, torch_dtype=torch.float16)
model.model.layers.to(torch.float32)
| raise ValueError( | ||
| "You cannot cast a model that has been loaded with Flash Attention 2 in `float32`" | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe mention how to go around this?
| # TODO: @younesbelkada find a better way to do this directly in `xxxFlashAttention` modules | ||
| # currently it is not possible to retrieve the original dtype for quantized models. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this be investigated now?
|
I also thought that we could retrieve the data type from PS. Although it may fail in such an edge case as Lysandre said, we usually have a consistent data type in training. |
| # in fp32. (LlamaRMSNorm handles it correctly) | ||
| input_dtype = query_states.dtype | ||
| if input_dtype == torch.float32: | ||
| attention_dtype = self.config._flash_attn_2_attention_dtype |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not all attention modules have a config variable. But guess that's ok we can just pass it forward.
|
Closing this PR in favor of #26846 |
What does this PR do?
Fixes: #26451
Currently performing bf16 fine-tuning with FA-2 leads to hidden states silently being casted in float16
As it is challenging to retrieve the original dtype of the model in case the model is quantized, I propose to store that dtype in a private attribute to be able to retrieve it conveniently without having to perform any sort of hack that gets the correct dtype if the model is quantized