- 
                Notifications
    You must be signed in to change notification settings 
- Fork 31k
[FEAT] Add Neftune into transformers Trainer #27141
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
| The documentation is not available anymore as the PR was closed or merged. | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! Overall this looks very good and handy to use. I left a few comments for an initial review :)
Co-authored-by: Zach Mueller <[email protected]>
| Added a test and a relevant documentation section, this PR is ready for final review! | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice work! 💪
Just some small comments. Main one is to add a check for the deactivation logic.
        
          
                src/transformers/trainer.py
              
                Outdated
          
        
      | # After training we make sure to retrieve back the original forward pass method | ||
| # for the embedding layer by removing the forward post hook. | ||
| if self.neftune_noise_alpha is not None: | ||
| if is_peft_available() and isinstance(self.model, PeftModel): | ||
| embeddings = unwrap_model(self.model.base_model).get_input_embeddings() | ||
| else: | ||
| embeddings = unwrap_model(self.model).get_input_embeddings() | ||
|  | ||
| self.neftune_hook_handle.remove() | ||
| del embeddings.neftune_noise_alpha | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's make this into an equivalent method _deacivate_neftune
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done in e55ab8b
        
          
                src/transformers/trainer.py
              
                Outdated
          
        
      | if is_peft_available() and isinstance(self.model, PeftModel): | ||
| embeddings = unwrap_model(self.model.base_model).get_input_embeddings() | ||
| else: | ||
| embeddings = unwrap_model(self.model).get_input_embeddings() | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this logic used anywhere else? It looks general enough that we could have a _get_model_input_embeddings function (not necessarily to be done in this PR)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
happy to refactor this in a follow up PR!
|  | ||
| # Make sure forward pass works fine | ||
| _ = trainer.model(torch.LongTensor([[1, 0, 1]]).to(torch_device)) | ||
| self.assertTrue(len(trainer.model.get_input_embeddings()._forward_hooks) == 0) | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A check should be made that it's correctly disabled after training has finished
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the line
self.assertTrue(len(trainer.model.get_input_embeddings()._forward_hooks) == 0)Should check if the forward hook as been correctly removed so I think all should be good here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note also that line is called right after training, so it should check that neftune is correctly disabled after training.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added slightly more elaborated test in ca8f8c4
Co-authored-by: amyeroberts <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome - thanks for iterating!
| if not hasattr(self, "neftune_hook_handle"): | ||
| raise ValueError("Neftune is not activated make sure to call `trainer._activate_neftune()` first") | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice :)
* add v1 neftune * use `unwrap_model` instead * add test + docs * Apply suggestions from code review Co-authored-by: Zach Mueller <[email protected]> * more details * fixup * Update docs/source/en/main_classes/trainer.md Co-authored-by: amyeroberts <[email protected]> * refactor a bit * more elaborated test * fix unwrap issue --------- Co-authored-by: Zach Mueller <[email protected]> Co-authored-by: amyeroberts <[email protected]>
What does this PR do?
As per title
Fixes: huggingface/trl#923
Fixes: #26899
This PR adds NEFTune: a new technique for enhancing Supervised fine-tuning results results proposed in: https://arxiv.org/abs/2310.05914
I propose a very simple API which is as simple as passing a valid
neftune_noise_alphaargument when initializing theTrainingArguments. To avoid any surprising behaviour, we should revert to the original forward method at the end of the training. This is handled inside the inner training loop that attaches the correct forward hook before the beginning of training, and makes sure to remove it right after training the model.