diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 9ed80cfb0b78..7508f0968863 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -959,6 +959,8 @@ title: FLAVA - local: model_doc/gemma3 title: Gemma3 + - local: model_doc/gemma3n + title: Gemma3n - local: model_doc/git title: GIT - local: model_doc/glm4v diff --git a/docs/source/en/model_doc/gemma3n.md b/docs/source/en/model_doc/gemma3n.md new file mode 100644 index 000000000000..7f38c3b18c92 --- /dev/null +++ b/docs/source/en/model_doc/gemma3n.md @@ -0,0 +1,204 @@ + + + +
+
+ PyTorch + SDPA +
+
+ +# Gemma3n + +## Overview + +Gemma3n is a multimodal model with pretrained and instruction-tuned variants, available in E4B and E2B sizes. While +large portions of the language model architecture are shared with prior Gemma releases, there are many new additions in +this model, including [Alternating Updates][altup] (AltUp), [Learned Augmented Residual Layer][laurel] (LAuReL), +[MatFormer][matformer], Per-Layer Embeddings (PLE), activation sparsity, and KV cache sharing. The language model uses +a similar attention pattern to [Gemma 3](./gemma3.md) with alternating 4 local sliding window self-attention layers for +every global self-attention layer with a maximum context length of 32k tokens. Gemma 3n introduces +[MobileNet v5][mobilenetv5] as the vision encoder, using a default resolution of 768x768 pixels, and adds a +[Universal Speech Model][usm] (USM) as the audio encoder. + +The instruction-tuned variant was post-trained with knowledge distillation and reinforcement learning. + +You can find all the original Gemma 3n checkpoints under the [Gemma 3n][gemma3n-collection] release. + +> [!TIP] +> Click on the Gemma 3n models in the right sidebar for more examples of how to apply Gemma to different vision, audio, +> and language tasks. + +The example below demonstrates how to generate text based on an image with [`Pipeline`] or the [`AutoModel`] class. + + + + +```py +import torch +from transformers import pipeline + +pipeline = pipeline( + task="image-text-to-text", + model="google/gemma-3n-e4b", + device=0, + torch_dtype=torch.bfloat16 +) +pipeline( + "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg", + text=" What is shown in this image?" +) +``` + + + + +```py +import torch +from transformers import AutoProcessor, Gemma3nForConditionalGeneration + +model = Gemma3nForConditionalGeneration.from_pretrained( + "google/gemma-3n-e4b-it", + torch_dtype=torch.bfloat16, + device_map="auto", + attn_implementation="sdpa" +) +processor = AutoProcessor.from_pretrained( + "google/gemma-3n-e4b-it", + padding_side="left" +) + +messages = [ + { + "role": "system", + "content": [ + {"type": "text", "text": "You are a helpful assistant."} + ] + }, + { + "role": "user", "content": [ + {"type": "image", "url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg"}, + {"type": "text", "text": "What is shown in this image?"}, + ] + }, +] +inputs = processor.apply_chat_template( + messages, + tokenize=True, + return_dict=True, + return_tensors="pt", + add_generation_prompt=True, +).to("cuda") + +output = model.generate(**inputs, max_new_tokens=50, cache_implementation="static") +print(processor.decode(output[0], skip_special_tokens=True)) +``` + + + + +```bash +echo -e "Plants create energy through a process known as" | transformers run --task text-generation --model google/gemma-3n-e2b --device 0 +``` + + + + +## Notes + +- Use [`Gemma3nForConditionalGeneration`] for image-audio-and-text, image-and-text, image-and-audio, audio-and-text, + image-only and aduio-only inputs. +- Gemma 3n supports multiple images per input, but make sure the images are correctly batched before passing them to + the processor. Each batch should be a list of one or more images. + + ```py + url_cow = "https://media.istockphoto.com/id/1192867753/photo/cow-in-berchida-beach-siniscola.jpg?s=612x612&w=0&k=20&c=v0hjjniwsMNfJSuKWZuIn8pssmD5h5bSN1peBd1CmH4=" + url_cat = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg" + + messages =[ + { + "role": "system", + "content": [ + {"type": "text", "text": "You are a helpful assistant."} + ] + }, + { + "role": "user", + "content": [ + {"type": "image", "url": url_cow}, + {"type": "image", "url": url_cat}, + {"type": "text", "text": "Which image is cuter?"}, + ] + }, + ] + ``` +- Text passed to the processor should have a `` token wherever an image should be inserted. +- Gemma 3n accept at most one target audio clip per input, though multiple audio clips can be provided in few-shot + prompts, for example. +- Text passed to the processor should have a `` token wherever an audio clip should be inserted. +- The processor has its own [`~ProcessorMixin.apply_chat_template`] method to convert chat messages to model inputs. + +## Gemma3nAudioFeatureExtractor + +[[autodoc]] Gemma3nAudioFeatureExtractor + +## Gemma3nProcessor + +[[autodoc]] Gemma3nProcessor + +## Gemma3nTextConfig + +[[autodoc]] Gemma3nTextConfig + +## Gemma3nVisionConfig + +[[autodoc]] Gemma3nVisionConfig + +## Gemma3nAudioConfig + +[[autodoc]] Gemma3nAudioConfig + +## Gemma3nConfig + +[[autodoc]] Gemma3nConfig + +## Gemma3nTextModel + +[[autodoc]] Gemma3nTextModel + - forward + +## Gemma3nModel + +[[autodoc]] Gemma3nModel + - forward + +## Gemma3nForCausalLM + +[[autodoc]] Gemma3nForCausalLM + - forward + +## Gemma3nForConditionalGeneration + +[[autodoc]] Gemma3nForConditionalGeneration + - forward + +[altup]: https://proceedings.neurips.cc/paper_files/paper/2023/hash/f2059277ac6ce66e7e5543001afa8bb5-Abstract-Conference.html +[attention-mask-viz]: https://github.com/huggingface/transformers/blob/beb9b5b02246b9b7ee81ddf938f93f44cfeaad19/src/transformers/utils/attention_visualizer.py#L139 +[gemma3n-collection]: https://huggingface.co/collections/google/gemma-3n +[laurel]: https://arxiv.org/abs/2411.07501 +[matformer]: https://arxiv.org/abs/2310.07707 +[usm]: https://arxiv.org/abs/2303.01037 diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index 71ad6eaadeb5..3b9e3e65df6f 100644 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -140,6 +140,10 @@ ("gemma2", "Gemma2Config"), ("gemma3", "Gemma3Config"), ("gemma3_text", "Gemma3TextConfig"), + ("gemma3n", "Gemma3nConfig"), + ("gemma3n_audio", "Gemma3nAudioConfig"), + ("gemma3n_text", "Gemma3nTextConfig"), + ("gemma3n_vision", "Gemma3nVisionConfig"), ("git", "GitConfig"), ("glm", "GlmConfig"), ("glm4", "Glm4Config"), @@ -518,6 +522,10 @@ ("gemma2", "Gemma2"), ("gemma3", "Gemma3ForConditionalGeneration"), ("gemma3_text", "Gemma3ForCausalLM"), + ("gemma3n", "Gemma3nForConditionalGeneration"), + ("gemma3n_audio", "Gemma3nAudioEncoder"), + ("gemma3n_text", "Gemma3nForCausalLM"), + ("gemma3n_vision", "TimmWrapperModel"), ("git", "GIT"), ("glm", "GLM"), ("glm4", "GLM4"), @@ -839,6 +847,9 @@ ("clip_text_model", "clip"), ("aria_text", "aria"), ("gemma3_text", "gemma3"), + ("gemma3n_audio", "gemma3n"), + ("gemma3n_text", "gemma3n"), + ("gemma3n_vision", "gemma3n"), ("glm4v_text", "glm4v"), ("idefics3_vision", "idefics3"), ("siglip_vision_model", "siglip"), diff --git a/src/transformers/models/auto/feature_extraction_auto.py b/src/transformers/models/auto/feature_extraction_auto.py index d54ca4b0f5ae..3595de53bbda 100644 --- a/src/transformers/models/auto/feature_extraction_auto.py +++ b/src/transformers/models/auto/feature_extraction_auto.py @@ -61,6 +61,7 @@ ("dpt", "DPTFeatureExtractor"), ("encodec", "EncodecFeatureExtractor"), ("flava", "FlavaFeatureExtractor"), + ("gemma3n", "Gemma3nAudioFeatureExtractor"), ("glpn", "GLPNFeatureExtractor"), ("granite_speech", "GraniteSpeechFeatureExtractor"), ("groupvit", "CLIPFeatureExtractor"), diff --git a/src/transformers/models/auto/image_processing_auto.py b/src/transformers/models/auto/image_processing_auto.py index b99dd365f57c..bee0335338cc 100644 --- a/src/transformers/models/auto/image_processing_auto.py +++ b/src/transformers/models/auto/image_processing_auto.py @@ -88,6 +88,7 @@ ("focalnet", ("BitImageProcessor", "BitImageProcessorFast")), ("fuyu", ("FuyuImageProcessor",)), ("gemma3", ("Gemma3ImageProcessor", "Gemma3ImageProcessorFast")), + ("gemma3n", ("SiglipImageProcessor", "SiglipImageProcessorFast")), ("git", ("CLIPImageProcessor", "CLIPImageProcessorFast")), ("glm4v", ("Glm4vImageProcessor", "Glm4vImageProcessorFast")), ("glpn", ("GLPNImageProcessor",)), diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index add9d09b0e2d..08b91dc1ea5f 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -132,6 +132,10 @@ ("gemma2", "Gemma2Model"), ("gemma3", "Gemma3Model"), ("gemma3_text", "Gemma3TextModel"), + ("gemma3n", "Gemma3nModel"), + ("gemma3n_audio", "Gemma3nAudioEncoder"), + ("gemma3n_text", "Gemma3nTextModel"), + ("gemma3n_vision", "TimmWrapperModel"), ("git", "GitModel"), ("glm", "GlmModel"), ("glm4", "Glm4Model"), @@ -583,6 +587,8 @@ ("gemma2", "Gemma2ForCausalLM"), ("gemma3", "Gemma3ForConditionalGeneration"), ("gemma3_text", "Gemma3ForCausalLM"), + ("gemma3n", "Gemma3nForConditionalGeneration"), + ("gemma3n_text", "Gemma3nForCausalLM"), ("git", "GitForCausalLM"), ("glm", "GlmForCausalLM"), ("glm4", "Glm4ForCausalLM"), @@ -906,6 +912,7 @@ ("emu3", "Emu3ForConditionalGeneration"), ("fuyu", "FuyuForCausalLM"), ("gemma3", "Gemma3ForConditionalGeneration"), + ("gemma3n", "Gemma3nForConditionalGeneration"), ("git", "GitForCausalLM"), ("glm4v", "Glm4vForConditionalGeneration"), ("got_ocr2", "GotOcr2ForConditionalGeneration"), diff --git a/src/transformers/models/auto/processing_auto.py b/src/transformers/models/auto/processing_auto.py index bccfe3e6d57f..e5bd673f6390 100644 --- a/src/transformers/models/auto/processing_auto.py +++ b/src/transformers/models/auto/processing_auto.py @@ -66,6 +66,7 @@ ("flava", "FlavaProcessor"), ("fuyu", "FuyuProcessor"), ("gemma3", "Gemma3Processor"), + ("gemma3n", "Gemma3nProcessor"), ("git", "GitProcessor"), ("glm4v", "Glm4vProcessor"), ("got_ocr2", "GotOcr2Processor"), diff --git a/src/transformers/models/auto/tokenization_auto.py b/src/transformers/models/auto/tokenization_auto.py index 0456e1945ca4..c8656a710746 100644 --- a/src/transformers/models/auto/tokenization_auto.py +++ b/src/transformers/models/auto/tokenization_auto.py @@ -236,6 +236,20 @@ "GemmaTokenizerFast" if is_tokenizers_available() else None, ), ), + ( + "gemma3n", + ( + "GemmaTokenizer" if is_sentencepiece_available() else None, + "GemmaTokenizerFast" if is_tokenizers_available() else None, + ), + ), + ( + "gemma3n_text", + ( + "GemmaTokenizer" if is_sentencepiece_available() else None, + "GemmaTokenizerFast" if is_tokenizers_available() else None, + ), + ), ("git", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)), ("glm", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)), ("glm4", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)), diff --git a/src/transformers/models/gemma3n/__init__.py b/src/transformers/models/gemma3n/__init__.py new file mode 100644 index 000000000000..229e91827036 --- /dev/null +++ b/src/transformers/models/gemma3n/__init__.py @@ -0,0 +1,29 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_gemma3n import * + from .feature_extraction_gemma3n import * + from .modeling_gemma3n import * + from .processing_gemma3n import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/src/transformers/models/gemma3n/configuration_gemma3n.py b/src/transformers/models/gemma3n/configuration_gemma3n.py new file mode 100644 index 000000000000..ca1a06717748 --- /dev/null +++ b/src/transformers/models/gemma3n/configuration_gemma3n.py @@ -0,0 +1,680 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/gemma3n/modular_gemma3n.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_gemma3n.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2025 Google Inc. HuggingFace Inc. team. All rights reserved. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from collections.abc import Sequence +from typing import Any, Optional, Union + +from ...configuration_utils import PretrainedConfig, layer_type_validation +from ...modeling_rope_utils import rope_config_validation +from ...utils import is_timm_available, logging, requires_backends + + +if is_timm_available(): + from timm.data import ImageNetInfo, infer_imagenet_subset + + +logger = logging.get_logger(__name__) + + +class Gemma3nTextConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Gemma3nTextModel`]. It is used to instantiate an + Gemma3nTextModel model according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the Gemma 3n E4B, e.g. + [google/gemma-3n-E4B](https://huggingface.co/google/gemma-3n-E4B). + + Configuration objects that inherit from [`Gemma3nTextConfig`] and can be used to control the model outputs. Read + the documentation from [`Gemma3nTextConfig`] for more information. + + Args: + vocab_size (`int`, *optional*, defaults to 262400): + Vocabulary size of the Gemma3nText model. Defines the number of different tokens that can be represented by + the `inputs_ids` passed when calling [`Gemma3nTextModel`] + vocab_size_per_layer_input (`int`, *optional*, defaults to 262144): + Vocabulary size of the per-layer text embeddings that augment the standard embeddings. + hidden_size (`int`, *optional*, defaults to 2048): + Dimension of the hidden representations. + hidden_size_per_layer_input (`int`, *optional*, defaults to 256): + Dimension of the hidden representations for per-layer emebeddings. + intermediate_size (`int` or `Sequence[int]`, *optional*, defaults to 16384): + Dimension of the MLP representations. MatFormer configurations may wish to provide a sequence of integers + to account for vairable intermediate_size values across layers. In such cases, + `len(intermediate_size) == num_hidden_layers`. + num_hidden_layers (`int`, *optional*, defaults to 35): + Number of hidden layers in the Transformer decoder. + num_attention_heads (`int`, *optional*, defaults to 8): + Number of attention heads for each attention layer in the Transformer decoder. + num_key_value_heads (`int`, *optional*, defaults to 2): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout this + [paper](https://arxiv.org/pdf/2305.13245.pdf). If not specified, will default to `num_attention_heads`. + head_dim (`int`, *optional*, defaults to 256): + The attention head dimension. + hidden_activation (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`): + The non-linear activation function (function or string) in the decoder. Will default to + `"gelu_pytorch_tanh"` if not specified. `"gelu_pytorch_tanh"` uses an approximation of the `"gelu"` + activation function. + max_position_embeddings (`int`, *optional*, defaults to 32768): + The maximum sequence length that this model might ever be used with. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + pad_token_id (`int`, *optional*, defaults to 0): + Padding token id. + eos_token_id (`int`, *optional*, defaults to 1): + End of stream token id. + bos_token_id (`int`, *optional*, defaults to 2): + Beginning of stream token id. + rope_theta (`float`, *optional*, defaults to 1000000.0): + The base period of the RoPE embeddings. + rope_scaling (`Dict`, *optional*): + Dictionary containing the scaling configuration for the RoPE embeddings used in gloabl attention. + NOTE: if you apply new rope type and you expect the model to work on longer `max_position_embeddings`, we + recommend you to update this value accordingly. + Expected contents: + `rope_type` (`str`): + The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope', + 'llama3'], with 'default' being the original RoPE implementation. + `factor` (`float`, *optional*): + Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In + most scaling types, a `factor` of x will enable the model to handle sequences of length x * + original maximum pre-trained length. + `original_max_position_embeddings` (`int`, *optional*): + Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during + pretraining. + `attention_factor` (`float`, *optional*): + Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention + computation. If unspecified, it defaults to value recommended by the implementation, using the + `factor` field to infer the suggested value. + `beta_fast` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear + ramp function. If unspecified, it defaults to 32. + `beta_slow` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear + ramp function. If unspecified, it defaults to 1. + `short_factor` (`List[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to short contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `long_factor` (`List[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to long contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `low_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE + `high_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE + rope_local_base_freq (float, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings for local attention. + attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`): + Whether to use a bias in the query, key, value and output projection layers during self-attention. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + sliding_window (`int`, *optional*, defaults to 512): + This is the size of the sliding window used by local attention layers. + layer_types (`Optional`, *optional*): + A sequence of strings defining the attention type for that layer as either "sliding_attention" or + "full_attention". If not provided, `layer_types` will de inferred from `num_hidden_layers` using a pattern + of four "sliding_attention" layers followed one "full_attention". The last layer in the model should always + be a "full_attention" layer. + final_logit_softcapping (`float`, *optional*, defaults to 30.0): + Scaling factor when applying tanh softcapping on the logits. + altup_active_idx (`int`, *optional*, defaults to 0): + The index of the prediction from which AltUp will compute additional predictions or correct + altup_coef_clip (`float`, *optional*, defaults to 120.0): + The maximum amplitude of an AltUp prediction or correction coeficient weight. + altup_correct_scale (`bool`, *optional*, defaults to `True`): + If True, apply the `AltUp.correct_output_scale` to the corrected prediction at `altup_active_idx`. + altup_num_inputs (`int`, *optional*, defaults to 4): + The number of predictions that AltUp should be make given the input sequence. + num_kv_shared_layers (`int`, *optional*, defaults to 15): + The number of layer that share KV cache values. During the forward pass, the last `num_kv_shared_layers` + layers in the model "share" the KV values in that each local and global layer in this range uses the KV + cache values computed for the last local or global layer, respectively, before entering this range. The + value should be `num_kv_shared_layers` should be a scalar of `sliding_window_pattern`. + laurel_rank (int, *optional*, defaults to 64): + The intermediate size for the linear projections in the Learned Augmented Residual Layer. + activation_sparsity_pattern (Sequence[float], *optional*, defaults to `(0.95, 0.95, 0.95, 0.95, 0.95, 0.95, 0.95, 0.95, 0.95, 0.95, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0)`): + The sparsity factor used to extract the top-k activations for a given layer. The provided Sequence must + explicitly provide a sparsity value for each layer in the model. + + ```python + >>> from transformers import Gemma3nTextModel, Gemma3nTextConfig + + >>> # Initializing a Gemma3nText gemma3n_text-E4B style configuration + >>> configuration = Gemma3nTextConfig() + + >>> # Initializing a model from the gemma3n_text-E4B style configuration + >>> model = Gemma3nTextModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ``` + """ + + model_type = "gemma3n_text" + keys_to_ignore_at_inference = ["past_key_values"] + base_model_tp_plan = { + "layers.*.self_attn.q_proj": "colwise", + "layers.*.self_attn.k_proj": "colwise", + "layers.*.self_attn.v_proj": "colwise", + "layers.*.self_attn.o_proj": "rowwise", + "layers.*.mlp.gate_proj": "colwise", + "layers.*.mlp.up_proj": "colwise", + "layers.*.mlp.down_proj": "rowwise", + } + base_model_pp_plan = { + "embed_tokens": (["input_ids"], ["inputs_embeds"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "norm": (["hidden_states"], ["hidden_states"]), + } + + def __init__( + self, + vocab_size: int = 262_400, + vocab_size_per_layer_input: int = 262_144, + hidden_size: int = 2048, + hidden_size_per_layer_input: int = 256, + intermediate_size: Union[int, Sequence[int]] = 16_384, + num_hidden_layers: int = 35, + num_attention_heads: int = 8, + num_key_value_heads: int = 2, + head_dim: int = 256, + hidden_activation: str = "gelu_pytorch_tanh", + max_position_embeddings: int = 32_768, + initializer_range: float = 0.02, + rms_norm_eps: float = 1e-6, + use_cache: bool = True, + pad_token_id: int = 0, + eos_token_id: int = 1, + bos_token_id: int = 2, + rope_theta: float = 1_000_000.0, + rope_scaling: Optional[dict[str, Any]] = None, + rope_local_base_freq: float = 10_000.0, + attention_bias: bool = False, + attention_dropout: float = 0.0, + sliding_window: int = 512, + layer_types: Optional[Sequence[str]] = None, + final_logit_softcapping: float = 30.0, + altup_active_idx: int = 0, + altup_coef_clip: float = 120.0, + altup_correct_scale: bool = True, + altup_num_inputs: int = 4, + num_kv_shared_layers: int = 15, + laurel_rank: int = 64, + activation_sparsity_pattern: Optional[Union[float, Sequence[float]]] = (0.95,) * 10 + (0.0,) * 25, + **kwargs, + ): + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + **kwargs, + ) + + if isinstance(intermediate_size, Sequence) and (intsize_len := len(intermediate_size)) != num_hidden_layers: + raise ValueError( + "intermediate_size must have an explicit intermediate size for every layer or one for all layers. " + f"Expected {num_hidden_layers} values but got {intsize_len}." + ) + elif not isinstance(intermediate_size, Sequence): + intermediate_size = [intermediate_size] * num_hidden_layers + + self.vocab_size = vocab_size + self.vocab_size_per_layer_input = vocab_size_per_layer_input + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.head_dim = head_dim + self.num_key_value_heads = num_key_value_heads + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + self.hidden_activation = hidden_activation + self.sliding_window = sliding_window + self.final_logit_softcapping = final_logit_softcapping + self.layer_types = layer_types + + self.rope_local_base_freq = rope_local_base_freq + self.rope_scaling = rope_scaling + rope_config_validation(self) + + if layer_types is None: + self.layer_types = [ + "full_attention" if i % 5 == 0 else "sliding_attention" for i in range(self.num_hidden_layers) + ] + else: + self.layer_types = layer_types + + layer_type_validation(self.layer_types) + + self.hidden_size_per_layer_input = hidden_size_per_layer_input + self.num_kv_shared_layers = num_kv_shared_layers + + self.altup_active_idx = altup_active_idx + self.altup_coef_clip = altup_coef_clip + self.altup_correct_scale = altup_correct_scale + self.altup_num_inputs = altup_num_inputs + + self.laurel_rank = laurel_rank + + if activation_sparsity_pattern is None: + activation_sparsity_pattern = [0.0] * num_hidden_layers + + if (len_asp := len(activation_sparsity_pattern)) != num_hidden_layers: + raise ValueError( + "activation_sparsity_pattern must have an explicit activation sparsity value for every layer." + f"Expected {num_hidden_layers} values but got {len_asp}." + ) + self.activation_sparsity_pattern = activation_sparsity_pattern + + +class Gemma3nAudioConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Gemma3nAudioEncoder`], based on Gogole's + [Universal Speech Model](). It is used to instantiate an Gemma3nAudioEncoder model according to the specified + arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar + configuration to that of the Gemma 3n E4B, e.g. [google/gemma-3n-E4B](https://huggingface.co/google/gemma-3n-E4B). + + Configuration objects that inherit from [`Gemma3nAudioConfig`] and can be used to control the model outputs. Read + the documentation from [`Gemma3nAudioConfig`] for more information. + + Args: + vocab_size (`int`, *optional*, defaults to 128): + Vocabulary size of the additional hard-token embeddings for audio model. These augment the embeddings + included in the `Gemma3nTextModel` to provide, e.g., the end of audio and audio soft token placeholder + tokens when converting `input_ids` to embeddings in the `Gemma3nForConditionalGeneration` model. + vocab_offset (`int`, *optional*, defaults to 262272): + Offset between the tokenizer vocab index for the token ids embedded by `Gemma3nMultimodalEmbedder` and the + 0-indexed `Gemma3nMultimodalEmbedder.embedding` table. + input_feat_size (`int`, *optional*, defaults to 128): + The number of channels in each mel-spectrogram frame. + hidden_size (`int`, *optional*, defaults to 1536): + Dimension of the hidden representations. + rms_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the rms normalization layers. + gradient_clipping (`float`, *optional*, defaults to 10000000000.0): + Clipping value used to stablize extremely large gradient values. + conf_attention_chunk_size (`int`, *optional*, defaults to 12): + The sub-sequence size for local attention processing inside the Conformer ("conf") section of the + Universal Speech Model. + conf_attention_context_left (`int`, *optional*, defaults to 13): + The left context size of the local attention inside the Conformer ("conf") section of the + Universal Speech Model. + conf_attention_context_right (`int`, *optional*, defaults to 0): + The right context size of the local attention inside the Conformer ("conf") section of the + Universal Speech Model. + conf_attention_logit_cap (`float`, *optional*, defaults to 50.0): + Logit cap applied during local attention inside the Conformer ("conf") section of the + Universal Speech Model. + conf_num_attention_heads (`int`, *optional*, defaults to 8): + The number of attention heads in local attention inside the Conformer ("conf") section of the + Universal Speech Model. + conf_num_hidden_layers (`int`, *optional*, defaults to 12): + The number of layers that use local attention inside the Conformer ("conf") section of the + Universal Speech Model. + conf_conv_kernel_size (`int`, *optional*, defaults to 5): + Convolution kernel size for the conformer block inside the Conformer ("conf") section of the + Universal Speech Model. + conf_reduction_factor (`int`, *optional*, defaults to 4): + Reduction factor used in the conformer block inside the Conformer ("conf") section of the + Universal Speech Model. + conf_residual_weight (`float`, *optional*, defaults to 0.5): + Residual connection weight inside the Conformer ("conf") section of the + Universal Speech Model. + sscp_conv_channel_size (`tuple(int, int)`, *optional*, defaults to `(128, 32)`): + The channel sizes for the first and second convolutional layers in the Sub-sample Convolution Projection + ("sscp") section of the Universal Speech Model. + sscp_conv_group_norm_eps (`float`, *optional*, defaults to 0.001): + Epsilon used in group normalization in the subsample convolution projection in the Sub-sample Convolution + Projection ("sscp") section of the Universal Speech Model. + sscp_conv_kernel_size (`tuple(tuple(int, int), tuple(int, int))`, *optional*, defaults to `((3, 3), (3, 3))`): + Kernel sizes of the two convolutional layers in the subsample convolution projection in the Sub-sample + Convolution Projection ("sscp") section of the Universal Speech Model. The kernel sizes are specified as a + tuple of height and width for each layer, where the height corresponds to the time dimension and the width + corresponds to the frequency dimension. + sscp_conv_stride_size (`tuple(tuple(int, int), tuple(int, int))`, *optional*, defaults to `((2, 2), (2, 2))`): + Stride sizes of the two convolutional layers in the subsample convolution projection in the Sub-sample + Convolution Projection ("sscp") section of the Universal Speech Model. The stride sizes are specified as a + tuple of height and width for each layer, where the height corresponds to the time dimension and the width + corresponds to the frequency dimension. + + Example: + + ```python + >>> from transformers import Gemma3nAudioConfig, Gemma3nAudioEncoder + + >>> # Initializing a Gemma3nAudioEncoder gemma3n_audio-E4B-style configuration + >>> configuration = Gemma3nAudioConfig() + + >>> # Initializing a model from the gemma3n_audio-E4B style configuration + >>> model = Gemma3nAudioEncoder(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ``` + """ + + model_type = "gemma3n_audio" + + def __init__( + self, + vocab_size: int = 128, + vocab_offset: int = 262_144 + 128, # text vocab size + vision vocab size + input_feat_size: int = 128, + hidden_size: int = 1536, + rms_norm_eps: float = 1e-6, + gradient_clipping: float = 10_000_000_000.0, + conf_attention_chunk_size: int = 12, + conf_attention_context_left: int = 13, + conf_attention_context_right: int = 0, + conf_attention_logit_cap: float = 50.0, + conf_num_attention_heads: int = 8, + conf_num_hidden_layers: int = 12, + conf_conv_kernel_size: int = 5, + conf_reduction_factor: int = 4, + conf_residual_weight: float = 0.5, + sscp_conv_channel_size: tuple[int, int] = (128, 32), + sscp_conv_group_norm_eps: float = 1e-3, + sscp_conv_kernel_size: tuple[tuple[int, int], tuple[int, int]] = ( + (3, 3), + (3, 3), + ), + sscp_conv_stride_size: tuple[tuple[int, int], tuple[int, int]] = ( + (2, 2), + (2, 2), + ), + **kwargs, + ): + super().__init__(**kwargs) + self.input_feat_size = input_feat_size + self.hidden_size = hidden_size + self.rms_norm_eps = rms_norm_eps + self.vocab_size = vocab_size + self.vocab_offset = vocab_offset + self.gradient_clipping = gradient_clipping + self.conf_attention_chunk_size = conf_attention_chunk_size + self.conf_attention_context_left = conf_attention_context_left + self.conf_attention_context_right = conf_attention_context_right + self.conf_attention_logit_cap = conf_attention_logit_cap + self.conf_num_attention_heads = conf_num_attention_heads + self.conf_num_hidden_layers = conf_num_hidden_layers + self.conf_conv_kernel_size = conf_conv_kernel_size + self.conf_reduction_factor = conf_reduction_factor + self.conf_residual_weight = conf_residual_weight + self.sscp_conv_channel_size = sscp_conv_channel_size + self.sscp_conv_group_norm_eps = sscp_conv_group_norm_eps + self.sscp_conv_kernel_size = sscp_conv_kernel_size + self.sscp_conv_stride_size = sscp_conv_stride_size + + +class Gemma3nVisionConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration for a timm backbone [`TimmWrapper`]. It is used to + instantiate an timm model model according to the specified arguments, defining the model architecture. + Instantiating a configuration with the defaults will yield a similar configuration to that of the Gemma 3n E4B + vision tower, e.g. [google/gemma-3n-E4B](https://huggingface.co/google/gemma-3n-E4B). + + Configuration objects inherit from [`Gemma3nVisionConfig`] and can be used to control the model outputs. Read the + documentation from [`Gemma3nVisionConfig`] for more information. + + Config loads imagenet label descriptions and stores them in `id2label` attribute, `label2id` attribute for default + imagenet models is set to `None` due to occlusions in the label descriptions. + + Args: + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + do_pooling (`bool`, *optional*, defaults to `False`): + Whether to do pooling for the last_hidden_state in `TimmWrapper` or not. + architecture (`str`, *optional*, defaults to `"mobilenetv5_300m_enc"`): + Determines vision architecture for TimmWrapper. + hidden_size (`int`, *optional*, defaults to 2048): + Dimension of the hidden representations. + vocab_size (`int`, *optional*, defaults to 128): + Vocabulary size of the additional hard-token embeddings for vision model. + vocab_offset (`int`, *optional*, defaults to 262144): + Offset between the tokenizer vocab index for the token ids embedded by `Gemma3nMultimodalEmbedder` and the + 0-indexed `Gemma3nMultimodalEmbedder.embedding` table. + rms_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the rms normalization layers. + + Example: + ```python + >>> from transformers import Gemma3nVisionConfig, TimmWrapper + + >>> # Initializing a TimmWrapper gemma3n_vision-E4B-style configuration + >>> configuration = Gemma3nVisionConfig() + + >>> # Initializing a gemma3n_vision-E4B-style TimmWrapper from the configuration + >>> model = TimmWrapper(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ``` + """ + + model_type = "gemma3n_vision" + + def __init__( + self, + initializer_range: float = 0.02, + do_pooling: bool = False, + architecture: str = "mobilenetv5_300m_enc", + hidden_size: int = 2048, + vocab_size: int = 128, + vocab_offset: int = 262_144, + rms_norm_eps: float = 1e-06, + model_args: Optional[dict] = None, + **kwargs, + ): + super().__init__(**kwargs) + self.initializer_range = initializer_range + self.do_pooling = do_pooling + self.model_args = model_args # named "model_args" for BC with timm + self.architecture = architecture + self.hidden_size = hidden_size + self.vocab_size = vocab_size + self.vocab_offset = vocab_offset + self.rms_norm_eps = rms_norm_eps + + @classmethod + def from_dict(cls, config_dict: dict[str, Any], **kwargs): + label_names = config_dict.get("label_names", None) + is_custom_model = "num_labels" in kwargs or "id2label" in kwargs + + # if no labels added to config, use imagenet labeller in timm + if label_names is None and not is_custom_model: + requires_backends(cls, ["timm"]) + imagenet_subset = infer_imagenet_subset(config_dict) + if imagenet_subset: + dataset_info = ImageNetInfo(imagenet_subset) + synsets = dataset_info.label_names() + label_descriptions = dataset_info.label_descriptions(as_dict=True) + label_names = [label_descriptions[synset] for synset in synsets] + + if label_names is not None and not is_custom_model: + kwargs["id2label"] = dict(enumerate(label_names)) + + # if all label names are unique, create label2id mapping as well + if len(set(label_names)) == len(label_names): + kwargs["label2id"] = {name: i for i, name in enumerate(label_names)} + else: + kwargs["label2id"] = None + + # timm config stores the `num_classes` attribute in both the root of config and in the "pretrained_cfg" dict. + # We are removing these attributes in order to have the native `transformers` num_labels attribute in config + # and to avoid duplicate attributes + num_labels_in_kwargs = kwargs.pop("num_labels", None) + num_labels_in_dict = config_dict.pop("num_classes", None) + + # passed num_labels has priority over num_classes in config_dict + kwargs["num_labels"] = num_labels_in_kwargs or num_labels_in_dict + + # pop num_classes from "pretrained_cfg", + # it is not necessary to have it, only root one is used in timm + if "pretrained_cfg" in config_dict and "num_classes" in config_dict["pretrained_cfg"]: + config_dict["pretrained_cfg"].pop("num_classes", None) + + return super().from_dict(config_dict, **kwargs) + + def to_dict(self) -> dict[str, Any]: + output = super().to_dict() + output["num_classes"] = self.num_labels + output["label_names"] = list(self.id2label.values()) + output.pop("id2label", None) + output.pop("label2id", None) + return output + + +class Gemma3nConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Gemma3nForConditionalGeneration`]. It is used to + instantiate a Gemma3nForConditionalGeneration according to the specified arguments, defining the model + architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of + Gemma3n-E4B. + + e.g. [google/gemma-3n-E4B](https://huggingface.co/google/gemma-3n-E4B) + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + text_config (`Union[Gemma3nTextConfig, dict]`, *optional*): + The config object of the text backbone. + vision_config (`Union[AutoConfig, dict]`, *optional*): + Custom vision config or dict. + audio_config (`Union[AutoConfig, dict]`, *optional*): + Custom audio config or dict. + audio_soft_tokens_per_image (`int`, *optional*, defaults to 188): + The number of soft tokens per audio clip. + vision_soft_tokens_per_image (`int`, *optional*, defaults to 256): + The number of soft tokens per image. + boi_token_id (`int`, *optional*, defaults to 255999): + The begin-of-image token index to wrap the image prompt. + eoi_token_id (`int`, *optional*, defaults to 262144): + The end-of-image token index to wrap the image prompt. + image_token_id (`int`, *optional*, defaults to 262145): + The image token index to encode the image prompt. + boa_token_id (`int`, *optional*, defaults to 256000): + The begin-of-audio token index to wrap the audio prompt. + eoa_token_id (`int`, *optional*, defaults to 262272): + The end-of-audio token index to wrap the audio prompt. + audio_token_id (`int`, *optional*, defaults to 262273): + The audio token index to encode the audio prompt. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + + + Example: + + ```python + >>> from transformers import Gemma3nForConditionalGeneration, Gemma3nConfig, Gemma3nTextConfig + + >>> # Initializing a MobileNet vision config, which is loaded from TIMM + >>> vision_config = Gemma3nVisionConfig() + + >>> # Initializing a Gemma3n Audio config + >>> audio_config = Gemma3nAudioConfig() + + >>> # Initializing a Gemma3n Text config + >>> text_config = Gemma3nTextConfig() + + >>> # Initializing a Gemma3n gemma-3-4b style configuration + >>> configuration = Gemma3nConfig(text_config, vision_config, audio_config) + + >>> # Initializing a model from the gemma-3-4b style configuration + >>> model = Gemma3nTextConfig(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "gemma3n" + sub_configs = { + "text_config": Gemma3nTextConfig, + "vision_config": Gemma3nVisionConfig, + "audio_config": Gemma3nAudioConfig, + } + + def __init__( + self, + text_config: Optional[Union[Gemma3nTextConfig, dict[str, Any]]] = None, + vision_config: Optional[Union[Gemma3nVisionConfig, dict[str, Any]]] = None, + audio_config: Optional[Union[Gemma3nAudioConfig, dict[str, Any]]] = None, + audio_soft_tokens_per_image: int = 188, + vision_soft_tokens_per_image: int = 256, + boi_token_id: int = 255_999, + eoi_token_id: int = 262_144, + image_token_id: int = 262_145, + boa_token_id: int = 256_000, + eoa_token_id: int = 262_272, + audio_token_id: int = 262_273, + initializer_range: float = 0.02, + **kwargs, + ): + super().__init__(**kwargs) + + if isinstance(text_config, dict): + text_config = Gemma3nTextConfig(**text_config) + elif text_config is None: + text_config = Gemma3nTextConfig() + logger.info("text_config is None. Using default Gemma3nTextConfig.") + + if isinstance(vision_config, dict): + vision_config = Gemma3nVisionConfig(**vision_config) + elif vision_config is None: + vision_config = Gemma3nVisionConfig() + logger.info("vision_config is None. Using default Gemma3nVisionConfig.") + + if isinstance(audio_config, dict): + audio_config = Gemma3nAudioConfig(**audio_config) + elif audio_config is None: + audio_config = Gemma3nAudioConfig() + logger.info("audio_config is None. Using default Gemma3nAudioConfig.") + + self.text_config = text_config + self.vision_config = vision_config + self.audio_config = audio_config + + self.audio_soft_tokens_per_image = audio_soft_tokens_per_image + self.vision_soft_tokens_per_image = vision_soft_tokens_per_image + self.boi_token_id = boi_token_id + self.eoi_token_id = eoi_token_id + self.image_token_id = image_token_id + self.boa_token_id = boa_token_id + self.eoa_token_id = eoa_token_id + self.audio_token_id = audio_token_id + self.initializer_range = initializer_range + + +__all__ = ["Gemma3nAudioConfig", "Gemma3nConfig", "Gemma3nTextConfig", "Gemma3nVisionConfig"] diff --git a/src/transformers/models/gemma3n/convert_gemma3n_weights.py b/src/transformers/models/gemma3n/convert_gemma3n_weights.py new file mode 100644 index 000000000000..2f25ca56d465 --- /dev/null +++ b/src/transformers/models/gemma3n/convert_gemma3n_weights.py @@ -0,0 +1,807 @@ +# coding=utf-8 +# Copyright 2025 Google Inc. HuggingFace Inc. team. All rights reserved. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +r"""Utility to convert Gemma models from Orbax to HF Transformers checkpoint. + +python src/transformers/models/gemma3n/convert_gemma3n_weights.py \ + --variant='gemma3n_e4b' \ + --tokenizer_path="$HOME/nano3/checkpoints/tokenizer/gemma-3n-tokenizer.model" \ + --checkpoint_path="$HOME/nano3/checkpoints/g251_orbax/" \ + --output_path="$HOME/nano3/checkpoints/g251_vision_encoder/" +""" + +import json +import os +import re +from collections.abc import Iterable, Mapping +from typing import Any + +import accelerate +import numpy as np +import torch +import tree +from absl import app, flags, logging +from orbax import checkpoint as obc + +from transformers import ( + Gemma3nAudioConfig, + Gemma3nAudioFeatureExtractor, + Gemma3nConfig, + Gemma3nForConditionalGeneration, + Gemma3nProcessor, + Gemma3nTextConfig, + Gemma3nVisionConfig, + GemmaTokenizerFast, + GenerationConfig, + SiglipImageProcessorFast, +) +from transformers.image_utils import PILImageResampling + + +# ==== Internal Constants and Classes ==== + + +_CHAT_TEMPLATE = """{{ bos_token }} +{%- if messages[0]['role'] == 'system' -%} + {%- if messages[0]['content'] is string -%} + {%- set first_user_prefix = messages[0]['content'] + '\n\n' -%} + {%- else -%} + {%- set first_user_prefix = messages[0]['content'][0]['text'] + '\n\n' -%} + {%- endif -%} + {%- set loop_messages = messages[1:] -%} +{%- else -%} + {%- set first_user_prefix = "" -%} + {%- set loop_messages = messages -%} +{%- endif -%} +{%- for message in loop_messages -%} + {%- if (message['role'] == 'user') != (loop.index0 % 2 == 0) -%} + {{ raise_exception("Conversation roles must alternate user/assistant/user/assistant/...") }} + {%- endif -%} + {%- if (message['role'] == 'assistant') -%} + {%- set role = "model" -%} + {%- else -%} + {%- set role = message['role'] -%} + {%- endif -%} + {{ '' + role + '\n' + (first_user_prefix if loop.first else "") }} + {%- if message['content'] is string -%} + {{ message['content'] | trim }} + {%- elif message['content'] is iterable -%} + {%- for item in message['content'] -%} + {%- if item['type'] == 'audio' -%} + {{ '' }} + {%- elif item['type'] == 'image' -%} + {{ '' }} + {%- elif item['type'] == 'text' -%} + {{ item['text'] | trim }} + {%- endif -%} + {%- endfor -%} + {%- else -%} + {{ raise_exception("Invalid content type") }} + {%- endif -%} + {{ '\n' }} +{%- endfor -%} +{%- if add_generation_prompt -%} + {{'model\n'}} +{%- endif -%} +""" + +_DTYPES = {"float32", "bfloat16", "float16"} + +_SLIDING_WINDOW_PATTERN = 5 + +_AUDIO_ENCODER_PARAMETER = "AudioEncoder/encoder" +_AUDIO_ENCODER_CONFORMER = f"{_AUDIO_ENCODER_PARAMETER}/conformer/stacked_layers" +_AUDIO_ENCODER_SSCP = f"{_AUDIO_ENCODER_PARAMETER}/feature" + +_TRANSFORMER_PARAMETER = "transformer" +_TRANSFORMER_ALTUP_PROJ = f"{_TRANSFORMER_PARAMETER}/altup_projection_" +_TRANSFORMER_ALTUP_UNEMB = f"{_TRANSFORMER_PARAMETER}/altup_unembed_projection_" +_TRANSFORMER_DECODER_BLOCK = f"{_TRANSFORMER_PARAMETER}/stacked_layers/attention_type_" +_TRANSFORMER_DECODER_BLOCK_LEN = len(_TRANSFORMER_DECODER_BLOCK) +_TRANSFORMER_EMBEDDER = f"{_TRANSFORMER_PARAMETER}/embedder" +_TRANSFORMER_FINAL_NORM = "transformer/final_norm" +_TRANSFORMER_POST_TRAINING_PREFIX = "rlx_networks/policy_network/" +_TRANSFORMER_POST_TRAINING_PREFIX_LEN = len(_TRANSFORMER_POST_TRAINING_PREFIX) + +# _MOBILE_NET_CONFIG = Gemma3nVisionConfig.from_pretrained("") + +_MOBILE_NET_PREFIX = "mobilenet" +_MOBILE_NET_TIMM_SUMMED_BLOCK_SIZES = [3, 8, 45, 84] +_MOBILE_NET_CONV = "block_group_conv2d_" +_MOBILE_NET_FIB = "block_group_fused_ib_" +_MOBILE_NET_MQA = "block_group_mmqa_" +_MOBILE_NET_MSFA = "block_adapter_" +_MOBILE_NET_UIB = "block_group_uib_" +_MOBILE_NET_UIB_HAS_DW_START = { + (1, 0), + (1, 1), + (1, 2), + (1, 3), + (1, 4), + (2, 0), + (2, 1), + (2, 2), + (2, 3), + (2, 4), + (2, 5), + (2, 6), + (2, 7), + (3, 0), +} +_MOBILE_NET_UIB_HAS_DW_MID = { + (1, 0), + (2, 0), + (3, 0), +} + +_VARIANT_GEMMA_3_2B = "gemma3n_e2b" +_VARIANT_GEMMA_3_4B = "gemma3n_e4b" +_VARIANTS: Mapping[str, Gemma3nConfig] = { + _VARIANT_GEMMA_3_2B: Gemma3nConfig( + text_config=Gemma3nTextConfig( + intermediate_size=2048 * 4, + num_hidden_layers=30, + activation_sparsity_pattern=(0.95,) * 10 + (0.0,) * 20, + num_kv_shared_layers=10, + ), + vision_config=Gemma3nVisionConfig(), + audio_config=Gemma3nAudioConfig(), + ), + _VARIANT_GEMMA_3_4B: Gemma3nConfig( + text_config=Gemma3nTextConfig(), + vision_config=Gemma3nVisionConfig(), + audio_config=Gemma3nAudioConfig(), + ), +} + + +# ==== Flags ==== + +_AUDIO_DTYPE = flags.DEFINE_enum( + name="audio_dtype", + default="bfloat16", + help="The floating point precision (aka dtype) of the model.", + enum_values=_DTYPES, +) + +_CHECKPOINT_PATH = flags.DEFINE_string( + name="checkpoint_path", + default=None, + help="Path to the Orbax checkpoint.", + required=True, +) + +_INCLUDE_CHAT_TEMPLATE = flags.DEFINE_bool( + name="include_chat_template", default=False, help="If true, will save the default chat template with the tokenizer" +) + +_OUTPUT_PATH = flags.DEFINE_string( + name="output_path", + default=None, + help="Path to store the HF checkpoint.", + required=True, +) + +_TRANSFORMER_DTYPE = flags.DEFINE_enum( + name="text_dtype", + default="bfloat16", + help="The floating point precision (aka dtype) of the model.", + enum_values=_DTYPES, +) + +_TOKENIZER_PATH = flags.DEFINE_string( + name="tokenizer_path", + default=None, + help="Path to the SentencePiece model file.", + required=True, +) + +_VARIANT = flags.DEFINE_enum( + name="variant", + default=_VARIANT_GEMMA_3_4B, + help="The model variant to convert.", + enum_values=set(_VARIANTS.keys()), +) + +_VERBOSE = flags.DEFINE_bool( + name="verbose", + default=False, + help="If true, log the path, shape, and dtype of every converted layer.", +) + +_VISION_DTYPE = flags.DEFINE_enum( + name="vision_dtype", + default="bfloat16", + help="The floating point precision (aka dtype) of the model.", + enum_values=_DTYPES, +) + + +def convert_audio_encoder_weights( + config: Gemma3nAudioConfig, + path: str, + param: str, + weights: np.ndarray, +) -> Iterable[tuple[str, np.ndarray]]: + converted_paths: list[str] = [] + converted_weights: list[Any] = [] + + if path.startswith(_AUDIO_ENCODER_CONFORMER): + assert weights.shape[0] == config.conf_num_hidden_layers + + for i, matrix in enumerate(weights): + if "fflayer_end" in path: + base = f"conformer.{i}.ffw_layer_end" + + if path.endswith("ffn_layer1"): + converted_paths.append(f"{base}.ffw_layer_1.weight") + converted_weights.append(matrix.transpose()) + elif path.endswith("ffn_layer2"): + converted_paths.append(f"{base}.ffw_layer_2.weight") + converted_weights.append(matrix.transpose()) + elif path.endswith("post_layer_norm"): + converted_paths.append(f"{base}.post_layer_norm.weight") + converted_weights.append(matrix) + elif path.endswith("pre_layer_norm"): + converted_paths.append(f"{base}.pre_layer_norm.weight") + converted_weights.append(matrix) + elif "fflayer_start" in path: + base = f"conformer.{i}.ffw_layer_start" + + if path.endswith("ffn_layer1"): + converted_paths.append(f"{base}.ffw_layer_1.weight") + converted_weights.append(matrix.transpose()) + elif path.endswith("ffn_layer2"): + converted_paths.append(f"{base}.ffw_layer_2.weight") + converted_weights.append(matrix.transpose()) + elif path.endswith("post_layer_norm"): + converted_paths.append(f"{base}.post_layer_norm.weight") + converted_weights.append(matrix) + elif path.endswith("pre_layer_norm"): + converted_paths.append(f"{base}.pre_layer_norm.weight") + converted_weights.append(matrix) + elif path.endswith("final_ln"): + converted_paths.append(f"conformer.{i}.norm.weight") + converted_weights.append(matrix) + elif "lconv" in path: + base = f"conformer.{i}.lconv1d" + + if path.endswith("conv_norm"): + converted_paths.append(f"{base}.conv_norm.weight") + converted_weights.append(matrix) + elif path.endswith("depthwise_conv1d"): + converted_paths.append(f"{base}.depthwise_conv1d.weight") + converted_weights.append(matrix.transpose()) + elif path.endswith("linear_end"): + converted_paths.append(f"{base}.linear_end.weight") + converted_weights.append(matrix.transpose()) + elif path.endswith("linear_start"): + converted_paths.append(f"{base}.linear_start.weight") + converted_weights.append(matrix.transpose()) + elif path.endswith("ln"): + converted_paths.append(f"{base}.pre_layer_norm.weight") + converted_weights.append(matrix) + elif "trans_atten" in path: + base = f"conformer.{i}.attention" + + if param == "per_dim_scale": + converted_paths.append(f"{base}.attn.per_dim_scale") + converted_weights.append(matrix) + + if path.endswith("query_key_value_projection"): + converted_paths.extend( + [f"{base}.attn.q_proj.weight", f"{base}.attn.k_proj.weight", f"{base}.attn.v_proj.weight"] + ) + converted_weights.extend( + [ + m.reshape(config.hidden_size, config.hidden_size).transpose() + for m in matrix.transpose(1, 0, 2, 3) + ] + ) + elif path.endswith("pos_proj"): + converted_paths.append(f"{base}.attn.relative_position_embedding.pos_proj.weight") + converted_weights.append(matrix.reshape(config.hidden_size, config.hidden_size).transpose()) + elif path.endswith("post"): + converted_paths.append(f"{base}.post.weight") + converted_weights.append(matrix.transpose(2, 0, 1).reshape(config.hidden_size, config.hidden_size)) + elif path.endswith("post_norm"): + converted_paths.append(f"{base}.post_norm.weight") + converted_weights.append(matrix) + elif path.endswith("pre_norm"): + converted_paths.append(f"{base}.pre_attn_norm.weight") + converted_weights.append(matrix) + elif path.startswith(_AUDIO_ENCODER_SSCP): + if path.endswith("input_proj"): + converted_paths.append("subsample_conv_projection.input_proj_linear.weight") + converted_weights.append( + weights.transpose(2, 0, 1).reshape(config.hidden_size, config.sscp_conv_channel_size[1] ** 2) + ) + elif "norm_" in path: + index = int(path[-1]) + converted_paths.append(f"subsample_conv_projection.conv_{index}.norm.weight") + converted_weights.append(weights) + elif "subsampling_" in path: + index = int(path[-1]) + converted_paths.append(f"subsample_conv_projection.conv_{index}.conv.weight") + converted_weights.append(weights.transpose(3, 2, 0, 1)) + + if (cpl := len(converted_paths)) != (cwl := len(converted_weights)): + raise ValueError( + "The `converted_paths` and `converted_weights` should be the same " + f"length. Got {cpl} and {cwl}, respectively, for {path}." + ) + + return zip(converted_paths, converted_weights) + + +def convert_transformer_weights( + config: Gemma3nTextConfig, + path: str, + param: str, + weights: np.ndarray, +) -> Iterable[tuple[str, np.ndarray]]: + if path.startswith(_TRANSFORMER_POST_TRAINING_PREFIX): + path = path[_TRANSFORMER_POST_TRAINING_PREFIX_LEN:] + + converted_paths: list[str] = [] + converted_weights: list[Any] = [] + + if path.startswith(_TRANSFORMER_ALTUP_PROJ): + index = int(path[-1]) + converted_paths.append(f"altup_projections.{index}.weight") + converted_weights.append(weights.transpose()) + elif path.startswith(_TRANSFORMER_ALTUP_UNEMB): + index = int(path[-1]) + converted_paths.append(f"altup_unembed_projections.{index}.weight") + converted_weights.append(weights.transpose()) + elif path.startswith(_TRANSFORMER_DECODER_BLOCK): + attention_type_index = int(path[_TRANSFORMER_DECODER_BLOCK_LEN]) + assert weights.shape[0] == config.num_hidden_layers / _SLIDING_WINDOW_PATTERN + + for i, matrix in enumerate(weights): + layer_idx = _SLIDING_WINDOW_PATTERN * i + attention_type_index + base_path = f"layers.{layer_idx}" + + if "altup" in path: + altup_path = f"{base_path}.altup" + + if param == "correct_output_scale": + converted_paths.append(f"{altup_path}.correct_output_scale") + converted_weights.append(matrix) + elif param == "correction_coefs": + converted_paths.append(f"{altup_path}.correction_coefs.weight") + converted_weights.append(matrix.transpose()) + elif param == "prediction_coefs": + converted_paths.append(f"{altup_path}.prediction_coefs.weight") + converted_weights.append( + np.clip( + matrix.reshape(config.altup_num_inputs, config.altup_num_inputs**2).transpose(), + -config.altup_coef_clip, + config.altup_coef_clip, + ) + ) + + if path.endswith("modality_router"): + converted_paths.append(f"{altup_path}.modality_router.weight") + converted_weights.append(matrix.transpose()) + elif path.endswith("router_norm_layer"): + converted_paths.append(f"{altup_path}.router_norm.weight") + converted_weights.append(matrix) + elif path.endswith("attn/attn_vec_einsum"): + converted_paths.append(f"{base_path}.self_attn.o_proj.weight") + converted_weights.append( + matrix.transpose(2, 0, 1).reshape(config.hidden_size, config.num_attention_heads * config.head_dim) + ) + elif path.endswith("attn/kv_einsum"): + converted_paths.extend( + [ + f"{base_path}.self_attn.k_proj.weight", + f"{base_path}.self_attn.v_proj.weight", + ] + ) + k_proj_weights, v_proj_weights = matrix.transpose(0, 2, 1, 3) + kv_proj_shape = (config.hidden_size, config.num_key_value_heads * config.head_dim) + converted_weights.extend( + [ + k_proj_weights.reshape(kv_proj_shape).transpose(), + v_proj_weights.reshape(kv_proj_shape).transpose(), + ] + ) + elif path.endswith("attn/q_einsum"): + converted_paths.append(f"{base_path}.self_attn.q_proj.weight") + converted_weights.append( + matrix.transpose(1, 0, 2) + .reshape(config.hidden_size, config.num_attention_heads * config.head_dim) + .transpose() + ) + elif path.endswith("attn/query_norm"): + converted_paths.append(f"{base_path}.self_attn.q_norm.weight") + converted_weights.append(matrix) + elif path.endswith("attn/key_norm"): + converted_paths.append(f"{base_path}.self_attn.k_norm.weight") + converted_weights.append(matrix) + elif path.endswith("laurel_block/linear_left"): + converted_paths.append(f"{base_path}.laurel.linear_left.weight") + converted_weights.append(matrix.transpose()) + elif path.endswith("laurel_block/linear_right"): + converted_paths.append(f"{base_path}.laurel.linear_right.weight") + converted_weights.append(matrix.transpose()) + elif path.endswith("mlp/gating_einsum"): + converted_paths.extend([f"{base_path}.mlp.gate_proj.weight", f"{base_path}.mlp.up_proj.weight"]) + gate_proj_weight, up_proj_weight = matrix + converted_weights.extend([gate_proj_weight, up_proj_weight]) + elif path.endswith("mlp/linear"): + converted_paths.append(f"{base_path}.mlp.down_proj.weight") + converted_weights.append(matrix.transpose()) + elif path.endswith("per_layer_input_gate"): + converted_paths.append(f"{base_path}.per_layer_input_gate.weight") + converted_weights.append(matrix.transpose()) + elif path.endswith("per_layer_projection"): + converted_paths.append(f"{base_path}.per_layer_projection.weight") + converted_weights.append(matrix.transpose()) + elif path.endswith("post_attention_norm"): + converted_paths.append(f"{base_path}.post_attention_layernorm.weight") + converted_weights.append(matrix) + elif path.endswith("post_ffw_norm"): + converted_paths.append(f"{base_path}.post_feedforward_layernorm.weight") + converted_weights.append(matrix) + elif path.endswith("post_laurel_norm"): + converted_paths.append(f"{base_path}.laurel.post_laurel_norm.weight") + converted_weights.append(matrix) + elif path.endswith("post_per_layer_input_norm"): + converted_paths.append(f"{base_path}.post_per_layer_input_norm.weight") + converted_weights.append(matrix) + elif path.endswith("pre_attention_norm"): + converted_paths.append(f"{base_path}.input_layernorm.weight") + converted_weights.append(matrix) + elif path.endswith("pre_ffw_norm"): + converted_paths.append(f"{base_path}.pre_feedforward_layernorm.weight") + converted_weights.append(matrix) + elif path == _TRANSFORMER_EMBEDDER: + if param == "input_embedding": + converted_paths.append("embed_tokens.weight") + # Gemma 3n model doesn't have soft tokens or "end of" tokens for images and audio in its input and output + # embeddings, so we resize to avoid bugs observed with Mllama + pre_expansion_embeddings = weights + pad_token_slice = slice(config.pad_token_id, config.pad_token_id + 1) + new_embeddings = np.repeat(pre_expansion_embeddings[pad_token_slice], 256, axis=0) + weights = np.vstack([pre_expansion_embeddings, new_embeddings]) + converted_weights.append(weights) + elif param == "per_layer_embeddings": + converted_paths.append("embed_tokens_per_layer.weight") + converted_weights.append( + weights.reshape( + config.vocab_size_per_layer_input, config.num_hidden_layers * config.hidden_size_per_layer_input + ) + ) + elif path.startswith(_TRANSFORMER_EMBEDDER): + # TODO: ryanmullins - support multimodal norms and projections + if path.endswith("per_layer_model_projection"): + converted_paths.append("per_layer_model_projection.weight") + converted_weights.append( + weights.reshape( + config.hidden_size, config.num_hidden_layers * config.hidden_size_per_layer_input + ).transpose() + ) + elif path.endswith("per_layer_projection_norm"): + converted_paths.append("per_layer_projection_norm.weight") + converted_weights.append(weights) + elif path == _TRANSFORMER_FINAL_NORM: + converted_paths = ["norm.weight"] + converted_weights = [weights] + + if (cpl := len(converted_paths)) != (cwl := len(converted_weights)): + raise ValueError( + "The `converted_paths` and `converted_weights` should be the same " + f"length. Got {cpl} and {cwl}, respectively, for {path}." + ) + + return zip(converted_paths, converted_weights) + + +def convert_vision_weights( + config: Gemma3nVisionConfig, + path: str, + param: str, + weights: np.ndarray, +) -> Iterable[tuple[str, np.ndarray]]: + def generate_base_path(path: str, block_type: str) -> tuple[str, tuple[int, int]]: + re_str = r"{}(\d+)/".format(block_type) + re_pattern = re.compile(re_str) + match = re.search(re_pattern, path).group(1) + idx = abs(int(match)) - 1 + + for block_idx, v in enumerate(_MOBILE_NET_TIMM_SUMMED_BLOCK_SIZES): + if v > idx: + offset = _MOBILE_NET_TIMM_SUMMED_BLOCK_SIZES[block_idx - 1] if block_idx > 0 else 0 + layer_idx = idx - offset + return f"blocks.{block_idx}.{layer_idx}", (block_idx, layer_idx) + + raise ValueError(f"could not extract a base path from {path}") + + if _MOBILE_NET_MSFA in path: + converted_path = "msfa" + + if "ffn/Normalize_0" in path: + converted_path += ".ffn.pw_exp.bn.weight" + converted_weight = weights + elif "ffn/Normalize_1" in path: + converted_path += ".ffn.pw_proj.bn.weight" + converted_weight = weights + elif "ffn/expand" in path: + converted_path += ".ffn.pw_exp.conv.weight" + converted_weight = weights.transpose()[:, :, None, None] + elif "ffn/project" in path: + converted_path += ".ffn.pw_proj.conv.weight" + converted_weight = weights.transpose()[:, :, None, None] + elif "Normalize_0" in path: + converted_path += ".norm.weight" + converted_weight = weights + elif _MOBILE_NET_CONV in path: + if "Conv_0" in path: + converted_path = "conv_stem.conv.weight" + converted_weight = weights.transpose(3, 2, 1, 0) + elif "Normalize_0" in path: + converted_path = "conv_stem.bn.weight" + converted_weight = weights + elif _MOBILE_NET_FIB in path: + converted_path, _ = generate_base_path(path, _MOBILE_NET_FIB) + if "Normalize_0" in path: + converted_path += ".bn1.weight" + converted_weight = weights + elif "Normalize_1" in path: + converted_path += ".bn2.weight" + converted_weight = weights + elif "expand_conv" in path: + converted_path += ".conv_exp.weight" + converted_weight = weights.transpose(3, 2, 1, 0) + else: + converted_path += ".conv_pwl.weight" + converted_weight = weights.transpose()[:, :, None, None] + elif _MOBILE_NET_MQA in path: + converted_path, _ = generate_base_path(path, _MOBILE_NET_MQA) + + if "LayerScale_0" in path: + converted_path += ".layer_scale.gamma" + converted_weight = weights + elif "Normalize_0" in path: + converted_path += ".norm.weight" + converted_weight = weights + elif "Normalize_1" in path: + converted_path += ".attn.key.norm.weight" + converted_weight = weights + elif "Normalize_2" in path: + converted_path += ".attn.value.norm.weight" + converted_weight = weights + elif "key_dwconv" in path: + converted_path += ".attn.key.down_conv.weight" + converted_weight = weights.transpose() + elif "key_proj" in path: + converted_path += ".attn.key.proj.weight" + converted_weight = weights.transpose()[:, :, None, None] + elif "output_proj" in path: + converted_path += ".attn.output.proj.weight" + converted_weight = weights.transpose()[:, :, None, None] + elif "query_proj" in path: + converted_path += ".attn.query.proj.weight" + converted_weight = weights.transpose()[:, :, None, None] + elif "value_dwconv" in path: + converted_path += ".attn.value.down_conv.weight" + converted_weight = weights.transpose() + elif "value_proj" in path: + converted_path += ".attn.value.proj.weight" + converted_weight = weights.transpose()[:, :, None, None] + elif _MOBILE_NET_UIB in path: + converted_path, idx_key = generate_base_path(path, _MOBILE_NET_UIB) + + has_dw_start = idx_key in _MOBILE_NET_UIB_HAS_DW_START + has_dw_mid = idx_key in _MOBILE_NET_UIB_HAS_DW_MID + + if "LayerScale_0" in path: + converted_path += ".layer_scale.gamma" + converted_weight = weights + elif "Normalize_0" in path: + converted_path += ".dw_start.bn.weight" if has_dw_start else ".pw_exp.bn.weight" + converted_weight = weights + elif "Normalize_1" in path: + converted_path += ".pw_exp.bn.weight" if has_dw_start else ".pw_proj.bn.weight" + converted_weight = weights + elif "Normalize_2" in path: + converted_path += ".dw_mid.bn.weight" if has_dw_mid else ".pw_proj.bn.weight" + converted_weight = weights + elif "Normalize_3" in path: + converted_path += ".pw_proj.bn.weight" + converted_weight = weights + elif "expand" in path: + converted_path += ".pw_exp.conv.weight" + converted_weight = weights.transpose()[:, :, None, None] + elif "middle_dwconv" in path: + converted_path += ".dw_mid.conv.weight" + converted_weight = weights.transpose(3, 2, 1, 0) + elif "project" in path: + converted_path += ".pw_proj.conv.weight" + converted_weight = weights.transpose()[:, :, None, None] + elif "start_dwconv" in path: + converted_path += ".dw_start.conv.weight" + converted_weight = weights.transpose(3, 2, 1, 0) + + return [(converted_path, converted_weight)] + + +def convert(checkpoint_path: str, config: Gemma3nConfig) -> dict[str, torch.Tensor]: + """Loads Orbax checkpoint from `input_path` and converts it to HF tree.""" + checkpointer = obc.PyTreeCheckpointer() + ckpt = checkpointer.restore(checkpoint_path) + hf_tree: dict[str, torch.Tensor] = {} + + def update_tree(path: str, weights: np.ndarray, target_dtype: torch.dtype) -> None: + hf_tree[path] = torch.from_numpy(weights.astype("float32")).type(target_dtype) + if _VERBOSE.value: + logging.info( + "%s converted shape=%s with dtype=%s", + path, + weights.shape, + target_dtype, + ) + + for (path, param), value in tree.flatten_with_path(ckpt): + if param == "audio_input_embedding_extra": + update_tree("model.embed_audio.embedding.weight", value, config.audio_config.torch_dtype) + elif path.endswith("audio_embedding_norm"): + update_tree("model.embed_audio.hard_embedding_norm.weight", value, config.audio_config.torch_dtype) + elif path.endswith("audio_input_projection"): + update_tree( + "model.embed_audio.embedding_projection.weight", value.transpose(), config.audio_config.torch_dtype + ) + elif path.endswith("audio_soft_embedding_norm"): + update_tree("model.embed_audio.soft_embedding_norm.weight", value, config.audio_config.torch_dtype) + elif param == "mm_input_embedding_extra": + update_tree("model.embed_vision.embedding.weight", value, config.vision_config.torch_dtype) + elif path.endswith("mm_hard_embedding_norm"): + update_tree("model.embed_vision.hard_embedding_norm.weight", value, config.vision_config.torch_dtype) + elif path.endswith("mm_input_projection"): + update_tree( + "model.embed_vision.embedding_projection.weight", value.transpose(), config.vision_config.torch_dtype + ) + elif path.endswith("mm_soft_embedding_norm"): + update_tree("model.embed_vision.soft_embedding_norm.weight", value, config.vision_config.torch_dtype) + elif path.startswith(_TRANSFORMER_PARAMETER): + for path, weights in convert_transformer_weights(config.text_config, path, param, value): + update_tree(f"model.language_model.{path}", weights, config.text_config.torch_dtype) + elif _MOBILE_NET_PREFIX in path: + mobilenet_prefix_idx = path.index(_MOBILE_NET_PREFIX) + path = path[mobilenet_prefix_idx:] + for path, weights in convert_vision_weights(config.vision_config, path, param, value): + update_tree(f"model.vision_tower.timm_model.{path}", weights, config.vision_config.torch_dtype) + elif path.startswith(_AUDIO_ENCODER_PARAMETER): + for path, weights in convert_audio_encoder_weights(config.audio_config, path, param, value): + update_tree(f"model.audio_tower.{path}", weights, config.audio_config.torch_dtype) + + hf_tree["lm_head.weight"] = hf_tree["model.language_model.embed_tokens.weight"] + + return hf_tree + + +def main(*args): + del args + + output_path = _OUTPUT_PATH.value + variant = _VARIANT.value + + config = _VARIANTS[variant] + config.audio_config.torch_dtype = getattr(torch, _AUDIO_DTYPE.value) + config.text_config.torch_dtype = getattr(torch, _TRANSFORMER_DTYPE.value) + config.vision_config.torch_dtype = getattr(torch, _VISION_DTYPE.value) + if _INCLUDE_CHAT_TEMPLATE.value: + # Chat template is included for instruction tuned models, which treat + # both "" and "" as generation stoppers. + config.eos_token_id = [1, 106] + + logging.info( + "Converting Gemma 3 (%s) @ %s (language) and %s (vision)", + variant, + _TRANSFORMER_DTYPE.value, + _VISION_DTYPE.value, + ) + state_tree = convert(_CHECKPOINT_PATH.value, config) + logging.info("Converted Gemma 3 (%s) state tree from Orbax to Hugging Face.", variant) + + with accelerate.init_empty_weights(): + model = Gemma3nForConditionalGeneration(config=config) + + model.load_state_dict(state_tree, assign=True, strict=True) + logging.info( + "Loaded Gemma 3 (%s) in Hugging Face Transformers as a %s instance.", + variant, + type(model).__name__, + ) + model.save_pretrained(output_path, state_dict=state_tree, safe_serialization=True) + logging.info( + "Saved Gemma 3 (%s) to SafeTensors in %s using %s", + variant, + output_path, + type(model).__name__, + ) + del model + del state_tree + + chat_template_kwargs = {"chat_template": _CHAT_TEMPLATE} if _INCLUDE_CHAT_TEMPLATE.value else {} + + tokenizer = GemmaTokenizerFast( + _TOKENIZER_PATH.value, + add_bos_token=True, + extra_special_tokens={ + "image_token": "", # Should be ID=262_145 + "boi_token": "", # Should be ID=255_999 + "eoi_token": "", # Should be ID=262_144 + "audio_token": "", # Should be ID=262_273 + "boa_token": "", # Should be ID=256_000 + "eoa_token": "", # Should be ID=262_272 + }, + **chat_template_kwargs, + ) + tokenizer.save_pretrained(output_path) + logging.info("Saved GemmaTokenizer for %s to %s", variant, output_path) + + feature_extractor = Gemma3nAudioFeatureExtractor() + image_processor = SiglipImageProcessorFast( + image_seq_length=256, + image_mean=(0.5,) * 3, + image_std=(0.5,) * 3, + size={"height": 768, "width": 768}, + resample=PILImageResampling.BILINEAR, + do_normalize=False, + ) + processor = Gemma3nProcessor( + feature_extractor=feature_extractor, + image_processor=image_processor, + tokenizer=tokenizer, + **chat_template_kwargs, + ) + processor.save_pretrained(output_path) + + logging.info("Saved Gemma3nProcessor for %s to %s", variant, output_path) + + # NOTE: feature_extractor and image_processor both use the same filename, preprocessor_config.json, when saved to + # disk, but the files are overwritten by processor.save_pretrained(). However, the configs can be unioned, saved, + # and loaded from the same preprocessor_config.json file, so we do that explicitly here. + feature_extractor_config = json.loads(feature_extractor.to_json_string()) + image_processor_config = json.loads(image_processor.to_json_string()) + preprocessor_config = {**feature_extractor_config, **image_processor_config} + with open(os.path.join(output_path, "preprocessor_config.json"), "w", encoding="utf-8") as writer: + writer.write(json.dumps(preprocessor_config, indent=2, sort_keys=True) + "\n") + + logging.info("Saved joint preprocessor_config.json for %s to %s", variant, output_path) + + del feature_extractor, image_processor, processor, tokenizer + + generation_config = GenerationConfig( + pad_token_id=config.text_config.pad_token_id, + bos_token_id=config.text_config.bos_token_id, + eos_token_id=( + [config.text_config.eos_token_id, 106] if _INCLUDE_CHAT_TEMPLATE.value else config.text_config.eos_token_id + ), + cache_implementation="hybrid", + temperature=1.0, + do_sample=True, + top_k=64, + top_p=0.95, + ) + generation_config.save_pretrained(output_path) + + +if __name__ == "__main__": + app.run(main) diff --git a/src/transformers/models/gemma3n/feature_extraction_gemma3n.py b/src/transformers/models/gemma3n/feature_extraction_gemma3n.py new file mode 100644 index 000000000000..63598926af2b --- /dev/null +++ b/src/transformers/models/gemma3n/feature_extraction_gemma3n.py @@ -0,0 +1,338 @@ +# coding=utf-8 +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from collections.abc import Sequence +from typing import Optional, Union + +import numpy as np + +from ...feature_extraction_sequence_utils import SequenceFeatureExtractor +from ...feature_extraction_utils import BatchFeature +from ...utils import PaddingStrategy, TensorType, logging + + +logger = logging.get_logger(__name__) + + +def create_fb_matrix( + n_freqs: int, + f_min: float, + f_max: float, + n_mels: int, + sample_rate: int, + fft_length: int, + norm: Optional[str] = None, +) -> np.ndarray: + r"""Create a frequency bin conversion matrix (NumPy version). + + Args: + n_freqs (int): Number of frequencies to highlight/apply + f_min (float): Minimum frequency (Hz) + f_max (float): Maximum frequency (Hz) + n_mels (int): Number of mel filterbanks + sample_rate (int): Sample rate of the audio waveform + fft_length (int): FFT length + norm (Optional[str]): If 'slaney', divide the triangular mel weights by + the width of the mel band (area normalization). (Default: ``None``) + + Returns: + np.ndarray: Triangular filter banks (fb matrix) of size (``n_freqs``, + ``n_mels``) + meaning number of frequencies to highlight/apply to x the number of + filterbanks. + Each column is a filterbank so that assuming there is a matrix A of + size (..., ``n_freqs``), the applied result would be + ``A @ create_fb_matrix_numpy(A.shape[-1], ...)``. + """ + + if norm is not None and norm != "slaney": + raise ValueError("norm must be one of None or 'slaney'") + + # freq bins + all_freqs = np.arange(n_freqs, dtype=np.float32) * (sample_rate / fft_length) + + # calculate mel freq bins + # hertz to mel(f) is 2595. * math.log10(1. + (f / 700.)) + m_min = 2595.0 * math.log10(1.0 + (f_min / 700.0)) + m_max = 2595.0 * math.log10(1.0 + (f_max / 700.0)) + m_pts = np.linspace(m_min, m_max, n_mels + 2) + # mel to hertz(mel) is 700. * (10**(mel / 2595.) - 1.) + f_pts = 700.0 * (10 ** (m_pts / 2595.0) - 1.0) + # calculate difference between each mel point and each stft freq point in Hz + f_diff = f_pts[1:] - f_pts[:-1] # (n_mels + 1) + slopes = np.expand_dims(f_pts, 0) - np.expand_dims(all_freqs, 1) # (n_freqs, n_mels + 2) + # create overlapping triangles + zero = np.zeros(1, dtype=np.float32) + down_slopes = (-1.0 * slopes[:, :-2]) / f_diff[:-1] # (n_freqs, n_mels) + up_slopes = slopes[:, 2:] / f_diff[1:] # (n_freqs, n_mels) + fb = np.maximum(zero, np.minimum(down_slopes, up_slopes)) + + if norm is not None and norm == "slaney": + # Slaney-style mel is scaled to be approx constant energy per channel + enorm = 2.0 / (f_pts[2 : n_mels + 2] - f_pts[:n_mels]) + fb *= np.expand_dims(enorm, 0) + + return fb + + +def _unfold(array: np.ndarray, dimension: int, size: int, step: int) -> np.ndarray: + """A basic NumPy equivalent of PyTorch's unfold for 2D arrays along the last dim.""" + if array.ndim != 2: + raise ValueError("This unfold implementation currently supports 2D arrays (batch, time).") + if dimension != -1 and dimension != array.ndim - 1: + raise ValueError("This unfold implementation only supports unfolding the last dimension.") + + batch_size, original_length = array.shape + num_frames = (original_length - size) // step + 1 + + if num_frames <= 0: + return np.zeros((batch_size, 0, size), dtype=array.dtype) + + output_shape = (batch_size, num_frames, size) + output_strides = (array.strides[0], array.strides[1] * step, array.strides[1]) + + return np.lib.stride_tricks.as_strided(array, shape=output_shape, strides=output_strides) + + +class Gemma3nAudioFeatureExtractor(SequenceFeatureExtractor): + """An audio feature extractor Universal Speech Models https://arxiv.org/abs/2303.01037. + + Args: + feature_size (`int`, *optional*, defaults to 128): + The feature dimension of the extracted features. + sampling_rate (`int`, *optional*, defaults to 16000): + The sampling rate at which the audio files should be digitalized expressed in hertz (Hz). + padding_value (`float`, *optional*, defaults to 0.0): + Padding value used to pad the audio. Should correspond to silences. + return_attention_mask (`bool`, *optional*, defaults to `True`): + Whether to return the attention mask for the generated MEL spectrograms. + frame_length_ms (`float`, *optional*, defaults to 32.0): + The length of a frame in milliseconds. + hop_length_ms (`float`, *optional*, defaults to 10.0): + Length of the overlapping windows for the STFT used to obtain the Mel Frequency coefficients. + min_frequency (`float`, *optional*, defaults to 125.0): + The minimum frequency (in Hz) for the Mel filterbank. + max_frequency (`float`, *optional*, defaults to 7600.0): + The maximum frequency (in Hz) for the Mel filterbank. + preemphasis (`float`, *optional*, defaults to 0.97): + The preemphasis coefficient. + preemphasis_htk_flavor (`bool`, *optional*, defaults to `True`): + Whether to use HTK-style preemphasis. + fft_overdrive (`bool`, *optional*, defaults to `True`): + Whether to use FFT overdrive. + dither (`float`, *optional*, defaults to 0.0): + Adds dithering. In other words, adds a small Gaussian noise to each frame. + E.g. use 0.0001 to add dithering with a normal distribution centered + around 0.0 with standard deviation 0.0001 (assuming [-1,+1] range of raw_speech). + The value 0.0 means no dithering. + Dithering has similar effect as `spectrogram(mel_floor=...)`. It reduces + the high log_mel_fbank values for signals with hard-zero sections, + when VAD cutoff is present in the signal. + input_scale_factor (`float`, *optional*, defaults to 1.0): + Scaling factor applied to the input waveform. + mel_floor (`float`, *optional*, defaults to 1e-05): + Minimum value for Mel spectrograms to avoid log(0). + per_bin_mean (`Optional[Sequence[float]]`, *optional*): + Mean values for per-bin normalization. + per_bin_stddev (`Optional[Sequence[float]]`, *optional*): + Standard deviation values for per-bin normalization. + """ + + model_input_names = ["input_features", "input_features_mask"] + + def __init__( + self, + feature_size: int = 128, + sampling_rate: int = 16_000, + padding_value: float = 0.0, + return_attention_mask: bool = True, + frame_length_ms: float = 32.0, + hop_length_ms: float = 10.0, + min_frequency: float = 125.0, + max_frequency: float = 7600.0, + preemphasis: float = 0.97, + preemphasis_htk_flavor: bool = True, + fft_overdrive: bool = True, + dither: float = 0.0, + input_scale_factor: float = 1.0, + mel_floor: float = 1e-5, + per_bin_mean: Optional[Sequence[float]] = None, + per_bin_stddev: Optional[Sequence[float]] = None, + **kwargs, + ): + super().__init__( + feature_size=feature_size, + sampling_rate=sampling_rate, + padding_value=padding_value, + return_attention_mask=return_attention_mask, + **kwargs, + ) + + self.min_frequency = min_frequency + self.max_frequency = max_frequency + self.preemphasis = preemphasis + self.preemphasis_htk_flavor = preemphasis_htk_flavor + self.fft_overdrive = fft_overdrive + self.dither = dither + self.input_scale_factor = input_scale_factor + self.frame_length = int(round(sampling_rate * frame_length_ms / 1000.0)) + self.hop_length = int(round(sampling_rate * hop_length_ms / 1000.0)) + self.mel_floor = np.array(mel_floor, dtype=np.float64) + + fft_length = 2 ** math.ceil(math.log2(self.frame_length)) + if self.fft_overdrive: + fft_length *= 2 + self.fft_length = fft_length + + hann_arange = np.arange(self.frame_length, dtype=np.float32) + window = 0.5 * (1 - np.cos(2 * np.pi * hann_arange / self.frame_length)) + self.window = window.astype(np.float32) + + self.mel_filters = create_fb_matrix( + n_freqs=self.fft_length // 2 + 1, + f_min=min_frequency, + f_max=max_frequency, + n_mels=feature_size, + sample_rate=self.sampling_rate, + norm=None, + fft_length=fft_length, + ) + + if per_bin_mean is not None: + self.per_bin_mean = np.array(per_bin_mean).reshape(1, 1, feature_size) + else: + self.per_bin_mean = None + + if per_bin_stddev is not None: + self.per_bin_stddev = np.array(per_bin_stddev).reshape(1, 1, feature_size) + else: + self.per_bin_stddev = None + + def _extract_spectrogram(self, waveform: np.ndarray, attention_mask: np.ndarray) -> tuple[np.ndarray, np.ndarray]: + """""" + if waveform.ndim == 1: # If single waveform, add batch dimension + waveform = np.expand_dims(waveform, axis=0) + + if self.dither > 0.0: + waveform = waveform + self.dither * np.random.randn(*waveform.shape).astype(waveform.dtype) + + if self.input_scale_factor != 1.0: + waveform = waveform * self.input_scale_factor + + frame_size_for_unfold = self.frame_length + 1 + + # NumPy equivalent of unfold for [B, NumFrames, frame_size_for_unfold] + frames_to_process = _unfold(waveform, dimension=-1, size=frame_size_for_unfold, step=self.hop_length) + + if self.preemphasis > 0.0: + if self.preemphasis_htk_flavor: + first_in_frame = frames_to_process[..., :1] * (1.0 - self.preemphasis) + rest_in_frame = frames_to_process[..., 1:-1] - self.preemphasis * frames_to_process[..., :-2] + frames = np.concatenate([first_in_frame, rest_in_frame], axis=-1) + else: + frames = frames_to_process[..., 1:] - self.preemphasis * frames_to_process[..., :-1] + else: + frames = frames_to_process[..., :-1] + + frames = frames * self.window # Broadcasting window + stft = np.fft.rfft(frames, n=self.fft_length, axis=-1) + + magnitude_spec = np.abs(stft) + + mel_spec = np.matmul(magnitude_spec, self.mel_filters) + log_mel_spec = np.log(np.maximum(mel_spec, self.mel_floor)) + + if self.per_bin_mean is not None: + log_mel_spec = log_mel_spec - self.per_bin_mean # Broadcasting + + if self.per_bin_stddev is not None: + log_mel_spec = log_mel_spec / self.per_bin_stddev # Broadcasting + + mel_spectrogram = log_mel_spec.squeeze() + mask = attention_mask[:: self.hop_length].astype(bool) + # TODO: The filtered mask is always exactly 3 elements longer than the mel_spectrogram. Why??? + return mel_spectrogram, mask[: mel_spectrogram.shape[0]] + + def __call__( + self, + raw_speech: Union[np.ndarray, list[float], list[np.ndarray], list[list[float]]], + padding: Union[bool, str, PaddingStrategy] = "longest", + max_length: Optional[int] = 480_000, + truncation: bool = True, + pad_to_multiple_of: Optional[int] = 128, + return_tensors: Optional[Union[str, TensorType]] = None, + return_attention_mask: Optional[bool] = True, + **kwargs, + ) -> BatchFeature: + """Creates a batch of MEL spectrograms from the provided raw speech. + + This implementation uses a different algorithm for windowing and preemphasis compared to the built-in + `transformers.audio_utils.spectrogram()` function that _will_ result in different outputs. Consider this + carefully when selecting an audio feature extactor, especially with pre-trained models. + + Args: + raw_speech: + The audio for which MEL spectrograms are created. + padding (`Union[bool, str, PaddingStrategy]`, *optional*, defaults to `"longest"`): + The padding strategy to use for batches of audio with different lengths. + max_length (`int`, *optional*, defaults to 480000): + If provided, defines the maximum length of the audio to allow. Audio longer than this will be + truncated if `truncation=True`. + truncation (`bool`, *optional*, defaults to `True`): + Whether or not to truncate audio above `max_length`. + pad_to_multiple_of (`int`, *optional*, defaults to 128): + When padding, pad to a multiple of this value. The default value is defined for optimal TPU support. + return_tensors (`Union[str, TensorType]`, *optional*, defaults to `None`): + The type of tensors to return (e.g., NumPy, Torch, JAX, TensorFlow). + return_attention_mask (`bool`, *optional*, defaults to `True`): + Whether to return the attention mask for the generated MEL spectrograms. + """ + + is_batched_numpy = isinstance(raw_speech, np.ndarray) and len(raw_speech.shape) > 1 + is_batched_sequence = isinstance(raw_speech, Sequence) and isinstance(raw_speech[0], (np.ndarray, Sequence)) + is_batched = is_batched_numpy or is_batched_sequence + + if is_batched: + raw_speech = [np.asarray([rs]).T for rs in raw_speech] + elif not is_batched and not isinstance(raw_speech, np.ndarray): + raw_speech = np.asarray(raw_speech) + + if not is_batched: # always return a batch + raw_speech = [np.asarray([raw_speech])] + + batched_speech = self.pad( + BatchFeature({"input_features": raw_speech}), + padding=padding, + max_length=max_length, + truncation=truncation, + pad_to_multiple_of=pad_to_multiple_of, + return_attention_mask=return_attention_mask, + ) + + prepared_speech = [] + prepared_speech_mask = [] + for speech, mask in zip(batched_speech.input_features, batched_speech.attention_mask): + speech, mask = self._extract_spectrogram(speech.T, mask) + prepared_speech.append(speech.astype(np.float32)) + prepared_speech_mask.append(mask) + + return BatchFeature( + {"input_features": prepared_speech, "input_features_mask": prepared_speech_mask}, + tensor_type=return_tensors, + ) + + +__all__ = ["Gemma3nAudioFeatureExtractor"] diff --git a/src/transformers/models/gemma3n/modeling_gemma3n.py b/src/transformers/models/gemma3n/modeling_gemma3n.py new file mode 100644 index 000000000000..0817e16451ac --- /dev/null +++ b/src/transformers/models/gemma3n/modeling_gemma3n.py @@ -0,0 +1,2422 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/gemma3n/modular_gemma3n.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_gemma3n.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2025 Google Inc. HuggingFace Inc. team. All rights reserved. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import copy +import math +from collections.abc import Callable, Sequence +from dataclasses import dataclass +from typing import Optional, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache, HybridCache +from ...generation import GenerationMixin +from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import ( + ModelOutput, + auto_docstring, + can_return_tuple, + is_torchdynamo_compiling, + logging, +) +from ...utils.deprecation import deprecate_kwarg +from ..auto import AutoModel +from .configuration_gemma3n import Gemma3nAudioConfig, Gemma3nConfig, Gemma3nTextConfig, Gemma3nVisionConfig + + +logger = logging.get_logger(__name__) + + +@dataclass +@auto_docstring( + custom_intro=""" + Base class for Gemma3n outputs, with hidden states and attentions. + """ +) +class Gemma3nModelOutputWithPast(BaseModelOutputWithPast): + r""" + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) + + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + image_hidden_states (`torch.FloatTensor`, *optional*): + A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`. + image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state. + audio_hidden_states (`torch.FloatTensor`, *optional*): + A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`. + audio_hidden_states of the model produced by the audio encoder and after projecting the last hidden state. + """ + + image_hidden_states: Optional[torch.FloatTensor] = None + + audio_hidden_states: Optional[torch.FloatTensor] = None + + +@dataclass +@auto_docstring( + custom_intro=""" + Base class for Gemma3n causal language model (or autoregressive) outputs. + """ +) +class Gemma3nCausalLMOutputWithPast(ModelOutput): + r""" + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss (for next-token prediction). + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.text_config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) + + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + image_hidden_states (`torch.FloatTensor`, *optional*): + A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`. + image_hidden_states of the model produced by the vision encoder after projecting last hidden state. + audio_hidden_states (`torch.FloatTensor`, *optional*): + A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`. + audio_hidden_states of the model produced by the audio encoder and after projecting the last hidden state. + """ + + loss: Optional[torch.FloatTensor] = None + logits: Optional[torch.FloatTensor] = None + past_key_values: Optional[Union[list[torch.FloatTensor], Cache]] = None + hidden_states: Optional[tuple[torch.FloatTensor]] = None + attentions: Optional[tuple[torch.FloatTensor]] = None + image_hidden_states: Optional[torch.FloatTensor] = None + + audio_hidden_states: Optional[torch.FloatTensor] = None + + +class Gemma3nRMSNorm(nn.Module): + def __init__(self, dim: int, eps: float = 1e-6, with_scale: bool = True): + super().__init__() + self.eps = eps + self.with_scale = with_scale + + if self.with_scale: + self.weight = nn.Parameter(torch.ones(dim)) + else: + self.register_buffer("weight", torch.tensor(1.0), persistent=False) + + def _norm(self, x): + return x / torch.sqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # Llama does x.to(float16) * w whilst Gemma2 is (x * w).to(float16) + # See https://github.com/huggingface/transformers/pull/29402 + output = self._norm(x.float()) * self.weight.float() + return output.type_as(x) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.eps}" + + +# ==== Audio Encoder ==== + + +class Gemma3nAudioRelativePositionEmbedding(nn.Module): + def __init__(self, config: Gemma3nAudioConfig): + super().__init__() + self.config = config + + self.num_heads = self.config.conf_num_attention_heads + self.channels = self.config.hidden_size + self.head_dim = self.channels // self.num_heads + self.max_backward = max(0, self.config.conf_attention_context_left - 1) + self.max_forward = self.config.conf_attention_context_right + + self.pos_proj = nn.Linear(self.channels, self.num_heads * self.head_dim, bias=False) + + min_timescale = 1.0 + max_timescale = 1.0e4 + num_timescales = self.channels // 2 + log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / max(num_timescales - 1, 1) + inv_timescales = min_timescale * torch.exp(torch.arange(num_timescales) * -log_timescale_increment) + self.register_buffer( + "inv_timescales", + inv_timescales.float().unsqueeze(0).unsqueeze(0), + persistent=False, + ) + + def _get_timing_signal_1d_pos(self, position: torch.Tensor, dtype: torch.dtype) -> torch.Tensor: + position = position.float().unsqueeze(-1) + scaled_time = position * self.inv_timescales.to(device=position.device, dtype=torch.float32) + timing_signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=-1) + return timing_signal.type(dtype) + + def _relative_shift( + self, + term_bd_before_shift: torch.Tensor, + batch_size: int, + num_heads: int, + num_query_blocks: int, + query_block_size: int, + key_context_size: int, + max_span_plus_1: int, + ) -> torch.Tensor: + """Performs the relative shift. + + Args: + term_bd_before_shift: Tensor of shape [B, N, U, W, F_span]. batch_size + (B), num_heads (N), num_query_blocks (U), query_block_size (W), + key_context_size (C = W+L+R), max_span_plus_1 (F_span = L+R+1). + + Returns: + Tensor of shape [B, N, U, W, C]. + """ + # term_bd_before_shift shape: [B, N, U, W, F_span] + # Target shape after shift: [B, N, U, W, C] + + # Padding amount for the last dimension (F_span) to become (C + 1) + # C = key_context_size + # F_span = max_span_plus_1 + pad_amount_last_dim = (key_context_size + 1) - max_span_plus_1 + + # PyTorch F.pad expects (pad_left, pad_right, pad_top, pad_bottom ...) + # We only pad the last dimension on the right. + padding_tuple = (0, pad_amount_last_dim) + + term_bd_padded = nn.functional.pad(term_bd_before_shift, padding_tuple) + # Shape after pad: [B, N, U, W, C+1] + + # Reshape for slicing (emulating JAX's behavior) + # [B, N, U, W * (C+1)] + term_bd_reshaped = term_bd_padded.reshape( + ( + batch_size, + num_heads, + num_query_blocks, + query_block_size * (key_context_size + 1), + ) + ) + + # Slice to effective [B, N, U, W * C] + term_bd_sliced = term_bd_reshaped[:, :, :, : query_block_size * key_context_size] + + # Reshape back to [B, N, U, W, C] + term_bd_shifted = term_bd_sliced.reshape( + ( + batch_size, + num_heads, + num_query_blocks, + query_block_size, + key_context_size, + ) + ) + return term_bd_shifted + + def forward(self, queries: torch.Tensor, keys: torch.Tensor) -> torch.Tensor: + # queries: [B, U, W, N, H] (batch, num_query_blocks, query_block_size, num_heads, head_dim) + # keys: [B, U, C, N, H] (batch, num_query_blocks, key_context_size, num_heads, head_dim) + # C = W + L + R (key_context_size) + # F_span = L + R + 1 (max_span + 1) + + batch_size, num_query_blocks, query_block_size, num_heads, head_dim = queries.shape + _, _, key_context_size, _, _ = keys.shape + + # Relative positions for sinusoidal embeddings: [L, L-1, ..., -R] + # Length is L+R+1 = self.max_span + 1 + pos_indices = torch.arange(self.max_backward, -self.max_forward - 1, -1, device=queries.device).unsqueeze( + 0 + ) # Shape [1, F_span] + + max_span_plus_1 = pos_indices.shape[1] # F_span + + sin_emb_timing_signal = self._get_timing_signal_1d_pos( + pos_indices, dtype=queries.dtype + ) # Shape [1, F_span, self.channels] + + # Project sinusoidal embeddings: [1, F_span, self.channels] -> [1, F_span, N*H] + projected_sin_emb = self.pos_proj(sin_emb_timing_signal) + # Reshape to [1, F_span, N, H] then squeeze to [F_span, N, H] + sin_emb = projected_sin_emb.reshape(1, max_span_plus_1, self.num_heads, self.head_dim).squeeze( + 0 + ) # Shape [F, N, H] + + # term_ac: Query-Key content interaction + # queries: [B, U, W, N, H] -> permute to [B, N, U, W, H] for matmul + # keys: [B, U, C, N, H] -> permute to [B, N, U, H, C] for matmul + queries_p = queries.permute(0, 3, 1, 2, 4) # [B, N, U, W, H] + keys_p_t = keys.permute(0, 3, 1, 4, 2) # [B, N, U, H, C] + term_ac = torch.matmul(queries_p, keys_p_t) # [B, N, U, W, C] + + # term_bd: Query-Position interaction + # Original einsum: term_bd_unshifed = torch.einsum('buwnh,fnh->bnuwf', queries, sin_emb) + # queries shape: [B, U, W, N, H] + # sin_emb shape: [F, N, H] + # Target output shape: [B, N, U, W, F] + + # Permute queries to [B, N, U, W, H] for easier broadcasting with sin_emb + q_permuted = queries.permute(0, 3, 1, 2, 4) + + # Permute sin_emb to [N, H, F] to prepare for matmul + # sin_emb original is [F, N, H] + s_permuted = sin_emb.permute(1, 2, 0) # Shape: [N, H, F] + + # Reshape queries for matmul: [B, N, U*W, H] + q_reshaped = q_permuted.reshape(batch_size, num_heads, num_query_blocks * query_block_size, head_dim) + + # Perform matmul: [B, N, U*W, H] @ [N, H, F] + # s_permuted ([N, H, F]) will be broadcast to [B, N, H, F] + # Result: [B, N, U*W, F] + term_bd_unshifed_matmul = torch.matmul(q_reshaped, s_permuted) + + # Reshape to target [B, N, U, W, F] + term_bd_unshifed = term_bd_unshifed_matmul.reshape( + batch_size, + num_heads, + num_query_blocks, + query_block_size, + max_span_plus_1, + ) + + # Apply relative shift to term_bd_unshifed + term_bd_shifted = self._relative_shift( + term_bd_unshifed, + batch_size, + num_heads, + num_query_blocks, + query_block_size, + key_context_size, + max_span_plus_1, + ) # Shape [B, N, U, W, C] + + return term_ac + term_bd_shifted + + +class Gemma3nAudioAttention(nn.Module): + def __init__(self, config: Gemma3nAudioConfig): + super().__init__() + self.config = config + + self.num_heads = self.config.conf_num_attention_heads + self.hidden_size = self.config.hidden_size + self.head_dim = self.hidden_size // self.num_heads + + self.chunk_size = self.config.conf_attention_chunk_size + self.max_future_horizon = self.config.conf_attention_context_right + self.max_past_horizon = max(0, self.config.conf_attention_context_left - 1) + self.attention_logits_soft_cap = self.config.conf_attention_logit_cap + self.context_size = self.chunk_size + self.max_past_horizon + self.max_future_horizon + + self.relative_position_embedding = Gemma3nAudioRelativePositionEmbedding(config) + self.per_dim_scale = nn.Parameter(torch.zeros((self.head_dim,))) + + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) + self.k_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) + self.v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) + + q_scale = self.head_dim**-0.5 + r_softplus_0 = 1.0 / torch.nn.functional.softplus(torch.tensor(0.0)) + self.register_buffer("q_scale", (q_scale * r_softplus_0).clone().detach(), persistent=False) + + lower_causal_mask = torch.tril( + torch.ones((self.context_size, self.chunk_size), dtype=torch.bool), + diagonal=0, + ).T + upper_causal_mask = torch.tril( + torch.ones((self.chunk_size, self.context_size), dtype=torch.bool), + diagonal=self.max_past_horizon + self.max_future_horizon, + ) + local_causal_valid_mask = torch.ones((self.chunk_size, self.context_size), dtype=torch.bool) + local_causal_valid_mask = local_causal_valid_mask * lower_causal_mask * upper_causal_mask + self.register_buffer("local_causal_valid_mask", local_causal_valid_mask, persistent=False) + + self.register_buffer( + "softcap", + torch.tensor(self.attention_logits_soft_cap).float(), + persistent=False, + ) + + def _pad_dim1(self, x: torch.Tensor, pad_left: int, pad_right: int) -> torch.Tensor: + batch, _, *tail_shape = x.shape + left = x.new_zeros((batch, pad_left, *tail_shape)) + right = x.new_zeros((batch, pad_right, *tail_shape)) + x = torch.cat([left, x, right], dim=1) + return x + + def _convert_to_block(self, hidden_states: torch.Tensor) -> torch.Tensor: + """Turns a sequence to non overlapping blocks. + + Args: + hidden_states: a tensor of [batch, time, ...]. + + Returns: + A tensor of [batch, num_blocks, block_size, ...], with necessary + paddings, + where output[:, i, ...] are x[:, i*block_size:(i+1)*block_size, ...]. + """ + shape = hidden_states.shape + b, t = shape[:2] + num_blocks = (t + self.chunk_size - 1) // self.chunk_size + + if (padding_len := num_blocks * self.chunk_size - t) > 0: + hidden_states = self._pad_dim1(hidden_states, 0, padding_len) + + permute_dims = (b, num_blocks, self.chunk_size) + shape[2:] + hidden_states = hidden_states.reshape(permute_dims).contiguous() + return hidden_states + + def _extract_block_context(self, hidden_states: torch.Tensor) -> torch.Tensor: + """Extracts temporal context for every block. + + Args: + hidden_states: a tensor of [batch, time, ...]. + + Returns: + A tensor of [batch, num_blocks, context_size, ...], with necessary + paddings, + where context_size = block_size + left_context + right_context, + and output[:, i, ...] are x[:, start-left_context:end+right_context, + ...], + start = i * block_size, end = (i + 1) * block_size. + """ + pad_left = self.max_past_horizon + # The JAX equivalent padding for signal.frame with pad_mode='valid' is + # (left_context, right_context + block_size - 1) on the time dimension. + # PyTorch's _pad_dim1 applies padding symmetrically if only one value is given, + # or (pad_dim_start, pad_dim_end) if two are given. + # Our _pad_dim1(x, pad_left, pad_right) pads dim -2 (time for [B,T,N,H]) + # or dim 1 (time for [B,T]). + # The current pad_right calculation matches the JAX effective padding. + pad_right = self.max_future_horizon + self.chunk_size - 1 + hidden_states = self._pad_dim1(hidden_states, pad_left, pad_right) + + frame_len = self.context_size + frame_step = self.chunk_size + + # Directly use unfold without the subframe_factor logic + # x.unfold(dimension, size, step) + # dimension=1 (time dimension, assuming x is [B, T_padded, ...]) + # size=frame_len (context_size) + # step=frame_step (chunk_size) + x_unfolded = hidden_states.unfold(dimension=1, size=frame_len, step=frame_step) + + # If x was [B, T_padded], x_unfolded is [B, num_blocks, frame_len] + # If x was [B, T_padded, N, H], x_unfolded is [B, num_blocks, N, H, frame_len] + # We want to match JAX's typical output for such operations which might be + # [B, num_blocks, frame_len, N, H] if N, H are present. + # The relative_position_embedding expects keys as [B, U, C, N, H]. + # If x_unfolded is [B, U, N, H, C(frame_len)], we need to move C. + if hidden_states.ndim > 2 and x_unfolded.ndim > 3: # Check if inner dimensions (like N, H) exist + # Current shape after unfold for [B, T_pad, N, H] is [B, U, N, H, C] + # Target shape for keys in RPE: [B, U, C, N, H] + x_unfolded = torch.movedim(x_unfolded, source=-1, destination=2) + + return x_unfolded.contiguous() + + def forward(self, hidden_states: torch.Tensor, mask: torch.BoolTensor) -> torch.Tensor: + # sl.Dense uses jax.numpy.einsum("...a,abcd->...bcd") and jax.numpy.select() + qkv_shape = (*hidden_states.shape[:-1], self.num_heads, self.head_dim) + query_states = self.q_proj(hidden_states).reshape(qkv_shape).contiguous() + key_states = self.k_proj(hidden_states).reshape(qkv_shape).contiguous() + value_states = self.v_proj(hidden_states).reshape(qkv_shape).contiguous() + + per_dim_scale_sp = torch.nn.functional.softplus(self.per_dim_scale) + + broadcast_shape = (1, 1, 1, self.head_dim) + per_dim_scale_sp_broadcast = per_dim_scale_sp.view(broadcast_shape) + query_states = query_states * self.q_scale * per_dim_scale_sp_broadcast + + batch_size, q_time = query_states.shape[:2] + + query_blocks = self._convert_to_block(query_states) + key_blocks = self._extract_block_context(key_states) + value_blocks = self._extract_block_context(value_states) + num_query_blocks = query_blocks.shape[1] + + # 1. Create a mask indicating originally valid positions. + original_valid_mask = ~mask # True for valid, False for padded + + # 2. Extract blocks from this validity mask. + extracted_valid_mask_blocks = self._extract_block_context(original_valid_mask) + + # If subframe_factor was used in _extract_block_context for a [B, T] input mask, + # the shape might be [B, U, C/SF, SF]. Reshape to [B, U, C]. + # batch_size and num_query_blocks are known from query_blocks. + # self.context_size is C. + if ( + extracted_valid_mask_blocks.ndim == 4 + and extracted_valid_mask_blocks.shape[2] * extracted_valid_mask_blocks.shape[3] == self.context_size + ): + extracted_valid_mask_blocks = extracted_valid_mask_blocks.reshape( + batch_size, num_query_blocks, self.context_size + ) + # After potential reshape, ensure it's [B, U, C] if it was from a [B,T] mask. + # This assertion might be too strict if _extract_block_context handles higher-rank inputs differently, + # but for the mask case, this should hold. + if extracted_valid_mask_blocks.shape != ( + batch_size, + num_query_blocks, + self.context_size, + ): + raise ValueError( + "Shape of extracted_valid_mask_blocks" + f" {extracted_valid_mask_blocks.shape} is not ({batch_size}," + f" {num_query_blocks}, {self.context_size}) after potential reshape." + ) + + # 3. Expand dimensions for broadcasting with logits and causal mask. + # Target shape for broadcasting with logits [B,N,U,W,C] + # extracted_valid_mask_blocks to [B, 1, U, 1, C] + condition_from_input_validity = extracted_valid_mask_blocks.unsqueeze(1).unsqueeze(-2) + + # self.local_causal_valid_mask is [W, C], True where allowed by local window. + # Expand to [1, 1, 1, W, C] + condition_from_causality = self.local_causal_valid_mask.unsqueeze(0).unsqueeze(0).unsqueeze(0) + + # 4. Combine the two conditions. + # final_condition will be True where a key is *both* originally valid *and* causally accessible. + # Broadcasts to [B, 1, U, W, C] + final_condition_for_where = torch.logical_and( + condition_from_input_validity, + condition_from_causality.to(condition_from_input_validity.device), # Ensure same device + ) + + # Embed queries and keys + logits = self.relative_position_embedding(query_blocks, key_blocks) + + # Apply attention logit softcap + # Ensure softcap is on the same device as logits + softcap_val = self.softcap.to(logits.device) + logits = logits / softcap_val + logits = torch.tanh(logits) + logits = logits * softcap_val + + # Apply the combined mask. + # final_condition_for_where will broadcast with logits [B,N,U,W,C] + logits = torch.where(final_condition_for_where, logits, torch.finfo(logits.dtype).min) + probabilities = torch.nn.functional.softmax(logits, dim=-1, dtype=torch.float32).to(dtype=value_blocks.dtype) + + # context_vectors is adapted from jax.numpy.einsum("BNuwc,BucNH->BuwNH", ...) + b_dim, n_dim, u_dim, w_dim, c_dim = probabilities.shape + h_dim = value_blocks.shape[-1] + prob_bun = probabilities.permute(0, 2, 1, 3, 4).reshape(-1, w_dim, c_dim) + v_bun = value_blocks.permute(0, 1, 3, 2, 4).reshape(-1, c_dim, h_dim) + result_bmm = torch.bmm(prob_bun, v_bun) + context_vectors = result_bmm.reshape(b_dim, u_dim, n_dim, w_dim, h_dim).permute(0, 1, 3, 2, 4) + context_vectors = context_vectors.reshape( + ( + batch_size, + num_query_blocks * self.chunk_size, + self.num_heads, + self.head_dim, + ) + ) + context_vectors = context_vectors[:, :q_time] + + return context_vectors + + +class Gemma3nAudioCumulativeGroupNorm(nn.Module): + """Applies Group Normalization cumulatively over the time dimension. + + This layer normalizes the input by calculating the mean and variance + cumulatively over the time dimension (dim 1). The statistics are computed + over all feature dimensions (specified by `feature_dims` and `num_channels`) + for elements marked as valid by the optional `mask`. + + If a `mask` is provided (True for valid, False for invalid/padded), + invalid time steps do not contribute to the statistics calculation, and + their corresponding output values are zeroed out. + + Scale and bias, if enabled, are applied per-channel (last dimension). + This behavior is similar to JAX's `GroupNormalization` with `num_groups=1` + and `cumulative=True`. + """ + + def __init__( + self, + num_channels: int, # Number of channels (size of the last dimension) + feature_dims: Sequence[int], # Sizes of non-channel feature dimensions, e.g., (H, W) for input [B,T,H,W,C] + eps: float = 1e-3, + ): + super().__init__() + self.num_channels = num_channels + self.feature_dims = tuple(feature_dims) + self.eps = eps + + # Scale parameter depends only on the channel dimension + self.weight = nn.Parameter(torch.ones(num_channels)) + + # Axes for normalization: all dimensions except Batch (0) and Time (1). + # For input [B, T, *feature_dims, C], these are dims from 2 onwards. + self.reduction_axes = tuple(range(2, 2 + len(self.feature_dims) + 1)) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + """Applies cumulative group norm, optionally using a mask. + + Args: + hidden_states: Input tensor, shape [B, T, *feature_dims, C]. + + Returns: + Normalized tensor with the same shape as x. + """ + expected_input_suffix = self.feature_dims + (self.num_channels,) + if hidden_states.shape[2:] != expected_input_suffix: + raise ValueError( + f"Input tensor shape suffix {hidden_states.shape[2:]} does not match expected" + f" suffix (feature_dims + num_channels) {expected_input_suffix}" + ) + + input_dtype = hidden_states.dtype + # Calculations are performed in float32 for numerical stability. + calc_dtype = torch.float32 + x_calc = hidden_states.to(calc_dtype) + + # Prepare a broadcastable mask (`mask_calc`). + # If no mask is provided, treat all elements as valid + # (mask_calc is all ones). + # Otherwise, expand the [B, T] mask to [B, T, 1, ..., 1] for broadcasting. + mask_calc = torch.ones_like(x_calc, dtype=calc_dtype) + + # Cumulative Statistics Calculation + # 1. Sum of values over reduction axes at each time step. + sum_values_at_t = torch.sum(x_calc, dim=self.reduction_axes, keepdim=True) + # 2. Cumulative sum of values over time. + cum_sum_values = torch.cumsum(sum_values_at_t, dim=1) + + # 3. Count of valid elements in the normalization group at each time step. + # (A "group" here consists of all features at a given Batch, Time). + elements_in_group_at_t = torch.sum(mask_calc, dim=self.reduction_axes, keepdim=True) + # 4. Cumulative count of valid elements over time. + cum_count_elements = torch.cumsum(elements_in_group_at_t, dim=1) + # Avoid division by zero if all preceding elements were masked. + safe_cum_count_elements = torch.clamp(cum_count_elements, min=1.0) + + # 5. Cumulative mean. + cum_mean = cum_sum_values / safe_cum_count_elements + + # 6. Sum of squared differences from the cumulative mean. + # Only sum for valid elements: (x_calc - cum_mean)^2 * mask_calc. + # Using x_calc here for the difference, as cum_mean already accounts for masking. + squared_diff_from_mean = (x_calc - cum_mean).pow(2) + sum_sq_diff_at_t = torch.sum(squared_diff_from_mean, dim=self.reduction_axes, keepdim=True) + + # 7. Cumulative sum of squared differences over time. + cum_sum_sq_diff = torch.cumsum(sum_sq_diff_at_t, dim=1) + + # 8. Cumulative variance. + cum_variance = cum_sum_sq_diff / safe_cum_count_elements + + # Normalize the input using the calculated cumulative statistics: + # (x - E[x]) / sqrt(Var[x] + eps) + normalized_x = (x_calc - cum_mean) * torch.rsqrt(cum_variance + self.eps) + + # Apply affine transformation (scale and bias) if enabled. + # Scale and bias are applied per-channel (last dimension). + scale = self.weight.to(calc_dtype) + # Reshape for broadcasting: [C] -> [1, ..., 1, C] + scale_view_shape = [1] * (hidden_states.dim() - 1) + [self.num_channels] + normalized_x = normalized_x * scale.view(scale_view_shape) + + # Zero out outputs for time steps that were originally masked (where mask_calc is 0). + # This ensures padded/invalid positions in the input result in zero output. + final_output = normalized_x * mask_calc + + return final_output.to(input_dtype) + + +class Gemma3nAudioSSCPConvBlock(nn.Module): + """A single convolution block for the SubSampleConvProjection. + + This block consists of a 2D convolution, followed by CumulativeGroupNorm, + and a ReLU activation. It handles manual padding for the convolution. + """ + + def __init__( + self, + config: Gemma3nAudioConfig, + idx: int, + input_freq_dim: int, # Changed from input_spatial_dim + manual_padding: tuple[int, int, int, int] = (0, 0, 0, 0), + ): + super().__init__() + self.config = config + self.manual_padding = manual_padding + + # in_channels is 1 for the first block, or C_out from previous block's conv + in_channels = 1 if idx == 0 else self.config.sscp_conv_channel_size[idx - 1] + out_channels = self.config.sscp_conv_channel_size[idx] + kernel_h, kernel_w = self.config.sscp_conv_kernel_size[idx] + stride_h, stride_w = self.config.sscp_conv_stride_size[idx] + + self.conv = nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=( + kernel_h, + kernel_w, + ), # Kernel (kH, kW) operates on (Time, Freq_dim) + stride=(stride_h, stride_w), + padding=(0, 0), # Manual padding is used + bias=False, + ) + + # Calculate output frequency dimension (f_out_conv) after this convolution. + # input_freq_dim is the unpadded width (feature dimension). + # self.manual_padding is (pad_F_left, pad_F_right, pad_T_top, pad_T_bottom) + f_in_padded = input_freq_dim + self.manual_padding[0] + self.manual_padding[1] + f_out_conv = (f_in_padded - kernel_w) // stride_w + 1 + + self.norm = Gemma3nAudioCumulativeGroupNorm( + num_channels=out_channels, # Channels of the conv output + feature_dims=(f_out_conv,), # The frequency dimension size after conv + eps=self.config.sscp_conv_group_norm_eps, + ) + + self.activation = nn.ReLU() + + def forward(self, audio_encodings: torch.Tensor) -> torch.Tensor: + # Input audio_encodings is [B, C_in, T_in, F_in] (e.g., C_in=1) + # manual_padding is (pad_F_left, pad_F_right, pad_T_top, pad_T_bottom) + # F.pad applies to last two dims: F_in then T_in + audio_encodings_padded = F.pad(audio_encodings, self.manual_padding, mode="constant", value=0.0) + # Expected padded shape for F_in, k_w=3, pad_F=(1,1) -> F_padded = F_in+2 + # Expected padded shape for T_in, k_h=3, pad_T=(0,2) -> T_padded = T_in+2 + audio_encodings_conv = self.conv(audio_encodings_padded) + # Expected conv output shape: [B, C_out, T_out, F_out] + # Input to norm is [B, T_out, F_out, C_out] + x_for_norm = audio_encodings_conv.permute(0, 2, 3, 1).contiguous() + x_normed = self.norm(x_for_norm) + # Output of norm is [B, T_out, F_out, C_out], permute back to [B, C_out, T_out, F_out] + audio_encodings_normed = x_normed.permute(0, 3, 1, 2).contiguous() + return self.activation(audio_encodings_normed) + + +class Gemma3nAudioSubSampleConvProjection(nn.Module): + def __init__(self, config: Gemma3nAudioConfig): + super().__init__() + self.config = config + + current_f_for_block_input = config.input_feat_size # Start with original feature dim + calculated_block_padding = [] + calculated_f_out_dims = [] # Tracking frequency dimension output sizes + + for i in range(2): # Assuming 2 conv layers as per sscp_conv_... arrays + kernel_h, kernel_w = config.sscp_conv_kernel_size[i] + stride_h, stride_w = config.sscp_conv_stride_size[i] + + # Padding for Time (Height for Conv2d) - REVERSE_CAUSAL like + # JAX 'reverse_causal' padding is (0, kernel_size - 1) + pad_t_top = 0 + pad_t_bottom = kernel_h - 1 + + # Frequency Padding (Width for Conv2d) + # Based on JAX effective padding (1,1) for F_in=10, K_w=3, S_w=2 + # and the successful test configuration. + # If kernel/stride/input_freq for frequency changes, this might need re-evaluation + # to match generic JAX 'SAME' behavior if it differs. + pad_f_left = 1 + pad_f_right = 1 + + manual_padding_tuple = ( + pad_f_left, + pad_f_right, + pad_t_top, + pad_t_bottom, + ) + calculated_block_padding.append(manual_padding_tuple) + + # Calculate output frequency dimension after this convolution + # This uses the actual padding applied and kernel/stride. + f_in_padded = current_f_for_block_input + pad_f_left + pad_f_right + f_out_after_conv = (f_in_padded - kernel_w) // stride_w + 1 # Assuming dilation_w = 1 + calculated_f_out_dims.append(f_out_after_conv) + current_f_for_block_input = f_out_after_conv + + self.conv_0 = Gemma3nAudioSSCPConvBlock( + idx=0, + input_freq_dim=config.input_feat_size, # Pass original feature dim + config=config, + manual_padding=calculated_block_padding[0], + ) + self.conv_1 = Gemma3nAudioSSCPConvBlock( + idx=1, + input_freq_dim=calculated_f_out_dims[0], # Output freq dim from conv_0 + config=config, + manual_padding=calculated_block_padding[1], + ) + final_c_out = config.sscp_conv_channel_size[-1] + final_f_out = calculated_f_out_dims[-1] # Final frequency dimension + self.input_proj_in_features = final_c_out * final_f_out + self.input_proj_linear = nn.Linear(self.input_proj_in_features, self.config.hidden_size, bias=False) + + def forward(self, audio_encodings: torch.Tensor) -> torch.Tensor: + # audio_encodings is [B, T, F_in] + # Reshape to [B, 1, T, F_in] (Batch, Channels=1, Height=Time, Width=F_in) + audio_encodings_reshaped = audio_encodings.unsqueeze(1) + x = self.conv_0(audio_encodings_reshaped) + x = self.conv_1(x) + # x from conv_1 is [B, C_out_1, T_out_1, F_out_1] + b, c_out, t_out, f_out = x.shape + # Permute to [B, T_out_1, F_out_1, C_out_1] then flatten F_out_1 and C_out_1 + x_permuted = x.permute(0, 2, 3, 1).contiguous() + output_flattened = x_permuted.view(b, t_out, f_out * c_out) + output = self.input_proj_linear(output_flattened) + return output + + +class Gemma3nAudioConformerAttention(nn.Module): + def __init__(self, config: Gemma3nAudioConfig): + super().__init__() + self.config = config + self.post_in_features = self.config.hidden_size + self.register_buffer("gradient_clipping", torch.tensor(self.config.gradient_clipping), persistent=False) + self.pre_attn_norm = Gemma3nRMSNorm(self.config.hidden_size) + self.attn = Gemma3nAudioAttention(config) + self.post = nn.Linear(self.post_in_features, self.config.hidden_size, bias=False) + self.post_norm = Gemma3nRMSNorm(self.config.hidden_size) + + def forward(self, audio_encodings: torch.Tensor, audio_mel_mask: torch.BoolTensor) -> torch.Tensor: + audio_encodings_input_to_attn = audio_encodings + audio_encodings = torch.clamp(audio_encodings, -self.gradient_clipping, self.gradient_clipping) + audio_encodings_norm = self.pre_attn_norm(audio_encodings) + # Output of self.attn is [B, T, NumHeads, HeadDim] + audio_encodings_attn_out = self.attn(audio_encodings_norm, audio_mel_mask) + + # Reshape from [B, T, NumHeads, HeadDim] to [B, T, NumHeads * HeadDim] + # NumHeads * HeadDim = hidden_size + b, t, num_heads, head_dim = audio_encodings_attn_out.shape + audio_encodings_reshaped = audio_encodings_attn_out.reshape(b, t, num_heads * head_dim) + + audio_encodings = self.post(audio_encodings_reshaped) + audio_encodings = torch.clamp(audio_encodings, -self.gradient_clipping, self.gradient_clipping) + return audio_encodings_input_to_attn + self.post_norm(audio_encodings) + + +class Gemma3nAudioConformerFeedForward(nn.Module): + def __init__(self, config: Gemma3nAudioConfig): + super().__init__() + self.config = config + + self.register_buffer("gradient_clipping", torch.tensor(self.config.gradient_clipping), persistent=False) + + self.pre_layer_norm = Gemma3nRMSNorm(self.config.hidden_size) + self.ffw_layer_1 = nn.Linear(self.config.hidden_size, self.config.hidden_size * 4, bias=False) + self.ffw_layer_2 = nn.Linear(self.config.hidden_size * 4, self.config.hidden_size, bias=False) + self.post_layer_norm = Gemma3nRMSNorm(self.config.hidden_size) + self.post_layer_scale = torch.tensor(self.config.conf_residual_weight) + + def forward(self, audio_encodings: torch.Tensor) -> torch.Tensor: + residual = audio_encodings + audio_encodings = torch.clamp(audio_encodings, -self.gradient_clipping, self.gradient_clipping) + audio_encodings = self.pre_layer_norm(audio_encodings) + audio_encodings: torch.Tensor = self.ffw_layer_1(audio_encodings) + audio_encodings = nn.functional.silu(audio_encodings) + audio_encodings: torch.Tensor = self.ffw_layer_2(audio_encodings) + audio_encodings = torch.clamp(audio_encodings, -self.gradient_clipping, self.gradient_clipping) + audio_encodings = self.post_layer_norm(audio_encodings) + return residual + (audio_encodings * self.post_layer_scale) + + +class Gemma3nAudioConformerLightConv1d(nn.Module): + def __init__(self, config: Gemma3nAudioConfig): + super().__init__() + self.config = config + + self.pre_layer_norm = Gemma3nRMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps) + self.linear_start = nn.Linear(self.config.hidden_size, self.config.hidden_size * 2, bias=False) + self.depthwise_conv1d = nn.Conv1d( + in_channels=self.config.hidden_size, + out_channels=self.config.hidden_size, + kernel_size=self.config.conf_conv_kernel_size, + stride=1, + padding=0, # Manual causal padding + groups=self.config.hidden_size, # Depthwise + bias=False, + ) + self.register_buffer("gradient_clipping", torch.tensor(self.config.gradient_clipping), persistent=False) + self.conv_norm = Gemma3nRMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps) + self.linear_end = nn.Linear(self.config.hidden_size, self.config.hidden_size, bias=False) + + self.causal_padding = self.config.conf_conv_kernel_size - 1 + + def forward(self, audio_encodings: torch.Tensor) -> torch.Tensor: + audio_encodings_residual = audio_encodings # Save for residual connection + + audio_encodings = self.pre_layer_norm(audio_encodings) + audio_encodings = self.linear_start(audio_encodings) + audio_encodings = torch.nn.functional.glu(audio_encodings, dim=-1) + # Permute for Conv1d: [B, T, D] -> [B, D, T] + audio_encodings_permuted = audio_encodings.permute(0, 2, 1) + # Apply manual causal padding + audio_encodings_permuted_padded = F.pad(audio_encodings_permuted, (self.causal_padding, 0)) + audio_encodings = self.depthwise_conv1d(audio_encodings_permuted_padded) + # Permute back: [B, D, T_out] -> [B, T_out, D] + audio_encodings = audio_encodings.permute(0, 2, 1) + audio_encodings = torch.clamp(audio_encodings, -self.gradient_clipping, self.gradient_clipping) + audio_encodings = self.conv_norm(audio_encodings) + audio_encodings = nn.functional.silu(audio_encodings) + audio_encodings = self.linear_end(audio_encodings) + output = audio_encodings + audio_encodings_residual + return output + + +class Gemma3nAudioConformerBlock(nn.Module): + def __init__(self, config: Gemma3nAudioConfig): + super().__init__() + self.config = config + + self.ffw_layer_start = Gemma3nAudioConformerFeedForward(self.config) + self.attention = Gemma3nAudioConformerAttention(self.config) + self.lconv1d = Gemma3nAudioConformerLightConv1d(self.config) + self.ffw_layer_end = Gemma3nAudioConformerFeedForward(self.config) + self.register_buffer("gradient_clipping", torch.tensor(self.config.gradient_clipping), persistent=False) + self.norm = Gemma3nRMSNorm(self.config.hidden_size) + + def forward(self, audio_encodings: torch.Tensor, audio_mel_mask: torch.BoolTensor) -> torch.Tensor: + audio_encodings = self.ffw_layer_start(audio_encodings) + audio_encodings = self.attention(audio_encodings, audio_mel_mask) + validity_mask_for_lconv = ~audio_mel_mask # True for valid + audio_encodings_for_lconv_input = audio_encodings * validity_mask_for_lconv.unsqueeze(-1).to( + audio_encodings.dtype + ) + audio_encodings = self.lconv1d(audio_encodings_for_lconv_input) + + audio_encodings = self.ffw_layer_end(audio_encodings) + audio_encodings = torch.clamp(audio_encodings, -self.gradient_clipping, self.gradient_clipping) + output = self.norm(audio_encodings) + return output + + +class Gemma3nAudioEncoder(PreTrainedModel): + """A Universal Speech Encoder -- https://arxiv.org/abs/2303.01037""" + + config_class = Gemma3nAudioConfig + + main_input_name = "audio_mel" + + def __init__(self, config: Gemma3nAudioConfig): + super().__init__(config) + self.config = config + + self.subsample_conv_projection = Gemma3nAudioSubSampleConvProjection(config) + self.conformer = nn.ModuleList( + [Gemma3nAudioConformerBlock(config) for _ in range(config.conf_num_hidden_layers)] + ) + + def forward( + self, audio_mel: torch.Tensor, audio_mel_mask: torch.BoolTensor + ) -> tuple[torch.Tensor, torch.BoolTensor]: + """Encodes a batch of MELs. + + Args: + audio_mel: a torch.Tensor of shape [batch, num_frames, num_channels, + mel_bins]. + + Returns: + audio_encodings: a torch.Tensor of shape + `[batch_size, self.config.audio_soft_tokens_per_image, + self.config.audio_config.hidden_size]` + audio_mel_mask: a torch.BoolTensor of shape [batch, num_frames]. + """ + audio_encodings = self.subsample_conv_projection(audio_mel) # audio_encodings: [B, T_sub, D] + + # Subsample the input audio_mel_mask to match the time dimension of audio_encodings (T_sub) + t_sub = audio_encodings.shape[1] + + time_stride_product = 1 + for stride_pair_idx in range(len(self.config.sscp_conv_stride_size)): + time_stride_product *= self.config.sscp_conv_stride_size[stride_pair_idx][0] + + # Create indices for gathering from the original mask. + # These indices map to original time steps corresponding to the start of each + # receptive field in the subsampled output. + indices = torch.arange(t_sub, device=audio_mel_mask.device) * time_stride_product + indices = torch.clamp(indices, max=audio_mel_mask.shape[1] - 1) # Ensure indices are valid + + # Expand indices for batch compatibility if B > 1 and indices is 1D. + if audio_mel_mask.ndim > 1 and indices.ndim == 1: + indices = indices.unsqueeze(0).expand(audio_mel_mask.shape[0], -1) # [B, T_sub] + elif ( + audio_mel_mask.ndim == indices.ndim + and audio_mel_mask.shape[0] == 1 + and indices.shape[0] != 1 + and t_sub == indices.shape[0] + ): + # Handle case where B=1 but indices became [T_sub] instead of [1, T_sub] + indices = indices.unsqueeze(0) + + current_mask = torch.gather(audio_mel_mask, 1, indices) # [B, T_sub] + + for block in self.conformer: + audio_encodings = block(audio_encodings, current_mask) # Pass the processed mask + + if self.config.conf_reduction_factor > 1: + audio_encodings = audio_encodings[:, :: self.config.conf_reduction_factor] + # Reduce the mask as well + current_mask = current_mask[:, :: self.config.conf_reduction_factor] + + audio_encodings = audio_encodings.masked_fill(current_mask.unsqueeze(-1), 0.0) + return audio_encodings, current_mask + + +class Gemma3nTextScaledWordEmbedding(nn.Embedding): + """ + This module overrides nn.Embeddings' forward by multiplying with embeddings scale. + """ + + def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, embed_scale: float = 1.0): + super().__init__(num_embeddings, embedding_dim, padding_idx) + self.register_buffer("embed_scale", torch.tensor(embed_scale), persistent=False) + + def forward(self, input_ids: torch.Tensor): + return super().forward(input_ids) * self.embed_scale.to(self.weight.dtype) + + +class Gemma3nTextLaurelBlock(nn.Module): + """Learned Augmented Residual Layer""" + + def __init__(self, config: Gemma3nTextConfig): + super().__init__() + self.config = config + + self.linear_left = nn.Linear(self.config.hidden_size, self.config.laurel_rank, bias=False) + self.linear_right = nn.Linear(self.config.laurel_rank, self.config.hidden_size, bias=False) + self.post_laurel_norm = Gemma3nRMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + laurel_hidden_states: torch.Tensor = self.linear_left(hidden_states) + laurel_hidden_states: torch.Tensor = self.linear_right(laurel_hidden_states) + normed_laurel_hidden_states = self.post_laurel_norm(laurel_hidden_states) + return hidden_states + normed_laurel_hidden_states + + +class Gemma3nTextMLP(nn.Module): + def __init__(self, config: Gemma3nTextConfig, layer_idx: int = 0): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size[layer_idx] + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_activation] + self.activation_sparsity = config.activation_sparsity_pattern[layer_idx] + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + gate_proj = self.gate_proj(hidden_states) + if self.activation_sparsity > 0.0: + gate_proj = self._gaussian_topk(gate_proj) + activations = self.act_fn(gate_proj) + up_proj = self.up_proj(hidden_states) + down_proj = self.down_proj(activations * up_proj) + return down_proj + + def _gaussian_topk(self, inputs: torch.Tensor) -> torch.Tensor: + target_sparsity_tensor = torch.tensor(self.activation_sparsity, dtype=torch.float32, device=inputs.device) + # normal_dist and std_multiplier are adapted from jax.scipy.stats.norm.ppf(). + # + # References: + # * https://docs.jax.dev/en/latest/_autosummary/jax.scipy.stats.norm.ppf.html + # * https://pytorch.org/docs/stable/distributions.html#torch.distributions.normal.Normal + # * https://pytorch.org/docs/stable/distributions.html#torch.distributions.transformed_distribution.TransformedDistribution.icdf + normal_dist = torch.distributions.normal.Normal(0, 1) + std_multiplier: torch.Tensor = normal_dist.icdf(target_sparsity_tensor) + std_multiplier = std_multiplier.type(inputs.dtype) + inputs_mean = torch.mean(inputs, dim=-1, keepdim=True) + inputs_std = torch.std(inputs, dim=-1, keepdim=True, unbiased=False) + cutoff_x = inputs_mean + inputs_std * std_multiplier + return nn.functional.relu(inputs - cutoff_x) + + +class Gemma3nTextAltUp(nn.Module): + """Alternating Updates (AltUp) + + The AltUp module wraps transformer layers. The `predict` step modifies the + input to the transformer layer, and the `correct` step propagates the output + of the transformer layer to the sparsely updated dimensions. + + See more in the research paper: + + https://proceedings.neurips.cc/paper_files/paper/2023/file/f2059277ac6ce66e7e5543001afa8bb5-Paper-Conference.pdf + """ + + def __init__(self, config: Gemma3nTextConfig): + super().__init__() + self.config = config + self.correct_output_scale = nn.Parameter(torch.zeros(self.config.hidden_size)) + self.correction_coefs = nn.Linear(self.config.altup_num_inputs, self.config.altup_num_inputs, bias=False) + self.prediction_coefs = nn.Linear(self.config.altup_num_inputs, self.config.altup_num_inputs**2, bias=False) + self.modality_router = nn.Linear(self.config.hidden_size, self.config.altup_num_inputs, bias=False) + self.router_norm = Gemma3nRMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps) + self.register_buffer("router_input_scale", torch.tensor(self.config.hidden_size**-1.0), persistent=False) + + def compute_router_modalities(self, x: torch.Tensor) -> torch.Tensor: + router_inputs = self.router_norm(x) * self.router_input_scale + routed = self.modality_router(router_inputs) + return torch.tanh(routed.float()).type_as(x) + + def predict(self, hidden_states: torch.Tensor) -> torch.Tensor: + """Predicts the output of a layer using a trainable map. + + Args: + hidden_states: A 4D tensor of shape `[num_altup_inputs, batch_size, num_tokens, hidden_size]` derived by + stacking the input embeddings and preprocessing the last `num_altup_inputs - 1` matrices. + + Returns: + A 4D tensor of shape `[num_altup_inputs, batch_size, num_tokens, hidden_size]` containing the predictions. + """ + modalities = self.compute_router_modalities(hidden_states[self.config.altup_active_idx]) + + if self.training and self.config.altup_coef_clip is not None: + self.prediction_coefs.weight.data.clamp_(-self.config.altup_coef_clip, self.config.altup_coef_clip) + + # Project and then transpose all 2D matrices contained so that mulmat gives the correct result + all_coefs: torch.Tensor = ( + self.prediction_coefs(modalities) + .reshape(*modalities.shape[:-1], self.config.altup_num_inputs, self.config.altup_num_inputs) + .permute(0, 1, 3, 2) + ) + + # permute hidden_states to [batch_size, num_tokens, hidden_size, altup_num_inputs] + predictions = torch.matmul(hidden_states.permute(1, 2, 3, 0), all_coefs) + predictions = predictions.permute(3, 0, 1, 2) # undo the permute + predictions += hidden_states # add the original input + return predictions.contiguous().type_as(hidden_states) + + def correct(self, predictions: torch.Tensor, activated: torch.Tensor) -> torch.Tensor: + """Corrects the predictions relative to the + + Args: + predictions: A 4D tensor of shape `[num_altup_inputs, batch_size, num_tokens, hidden_size]` derived by + stacking the input embeddings and preprocessing the last `num_altup_inputs - 1` matrices. + activated: A 3D tensor of shape `[batch_size, num_tokens, hidden_size]` containing the activated inputs. + + Returns: + A 4D tensor of shape `[num_altup_inputs, batch_size, num_tokens, hidden_size]` correcting the original + predictions relative to the activated input embeddings. + """ + modalities = self.compute_router_modalities(activated) + innovation = activated - predictions[self.config.altup_active_idx] # (batch, num_tokens, hidden_size) + innovation = innovation.repeat(self.config.altup_num_inputs, 1, 1, 1) # Repeat on dim0 to match predictions + + if self.config.altup_coef_clip is not None: + self.correction_coefs.weight.data.clamp_(-self.config.altup_coef_clip, self.config.altup_coef_clip) + + # all_coefs adapted from jax.numpy.einsum("...p,pi->...i", ...) + # Permute to (altup_num_inputs, batch_size, num_tokens) as the last dim is a scalar applied to each altup input + # and expand on dim1 for broadcastability + all_coefs: torch.Tensor = self.correction_coefs(modalities) + 1.0 + all_coefs = all_coefs.permute(2, 0, 1).unsqueeze(-1) + + corrected = torch.mul(innovation, all_coefs) + corrected += predictions # add the original input + return corrected.contiguous().type_as(activated) + + def scale_corrected_output(self, corrected: torch.Tensor) -> torch.Tensor: + """Scales the provided 3D tensor of shape [batch_size, num_tokens, hidden_size].""" + return (corrected.type_as(self.correct_output_scale) * self.correct_output_scale).type_as(corrected) + + +class Gemma3nTextRotaryEmbedding(nn.Module): + def __init__(self, config: Gemma3nTextConfig, device=None): + super().__init__() + # BC: "rope_type" was originally "type" + if hasattr(config, "rope_scaling") and config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + @torch.no_grad() + @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) + def forward(self, x, position_ids): + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) + position_ids_expanded = position_ids[:, None, :].float() + + device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): # Force float32 + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + dropout: float = 0.0, + scaling: Optional[float] = None, + softcap: Optional[float] = None, + **kwargs, +) -> tuple[torch.Tensor, torch.Tensor]: + if scaling is None: + scaling = module.head_dim**-0.5 + + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + + if softcap is not None: + attn_weights = attn_weights / softcap + attn_weights = torch.tanh(attn_weights) + attn_weights = attn_weights * softcap + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + return attn_output, attn_weights + + +def apply_rotary_pos_emb( + x: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + position_ids: Optional[torch.Tensor] = None, + unsqueeze_dim: int = 1, +): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + x (`torch.Tensor`): The tensor to embed. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + return (x * cos) + (rotate_half(x) * sin) + + +class Gemma3nTextAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: Gemma3nTextConfig, layer_idx: int): + super().__init__() + self.is_sliding = config.layer_types[layer_idx] == "sliding_attention" + self.config = config + self.layer_idx = layer_idx + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.attention_dropout = self.config.attention_dropout + self.is_causal = True + + self.q_proj = nn.Linear( + config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias + ) + self.k_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.v_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.o_proj = nn.Linear( + config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias + ) + self.sliding_window = config.sliding_window if self.is_sliding else None + + self.q_norm = Gemma3nRMSNorm(dim=config.head_dim, eps=config.rms_norm_eps) + self.k_norm = Gemma3nRMSNorm(dim=config.head_dim, eps=config.rms_norm_eps) + self.v_norm = Gemma3nRMSNorm(dim=config.head_dim, eps=config.rms_norm_eps, with_scale=False) + + first_kv_shared_layer_idx = self.config.num_hidden_layers - self.config.num_kv_shared_layers + self.is_kv_shared_layer = layer_idx >= first_kv_shared_layer_idx + # Find the index of the last sliding or full layer before sharing starts (or None if no sharing) + layer_type = config.layer_types[layer_idx] + self.kv_shared_layer_index = ( + first_kv_shared_layer_idx - 1 - config.layer_types[first_kv_shared_layer_idx - 1 :: -1].index(layer_type) + if self.is_kv_shared_layer + else None + ) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: torch.Tensor, + attention_mask: Optional[torch.Tensor], + past_key_value: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.config.head_dim) + + cos, sin = position_embeddings + + query_states = self.q_proj(hidden_states).view(hidden_shape) + query_states = self.q_norm(query_states) + query_states = apply_rotary_pos_emb(query_states, cos, sin, unsqueeze_dim=2) + query_states = query_states.transpose(1, 2) + + if self.is_kv_shared_layer and self.kv_shared_layer_index is not None and past_key_value is not None: + # HybridCache has complex slicing when layer_type == "sliding_attention" that impact Shared KV Cache. + if isinstance(past_key_value, HybridCache) and self.is_sliding: + max_length = past_key_value.sliding_window + if cache_position.shape[0] > max_length: + # If in the prefill phase for a "sliding_attention" layer and the prefill is larger than the cache, + # slice into the entire cache. + indices = slice(0, max_length) + else: + # If prefill fits or generating for a "sliding_attention" layer, clamp to max_cache_len - 1 + indices = cache_position.clamp(min=0, max=max_length - 1) + else: + indices = cache_position + + key_states = past_key_value.key_cache[self.kv_shared_layer_index][:, :, indices] + value_states = past_key_value.value_cache[self.kv_shared_layer_index][:, :, indices] + else: + key_states = self.k_proj(hidden_states).view(hidden_shape) + key_states = self.k_norm(key_states) + key_states = apply_rotary_pos_emb(key_states, cos, sin, unsqueeze_dim=2) + key_states = key_states.transpose(1, 2) + + value_states = self.v_proj(hidden_states).view(hidden_shape) + value_states = self.v_norm(value_states) + value_states = value_states.transpose(1, 2) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = { + "sin": sin, + "cos": cos, + "cache_position": cache_position, + "sliding_window": self.sliding_window, + } + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=self.attention_dropout if self.training else 0.0, + scaling=1.0, + sliding_window=self.sliding_window, + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +class Gemma3nTextDecoderLayer(GradientCheckpointingLayer): + def __init__(self, config: Gemma3nTextConfig, layer_idx: int): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.layer_idx = layer_idx + self.attention_type = config.layer_types[layer_idx] + self.self_attn = Gemma3nTextAttention(config, layer_idx) + self.mlp = Gemma3nTextMLP(config, layer_idx=layer_idx) + self.input_layernorm = Gemma3nRMSNorm(self.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = Gemma3nRMSNorm(self.hidden_size, eps=config.rms_norm_eps) + self.pre_feedforward_layernorm = Gemma3nRMSNorm(self.hidden_size, eps=config.rms_norm_eps) + self.post_feedforward_layernorm = Gemma3nRMSNorm(self.hidden_size, eps=config.rms_norm_eps) + + self.hidden_size_per_layer_input = config.hidden_size_per_layer_input + self.act_fn = ACT2FN[config.hidden_activation] + + self.altup = Gemma3nTextAltUp(config) + self.laurel = Gemma3nTextLaurelBlock(config) + self.per_layer_input_gate = nn.Linear(self.hidden_size, self.hidden_size_per_layer_input, bias=False) + self.per_layer_projection = nn.Linear(self.hidden_size_per_layer_input, self.hidden_size, bias=False) + self.post_per_layer_input_norm = Gemma3nRMSNorm(self.hidden_size, eps=config.rms_norm_eps) + + @deprecate_kwarg("last_cache_position", version="4.53.0") + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings_global: torch.Tensor, + position_embeddings_local: torch.Tensor, + per_layer_input: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> tuple[torch.Tensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: + predictions = self.altup.predict(hidden_states) + active_prediction = predictions[self.config.altup_active_idx] + + active_prediction_normed = self.input_layernorm(active_prediction) + laurel_output = self.laurel(active_prediction_normed) + + # apply global RoPE to non-sliding layer only + if self.self_attn.is_sliding: + position_embeddings = position_embeddings_local + else: + position_embeddings = position_embeddings_global + + attn, self_attn_weights = self.self_attn( + hidden_states=active_prediction_normed, + position_embeddings=position_embeddings, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + attn = self.post_attention_layernorm(attn) + + attn_gated = active_prediction + attn + attn_laurel = (attn_gated + laurel_output) / math.sqrt(2) + + attn_norm = self.pre_feedforward_layernorm(attn_laurel) + attn_ffw = self.mlp(attn_norm) + attn_ffw_norm = self.post_feedforward_layernorm(attn_ffw) + attn_ffw_laurel_gated = attn_laurel + attn_ffw_norm + corrected_predictions = self.altup.correct(predictions, attn_ffw_laurel_gated) + + first_prediction = corrected_predictions[self.config.altup_active_idx] + first_prediction_clone = first_prediction.clone() + if self.config.altup_correct_scale: + first_prediction = self.altup.scale_corrected_output(first_prediction_clone) + + # per_layer_input_gate adapted from jax.numpy.einsum("btd,dp->btp", ...) + first_prediction = self.per_layer_input_gate(first_prediction) + first_prediction = self.act_fn(first_prediction) + first_prediction = torch.multiply(first_prediction, per_layer_input) + + # per_layer_projection adapted from jax.numpy.einsum("btp,pd->btd", ...) + first_prediction = self.per_layer_projection(first_prediction) + first_prediction = self.post_per_layer_input_norm(first_prediction) + corrected_predictions[1:] += first_prediction + + outputs = (corrected_predictions,) + + if output_attentions: + outputs += (self_attn_weights,) + + return outputs + + +@auto_docstring +class Gemma3nPreTrainedModel(PreTrainedModel): + config_class = Gemma3nConfig + base_model_prefix = "" + supports_gradient_checkpointing = True + _no_split_modules = ["Gemma3nDecoderLayer"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn_3 = True + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_flex_attn = True + _supports_cache_class = True + _supports_quantized_cache = True + _supports_static_cache = True + _supports_attention_backend = True + + def _init_weights(self, module): + # important: this ported version of Gemma2 isn't meant for training from scratch - only + # inference and fine-tuning - so the proper init weights code has been removed + std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range) + + if isinstance(module, (nn.Linear, nn.Conv1d, nn.Conv2d)): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, Gemma3nRMSNorm): + if module.with_scale: + module.weight.data.fill_(1.0) + elif isinstance(module, Gemma3nAudioCumulativeGroupNorm): + module.weight.data.fill_(1.0) + elif isinstance(module, Gemma3nAudioAttention): + module.per_dim_scale.data.zero_() + elif isinstance(module, Gemma3nTextAltUp): + module.correct_output_scale.data.zero_() + + +@auto_docstring(custom_intro="The base Gemma 3n language model without a language modeling head.") +class Gemma3nTextModel(Gemma3nPreTrainedModel): + config_class = Gemma3nTextConfig + + def __init__(self, config: Gemma3nTextConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + # Gemma3n downcasts the below to bfloat16, causing sqrt(3072)=55.4256 to become 55.5. See https://github.com/huggingface/transformers/pull/29402 + self.embed_tokens = Gemma3nTextScaledWordEmbedding( + config.vocab_size, config.hidden_size, self.padding_idx, embed_scale=self.config.hidden_size**0.5 + ) + self.layers = nn.ModuleList( + [Gemma3nTextDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + + self.norm = Gemma3nRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = Gemma3nTextRotaryEmbedding(config=config) + self.gradient_checkpointing = False + + # TODO (raushan): Fix this after RoPE refactor. For now we hack it by + # reassigning thetas when we want to create a local RoPE layer. Config + # defaults should hold values for global RoPE. + config = copy.deepcopy(config) + config.rope_theta = config.rope_local_base_freq + config.rope_scaling = {"rope_type": "default"} + self.rotary_emb_local = Gemma3nTextRotaryEmbedding(config=config) + + self.hidden_size = config.hidden_size + self.hidden_size_per_layer_input = config.hidden_size_per_layer_input + + self.embed_tokens_per_layer = Gemma3nTextScaledWordEmbedding( + config.vocab_size_per_layer_input, + config.num_hidden_layers * config.hidden_size_per_layer_input, + self.padding_idx, + embed_scale=config.hidden_size_per_layer_input**0.5, + ) + + self.per_layer_model_projection = nn.Linear( + self.hidden_size, + config.num_hidden_layers * config.hidden_size_per_layer_input, + bias=False, + ) + + self.per_layer_projection_norm = Gemma3nRMSNorm(config.hidden_size_per_layer_input, eps=config.rms_norm_eps) + + self.altup_projections = nn.ModuleList( + [nn.Linear(self.hidden_size, self.hidden_size, bias=False) for _ in range(1, self.config.altup_num_inputs)] + ) + + self.altup_unembed_projections = nn.ModuleList( + [nn.Linear(self.hidden_size, self.hidden_size, bias=False) for _ in range(1, self.config.altup_num_inputs)] + ) + + self.register_buffer("per_layer_projection_scale", torch.tensor(self.hidden_size**-0.5), persistent=False) + self.register_buffer("per_layer_input_scale", torch.rsqrt(torch.tensor(2.0)), persistent=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + per_layer_inputs: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **flash_attn_kwargs: Unpack[FlashAttentionKwargs], + ) -> BaseModelOutputWithPast: + r""" + per_layer_inputs (torch.Tensor, *optional*, defaults to None): + Pre-computed per-layer embeddings. If None, they are derived from input_ids if provided. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + ) + use_cache = False + + if input_ids is not None: + inputs_embeds = self.embed_tokens(input_ids) + per_layer_inputs = self.get_per_layer_inputs(input_ids) + + per_layer_inputs = self.project_per_layer_inputs(inputs_embeds, per_layer_inputs) + + if use_cache and past_key_values is None and not self.training: + past_key_values = DynamicCache() + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, + past_seen_tokens + inputs_embeds.shape[1], + device=inputs_embeds.device, + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + # It may already have been prepared by e.g. `generate` + if not isinstance(causal_mask_mapping := attention_mask, dict): + # Prepare mask arguments + mask_kwargs = { + "config": self.config, + "input_embeds": inputs_embeds, + "attention_mask": attention_mask, + "cache_position": cache_position, + "past_key_values": past_key_values, + } + # Create the masks + causal_mask_mapping = { + "full_attention": create_causal_mask(**mask_kwargs), + "sliding_attention": create_sliding_window_causal_mask(**mask_kwargs), + } + + # embed positions + hidden_states_0 = inputs_embeds + + # Initialize RoPE embeddings + position_embeddings_global = self.rotary_emb(hidden_states_0, position_ids) + position_embeddings_local = self.rotary_emb_local(hidden_states_0, position_ids) + + # Expand hidden_states to support per-layer inputs + target_magnitude: torch.Tensor = torch.mean(hidden_states_0**2, dim=-1, keepdim=True) ** 0.5 + epsilon_tensor = torch.tensor(torch.finfo().min) + + temp_hidden_states = [hidden_states_0] + for i in range(1, self.config.altup_num_inputs): + # altup_proj adapted from jax.numpy.einsum("btp,pd->btd", ...) + altup_proj: torch.Tensor = self.altup_projections[i - 1](hidden_states_0) + current_hidden_state = altup_proj.type(hidden_states_0.dtype) + new_magnitude = torch.mean(current_hidden_state**2, dim=-1, keepdim=True) ** 0.5 + current_hidden_state = current_hidden_state * ( + target_magnitude / torch.maximum(new_magnitude, epsilon_tensor) + ) + temp_hidden_states.append(current_hidden_state) + + hidden_states = torch.stack(temp_hidden_states, dim=0) # [num_altup_inputs, batch, seq_len, hidden_size] + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + + for decoder_layer in self.layers[: self.config.num_hidden_layers]: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + causal_mask = causal_mask_mapping[decoder_layer.attention_type] + per_layer_input = per_layer_inputs[:, :, decoder_layer.layer_idx, :] + + layer_outputs = decoder_layer( + hidden_states, + position_embeddings_global=position_embeddings_global, + position_embeddings_local=position_embeddings_local, + per_layer_input=per_layer_input, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + **flash_attn_kwargs, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + # add hidden states from the last decoder layer (but before reprojecting to stay consistent with layer output) + if output_hidden_states: + all_hidden_states += (hidden_states,) + + # Per-layer inputs to single output + target_magnitude = torch.mean(hidden_states[0] ** 2, dim=-1, keepdim=True) ** 0.5 + temp_hidden_states = [hidden_states[0]] + for i in range(1, self.config.altup_num_inputs): + # altup_unembed_projections adapted from jax.numpy.einsum("btp,pd->btd", ...) + altup_unemb_proj: torch.Tensor = self.altup_unembed_projections[i - 1](hidden_states[i]) + current_hidden_state = altup_unemb_proj.type(hidden_states_0.dtype) + new_magnitude = torch.mean(current_hidden_state**2, dim=-1, keepdim=True) ** 0.5 + current_hidden_state = current_hidden_state * ( + target_magnitude / torch.maximum(new_magnitude, epsilon_tensor) + ) + temp_hidden_states.append(current_hidden_state) + + hidden_states = torch.stack(temp_hidden_states) + hidden_states = torch.mean(hidden_states, dim=0) + hidden_states = self.norm(hidden_states) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + def get_per_layer_inputs(self, input_ids: torch.LongTensor) -> torch.Tensor: + return self.embed_tokens_per_layer(input_ids).reshape( + *input_ids.shape, + self.config.num_hidden_layers, + self.hidden_size_per_layer_input, + ) + + def project_per_layer_inputs( + self, + inputs_embeds: torch.Tensor, + per_layer_inputs: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + per_layer_projection: torch.Tensor = self.per_layer_model_projection(inputs_embeds) + per_layer_projection *= self.per_layer_projection_scale.type(inputs_embeds.dtype) + per_layer_projection = per_layer_projection.reshape( + *inputs_embeds.shape[:-1], + self.config.num_hidden_layers, + self.hidden_size_per_layer_input, + ) + per_layer_projection = self.per_layer_projection_norm(per_layer_projection) + + if per_layer_inputs is None: + return per_layer_projection + + if per_layer_projection.shape != per_layer_inputs.shape: + # per-layer inputs are sometimes padded with zeros, slice the relevant embeddings. + per_layer_inputs = per_layer_inputs[..., : self.config.num_hidden_layers, :] + + return (per_layer_projection + per_layer_inputs) * self.per_layer_input_scale.type(inputs_embeds.dtype) + + +@auto_docstring(custom_intro="The base Gemma 3n language model with a language modeling head.") +class Gemma3nForCausalLM(Gemma3nPreTrainedModel, GenerationMixin): + _tied_weights_keys = ["lm_head.weight"] + _tp_plan = {"lm_head": "colwise_rep"} + _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} + config_class = Gemma3nTextConfig + base_model_prefix = "model" + _checkpoint_conversion_mapping = {"model.language_model": "model"} + + def __init__(self, config: Gemma3nTextConfig): + super().__init__(config) + self.model = Gemma3nTextModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **loss_kwargs, + ) -> CausalLMOutputWithPast: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Example: + + ```python + >>> from transformers import AutoTokenizer, Gemma3nForCausalLM + + >>> model = Gemma3nForCausalLM.from_pretrained("google/gemma-2-9b") + >>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-9b") + + >>> prompt = "What is your favorite condiment?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "What is your favorite condiment?" + ```""" + + if self.training and self.config._attn_implementation != "eager": + logger.warning_once( + "It is strongly recommended to train Gemma3n models with the `eager` attention implementation " + f"instead of `{self.config._attn_implementation}`. Use `eager` with `AutoModelForCausalLM.from_pretrained('', attn_implementation='eager')`." + ) + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs: BaseModelOutputWithPast = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + cache_position=cache_position, + **loss_kwargs, + ) + + hidden_states = outputs.last_hidden_state + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + if self.config.final_logit_softcapping is not None: + logits = logits / self.config.final_logit_softcapping + logits = torch.tanh(logits) + logits = logits * self.config.final_logit_softcapping + + loss = None + if labels is not None: + loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs) + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +class Gemma3nMultimodalEmbedder(nn.Module): + """Embeds token ids or soft tokens for multimodal content into language model space.""" + + def __init__( + self, + multimodal_config: Union[Gemma3nAudioConfig, Gemma3nVisionConfig], + text_config: Gemma3nTextConfig, + ): + super().__init__() + + self.multimodal_hidden_size = multimodal_config.hidden_size + self.eps = multimodal_config.rms_norm_eps + self.vocab_offset = multimodal_config.vocab_offset + self.vocab_size = multimodal_config.vocab_size + self.text_hidden_size = text_config.hidden_size + + self.embedding = nn.Embedding(self.vocab_size, self.multimodal_hidden_size) + self.hard_embedding_norm = Gemma3nRMSNorm(self.multimodal_hidden_size, eps=self.eps) + self.soft_embedding_norm = Gemma3nRMSNorm(self.multimodal_hidden_size, eps=self.eps) + self.embedding_projection = nn.Linear(self.multimodal_hidden_size, self.text_hidden_size, bias=False) + self.embedding_post_projection_norm = Gemma3nRMSNorm(self.text_hidden_size, eps=self.eps, with_scale=False) + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Embeds token ids or soft tokens for multimodal content into language model space. + + Args: + input_ids: A torch.LongTensor containing the token ids to embed. Values should be in the range + `[vocab_offset, vocab_offset + vocab_size)`. + inputs_embeds: A torch.Tensor containing the soft tokens to embed. + + Returns: + A torch.Tensor of embeddings with shape `[batch_size, seq_len, self.config.text_config.hidden_size]`. + """ + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if inputs_embeds is not None: + emb_norm = self.soft_embedding_norm(inputs_embeds) + else: + hard_emb = self.embedding(input_ids - self.vocab_offset) + emb_norm = self.hard_embedding_norm(hard_emb) + + emb_norm_proj = self.embedding_projection(emb_norm) + return self.embedding_post_projection_norm(emb_norm_proj) + + +@auto_docstring( + custom_intro=""" + The base Gemma 3n model comprising a vision backbone, an audio backbone, and a language model without a + language modeling head. + """ +) +class Gemma3nModel(Gemma3nPreTrainedModel): + _checkpoint_conversion_mapping = {} + # we are filtering the logits/labels so we shouldn't divide the loss based on num_items_in_batch + accepts_loss_kwargs = False + + def __init__(self, config: Gemma3nConfig): + super().__init__(config) + self.vision_tower = AutoModel.from_config(config=config.vision_config) + self.vocab_size = config.text_config.vocab_size + + language_model = AutoModel.from_config(config=config.text_config) + self.language_model = language_model + + self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1 + self.vocab_size_per_layer_input = config.text_config.vocab_size_per_layer_input + self.audio_tower = AutoModel.from_config(config.audio_config) + self.embed_vision = Gemma3nMultimodalEmbedder(config.vision_config, config.text_config) + self.embed_audio = Gemma3nMultimodalEmbedder(config.audio_config, config.text_config) + self.post_init() + + def get_input_embeddings(self): + return self.language_model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.language_model.set_input_embeddings(value) + + def set_decoder(self, decoder): + self.language_model = decoder + + def get_decoder(self): + return self.language_model + + def get_image_features(self, pixel_values: torch.Tensor) -> torch.Tensor: + """ + Projects the last hidden state from the vision model into language model space. + + Args: + pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`) + The tensors corresponding to the input images. + + Returns: + image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`). + """ + vision_outputs = self.vision_tower( + pixel_values=pixel_values, do_pooling=False, return_dict=True + ).last_hidden_state + # Convert from (batch, channels, height, width) to (batch, height * width, channels) where: + # height == width and height * width == Gemma3nConfig.vision_soft_tokens_per_image. + vision_outputs = vision_outputs.reshape( + vision_outputs.shape[0], + self.config.vision_config.hidden_size, + self.config.vision_soft_tokens_per_image, + ).permute(0, 2, 1) + # Normalize and embed the soft tokens into language model space. + vision_outputs *= self.config.vision_config.hidden_size**0.5 + return self.embed_vision(inputs_embeds=vision_outputs) + + @can_return_tuple + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, # text inputs + pixel_values: Optional[torch.FloatTensor] = None, # vision inputs + input_features: Optional[torch.FloatTensor] = None, # audio inputs + attention_mask: Optional[torch.Tensor] = None, + input_features_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[list[torch.FloatTensor], Cache]] = None, + token_type_ids: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + **lm_kwargs, + ) -> Gemma3nCausalLMOutputWithPast: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.text_config.vocab_size]`. + + Example: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, Gemma3nForConditionalGeneration + + >>> model = Gemma3nForConditionalGeneration.from_pretrained("google/gemma3n2-3b-mix-224") + >>> processor = AutoProcessor.from_pretrained("google/gemma3n2-3b-mix-224") + + >>> prompt = "Where is the cat standing?" + >>> url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> inputs = processor(images=image, text=prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(**inputs,) + >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Where is the cat standing?\nsnow" + ``` + """ + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + + if input_ids is not None: + inputs_embeds = self.get_input_embeddings()(input_ids) + + # Prepare per-layer inputs from inputs_ids + per_layer_inputs_mask = torch.logical_and(input_ids >= 0, input_ids < self.vocab_size_per_layer_input) + per_layer_inputs_tokens = torch.where(per_layer_inputs_mask, input_ids, torch.zeros_like(input_ids)) + per_layer_inputs = self.language_model.get_per_layer_inputs(per_layer_inputs_tokens) + + # Handle vision tokens (>= embed_vision.vocab_offset and < embed_audio.vocab_offset) + vision_mask = torch.logical_and( + input_ids >= self.embed_vision.vocab_offset, input_ids < self.embed_audio.vocab_offset + ) + dummy_vision_token_id = self.embed_vision.vocab_offset + self.embed_vision.vocab_size - 1 + vision_input_ids = torch.where(vision_mask, input_ids, dummy_vision_token_id).to(inputs_embeds.device) + vision_embeds = self.embed_vision(input_ids=vision_input_ids) + expanded_vision_mask = vision_mask.unsqueeze(-1).expand_as(inputs_embeds) + inputs_embeds = torch.where(expanded_vision_mask, vision_embeds, inputs_embeds) + + # Handle audio tokens (>= embed_audio.vocab_offset) + audio_mask = input_ids >= self.embed_audio.vocab_offset + dummy_audio_token_id = self.embed_audio.vocab_offset + self.embed_audio.vocab_size - 1 + audio_input_ids = torch.where(audio_mask, input_ids, dummy_audio_token_id).to(inputs_embeds.device) + audio_embeds = self.embed_audio(input_ids=audio_input_ids) + expanded_audio_mask = audio_mask.unsqueeze(-1).expand_as(inputs_embeds) + inputs_embeds = torch.where(expanded_audio_mask, audio_embeds, inputs_embeds) + else: + per_layer_inputs = None + + # Merge text and images + if pixel_values is not None: + image_features = self.get_image_features(pixel_values) + + if input_ids is None: + special_image_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + else: + special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1) + special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) + + if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel(): + image_tokens_in_text = (special_image_mask).sum(dim=1).sum(dim=0)[0] + raise ValueError( + f"Number of images does not match number of special image tokens in the input text. " + f"Got {image_tokens_in_text} image tokens in the text and " + f"{image_features.shape[0] * image_features.shape[1]} tokens from image embeddings." + ) + image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) + + # Merge text and audio + if input_features is not None and input_features_mask is not None: + audio_features, audio_mask = self.get_audio_features(input_features, ~input_features_mask) + + # The Gemma3nProcessor expects all audio will be 30s in length and inserts 188 audio soft tokens into the + # text to account for this. However, the audio preprocessing and encoder do not gurarantee they will + # produce 188 soft tokens; they will produce at most that many tokens, but they may produce fewer tokens + # depending on the length of the longest audio input in the batch. When we encounter this situation, we pad + # the audio feature out to 188 soft tokens with the emebedding of the last token in the embed_audio vocab. + audio_padding_toks = torch.tensor([[self.vocab_size - 1]], dtype=torch.long, device=audio_features.device) + audio_padding_embs = self.embed_audio(input_ids=audio_padding_toks) + audio_features = torch.where(audio_mask.unsqueeze(-1), audio_padding_embs, audio_features) + + audio_batch_size, audio_seq_len, audio_embed_dim = audio_features.shape + extra_padding_tokens = self.config.audio_soft_tokens_per_image - audio_seq_len + extra_padding_features = audio_padding_embs.expand(audio_batch_size, extra_padding_tokens, audio_embed_dim) + + audio_features = torch.cat((audio_features, extra_padding_features), dim=1) + + if input_ids is None: + special_audio_mask = inputs_embeds == self.embed_audio( + input_ids=torch.tensor(self.config.audio_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + else: + special_audio_mask = (input_ids == self.config.audio_token_id).unsqueeze(-1) + special_audio_mask = special_audio_mask.expand_as(inputs_embeds).to(inputs_embeds.device) + + if not is_torchdynamo_compiling() and inputs_embeds[special_audio_mask].numel() != audio_features.numel(): + audio_tokens_in_text = (special_audio_mask).sum(dim=1).sum(dim=0)[0] + raise ValueError( + f"Number of audio input features does not match number of special audio tokens in the input text. " + f"Got {audio_tokens_in_text} audio tokens in the text and " + f"{audio_features.shape[0] * audio_features.shape[1]} tokens from audio embeddings." + ) + audio_features = audio_features.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(special_audio_mask, audio_features) + + outputs = self.language_model( + input_ids=None, + per_layer_inputs=per_layer_inputs, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=True, + cache_position=cache_position, + **lm_kwargs, + ) + + return Gemma3nModelOutputWithPast( + last_hidden_state=outputs.last_hidden_state, + past_key_values=outputs.past_key_values if use_cache else None, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + image_hidden_states=image_features if pixel_values is not None else None, + audio_hidden_states=audio_features if input_features is not None else None, + ) + + def get_audio_features( + self, input_features: torch.Tensor, input_features_mask: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Projects the last hidden state from the audio encoder into language model space. + + Args: + input_features (`torch.FloatTensor]` of shape `(num_images, seq_length, num_features)`): + The tensors corresponding to the input audio. + input_features (`torch.FloatTensor]` of shape `(num_images, seq_length)`): + The attention mask for the input audio. + + Returns: + audio_features (`torch.Tensor`): Audio feature tensor of shape `(num_images, audio_length, embed_dim)`). + """ + audio_outputs, audio_mask = self.audio_tower(input_features, input_features_mask) + return self.embed_audio(inputs_embeds=audio_outputs), audio_mask + + +@auto_docstring( + custom_intro=""" + The base Gemma 3n model comprising a vision backbone, an audio backbone, a language model, and a language modeling + head. + """ +) +class Gemma3nForConditionalGeneration(Gemma3nPreTrainedModel, GenerationMixin): + _checkpoint_conversion_mapping = {} + _tied_weights_keys = ["lm_head.weight"] + base_model_prefix = "model" + + def __init__(self, config: Gemma3nConfig): + super().__init__(config) + self.model = Gemma3nModel(config) + self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False) + self.post_init() + + def get_input_embeddings(self): + return self.model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.model.set_input_embeddings(value) + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model.set_decoder(decoder) + + def get_decoder(self): + return self.model.get_decoder() + + def get_image_features(self, pixel_values): + return self.model.get_image_features(pixel_values) + + # Make modules available throught conditional class for BC + @property + def language_model(self): + return self.model.language_model + + @property + def vision_tower(self): + return self.model.vision_tower + + @property + def multi_modal_projector(self): + raise AttributeError("Use embed_vision instead of multi_modal_projector.") + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, # text inputs + pixel_values: Optional[torch.FloatTensor] = None, # vision inputs + input_features: Optional[torch.FloatTensor] = None, # audio inputs + attention_mask: Optional[torch.Tensor] = None, + input_features_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[list[torch.FloatTensor], Cache]] = None, + token_type_ids: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **lm_kwargs, + ) -> Gemma3nCausalLMOutputWithPast: + r""" + input_features (torch.Tensor, *optional*, defaults to None): + The audio inputs to be encoded. + input_features_mask (torch.Tensor, *optional*, defaults to None): + The attention mask for the input audio. + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are + ignored (masked), the loss is only computed for the tokens with labels in + `[0, ..., config.text_config.vocab_size]`. + + Example: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, Gemma3ForConditionalGeneration + + >>> model = Gemma3ForConditionalGeneration.from_pretrained("google/gemma-3-4b-it") + >>> processor = AutoProcessor.from_pretrained("google/gemma-3-4b-it") + + >>> messages = [ + ... { + ... "role": "system", + ... "content": [ + ... {"type": "text", "text": "You are a helpful assistant."} + ... ] + ... }, + ... { + ... "role": "user", "content": [ + ... {"type": "image", "url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg"}, + ... {"type": "text", "text": "Where is the cat standing?"}, + ... ] + ... }, + ... ] + + >>> inputs = processor.apply_chat_template( + ... messages, + ... tokenizer=True, + ... return_dict=True, + ... return_tensors="pt", + ... add_generation_prompt=True + ... ) + >>> # Generate + >>> generate_ids = model.generate(**inputs) + >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "user\nYou are a helpful assistant.\n\n\n\n\n\nWhere is the cat standing?\nmodel\nBased on the image, the cat is standing in a snowy area, likely outdoors. It appears to" + ``` + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + + outputs = self.model( + input_ids=input_ids, + pixel_values=pixel_values, + input_features=input_features, + attention_mask=attention_mask, + input_features_mask=input_features_mask, + position_ids=position_ids, + past_key_values=past_key_values, + token_type_ids=token_type_ids, + cache_position=cache_position, + inputs_embeds=inputs_embeds, + labels=labels, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=True, + **lm_kwargs, + ) + + hidden_states = outputs.last_hidden_state + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + if (final_logit_softcapping := self.config.get_text_config().final_logit_softcapping) is not None: + logits = logits / final_logit_softcapping + logits = torch.tanh(logits) + logits = logits * final_logit_softcapping + + loss = None + if labels is not None: + # Upcast to float if we need to compute the loss to avoid potential precision issues + logits = logits.float() + shift_logits = logits[..., :-1, :] + shift_labels = labels[..., 1:] + if attention_mask is not None: + # we use the input attention mask to shift the logits and labels, because it is 2D. + # we also crop attn mask in case it is longer, which happens in PrefixTuning with peft + shift_attention_mask = attention_mask[:, -shift_logits.shape[1] :].to(logits.device) + shift_logits = shift_logits[shift_attention_mask.to(logits.device) != 0].contiguous() + shift_labels = shift_labels[shift_attention_mask.to(shift_labels.device) != 0].contiguous() + else: + shift_logits = shift_logits.contiguous() + shift_labels = shift_labels.contiguous() + # Flatten the tokens + loss_fct = nn.CrossEntropyLoss() + + flat_logits = shift_logits.view(-1, self.config.text_config.vocab_size) + flat_labels = shift_labels.view(-1).to(shift_logits.device) + loss = loss_fct(flat_logits, flat_labels) + + return Gemma3nCausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + image_hidden_states=outputs.image_hidden_states, + audio_hidden_states=outputs.audio_hidden_states, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + inputs_embeds=None, + cache_position=None, + position_ids=None, + pixel_values=None, + input_features=None, + attention_mask=None, + input_features_mask=None, + token_type_ids=None, + use_cache=True, + logits_to_keep=None, + labels=None, + **kwargs, + ): + # Overwritten -- custom `position_ids` and `pixel_values` handling + model_inputs = super().prepare_inputs_for_generation( + input_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + position_ids=position_ids, + cache_position=cache_position, + use_cache=use_cache, + logits_to_keep=logits_to_keep, + token_type_ids=token_type_ids, + **kwargs, + ) + + # If we're in cached decoding stage, multimodal inputs should be None because input ids do not contain special + # tokens anymore. Otherwise multimodal inputs should be passed to model. + # NOTE: use_cache=False always needs pixel_values, input_features, and input_features_mask + if cache_position[0] == 0: + model_inputs["pixel_values"] = pixel_values + model_inputs["input_features"] = input_features + model_inputs["input_features_mask"] = input_features_mask + + return model_inputs + + @property + def audio_tower(self): + return self.model.audio_tower + + +__all__ = [ + "Gemma3nAudioEncoder", + "Gemma3nForCausalLM", + "Gemma3nForConditionalGeneration", + "Gemma3nModel", + "Gemma3nPreTrainedModel", + "Gemma3nTextModel", +] diff --git a/src/transformers/models/gemma3n/modular_gemma3n.py b/src/transformers/models/gemma3n/modular_gemma3n.py new file mode 100644 index 000000000000..a3ffa710d842 --- /dev/null +++ b/src/transformers/models/gemma3n/modular_gemma3n.py @@ -0,0 +1,2664 @@ +# coding=utf-8 +# Copyright 2025 Google Inc. HuggingFace Inc. team. All rights reserved. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import copy +import math +from collections.abc import Callable, Sequence +from typing import Any, Optional, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache, HybridCache +from ...configuration_utils import PretrainedConfig, layer_type_validation +from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_outputs import BaseModelOutputWithPast +from ...modeling_rope_utils import rope_config_validation +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import auto_docstring, can_return_tuple, is_torchdynamo_compiling, logging +from ..auto import AutoModel +from ..gemma2.configuration_gemma2 import Gemma2Config +from ..gemma2.modeling_gemma2 import ( + Gemma2MLP, + Gemma2PreTrainedModel, + Gemma2RotaryEmbedding, + eager_attention_forward, + rotate_half, +) +from ..gemma3.modeling_gemma3 import ( + Gemma3Attention, + Gemma3DecoderLayer, + Gemma3ForCausalLM, + Gemma3RMSNorm, + Gemma3TextModel, + Gemma3TextScaledWordEmbedding, +) +from ..paligemma.modeling_paligemma import ( + PaliGemmaCausalLMOutputWithPast, + PaliGemmaForConditionalGeneration, + PaliGemmaModel, + PaligemmaModelOutputWithPast, +) +from ..timm_wrapper.configuration_timm_wrapper import TimmWrapperConfig + + +logger = logging.get_logger(__name__) + + +class Gemma3nTextConfig(Gemma2Config, PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Gemma3nTextModel`]. It is used to instantiate an + Gemma3nTextModel model according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the Gemma 3n E4B, e.g. + [google/gemma-3n-E4B](https://huggingface.co/google/gemma-3n-E4B). + + Configuration objects that inherit from [`Gemma3nTextConfig`] and can be used to control the model outputs. Read + the documentation from [`Gemma3nTextConfig`] for more information. + + Args: + vocab_size (`int`, *optional*, defaults to 262400): + Vocabulary size of the Gemma3nText model. Defines the number of different tokens that can be represented by + the `inputs_ids` passed when calling [`Gemma3nTextModel`] + vocab_size_per_layer_input (`int`, *optional*, defaults to 262144): + Vocabulary size of the per-layer text embeddings that augment the standard embeddings. + hidden_size (`int`, *optional*, defaults to 2048): + Dimension of the hidden representations. + hidden_size_per_layer_input (`int`, *optional*, defaults to 256): + Dimension of the hidden representations for per-layer emebeddings. + intermediate_size (`int` or `Sequence[int]`, *optional*, defaults to 16384): + Dimension of the MLP representations. MatFormer configurations may wish to provide a sequence of integers + to account for vairable intermediate_size values across layers. In such cases, + `len(intermediate_size) == num_hidden_layers`. + num_hidden_layers (`int`, *optional*, defaults to 35): + Number of hidden layers in the Transformer decoder. + num_attention_heads (`int`, *optional*, defaults to 8): + Number of attention heads for each attention layer in the Transformer decoder. + num_key_value_heads (`int`, *optional*, defaults to 2): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout this + [paper](https://arxiv.org/pdf/2305.13245.pdf). If not specified, will default to `num_attention_heads`. + head_dim (`int`, *optional*, defaults to 256): + The attention head dimension. + hidden_activation (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`): + The non-linear activation function (function or string) in the decoder. Will default to + `"gelu_pytorch_tanh"` if not specified. `"gelu_pytorch_tanh"` uses an approximation of the `"gelu"` + activation function. + max_position_embeddings (`int`, *optional*, defaults to 32768): + The maximum sequence length that this model might ever be used with. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + pad_token_id (`int`, *optional*, defaults to 0): + Padding token id. + eos_token_id (`int`, *optional*, defaults to 1): + End of stream token id. + bos_token_id (`int`, *optional*, defaults to 2): + Beginning of stream token id. + rope_theta (`float`, *optional*, defaults to 1000000.0): + The base period of the RoPE embeddings. + rope_scaling (`Dict`, *optional*): + Dictionary containing the scaling configuration for the RoPE embeddings used in gloabl attention. + NOTE: if you apply new rope type and you expect the model to work on longer `max_position_embeddings`, we + recommend you to update this value accordingly. + Expected contents: + `rope_type` (`str`): + The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope', + 'llama3'], with 'default' being the original RoPE implementation. + `factor` (`float`, *optional*): + Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In + most scaling types, a `factor` of x will enable the model to handle sequences of length x * + original maximum pre-trained length. + `original_max_position_embeddings` (`int`, *optional*): + Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during + pretraining. + `attention_factor` (`float`, *optional*): + Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention + computation. If unspecified, it defaults to value recommended by the implementation, using the + `factor` field to infer the suggested value. + `beta_fast` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear + ramp function. If unspecified, it defaults to 32. + `beta_slow` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear + ramp function. If unspecified, it defaults to 1. + `short_factor` (`List[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to short contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `long_factor` (`List[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to long contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `low_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE + `high_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE + rope_local_base_freq (float, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings for local attention. + attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`): + Whether to use a bias in the query, key, value and output projection layers during self-attention. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + sliding_window (`int`, *optional*, defaults to 512): + This is the size of the sliding window used by local attention layers. + layer_types (`Optional`, *optional*): + A sequence of strings defining the attention type for that layer as either "sliding_attention" or + "full_attention". If not provided, `layer_types` will de inferred from `num_hidden_layers` using a pattern + of four "sliding_attention" layers followed one "full_attention". The last layer in the model should always + be a "full_attention" layer. + final_logit_softcapping (`float`, *optional*, defaults to 30.0): + Scaling factor when applying tanh softcapping on the logits. + altup_active_idx (`int`, *optional*, defaults to 0): + The index of the prediction from which AltUp will compute additional predictions or correct + altup_coef_clip (`float`, *optional*, defaults to 120.0): + The maximum amplitude of an AltUp prediction or correction coeficient weight. + altup_correct_scale (`bool`, *optional*, defaults to `True`): + If True, apply the `AltUp.correct_output_scale` to the corrected prediction at `altup_active_idx`. + altup_num_inputs (`int`, *optional*, defaults to 4): + The number of predictions that AltUp should be make given the input sequence. + num_kv_shared_layers (`int`, *optional*, defaults to 15): + The number of layer that share KV cache values. During the forward pass, the last `num_kv_shared_layers` + layers in the model "share" the KV values in that each local and global layer in this range uses the KV + cache values computed for the last local or global layer, respectively, before entering this range. The + value should be `num_kv_shared_layers` should be a scalar of `sliding_window_pattern`. + laurel_rank (int, *optional*, defaults to 64): + The intermediate size for the linear projections in the Learned Augmented Residual Layer. + activation_sparsity_pattern (Sequence[float], *optional*, defaults to `(0.95, 0.95, 0.95, 0.95, 0.95, 0.95, 0.95, 0.95, 0.95, 0.95, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0)`): + The sparsity factor used to extract the top-k activations for a given layer. The provided Sequence must + explicitly provide a sparsity value for each layer in the model. + + ```python + >>> from transformers import Gemma3nTextModel, Gemma3nTextConfig + + >>> # Initializing a Gemma3nText gemma3n_text-E4B style configuration + >>> configuration = Gemma3nTextConfig() + + >>> # Initializing a model from the gemma3n_text-E4B style configuration + >>> model = Gemma3nTextModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ``` + """ + + model_type = "gemma3n_text" + + def __init__( + self, + vocab_size: int = 262_400, + vocab_size_per_layer_input: int = 262_144, + hidden_size: int = 2048, + hidden_size_per_layer_input: int = 256, + intermediate_size: Union[int, Sequence[int]] = 16_384, + num_hidden_layers: int = 35, + num_attention_heads: int = 8, + num_key_value_heads: int = 2, + head_dim: int = 256, + hidden_activation: str = "gelu_pytorch_tanh", + max_position_embeddings: int = 32_768, + initializer_range: float = 0.02, + rms_norm_eps: float = 1e-6, + use_cache: bool = True, + pad_token_id: int = 0, + eos_token_id: int = 1, + bos_token_id: int = 2, + rope_theta: float = 1_000_000.0, + rope_scaling: Optional[dict[str, Any]] = None, + rope_local_base_freq: float = 10_000.0, + attention_bias: bool = False, + attention_dropout: float = 0.0, + sliding_window: int = 512, + layer_types: Optional[Sequence[str]] = None, + final_logit_softcapping: float = 30.0, + altup_active_idx: int = 0, + altup_coef_clip: float = 120.0, + altup_correct_scale: bool = True, + altup_num_inputs: int = 4, + num_kv_shared_layers: int = 15, + laurel_rank: int = 64, + activation_sparsity_pattern: Optional[Union[float, Sequence[float]]] = (0.95,) * 10 + (0.0,) * 25, + **kwargs, + ): + PretrainedConfig.__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + **kwargs, + ) + + if isinstance(intermediate_size, Sequence) and (intsize_len := len(intermediate_size)) != num_hidden_layers: + raise ValueError( + "intermediate_size must have an explicit intermediate size for every layer or one for all layers. " + f"Expected {num_hidden_layers} values but got {intsize_len}." + ) + elif not isinstance(intermediate_size, Sequence): + intermediate_size = [intermediate_size] * num_hidden_layers + + self.vocab_size = vocab_size + self.vocab_size_per_layer_input = vocab_size_per_layer_input + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.head_dim = head_dim + self.num_key_value_heads = num_key_value_heads + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + self.hidden_activation = hidden_activation + self.sliding_window = sliding_window + self.final_logit_softcapping = final_logit_softcapping + self.layer_types = layer_types + + self.rope_local_base_freq = rope_local_base_freq + self.rope_scaling = rope_scaling + rope_config_validation(self) + + if layer_types is None: + self.layer_types = [ + "full_attention" if i % 5 == 0 else "sliding_attention" for i in range(self.num_hidden_layers) + ] + else: + self.layer_types = layer_types + + layer_type_validation(self.layer_types) + + self.hidden_size_per_layer_input = hidden_size_per_layer_input + self.num_kv_shared_layers = num_kv_shared_layers + + self.altup_active_idx = altup_active_idx + self.altup_coef_clip = altup_coef_clip + self.altup_correct_scale = altup_correct_scale + self.altup_num_inputs = altup_num_inputs + + self.laurel_rank = laurel_rank + + if activation_sparsity_pattern is None: + activation_sparsity_pattern = [0.0] * num_hidden_layers + + if (len_asp := len(activation_sparsity_pattern)) != num_hidden_layers: + raise ValueError( + "activation_sparsity_pattern must have an explicit activation sparsity value for every layer." + f"Expected {num_hidden_layers} values but got {len_asp}." + ) + self.activation_sparsity_pattern = activation_sparsity_pattern + + +class Gemma3nAudioConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Gemma3nAudioEncoder`], based on Gogole's + [Universal Speech Model](). It is used to instantiate an Gemma3nAudioEncoder model according to the specified + arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar + configuration to that of the Gemma 3n E4B, e.g. [google/gemma-3n-E4B](https://huggingface.co/google/gemma-3n-E4B). + + Configuration objects that inherit from [`Gemma3nAudioConfig`] and can be used to control the model outputs. Read + the documentation from [`Gemma3nAudioConfig`] for more information. + + Args: + vocab_size (`int`, *optional*, defaults to 128): + Vocabulary size of the additional hard-token embeddings for audio model. These augment the embeddings + included in the `Gemma3nTextModel` to provide, e.g., the end of audio and audio soft token placeholder + tokens when converting `input_ids` to embeddings in the `Gemma3nForConditionalGeneration` model. + vocab_offset (`int`, *optional*, defaults to 262272): + Offset between the tokenizer vocab index for the token ids embedded by `Gemma3nMultimodalEmbedder` and the + 0-indexed `Gemma3nMultimodalEmbedder.embedding` table. + input_feat_size (`int`, *optional*, defaults to 128): + The number of channels in each mel-spectrogram frame. + hidden_size (`int`, *optional*, defaults to 1536): + Dimension of the hidden representations. + rms_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the rms normalization layers. + gradient_clipping (`float`, *optional*, defaults to 10000000000.0): + Clipping value used to stablize extremely large gradient values. + conf_attention_chunk_size (`int`, *optional*, defaults to 12): + The sub-sequence size for local attention processing inside the Conformer ("conf") section of the + Universal Speech Model. + conf_attention_context_left (`int`, *optional*, defaults to 13): + The left context size of the local attention inside the Conformer ("conf") section of the + Universal Speech Model. + conf_attention_context_right (`int`, *optional*, defaults to 0): + The right context size of the local attention inside the Conformer ("conf") section of the + Universal Speech Model. + conf_attention_logit_cap (`float`, *optional*, defaults to 50.0): + Logit cap applied during local attention inside the Conformer ("conf") section of the + Universal Speech Model. + conf_num_attention_heads (`int`, *optional*, defaults to 8): + The number of attention heads in local attention inside the Conformer ("conf") section of the + Universal Speech Model. + conf_num_hidden_layers (`int`, *optional*, defaults to 12): + The number of layers that use local attention inside the Conformer ("conf") section of the + Universal Speech Model. + conf_conv_kernel_size (`int`, *optional*, defaults to 5): + Convolution kernel size for the conformer block inside the Conformer ("conf") section of the + Universal Speech Model. + conf_reduction_factor (`int`, *optional*, defaults to 4): + Reduction factor used in the conformer block inside the Conformer ("conf") section of the + Universal Speech Model. + conf_residual_weight (`float`, *optional*, defaults to 0.5): + Residual connection weight inside the Conformer ("conf") section of the + Universal Speech Model. + sscp_conv_channel_size (`tuple(int, int)`, *optional*, defaults to `(128, 32)`): + The channel sizes for the first and second convolutional layers in the Sub-sample Convolution Projection + ("sscp") section of the Universal Speech Model. + sscp_conv_group_norm_eps (`float`, *optional*, defaults to 0.001): + Epsilon used in group normalization in the subsample convolution projection in the Sub-sample Convolution + Projection ("sscp") section of the Universal Speech Model. + sscp_conv_kernel_size (`tuple(tuple(int, int), tuple(int, int))`, *optional*, defaults to `((3, 3), (3, 3))`): + Kernel sizes of the two convolutional layers in the subsample convolution projection in the Sub-sample + Convolution Projection ("sscp") section of the Universal Speech Model. The kernel sizes are specified as a + tuple of height and width for each layer, where the height corresponds to the time dimension and the width + corresponds to the frequency dimension. + sscp_conv_stride_size (`tuple(tuple(int, int), tuple(int, int))`, *optional*, defaults to `((2, 2), (2, 2))`): + Stride sizes of the two convolutional layers in the subsample convolution projection in the Sub-sample + Convolution Projection ("sscp") section of the Universal Speech Model. The stride sizes are specified as a + tuple of height and width for each layer, where the height corresponds to the time dimension and the width + corresponds to the frequency dimension. + + Example: + + ```python + >>> from transformers import Gemma3nAudioConfig, Gemma3nAudioEncoder + + >>> # Initializing a Gemma3nAudioEncoder gemma3n_audio-E4B-style configuration + >>> configuration = Gemma3nAudioConfig() + + >>> # Initializing a model from the gemma3n_audio-E4B style configuration + >>> model = Gemma3nAudioEncoder(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ``` + """ + + model_type = "gemma3n_audio" + + def __init__( + self, + vocab_size: int = 128, + vocab_offset: int = 262_144 + 128, # text vocab size + vision vocab size + input_feat_size: int = 128, + hidden_size: int = 1536, + rms_norm_eps: float = 1e-6, + gradient_clipping: float = 10_000_000_000.0, + conf_attention_chunk_size: int = 12, + conf_attention_context_left: int = 13, + conf_attention_context_right: int = 0, + conf_attention_logit_cap: float = 50.0, + conf_num_attention_heads: int = 8, + conf_num_hidden_layers: int = 12, + conf_conv_kernel_size: int = 5, + conf_reduction_factor: int = 4, + conf_residual_weight: float = 0.5, + sscp_conv_channel_size: tuple[int, int] = (128, 32), + sscp_conv_group_norm_eps: float = 1e-3, + sscp_conv_kernel_size: tuple[tuple[int, int], tuple[int, int]] = ( + (3, 3), + (3, 3), + ), + sscp_conv_stride_size: tuple[tuple[int, int], tuple[int, int]] = ( + (2, 2), + (2, 2), + ), + **kwargs, + ): + super().__init__(**kwargs) + self.input_feat_size = input_feat_size + self.hidden_size = hidden_size + self.rms_norm_eps = rms_norm_eps + self.vocab_size = vocab_size + self.vocab_offset = vocab_offset + self.gradient_clipping = gradient_clipping + self.conf_attention_chunk_size = conf_attention_chunk_size + self.conf_attention_context_left = conf_attention_context_left + self.conf_attention_context_right = conf_attention_context_right + self.conf_attention_logit_cap = conf_attention_logit_cap + self.conf_num_attention_heads = conf_num_attention_heads + self.conf_num_hidden_layers = conf_num_hidden_layers + self.conf_conv_kernel_size = conf_conv_kernel_size + self.conf_reduction_factor = conf_reduction_factor + self.conf_residual_weight = conf_residual_weight + self.sscp_conv_channel_size = sscp_conv_channel_size + self.sscp_conv_group_norm_eps = sscp_conv_group_norm_eps + self.sscp_conv_kernel_size = sscp_conv_kernel_size + self.sscp_conv_stride_size = sscp_conv_stride_size + + +class Gemma3nVisionConfig(TimmWrapperConfig): + r""" + This is the configuration class to store the configuration for a timm backbone [`TimmWrapper`]. It is used to + instantiate an timm model model according to the specified arguments, defining the model architecture. + Instantiating a configuration with the defaults will yield a similar configuration to that of the Gemma 3n E4B + vision tower, e.g. [google/gemma-3n-E4B](https://huggingface.co/google/gemma-3n-E4B). + + Configuration objects inherit from [`Gemma3nVisionConfig`] and can be used to control the model outputs. Read the + documentation from [`Gemma3nVisionConfig`] for more information. + + Config loads imagenet label descriptions and stores them in `id2label` attribute, `label2id` attribute for default + imagenet models is set to `None` due to occlusions in the label descriptions. + + Args: + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + do_pooling (`bool`, *optional*, defaults to `False`): + Whether to do pooling for the last_hidden_state in `TimmWrapper` or not. + architecture (`str`, *optional*, defaults to `"mobilenetv5_300m_enc"`): + Determines vision architecture for TimmWrapper. + hidden_size (`int`, *optional*, defaults to 2048): + Dimension of the hidden representations. + vocab_size (`int`, *optional*, defaults to 128): + Vocabulary size of the additional hard-token embeddings for vision model. + vocab_offset (`int`, *optional*, defaults to 262144): + Offset between the tokenizer vocab index for the token ids embedded by `Gemma3nMultimodalEmbedder` and the + 0-indexed `Gemma3nMultimodalEmbedder.embedding` table. + rms_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the rms normalization layers. + + Example: + ```python + >>> from transformers import Gemma3nVisionConfig, TimmWrapper + + >>> # Initializing a TimmWrapper gemma3n_vision-E4B-style configuration + >>> configuration = Gemma3nVisionConfig() + + >>> # Initializing a gemma3n_vision-E4B-style TimmWrapper from the configuration + >>> model = TimmWrapper(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ``` + """ + + model_type = "gemma3n_vision" + + def __init__( + self, + initializer_range: float = 0.02, + do_pooling: bool = False, + architecture: str = "mobilenetv5_300m_enc", + hidden_size: int = 2048, + vocab_size: int = 128, + vocab_offset: int = 262_144, + rms_norm_eps: float = 1e-06, + model_args: Optional[dict] = None, + **kwargs, + ): + super().__init__(**kwargs) + self.architecture = architecture + self.initializer_range = initializer_range + self.do_pooling = do_pooling + self.hidden_size = hidden_size + self.vocab_size = vocab_size + self.vocab_offset = vocab_offset + self.rms_norm_eps = rms_norm_eps + + +class Gemma3nConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Gemma3nForConditionalGeneration`]. It is used to + instantiate a Gemma3nForConditionalGeneration according to the specified arguments, defining the model + architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of + Gemma3n-E4B. + + e.g. [google/gemma-3n-E4B](https://huggingface.co/google/gemma-3n-E4B) + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + text_config (`Union[Gemma3nTextConfig, dict]`, *optional*): + The config object of the text backbone. + vision_config (`Union[AutoConfig, dict]`, *optional*): + Custom vision config or dict. + audio_config (`Union[AutoConfig, dict]`, *optional*): + Custom audio config or dict. + audio_soft_tokens_per_image (`int`, *optional*, defaults to 188): + The number of soft tokens per audio clip. + vision_soft_tokens_per_image (`int`, *optional*, defaults to 256): + The number of soft tokens per image. + boi_token_id (`int`, *optional*, defaults to 255999): + The begin-of-image token index to wrap the image prompt. + eoi_token_id (`int`, *optional*, defaults to 262144): + The end-of-image token index to wrap the image prompt. + image_token_id (`int`, *optional*, defaults to 262145): + The image token index to encode the image prompt. + boa_token_id (`int`, *optional*, defaults to 256000): + The begin-of-audio token index to wrap the audio prompt. + eoa_token_id (`int`, *optional*, defaults to 262272): + The end-of-audio token index to wrap the audio prompt. + audio_token_id (`int`, *optional*, defaults to 262273): + The audio token index to encode the audio prompt. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + + + Example: + + ```python + >>> from transformers import Gemma3nForConditionalGeneration, Gemma3nConfig, Gemma3nTextConfig + + >>> # Initializing a MobileNet vision config, which is loaded from TIMM + >>> vision_config = Gemma3nVisionConfig() + + >>> # Initializing a Gemma3n Audio config + >>> audio_config = Gemma3nAudioConfig() + + >>> # Initializing a Gemma3n Text config + >>> text_config = Gemma3nTextConfig() + + >>> # Initializing a Gemma3n gemma-3-4b style configuration + >>> configuration = Gemma3nConfig(text_config, vision_config, audio_config) + + >>> # Initializing a model from the gemma-3-4b style configuration + >>> model = Gemma3nTextConfig(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "gemma3n" + sub_configs = { + "text_config": Gemma3nTextConfig, + "vision_config": Gemma3nVisionConfig, + "audio_config": Gemma3nAudioConfig, + } + + def __init__( + self, + text_config: Optional[Union[Gemma3nTextConfig, dict[str, Any]]] = None, + vision_config: Optional[Union[Gemma3nVisionConfig, dict[str, Any]]] = None, + audio_config: Optional[Union[Gemma3nAudioConfig, dict[str, Any]]] = None, + audio_soft_tokens_per_image: int = 188, + vision_soft_tokens_per_image: int = 256, + boi_token_id: int = 255_999, + eoi_token_id: int = 262_144, + image_token_id: int = 262_145, + boa_token_id: int = 256_000, + eoa_token_id: int = 262_272, + audio_token_id: int = 262_273, + initializer_range: float = 0.02, + **kwargs, + ): + super().__init__(**kwargs) + + if isinstance(text_config, dict): + text_config = Gemma3nTextConfig(**text_config) + elif text_config is None: + text_config = Gemma3nTextConfig() + logger.info("text_config is None. Using default Gemma3nTextConfig.") + + if isinstance(vision_config, dict): + vision_config = Gemma3nVisionConfig(**vision_config) + elif vision_config is None: + vision_config = Gemma3nVisionConfig() + logger.info("vision_config is None. Using default Gemma3nVisionConfig.") + + if isinstance(audio_config, dict): + audio_config = Gemma3nAudioConfig(**audio_config) + elif audio_config is None: + audio_config = Gemma3nAudioConfig() + logger.info("audio_config is None. Using default Gemma3nAudioConfig.") + + self.text_config = text_config + self.vision_config = vision_config + self.audio_config = audio_config + + self.audio_soft_tokens_per_image = audio_soft_tokens_per_image + self.vision_soft_tokens_per_image = vision_soft_tokens_per_image + self.boi_token_id = boi_token_id + self.eoi_token_id = eoi_token_id + self.image_token_id = image_token_id + self.boa_token_id = boa_token_id + self.eoa_token_id = eoa_token_id + self.audio_token_id = audio_token_id + self.initializer_range = initializer_range + + +class Gemma3nModelOutputWithPast(PaligemmaModelOutputWithPast): + r""" + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) + + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + image_hidden_states (`torch.FloatTensor`, *optional*): + A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`. + image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state. + audio_hidden_states (`torch.FloatTensor`, *optional*): + A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`. + audio_hidden_states of the model produced by the audio encoder and after projecting the last hidden state. + """ + + audio_hidden_states: Optional[torch.FloatTensor] = None + + +class Gemma3nCausalLMOutputWithPast(PaliGemmaCausalLMOutputWithPast): + r""" + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss (for next-token prediction). + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.text_config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) + + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + image_hidden_states (`torch.FloatTensor`, *optional*): + A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`. + image_hidden_states of the model produced by the vision encoder after projecting last hidden state. + audio_hidden_states (`torch.FloatTensor`, *optional*): + A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`. + audio_hidden_states of the model produced by the audio encoder and after projecting the last hidden state. + """ + + audio_hidden_states: Optional[torch.FloatTensor] = None + + +class Gemma3nRMSNorm(Gemma3RMSNorm): + def __init__(self, dim: int, eps: float = 1e-6, with_scale: bool = True): + super().__init__(dim, eps=eps) + del self.weight + self.with_scale = with_scale + + if self.with_scale: + self.weight = nn.Parameter(torch.ones(dim)) + else: + self.register_buffer("weight", torch.tensor(1.0), persistent=False) + + def _norm(self, x): + return x / torch.sqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # Llama does x.to(float16) * w whilst Gemma2 is (x * w).to(float16) + # See https://github.com/huggingface/transformers/pull/29402 + output = self._norm(x.float()) * self.weight.float() + return output.type_as(x) + + +# ==== Audio Encoder ==== + + +class Gemma3nAudioRelativePositionEmbedding(nn.Module): + def __init__(self, config: Gemma3nAudioConfig): + super().__init__() + self.config = config + + self.num_heads = self.config.conf_num_attention_heads + self.channels = self.config.hidden_size + self.head_dim = self.channels // self.num_heads + self.max_backward = max(0, self.config.conf_attention_context_left - 1) + self.max_forward = self.config.conf_attention_context_right + + self.pos_proj = nn.Linear(self.channels, self.num_heads * self.head_dim, bias=False) + + min_timescale = 1.0 + max_timescale = 1.0e4 + num_timescales = self.channels // 2 + log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / max(num_timescales - 1, 1) + inv_timescales = min_timescale * torch.exp(torch.arange(num_timescales) * -log_timescale_increment) + self.register_buffer( + "inv_timescales", + inv_timescales.float().unsqueeze(0).unsqueeze(0), + persistent=False, + ) + + def _get_timing_signal_1d_pos(self, position: torch.Tensor, dtype: torch.dtype) -> torch.Tensor: + position = position.float().unsqueeze(-1) + scaled_time = position * self.inv_timescales.to(device=position.device, dtype=torch.float32) + timing_signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=-1) + return timing_signal.type(dtype) + + def _relative_shift( + self, + term_bd_before_shift: torch.Tensor, + batch_size: int, + num_heads: int, + num_query_blocks: int, + query_block_size: int, + key_context_size: int, + max_span_plus_1: int, + ) -> torch.Tensor: + """Performs the relative shift. + + Args: + term_bd_before_shift: Tensor of shape [B, N, U, W, F_span]. batch_size + (B), num_heads (N), num_query_blocks (U), query_block_size (W), + key_context_size (C = W+L+R), max_span_plus_1 (F_span = L+R+1). + + Returns: + Tensor of shape [B, N, U, W, C]. + """ + # term_bd_before_shift shape: [B, N, U, W, F_span] + # Target shape after shift: [B, N, U, W, C] + + # Padding amount for the last dimension (F_span) to become (C + 1) + # C = key_context_size + # F_span = max_span_plus_1 + pad_amount_last_dim = (key_context_size + 1) - max_span_plus_1 + + # PyTorch F.pad expects (pad_left, pad_right, pad_top, pad_bottom ...) + # We only pad the last dimension on the right. + padding_tuple = (0, pad_amount_last_dim) + + term_bd_padded = nn.functional.pad(term_bd_before_shift, padding_tuple) + # Shape after pad: [B, N, U, W, C+1] + + # Reshape for slicing (emulating JAX's behavior) + # [B, N, U, W * (C+1)] + term_bd_reshaped = term_bd_padded.reshape( + ( + batch_size, + num_heads, + num_query_blocks, + query_block_size * (key_context_size + 1), + ) + ) + + # Slice to effective [B, N, U, W * C] + term_bd_sliced = term_bd_reshaped[:, :, :, : query_block_size * key_context_size] + + # Reshape back to [B, N, U, W, C] + term_bd_shifted = term_bd_sliced.reshape( + ( + batch_size, + num_heads, + num_query_blocks, + query_block_size, + key_context_size, + ) + ) + return term_bd_shifted + + def forward(self, queries: torch.Tensor, keys: torch.Tensor) -> torch.Tensor: + # queries: [B, U, W, N, H] (batch, num_query_blocks, query_block_size, num_heads, head_dim) + # keys: [B, U, C, N, H] (batch, num_query_blocks, key_context_size, num_heads, head_dim) + # C = W + L + R (key_context_size) + # F_span = L + R + 1 (max_span + 1) + + batch_size, num_query_blocks, query_block_size, num_heads, head_dim = queries.shape + _, _, key_context_size, _, _ = keys.shape + + # Relative positions for sinusoidal embeddings: [L, L-1, ..., -R] + # Length is L+R+1 = self.max_span + 1 + pos_indices = torch.arange(self.max_backward, -self.max_forward - 1, -1, device=queries.device).unsqueeze( + 0 + ) # Shape [1, F_span] + + max_span_plus_1 = pos_indices.shape[1] # F_span + + sin_emb_timing_signal = self._get_timing_signal_1d_pos( + pos_indices, dtype=queries.dtype + ) # Shape [1, F_span, self.channels] + + # Project sinusoidal embeddings: [1, F_span, self.channels] -> [1, F_span, N*H] + projected_sin_emb = self.pos_proj(sin_emb_timing_signal) + # Reshape to [1, F_span, N, H] then squeeze to [F_span, N, H] + sin_emb = projected_sin_emb.reshape(1, max_span_plus_1, self.num_heads, self.head_dim).squeeze( + 0 + ) # Shape [F, N, H] + + # term_ac: Query-Key content interaction + # queries: [B, U, W, N, H] -> permute to [B, N, U, W, H] for matmul + # keys: [B, U, C, N, H] -> permute to [B, N, U, H, C] for matmul + queries_p = queries.permute(0, 3, 1, 2, 4) # [B, N, U, W, H] + keys_p_t = keys.permute(0, 3, 1, 4, 2) # [B, N, U, H, C] + term_ac = torch.matmul(queries_p, keys_p_t) # [B, N, U, W, C] + + # term_bd: Query-Position interaction + # Original einsum: term_bd_unshifed = torch.einsum('buwnh,fnh->bnuwf', queries, sin_emb) + # queries shape: [B, U, W, N, H] + # sin_emb shape: [F, N, H] + # Target output shape: [B, N, U, W, F] + + # Permute queries to [B, N, U, W, H] for easier broadcasting with sin_emb + q_permuted = queries.permute(0, 3, 1, 2, 4) + + # Permute sin_emb to [N, H, F] to prepare for matmul + # sin_emb original is [F, N, H] + s_permuted = sin_emb.permute(1, 2, 0) # Shape: [N, H, F] + + # Reshape queries for matmul: [B, N, U*W, H] + q_reshaped = q_permuted.reshape(batch_size, num_heads, num_query_blocks * query_block_size, head_dim) + + # Perform matmul: [B, N, U*W, H] @ [N, H, F] + # s_permuted ([N, H, F]) will be broadcast to [B, N, H, F] + # Result: [B, N, U*W, F] + term_bd_unshifed_matmul = torch.matmul(q_reshaped, s_permuted) + + # Reshape to target [B, N, U, W, F] + term_bd_unshifed = term_bd_unshifed_matmul.reshape( + batch_size, + num_heads, + num_query_blocks, + query_block_size, + max_span_plus_1, + ) + + # Apply relative shift to term_bd_unshifed + term_bd_shifted = self._relative_shift( + term_bd_unshifed, + batch_size, + num_heads, + num_query_blocks, + query_block_size, + key_context_size, + max_span_plus_1, + ) # Shape [B, N, U, W, C] + + return term_ac + term_bd_shifted + + +class Gemma3nAudioAttention(nn.Module): + def __init__(self, config: Gemma3nAudioConfig): + super().__init__() + self.config = config + + self.num_heads = self.config.conf_num_attention_heads + self.hidden_size = self.config.hidden_size + self.head_dim = self.hidden_size // self.num_heads + + self.chunk_size = self.config.conf_attention_chunk_size + self.max_future_horizon = self.config.conf_attention_context_right + self.max_past_horizon = max(0, self.config.conf_attention_context_left - 1) + self.attention_logits_soft_cap = self.config.conf_attention_logit_cap + self.context_size = self.chunk_size + self.max_past_horizon + self.max_future_horizon + + self.relative_position_embedding = Gemma3nAudioRelativePositionEmbedding(config) + self.per_dim_scale = nn.Parameter(torch.zeros((self.head_dim,))) + + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) + self.k_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) + self.v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) + + q_scale = self.head_dim**-0.5 + r_softplus_0 = 1.0 / torch.nn.functional.softplus(torch.tensor(0.0)) + self.register_buffer("q_scale", (q_scale * r_softplus_0).clone().detach(), persistent=False) + + lower_causal_mask = torch.tril( + torch.ones((self.context_size, self.chunk_size), dtype=torch.bool), + diagonal=0, + ).T + upper_causal_mask = torch.tril( + torch.ones((self.chunk_size, self.context_size), dtype=torch.bool), + diagonal=self.max_past_horizon + self.max_future_horizon, + ) + local_causal_valid_mask = torch.ones((self.chunk_size, self.context_size), dtype=torch.bool) + local_causal_valid_mask = local_causal_valid_mask * lower_causal_mask * upper_causal_mask + self.register_buffer("local_causal_valid_mask", local_causal_valid_mask, persistent=False) + + self.register_buffer( + "softcap", + torch.tensor(self.attention_logits_soft_cap).float(), + persistent=False, + ) + + def _pad_dim1(self, x: torch.Tensor, pad_left: int, pad_right: int) -> torch.Tensor: + batch, _, *tail_shape = x.shape + left = x.new_zeros((batch, pad_left, *tail_shape)) + right = x.new_zeros((batch, pad_right, *tail_shape)) + x = torch.cat([left, x, right], dim=1) + return x + + def _convert_to_block(self, hidden_states: torch.Tensor) -> torch.Tensor: + """Turns a sequence to non overlapping blocks. + + Args: + hidden_states: a tensor of [batch, time, ...]. + + Returns: + A tensor of [batch, num_blocks, block_size, ...], with necessary + paddings, + where output[:, i, ...] are x[:, i*block_size:(i+1)*block_size, ...]. + """ + shape = hidden_states.shape + b, t = shape[:2] + num_blocks = (t + self.chunk_size - 1) // self.chunk_size + + if (padding_len := num_blocks * self.chunk_size - t) > 0: + hidden_states = self._pad_dim1(hidden_states, 0, padding_len) + + permute_dims = (b, num_blocks, self.chunk_size) + shape[2:] + hidden_states = hidden_states.reshape(permute_dims).contiguous() + return hidden_states + + def _extract_block_context(self, hidden_states: torch.Tensor) -> torch.Tensor: + """Extracts temporal context for every block. + + Args: + hidden_states: a tensor of [batch, time, ...]. + + Returns: + A tensor of [batch, num_blocks, context_size, ...], with necessary + paddings, + where context_size = block_size + left_context + right_context, + and output[:, i, ...] are x[:, start-left_context:end+right_context, + ...], + start = i * block_size, end = (i + 1) * block_size. + """ + pad_left = self.max_past_horizon + # The JAX equivalent padding for signal.frame with pad_mode='valid' is + # (left_context, right_context + block_size - 1) on the time dimension. + # PyTorch's _pad_dim1 applies padding symmetrically if only one value is given, + # or (pad_dim_start, pad_dim_end) if two are given. + # Our _pad_dim1(x, pad_left, pad_right) pads dim -2 (time for [B,T,N,H]) + # or dim 1 (time for [B,T]). + # The current pad_right calculation matches the JAX effective padding. + pad_right = self.max_future_horizon + self.chunk_size - 1 + hidden_states = self._pad_dim1(hidden_states, pad_left, pad_right) + + frame_len = self.context_size + frame_step = self.chunk_size + + # Directly use unfold without the subframe_factor logic + # x.unfold(dimension, size, step) + # dimension=1 (time dimension, assuming x is [B, T_padded, ...]) + # size=frame_len (context_size) + # step=frame_step (chunk_size) + x_unfolded = hidden_states.unfold(dimension=1, size=frame_len, step=frame_step) + + # If x was [B, T_padded], x_unfolded is [B, num_blocks, frame_len] + # If x was [B, T_padded, N, H], x_unfolded is [B, num_blocks, N, H, frame_len] + # We want to match JAX's typical output for such operations which might be + # [B, num_blocks, frame_len, N, H] if N, H are present. + # The relative_position_embedding expects keys as [B, U, C, N, H]. + # If x_unfolded is [B, U, N, H, C(frame_len)], we need to move C. + if hidden_states.ndim > 2 and x_unfolded.ndim > 3: # Check if inner dimensions (like N, H) exist + # Current shape after unfold for [B, T_pad, N, H] is [B, U, N, H, C] + # Target shape for keys in RPE: [B, U, C, N, H] + x_unfolded = torch.movedim(x_unfolded, source=-1, destination=2) + + return x_unfolded.contiguous() + + def forward(self, hidden_states: torch.Tensor, mask: torch.BoolTensor) -> torch.Tensor: + # sl.Dense uses jax.numpy.einsum("...a,abcd->...bcd") and jax.numpy.select() + qkv_shape = (*hidden_states.shape[:-1], self.num_heads, self.head_dim) + query_states = self.q_proj(hidden_states).reshape(qkv_shape).contiguous() + key_states = self.k_proj(hidden_states).reshape(qkv_shape).contiguous() + value_states = self.v_proj(hidden_states).reshape(qkv_shape).contiguous() + + per_dim_scale_sp = torch.nn.functional.softplus(self.per_dim_scale) + + broadcast_shape = (1, 1, 1, self.head_dim) + per_dim_scale_sp_broadcast = per_dim_scale_sp.view(broadcast_shape) + query_states = query_states * self.q_scale * per_dim_scale_sp_broadcast + + batch_size, q_time = query_states.shape[:2] + + query_blocks = self._convert_to_block(query_states) + key_blocks = self._extract_block_context(key_states) + value_blocks = self._extract_block_context(value_states) + num_query_blocks = query_blocks.shape[1] + + # 1. Create a mask indicating originally valid positions. + original_valid_mask = ~mask # True for valid, False for padded + + # 2. Extract blocks from this validity mask. + extracted_valid_mask_blocks = self._extract_block_context(original_valid_mask) + + # If subframe_factor was used in _extract_block_context for a [B, T] input mask, + # the shape might be [B, U, C/SF, SF]. Reshape to [B, U, C]. + # batch_size and num_query_blocks are known from query_blocks. + # self.context_size is C. + if ( + extracted_valid_mask_blocks.ndim == 4 + and extracted_valid_mask_blocks.shape[2] * extracted_valid_mask_blocks.shape[3] == self.context_size + ): + extracted_valid_mask_blocks = extracted_valid_mask_blocks.reshape( + batch_size, num_query_blocks, self.context_size + ) + # After potential reshape, ensure it's [B, U, C] if it was from a [B,T] mask. + # This assertion might be too strict if _extract_block_context handles higher-rank inputs differently, + # but for the mask case, this should hold. + if extracted_valid_mask_blocks.shape != ( + batch_size, + num_query_blocks, + self.context_size, + ): + raise ValueError( + "Shape of extracted_valid_mask_blocks" + f" {extracted_valid_mask_blocks.shape} is not ({batch_size}," + f" {num_query_blocks}, {self.context_size}) after potential reshape." + ) + + # 3. Expand dimensions for broadcasting with logits and causal mask. + # Target shape for broadcasting with logits [B,N,U,W,C] + # extracted_valid_mask_blocks to [B, 1, U, 1, C] + condition_from_input_validity = extracted_valid_mask_blocks.unsqueeze(1).unsqueeze(-2) + + # self.local_causal_valid_mask is [W, C], True where allowed by local window. + # Expand to [1, 1, 1, W, C] + condition_from_causality = self.local_causal_valid_mask.unsqueeze(0).unsqueeze(0).unsqueeze(0) + + # 4. Combine the two conditions. + # final_condition will be True where a key is *both* originally valid *and* causally accessible. + # Broadcasts to [B, 1, U, W, C] + final_condition_for_where = torch.logical_and( + condition_from_input_validity, + condition_from_causality.to(condition_from_input_validity.device), # Ensure same device + ) + + # Embed queries and keys + logits = self.relative_position_embedding(query_blocks, key_blocks) + + # Apply attention logit softcap + # Ensure softcap is on the same device as logits + softcap_val = self.softcap.to(logits.device) + logits = logits / softcap_val + logits = torch.tanh(logits) + logits = logits * softcap_val + + # Apply the combined mask. + # final_condition_for_where will broadcast with logits [B,N,U,W,C] + logits = torch.where(final_condition_for_where, logits, torch.finfo(logits.dtype).min) + probabilities = torch.nn.functional.softmax(logits, dim=-1, dtype=torch.float32).to(dtype=value_blocks.dtype) + + # context_vectors is adapted from jax.numpy.einsum("BNuwc,BucNH->BuwNH", ...) + b_dim, n_dim, u_dim, w_dim, c_dim = probabilities.shape + h_dim = value_blocks.shape[-1] + prob_bun = probabilities.permute(0, 2, 1, 3, 4).reshape(-1, w_dim, c_dim) + v_bun = value_blocks.permute(0, 1, 3, 2, 4).reshape(-1, c_dim, h_dim) + result_bmm = torch.bmm(prob_bun, v_bun) + context_vectors = result_bmm.reshape(b_dim, u_dim, n_dim, w_dim, h_dim).permute(0, 1, 3, 2, 4) + context_vectors = context_vectors.reshape( + ( + batch_size, + num_query_blocks * self.chunk_size, + self.num_heads, + self.head_dim, + ) + ) + context_vectors = context_vectors[:, :q_time] + + return context_vectors + + +class Gemma3nAudioCumulativeGroupNorm(nn.Module): + """Applies Group Normalization cumulatively over the time dimension. + + This layer normalizes the input by calculating the mean and variance + cumulatively over the time dimension (dim 1). The statistics are computed + over all feature dimensions (specified by `feature_dims` and `num_channels`) + for elements marked as valid by the optional `mask`. + + If a `mask` is provided (True for valid, False for invalid/padded), + invalid time steps do not contribute to the statistics calculation, and + their corresponding output values are zeroed out. + + Scale and bias, if enabled, are applied per-channel (last dimension). + This behavior is similar to JAX's `GroupNormalization` with `num_groups=1` + and `cumulative=True`. + """ + + def __init__( + self, + num_channels: int, # Number of channels (size of the last dimension) + feature_dims: Sequence[int], # Sizes of non-channel feature dimensions, e.g., (H, W) for input [B,T,H,W,C] + eps: float = 1e-3, + ): + super().__init__() + self.num_channels = num_channels + self.feature_dims = tuple(feature_dims) + self.eps = eps + + # Scale parameter depends only on the channel dimension + self.weight = nn.Parameter(torch.ones(num_channels)) + + # Axes for normalization: all dimensions except Batch (0) and Time (1). + # For input [B, T, *feature_dims, C], these are dims from 2 onwards. + self.reduction_axes = tuple(range(2, 2 + len(self.feature_dims) + 1)) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + """Applies cumulative group norm, optionally using a mask. + + Args: + hidden_states: Input tensor, shape [B, T, *feature_dims, C]. + + Returns: + Normalized tensor with the same shape as x. + """ + expected_input_suffix = self.feature_dims + (self.num_channels,) + if hidden_states.shape[2:] != expected_input_suffix: + raise ValueError( + f"Input tensor shape suffix {hidden_states.shape[2:]} does not match expected" + f" suffix (feature_dims + num_channels) {expected_input_suffix}" + ) + + input_dtype = hidden_states.dtype + # Calculations are performed in float32 for numerical stability. + calc_dtype = torch.float32 + x_calc = hidden_states.to(calc_dtype) + + # Prepare a broadcastable mask (`mask_calc`). + # If no mask is provided, treat all elements as valid + # (mask_calc is all ones). + # Otherwise, expand the [B, T] mask to [B, T, 1, ..., 1] for broadcasting. + mask_calc = torch.ones_like(x_calc, dtype=calc_dtype) + + # Cumulative Statistics Calculation + # 1. Sum of values over reduction axes at each time step. + sum_values_at_t = torch.sum(x_calc, dim=self.reduction_axes, keepdim=True) + # 2. Cumulative sum of values over time. + cum_sum_values = torch.cumsum(sum_values_at_t, dim=1) + + # 3. Count of valid elements in the normalization group at each time step. + # (A "group" here consists of all features at a given Batch, Time). + elements_in_group_at_t = torch.sum(mask_calc, dim=self.reduction_axes, keepdim=True) + # 4. Cumulative count of valid elements over time. + cum_count_elements = torch.cumsum(elements_in_group_at_t, dim=1) + # Avoid division by zero if all preceding elements were masked. + safe_cum_count_elements = torch.clamp(cum_count_elements, min=1.0) + + # 5. Cumulative mean. + cum_mean = cum_sum_values / safe_cum_count_elements + + # 6. Sum of squared differences from the cumulative mean. + # Only sum for valid elements: (x_calc - cum_mean)^2 * mask_calc. + # Using x_calc here for the difference, as cum_mean already accounts for masking. + squared_diff_from_mean = (x_calc - cum_mean).pow(2) + sum_sq_diff_at_t = torch.sum(squared_diff_from_mean, dim=self.reduction_axes, keepdim=True) + + # 7. Cumulative sum of squared differences over time. + cum_sum_sq_diff = torch.cumsum(sum_sq_diff_at_t, dim=1) + + # 8. Cumulative variance. + cum_variance = cum_sum_sq_diff / safe_cum_count_elements + + # Normalize the input using the calculated cumulative statistics: + # (x - E[x]) / sqrt(Var[x] + eps) + normalized_x = (x_calc - cum_mean) * torch.rsqrt(cum_variance + self.eps) + + # Apply affine transformation (scale and bias) if enabled. + # Scale and bias are applied per-channel (last dimension). + scale = self.weight.to(calc_dtype) + # Reshape for broadcasting: [C] -> [1, ..., 1, C] + scale_view_shape = [1] * (hidden_states.dim() - 1) + [self.num_channels] + normalized_x = normalized_x * scale.view(scale_view_shape) + + # Zero out outputs for time steps that were originally masked (where mask_calc is 0). + # This ensures padded/invalid positions in the input result in zero output. + final_output = normalized_x * mask_calc + + return final_output.to(input_dtype) + + +class Gemma3nAudioSSCPConvBlock(nn.Module): + """A single convolution block for the SubSampleConvProjection. + + This block consists of a 2D convolution, followed by CumulativeGroupNorm, + and a ReLU activation. It handles manual padding for the convolution. + """ + + def __init__( + self, + config: Gemma3nAudioConfig, + idx: int, + input_freq_dim: int, # Changed from input_spatial_dim + manual_padding: tuple[int, int, int, int] = (0, 0, 0, 0), + ): + super().__init__() + self.config = config + self.manual_padding = manual_padding + + # in_channels is 1 for the first block, or C_out from previous block's conv + in_channels = 1 if idx == 0 else self.config.sscp_conv_channel_size[idx - 1] + out_channels = self.config.sscp_conv_channel_size[idx] + kernel_h, kernel_w = self.config.sscp_conv_kernel_size[idx] + stride_h, stride_w = self.config.sscp_conv_stride_size[idx] + + self.conv = nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=( + kernel_h, + kernel_w, + ), # Kernel (kH, kW) operates on (Time, Freq_dim) + stride=(stride_h, stride_w), + padding=(0, 0), # Manual padding is used + bias=False, + ) + + # Calculate output frequency dimension (f_out_conv) after this convolution. + # input_freq_dim is the unpadded width (feature dimension). + # self.manual_padding is (pad_F_left, pad_F_right, pad_T_top, pad_T_bottom) + f_in_padded = input_freq_dim + self.manual_padding[0] + self.manual_padding[1] + f_out_conv = (f_in_padded - kernel_w) // stride_w + 1 + + self.norm = Gemma3nAudioCumulativeGroupNorm( + num_channels=out_channels, # Channels of the conv output + feature_dims=(f_out_conv,), # The frequency dimension size after conv + eps=self.config.sscp_conv_group_norm_eps, + ) + + self.activation = nn.ReLU() + + def forward(self, audio_encodings: torch.Tensor) -> torch.Tensor: + # Input audio_encodings is [B, C_in, T_in, F_in] (e.g., C_in=1) + # manual_padding is (pad_F_left, pad_F_right, pad_T_top, pad_T_bottom) + # F.pad applies to last two dims: F_in then T_in + audio_encodings_padded = F.pad(audio_encodings, self.manual_padding, mode="constant", value=0.0) + # Expected padded shape for F_in, k_w=3, pad_F=(1,1) -> F_padded = F_in+2 + # Expected padded shape for T_in, k_h=3, pad_T=(0,2) -> T_padded = T_in+2 + audio_encodings_conv = self.conv(audio_encodings_padded) + # Expected conv output shape: [B, C_out, T_out, F_out] + # Input to norm is [B, T_out, F_out, C_out] + x_for_norm = audio_encodings_conv.permute(0, 2, 3, 1).contiguous() + x_normed = self.norm(x_for_norm) + # Output of norm is [B, T_out, F_out, C_out], permute back to [B, C_out, T_out, F_out] + audio_encodings_normed = x_normed.permute(0, 3, 1, 2).contiguous() + return self.activation(audio_encodings_normed) + + +class Gemma3nAudioSubSampleConvProjection(nn.Module): + def __init__(self, config: Gemma3nAudioConfig): + super().__init__() + self.config = config + + current_f_for_block_input = config.input_feat_size # Start with original feature dim + calculated_block_padding = [] + calculated_f_out_dims = [] # Tracking frequency dimension output sizes + + for i in range(2): # Assuming 2 conv layers as per sscp_conv_... arrays + kernel_h, kernel_w = config.sscp_conv_kernel_size[i] + stride_h, stride_w = config.sscp_conv_stride_size[i] + + # Padding for Time (Height for Conv2d) - REVERSE_CAUSAL like + # JAX 'reverse_causal' padding is (0, kernel_size - 1) + pad_t_top = 0 + pad_t_bottom = kernel_h - 1 + + # Frequency Padding (Width for Conv2d) + # Based on JAX effective padding (1,1) for F_in=10, K_w=3, S_w=2 + # and the successful test configuration. + # If kernel/stride/input_freq for frequency changes, this might need re-evaluation + # to match generic JAX 'SAME' behavior if it differs. + pad_f_left = 1 + pad_f_right = 1 + + manual_padding_tuple = ( + pad_f_left, + pad_f_right, + pad_t_top, + pad_t_bottom, + ) + calculated_block_padding.append(manual_padding_tuple) + + # Calculate output frequency dimension after this convolution + # This uses the actual padding applied and kernel/stride. + f_in_padded = current_f_for_block_input + pad_f_left + pad_f_right + f_out_after_conv = (f_in_padded - kernel_w) // stride_w + 1 # Assuming dilation_w = 1 + calculated_f_out_dims.append(f_out_after_conv) + current_f_for_block_input = f_out_after_conv + + self.conv_0 = Gemma3nAudioSSCPConvBlock( + idx=0, + input_freq_dim=config.input_feat_size, # Pass original feature dim + config=config, + manual_padding=calculated_block_padding[0], + ) + self.conv_1 = Gemma3nAudioSSCPConvBlock( + idx=1, + input_freq_dim=calculated_f_out_dims[0], # Output freq dim from conv_0 + config=config, + manual_padding=calculated_block_padding[1], + ) + final_c_out = config.sscp_conv_channel_size[-1] + final_f_out = calculated_f_out_dims[-1] # Final frequency dimension + self.input_proj_in_features = final_c_out * final_f_out + self.input_proj_linear = nn.Linear(self.input_proj_in_features, self.config.hidden_size, bias=False) + + def forward(self, audio_encodings: torch.Tensor) -> torch.Tensor: + # audio_encodings is [B, T, F_in] + # Reshape to [B, 1, T, F_in] (Batch, Channels=1, Height=Time, Width=F_in) + audio_encodings_reshaped = audio_encodings.unsqueeze(1) + x = self.conv_0(audio_encodings_reshaped) + x = self.conv_1(x) + # x from conv_1 is [B, C_out_1, T_out_1, F_out_1] + b, c_out, t_out, f_out = x.shape + # Permute to [B, T_out_1, F_out_1, C_out_1] then flatten F_out_1 and C_out_1 + x_permuted = x.permute(0, 2, 3, 1).contiguous() + output_flattened = x_permuted.view(b, t_out, f_out * c_out) + output = self.input_proj_linear(output_flattened) + return output + + +class Gemma3nAudioConformerAttention(nn.Module): + def __init__(self, config: Gemma3nAudioConfig): + super().__init__() + self.config = config + self.post_in_features = self.config.hidden_size + self.register_buffer("gradient_clipping", torch.tensor(self.config.gradient_clipping), persistent=False) + self.pre_attn_norm = Gemma3nRMSNorm(self.config.hidden_size) + self.attn = Gemma3nAudioAttention(config) + self.post = nn.Linear(self.post_in_features, self.config.hidden_size, bias=False) + self.post_norm = Gemma3nRMSNorm(self.config.hidden_size) + + def forward(self, audio_encodings: torch.Tensor, audio_mel_mask: torch.BoolTensor) -> torch.Tensor: + audio_encodings_input_to_attn = audio_encodings + audio_encodings = torch.clamp(audio_encodings, -self.gradient_clipping, self.gradient_clipping) + audio_encodings_norm = self.pre_attn_norm(audio_encodings) + # Output of self.attn is [B, T, NumHeads, HeadDim] + audio_encodings_attn_out = self.attn(audio_encodings_norm, audio_mel_mask) + + # Reshape from [B, T, NumHeads, HeadDim] to [B, T, NumHeads * HeadDim] + # NumHeads * HeadDim = hidden_size + b, t, num_heads, head_dim = audio_encodings_attn_out.shape + audio_encodings_reshaped = audio_encodings_attn_out.reshape(b, t, num_heads * head_dim) + + audio_encodings = self.post(audio_encodings_reshaped) + audio_encodings = torch.clamp(audio_encodings, -self.gradient_clipping, self.gradient_clipping) + return audio_encodings_input_to_attn + self.post_norm(audio_encodings) + + +class Gemma3nAudioConformerFeedForward(nn.Module): + def __init__(self, config: Gemma3nAudioConfig): + super().__init__() + self.config = config + + self.register_buffer("gradient_clipping", torch.tensor(self.config.gradient_clipping), persistent=False) + + self.pre_layer_norm = Gemma3nRMSNorm(self.config.hidden_size) + self.ffw_layer_1 = nn.Linear(self.config.hidden_size, self.config.hidden_size * 4, bias=False) + self.ffw_layer_2 = nn.Linear(self.config.hidden_size * 4, self.config.hidden_size, bias=False) + self.post_layer_norm = Gemma3nRMSNorm(self.config.hidden_size) + self.post_layer_scale = torch.tensor(self.config.conf_residual_weight) + + def forward(self, audio_encodings: torch.Tensor) -> torch.Tensor: + residual = audio_encodings + audio_encodings = torch.clamp(audio_encodings, -self.gradient_clipping, self.gradient_clipping) + audio_encodings = self.pre_layer_norm(audio_encodings) + audio_encodings: torch.Tensor = self.ffw_layer_1(audio_encodings) + audio_encodings = nn.functional.silu(audio_encodings) + audio_encodings: torch.Tensor = self.ffw_layer_2(audio_encodings) + audio_encodings = torch.clamp(audio_encodings, -self.gradient_clipping, self.gradient_clipping) + audio_encodings = self.post_layer_norm(audio_encodings) + return residual + (audio_encodings * self.post_layer_scale) + + +class Gemma3nAudioConformerLightConv1d(nn.Module): + def __init__(self, config: Gemma3nAudioConfig): + super().__init__() + self.config = config + + self.pre_layer_norm = Gemma3nRMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps) + self.linear_start = nn.Linear(self.config.hidden_size, self.config.hidden_size * 2, bias=False) + self.depthwise_conv1d = nn.Conv1d( + in_channels=self.config.hidden_size, + out_channels=self.config.hidden_size, + kernel_size=self.config.conf_conv_kernel_size, + stride=1, + padding=0, # Manual causal padding + groups=self.config.hidden_size, # Depthwise + bias=False, + ) + self.register_buffer("gradient_clipping", torch.tensor(self.config.gradient_clipping), persistent=False) + self.conv_norm = Gemma3nRMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps) + self.linear_end = nn.Linear(self.config.hidden_size, self.config.hidden_size, bias=False) + + self.causal_padding = self.config.conf_conv_kernel_size - 1 + + def forward(self, audio_encodings: torch.Tensor) -> torch.Tensor: + audio_encodings_residual = audio_encodings # Save for residual connection + + audio_encodings = self.pre_layer_norm(audio_encodings) + audio_encodings = self.linear_start(audio_encodings) + audio_encodings = torch.nn.functional.glu(audio_encodings, dim=-1) + # Permute for Conv1d: [B, T, D] -> [B, D, T] + audio_encodings_permuted = audio_encodings.permute(0, 2, 1) + # Apply manual causal padding + audio_encodings_permuted_padded = F.pad(audio_encodings_permuted, (self.causal_padding, 0)) + audio_encodings = self.depthwise_conv1d(audio_encodings_permuted_padded) + # Permute back: [B, D, T_out] -> [B, T_out, D] + audio_encodings = audio_encodings.permute(0, 2, 1) + audio_encodings = torch.clamp(audio_encodings, -self.gradient_clipping, self.gradient_clipping) + audio_encodings = self.conv_norm(audio_encodings) + audio_encodings = nn.functional.silu(audio_encodings) + audio_encodings = self.linear_end(audio_encodings) + output = audio_encodings + audio_encodings_residual + return output + + +class Gemma3nAudioConformerBlock(nn.Module): + def __init__(self, config: Gemma3nAudioConfig): + super().__init__() + self.config = config + + self.ffw_layer_start = Gemma3nAudioConformerFeedForward(self.config) + self.attention = Gemma3nAudioConformerAttention(self.config) + self.lconv1d = Gemma3nAudioConformerLightConv1d(self.config) + self.ffw_layer_end = Gemma3nAudioConformerFeedForward(self.config) + self.register_buffer("gradient_clipping", torch.tensor(self.config.gradient_clipping), persistent=False) + self.norm = Gemma3nRMSNorm(self.config.hidden_size) + + def forward(self, audio_encodings: torch.Tensor, audio_mel_mask: torch.BoolTensor) -> torch.Tensor: + audio_encodings = self.ffw_layer_start(audio_encodings) + audio_encodings = self.attention(audio_encodings, audio_mel_mask) + validity_mask_for_lconv = ~audio_mel_mask # True for valid + audio_encodings_for_lconv_input = audio_encodings * validity_mask_for_lconv.unsqueeze(-1).to( + audio_encodings.dtype + ) + audio_encodings = self.lconv1d(audio_encodings_for_lconv_input) + + audio_encodings = self.ffw_layer_end(audio_encodings) + audio_encodings = torch.clamp(audio_encodings, -self.gradient_clipping, self.gradient_clipping) + output = self.norm(audio_encodings) + return output + + +class Gemma3nAudioEncoder(PreTrainedModel): + """A Universal Speech Encoder -- https://arxiv.org/abs/2303.01037""" + + config_class = Gemma3nAudioConfig + + main_input_name = "audio_mel" + + def __init__(self, config: Gemma3nAudioConfig): + super().__init__(config) + self.config = config + + self.subsample_conv_projection = Gemma3nAudioSubSampleConvProjection(config) + self.conformer = nn.ModuleList( + [Gemma3nAudioConformerBlock(config) for _ in range(config.conf_num_hidden_layers)] + ) + + def forward( + self, audio_mel: torch.Tensor, audio_mel_mask: torch.BoolTensor + ) -> tuple[torch.Tensor, torch.BoolTensor]: + """Encodes a batch of MELs. + + Args: + audio_mel: a torch.Tensor of shape [batch, num_frames, num_channels, + mel_bins]. + + Returns: + audio_encodings: a torch.Tensor of shape + `[batch_size, self.config.audio_soft_tokens_per_image, + self.config.audio_config.hidden_size]` + audio_mel_mask: a torch.BoolTensor of shape [batch, num_frames]. + """ + audio_encodings = self.subsample_conv_projection(audio_mel) # audio_encodings: [B, T_sub, D] + + # Subsample the input audio_mel_mask to match the time dimension of audio_encodings (T_sub) + t_sub = audio_encodings.shape[1] + + time_stride_product = 1 + for stride_pair_idx in range(len(self.config.sscp_conv_stride_size)): + time_stride_product *= self.config.sscp_conv_stride_size[stride_pair_idx][0] + + # Create indices for gathering from the original mask. + # These indices map to original time steps corresponding to the start of each + # receptive field in the subsampled output. + indices = torch.arange(t_sub, device=audio_mel_mask.device) * time_stride_product + indices = torch.clamp(indices, max=audio_mel_mask.shape[1] - 1) # Ensure indices are valid + + # Expand indices for batch compatibility if B > 1 and indices is 1D. + if audio_mel_mask.ndim > 1 and indices.ndim == 1: + indices = indices.unsqueeze(0).expand(audio_mel_mask.shape[0], -1) # [B, T_sub] + elif ( + audio_mel_mask.ndim == indices.ndim + and audio_mel_mask.shape[0] == 1 + and indices.shape[0] != 1 + and t_sub == indices.shape[0] + ): + # Handle case where B=1 but indices became [T_sub] instead of [1, T_sub] + indices = indices.unsqueeze(0) + + current_mask = torch.gather(audio_mel_mask, 1, indices) # [B, T_sub] + + for block in self.conformer: + audio_encodings = block(audio_encodings, current_mask) # Pass the processed mask + + if self.config.conf_reduction_factor > 1: + audio_encodings = audio_encodings[:, :: self.config.conf_reduction_factor] + # Reduce the mask as well + current_mask = current_mask[:, :: self.config.conf_reduction_factor] + + audio_encodings = audio_encodings.masked_fill(current_mask.unsqueeze(-1), 0.0) + return audio_encodings, current_mask + + +# ==== Language Model ==== + + +class Gemma3nTextScaledWordEmbedding(Gemma3TextScaledWordEmbedding): + pass + + +class Gemma3nTextLaurelBlock(nn.Module): + """Learned Augmented Residual Layer""" + + def __init__(self, config: Gemma3nTextConfig): + super().__init__() + self.config = config + + self.linear_left = nn.Linear(self.config.hidden_size, self.config.laurel_rank, bias=False) + self.linear_right = nn.Linear(self.config.laurel_rank, self.config.hidden_size, bias=False) + self.post_laurel_norm = Gemma3nRMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + laurel_hidden_states: torch.Tensor = self.linear_left(hidden_states) + laurel_hidden_states: torch.Tensor = self.linear_right(laurel_hidden_states) + normed_laurel_hidden_states = self.post_laurel_norm(laurel_hidden_states) + return hidden_states + normed_laurel_hidden_states + + +class Gemma3nTextMLP(Gemma2MLP): + def __init__(self, config: Gemma3nTextConfig, layer_idx: int = 0): + super().__init__(config) + self.intermediate_size = config.intermediate_size[layer_idx] + self.activation_sparsity = config.activation_sparsity_pattern[layer_idx] + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + gate_proj = self.gate_proj(hidden_states) + if self.activation_sparsity > 0.0: + gate_proj = self._gaussian_topk(gate_proj) + activations = self.act_fn(gate_proj) + up_proj = self.up_proj(hidden_states) + down_proj = self.down_proj(activations * up_proj) + return down_proj + + def _gaussian_topk(self, inputs: torch.Tensor) -> torch.Tensor: + target_sparsity_tensor = torch.tensor(self.activation_sparsity, dtype=torch.float32, device=inputs.device) + # normal_dist and std_multiplier are adapted from jax.scipy.stats.norm.ppf(). + # + # References: + # * https://docs.jax.dev/en/latest/_autosummary/jax.scipy.stats.norm.ppf.html + # * https://pytorch.org/docs/stable/distributions.html#torch.distributions.normal.Normal + # * https://pytorch.org/docs/stable/distributions.html#torch.distributions.transformed_distribution.TransformedDistribution.icdf + normal_dist = torch.distributions.normal.Normal(0, 1) + std_multiplier: torch.Tensor = normal_dist.icdf(target_sparsity_tensor) + std_multiplier = std_multiplier.type(inputs.dtype) + inputs_mean = torch.mean(inputs, dim=-1, keepdim=True) + inputs_std = torch.std(inputs, dim=-1, keepdim=True, unbiased=False) + cutoff_x = inputs_mean + inputs_std * std_multiplier + return nn.functional.relu(inputs - cutoff_x) + + +class Gemma3nTextAltUp(nn.Module): + """Alternating Updates (AltUp) + + The AltUp module wraps transformer layers. The `predict` step modifies the + input to the transformer layer, and the `correct` step propagates the output + of the transformer layer to the sparsely updated dimensions. + + See more in the research paper: + + https://proceedings.neurips.cc/paper_files/paper/2023/file/f2059277ac6ce66e7e5543001afa8bb5-Paper-Conference.pdf + """ + + def __init__(self, config: Gemma3nTextConfig): + super().__init__() + self.config = config + self.correct_output_scale = nn.Parameter(torch.zeros(self.config.hidden_size)) + self.correction_coefs = nn.Linear(self.config.altup_num_inputs, self.config.altup_num_inputs, bias=False) + self.prediction_coefs = nn.Linear(self.config.altup_num_inputs, self.config.altup_num_inputs**2, bias=False) + self.modality_router = nn.Linear(self.config.hidden_size, self.config.altup_num_inputs, bias=False) + self.router_norm = Gemma3nRMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps) + self.register_buffer("router_input_scale", torch.tensor(self.config.hidden_size**-1.0), persistent=False) + + def compute_router_modalities(self, x: torch.Tensor) -> torch.Tensor: + router_inputs = self.router_norm(x) * self.router_input_scale + routed = self.modality_router(router_inputs) + return torch.tanh(routed.float()).type_as(x) + + def predict(self, hidden_states: torch.Tensor) -> torch.Tensor: + """Predicts the output of a layer using a trainable map. + + Args: + hidden_states: A 4D tensor of shape `[num_altup_inputs, batch_size, num_tokens, hidden_size]` derived by + stacking the input embeddings and preprocessing the last `num_altup_inputs - 1` matrices. + + Returns: + A 4D tensor of shape `[num_altup_inputs, batch_size, num_tokens, hidden_size]` containing the predictions. + """ + modalities = self.compute_router_modalities(hidden_states[self.config.altup_active_idx]) + + if self.training and self.config.altup_coef_clip is not None: + self.prediction_coefs.weight.data.clamp_(-self.config.altup_coef_clip, self.config.altup_coef_clip) + + # Project and then transpose all 2D matrices contained so that mulmat gives the correct result + all_coefs: torch.Tensor = ( + self.prediction_coefs(modalities) + .reshape(*modalities.shape[:-1], self.config.altup_num_inputs, self.config.altup_num_inputs) + .permute(0, 1, 3, 2) + ) + + # permute hidden_states to [batch_size, num_tokens, hidden_size, altup_num_inputs] + predictions = torch.matmul(hidden_states.permute(1, 2, 3, 0), all_coefs) + predictions = predictions.permute(3, 0, 1, 2) # undo the permute + predictions += hidden_states # add the original input + return predictions.contiguous().type_as(hidden_states) + + def correct(self, predictions: torch.Tensor, activated: torch.Tensor) -> torch.Tensor: + """Corrects the predictions relative to the + + Args: + predictions: A 4D tensor of shape `[num_altup_inputs, batch_size, num_tokens, hidden_size]` derived by + stacking the input embeddings and preprocessing the last `num_altup_inputs - 1` matrices. + activated: A 3D tensor of shape `[batch_size, num_tokens, hidden_size]` containing the activated inputs. + + Returns: + A 4D tensor of shape `[num_altup_inputs, batch_size, num_tokens, hidden_size]` correcting the original + predictions relative to the activated input embeddings. + """ + modalities = self.compute_router_modalities(activated) + innovation = activated - predictions[self.config.altup_active_idx] # (batch, num_tokens, hidden_size) + innovation = innovation.repeat(self.config.altup_num_inputs, 1, 1, 1) # Repeat on dim0 to match predictions + + if self.config.altup_coef_clip is not None: + self.correction_coefs.weight.data.clamp_(-self.config.altup_coef_clip, self.config.altup_coef_clip) + + # all_coefs adapted from jax.numpy.einsum("...p,pi->...i", ...) + # Permute to (altup_num_inputs, batch_size, num_tokens) as the last dim is a scalar applied to each altup input + # and expand on dim1 for broadcastability + all_coefs: torch.Tensor = self.correction_coefs(modalities) + 1.0 + all_coefs = all_coefs.permute(2, 0, 1).unsqueeze(-1) + + corrected = torch.mul(innovation, all_coefs) + corrected += predictions # add the original input + return corrected.contiguous().type_as(activated) + + def scale_corrected_output(self, corrected: torch.Tensor) -> torch.Tensor: + """Scales the provided 3D tensor of shape [batch_size, num_tokens, hidden_size].""" + return (corrected.type_as(self.correct_output_scale) * self.correct_output_scale).type_as(corrected) + + +class Gemma3nTextRotaryEmbedding(Gemma2RotaryEmbedding): + pass + + +def apply_rotary_pos_emb( + x: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + position_ids: Optional[torch.Tensor] = None, + unsqueeze_dim: int = 1, +): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + x (`torch.Tensor`): The tensor to embed. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + return (x * cos) + (rotate_half(x) * sin) + + +class Gemma3nTextAttention(Gemma3Attention): + def __init__(self, config: Gemma3nTextConfig, layer_idx: int): + super().__init__() + del self.attn_logit_softcapping + del self.scaling + self.v_norm = Gemma3nRMSNorm(dim=config.head_dim, eps=config.rms_norm_eps, with_scale=False) + + first_kv_shared_layer_idx = self.config.num_hidden_layers - self.config.num_kv_shared_layers + self.is_kv_shared_layer = layer_idx >= first_kv_shared_layer_idx + # Find the index of the last sliding or full layer before sharing starts (or None if no sharing) + layer_type = config.layer_types[layer_idx] + self.kv_shared_layer_index = ( + first_kv_shared_layer_idx - 1 - config.layer_types[first_kv_shared_layer_idx - 1 :: -1].index(layer_type) + if self.is_kv_shared_layer + else None + ) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: torch.Tensor, + attention_mask: Optional[torch.Tensor], + past_key_value: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.config.head_dim) + + cos, sin = position_embeddings + + query_states = self.q_proj(hidden_states).view(hidden_shape) + query_states = self.q_norm(query_states) + query_states = apply_rotary_pos_emb(query_states, cos, sin, unsqueeze_dim=2) + query_states = query_states.transpose(1, 2) + + if self.is_kv_shared_layer and self.kv_shared_layer_index is not None and past_key_value is not None: + # HybridCache has complex slicing when layer_type == "sliding_attention" that impact Shared KV Cache. + if isinstance(past_key_value, HybridCache) and self.is_sliding: + max_length = past_key_value.sliding_window + if cache_position.shape[0] > max_length: + # If in the prefill phase for a "sliding_attention" layer and the prefill is larger than the cache, + # slice into the entire cache. + indices = slice(0, max_length) + else: + # If prefill fits or generating for a "sliding_attention" layer, clamp to max_cache_len - 1 + indices = cache_position.clamp(min=0, max=max_length - 1) + else: + indices = cache_position + + key_states = past_key_value.key_cache[self.kv_shared_layer_index][:, :, indices] + value_states = past_key_value.value_cache[self.kv_shared_layer_index][:, :, indices] + else: + key_states = self.k_proj(hidden_states).view(hidden_shape) + key_states = self.k_norm(key_states) + key_states = apply_rotary_pos_emb(key_states, cos, sin, unsqueeze_dim=2) + key_states = key_states.transpose(1, 2) + + value_states = self.v_proj(hidden_states).view(hidden_shape) + value_states = self.v_norm(value_states) + value_states = value_states.transpose(1, 2) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = { + "sin": sin, + "cos": cos, + "cache_position": cache_position, + "sliding_window": self.sliding_window, + } + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=self.attention_dropout if self.training else 0.0, + scaling=1.0, + sliding_window=self.sliding_window, + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +class Gemma3nTextDecoderLayer(Gemma3DecoderLayer): + def __init__(self, config: Gemma3nTextConfig, layer_idx: int): + super().__init__(config, layer_idx) + self.mlp = Gemma3nTextMLP(config, layer_idx=layer_idx) + + self.hidden_size_per_layer_input = config.hidden_size_per_layer_input + self.act_fn = ACT2FN[config.hidden_activation] + + self.altup = Gemma3nTextAltUp(config) + self.laurel = Gemma3nTextLaurelBlock(config) + self.self_attn = Gemma3nTextAttention(config, layer_idx) + self.per_layer_input_gate = nn.Linear(self.hidden_size, self.hidden_size_per_layer_input, bias=False) + self.per_layer_projection = nn.Linear(self.hidden_size_per_layer_input, self.hidden_size, bias=False) + self.post_per_layer_input_norm = Gemma3nRMSNorm(self.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings_global: torch.Tensor, + position_embeddings_local: torch.Tensor, + per_layer_input: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> tuple[torch.Tensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: + predictions = self.altup.predict(hidden_states) + active_prediction = predictions[self.config.altup_active_idx] + + active_prediction_normed = self.input_layernorm(active_prediction) + laurel_output = self.laurel(active_prediction_normed) + + # apply global RoPE to non-sliding layer only + if self.self_attn.is_sliding: + position_embeddings = position_embeddings_local + else: + position_embeddings = position_embeddings_global + + attn, self_attn_weights = self.self_attn( + hidden_states=active_prediction_normed, + position_embeddings=position_embeddings, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + attn = self.post_attention_layernorm(attn) + + attn_gated = active_prediction + attn + attn_laurel = (attn_gated + laurel_output) / math.sqrt(2) + + attn_norm = self.pre_feedforward_layernorm(attn_laurel) + attn_ffw = self.mlp(attn_norm) + attn_ffw_norm = self.post_feedforward_layernorm(attn_ffw) + attn_ffw_laurel_gated = attn_laurel + attn_ffw_norm + corrected_predictions = self.altup.correct(predictions, attn_ffw_laurel_gated) + + first_prediction = corrected_predictions[self.config.altup_active_idx] + first_prediction_clone = first_prediction.clone() + if self.config.altup_correct_scale: + first_prediction = self.altup.scale_corrected_output(first_prediction_clone) + + # per_layer_input_gate adapted from jax.numpy.einsum("btd,dp->btp", ...) + first_prediction = self.per_layer_input_gate(first_prediction) + first_prediction = self.act_fn(first_prediction) + first_prediction = torch.multiply(first_prediction, per_layer_input) + + # per_layer_projection adapted from jax.numpy.einsum("btp,pd->btd", ...) + first_prediction = self.per_layer_projection(first_prediction) + first_prediction = self.post_per_layer_input_norm(first_prediction) + corrected_predictions[1:] += first_prediction + + outputs = (corrected_predictions,) + + if output_attentions: + outputs += (self_attn_weights,) + + return outputs + + +class Gemma3nPreTrainedModel(Gemma2PreTrainedModel): + config_class = Gemma3nConfig + base_model_prefix = "" + _no_split_modules = ["Gemma3nDecoderLayer"] + + def _init_weights(self, module): + # important: this ported version of Gemma2 isn't meant for training from scratch - only + # inference and fine-tuning - so the proper init weights code has been removed + std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range) + + if isinstance(module, (nn.Linear, nn.Conv1d, nn.Conv2d)): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, Gemma3nRMSNorm): + if module.with_scale: + module.weight.data.fill_(1.0) + elif isinstance(module, Gemma3nAudioCumulativeGroupNorm): + module.weight.data.fill_(1.0) + elif isinstance(module, Gemma3nAudioAttention): + module.per_dim_scale.data.zero_() + elif isinstance(module, Gemma3nTextAltUp): + module.correct_output_scale.data.zero_() + + +@auto_docstring(custom_intro="The base Gemma 3n language model without a language modeling head.") +class Gemma3nTextModel(Gemma3TextModel): + config_class = Gemma3nTextConfig + + def __init__(self, config: Gemma3nTextConfig): + super().__init__(config) + + self.hidden_size = config.hidden_size + self.hidden_size_per_layer_input = config.hidden_size_per_layer_input + + self.embed_tokens_per_layer = Gemma3nTextScaledWordEmbedding( + config.vocab_size_per_layer_input, + config.num_hidden_layers * config.hidden_size_per_layer_input, + self.padding_idx, + embed_scale=config.hidden_size_per_layer_input**0.5, + ) + + self.per_layer_model_projection = nn.Linear( + self.hidden_size, + config.num_hidden_layers * config.hidden_size_per_layer_input, + bias=False, + ) + + self.per_layer_projection_norm = Gemma3nRMSNorm(config.hidden_size_per_layer_input, eps=config.rms_norm_eps) + self.layers = nn.ModuleList( + [Gemma3nTextDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + + self.norm = Gemma3nRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + self.altup_projections = nn.ModuleList( + [nn.Linear(self.hidden_size, self.hidden_size, bias=False) for _ in range(1, self.config.altup_num_inputs)] + ) + + self.altup_unembed_projections = nn.ModuleList( + [nn.Linear(self.hidden_size, self.hidden_size, bias=False) for _ in range(1, self.config.altup_num_inputs)] + ) + + self.register_buffer("per_layer_projection_scale", torch.tensor(self.hidden_size**-0.5), persistent=False) + self.register_buffer("per_layer_input_scale", torch.rsqrt(torch.tensor(2.0)), persistent=False) + self.rotary_emb = Gemma3nTextRotaryEmbedding(config=config) + + # TODO (raushan): Fix this after RoPE refactor. For now we hack it by + # reassigning thetas when we want to create a local RoPE layer. Config + # defaults should hold values for global RoPE. + config = copy.deepcopy(config) + config.rope_theta = config.rope_local_base_freq + config.rope_scaling = {"rope_type": "default"} + self.rotary_emb_local = Gemma3nTextRotaryEmbedding(config=config) + + def get_per_layer_inputs(self, input_ids: torch.LongTensor) -> torch.Tensor: + return self.embed_tokens_per_layer(input_ids).reshape( + *input_ids.shape, + self.config.num_hidden_layers, + self.hidden_size_per_layer_input, + ) + + def project_per_layer_inputs( + self, + inputs_embeds: torch.Tensor, + per_layer_inputs: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + per_layer_projection: torch.Tensor = self.per_layer_model_projection(inputs_embeds) + per_layer_projection *= self.per_layer_projection_scale.type(inputs_embeds.dtype) + per_layer_projection = per_layer_projection.reshape( + *inputs_embeds.shape[:-1], + self.config.num_hidden_layers, + self.hidden_size_per_layer_input, + ) + per_layer_projection = self.per_layer_projection_norm(per_layer_projection) + + if per_layer_inputs is None: + return per_layer_projection + + if per_layer_projection.shape != per_layer_inputs.shape: + # per-layer inputs are sometimes padded with zeros, slice the relevant embeddings. + per_layer_inputs = per_layer_inputs[..., : self.config.num_hidden_layers, :] + + return (per_layer_projection + per_layer_inputs) * self.per_layer_input_scale.type(inputs_embeds.dtype) + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + per_layer_inputs: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **flash_attn_kwargs: Unpack[FlashAttentionKwargs], + ) -> BaseModelOutputWithPast: + r""" + per_layer_inputs (torch.Tensor, *optional*, defaults to None): + Pre-computed per-layer embeddings. If None, they are derived from input_ids if provided. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + ) + use_cache = False + + if input_ids is not None: + inputs_embeds = self.embed_tokens(input_ids) + per_layer_inputs = self.get_per_layer_inputs(input_ids) + + per_layer_inputs = self.project_per_layer_inputs(inputs_embeds, per_layer_inputs) + + if use_cache and past_key_values is None and not self.training: + past_key_values = DynamicCache() + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, + past_seen_tokens + inputs_embeds.shape[1], + device=inputs_embeds.device, + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + # It may already have been prepared by e.g. `generate` + if not isinstance(causal_mask_mapping := attention_mask, dict): + # Prepare mask arguments + mask_kwargs = { + "config": self.config, + "input_embeds": inputs_embeds, + "attention_mask": attention_mask, + "cache_position": cache_position, + "past_key_values": past_key_values, + } + # Create the masks + causal_mask_mapping = { + "full_attention": create_causal_mask(**mask_kwargs), + "sliding_attention": create_sliding_window_causal_mask(**mask_kwargs), + } + + # embed positions + hidden_states_0 = inputs_embeds + + # Initialize RoPE embeddings + position_embeddings_global = self.rotary_emb(hidden_states_0, position_ids) + position_embeddings_local = self.rotary_emb_local(hidden_states_0, position_ids) + + # Expand hidden_states to support per-layer inputs + target_magnitude: torch.Tensor = torch.mean(hidden_states_0**2, dim=-1, keepdim=True) ** 0.5 + epsilon_tensor = torch.tensor(torch.finfo().min) + + temp_hidden_states = [hidden_states_0] + for i in range(1, self.config.altup_num_inputs): + # altup_proj adapted from jax.numpy.einsum("btp,pd->btd", ...) + altup_proj: torch.Tensor = self.altup_projections[i - 1](hidden_states_0) + current_hidden_state = altup_proj.type(hidden_states_0.dtype) + new_magnitude = torch.mean(current_hidden_state**2, dim=-1, keepdim=True) ** 0.5 + current_hidden_state = current_hidden_state * ( + target_magnitude / torch.maximum(new_magnitude, epsilon_tensor) + ) + temp_hidden_states.append(current_hidden_state) + + hidden_states = torch.stack(temp_hidden_states, dim=0) # [num_altup_inputs, batch, seq_len, hidden_size] + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + + for decoder_layer in self.layers[: self.config.num_hidden_layers]: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + causal_mask = causal_mask_mapping[decoder_layer.attention_type] + per_layer_input = per_layer_inputs[:, :, decoder_layer.layer_idx, :] + + layer_outputs = decoder_layer( + hidden_states, + position_embeddings_global=position_embeddings_global, + position_embeddings_local=position_embeddings_local, + per_layer_input=per_layer_input, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + **flash_attn_kwargs, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + # add hidden states from the last decoder layer (but before reprojecting to stay consistent with layer output) + if output_hidden_states: + all_hidden_states += (hidden_states,) + + # Per-layer inputs to single output + target_magnitude = torch.mean(hidden_states[0] ** 2, dim=-1, keepdim=True) ** 0.5 + temp_hidden_states = [hidden_states[0]] + for i in range(1, self.config.altup_num_inputs): + # altup_unembed_projections adapted from jax.numpy.einsum("btp,pd->btd", ...) + altup_unemb_proj: torch.Tensor = self.altup_unembed_projections[i - 1](hidden_states[i]) + current_hidden_state = altup_unemb_proj.type(hidden_states_0.dtype) + new_magnitude = torch.mean(current_hidden_state**2, dim=-1, keepdim=True) ** 0.5 + current_hidden_state = current_hidden_state * ( + target_magnitude / torch.maximum(new_magnitude, epsilon_tensor) + ) + temp_hidden_states.append(current_hidden_state) + + hidden_states = torch.stack(temp_hidden_states) + hidden_states = torch.mean(hidden_states, dim=0) + hidden_states = self.norm(hidden_states) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + +@auto_docstring(custom_intro="The base Gemma 3n language model with a language modeling head.") +class Gemma3nForCausalLM(Gemma3ForCausalLM): + _checkpoint_conversion_mapping = {"model.language_model": "model"} + base_model_prefix = "model" + + +class Gemma3nMultimodalEmbedder(nn.Module): + """Embeds token ids or soft tokens for multimodal content into language model space.""" + + def __init__( + self, + multimodal_config: Union[Gemma3nAudioConfig, Gemma3nVisionConfig], + text_config: Gemma3nTextConfig, + ): + super().__init__() + + self.multimodal_hidden_size = multimodal_config.hidden_size + self.eps = multimodal_config.rms_norm_eps + self.vocab_offset = multimodal_config.vocab_offset + self.vocab_size = multimodal_config.vocab_size + self.text_hidden_size = text_config.hidden_size + + self.embedding = nn.Embedding(self.vocab_size, self.multimodal_hidden_size) + self.hard_embedding_norm = Gemma3nRMSNorm(self.multimodal_hidden_size, eps=self.eps) + self.soft_embedding_norm = Gemma3nRMSNorm(self.multimodal_hidden_size, eps=self.eps) + self.embedding_projection = nn.Linear(self.multimodal_hidden_size, self.text_hidden_size, bias=False) + self.embedding_post_projection_norm = Gemma3nRMSNorm(self.text_hidden_size, eps=self.eps, with_scale=False) + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Embeds token ids or soft tokens for multimodal content into language model space. + + Args: + input_ids: A torch.LongTensor containing the token ids to embed. Values should be in the range + `[vocab_offset, vocab_offset + vocab_size)`. + inputs_embeds: A torch.Tensor containing the soft tokens to embed. + + Returns: + A torch.Tensor of embeddings with shape `[batch_size, seq_len, self.config.text_config.hidden_size]`. + """ + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if inputs_embeds is not None: + emb_norm = self.soft_embedding_norm(inputs_embeds) + else: + hard_emb = self.embedding(input_ids - self.vocab_offset) + emb_norm = self.hard_embedding_norm(hard_emb) + + emb_norm_proj = self.embedding_projection(emb_norm) + return self.embedding_post_projection_norm(emb_norm_proj) + + +@auto_docstring( + custom_intro=""" + The base Gemma 3n model comprising a vision backbone, an audio backbone, and a language model without a + language modeling head. + """ +) +class Gemma3nModel(PaliGemmaModel): + _checkpoint_conversion_mapping = {} + + def __init__(self, config: Gemma3nConfig): + super().__init__() + del self.multi_modal_projector # Replaced by Gemma3nVisionEmbedder + self.vocab_size_per_layer_input = config.text_config.vocab_size_per_layer_input + self.audio_tower = AutoModel.from_config(config.audio_config) + self.embed_vision = Gemma3nMultimodalEmbedder(config.vision_config, config.text_config) + self.embed_audio = Gemma3nMultimodalEmbedder(config.audio_config, config.text_config) + + def get_image_features(self, pixel_values: torch.Tensor) -> torch.Tensor: + """ + Projects the last hidden state from the vision model into language model space. + + Args: + pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`) + The tensors corresponding to the input images. + + Returns: + image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`). + """ + vision_outputs = self.vision_tower( + pixel_values=pixel_values, do_pooling=False, return_dict=True + ).last_hidden_state + # Convert from (batch, channels, height, width) to (batch, height * width, channels) where: + # height == width and height * width == Gemma3nConfig.vision_soft_tokens_per_image. + vision_outputs = vision_outputs.reshape( + vision_outputs.shape[0], + self.config.vision_config.hidden_size, + self.config.vision_soft_tokens_per_image, + ).permute(0, 2, 1) + # Normalize and embed the soft tokens into language model space. + vision_outputs *= self.config.vision_config.hidden_size**0.5 + return self.embed_vision(inputs_embeds=vision_outputs) + + @can_return_tuple + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, # text inputs + pixel_values: Optional[torch.FloatTensor] = None, # vision inputs + input_features: Optional[torch.FloatTensor] = None, # audio inputs + attention_mask: Optional[torch.Tensor] = None, + input_features_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[list[torch.FloatTensor], Cache]] = None, + token_type_ids: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + **lm_kwargs, + ) -> Gemma3nCausalLMOutputWithPast: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.text_config.vocab_size]`. + + Example: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, Gemma3nForConditionalGeneration + + >>> model = Gemma3nForConditionalGeneration.from_pretrained("google/gemma3n2-3b-mix-224") + >>> processor = AutoProcessor.from_pretrained("google/gemma3n2-3b-mix-224") + + >>> prompt = "Where is the cat standing?" + >>> url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> inputs = processor(images=image, text=prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(**inputs,) + >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Where is the cat standing?\nsnow" + ``` + """ + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + + if input_ids is not None: + inputs_embeds = self.get_input_embeddings()(input_ids) + + # Prepare per-layer inputs from inputs_ids + per_layer_inputs_mask = torch.logical_and(input_ids >= 0, input_ids < self.vocab_size_per_layer_input) + per_layer_inputs_tokens = torch.where(per_layer_inputs_mask, input_ids, torch.zeros_like(input_ids)) + per_layer_inputs = self.language_model.get_per_layer_inputs(per_layer_inputs_tokens) + + # Handle vision tokens (>= embed_vision.vocab_offset and < embed_audio.vocab_offset) + vision_mask = torch.logical_and( + input_ids >= self.embed_vision.vocab_offset, input_ids < self.embed_audio.vocab_offset + ) + dummy_vision_token_id = self.embed_vision.vocab_offset + self.embed_vision.vocab_size - 1 + vision_input_ids = torch.where(vision_mask, input_ids, dummy_vision_token_id).to(inputs_embeds.device) + vision_embeds = self.embed_vision(input_ids=vision_input_ids) + expanded_vision_mask = vision_mask.unsqueeze(-1).expand_as(inputs_embeds) + inputs_embeds = torch.where(expanded_vision_mask, vision_embeds, inputs_embeds) + + # Handle audio tokens (>= embed_audio.vocab_offset) + audio_mask = input_ids >= self.embed_audio.vocab_offset + dummy_audio_token_id = self.embed_audio.vocab_offset + self.embed_audio.vocab_size - 1 + audio_input_ids = torch.where(audio_mask, input_ids, dummy_audio_token_id).to(inputs_embeds.device) + audio_embeds = self.embed_audio(input_ids=audio_input_ids) + expanded_audio_mask = audio_mask.unsqueeze(-1).expand_as(inputs_embeds) + inputs_embeds = torch.where(expanded_audio_mask, audio_embeds, inputs_embeds) + else: + per_layer_inputs = None + + # Merge text and images + if pixel_values is not None: + image_features = self.get_image_features(pixel_values) + + if input_ids is None: + special_image_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + else: + special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1) + special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) + + if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel(): + image_tokens_in_text = (special_image_mask).sum(dim=1).sum(dim=0)[0] + raise ValueError( + f"Number of images does not match number of special image tokens in the input text. " + f"Got {image_tokens_in_text} image tokens in the text and " + f"{image_features.shape[0] * image_features.shape[1]} tokens from image embeddings." + ) + image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) + + # Merge text and audio + if input_features is not None and input_features_mask is not None: + audio_features, audio_mask = self.get_audio_features(input_features, ~input_features_mask) + + # The Gemma3nProcessor expects all audio will be 30s in length and inserts 188 audio soft tokens into the + # text to account for this. However, the audio preprocessing and encoder do not gurarantee they will + # produce 188 soft tokens; they will produce at most that many tokens, but they may produce fewer tokens + # depending on the length of the longest audio input in the batch. When we encounter this situation, we pad + # the audio feature out to 188 soft tokens with the emebedding of the last token in the embed_audio vocab. + audio_padding_toks = torch.tensor([[self.vocab_size - 1]], dtype=torch.long, device=audio_features.device) + audio_padding_embs = self.embed_audio(input_ids=audio_padding_toks) + audio_features = torch.where(audio_mask.unsqueeze(-1), audio_padding_embs, audio_features) + + audio_batch_size, audio_seq_len, audio_embed_dim = audio_features.shape + extra_padding_tokens = self.config.audio_soft_tokens_per_image - audio_seq_len + extra_padding_features = audio_padding_embs.expand(audio_batch_size, extra_padding_tokens, audio_embed_dim) + + audio_features = torch.cat((audio_features, extra_padding_features), dim=1) + + if input_ids is None: + special_audio_mask = inputs_embeds == self.embed_audio( + input_ids=torch.tensor(self.config.audio_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + else: + special_audio_mask = (input_ids == self.config.audio_token_id).unsqueeze(-1) + special_audio_mask = special_audio_mask.expand_as(inputs_embeds).to(inputs_embeds.device) + + if not is_torchdynamo_compiling() and inputs_embeds[special_audio_mask].numel() != audio_features.numel(): + audio_tokens_in_text = (special_audio_mask).sum(dim=1).sum(dim=0)[0] + raise ValueError( + f"Number of audio input features does not match number of special audio tokens in the input text. " + f"Got {audio_tokens_in_text} audio tokens in the text and " + f"{audio_features.shape[0] * audio_features.shape[1]} tokens from audio embeddings." + ) + audio_features = audio_features.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(special_audio_mask, audio_features) + + outputs = self.language_model( + input_ids=None, + per_layer_inputs=per_layer_inputs, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=True, + cache_position=cache_position, + **lm_kwargs, + ) + + return Gemma3nModelOutputWithPast( + last_hidden_state=outputs.last_hidden_state, + past_key_values=outputs.past_key_values if use_cache else None, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + image_hidden_states=image_features if pixel_values is not None else None, + audio_hidden_states=audio_features if input_features is not None else None, + ) + + def get_audio_features( + self, input_features: torch.Tensor, input_features_mask: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Projects the last hidden state from the audio encoder into language model space. + + Args: + input_features (`torch.FloatTensor]` of shape `(num_images, seq_length, num_features)`): + The tensors corresponding to the input audio. + input_features (`torch.FloatTensor]` of shape `(num_images, seq_length)`): + The attention mask for the input audio. + + Returns: + audio_features (`torch.Tensor`): Audio feature tensor of shape `(num_images, audio_length, embed_dim)`). + """ + audio_outputs, audio_mask = self.audio_tower(input_features, input_features_mask) + return self.embed_audio(inputs_embeds=audio_outputs), audio_mask + + def _update_causal_mask(self, **super_kwargs): + raise AttributeError("We don't want to inherit it") + + +@auto_docstring( + custom_intro=""" + The base Gemma 3n model comprising a vision backbone, an audio backbone, a language model, and a language modeling + head. + """ +) +class Gemma3nForConditionalGeneration(PaliGemmaForConditionalGeneration): + _checkpoint_conversion_mapping = {} + base_model_prefix = "model" + + @property + def audio_tower(self): + return self.model.audio_tower + + @property + def multi_modal_projector(self): + raise AttributeError("Use embed_vision instead of multi_modal_projector.") + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, # text inputs + pixel_values: Optional[torch.FloatTensor] = None, # vision inputs + input_features: Optional[torch.FloatTensor] = None, # audio inputs + attention_mask: Optional[torch.Tensor] = None, + input_features_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[list[torch.FloatTensor], Cache]] = None, + token_type_ids: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **lm_kwargs, + ) -> Gemma3nCausalLMOutputWithPast: + r""" + input_features (torch.Tensor, *optional*, defaults to None): + The audio inputs to be encoded. + input_features_mask (torch.Tensor, *optional*, defaults to None): + The attention mask for the input audio. + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are + ignored (masked), the loss is only computed for the tokens with labels in + `[0, ..., config.text_config.vocab_size]`. + + Example: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, Gemma3ForConditionalGeneration + + >>> model = Gemma3ForConditionalGeneration.from_pretrained("google/gemma-3-4b-it") + >>> processor = AutoProcessor.from_pretrained("google/gemma-3-4b-it") + + >>> messages = [ + ... { + ... "role": "system", + ... "content": [ + ... {"type": "text", "text": "You are a helpful assistant."} + ... ] + ... }, + ... { + ... "role": "user", "content": [ + ... {"type": "image", "url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg"}, + ... {"type": "text", "text": "Where is the cat standing?"}, + ... ] + ... }, + ... ] + + >>> inputs = processor.apply_chat_template( + ... messages, + ... tokenizer=True, + ... return_dict=True, + ... return_tensors="pt", + ... add_generation_prompt=True + ... ) + >>> # Generate + >>> generate_ids = model.generate(**inputs) + >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "user\nYou are a helpful assistant.\n\n\n\n\n\nWhere is the cat standing?\nmodel\nBased on the image, the cat is standing in a snowy area, likely outdoors. It appears to" + ``` + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + + outputs = self.model( + input_ids=input_ids, + pixel_values=pixel_values, + input_features=input_features, + attention_mask=attention_mask, + input_features_mask=input_features_mask, + position_ids=position_ids, + past_key_values=past_key_values, + token_type_ids=token_type_ids, + cache_position=cache_position, + inputs_embeds=inputs_embeds, + labels=labels, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=True, + **lm_kwargs, + ) + + hidden_states = outputs.last_hidden_state + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + if (final_logit_softcapping := self.config.get_text_config().final_logit_softcapping) is not None: + logits = logits / final_logit_softcapping + logits = torch.tanh(logits) + logits = logits * final_logit_softcapping + + loss = None + if labels is not None: + # Upcast to float if we need to compute the loss to avoid potential precision issues + logits = logits.float() + shift_logits = logits[..., :-1, :] + shift_labels = labels[..., 1:] + if attention_mask is not None: + # we use the input attention mask to shift the logits and labels, because it is 2D. + # we also crop attn mask in case it is longer, which happens in PrefixTuning with peft + shift_attention_mask = attention_mask[:, -shift_logits.shape[1] :].to(logits.device) + shift_logits = shift_logits[shift_attention_mask.to(logits.device) != 0].contiguous() + shift_labels = shift_labels[shift_attention_mask.to(shift_labels.device) != 0].contiguous() + else: + shift_logits = shift_logits.contiguous() + shift_labels = shift_labels.contiguous() + # Flatten the tokens + loss_fct = nn.CrossEntropyLoss() + + flat_logits = shift_logits.view(-1, self.config.text_config.vocab_size) + flat_labels = shift_labels.view(-1).to(shift_logits.device) + loss = loss_fct(flat_logits, flat_labels) + + return Gemma3nCausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + image_hidden_states=outputs.image_hidden_states, + audio_hidden_states=outputs.audio_hidden_states, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + inputs_embeds=None, + cache_position=None, + position_ids=None, + pixel_values=None, + input_features=None, + attention_mask=None, + input_features_mask=None, + token_type_ids=None, + use_cache=True, + logits_to_keep=None, + labels=None, + **kwargs, + ): + # Overwritten -- custom `position_ids` and `pixel_values` handling + model_inputs = super().prepare_inputs_for_generation( + input_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + position_ids=position_ids, + cache_position=cache_position, + use_cache=use_cache, + logits_to_keep=logits_to_keep, + token_type_ids=token_type_ids, + **kwargs, + ) + + # If we're in cached decoding stage, multimodal inputs should be None because input ids do not contain special + # tokens anymore. Otherwise multimodal inputs should be passed to model. + # NOTE: use_cache=False always needs pixel_values, input_features, and input_features_mask + if cache_position[0] == 0: + model_inputs["pixel_values"] = pixel_values + model_inputs["input_features"] = input_features + model_inputs["input_features_mask"] = input_features_mask + + return model_inputs + + def _prepare_4d_causal_attention_mask_with_cache_position(self, **super_kwargs): + raise AttributeError("Do not inherit _prepare_4d_causal_attention_mask_with_cache_position from PaliGemma") + + +__all__ = [ + "Gemma3nAudioConfig", + "Gemma3nAudioEncoder", + "Gemma3nConfig", + "Gemma3nForCausalLM", + "Gemma3nForConditionalGeneration", + "Gemma3nModel", + "Gemma3nPreTrainedModel", # noqa: F822 + "Gemma3nTextConfig", + "Gemma3nTextModel", + "Gemma3nVisionConfig", +] diff --git a/src/transformers/models/gemma3n/processing_gemma3n.py b/src/transformers/models/gemma3n/processing_gemma3n.py new file mode 100644 index 000000000000..45e953b5c5d2 --- /dev/null +++ b/src/transformers/models/gemma3n/processing_gemma3n.py @@ -0,0 +1,191 @@ +# coding=utf-8 +# Copyright 2025 Google Inc. HuggingFace Inc. team. All rights reserved. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Optional, Union + +import numpy as np + +from ...feature_extraction_utils import BatchFeature +from ...image_utils import ImageInput, make_nested_list_of_images +from ...processing_utils import AudioKwargs, ImagesKwargs, ProcessingKwargs, ProcessorMixin, Unpack +from ...tokenization_utils_base import PreTokenizedInput, TextInput + + +class Gemma3nImagesKwargs(ImagesKwargs): + do_pan_and_scan: Optional[bool] + pan_and_scan_min_crop_size: Optional[int] + pan_and_scan_max_num_crops: Optional[int] + pan_and_scan_min_ratio_to_activate: Optional[float] + do_convert_rgb: Optional[bool] + + +class Gemma3nProcessorKwargs(ProcessingKwargs, total=False): + audio_kwargs: AudioKwargs + images_kwargs: Gemma3nImagesKwargs + _defaults = { + "text_kwargs": { + "padding": False, + }, + } + + +class Gemma3nProcessor(ProcessorMixin): + """ + A processor for Gemma 3n, wrapping the full capabilities of a feature extractor, image processor, and tokenizer + into a single processor. + + Args: + feature_extractor (`Gemma3nAudioFeatureExtractor`): + Feature extractor that converts raw audio waveforms into MEL spectrograms for the audio encoder. This + should return a `BatchFeature` with `input_features` and `input_features_mask` features. + image_processor (`SiglipImageProcessorFast`): + Image processor that prepares batches of images for the vision encoder. This should return a `BatchFeature` + with a `pixel_values` feature. + tokenizer (`GemmaTokenizerFast`): + The text tokenizer for the model. + chat_template (`string`, *optional*): + A Jinja template for generating text prompts from a set of messages. + audio_seq_length (int, *optional*, defaults to 188): + The number of audio soft tokens that will be added to the text prompt + image_seq_length (int, *optional*, defaults to 256): + The number of image soft tokens that should be added to + """ + + attributes = ["feature_extractor", "image_processor", "tokenizer"] + feature_extractor_class = "AutoFeatureExtractor" + image_processor_class = "AutoImageProcessor" + tokenizer_class = "AutoTokenizer" + + def __init__( + self, + feature_extractor, + image_processor, + tokenizer, + chat_template=None, + audio_seq_length: int = 188, + image_seq_length: int = 256, + **kwargs, + ): + self.audio_seq_length = audio_seq_length + self.audio_token_id = tokenizer.audio_token_id + self.boa_token = tokenizer.boa_token + self.audio_token = tokenizer.audio_token + audio_tokens_expanded = "".join([tokenizer.audio_token] * audio_seq_length) + self.full_audio_sequence = f"\n\n{tokenizer.boa_token}{audio_tokens_expanded}{tokenizer.eoa_token}\n\n" + + self.image_seq_length = image_seq_length + self.image_token_id = tokenizer.image_token_id + self.boi_token = tokenizer.boi_token + self.image_token = tokenizer.image_token + image_tokens_expanded = "".join([tokenizer.image_token] * image_seq_length) + self.full_image_sequence = f"\n\n{tokenizer.boi_token}{image_tokens_expanded}{tokenizer.eoi_token}\n\n" + + super().__init__( + feature_extractor=feature_extractor, + image_processor=image_processor, + tokenizer=tokenizer, + chat_template=chat_template, + **kwargs, + ) + + def __call__( + self, + images: ImageInput = None, + text: Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]] = None, + audio: Optional[Union[np.ndarray, list[float], list[np.ndarray], list[list[float]]]] = None, + videos=None, + **kwargs: Unpack[Gemma3nProcessorKwargs], + ) -> BatchFeature: + if text is None and images is None and audio is None: + raise ValueError("Provide at least one of `text`, `images`, or `audio`.") + + output_kwargs = self._merge_kwargs( + Gemma3nProcessorKwargs, + tokenizer_init_kwargs=self.tokenizer.init_kwargs, + **kwargs, + ) + + if isinstance(text, str): + text = [text] + elif not isinstance(text, list) and not isinstance(text[0], str): + raise ValueError("Invalid input text. Please provide a string, or a list of strings") + + if audio is not None: + audio_inputs = self.feature_extractor(audio, **output_kwargs["audio_kwargs"]) + + if not text: + text = [self.audio_token for _ in audio] + + # Expand placeholder audio tokens to the full audio token sequence + text = [prompt.replace(self.audio_token, self.full_audio_sequence) for prompt in text] + else: + audio_inputs = {} + + if images is not None: + batched_images = make_nested_list_of_images(images) + image_inputs = self.image_processor(batched_images, **output_kwargs["images_kwargs"]) + + # Create empty text to be replaced with placeholders + if not text: + text = [" ".join([self.image_token] * len(images)) for images in batched_images] + + if len(batched_images) != len(text): + raise ValueError( + f"Received inconsistently sized batches of images ({len(batched_images)}) and text ({len(text)})." + ) + + # Expand placeholder image tokens to the full image token sequence + text = [prompt.replace(self.image_token, self.full_image_sequence) for prompt in text] + else: + image_inputs = {} + + return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None) + text_inputs = self.tokenizer(text=text, **output_kwargs["text_kwargs"], return_tensors="np") + self._check_special_mm_tokens(text, text_inputs, modalities=["image"]) + + # Add token type ids manually, as tokenizer can't do arbitrary position token types + array_ids = text_inputs["input_ids"] + token_type_ids = np.zeros_like(array_ids) + token_type_ids[array_ids == self.image_token_id] = 1 + token_type_ids[array_ids == self.audio_token_id] = 3 + text_inputs = {k: v.tolist() for k, v in text_inputs.items()} # in case user requested list inputs + text_inputs["token_type_ids"] = token_type_ids.tolist() + return BatchFeature(data={**text_inputs, **image_inputs, **audio_inputs}, tensor_type=return_tensors) + + # Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Gemma + def batch_decode(self, *args, **kwargs): + """ + This method forwards all its arguments to GemmaTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please + refer to the docstring of this method for more information. + """ + return self.tokenizer.batch_decode(*args, **kwargs) + + # Copied from transformers.models.clip.processing_clip.CLIPProcessor.decode with CLIP->Gemma + def decode(self, *args, **kwargs): + """ + This method forwards all its arguments to GemmaTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to + the docstring of this method for more information. + """ + return self.tokenizer.decode(*args, **kwargs) + + @property + def model_input_names(self): + tokenizer_input_names = self.tokenizer.model_input_names + ["token_type_ids"] + image_processor_input_names = self.image_processor.model_input_names + feature_extactor_input_names = self.feature_extractor.model_input_names + return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names + feature_extactor_input_names)) + + +__all__ = ["Gemma3nProcessor"] diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py index 10f31b81c8f5..fe24e85c6419 100644 --- a/src/transformers/testing_utils.py +++ b/src/transformers/testing_utils.py @@ -1627,6 +1627,7 @@ def set_model_tester_for_less_flaky_test(test_case): "AriaVisionText2TextModelTester", "GPTNeoModelTester", "DPTModelTester", + "Gemma3nTextModelTester", # cannot have a single layer combined with the cache sharing config attrs in the tester ] if test_case.model_tester.__class__.__name__ in exceptional_classes: target_num_hidden_layers = None diff --git a/tests/models/gemma3n/__init__.py b/tests/models/gemma3n/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/models/gemma3n/test_feature_extraction_gemma3n.py b/tests/models/gemma3n/test_feature_extraction_gemma3n.py new file mode 100644 index 000000000000..d2b10315bd6e --- /dev/null +++ b/tests/models/gemma3n/test_feature_extraction_gemma3n.py @@ -0,0 +1,277 @@ +# Copyright 2025 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import itertools +import os +import random +import tempfile +import unittest +from typing import Optional, Sequence + +import numpy as np +from parameterized import parameterized + +from transformers.models.gemma3n import Gemma3nAudioFeatureExtractor +from transformers.testing_utils import ( + check_json_file_has_correct_format, + require_torch, +) +from transformers.utils.import_utils import is_torch_available + +from ...test_sequence_feature_extraction_common import SequenceFeatureExtractionTestMixin + + +if is_torch_available(): + pass + +global_rng = random.Random() + +MAX_LENGTH_FOR_TESTING = 512 + + +def floats_list(shape, scale=1.0, rng=None): + """Creates a random float32 tensor""" + if rng is None: + rng = global_rng + + values = [] + for _ in range(shape[0]): + values.append([]) + for _ in range(shape[1]): + values[-1].append(rng.random() * scale) + + return values + + +class Gemma3nAudioFeatureExtractionTester: + def __init__( + self, + parent, + batch_size=7, + min_seq_length=400, + max_seq_length=2000, + feature_size: int = 128, + sampling_rate: int = 16_000, + padding_value: float = 0.0, + return_attention_mask: bool = False, + # ignore hop_length / frame_length for now, as ms -> length conversion causes issues with serialization tests + # frame_length_ms: float = 32.0, + # hop_length: float = 10.0, + min_frequency: float = 125.0, + max_frequency: float = 7600.0, + preemphasis: float = 0.97, + preemphasis_htk_flavor: bool = True, + fft_overdrive: bool = True, + dither: float = 0.0, + input_scale_factor: float = 1.0, + mel_floor: float = 1e-5, + per_bin_mean: Optional[Sequence[float]] = None, + per_bin_stddev: Optional[Sequence[float]] = None, + ): + self.parent = parent + self.batch_size = batch_size + self.min_seq_length = min_seq_length + self.max_seq_length = max_seq_length + self.seq_length_diff = (self.max_seq_length - self.min_seq_length) // (self.batch_size - 1) + self.feature_size = feature_size + self.sampling_rate = sampling_rate + self.padding_value = padding_value + self.return_attention_mask = return_attention_mask + # ignore hop_length / frame_length for now, as ms -> length conversion causes issues with serialization tests + # self.frame_length_ms = frame_length_ms + # self.hop_length = hop_length + self.min_frequency = min_frequency + self.max_frequency = max_frequency + self.preemphasis = preemphasis + self.preemphasis_htk_flavor = preemphasis_htk_flavor + self.fft_overdrive = fft_overdrive + self.dither = dither + self.input_scale_factor = input_scale_factor + self.mel_floor = mel_floor + self.per_bin_mean = per_bin_mean + self.per_bin_stddev = per_bin_stddev + + def prepare_feat_extract_dict(self): + return { + "feature_size": self.feature_size, + "sampling_rate": self.sampling_rate, + "padding_value": self.padding_value, + "return_attention_mask": self.return_attention_mask, + "min_frequency": self.min_frequency, + "max_frequency": self.max_frequency, + "preemphasis": self.preemphasis, + "preemphasis_htk_flavor": self.preemphasis_htk_flavor, + "fft_overdrive": self.fft_overdrive, + "dither": self.dither, + "input_scale_factor": self.input_scale_factor, + "mel_floor": self.mel_floor, + "per_bin_mean": self.per_bin_mean, + "per_bin_stddev": self.per_bin_stddev, + } + + def prepare_inputs_for_common(self, equal_length=False, numpify=False): + def _flatten(list_of_lists): + return list(itertools.chain(*list_of_lists)) + + if equal_length: + speech_inputs = [floats_list((self.max_seq_length, self.feature_size)) for _ in range(self.batch_size)] + else: + # make sure that inputs increase in size + speech_inputs = [ + floats_list((x, self.feature_size)) + for x in range(self.min_seq_length, self.max_seq_length, self.seq_length_diff) + ] + if numpify: + speech_inputs = [np.asarray(x) for x in speech_inputs] + return speech_inputs + + +class Gemma3nAudioFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.TestCase): + feature_extraction_class = Gemma3nAudioFeatureExtractor + + def setUp(self): + self.feat_extract_tester = Gemma3nAudioFeatureExtractionTester(self) + + def test_feat_extract_from_and_save_pretrained(self): + feat_extract_first = self.feature_extraction_class(**self.feat_extract_dict) + + with tempfile.TemporaryDirectory() as tmpdirname: + saved_file = feat_extract_first.save_pretrained(tmpdirname)[0] + check_json_file_has_correct_format(saved_file) + feat_extract_second = self.feature_extraction_class.from_pretrained(tmpdirname) + + dict_first = feat_extract_first.to_dict() + dict_second = feat_extract_second.to_dict() + mel_1 = feat_extract_first.mel_filters + mel_2 = feat_extract_second.mel_filters + self.assertTrue(np.allclose(mel_1, mel_2)) + self.assertEqual(dict_first, dict_second) + + def test_feat_extract_to_json_file(self): + feat_extract_first = self.feature_extraction_class(**self.feat_extract_dict) + + with tempfile.TemporaryDirectory() as tmpdirname: + json_file_path = os.path.join(tmpdirname, "feat_extract.json") + feat_extract_first.to_json_file(json_file_path) + feat_extract_second = self.feature_extraction_class.from_json_file(json_file_path) + + dict_first = feat_extract_first.to_dict() + dict_second = feat_extract_second.to_dict() + mel_1 = feat_extract_first.mel_filters + mel_2 = feat_extract_second.mel_filters + self.assertTrue(np.allclose(mel_1, mel_2)) + self.assertEqual(dict_first, dict_second) + + def test_feat_extract_from_pretrained_kwargs(self): + feat_extract_first = self.feature_extraction_class(**self.feat_extract_dict) + + with tempfile.TemporaryDirectory() as tmpdirname: + saved_file = feat_extract_first.save_pretrained(tmpdirname)[0] + check_json_file_has_correct_format(saved_file) + feat_extract_second = self.feature_extraction_class.from_pretrained( + tmpdirname, feature_size=2 * self.feat_extract_dict["feature_size"] + ) + + mel_1 = feat_extract_first.mel_filters + mel_2 = feat_extract_second.mel_filters + self.assertTrue(2 * mel_1.shape[1] == mel_2.shape[1]) + + @parameterized.expand( + [ + ([floats_list((1, x))[0] for x in range(800, 1400, 200)],), + ([floats_list((1, x))[0] for x in (800, 800, 800)],), + ([floats_list((1, x))[0] for x in range(200, (MAX_LENGTH_FOR_TESTING + 500), 200)], True), + ] + ) + def test_call(self, audio_inputs, test_truncation=False): + feature_extractor = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict()) + np_audio_inputs = [np.asarray(audio_input) for audio_input in audio_inputs] + + input_features = feature_extractor(np_audio_inputs, padding="max_length", return_tensors="np").input_features + self.assertTrue(input_features.ndim == 3) + # input_features.shape should be (batch, num_frames, n_mels) ~= (batch, num_frames, feature_size) + # 480_000 is the max_length that inputs are padded to. we use that to calculate num_frames + expected_num_frames = (480_000 - feature_extractor.frame_length) // (feature_extractor.hop_length) + 1 + self.assertTrue( + input_features.shape[-2] == expected_num_frames, + f"no match: {input_features.shape[-1]} vs {expected_num_frames}", + ) + self.assertTrue(input_features.shape[-1] == feature_extractor.feature_size) + + encoded_sequences_1 = feature_extractor(audio_inputs, return_tensors="np").input_features + encoded_sequences_2 = feature_extractor(np_audio_inputs, return_tensors="np").input_features + for enc_seq_1, enc_seq_2 in zip(encoded_sequences_1, encoded_sequences_2): + self.assertTrue(np.allclose(enc_seq_1, enc_seq_2, atol=1e-3)) + + if test_truncation: + audio_inputs_truncated = [x[:MAX_LENGTH_FOR_TESTING] for x in audio_inputs] + np_audio_inputs_truncated = [np.asarray(audio_input) for audio_input in audio_inputs_truncated] + + encoded_sequences_1 = feature_extractor( + audio_inputs_truncated, max_length=MAX_LENGTH_FOR_TESTING, return_tensors="np" + ).input_features + encoded_sequences_2 = feature_extractor( + np_audio_inputs_truncated, max_length=MAX_LENGTH_FOR_TESTING, return_tensors="np" + ).input_features + for enc_seq_1, enc_seq_2 in zip(encoded_sequences_1, encoded_sequences_2): + self.assertTrue(np.allclose(enc_seq_1, enc_seq_2, atol=1e-3)) + + def test_dither(self): + np.random.seed(42) # seed the dithering randn() + + # Tests that features with and without little dithering are similar, but not the same + dict_no_dither = self.feat_extract_tester.prepare_feat_extract_dict() + dict_no_dither["dither"] = 0.0 + + dict_dither = self.feat_extract_tester.prepare_feat_extract_dict() + dict_dither["dither"] = 0.00003 # approx. 1/32k + + feature_extractor_no_dither = self.feature_extraction_class(**dict_no_dither) + feature_extractor_dither = self.feature_extraction_class(**dict_dither) + + # create three inputs of length 800, 1000, and 1200 + speech_inputs = [floats_list((1, x))[0] for x in range(800, 1400, 200)] + np_speech_inputs = [np.asarray(speech_input) for speech_input in speech_inputs] + + # compute features + input_features_no_dither = feature_extractor_no_dither( + np_speech_inputs, padding=True, return_tensors="np", sampling_rate=dict_no_dither["sampling_rate"] + ).input_features + input_features_dither = feature_extractor_dither( + np_speech_inputs, padding=True, return_tensors="np", sampling_rate=dict_dither["sampling_rate"] + ).input_features + + # test there is a difference between features (there's added noise to input signal) + diff = input_features_dither - input_features_no_dither + + # features are not identical + self.assertTrue(np.abs(diff).mean() > 1e-6) + # features are not too different + self.assertTrue(np.abs(diff).mean() <= 1e-4) + self.assertTrue(np.abs(diff).max() <= 5e-3) + + @require_torch + def test_double_precision_pad(self): + import torch + + feature_extractor = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict()) + np_speech_inputs = np.random.rand(100, 32).astype(np.float64) + py_speech_inputs = np_speech_inputs.tolist() + + for inputs in [py_speech_inputs, np_speech_inputs]: + np_processed = feature_extractor.pad([{"input_features": inputs}], return_tensors="np") + self.assertTrue(np_processed.input_features.dtype == np.float32) + pt_processed = feature_extractor.pad([{"input_features": inputs}], return_tensors="pt") + self.assertTrue(pt_processed.input_features.dtype == torch.float32) diff --git a/tests/models/gemma3n/test_modeling_gemma3n.py b/tests/models/gemma3n/test_modeling_gemma3n.py new file mode 100644 index 000000000000..2f546e19e49c --- /dev/null +++ b/tests/models/gemma3n/test_modeling_gemma3n.py @@ -0,0 +1,886 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Testing suite for the PyTorch Gemma3n model.""" + +import tempfile +import unittest + +import numpy as np +import pytest +from datasets import load_dataset +from parameterized import parameterized + +from transformers import ( + AutoModelForCausalLM, + AutoProcessor, + AutoTokenizer, + Gemma3nAudioConfig, + Gemma3nAudioFeatureExtractor, + Gemma3nConfig, + Gemma3nTextConfig, + GenerationConfig, + is_torch_available, +) +from transformers.testing_utils import ( + cleanup, + require_flash_attn, + require_read_token, + require_torch, + require_torch_gpu, + slow, + torch_device, +) + +from ...generation.test_utils import GenerationTesterMixin +from ...test_configuration_common import ConfigTester +from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor +from ..gemma.test_modeling_gemma import GemmaModelTester + + +if is_torch_available(): + import torch + + from transformers import ( + Gemma3nAudioEncoder, + Gemma3nForCausalLM, + Gemma3nForConditionalGeneration, + Gemma3nModel, + Gemma3nTextModel, + ) + + +class Gemma3nAudioModelTester: + def __init__( + self, + parent, + batch_size=2, + num_channels=32, # feature_size / input_feat_size + sampling_rate=16_000, + raw_audio_length=8_000, + is_training=True, + ): + self.parent = parent + self.batch_size = batch_size + self.num_channels = num_channels + self.sampling_rate = sampling_rate + self.raw_audio_length = raw_audio_length + self.is_training = is_training + + def get_feature_extractor_config(self): + return { + "feature_size": self.num_channels, + "sampling_rate": self.sampling_rate, + "padding_value": 0.0, + "return_attention_mask": True, + "frame_length_ms": 32.0, + "hop_length_ms": 10.0, + "dither": 0.0, # Important for determinism + } + + def get_audio_encoder_config(self): + return Gemma3nAudioConfig( + input_feat_size=self.num_channels, + hidden_size=32, + conf_num_attention_heads=4, + conf_num_hidden_layers=2, + sscp_conv_channel_size=(16, 8), + conf_conv_kernel_size=3, + conf_attention_chunk_size=4, + conf_attention_context_left=5, + ) + + def prepare_config_and_inputs_for_common(self): + # Prepare inputs for the audio encoder + feature_extractor_config = self.get_feature_extractor_config() + audio_encoder_config = self.get_audio_encoder_config() + + np.random.seed(0) + raw_speech_1 = np.sin(2 * np.pi * 440 * np.linspace(0, 1, self.raw_audio_length)).astype(np.float32) + raw_speech_2 = np.random.randn(self.raw_audio_length // 2).astype(np.float32) + raw_speech = [raw_speech_1, raw_speech_2] + + feature_extractor = Gemma3nAudioFeatureExtractor(**feature_extractor_config) + audio_inputs = feature_extractor(raw_speech, return_tensors="pt") + + input_features = audio_inputs["input_features"] + # The encoder expects a padding mask (True for padding), while the feature extractor + # returns an attention mask (True for valid tokens). We must invert it. + input_features_mask = ~audio_inputs["input_features_mask"].to(torch.bool) + + inputs_dict = { + "audio_mel": input_features, + "audio_mel_mask": input_features_mask, + } + return audio_encoder_config, inputs_dict + + +@unittest.skip("Skipped for now!") +@require_torch +class Gemma3nAudioModelTest(ModelTesterMixin, unittest.TestCase): + all_model_classes = (Gemma3nAudioEncoder,) if is_torch_available() else () + test_pruning = False + test_head_masking = False + test_missing_keys = False + is_generative = False + _is_stateful = True + main_input_name = "audio_mel" + test_initialization = False + test_can_init_all_missing_weights = False + + def setUp(self): + self.model_tester = Gemma3nAudioModelTester(self) + self.config_tester = ConfigTester(self, config_class=Gemma3nAudioConfig, hidden_size=37) + torch.manual_seed(0) + + # The following values are golden outputs from a deterministic run of the components. + # They are used to ensure that changes to the code do not alter the numerical output. + # Generated with seeds np.random.seed(0) and torch.manual_seed(0). + self.expected_input_features_shape = (2, 48, 32) + self.expected_input_features_slice = np.array([-5.733152, -5.337127, -4.916284, -4.378989, -3.7622747]) + self.expected_input_features_mask_shape = (2, 48) + self.expected_input_features_mask_slice = np.array([True, True, True, True, False]) + + self.expected_encoder_output_shape = (2, 3, 32) + self.expected_encoder_output_slice = torch.tensor([-0.4159, 0.6459, 0.6305, 2.2902, 0.9683]) + self.expected_encoder_mask_shape = (2, 3) + self.expected_encoder_mask_slice = torch.tensor([False, False, True]) + + # Prepare a shared feature extractor and raw audio for the tests + self.feature_extractor = Gemma3nAudioFeatureExtractor(**self.model_tester.get_feature_extractor_config()) + np.random.seed(0) + raw_speech_1 = np.sin(2 * np.pi * 440 * np.linspace(0, 1, self.model_tester.raw_audio_length)).astype( + np.float32 + ) + raw_speech_2 = np.random.randn(self.model_tester.raw_audio_length // 2).astype(np.float32) + self.raw_speech = [raw_speech_1, raw_speech_2] + + @unittest.skip("Audio encoder does not support attention output") + def test_attention_outputs(self): + pass + + @unittest.skip("Audio encoder does not support hidden state output") + def test_hidden_states_output(self): + pass + + @unittest.skip("Audio encoder returns a tuple, not a ModelOutput object, skipping equivalence test.") + def test_model_outputs_equivalence(self): + pass + + @unittest.skip("Audio encoder does not support retaining gradients on hidden states/attentions.") + def test_retain_grad_hidden_states_attentions(self): + pass + + @unittest.skip("Audio encoder does not have a concept of token embeddings") + def test_model_get_set_embeddings(self): + pass + + @unittest.skip("Audio encoder does not have a concept of token embeddings") + def test_resize_tokens_embeddings(self): + pass + + @unittest.skip("This model has a complex downsampling scheme that is hard to test with the generic batching test.") + def test_batching_equivalence(self): + pass + + def test_feature_extractor(self): + """ + Tests the feature extractor's output against pre-computed golden values. + This ensures the NumPy-based audio preprocessing is correct and consistent. + """ + audio_inputs = self.feature_extractor( + self.raw_speech, padding="longest", pad_to_multiple_of=128, return_tensors="np" + ) + + input_features = audio_inputs["input_features"] + self.assertEqual(input_features.shape, self.expected_input_features_shape) + np.testing.assert_allclose(input_features[0, 0, :5], self.expected_input_features_slice, rtol=1e-5, atol=1e-5) + + print(input_features[0, 0, :5]) + + input_features_mask = audio_inputs["input_features_mask"] + self.assertEqual(input_features_mask.shape, self.expected_input_features_mask_shape) + # The second audio sample is shorter (22 frames vs 48), so its mask should become False at index 22 + np.testing.assert_array_equal(input_features_mask[1, 21:26], self.expected_input_features_mask_slice) + + def test_audio_encoder(self): + """ + Tests the audio encoder's forward pass against pre-computed golden values. + This ensures the PyTorch-based audio encoding model is correct and consistent. + """ + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + model = Gemma3nAudioEncoder(config).to(torch_device).eval() + + with torch.no_grad(): + encoder_output, encoder_mask = model(**inputs_dict) + + print(encoder_output[0, 0, :5]) + + # Check output encodings + self.assertEqual(encoder_output.shape, self.expected_encoder_output_shape) + torch.testing.assert_close( + encoder_output[0, 0, :5], self.expected_encoder_output_slice.to(torch_device), rtol=1e-4, atol=1e-4 + ) + + # Check output mask (True means padded) + # Second sample has 22 feature frames. After downsampling by 4 (conv) -> 5 frames. After downsampling by 4 (reduction) -> 1 frame. + # So the mask should be [False, True, True] + self.assertEqual(encoder_mask.shape, self.expected_encoder_mask_shape) + torch.testing.assert_close(encoder_mask[1, :], self.expected_encoder_mask_slice.to(torch_device)) + + +class Gemma3nTextModelTester(GemmaModelTester): + activation_sparsity_pattern = None + forced_config_args = ["activation_sparsity_pattern"] + + def __init__( + self, + parent, + batch_size=13, + seq_length=7, + is_training=True, + use_input_mask=True, + use_token_type_ids=False, + use_labels=True, + vocab_size=99, + vocab_size_per_layer_input=99, + hidden_size=16, + num_hidden_layers=4, # override to correctly test sharing cache pattern + num_kv_shared_layers=2, # important to override + layer_types=[ + "full_attention", + "sliding_attention", + "full_attention", + "sliding_attention", + ], # similarly we want to test sharing on both types + num_attention_heads=2, + num_key_value_heads=2, + altup_num_inputs=2, + intermediate_size=21, + hidden_activation="gelu_pytorch_tanh", + max_position_embeddings=512, + type_vocab_size=16, + type_sequence_label_size=2, + initializer_range=0.02, + num_labels=3, + num_choices=4, + pad_token_id=0, + bos_token_id=1, + eos_token_id=2, + is_decoder=False, + ): + self._verify_model_attributes() + self.parent = parent + self.batch_size = batch_size + self.seq_length = seq_length + self.is_training = is_training + self.use_input_mask = use_input_mask + self.use_token_type_ids = use_token_type_ids + self.use_labels = use_labels + self.vocab_size = vocab_size + self.vocab_size_per_layer_input = vocab_size_per_layer_input + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_kv_shared_layers = num_kv_shared_layers + self.layer_types = layer_types + self.num_attention_heads = num_attention_heads + self.num_key_value_heads = num_key_value_heads + self.altup_num_inputs = altup_num_inputs + self.intermediate_size = intermediate_size + self.hidden_activation = hidden_activation + self.max_position_embeddings = max_position_embeddings + self.type_vocab_size = type_vocab_size + self.type_sequence_label_size = type_sequence_label_size + self.initializer_range = initializer_range + self.num_labels = num_labels + self.num_choices = num_choices + self.pad_token_id = pad_token_id + self.bos_token_id = bos_token_id + self.eos_token_id = eos_token_id + self.head_dim = self.hidden_size // self.num_attention_heads + self.is_decoder = is_decoder + + if is_torch_available(): + config_class = Gemma3nTextConfig + model_class = Gemma3nTextModel + for_causal_lm_class = Gemma3nForCausalLM + + +@unittest.skip("Skipped for now!") +@require_torch +class Gemma3nTextModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): + all_model_classes = (Gemma3nTextModel, Gemma3nForCausalLM) if is_torch_available() else () + all_generative_model_classes = (Gemma3nForCausalLM,) if is_torch_available() else () + test_headmasking = False + test_pruning = False + _is_stateful = True + model_split_percents = [0.5, 0.6] + + def setUp(self): + self.model_tester = Gemma3nTextModelTester(self) + self.config_tester = ConfigTester( + self, + config_class=Gemma3nConfig, + hidden_size=37, + text_config={"activation_sparsity_pattern": None}, + ) + + def _check_hidden_states_for_generate( + self, batch_size, hidden_states, prompt_length, output_length, config, use_cache=False + ): + "Gemma3n has special hidden states shape with 1 additional dim (which is then reduced with projections)" + + self.assertIsInstance(hidden_states, tuple) + self.assertListEqual( + [isinstance(iter_hidden_states, tuple) for iter_hidden_states in hidden_states], + [True] * len(hidden_states), + ) + self.assertEqual(len(hidden_states), (output_length - prompt_length)) + + # When `output_hidden_states=True`, each iteration of generate appends the hidden states corresponding to the + # new token(s) + # NOTE: `HybridCache` may have different lengths on different layers, if this test starts failing add more + # elaborate checks + for generated_length, iter_hidden_states in enumerate(hidden_states): + # regardless of using cache, the first forward pass will have the full prompt as input + if use_cache and generated_length > 0: + model_input_length = 1 + else: + model_input_length = prompt_length + generated_length + expected_shape = (config.altup_num_inputs, batch_size, model_input_length, config.hidden_size) + # check hidden size + self.assertListEqual( + [layer_hidden_states.shape for layer_hidden_states in iter_hidden_states], + [expected_shape] * len(iter_hidden_states), + ) + + +class Gemma3nVision2TextModelTester: + text_config = {"activation_sparsity_pattern": None} + forced_config_args = ["text_config"] + + def __init__( + self, + parent, + mm_tokens_per_image=2, + image_token_index=1, + boi_token_index=2, + eoi_token_index=3, + seq_length=25, + is_training=True, + vision_config={ + "use_labels": True, + "image_size": 20, + "patch_size": 5, + "num_channels": 3, + "is_training": True, + "hidden_size": 32, + "num_key_value_heads": 1, + "num_hidden_layers": 2, + "num_attention_heads": 4, + "intermediate_size": 37, + "dropout": 0.1, + "attention_dropout": 0.1, + "initializer_range": 0.02, + }, + use_cache=False, + ): + self.parent = parent + # `image_token_index` is set to 0 to pass "resize_embeddings" test, do not modify + self.mm_tokens_per_image = mm_tokens_per_image + self.image_token_index = image_token_index + self.boi_token_index = boi_token_index + self.eoi_token_index = eoi_token_index + self.llm_tester = Gemma3nTextModelTester(self.parent) + self.text_config = self.llm_tester.get_config() + self.vision_config = vision_config + self.seq_length = seq_length + self.pad_token_id = self.text_config.pad_token_id + + self.num_hidden_layers = self.text_config.num_hidden_layers + self.vocab_size = self.text_config.vocab_size + self.hidden_size = self.text_config.hidden_size + self.num_attention_heads = self.text_config.num_attention_heads + self.is_training = is_training + + self.batch_size = 3 + self.num_channels = vision_config["num_channels"] + self.image_size = vision_config["image_size"] + self.encoder_seq_length = seq_length + self.use_cache = use_cache + + def get_config(self): + return Gemma3nConfig( + text_config=self.text_config, + vision_config=self.vision_config, + image_token_index=self.image_token_index, + boi_token_index=self.boi_token_index, + eoi_token_index=self.eoi_token_index, + mm_tokens_per_image=self.mm_tokens_per_image, + ) + + def prepare_config_and_inputs(self): + pixel_values = floats_tensor( + [ + self.batch_size, + self.vision_config["num_channels"], + self.vision_config["image_size"], + self.vision_config["image_size"], + ] + ) + config = self.get_config() + + return config, pixel_values + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + config, pixel_values = config_and_inputs + input_ids = ids_tensor([self.batch_size, self.seq_length], config.text_config.vocab_size - 1) + 1 + attention_mask = input_ids.ne(self.pad_token_id).to(torch_device) + + # set the 3 first tokens to be image, and ensure that no other tokens are image tokens + # do not change this unless you modified image size or patch size + input_ids[input_ids == config.image_token_index] = self.pad_token_id + input_ids[:, :1] = config.image_token_index + + token_type_ids = torch.zeros_like(input_ids) + token_type_ids[input_ids == config.image_token_index] = 1 + + inputs_dict = { + "pixel_values": pixel_values, + "input_ids": input_ids, + "attention_mask": attention_mask, + "token_type_ids": token_type_ids, + } + return config, inputs_dict + + +@unittest.skip("Skipped for now!") +@require_torch +class Gemma3nVision2TextModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): + all_model_classes = (Gemma3nModel, Gemma3nForConditionalGeneration) if is_torch_available() else () + all_generative_model_classes = (Gemma3nForConditionalGeneration,) if is_torch_available() else () + test_headmasking = False + test_pruning = False + test_missing_keys = False + _is_stateful = True + model_split_percents = [0.5, 0.6] + + # MP works but offload doesn't work when the SigLIP MultiheadAttention is offloaded + # TODO: One potential solution would be to add to set preload_module_classes = ["SiglipMultiheadAttentionPoolingHead"] + # in the dispatch_model function + test_cpu_offload = False + test_disk_offload_safetensors = False + test_disk_offload_bin = False + + def setUp(self): + self.model_tester = Gemma3nVision2TextModelTester(self) + self.config_tester = ConfigTester( + self, + config_class=Gemma3nConfig, + hidden_size=37, + text_config={"activation_sparsity_pattern": None}, + ) + + @unittest.skip(reason="SiglipVisionModel (vision backbone) does not support standalone training") + def test_training_gradient_checkpointing(self): + pass + + @unittest.skip(reason="SiglipVisionModel (vision backbone) does not support standalone training") + def test_training_gradient_checkpointing_use_reentrant(self): + pass + + @unittest.skip(reason="SiglipVisionModel (vision backbone) does not support standalone training") + def test_training_gradient_checkpointing_use_reentrant_false(self): + pass + + @unittest.skip( + reason="HybridCache can't be gathered because it is not iterable. Adding a simple iter and dumping `distributed_iterator`" + " as in Dynamic Cache doesnt work. NOTE: @gante all cache objects would need better compatibility with multi gpu setting" + ) + def test_multi_gpu_data_parallel_forward(self): + pass + + @unittest.skip("Failing because of unique cache (HybridCache)") + def test_model_outputs_equivalence(self, **kwargs): + pass + + @parameterized.expand([("random",), ("same",)]) + @pytest.mark.generate + @unittest.skip("Gemma3n has HybridCache which is not compatible with assisted decoding") + def test_assisted_decoding_matches_greedy_search(self, assistant_type): + pass + + @unittest.skip("Gemma3n has HybridCache which is not compatible with assisted decoding") + def test_prompt_lookup_decoding_matches_greedy_search(self, assistant_type): + pass + + @pytest.mark.generate + @unittest.skip("Gemma3n has HybridCache which is not compatible with assisted decoding") + def test_assisted_decoding_sample(self): + pass + + @unittest.skip("Gemma3n has HybridCache which is not compatible with dola decoding") + def test_dola_decoding_sample(self): + pass + + @unittest.skip("Gemma3n has HybridCache and doesn't support continue from past kv") + def test_generate_continue_from_past_key_values(self): + pass + + @unittest.skip("Gemma3n has HybridCache and doesn't support low_memory generation") + def test_beam_search_low_memory(self): + pass + + @unittest.skip("Gemma3n has HybridCache and doesn't support contrastive generation") + def test_contrastive_generate(self): + pass + + @unittest.skip("Gemma3n has HybridCache and doesn't support contrastive generation") + def test_contrastive_generate_dict_outputs_use_cache(self): + pass + + @unittest.skip("Gemma3n has HybridCache and doesn't support contrastive generation") + def test_contrastive_generate_low_memory(self): + pass + + @unittest.skip("Gemma3n has HybridCache and doesn't support StaticCache. Though it could, it shouldn't support.") + def test_generate_with_static_cache(self): + pass + + @unittest.skip("Gemma3n has HybridCache and doesn't support StaticCache. Though it could, it shouldn't support.") + def test_generate_from_inputs_embeds_with_static_cache(self): + pass + + @unittest.skip( + reason="Siglip (vision backbone) uses the same initialization scheme as the Flax original implementation" + ) + def test_initialization(self): + pass + + @unittest.skip( + reason="Siglip has no FLEX attention, and we don't have a proper way to set/test attn in VLMs. TODO @raushan" + ) + def test_flex_attention_with_grads(self): + pass + + def test_automodelforcausallm(self): + """ + Regression test for #36741 -- make sure `AutoModelForCausalLM` works with a Gemma3n config, i.e. that + `AutoModelForCausalLM.from_pretrained` pulls the text config before loading the model + """ + config = self.model_tester.get_config() + model = Gemma3nForConditionalGeneration(config) + with tempfile.TemporaryDirectory() as tmp_dir: + model.save_pretrained(tmp_dir) + for_causal_lm = AutoModelForCausalLM.from_pretrained(tmp_dir) + self.assertIsInstance(for_causal_lm, Gemma3nForCausalLM) + + +@unittest.skip("Skipped for now!") +@slow +@require_torch_gpu +@require_read_token +class Gemma3nIntegrationTest(unittest.TestCase): + def setUp(self): + self.processor = AutoProcessor.from_pretrained("Google/gemma-3n-E4B-it", padding_side="left") + + url = "https://huggingface.co/datasets/hf-internal-testing/fixtures-captioning/resolve/main/cow_beach_1.png" + self.messages = [ + {"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]}, + { + "role": "user", + "content": [ + {"type": "image", "url": url}, + {"type": "text", "text": "What is shown in this image?"}, + ], + }, + ] + + audio_ds = load_dataset( + "etechgrid/28.5k_wavfiles_dataset", "default", data_files="wav_dataset/103-1240-0000.wav" + ) + self.audio_file_path = audio_ds["train"][0]["audio"]["path"] + + def tearDown(self): + cleanup(torch_device, gc_collect=True) + + def test_model_4b_bf16(self): + model_id = "Google/gemma-3n-E4B-it" + + model = Gemma3nForConditionalGeneration.from_pretrained( + model_id, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16 + ).to(torch_device) + + inputs = self.processor.apply_chat_template( + self.messages, + tokenize=True, + return_dict=True, + return_tensors="pt", + add_generation_prompt=True, + ).to(torch_device) + + output = model.generate(**inputs, max_new_tokens=30, do_sample=False) + output_text = self.processor.batch_decode(output, skip_special_tokens=True) + + EXPECTED_TEXTS = ['user\nYou are a helpful assistant.\n\n\n\n\n\nWhat is shown in this image?\nmodel\nCertainly! \n\nThe image shows a brown cow standing on a sandy beach with clear blue water and a blue sky in the background. It looks like'] # fmt: skip + self.assertEqual(output_text, EXPECTED_TEXTS) + + def test_model_with_audio(self): + """ + Tests the full model pipeline with batched audio inputs provided as file paths. + This ensures the processor correctly loads and processes audio files. + """ + + model_id = "Google/gemma-3n-E4B-it" + + model = Gemma3nForConditionalGeneration.from_pretrained( + model_id, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16 + ).to(torch_device) + + messages = [ + [ + { + "role": "user", + "content": [ + {"type": "text", "text": "Transcribe the following speech segment in English:"}, + {"type": "audio", "audio": str(self.audio_file_path)}, + ], + } + ], + ] + + inputs = self.processor.apply_chat_template( + messages, + add_generation_prompt=True, + tokenize=True, + return_dict=True, + padding=True, + return_tensors="pt", + ).to(torch_device, dtype=model.dtype) + + input_len = inputs["input_ids"].shape[-1] + + output = model.generate(**inputs, max_new_tokens=16, do_sample=False) + output = output[:, input_len:] + output_text = self.processor.batch_decode(output, skip_special_tokens=True) + + EXPECTED_TEXTS = ["Chapter 1. Mrs. Rachel Lind is surprised.\n\nMrs. Rachel Lind"] + self.assertEqual(output_text, EXPECTED_TEXTS) + + def test_model_4b_batch(self): + model_id = "Google/gemma-3n-E4B-it" + + model = Gemma3nForConditionalGeneration.from_pretrained( + model_id, low_cpu_mem_usage=False, torch_dtype=torch.bfloat16 + ).to(torch_device) + + messages_2 = [ + {"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]}, + { + "role": "user", + "content": [ + { + "type": "image", + "url": "https://huggingface.co/datasets/hf-internal-testing/fixtures-captioning/resolve/main/cow_beach_1.png", + }, + {"type": "image", "url": "https://www.ilankelman.org/stopsigns/australia.jpg"}, + {"type": "text", "text": "Are these images identical?"}, + ], + }, + ] + + inputs = self.processor.apply_chat_template( + [self.messages, messages_2], + tokenize=True, + return_dict=True, + return_tensors="pt", + padding=True, + add_generation_prompt=True, + ).to(torch_device) + + output = model.generate(**inputs, max_new_tokens=30, do_sample=False) + output_text = self.processor.batch_decode(output, skip_special_tokens=True) + + EXPECTED_TEXTS = [ + 'user\nYou are a helpful assistant.\n\n\n\n\n\nWhat is shown in this image?\nmodel\nCertainly! \n\nThe image shows a brown cow standing on a sandy beach with clear turquoise water and a blue sky in the background. It looks like', + "user\nYou are a helpful assistant.\n\n\n\n\n\n\n\n\n\nAre these images identical?\nmodel\nNo, these images are not identical. \n\nHere's a breakdown of the differences:\n\n* **Image 1:** Shows a cow" + ] # fmt: skip + self.assertEqual(output_text, EXPECTED_TEXTS) + + def test_model_4b_crops(self): + model_id = "Google/gemma-3n-E4B-it" + + model = Gemma3nForConditionalGeneration.from_pretrained( + model_id, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16 + ).to(torch_device) + + crop_config = { + "images_kwargs": { + "do_pan_and_scan": True, + "pan_and_scan_max_num_crops": 448, + "pan_and_scan_min_crop_size": 32, + "pan_and_scan_min_ratio_to_activate": 0.3, + } + } + + inputs = self.processor.apply_chat_template( + self.messages, + tokenize=True, + return_dict=True, + return_tensors="pt", + add_generation_prompt=True, + **crop_config, + ).to(torch_device) + + output = model.generate(**inputs, max_new_tokens=30, do_sample=False) + output_text = self.processor.batch_decode(output, skip_special_tokens=True) + + EXPECTED_NUM_IMAGES = 3 # one for the origin image and two crops of images + EXPECTED_TEXTS = ['user\nYou are a helpful assistant.\n\nHere is the original image \n\n\n\n and here are some crops to help you see better \n\n\n\n \n\n\n\nWhat is shown in this image?\nmodel\nThe image shows a brown cow standing on a beach with a turquoise ocean and blue sky in the background.'] # fmt: skip + self.assertEqual(len(inputs["pixel_values"]), EXPECTED_NUM_IMAGES) + self.assertEqual(output_text, EXPECTED_TEXTS) + + def test_model_4b_multiimage(self): + model_id = "Google/gemma-3n-E4B-it" + + model = Gemma3nForConditionalGeneration.from_pretrained( + model_id, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16 + ).to(torch_device) + + messages = [ + {"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]}, + { + "role": "user", + "content": [ + {"type": "image", "url": "https://www.ilankelman.org/stopsigns/australia.jpg"}, + {"type": "text", "text": "What do you see here?"}, + ], + }, + ] + + inputs = self.processor.apply_chat_template( + messages, + tokenize=True, + return_dict=True, + return_tensors="pt", + padding=True, + add_generation_prompt=True, + ).to(torch_device) + + output = model.generate(**inputs, max_new_tokens=30, do_sample=False) + output_text = self.processor.batch_decode(output, skip_special_tokens=True) + + EXPECTED_TEXTS = ["user\nYou are a helpful assistant.\n\n\n\n\n\nWhat do you see here?\nmodel\nOkay, let's break down what I see in this image:\n\n**Overall Scene:**\n\nIt looks like a street scene in a vibrant,"] # fmt: skip + self.assertEqual(output_text, EXPECTED_TEXTS) + + def test_model_1b_text_only(self): + model_id = "google/gemma-3-1b-it" + + model = Gemma3nForCausalLM.from_pretrained(model_id, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16).to( + torch_device + ) + tokenizer = AutoTokenizer.from_pretrained(model_id, padding_side="left") + inputs = tokenizer("Write a poem about Machine Learning.", return_tensors="pt").to(torch_device) + + output = model.generate(**inputs, max_new_tokens=30, do_sample=False) + output_text = tokenizer.batch_decode(output, skip_special_tokens=True) + + EXPECTED_TEXTS = ['Write a poem about Machine Learning.\n\n---\n\nThe data flows, a river deep,\nWith patterns hidden, secrets sleep.\nA neural net, a watchful eye,\nLearning'] # fmt: skip + self.assertEqual(output_text, EXPECTED_TEXTS) + + # TODO: raushan FA2 generates gibberish for no reason, check later + @require_flash_attn + @require_torch_gpu + @pytest.mark.flash_attn_test + def test_model_4b_flash_attn(self): + model_id = "Google/gemma-3n-E4B-it" + + model = Gemma3nForConditionalGeneration.from_pretrained( + model_id, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2" + ).to(torch_device) + + inputs = self.processor.apply_chat_template( + self.messages, + tokenize=True, + return_dict=True, + return_tensors="pt", + add_generation_prompt=True, + ).to(torch_device) + + output = model.generate(**inputs, max_new_tokens=30, do_sample=False) + output_text = self.processor.batch_decode(output, skip_special_tokens=True) + + EXPECTED_TEXTS = ['user\nYou are a helpful assistant.\n\n\n\n\n\nWhat is shown in this image?\nmodel\nCertainly! \n\nThe image shows a brown and white cow standing on a sandy beach next to a turquoise ocean. It looks like a very sunny and'] # fmt: skip + self.assertEqual(output_text, EXPECTED_TEXTS) + + @parameterized.expand([("flash_attention_2",), ("sdpa",), ("eager",)]) + def test_generation_beyond_sliding_window(self, attn_implementation: str): + """Test that we can correctly generate beyond the sliding window. This is non trivial as + we need to correctly slice the attention mask in all cases (because we use a HybridCache). + Outputs for every attention functions should be coherent and identical. + """ + model_id = "google/gemma-3-1b-it" + + input_text = [ + "This is a nice place. " * 800 + "I really enjoy the scenery,", # This is larger than 4096 tokens + "A list of colors: red, blue", # This will almost all be padding tokens + ] + tokenizer = AutoTokenizer.from_pretrained(model_id, padding="left") + inputs = tokenizer(input_text, padding=True, return_tensors="pt").to(torch_device) + + model = AutoModelForCausalLM.from_pretrained( + model_id, attn_implementation=attn_implementation, torch_dtype=torch.float16 + ).to(torch_device) + + # Make sure prefill is larger than sliding window + input_size = inputs.input_ids.shape[-1] + self.assertTrue(input_size > model.config.sliding_window) + + out = model.generate(**inputs, max_new_tokens=20)[:, input_size:] + output_text = tokenizer.batch_decode(out) + + EXPECTED_COMPLETIONS = [" and I'm going to take a walk.\n\nI really enjoy the scenery, and I'", ", green, yellow, orange, purple, brown, black, white, gray.\n\nI'"] # fmt: skip + self.assertEqual(output_text, EXPECTED_COMPLETIONS) + + def test_generation_beyond_sliding_window_with_generation_config(self): + """ + Same as `test_generation_beyond_sliding_window`, but passing a GenerationConfig. Regression test for #36684 -- + ensures `cache_implementation='hybrid'` is correctly inherited from the base `model.generation_config`. + """ + model_id = "google/gemma-3-1b-it" + attn_implementation = "sdpa" + + input_text = [ + "This is a nice place. " * 800 + "I really enjoy the scenery,", # This is larger than 4096 tokens + "A list of colors: red, blue", # This will almost all be padding tokens + ] + tokenizer = AutoTokenizer.from_pretrained(model_id, padding="left") + inputs = tokenizer(input_text, padding=True, return_tensors="pt").to(torch_device) + + model = AutoModelForCausalLM.from_pretrained( + model_id, attn_implementation=attn_implementation, torch_dtype=torch.float16 + ).to(torch_device) + + # Make sure prefill is larger than sliding window + input_size = inputs.input_ids.shape[-1] + self.assertTrue(input_size > model.config.sliding_window) + + generation_config = GenerationConfig(max_new_tokens=20) + + out = model.generate(**inputs, generation_config=generation_config)[:, input_size:] + output_text = tokenizer.batch_decode(out) + + EXPECTED_COMPLETIONS = [" and I'm going to take a walk.\n\nI really enjoy the scenery, and I'", ", green, yellow, orange, purple, brown, black, white, gray.\n\nI'"] # fmt: skip + self.assertEqual(output_text, EXPECTED_COMPLETIONS) diff --git a/tests/models/gemma3n/test_processing_gemma3n.py b/tests/models/gemma3n/test_processing_gemma3n.py new file mode 100644 index 000000000000..1d30a80c4896 --- /dev/null +++ b/tests/models/gemma3n/test_processing_gemma3n.py @@ -0,0 +1,185 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import shutil +import tempfile +import unittest + +import numpy as np +from parameterized import parameterized + +from transformers import GemmaTokenizerFast, SiglipImageProcessorFast, is_speech_available +from transformers.testing_utils import require_sentencepiece, require_torch, require_torchaudio, require_vision + +from .test_feature_extraction_gemma3n import floats_list + + +if is_speech_available(): + from transformers.models.gemma3n import Gemma3nAudioFeatureExtractor, Gemma3nProcessor + + +@require_torch +@require_torchaudio +@require_vision +@require_sentencepiece +class Gemma3nProcessorTest(unittest.TestCase): + def setUp(self): + # TODO: update to google? + self.model_id = "Google/gemma-3n-E4B-it" + self.tmpdirname = tempfile.mkdtemp(suffix="gemma3n") + self.maxDiff = None + + def get_tokenizer(self, **kwargs): + return GemmaTokenizerFast.from_pretrained(self.model_id, **kwargs) + + def get_feature_extractor(self, **kwargs): + return Gemma3nAudioFeatureExtractor.from_pretrained(self.model_id, **kwargs) + + def get_image_processor(self, **kwargs): + return SiglipImageProcessorFast.from_pretrained(self.model_id, **kwargs) + + def tearDown(self): + shutil.rmtree(self.tmpdirname) + + def test_save_load_pretrained_default(self): + # NOTE: feature_extractor and image_processor both use the same filename, preprocessor_config.json, when saved to + # disk, but the files are overwritten by processor.save_pretrained(). This test does not attempt to address + # this potential issue, and as such, does not guarantee content accuracy. + + tokenizer = self.get_tokenizer() + feature_extractor = self.get_feature_extractor() + image_processor = self.get_image_processor() + + processor = Gemma3nProcessor( + tokenizer=tokenizer, feature_extractor=feature_extractor, image_processor=image_processor + ) + + processor.save_pretrained(self.tmpdirname) + processor = Gemma3nProcessor.from_pretrained(self.tmpdirname) + + self.assertIsInstance(processor.tokenizer, GemmaTokenizerFast) + self.assertEqual(processor.tokenizer.get_vocab(), tokenizer.get_vocab()) + + self.assertIsInstance(processor.feature_extractor, Gemma3nAudioFeatureExtractor) + self.assertEqual(processor.feature_extractor.to_json_string(), feature_extractor.to_json_string()) + + def test_save_load_pretrained_additional_features(self): + tokenizer = self.get_tokenizer() + feature_extractor = self.get_feature_extractor() + image_processor = self.get_image_processor() + + processor = Gemma3nProcessor( + tokenizer=tokenizer, feature_extractor=feature_extractor, image_processor=image_processor + ) + processor.save_pretrained(self.tmpdirname) + + tokenizer_add_kwargs = self.get_tokenizer(bos_token="(BOS-BOS)", eos_token="(EOS-EOS)") + feature_extractor_add_kwargs = self.get_feature_extractor(dither=5.0, padding_value=1.0) + + processor = Gemma3nProcessor.from_pretrained( + self.tmpdirname, bos_token="(BOS-BOS)", eos_token="(EOS-EOS)", dither=5.0, padding_value=1.0 + ) + + self.assertEqual(processor.tokenizer.get_vocab(), tokenizer_add_kwargs.get_vocab()) + self.assertIsInstance(processor.tokenizer, GemmaTokenizerFast) + + self.assertEqual(processor.feature_extractor.to_json_string(), feature_extractor_add_kwargs.to_json_string()) + self.assertIsInstance(processor.feature_extractor, Gemma3nAudioFeatureExtractor) + + @parameterized.expand([256, 512, 768, 1024]) + def test_image_processor(self, image_size: int): + feature_extractor = self.get_feature_extractor() + tokenizer = self.get_tokenizer() + image_processor = self.get_image_processor() + processor = Gemma3nProcessor( + tokenizer=tokenizer, feature_extractor=feature_extractor, image_processor=image_processor + ) + + raw_image = np.random.randint(0, 256, size=(image_size, image_size, 3), dtype=np.uint8) + input_image_processor = image_processor(raw_image, return_tensors="pt") + input_processor = processor(text="Describe:", images=raw_image, return_tensors="pt") + + for key in input_image_processor.keys(): + self.assertAlmostEqual(input_image_processor[key].sum(), input_processor[key].sum(), delta=1e-2) + if "pixel_values" in key: + # NOTE: all images should be re-scaled to 768x768 + self.assertEqual(input_image_processor[key].shape, (1, 3, 768, 768)) + self.assertEqual(input_processor[key].shape, (1, 3, 768, 768)) + + def test_audio_feature_extractor(self): + feature_extractor = self.get_feature_extractor() + tokenizer = self.get_tokenizer() + image_processor = self.get_image_processor() + processor = Gemma3nProcessor( + tokenizer=tokenizer, feature_extractor=feature_extractor, image_processor=image_processor + ) + + raw_speech = floats_list((3, 1000)) + input_feat_extract = feature_extractor(raw_speech, return_tensors="pt") + input_processor = processor(text="Transcribe:", audio=raw_speech, return_tensors="pt") + + for key in input_feat_extract.keys(): + self.assertAlmostEqual(input_feat_extract[key].sum(), input_processor[key].sum(), delta=1e-2) + + def test_tokenizer(self): + feature_extractor = self.get_feature_extractor() + tokenizer = self.get_tokenizer() + image_processor = self.get_image_processor() + processor = Gemma3nProcessor( + tokenizer=tokenizer, feature_extractor=feature_extractor, image_processor=image_processor + ) + + input_str = "This is a test string" + + encoded_processor = processor(text=input_str) + + encoded_tok = tokenizer(input_str) + + for key in encoded_tok.keys(): + self.assertListEqual(encoded_tok[key], encoded_processor[key][0]) + + def test_tokenizer_decode(self): + feature_extractor = self.get_feature_extractor() + tokenizer = self.get_tokenizer() + image_processor = self.get_image_processor() + processor = Gemma3nProcessor( + tokenizer=tokenizer, feature_extractor=feature_extractor, image_processor=image_processor + ) + + predicted_ids = [[1, 4, 5, 8, 1, 0, 8], [3, 4, 3, 1, 1, 8, 9]] + + decoded_processor = processor.batch_decode(predicted_ids) + decoded_tok = tokenizer.batch_decode(predicted_ids) + + self.assertListEqual(decoded_tok, decoded_processor) + + def test_model_input_names(self): + feature_extractor = self.get_feature_extractor() + tokenizer = self.get_tokenizer() + image_processor = self.get_image_processor() + processor = Gemma3nProcessor( + tokenizer=tokenizer, feature_extractor=feature_extractor, image_processor=image_processor + ) + + for key in feature_extractor.model_input_names: + self.assertIn( + key, + processor.model_input_names, + ) + + for key in image_processor.model_input_names: + self.assertIn( + key, + processor.model_input_names, + ) diff --git a/utils/check_config_attributes.py b/utils/check_config_attributes.py index 22d6b033afbc..04fb04a64738 100644 --- a/utils/check_config_attributes.py +++ b/utils/check_config_attributes.py @@ -277,6 +277,7 @@ ], "Llama4VisionConfig": ["multi_modal_projector_bias", "norm_eps"], "SmolLM3Config": ["no_rope_layer_interval"], + "Gemma3nVisionConfig": ["architecture", "do_pooling", "model_args"], # this is for use in `timm` } diff --git a/utils/check_docstrings.py b/utils/check_docstrings.py index 3c27476bdc03..bc247b2b6011 100644 --- a/utils/check_docstrings.py +++ b/utils/check_docstrings.py @@ -79,6 +79,7 @@ # docstrings instead. If formatting should be ignored for the docstring, you can put a comment # no-format on the # line before the docstring. OBJECTS_TO_IGNORE = [ + "Gemma3nVisionConfig", "Llama4Processor", # Deprecated "InputExample",