-
Notifications
You must be signed in to change notification settings - Fork 31k
Fix RMSNormGated in Zamba2 #35943
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix RMSNormGated in Zamba2 #35943
Conversation
Rebase zamba2
rebase on upstream
This reverts commit 9007a52.
Co-authored-by: Arthur <[email protected]>
|
|
||
| class Zamba2RMSNormGated(MambaRMSNormGated): | ||
| pass | ||
| class Zamba2RMSNormGated(torch.nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this will also affect the mamba2 code then (as codestral mamba also uses ngroups > 1) - so I'd be for implementing this in the mamba2 code and use modular then.
cc @molbap
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think so, but as I'm no maintainer I leave the decision to the others 👀
|
cc @molbap I think! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey! Just want to make sure, this is a fix and not a new feature specific to a model right?
@vasqu I know it is tempting to also add this for mamba2 but AFAIK this was not the original RMSNorm used then no?
TLDR; we don't make modeling changes unless it's bug fixes in general. If you have a new RMS norm it's a new model for use 😉
|
@ArthurZucker I think it was an oversight in the original implementation for Mamba2 over here - @pglorio shows the relevant code snippets that I can't estimate how it changes the model tho and if slow tests would need to be changed accordingly. |
|
Interesting. We have 1-1 matching results for the codestral model, so my intuition would say we removed it because the input args made it a nop but I might be wrong. |
|
I'll probably check some other time if logits match - maybe they indeed are equivalent 👀 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay, let's only apply it to zamba for now!
* First commit * Finish model implementation * First commit * Finish model implementation * Register zamba2 * generated modeling and configuration * generated modeling and configuration * added hybrid cache * fix attention_mask in mamba * dropped unused loras * fix flash2 * config docstrings * fix config and fwd pass * make fixup fixes * text_modeling_zamba2 * small fixes * make fixup fixes * Fix modular model converter * added inheritances in modular, renamed zamba cache * modular rebase * new modular conversion * fix generated modeling file * fixed import for Zamba2RMSNormGated * modular file cleanup * make fixup and model tests * dropped inheritance for Zamba2PreTrainedModel * make fixup and unit tests * Add inheritance of rope from GemmaRotaryEmbedding * moved rope to model init * drop del self.self_attn and del self.feed_forward * fix tests * renamed lora -> adapter * rewrote adapter implementation * fixed tests * Fix torch_forward in mamba2 layer * Fix torch_forward in mamba2 layer * Fix torch_forward in mamba2 layer * Dropped adapter in-place sum * removed rope from attention init * updated rope * created get_layers method * make fixup fix * make fixup fixes * make fixup fixes * update to new attention standard * update to new attention standard * make fixup fixes * minor fixes * cache_position * removed cache_position postion_ids use_cache * remove config from modular * removed config from modular (2) * import apply_rotary_pos_emb from llama * fixed rope_kwargs * Instantiate cache in Zamba2Model * fix cache * fix @slow decorator * small fix in modular file * Update docs/source/en/model_doc/zamba2.md Co-authored-by: Arthur <[email protected]> * several minor fixes * inherit mamba2decoder fwd and drop position_ids in mamba * removed docstrings from modular * reinstate zamba2 attention decoder fwd * use regex for tied keys * Revert "use regex for tied keys" This reverts commit 9007a52. * use regex for tied keys * add cpu to slow forward tests * dropped config.use_shared_mlp_adapter * Update docs/source/en/model_doc/zamba2.md Co-authored-by: Arthur <[email protected]> * re-convert from modular * extended Zamba2RMSNormGated to n_groups>1 * removed einops import * set _supports_sdpa = True * add use_mem_eff_path flag for fused mamba2 fwd * added docstring for use_mem_eff_ath flag --------- Co-authored-by: root <[email protected]> Co-authored-by: Arthur <[email protected]>
* First commit * Finish model implementation * First commit * Finish model implementation * Register zamba2 * generated modeling and configuration * generated modeling and configuration * added hybrid cache * fix attention_mask in mamba * dropped unused loras * fix flash2 * config docstrings * fix config and fwd pass * make fixup fixes * text_modeling_zamba2 * small fixes * make fixup fixes * Fix modular model converter * added inheritances in modular, renamed zamba cache * modular rebase * new modular conversion * fix generated modeling file * fixed import for Zamba2RMSNormGated * modular file cleanup * make fixup and model tests * dropped inheritance for Zamba2PreTrainedModel * make fixup and unit tests * Add inheritance of rope from GemmaRotaryEmbedding * moved rope to model init * drop del self.self_attn and del self.feed_forward * fix tests * renamed lora -> adapter * rewrote adapter implementation * fixed tests * Fix torch_forward in mamba2 layer * Fix torch_forward in mamba2 layer * Fix torch_forward in mamba2 layer * Dropped adapter in-place sum * removed rope from attention init * updated rope * created get_layers method * make fixup fix * make fixup fixes * make fixup fixes * update to new attention standard * update to new attention standard * make fixup fixes * minor fixes * cache_position * removed cache_position postion_ids use_cache * remove config from modular * removed config from modular (2) * import apply_rotary_pos_emb from llama * fixed rope_kwargs * Instantiate cache in Zamba2Model * fix cache * fix @slow decorator * small fix in modular file * Update docs/source/en/model_doc/zamba2.md Co-authored-by: Arthur <[email protected]> * several minor fixes * inherit mamba2decoder fwd and drop position_ids in mamba * removed docstrings from modular * reinstate zamba2 attention decoder fwd * use regex for tied keys * Revert "use regex for tied keys" This reverts commit 9007a52. * use regex for tied keys * add cpu to slow forward tests * dropped config.use_shared_mlp_adapter * Update docs/source/en/model_doc/zamba2.md Co-authored-by: Arthur <[email protected]> * re-convert from modular * extended Zamba2RMSNormGated to n_groups>1 * removed einops import * set _supports_sdpa = True * add use_mem_eff_path flag for fused mamba2 fwd * added docstring for use_mem_eff_ath flag --------- Co-authored-by: root <[email protected]> Co-authored-by: Arthur <[email protected]>
What does this PR do?
This PR fixes
Zamba2RMSNormGatedto allow forconfig.mamba_ngroups>1. The Zamba2 7B checkpoints haveconfig.mamba_ngroups=2so this change is necessary to have the correct forward pass.I defined
Zamba2RMSNormGatedinside modular_zamba2.py instead of importing it, as this differs from the definition in modeling_mamba2.py. The implementation in this PR is the torch version of the mamba-ssm implementation of the original mamba2 (used here and torch implementation given here).Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
@ArthurZucker @Cyrilvallez