Skip to content

Conversation

@pglorio
Copy link
Contributor

@pglorio pglorio commented Jan 28, 2025

What does this PR do?

This PR fixes Zamba2RMSNormGated to allow for config.mamba_ngroups>1. The Zamba2 7B checkpoints have config.mamba_ngroups=2 so this change is necessary to have the correct forward pass.

I defined Zamba2RMSNormGated inside modular_zamba2.py instead of importing it, as this differs from the definition in modeling_mamba2.py. The implementation in this PR is the torch version of the mamba-ssm implementation of the original mamba2 (used here and torch implementation given here).

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

@ArthurZucker @Cyrilvallez

@pglorio pglorio changed the title Zamba2 Fix RMSNormGated in Zamba2 Jan 28, 2025

class Zamba2RMSNormGated(MambaRMSNormGated):
pass
class Zamba2RMSNormGated(torch.nn.Module):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this will also affect the mamba2 code then (as codestral mamba also uses ngroups > 1) - so I'd be for implementing this in the mamba2 code and use modular then.

cc @molbap

Copy link
Contributor Author

@pglorio pglorio Jan 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vasqu @molbap sounds good. Should I go ahead and update mamba2?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think so, but as I'm no maintainer I leave the decision to the others 👀

@Rocketknight1
Copy link
Member

cc @molbap I think!

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey! Just want to make sure, this is a fix and not a new feature specific to a model right?
@vasqu I know it is tempting to also add this for mamba2 but AFAIK this was not the original RMSNorm used then no?

TLDR; we don't make modeling changes unless it's bug fixes in general. If you have a new RMS norm it's a new model for use 😉

@vasqu
Copy link
Contributor

vasqu commented Feb 4, 2025

@ArthurZucker I think it was an oversight in the original implementation for Mamba2 over here - @pglorio shows the relevant code snippets that ngroups is indeed used in the gated rms norm, e.g. here and here.

I can't estimate how it changes the model tho and if slow tests would need to be changed accordingly.

@ArthurZucker
Copy link
Collaborator

Interesting. We have 1-1 matching results for the codestral model, so my intuition would say we removed it because the input args made it a nop but I might be wrong.

@vasqu
Copy link
Contributor

vasqu commented Feb 4, 2025

I'll probably check some other time if logits match - maybe they indeed are equivalent 👀

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, let's only apply it to zamba for now!

@ArthurZucker ArthurZucker merged commit a93b805 into huggingface:main Feb 4, 2025
14 checks passed
elvircrn pushed a commit to elvircrn/transformers that referenced this pull request Feb 13, 2025
* First commit

* Finish model implementation

* First commit

* Finish model implementation

* Register zamba2

* generated modeling and configuration

* generated modeling and configuration

* added hybrid cache

* fix attention_mask in mamba

* dropped unused loras

* fix flash2

* config docstrings

* fix config and fwd pass

* make fixup fixes

* text_modeling_zamba2

* small fixes

* make fixup fixes

* Fix modular model converter

* added inheritances in modular, renamed zamba cache

* modular rebase

* new modular conversion

* fix generated modeling file

* fixed import for Zamba2RMSNormGated

* modular file cleanup

* make fixup and model tests

* dropped inheritance for Zamba2PreTrainedModel

* make fixup and unit tests

* Add inheritance of rope from GemmaRotaryEmbedding

* moved rope to model init

* drop del self.self_attn and del self.feed_forward

* fix tests

* renamed lora -> adapter

* rewrote adapter implementation

* fixed tests

* Fix torch_forward in mamba2 layer

* Fix torch_forward in mamba2 layer

* Fix torch_forward in mamba2 layer

* Dropped adapter in-place sum

* removed rope from attention init

* updated rope

* created get_layers method

* make fixup fix

* make fixup fixes

* make fixup fixes

* update to new attention standard

* update to new attention standard

* make fixup fixes

* minor fixes

