-
Notifications
You must be signed in to change notification settings - Fork 31.1k
[FA-2] Final fix for FA2 dtype
#26846
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. |
ArthurZucker
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, think we can simplify a bit and remove the warning ?
| logger.warning_once( | ||
| "The input hidden states seems to be silently casted in float32, this might be related to" | ||
| " the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" | ||
| " float16." | ||
| f"The input hidden states seems to be silently casted in float32, this might be related to" | ||
| f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" | ||
| f" {target_dtype}." | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can remove this now no?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm I think we need to keep it to inform users about that
* final fix for FA2 dtype * try * oops * Update src/transformers/models/falcon/modeling_falcon.py Co-authored-by: Arthur <[email protected]> * apply fix everywhere --------- Co-authored-by: Arthur <[email protected]>
What does this PR do?
Replaces #26560
Fixes #26451
Proposes a simpler fix for dealing with FA-2 + PEFT + quantization fine-tuning where users usually cast all other modules (e.g. LayerNorms) in fp32 for training stability.
With #26761 being introduced, it is now much simpler to retrieve model's original dtype, note also that
self.config._pre_quantization_dtyperemains the single source of truth astois not supported for quantized modelscc @ArthurZucker @pacman100
Added also a nice test