-
Notifications
You must be signed in to change notification settings - Fork 31.2k
Llama: RoPE refactor #32135
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Llama: RoPE refactor #32135
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
#31999, which propagates the changes to all models, will fix this.
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
amyeroberts
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for all the work consolidating the rope logic!
Mostly some small questions and nits. Main comment is about the testing for all the compute functions
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are all of the arguments expected, even if optional?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no, not at all :) the validation function exists to (among other things) detect incorrect parameter configurations
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
All of these should be tested in a test rope utils module, including checks for taking rope_kwargs and config and their equivalence
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added "rope_kwargs and config and their equivalence" ✅
Numerical checks will be a todo for the post-release follow-up PR (#31999)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This works and is consistent with the other checks above. We should really make sure to check the rescaling values with specific numerical values in tests for the compute methods as well. This tests tells us things have changed, but not whether the change is in the right direction or magnitude
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fair, but that is a test that requires some numerical diving. Given our release goals -- would it be okay for me to add a todo/open an issue?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As long as it's actually done, then yes ;)
ArthurZucker
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] | |
| self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] |
should it be rope scaling rather than rope init? nit!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd rather go with init -- the default rope (i.e. not scaled) uses this path as well
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok this should leave enough freedom
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
tho, the fact that we don't have a nested config makes it simpler, checks are run somwhere else so pretty much equivalent
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice to see that go aways!
amyeroberts
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Beautiful - thanks for adding and iterating!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🤗
YaRN (Yet another RoPE extension method) combines the NTK-By-Parts Interpolation and Attention Scaling methods, improving upon existing RoPE interpolation methods for longer context window sizes. Fine-tuned models maintain their original performance across benchmarks while enabling efficient extrapolation and transfer learning for quicker convergence, especially in compute-limited environments. We implement YaRN and Dynamic-YaRN for the following list of models: - LLaMA - Falcon - GPT-NeoX - Olmo - Persimmon - Phi - StableLM - OpenLLaMA New unit tests are added to assert YaRN's correct behavior on both short and long sequence inputs. For more details, please refer to https://arxiv.org/abs/2309.00071. Co-authored-by: Miguel Almeida <[email protected]>
Iterate on YaRN implementation for LLaMA and remove diff from remaining
models for increased PR modularity.
This commit includes the following changes:
- Merge 'yarn_rope_scaling' and 'rope_scaling' dictionaries
- Remove unnecessary attributes ('extrapolation_factor' and 'finetuned')
from YaRN classes
- Inherit 'forward' method in YaRN classes from superclass
- Rename 'yarn' method to 'compute_yarn_scaling'
- Extend YaRN tests with further assertions
- Fix style inconsistencies
Co-authored-by: Miguel Monte e Freitas <[email protected]>
- Comply with the the tensor building logic introduced in huggingface#30743 - Add referencing to the optimized Attention Factor equation - Remove Dynamic YaRN for a more agile deployment Co-authored-by: mig-mfreitas <[email protected]>
Co-authored-by: amyeroberts <[email protected]>
Co-authored-by: amyeroberts <[email protected]>
Co-authored-by: Arthur <[email protected]>
Co-authored-by: Arthur <[email protected]>
1416972 to
c824be0
Compare
|
merged the yarn PR (percursor), now merging this one as soon as CI goes green |
|
Yarn PR is failing code quality checks on main. Could you make sure to rebase and then run make fix-copies etc here before merge? |
What does this PR do?
Same as #31999, but with llama being the only changed model.
Confirmed: slow tests are "passing" (same failures as
main)👉
RUN_SLOW=1 py.test -vv tests/models/llama/test_modeling_llama.py👉
RUN_SLOW=1 py.test -vv tests/utils/test_cache_utils.py👉
RUN_SLOW=1 py.test -vv tests/utils/test_modeling_rope_utils.py(new tests)Throughput benchmarks: No changes vs previous
main💔