Skip to content

Conversation

@younesbelkada
Copy link
Contributor

@younesbelkada younesbelkada commented Oct 30, 2023

What does this PR do?

As per title

Fixes: huggingface/trl#923
Fixes: #26899

This PR adds NEFTune: a new technique for enhancing Supervised fine-tuning results results proposed in: https://arxiv.org/abs/2310.05914

Screenshot 2023-10-13 at 17 36 38

I propose a very simple API which is as simple as passing a valid neftune_noise_alpha argument when initializing the TrainingArguments. To avoid any surprising behaviour, we should revert to the original forward method at the end of the training. This is handled inside the inner training loop that attaches the correct forward hook before the beginning of training, and makes sure to remove it right after training the model.

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Oct 30, 2023

The documentation is not available anymore as the PR was closed or merged.

Copy link
Contributor

@muellerzr muellerzr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Overall this looks very good and handy to use. I left a few comments for an initial review :)

@younesbelkada younesbelkada marked this pull request as ready for review October 30, 2023 15:51
@younesbelkada
Copy link
Contributor Author

Added a test and a relevant documentation section, this PR is ready for final review!

Copy link
Contributor

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice work! 💪

Just some small comments. Main one is to add a check for the deactivation logic.

Comment on lines 1981 to 1990
# After training we make sure to retrieve back the original forward pass method
# for the embedding layer by removing the forward post hook.
if self.neftune_noise_alpha is not None:
if is_peft_available() and isinstance(self.model, PeftModel):
embeddings = unwrap_model(self.model.base_model).get_input_embeddings()
else:
embeddings = unwrap_model(self.model).get_input_embeddings()

self.neftune_hook_handle.remove()
del embeddings.neftune_noise_alpha
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's make this into an equivalent method _deacivate_neftune

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in e55ab8b

Comment on lines 1984 to 1987
if is_peft_available() and isinstance(self.model, PeftModel):
embeddings = unwrap_model(self.model.base_model).get_input_embeddings()
else:
embeddings = unwrap_model(self.model).get_input_embeddings()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this logic used anywhere else? It looks general enough that we could have a _get_model_input_embeddings function (not necessarily to be done in this PR)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

happy to refactor this in a follow up PR!


# Make sure forward pass works fine
_ = trainer.model(torch.LongTensor([[1, 0, 1]]).to(torch_device))
self.assertTrue(len(trainer.model.get_input_embeddings()._forward_hooks) == 0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A check should be made that it's correctly disabled after training has finished

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the line

self.assertTrue(len(trainer.model.get_input_embeddings()._forward_hooks) == 0)

Should check if the forward hook as been correctly removed so I think all should be good here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note also that line is called right after training, so it should check that neftune is correctly disabled after training.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added slightly more elaborated test in ca8f8c4

Copy link
Contributor

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome - thanks for iterating!

Comment on lines +659 to +660
if not hasattr(self, "neftune_hook_handle"):
raise ValueError("Neftune is not activated make sure to call `trainer._activate_neftune()` first")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice :)

@younesbelkada younesbelkada merged commit 309a906 into huggingface:main Oct 31, 2023
@younesbelkada younesbelkada deleted the add-neftune branch October 31, 2023 15:04
EduardoPach pushed a commit to EduardoPach/transformers that referenced this pull request Nov 19, 2023
* add v1 neftune

* use `unwrap_model` instead

* add test + docs

* Apply suggestions from code review

Co-authored-by: Zach Mueller <[email protected]>

* more details

* fixup

* Update docs/source/en/main_classes/trainer.md

Co-authored-by: amyeroberts <[email protected]>

* refactor a bit

* more elaborated test

* fix unwrap issue

---------

Co-authored-by: Zach Mueller <[email protected]>
Co-authored-by: amyeroberts <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

NEFTune failed when using accelerate NEFTune Support pls

4 participants