Skip to content

Conversation

@younesbelkada
Copy link
Contributor

@younesbelkada younesbelkada commented Oct 16, 2023

What does this PR do?

Replaces #26560
Fixes #26451

Proposes a simpler fix for dealing with FA-2 + PEFT + quantization fine-tuning where users usually cast all other modules (e.g. LayerNorms) in fp32 for training stability.

With #26761 being introduced, it is now much simpler to retrieve model's original dtype, note also that self.config._pre_quantization_dtype remains the single source of truth as to is not supported for quantized models

cc @ArthurZucker @pacman100

Added also a nice test

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, think we can simplify a bit and remove the warning ?

Comment on lines 417 to 421
logger.warning_once(
"The input hidden states seems to be silently casted in float32, this might be related to"
" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
" float16."
f"The input hidden states seems to be silently casted in float32, this might be related to"
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
f" {target_dtype}."
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can remove this now no?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm I think we need to keep it to inform users about that

@younesbelkada younesbelkada merged commit 5a73316 into huggingface:main Oct 18, 2023
@younesbelkada younesbelkada deleted the fa-2-final-fix branch October 18, 2023 21:13
EduardoPach pushed a commit to EduardoPach/transformers that referenced this pull request Nov 19, 2023
* final fix for FA2 dtype

* try

* oops

* Update src/transformers/models/falcon/modeling_falcon.py

Co-authored-by: Arthur <[email protected]>

* apply fix everywhere

---------

Co-authored-by: Arthur <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

The hidden states in LlamaFlashAttention2 are cast in fp16 unexpectedly

3 participants