* cache_position

* removed cache_position postion_ids use_cache

* remove config from modular

* removed config from modular (2)

* import apply_rotary_pos_emb from llama

* fixed rope_kwargs

* Instantiate cache in Zamba2Model

* fix cache

* fix @slow decorator

* small fix in modular file

* Update docs/source/en/model_doc/zamba2.md

Co-authored-by: Arthur <[email protected]>

* several minor fixes

* inherit mamba2decoder fwd and drop position_ids in mamba

* removed docstrings from modular

* reinstate zamba2 attention decoder fwd

* use regex for tied keys

* Revert "use regex for tied keys"

This reverts commit 9007a52.

* use regex for tied keys

* add cpu to slow forward tests

* dropped config.use_shared_mlp_adapter

* Update docs/source/en/model_doc/zamba2.md

Co-authored-by: Arthur <[email protected]>

* re-convert from modular

* extended Zamba2RMSNormGated to n_groups>1

* removed einops import

* set _supports_sdpa = True

* add use_mem_eff_path flag for fused mamba2 fwd

* added docstring for use_mem_eff_ath flag

---------

Co-authored-by: root <[email protected]>
Co-authored-by: Arthur <[email protected]>
sbucaille pushed a commit to sbucaille/transformers that referenced this pull request Feb 16, 2025
* First commit

* Finish model implementation

* First commit

* Finish model implementation

* Register zamba2

* generated modeling and configuration

* generated modeling and configuration

* added hybrid cache

* fix attention_mask in mamba

* dropped unused loras

* fix flash2

* config docstrings

* fix config and fwd pass

* make fixup fixes

* text_modeling_zamba2

* small fixes

* make fixup fixes

* Fix modular model converter

* added inheritances in modular, renamed zamba cache

* modular rebase

* new modular conversion

* fix generated modeling file

* fixed import for Zamba2RMSNormGated

* modular file cleanup

* make fixup and model tests

* dropped inheritance for Zamba2PreTrainedModel

* make fixup and unit tests

* Add inheritance of rope from GemmaRotaryEmbedding

* moved rope to model init

* drop del self.self_attn and del self.feed_forward

* fix tests

* renamed lora -> adapter

* rewrote adapter implementation

* fixed tests

* Fix torch_forward in mamba2 layer

* Fix torch_forward in mamba2 layer

* Fix torch_forward in mamba2 layer

* Dropped adapter in-place sum

* removed rope from attention init

* updated rope

* created get_layers method

* make fixup fix

* make fixup fixes

* make fixup fixes

* update to new attention standard

* update to new attention standard

* make fixup fixes

* minor fixes

* cache_position

* removed cache_position postion_ids use_cache

* remove config from modular

* removed config from modular (2)

* import apply_rotary_pos_emb from llama

* fixed rope_kwargs

* Instantiate cache in Zamba2Model

* fix cache

* fix @slow decorator

* small fix in modular file

* Update docs/source/en/model_doc/zamba2.md

Co-authored-by: Arthur <[email protected]>

* several minor fixes

* inherit mamba2decoder fwd and drop position_ids in mamba

* removed docstrings from modular

* reinstate zamba2 attention decoder fwd

* use regex for tied keys

* Revert "use regex for tied keys"

This reverts commit 9007a52.

* use regex for tied keys

* add cpu to slow forward tests

* dropped config.use_shared_mlp_adapter

* Update docs/source/en/model_doc/zamba2.md

Co-authored-by: Arthur <[email protected]>

* re-convert from modular

* extended Zamba2RMSNormGated to n_groups>1

* removed einops import

* set _supports_sdpa = True

* add use_mem_eff_path flag for fused mamba2 fwd

* added docstring for use_mem_eff_ath flag

---------

Co-authored-by: root <[email protected]>
Co-authored-by: Arthur <[email protected]>
@tdoublep tdoublep mentioned this pull request Sep 13, 2025
5 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants