diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml
index 9ed80cfb0b78..7508f0968863 100644
--- a/docs/source/en/_toctree.yml
+++ b/docs/source/en/_toctree.yml
@@ -959,6 +959,8 @@
title: FLAVA
- local: model_doc/gemma3
title: Gemma3
+ - local: model_doc/gemma3n
+ title: Gemma3n
- local: model_doc/git
title: GIT
- local: model_doc/glm4v
diff --git a/docs/source/en/model_doc/gemma3n.md b/docs/source/en/model_doc/gemma3n.md
new file mode 100644
index 000000000000..7f38c3b18c92
--- /dev/null
+++ b/docs/source/en/model_doc/gemma3n.md
@@ -0,0 +1,204 @@
+
+
+
+
+
+

+

+
+
+
+# Gemma3n
+
+## Overview
+
+Gemma3n is a multimodal model with pretrained and instruction-tuned variants, available in E4B and E2B sizes. While
+large portions of the language model architecture are shared with prior Gemma releases, there are many new additions in
+this model, including [Alternating Updates][altup] (AltUp), [Learned Augmented Residual Layer][laurel] (LAuReL),
+[MatFormer][matformer], Per-Layer Embeddings (PLE), activation sparsity, and KV cache sharing. The language model uses
+a similar attention pattern to [Gemma 3](./gemma3.md) with alternating 4 local sliding window self-attention layers for
+every global self-attention layer with a maximum context length of 32k tokens. Gemma 3n introduces
+[MobileNet v5][mobilenetv5] as the vision encoder, using a default resolution of 768x768 pixels, and adds a
+[Universal Speech Model][usm] (USM) as the audio encoder.
+
+The instruction-tuned variant was post-trained with knowledge distillation and reinforcement learning.
+
+You can find all the original Gemma 3n checkpoints under the [Gemma 3n][gemma3n-collection] release.
+
+> [!TIP]
+> Click on the Gemma 3n models in the right sidebar for more examples of how to apply Gemma to different vision, audio,
+> and language tasks.
+
+The example below demonstrates how to generate text based on an image with [`Pipeline`] or the [`AutoModel`] class.
+
+
+
+
+```py
+import torch
+from transformers import pipeline
+
+pipeline = pipeline(
+ task="image-text-to-text",
+ model="google/gemma-3n-e4b",
+ device=0,
+ torch_dtype=torch.bfloat16
+)
+pipeline(
+ "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg",
+ text=" What is shown in this image?"
+)
+```
+
+
+
+
+```py
+import torch
+from transformers import AutoProcessor, Gemma3nForConditionalGeneration
+
+model = Gemma3nForConditionalGeneration.from_pretrained(
+ "google/gemma-3n-e4b-it",
+ torch_dtype=torch.bfloat16,
+ device_map="auto",
+ attn_implementation="sdpa"
+)
+processor = AutoProcessor.from_pretrained(
+ "google/gemma-3n-e4b-it",
+ padding_side="left"
+)
+
+messages = [
+ {
+ "role": "system",
+ "content": [
+ {"type": "text", "text": "You are a helpful assistant."}
+ ]
+ },
+ {
+ "role": "user", "content": [
+ {"type": "image", "url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg"},
+ {"type": "text", "text": "What is shown in this image?"},
+ ]
+ },
+]
+inputs = processor.apply_chat_template(
+ messages,
+ tokenize=True,
+ return_dict=True,
+ return_tensors="pt",
+ add_generation_prompt=True,
+).to("cuda")
+
+output = model.generate(**inputs, max_new_tokens=50, cache_implementation="static")
+print(processor.decode(output[0], skip_special_tokens=True))
+```
+
+
+
+
+```bash
+echo -e "Plants create energy through a process known as" | transformers run --task text-generation --model google/gemma-3n-e2b --device 0
+```
+
+
+
+
+## Notes
+
+- Use [`Gemma3nForConditionalGeneration`] for image-audio-and-text, image-and-text, image-and-audio, audio-and-text,
+ image-only and aduio-only inputs.
+- Gemma 3n supports multiple images per input, but make sure the images are correctly batched before passing them to
+ the processor. Each batch should be a list of one or more images.
+
+ ```py
+ url_cow = "https://media.istockphoto.com/id/1192867753/photo/cow-in-berchida-beach-siniscola.jpg?s=612x612&w=0&k=20&c=v0hjjniwsMNfJSuKWZuIn8pssmD5h5bSN1peBd1CmH4="
+ url_cat = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg"
+
+ messages =[
+ {
+ "role": "system",
+ "content": [
+ {"type": "text", "text": "You are a helpful assistant."}
+ ]
+ },
+ {
+ "role": "user",
+ "content": [
+ {"type": "image", "url": url_cow},
+ {"type": "image", "url": url_cat},
+ {"type": "text", "text": "Which image is cuter?"},
+ ]
+ },
+ ]
+ ```
+- Text passed to the processor should have a `` token wherever an image should be inserted.
+- Gemma 3n accept at most one target audio clip per input, though multiple audio clips can be provided in few-shot
+ prompts, for example.
+- Text passed to the processor should have a `` token wherever an audio clip should be inserted.
+- The processor has its own [`~ProcessorMixin.apply_chat_template`] method to convert chat messages to model inputs.
+
+## Gemma3nAudioFeatureExtractor
+
+[[autodoc]] Gemma3nAudioFeatureExtractor
+
+## Gemma3nProcessor
+
+[[autodoc]] Gemma3nProcessor
+
+## Gemma3nTextConfig
+
+[[autodoc]] Gemma3nTextConfig
+
+## Gemma3nVisionConfig
+
+[[autodoc]] Gemma3nVisionConfig
+
+## Gemma3nAudioConfig
+
+[[autodoc]] Gemma3nAudioConfig
+
+## Gemma3nConfig
+
+[[autodoc]] Gemma3nConfig
+
+## Gemma3nTextModel
+
+[[autodoc]] Gemma3nTextModel
+ - forward
+
+## Gemma3nModel
+
+[[autodoc]] Gemma3nModel
+ - forward
+
+## Gemma3nForCausalLM
+
+[[autodoc]] Gemma3nForCausalLM
+ - forward
+
+## Gemma3nForConditionalGeneration
+
+[[autodoc]] Gemma3nForConditionalGeneration
+ - forward
+
+[altup]: https://proceedings.neurips.cc/paper_files/paper/2023/hash/f2059277ac6ce66e7e5543001afa8bb5-Abstract-Conference.html
+[attention-mask-viz]: https://github.com/huggingface/transformers/blob/beb9b5b02246b9b7ee81ddf938f93f44cfeaad19/src/transformers/utils/attention_visualizer.py#L139
+[gemma3n-collection]: https://huggingface.co/collections/google/gemma-3n
+[laurel]: https://arxiv.org/abs/2411.07501
+[matformer]: https://arxiv.org/abs/2310.07707
+[usm]: https://arxiv.org/abs/2303.01037
diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py
index 71ad6eaadeb5..3b9e3e65df6f 100644
--- a/src/transformers/models/auto/configuration_auto.py
+++ b/src/transformers/models/auto/configuration_auto.py
@@ -140,6 +140,10 @@
("gemma2", "Gemma2Config"),
("gemma3", "Gemma3Config"),
("gemma3_text", "Gemma3TextConfig"),
+ ("gemma3n", "Gemma3nConfig"),
+ ("gemma3n_audio", "Gemma3nAudioConfig"),
+ ("gemma3n_text", "Gemma3nTextConfig"),
+ ("gemma3n_vision", "Gemma3nVisionConfig"),
("git", "GitConfig"),
("glm", "GlmConfig"),
("glm4", "Glm4Config"),
@@ -518,6 +522,10 @@
("gemma2", "Gemma2"),
("gemma3", "Gemma3ForConditionalGeneration"),
("gemma3_text", "Gemma3ForCausalLM"),
+ ("gemma3n", "Gemma3nForConditionalGeneration"),
+ ("gemma3n_audio", "Gemma3nAudioEncoder"),
+ ("gemma3n_text", "Gemma3nForCausalLM"),
+ ("gemma3n_vision", "TimmWrapperModel"),
("git", "GIT"),
("glm", "GLM"),
("glm4", "GLM4"),
@@ -839,6 +847,9 @@
("clip_text_model", "clip"),
("aria_text", "aria"),
("gemma3_text", "gemma3"),
+ ("gemma3n_audio", "gemma3n"),
+ ("gemma3n_text", "gemma3n"),
+ ("gemma3n_vision", "gemma3n"),
("glm4v_text", "glm4v"),
("idefics3_vision", "idefics3"),
("siglip_vision_model", "siglip"),
diff --git a/src/transformers/models/auto/feature_extraction_auto.py b/src/transformers/models/auto/feature_extraction_auto.py
index d54ca4b0f5ae..3595de53bbda 100644
--- a/src/transformers/models/auto/feature_extraction_auto.py
+++ b/src/transformers/models/auto/feature_extraction_auto.py
@@ -61,6 +61,7 @@
("dpt", "DPTFeatureExtractor"),
("encodec", "EncodecFeatureExtractor"),
("flava", "FlavaFeatureExtractor"),
+ ("gemma3n", "Gemma3nAudioFeatureExtractor"),
("glpn", "GLPNFeatureExtractor"),
("granite_speech", "GraniteSpeechFeatureExtractor"),
("groupvit", "CLIPFeatureExtractor"),
diff --git a/src/transformers/models/auto/image_processing_auto.py b/src/transformers/models/auto/image_processing_auto.py
index b99dd365f57c..bee0335338cc 100644
--- a/src/transformers/models/auto/image_processing_auto.py
+++ b/src/transformers/models/auto/image_processing_auto.py
@@ -88,6 +88,7 @@
("focalnet", ("BitImageProcessor", "BitImageProcessorFast")),
("fuyu", ("FuyuImageProcessor",)),
("gemma3", ("Gemma3ImageProcessor", "Gemma3ImageProcessorFast")),
+ ("gemma3n", ("SiglipImageProcessor", "SiglipImageProcessorFast")),
("git", ("CLIPImageProcessor", "CLIPImageProcessorFast")),
("glm4v", ("Glm4vImageProcessor", "Glm4vImageProcessorFast")),
("glpn", ("GLPNImageProcessor",)),
diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py
index add9d09b0e2d..08b91dc1ea5f 100644
--- a/src/transformers/models/auto/modeling_auto.py
+++ b/src/transformers/models/auto/modeling_auto.py
@@ -132,6 +132,10 @@
("gemma2", "Gemma2Model"),
("gemma3", "Gemma3Model"),
("gemma3_text", "Gemma3TextModel"),
+ ("gemma3n", "Gemma3nModel"),
+ ("gemma3n_audio", "Gemma3nAudioEncoder"),
+ ("gemma3n_text", "Gemma3nTextModel"),
+ ("gemma3n_vision", "TimmWrapperModel"),
("git", "GitModel"),
("glm", "GlmModel"),
("glm4", "Glm4Model"),
@@ -583,6 +587,8 @@
("gemma2", "Gemma2ForCausalLM"),
("gemma3", "Gemma3ForConditionalGeneration"),
("gemma3_text", "Gemma3ForCausalLM"),
+ ("gemma3n", "Gemma3nForConditionalGeneration"),
+ ("gemma3n_text", "Gemma3nForCausalLM"),
("git", "GitForCausalLM"),
("glm", "GlmForCausalLM"),
("glm4", "Glm4ForCausalLM"),
@@ -906,6 +912,7 @@
("emu3", "Emu3ForConditionalGeneration"),
("fuyu", "FuyuForCausalLM"),
("gemma3", "Gemma3ForConditionalGeneration"),
+ ("gemma3n", "Gemma3nForConditionalGeneration"),
("git", "GitForCausalLM"),
("glm4v", "Glm4vForConditionalGeneration"),
("got_ocr2", "GotOcr2ForConditionalGeneration"),
diff --git a/src/transformers/models/auto/processing_auto.py b/src/transformers/models/auto/processing_auto.py
index bccfe3e6d57f..e5bd673f6390 100644
--- a/src/transformers/models/auto/processing_auto.py
+++ b/src/transformers/models/auto/processing_auto.py
@@ -66,6 +66,7 @@
("flava", "FlavaProcessor"),
("fuyu", "FuyuProcessor"),
("gemma3", "Gemma3Processor"),
+ ("gemma3n", "Gemma3nProcessor"),
("git", "GitProcessor"),
("glm4v", "Glm4vProcessor"),
("got_ocr2", "GotOcr2Processor"),
diff --git a/src/transformers/models/auto/tokenization_auto.py b/src/transformers/models/auto/tokenization_auto.py
index 0456e1945ca4..c8656a710746 100644
--- a/src/transformers/models/auto/tokenization_auto.py
+++ b/src/transformers/models/auto/tokenization_auto.py
@@ -236,6 +236,20 @@
"GemmaTokenizerFast" if is_tokenizers_available() else None,
),
),
+ (
+ "gemma3n",
+ (
+ "GemmaTokenizer" if is_sentencepiece_available() else None,
+ "GemmaTokenizerFast" if is_tokenizers_available() else None,
+ ),
+ ),
+ (
+ "gemma3n_text",
+ (
+ "GemmaTokenizer" if is_sentencepiece_available() else None,
+ "GemmaTokenizerFast" if is_tokenizers_available() else None,
+ ),
+ ),
("git", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
("glm", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)),
("glm4", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)),
diff --git a/src/transformers/models/gemma3n/__init__.py b/src/transformers/models/gemma3n/__init__.py
new file mode 100644
index 000000000000..229e91827036
--- /dev/null
+++ b/src/transformers/models/gemma3n/__init__.py
@@ -0,0 +1,29 @@
+# Copyright 2025 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .configuration_gemma3n import *
+ from .feature_extraction_gemma3n import *
+ from .modeling_gemma3n import *
+ from .processing_gemma3n import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/src/transformers/models/gemma3n/configuration_gemma3n.py b/src/transformers/models/gemma3n/configuration_gemma3n.py
new file mode 100644
index 000000000000..ca1a06717748
--- /dev/null
+++ b/src/transformers/models/gemma3n/configuration_gemma3n.py
@@ -0,0 +1,680 @@
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# This file was automatically generated from src/transformers/models/gemma3n/modular_gemma3n.py.
+# Do NOT edit this file manually as any edits will be overwritten by the generation of
+# the file from the modular. If any change should be done, please apply the change to the
+# modular_gemma3n.py file directly. One of our CI enforces this.
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# coding=utf-8
+# Copyright 2025 Google Inc. HuggingFace Inc. team. All rights reserved.
+#
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from collections.abc import Sequence
+from typing import Any, Optional, Union
+
+from ...configuration_utils import PretrainedConfig, layer_type_validation
+from ...modeling_rope_utils import rope_config_validation
+from ...utils import is_timm_available, logging, requires_backends
+
+
+if is_timm_available():
+ from timm.data import ImageNetInfo, infer_imagenet_subset
+
+
+logger = logging.get_logger(__name__)
+
+
+class Gemma3nTextConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`Gemma3nTextModel`]. It is used to instantiate an
+ Gemma3nTextModel model according to the specified arguments, defining the model architecture. Instantiating a
+ configuration with the defaults will yield a similar configuration to that of the Gemma 3n E4B, e.g.
+ [google/gemma-3n-E4B](https://huggingface.co/google/gemma-3n-E4B).
+
+ Configuration objects that inherit from [`Gemma3nTextConfig`] and can be used to control the model outputs. Read
+ the documentation from [`Gemma3nTextConfig`] for more information.
+
+ Args:
+ vocab_size (`int`, *optional*, defaults to 262400):
+ Vocabulary size of the Gemma3nText model. Defines the number of different tokens that can be represented by
+ the `inputs_ids` passed when calling [`Gemma3nTextModel`]
+ vocab_size_per_layer_input (`int`, *optional*, defaults to 262144):
+ Vocabulary size of the per-layer text embeddings that augment the standard embeddings.
+ hidden_size (`int`, *optional*, defaults to 2048):
+ Dimension of the hidden representations.
+ hidden_size_per_layer_input (`int`, *optional*, defaults to 256):
+ Dimension of the hidden representations for per-layer emebeddings.
+ intermediate_size (`int` or `Sequence[int]`, *optional*, defaults to 16384):
+ Dimension of the MLP representations. MatFormer configurations may wish to provide a sequence of integers
+ to account for vairable intermediate_size values across layers. In such cases,
+ `len(intermediate_size) == num_hidden_layers`.
+ num_hidden_layers (`int`, *optional*, defaults to 35):
+ Number of hidden layers in the Transformer decoder.
+ num_attention_heads (`int`, *optional*, defaults to 8):
+ Number of attention heads for each attention layer in the Transformer decoder.
+ num_key_value_heads (`int`, *optional*, defaults to 2):
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
+ `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
+ by meanpooling all the original heads within that group. For more details checkout this
+ [paper](https://arxiv.org/pdf/2305.13245.pdf). If not specified, will default to `num_attention_heads`.
+ head_dim (`int`, *optional*, defaults to 256):
+ The attention head dimension.
+ hidden_activation (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`):
+ The non-linear activation function (function or string) in the decoder. Will default to
+ `"gelu_pytorch_tanh"` if not specified. `"gelu_pytorch_tanh"` uses an approximation of the `"gelu"`
+ activation function.
+ max_position_embeddings (`int`, *optional*, defaults to 32768):
+ The maximum sequence length that this model might ever be used with.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ rms_norm_eps (`float`, *optional*, defaults to 1e-06):
+ The epsilon used by the rms normalization layers.
+ use_cache (`bool`, *optional*, defaults to `True`):
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
+ relevant if `config.is_decoder=True`.
+ pad_token_id (`int`, *optional*, defaults to 0):
+ Padding token id.
+ eos_token_id (`int`, *optional*, defaults to 1):
+ End of stream token id.
+ bos_token_id (`int`, *optional*, defaults to 2):
+ Beginning of stream token id.
+ rope_theta (`float`, *optional*, defaults to 1000000.0):
+ The base period of the RoPE embeddings.
+ rope_scaling (`Dict`, *optional*):
+ Dictionary containing the scaling configuration for the RoPE embeddings used in gloabl attention.
+ NOTE: if you apply new rope type and you expect the model to work on longer `max_position_embeddings`, we
+ recommend you to update this value accordingly.
+ Expected contents:
+ `rope_type` (`str`):
+ The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
+ 'llama3'], with 'default' being the original RoPE implementation.
+ `factor` (`float`, *optional*):
+ Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
+ most scaling types, a `factor` of x will enable the model to handle sequences of length x *
+ original maximum pre-trained length.
+ `original_max_position_embeddings` (`int`, *optional*):
+ Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
+ pretraining.
+ `attention_factor` (`float`, *optional*):
+ Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
+ computation. If unspecified, it defaults to value recommended by the implementation, using the
+ `factor` field to infer the suggested value.
+ `beta_fast` (`float`, *optional*):
+ Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
+ ramp function. If unspecified, it defaults to 32.
+ `beta_slow` (`float`, *optional*):
+ Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
+ ramp function. If unspecified, it defaults to 1.
+ `short_factor` (`List[float]`, *optional*):
+ Only used with 'longrope'. The scaling factor to be applied to short contexts (<
+ `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
+ size divided by the number of attention heads divided by 2
+ `long_factor` (`List[float]`, *optional*):
+ Only used with 'longrope'. The scaling factor to be applied to long contexts (<
+ `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
+ size divided by the number of attention heads divided by 2
+ `low_freq_factor` (`float`, *optional*):
+ Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
+ `high_freq_factor` (`float`, *optional*):
+ Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
+ rope_local_base_freq (float, *optional*, defaults to 10000.0):
+ The base period of the RoPE embeddings for local attention.
+ attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
+ Whether to use a bias in the query, key, value and output projection layers during self-attention.
+ attention_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the attention probabilities.
+ sliding_window (`int`, *optional*, defaults to 512):
+ This is the size of the sliding window used by local attention layers.
+ layer_types (`Optional`, *optional*):
+ A sequence of strings defining the attention type for that layer as either "sliding_attention" or
+ "full_attention". If not provided, `layer_types` will de inferred from `num_hidden_layers` using a pattern
+ of four "sliding_attention" layers followed one "full_attention". The last layer in the model should always
+ be a "full_attention" layer.
+ final_logit_softcapping (`float`, *optional*, defaults to 30.0):
+ Scaling factor when applying tanh softcapping on the logits.
+ altup_active_idx (`int`, *optional*, defaults to 0):
+ The index of the prediction from which AltUp will compute additional predictions or correct
+ altup_coef_clip (`float`, *optional*, defaults to 120.0):
+ The maximum amplitude of an AltUp prediction or correction coeficient weight.
+ altup_correct_scale (`bool`, *optional*, defaults to `True`):
+ If True, apply the `AltUp.correct_output_scale` to the corrected prediction at `altup_active_idx`.
+ altup_num_inputs (`int`, *optional*, defaults to 4):
+ The number of predictions that AltUp should be make given the input sequence.
+ num_kv_shared_layers (`int`, *optional*, defaults to 15):
+ The number of layer that share KV cache values. During the forward pass, the last `num_kv_shared_layers`
+ layers in the model "share" the KV values in that each local and global layer in this range uses the KV
+ cache values computed for the last local or global layer, respectively, before entering this range. The
+ value should be `num_kv_shared_layers` should be a scalar of `sliding_window_pattern`.
+ laurel_rank (int, *optional*, defaults to 64):
+ The intermediate size for the linear projections in the Learned Augmented Residual Layer.
+ activation_sparsity_pattern (Sequence[float], *optional*, defaults to `(0.95, 0.95, 0.95, 0.95, 0.95, 0.95, 0.95, 0.95, 0.95, 0.95, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0)`):
+ The sparsity factor used to extract the top-k activations for a given layer. The provided Sequence must
+ explicitly provide a sparsity value for each layer in the model.
+
+ ```python
+ >>> from transformers import Gemma3nTextModel, Gemma3nTextConfig
+
+ >>> # Initializing a Gemma3nText gemma3n_text-E4B style configuration
+ >>> configuration = Gemma3nTextConfig()
+
+ >>> # Initializing a model from the gemma3n_text-E4B style configuration
+ >>> model = Gemma3nTextModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```
+ """
+
+ model_type = "gemma3n_text"
+ keys_to_ignore_at_inference = ["past_key_values"]
+ base_model_tp_plan = {
+ "layers.*.self_attn.q_proj": "colwise",
+ "layers.*.self_attn.k_proj": "colwise",
+ "layers.*.self_attn.v_proj": "colwise",
+ "layers.*.self_attn.o_proj": "rowwise",
+ "layers.*.mlp.gate_proj": "colwise",
+ "layers.*.mlp.up_proj": "colwise",
+ "layers.*.mlp.down_proj": "rowwise",
+ }
+ base_model_pp_plan = {
+ "embed_tokens": (["input_ids"], ["inputs_embeds"]),
+ "layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
+ "norm": (["hidden_states"], ["hidden_states"]),
+ }
+
+ def __init__(
+ self,
+ vocab_size: int = 262_400,
+ vocab_size_per_layer_input: int = 262_144,
+ hidden_size: int = 2048,
+ hidden_size_per_layer_input: int = 256,
+ intermediate_size: Union[int, Sequence[int]] = 16_384,
+ num_hidden_layers: int = 35,
+ num_attention_heads: int = 8,
+ num_key_value_heads: int = 2,
+ head_dim: int = 256,
+ hidden_activation: str = "gelu_pytorch_tanh",
+ max_position_embeddings: int = 32_768,
+ initializer_range: float = 0.02,
+ rms_norm_eps: float = 1e-6,
+ use_cache: bool = True,
+ pad_token_id: int = 0,
+ eos_token_id: int = 1,
+ bos_token_id: int = 2,
+ rope_theta: float = 1_000_000.0,
+ rope_scaling: Optional[dict[str, Any]] = None,
+ rope_local_base_freq: float = 10_000.0,
+ attention_bias: bool = False,
+ attention_dropout: float = 0.0,
+ sliding_window: int = 512,
+ layer_types: Optional[Sequence[str]] = None,
+ final_logit_softcapping: float = 30.0,
+ altup_active_idx: int = 0,
+ altup_coef_clip: float = 120.0,
+ altup_correct_scale: bool = True,
+ altup_num_inputs: int = 4,
+ num_kv_shared_layers: int = 15,
+ laurel_rank: int = 64,
+ activation_sparsity_pattern: Optional[Union[float, Sequence[float]]] = (0.95,) * 10 + (0.0,) * 25,
+ **kwargs,
+ ):
+ super().__init__(
+ pad_token_id=pad_token_id,
+ bos_token_id=bos_token_id,
+ eos_token_id=eos_token_id,
+ **kwargs,
+ )
+
+ if isinstance(intermediate_size, Sequence) and (intsize_len := len(intermediate_size)) != num_hidden_layers:
+ raise ValueError(
+ "intermediate_size must have an explicit intermediate size for every layer or one for all layers. "
+ f"Expected {num_hidden_layers} values but got {intsize_len}."
+ )
+ elif not isinstance(intermediate_size, Sequence):
+ intermediate_size = [intermediate_size] * num_hidden_layers
+
+ self.vocab_size = vocab_size
+ self.vocab_size_per_layer_input = vocab_size_per_layer_input
+ self.max_position_embeddings = max_position_embeddings
+ self.hidden_size = hidden_size
+ self.intermediate_size = intermediate_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.head_dim = head_dim
+ self.num_key_value_heads = num_key_value_heads
+ self.initializer_range = initializer_range
+ self.rms_norm_eps = rms_norm_eps
+ self.use_cache = use_cache
+ self.rope_theta = rope_theta
+ self.attention_bias = attention_bias
+ self.attention_dropout = attention_dropout
+ self.hidden_activation = hidden_activation
+ self.sliding_window = sliding_window
+ self.final_logit_softcapping = final_logit_softcapping
+ self.layer_types = layer_types
+
+ self.rope_local_base_freq = rope_local_base_freq
+ self.rope_scaling = rope_scaling
+ rope_config_validation(self)
+
+ if layer_types is None:
+ self.layer_types = [
+ "full_attention" if i % 5 == 0 else "sliding_attention" for i in range(self.num_hidden_layers)
+ ]
+ else:
+ self.layer_types = layer_types
+
+ layer_type_validation(self.layer_types)
+
+ self.hidden_size_per_layer_input = hidden_size_per_layer_input
+ self.num_kv_shared_layers = num_kv_shared_layers
+
+ self.altup_active_idx = altup_active_idx
+ self.altup_coef_clip = altup_coef_clip
+ self.altup_correct_scale = altup_correct_scale
+ self.altup_num_inputs = altup_num_inputs
+
+ self.laurel_rank = laurel_rank
+
+ if activation_sparsity_pattern is None:
+ activation_sparsity_pattern = [0.0] * num_hidden_layers
+
+ if (len_asp := len(activation_sparsity_pattern)) != num_hidden_layers:
+ raise ValueError(
+ "activation_sparsity_pattern must have an explicit activation sparsity value for every layer."
+ f"Expected {num_hidden_layers} values but got {len_asp}."
+ )
+ self.activation_sparsity_pattern = activation_sparsity_pattern
+
+
+class Gemma3nAudioConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`Gemma3nAudioEncoder`], based on Gogole's
+ [Universal Speech Model](). It is used to instantiate an Gemma3nAudioEncoder model according to the specified
+ arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar
+ configuration to that of the Gemma 3n E4B, e.g. [google/gemma-3n-E4B](https://huggingface.co/google/gemma-3n-E4B).
+
+ Configuration objects that inherit from [`Gemma3nAudioConfig`] and can be used to control the model outputs. Read
+ the documentation from [`Gemma3nAudioConfig`] for more information.
+
+ Args:
+ vocab_size (`int`, *optional*, defaults to 128):
+ Vocabulary size of the additional hard-token embeddings for audio model. These augment the embeddings
+ included in the `Gemma3nTextModel` to provide, e.g., the end of audio and audio soft token placeholder
+ tokens when converting `input_ids` to embeddings in the `Gemma3nForConditionalGeneration` model.
+ vocab_offset (`int`, *optional*, defaults to 262272):
+ Offset between the tokenizer vocab index for the token ids embedded by `Gemma3nMultimodalEmbedder` and the
+ 0-indexed `Gemma3nMultimodalEmbedder.embedding` table.
+ input_feat_size (`int`, *optional*, defaults to 128):
+ The number of channels in each mel-spectrogram frame.
+ hidden_size (`int`, *optional*, defaults to 1536):
+ Dimension of the hidden representations.
+ rms_norm_eps (`float`, *optional*, defaults to 1e-06):
+ The epsilon used by the rms normalization layers.
+ gradient_clipping (`float`, *optional*, defaults to 10000000000.0):
+ Clipping value used to stablize extremely large gradient values.
+ conf_attention_chunk_size (`int`, *optional*, defaults to 12):
+ The sub-sequence size for local attention processing inside the Conformer ("conf") section of the
+ Universal Speech Model.
+ conf_attention_context_left (`int`, *optional*, defaults to 13):
+ The left context size of the local attention inside the Conformer ("conf") section of the
+ Universal Speech Model.
+ conf_attention_context_right (`int`, *optional*, defaults to 0):
+ The right context size of the local attention inside the Conformer ("conf") section of the
+ Universal Speech Model.
+ conf_attention_logit_cap (`float`, *optional*, defaults to 50.0):
+ Logit cap applied during local attention inside the Conformer ("conf") section of the
+ Universal Speech Model.
+ conf_num_attention_heads (`int`, *optional*, defaults to 8):
+ The number of attention heads in local attention inside the Conformer ("conf") section of the
+ Universal Speech Model.
+ conf_num_hidden_layers (`int`, *optional*, defaults to 12):
+ The number of layers that use local attention inside the Conformer ("conf") section of the
+ Universal Speech Model.
+ conf_conv_kernel_size (`int`, *optional*, defaults to 5):
+ Convolution kernel size for the conformer block inside the Conformer ("conf") section of the
+ Universal Speech Model.
+ conf_reduction_factor (`int`, *optional*, defaults to 4):
+ Reduction factor used in the conformer block inside the Conformer ("conf") section of the
+ Universal Speech Model.
+ conf_residual_weight (`float`, *optional*, defaults to 0.5):
+ Residual connection weight inside the Conformer ("conf") section of the
+ Universal Speech Model.
+ sscp_conv_channel_size (`tuple(int, int)`, *optional*, defaults to `(128, 32)`):
+ The channel sizes for the first and second convolutional layers in the Sub-sample Convolution Projection
+ ("sscp") section of the Universal Speech Model.
+ sscp_conv_group_norm_eps (`float`, *optional*, defaults to 0.001):
+ Epsilon used in group normalization in the subsample convolution projection in the Sub-sample Convolution
+ Projection ("sscp") section of the Universal Speech Model.
+ sscp_conv_kernel_size (`tuple(tuple(int, int), tuple(int, int))`, *optional*, defaults to `((3, 3), (3, 3))`):
+ Kernel sizes of the two convolutional layers in the subsample convolution projection in the Sub-sample
+ Convolution Projection ("sscp") section of the Universal Speech Model. The kernel sizes are specified as a
+ tuple of height and width for each layer, where the height corresponds to the time dimension and the width
+ corresponds to the frequency dimension.
+ sscp_conv_stride_size (`tuple(tuple(int, int), tuple(int, int))`, *optional*, defaults to `((2, 2), (2, 2))`):
+ Stride sizes of the two convolutional layers in the subsample convolution projection in the Sub-sample
+ Convolution Projection ("sscp") section of the Universal Speech Model. The stride sizes are specified as a
+ tuple of height and width for each layer, where the height corresponds to the time dimension and the width
+ corresponds to the frequency dimension.
+
+ Example:
+
+ ```python
+ >>> from transformers import Gemma3nAudioConfig, Gemma3nAudioEncoder
+
+ >>> # Initializing a Gemma3nAudioEncoder gemma3n_audio-E4B-style configuration
+ >>> configuration = Gemma3nAudioConfig()
+
+ >>> # Initializing a model from the gemma3n_audio-E4B style configuration
+ >>> model = Gemma3nAudioEncoder(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```
+ """
+
+ model_type = "gemma3n_audio"
+
+ def __init__(
+ self,
+ vocab_size: int = 128,
+ vocab_offset: int = 262_144 + 128, # text vocab size + vision vocab size
+ input_feat_size: int = 128,
+ hidden_size: int = 1536,
+ rms_norm_eps: float = 1e-6,
+ gradient_clipping: float = 10_000_000_000.0,
+ conf_attention_chunk_size: int = 12,
+ conf_attention_context_left: int = 13,
+ conf_attention_context_right: int = 0,
+ conf_attention_logit_cap: float = 50.0,
+ conf_num_attention_heads: int = 8,
+ conf_num_hidden_layers: int = 12,
+ conf_conv_kernel_size: int = 5,
+ conf_reduction_factor: int = 4,
+ conf_residual_weight: float = 0.5,
+ sscp_conv_channel_size: tuple[int, int] = (128, 32),
+ sscp_conv_group_norm_eps: float = 1e-3,
+ sscp_conv_kernel_size: tuple[tuple[int, int], tuple[int, int]] = (
+ (3, 3),
+ (3, 3),
+ ),
+ sscp_conv_stride_size: tuple[tuple[int, int], tuple[int, int]] = (
+ (2, 2),
+ (2, 2),
+ ),
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ self.input_feat_size = input_feat_size
+ self.hidden_size = hidden_size
+ self.rms_norm_eps = rms_norm_eps
+ self.vocab_size = vocab_size
+ self.vocab_offset = vocab_offset
+ self.gradient_clipping = gradient_clipping
+ self.conf_attention_chunk_size = conf_attention_chunk_size
+ self.conf_attention_context_left = conf_attention_context_left
+ self.conf_attention_context_right = conf_attention_context_right
+ self.conf_attention_logit_cap = conf_attention_logit_cap
+ self.conf_num_attention_heads = conf_num_attention_heads
+ self.conf_num_hidden_layers = conf_num_hidden_layers
+ self.conf_conv_kernel_size = conf_conv_kernel_size
+ self.conf_reduction_factor = conf_reduction_factor
+ self.conf_residual_weight = conf_residual_weight
+ self.sscp_conv_channel_size = sscp_conv_channel_size
+ self.sscp_conv_group_norm_eps = sscp_conv_group_norm_eps
+ self.sscp_conv_kernel_size = sscp_conv_kernel_size
+ self.sscp_conv_stride_size = sscp_conv_stride_size
+
+
+class Gemma3nVisionConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration for a timm backbone [`TimmWrapper`]. It is used to
+ instantiate an timm model model according to the specified arguments, defining the model architecture.
+ Instantiating a configuration with the defaults will yield a similar configuration to that of the Gemma 3n E4B
+ vision tower, e.g. [google/gemma-3n-E4B](https://huggingface.co/google/gemma-3n-E4B).
+
+ Configuration objects inherit from [`Gemma3nVisionConfig`] and can be used to control the model outputs. Read the
+ documentation from [`Gemma3nVisionConfig`] for more information.
+
+ Config loads imagenet label descriptions and stores them in `id2label` attribute, `label2id` attribute for default
+ imagenet models is set to `None` due to occlusions in the label descriptions.
+
+ Args:
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ do_pooling (`bool`, *optional*, defaults to `False`):
+ Whether to do pooling for the last_hidden_state in `TimmWrapper` or not.
+ architecture (`str`, *optional*, defaults to `"mobilenetv5_300m_enc"`):
+ Determines vision architecture for TimmWrapper.
+ hidden_size (`int`, *optional*, defaults to 2048):
+ Dimension of the hidden representations.
+ vocab_size (`int`, *optional*, defaults to 128):
+ Vocabulary size of the additional hard-token embeddings for vision model.
+ vocab_offset (`int`, *optional*, defaults to 262144):
+ Offset between the tokenizer vocab index for the token ids embedded by `Gemma3nMultimodalEmbedder` and the
+ 0-indexed `Gemma3nMultimodalEmbedder.embedding` table.
+ rms_norm_eps (`float`, *optional*, defaults to 1e-06):
+ The epsilon used by the rms normalization layers.
+
+ Example:
+ ```python
+ >>> from transformers import Gemma3nVisionConfig, TimmWrapper
+
+ >>> # Initializing a TimmWrapper gemma3n_vision-E4B-style configuration
+ >>> configuration = Gemma3nVisionConfig()
+
+ >>> # Initializing a gemma3n_vision-E4B-style TimmWrapper from the configuration
+ >>> model = TimmWrapper(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```
+ """
+
+ model_type = "gemma3n_vision"
+
+ def __init__(
+ self,
+ initializer_range: float = 0.02,
+ do_pooling: bool = False,
+ architecture: str = "mobilenetv5_300m_enc",
+ hidden_size: int = 2048,
+ vocab_size: int = 128,
+ vocab_offset: int = 262_144,
+ rms_norm_eps: float = 1e-06,
+ model_args: Optional[dict] = None,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ self.initializer_range = initializer_range
+ self.do_pooling = do_pooling
+ self.model_args = model_args # named "model_args" for BC with timm
+ self.architecture = architecture
+ self.hidden_size = hidden_size
+ self.vocab_size = vocab_size
+ self.vocab_offset = vocab_offset
+ self.rms_norm_eps = rms_norm_eps
+
+ @classmethod
+ def from_dict(cls, config_dict: dict[str, Any], **kwargs):
+ label_names = config_dict.get("label_names", None)
+ is_custom_model = "num_labels" in kwargs or "id2label" in kwargs
+
+ # if no labels added to config, use imagenet labeller in timm
+ if label_names is None and not is_custom_model:
+ requires_backends(cls, ["timm"])
+ imagenet_subset = infer_imagenet_subset(config_dict)
+ if imagenet_subset:
+ dataset_info = ImageNetInfo(imagenet_subset)
+ synsets = dataset_info.label_names()
+ label_descriptions = dataset_info.label_descriptions(as_dict=True)
+ label_names = [label_descriptions[synset] for synset in synsets]
+
+ if label_names is not None and not is_custom_model:
+ kwargs["id2label"] = dict(enumerate(label_names))
+
+ # if all label names are unique, create label2id mapping as well
+ if len(set(label_names)) == len(label_names):
+ kwargs["label2id"] = {name: i for i, name in enumerate(label_names)}
+ else:
+ kwargs["label2id"] = None
+
+ # timm config stores the `num_classes` attribute in both the root of config and in the "pretrained_cfg" dict.
+ # We are removing these attributes in order to have the native `transformers` num_labels attribute in config
+ # and to avoid duplicate attributes
+ num_labels_in_kwargs = kwargs.pop("num_labels", None)
+ num_labels_in_dict = config_dict.pop("num_classes", None)
+
+ # passed num_labels has priority over num_classes in config_dict
+ kwargs["num_labels"] = num_labels_in_kwargs or num_labels_in_dict
+
+ # pop num_classes from "pretrained_cfg",
+ # it is not necessary to have it, only root one is used in timm
+ if "pretrained_cfg" in config_dict and "num_classes" in config_dict["pretrained_cfg"]:
+ config_dict["pretrained_cfg"].pop("num_classes", None)
+
+ return super().from_dict(config_dict, **kwargs)
+
+ def to_dict(self) -> dict[str, Any]:
+ output = super().to_dict()
+ output["num_classes"] = self.num_labels
+ output["label_names"] = list(self.id2label.values())
+ output.pop("id2label", None)
+ output.pop("label2id", None)
+ return output
+
+
+class Gemma3nConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`Gemma3nForConditionalGeneration`]. It is used to
+ instantiate a Gemma3nForConditionalGeneration according to the specified arguments, defining the model
+ architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of
+ Gemma3n-E4B.
+
+ e.g. [google/gemma-3n-E4B](https://huggingface.co/google/gemma-3n-E4B)
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ text_config (`Union[Gemma3nTextConfig, dict]`, *optional*):
+ The config object of the text backbone.
+ vision_config (`Union[AutoConfig, dict]`, *optional*):
+ Custom vision config or dict.
+ audio_config (`Union[AutoConfig, dict]`, *optional*):
+ Custom audio config or dict.
+ audio_soft_tokens_per_image (`int`, *optional*, defaults to 188):
+ The number of soft tokens per audio clip.
+ vision_soft_tokens_per_image (`int`, *optional*, defaults to 256):
+ The number of soft tokens per image.
+ boi_token_id (`int`, *optional*, defaults to 255999):
+ The begin-of-image token index to wrap the image prompt.
+ eoi_token_id (`int`, *optional*, defaults to 262144):
+ The end-of-image token index to wrap the image prompt.
+ image_token_id (`int`, *optional*, defaults to 262145):
+ The image token index to encode the image prompt.
+ boa_token_id (`int`, *optional*, defaults to 256000):
+ The begin-of-audio token index to wrap the audio prompt.
+ eoa_token_id (`int`, *optional*, defaults to 262272):
+ The end-of-audio token index to wrap the audio prompt.
+ audio_token_id (`int`, *optional*, defaults to 262273):
+ The audio token index to encode the audio prompt.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+
+
+ Example:
+
+ ```python
+ >>> from transformers import Gemma3nForConditionalGeneration, Gemma3nConfig, Gemma3nTextConfig
+
+ >>> # Initializing a MobileNet vision config, which is loaded from TIMM
+ >>> vision_config = Gemma3nVisionConfig()
+
+ >>> # Initializing a Gemma3n Audio config
+ >>> audio_config = Gemma3nAudioConfig()
+
+ >>> # Initializing a Gemma3n Text config
+ >>> text_config = Gemma3nTextConfig()
+
+ >>> # Initializing a Gemma3n gemma-3-4b style configuration
+ >>> configuration = Gemma3nConfig(text_config, vision_config, audio_config)
+
+ >>> # Initializing a model from the gemma-3-4b style configuration
+ >>> model = Gemma3nTextConfig(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "gemma3n"
+ sub_configs = {
+ "text_config": Gemma3nTextConfig,
+ "vision_config": Gemma3nVisionConfig,
+ "audio_config": Gemma3nAudioConfig,
+ }
+
+ def __init__(
+ self,
+ text_config: Optional[Union[Gemma3nTextConfig, dict[str, Any]]] = None,
+ vision_config: Optional[Union[Gemma3nVisionConfig, dict[str, Any]]] = None,
+ audio_config: Optional[Union[Gemma3nAudioConfig, dict[str, Any]]] = None,
+ audio_soft_tokens_per_image: int = 188,
+ vision_soft_tokens_per_image: int = 256,
+ boi_token_id: int = 255_999,
+ eoi_token_id: int = 262_144,
+ image_token_id: int = 262_145,
+ boa_token_id: int = 256_000,
+ eoa_token_id: int = 262_272,
+ audio_token_id: int = 262_273,
+ initializer_range: float = 0.02,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+
+ if isinstance(text_config, dict):
+ text_config = Gemma3nTextConfig(**text_config)
+ elif text_config is None:
+ text_config = Gemma3nTextConfig()
+ logger.info("text_config is None. Using default Gemma3nTextConfig.")
+
+ if isinstance(vision_config, dict):
+ vision_config = Gemma3nVisionConfig(**vision_config)
+ elif vision_config is None:
+ vision_config = Gemma3nVisionConfig()
+ logger.info("vision_config is None. Using default Gemma3nVisionConfig.")
+
+ if isinstance(audio_config, dict):
+ audio_config = Gemma3nAudioConfig(**audio_config)
+ elif audio_config is None:
+ audio_config = Gemma3nAudioConfig()
+ logger.info("audio_config is None. Using default Gemma3nAudioConfig.")
+
+ self.text_config = text_config
+ self.vision_config = vision_config
+ self.audio_config = audio_config
+
+ self.audio_soft_tokens_per_image = audio_soft_tokens_per_image
+ self.vision_soft_tokens_per_image = vision_soft_tokens_per_image
+ self.boi_token_id = boi_token_id
+ self.eoi_token_id = eoi_token_id
+ self.image_token_id = image_token_id
+ self.boa_token_id = boa_token_id
+ self.eoa_token_id = eoa_token_id
+ self.audio_token_id = audio_token_id
+ self.initializer_range = initializer_range
+
+
+__all__ = ["Gemma3nAudioConfig", "Gemma3nConfig", "Gemma3nTextConfig", "Gemma3nVisionConfig"]
diff --git a/src/transformers/models/gemma3n/convert_gemma3n_weights.py b/src/transformers/models/gemma3n/convert_gemma3n_weights.py
new file mode 100644
index 000000000000..2f25ca56d465
--- /dev/null
+++ b/src/transformers/models/gemma3n/convert_gemma3n_weights.py
@@ -0,0 +1,807 @@
+# coding=utf-8
+# Copyright 2025 Google Inc. HuggingFace Inc. team. All rights reserved.
+#
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+r"""Utility to convert Gemma models from Orbax to HF Transformers checkpoint.
+
+python src/transformers/models/gemma3n/convert_gemma3n_weights.py \
+ --variant='gemma3n_e4b' \
+ --tokenizer_path="$HOME/nano3/checkpoints/tokenizer/gemma-3n-tokenizer.model" \
+ --checkpoint_path="$HOME/nano3/checkpoints/g251_orbax/" \
+ --output_path="$HOME/nano3/checkpoints/g251_vision_encoder/"
+"""
+
+import json
+import os
+import re
+from collections.abc import Iterable, Mapping
+from typing import Any
+
+import accelerate
+import numpy as np
+import torch
+import tree
+from absl import app, flags, logging
+from orbax import checkpoint as obc
+
+from transformers import (
+ Gemma3nAudioConfig,
+ Gemma3nAudioFeatureExtractor,
+ Gemma3nConfig,
+ Gemma3nForConditionalGeneration,
+ Gemma3nProcessor,
+ Gemma3nTextConfig,
+ Gemma3nVisionConfig,
+ GemmaTokenizerFast,
+ GenerationConfig,
+ SiglipImageProcessorFast,
+)
+from transformers.image_utils import PILImageResampling
+
+
+# ==== Internal Constants and Classes ====
+
+
+_CHAT_TEMPLATE = """{{ bos_token }}
+{%- if messages[0]['role'] == 'system' -%}
+ {%- if messages[0]['content'] is string -%}
+ {%- set first_user_prefix = messages[0]['content'] + '\n\n' -%}
+ {%- else -%}
+ {%- set first_user_prefix = messages[0]['content'][0]['text'] + '\n\n' -%}
+ {%- endif -%}
+ {%- set loop_messages = messages[1:] -%}
+{%- else -%}
+ {%- set first_user_prefix = "" -%}
+ {%- set loop_messages = messages -%}
+{%- endif -%}
+{%- for message in loop_messages -%}
+ {%- if (message['role'] == 'user') != (loop.index0 % 2 == 0) -%}
+ {{ raise_exception("Conversation roles must alternate user/assistant/user/assistant/...") }}
+ {%- endif -%}
+ {%- if (message['role'] == 'assistant') -%}
+ {%- set role = "model" -%}
+ {%- else -%}
+ {%- set role = message['role'] -%}
+ {%- endif -%}
+ {{ '' + role + '\n' + (first_user_prefix if loop.first else "") }}
+ {%- if message['content'] is string -%}
+ {{ message['content'] | trim }}
+ {%- elif message['content'] is iterable -%}
+ {%- for item in message['content'] -%}
+ {%- if item['type'] == 'audio' -%}
+ {{ '' }}
+ {%- elif item['type'] == 'image' -%}
+ {{ '' }}
+ {%- elif item['type'] == 'text' -%}
+ {{ item['text'] | trim }}
+ {%- endif -%}
+ {%- endfor -%}
+ {%- else -%}
+ {{ raise_exception("Invalid content type") }}
+ {%- endif -%}
+ {{ '\n' }}
+{%- endfor -%}
+{%- if add_generation_prompt -%}
+ {{'model\n'}}
+{%- endif -%}
+"""
+
+_DTYPES = {"float32", "bfloat16", "float16"}
+
+_SLIDING_WINDOW_PATTERN = 5
+
+_AUDIO_ENCODER_PARAMETER = "AudioEncoder/encoder"
+_AUDIO_ENCODER_CONFORMER = f"{_AUDIO_ENCODER_PARAMETER}/conformer/stacked_layers"
+_AUDIO_ENCODER_SSCP = f"{_AUDIO_ENCODER_PARAMETER}/feature"
+
+_TRANSFORMER_PARAMETER = "transformer"
+_TRANSFORMER_ALTUP_PROJ = f"{_TRANSFORMER_PARAMETER}/altup_projection_"
+_TRANSFORMER_ALTUP_UNEMB = f"{_TRANSFORMER_PARAMETER}/altup_unembed_projection_"
+_TRANSFORMER_DECODER_BLOCK = f"{_TRANSFORMER_PARAMETER}/stacked_layers/attention_type_"
+_TRANSFORMER_DECODER_BLOCK_LEN = len(_TRANSFORMER_DECODER_BLOCK)
+_TRANSFORMER_EMBEDDER = f"{_TRANSFORMER_PARAMETER}/embedder"
+_TRANSFORMER_FINAL_NORM = "transformer/final_norm"
+_TRANSFORMER_POST_TRAINING_PREFIX = "rlx_networks/policy_network/"
+_TRANSFORMER_POST_TRAINING_PREFIX_LEN = len(_TRANSFORMER_POST_TRAINING_PREFIX)
+
+# _MOBILE_NET_CONFIG = Gemma3nVisionConfig.from_pretrained("")
+
+_MOBILE_NET_PREFIX = "mobilenet"
+_MOBILE_NET_TIMM_SUMMED_BLOCK_SIZES = [3, 8, 45, 84]
+_MOBILE_NET_CONV = "block_group_conv2d_"
+_MOBILE_NET_FIB = "block_group_fused_ib_"
+_MOBILE_NET_MQA = "block_group_mmqa_"
+_MOBILE_NET_MSFA = "block_adapter_"
+_MOBILE_NET_UIB = "block_group_uib_"
+_MOBILE_NET_UIB_HAS_DW_START = {
+ (1, 0),
+ (1, 1),
+ (1, 2),
+ (1, 3),
+ (1, 4),
+ (2, 0),
+ (2, 1),
+ (2, 2),
+ (2, 3),
+ (2, 4),
+ (2, 5),
+ (2, 6),
+ (2, 7),
+ (3, 0),
+}
+_MOBILE_NET_UIB_HAS_DW_MID = {
+ (1, 0),
+ (2, 0),
+ (3, 0),
+}
+
+_VARIANT_GEMMA_3_2B = "gemma3n_e2b"
+_VARIANT_GEMMA_3_4B = "gemma3n_e4b"
+_VARIANTS: Mapping[str, Gemma3nConfig] = {
+ _VARIANT_GEMMA_3_2B: Gemma3nConfig(
+ text_config=Gemma3nTextConfig(
+ intermediate_size=2048 * 4,
+ num_hidden_layers=30,
+ activation_sparsity_pattern=(0.95,) * 10 + (0.0,) * 20,
+ num_kv_shared_layers=10,
+ ),
+ vision_config=Gemma3nVisionConfig(),
+ audio_config=Gemma3nAudioConfig(),
+ ),
+ _VARIANT_GEMMA_3_4B: Gemma3nConfig(
+ text_config=Gemma3nTextConfig(),
+ vision_config=Gemma3nVisionConfig(),
+ audio_config=Gemma3nAudioConfig(),
+ ),
+}
+
+
+# ==== Flags ====
+
+_AUDIO_DTYPE = flags.DEFINE_enum(
+ name="audio_dtype",
+ default="bfloat16",
+ help="The floating point precision (aka dtype) of the model.",
+ enum_values=_DTYPES,
+)
+
+_CHECKPOINT_PATH = flags.DEFINE_string(
+ name="checkpoint_path",
+ default=None,
+ help="Path to the Orbax checkpoint.",
+ required=True,
+)
+
+_INCLUDE_CHAT_TEMPLATE = flags.DEFINE_bool(
+ name="include_chat_template", default=False, help="If true, will save the default chat template with the tokenizer"
+)
+
+_OUTPUT_PATH = flags.DEFINE_string(
+ name="output_path",
+ default=None,
+ help="Path to store the HF checkpoint.",
+ required=True,
+)
+
+_TRANSFORMER_DTYPE = flags.DEFINE_enum(
+ name="text_dtype",
+ default="bfloat16",
+ help="The floating point precision (aka dtype) of the model.",
+ enum_values=_DTYPES,
+)
+
+_TOKENIZER_PATH = flags.DEFINE_string(
+ name="tokenizer_path",
+ default=None,
+ help="Path to the SentencePiece model file.",
+ required=True,
+)
+
+_VARIANT = flags.DEFINE_enum(
+ name="variant",
+ default=_VARIANT_GEMMA_3_4B,
+ help="The model variant to convert.",
+ enum_values=set(_VARIANTS.keys()),
+)
+
+_VERBOSE = flags.DEFINE_bool(
+ name="verbose",
+ default=False,
+ help="If true, log the path, shape, and dtype of every converted layer.",
+)
+
+_VISION_DTYPE = flags.DEFINE_enum(
+ name="vision_dtype",
+ default="bfloat16",
+ help="The floating point precision (aka dtype) of the model.",
+ enum_values=_DTYPES,
+)
+
+
+def convert_audio_encoder_weights(
+ config: Gemma3nAudioConfig,
+ path: str,
+ param: str,
+ weights: np.ndarray,
+) -> Iterable[tuple[str, np.ndarray]]:
+ converted_paths: list[str] = []
+ converted_weights: list[Any] = []
+
+ if path.startswith(_AUDIO_ENCODER_CONFORMER):
+ assert weights.shape[0] == config.conf_num_hidden_layers
+
+ for i, matrix in enumerate(weights):
+ if "fflayer_end" in path:
+ base = f"conformer.{i}.ffw_layer_end"
+
+ if path.endswith("ffn_layer1"):
+ converted_paths.append(f"{base}.ffw_layer_1.weight")
+ converted_weights.append(matrix.transpose())
+ elif path.endswith("ffn_layer2"):
+ converted_paths.append(f"{base}.ffw_layer_2.weight")
+ converted_weights.append(matrix.transpose())
+ elif path.endswith("post_layer_norm"):
+ converted_paths.append(f"{base}.post_layer_norm.weight")
+ converted_weights.append(matrix)
+ elif path.endswith("pre_layer_norm"):
+ converted_paths.append(f"{base}.pre_layer_norm.weight")
+ converted_weights.append(matrix)
+ elif "fflayer_start" in path:
+ base = f"conformer.{i}.ffw_layer_start"
+
+ if path.endswith("ffn_layer1"):
+ converted_paths.append(f"{base}.ffw_layer_1.weight")
+ converted_weights.append(matrix.transpose())
+ elif path.endswith("ffn_layer2"):
+ converted_paths.append(f"{base}.ffw_layer_2.weight")
+ converted_weights.append(matrix.transpose())
+ elif path.endswith("post_layer_norm"):
+ converted_paths.append(f"{base}.post_layer_norm.weight")
+ converted_weights.append(matrix)
+ elif path.endswith("pre_layer_norm"):
+ converted_paths.append(f"{base}.pre_layer_norm.weight")
+ converted_weights.append(matrix)
+ elif path.endswith("final_ln"):
+ converted_paths.append(f"conformer.{i}.norm.weight")
+ converted_weights.append(matrix)
+ elif "lconv" in path:
+ base = f"conformer.{i}.lconv1d"
+
+ if path.endswith("conv_norm"):
+ converted_paths.append(f"{base}.conv_norm.weight")
+ converted_weights.append(matrix)
+ elif path.endswith("depthwise_conv1d"):
+ converted_paths.append(f"{base}.depthwise_conv1d.weight")
+ converted_weights.append(matrix.transpose())
+ elif path.endswith("linear_end"):
+ converted_paths.append(f"{base}.linear_end.weight")
+ converted_weights.append(matrix.transpose())
+ elif path.endswith("linear_start"):
+ converted_paths.append(f"{base}.linear_start.weight")
+ converted_weights.append(matrix.transpose())
+ elif path.endswith("ln"):
+ converted_paths.append(f"{base}.pre_layer_norm.weight")
+ converted_weights.append(matrix)
+ elif "trans_atten" in path:
+ base = f"conformer.{i}.attention"
+
+ if param == "per_dim_scale":
+ converted_paths.append(f"{base}.attn.per_dim_scale")
+ converted_weights.append(matrix)
+
+ if path.endswith("query_key_value_projection"):
+ converted_paths.extend(
+ [f"{base}.attn.q_proj.weight", f"{base}.attn.k_proj.weight", f"{base}.attn.v_proj.weight"]
+ )
+ converted_weights.extend(
+ [
+ m.reshape(config.hidden_size, config.hidden_size).transpose()
+ for m in matrix.transpose(1, 0, 2, 3)
+ ]
+ )
+ elif path.endswith("pos_proj"):
+ converted_paths.append(f"{base}.attn.relative_position_embedding.pos_proj.weight")
+ converted_weights.append(matrix.reshape(config.hidden_size, config.hidden_size).transpose())
+ elif path.endswith("post"):
+ converted_paths.append(f"{base}.post.weight")
+ converted_weights.append(matrix.transpose(2, 0, 1).reshape(config.hidden_size, config.hidden_size))
+ elif path.endswith("post_norm"):
+ converted_paths.append(f"{base}.post_norm.weight")
+ converted_weights.append(matrix)
+ elif path.endswith("pre_norm"):
+ converted_paths.append(f"{base}.pre_attn_norm.weight")
+ converted_weights.append(matrix)
+ elif path.startswith(_AUDIO_ENCODER_SSCP):
+ if path.endswith("input_proj"):
+ converted_paths.append("subsample_conv_projection.input_proj_linear.weight")
+ converted_weights.append(
+ weights.transpose(2, 0, 1).reshape(config.hidden_size, config.sscp_conv_channel_size[1] ** 2)
+ )
+ elif "norm_" in path:
+ index = int(path[-1])
+ converted_paths.append(f"subsample_conv_projection.conv_{index}.norm.weight")
+ converted_weights.append(weights)
+ elif "subsampling_" in path:
+ index = int(path[-1])
+ converted_paths.append(f"subsample_conv_projection.conv_{index}.conv.weight")
+ converted_weights.append(weights.transpose(3, 2, 0, 1))
+
+ if (cpl := len(converted_paths)) != (cwl := len(converted_weights)):
+ raise ValueError(
+ "The `converted_paths` and `converted_weights` should be the same "
+ f"length. Got {cpl} and {cwl}, respectively, for {path}."
+ )
+
+ return zip(converted_paths, converted_weights)
+
+
+def convert_transformer_weights(
+ config: Gemma3nTextConfig,
+ path: str,
+ param: str,
+ weights: np.ndarray,
+) -> Iterable[tuple[str, np.ndarray]]:
+ if path.startswith(_TRANSFORMER_POST_TRAINING_PREFIX):
+ path = path[_TRANSFORMER_POST_TRAINING_PREFIX_LEN:]
+
+ converted_paths: list[str] = []
+ converted_weights: list[Any] = []
+
+ if path.startswith(_TRANSFORMER_ALTUP_PROJ):
+ index = int(path[-1])
+ converted_paths.append(f"altup_projections.{index}.weight")
+ converted_weights.append(weights.transpose())
+ elif path.startswith(_TRANSFORMER_ALTUP_UNEMB):
+ index = int(path[-1])
+ converted_paths.append(f"altup_unembed_projections.{index}.weight")
+ converted_weights.append(weights.transpose())
+ elif path.startswith(_TRANSFORMER_DECODER_BLOCK):
+ attention_type_index = int(path[_TRANSFORMER_DECODER_BLOCK_LEN])
+ assert weights.shape[0] == config.num_hidden_layers / _SLIDING_WINDOW_PATTERN
+
+ for i, matrix in enumerate(weights):
+ layer_idx = _SLIDING_WINDOW_PATTERN * i + attention_type_index
+ base_path = f"layers.{layer_idx}"
+
+ if "altup" in path:
+ altup_path = f"{base_path}.altup"
+
+ if param == "correct_output_scale":
+ converted_paths.append(f"{altup_path}.correct_output_scale")
+ converted_weights.append(matrix)
+ elif param == "correction_coefs":
+ converted_paths.append(f"{altup_path}.correction_coefs.weight")
+ converted_weights.append(matrix.transpose())
+ elif param == "prediction_coefs":
+ converted_paths.append(f"{altup_path}.prediction_coefs.weight")
+ converted_weights.append(
+ np.clip(
+ matrix.reshape(config.altup_num_inputs, config.altup_num_inputs**2).transpose(),
+ -config.altup_coef_clip,
+ config.altup_coef_clip,
+ )
+ )
+
+ if path.endswith("modality_router"):
+ converted_paths.append(f"{altup_path}.modality_router.weight")
+ converted_weights.append(matrix.transpose())
+ elif path.endswith("router_norm_layer"):
+ converted_paths.append(f"{altup_path}.router_norm.weight")
+ converted_weights.append(matrix)
+ elif path.endswith("attn/attn_vec_einsum"):
+ converted_paths.append(f"{base_path}.self_attn.o_proj.weight")
+ converted_weights.append(
+ matrix.transpose(2, 0, 1).reshape(config.hidden_size, config.num_attention_heads * config.head_dim)
+ )
+ elif path.endswith("attn/kv_einsum"):
+ converted_paths.extend(
+ [
+ f"{base_path}.self_attn.k_proj.weight",
+ f"{base_path}.self_attn.v_proj.weight",
+ ]
+ )
+ k_proj_weights, v_proj_weights = matrix.transpose(0, 2, 1, 3)
+ kv_proj_shape = (config.hidden_size, config.num_key_value_heads * config.head_dim)
+ converted_weights.extend(
+ [
+ k_proj_weights.reshape(kv_proj_shape).transpose(),
+ v_proj_weights.reshape(kv_proj_shape).transpose(),
+ ]
+ )
+ elif path.endswith("attn/q_einsum"):
+ converted_paths.append(f"{base_path}.self_attn.q_proj.weight")
+ converted_weights.append(
+ matrix.transpose(1, 0, 2)
+ .reshape(config.hidden_size, config.num_attention_heads * config.head_dim)
+ .transpose()
+ )
+ elif path.endswith("attn/query_norm"):
+ converted_paths.append(f"{base_path}.self_attn.q_norm.weight")
+ converted_weights.append(matrix)
+ elif path.endswith("attn/key_norm"):
+ converted_paths.append(f"{base_path}.self_attn.k_norm.weight")
+ converted_weights.append(matrix)
+ elif path.endswith("laurel_block/linear_left"):
+ converted_paths.append(f"{base_path}.laurel.linear_left.weight")
+ converted_weights.append(matrix.transpose())
+ elif path.endswith("laurel_block/linear_right"):
+ converted_paths.append(f"{base_path}.laurel.linear_right.weight")
+ converted_weights.append(matrix.transpose())
+ elif path.endswith("mlp/gating_einsum"):
+ converted_paths.extend([f"{base_path}.mlp.gate_proj.weight", f"{base_path}.mlp.up_proj.weight"])
+ gate_proj_weight, up_proj_weight = matrix
+ converted_weights.extend([gate_proj_weight, up_proj_weight])
+ elif path.endswith("mlp/linear"):
+ converted_paths.append(f"{base_path}.mlp.down_proj.weight")
+ converted_weights.append(matrix.transpose())
+ elif path.endswith("per_layer_input_gate"):
+ converted_paths.append(f"{base_path}.per_layer_input_gate.weight")
+ converted_weights.append(matrix.transpose())
+ elif path.endswith("per_layer_projection"):
+ converted_paths.append(f"{base_path}.per_layer_projection.weight")
+ converted_weights.append(matrix.transpose())
+ elif path.endswith("post_attention_norm"):
+ converted_paths.append(f"{base_path}.post_attention_layernorm.weight")
+ converted_weights.append(matrix)
+ elif path.endswith("post_ffw_norm"):
+ converted_paths.append(f"{base_path}.post_feedforward_layernorm.weight")
+ converted_weights.append(matrix)
+ elif path.endswith("post_laurel_norm"):
+ converted_paths.append(f"{base_path}.laurel.post_laurel_norm.weight")
+ converted_weights.append(matrix)
+ elif path.endswith("post_per_layer_input_norm"):
+ converted_paths.append(f"{base_path}.post_per_layer_input_norm.weight")
+ converted_weights.append(matrix)
+ elif path.endswith("pre_attention_norm"):
+ converted_paths.append(f"{base_path}.input_layernorm.weight")
+ converted_weights.append(matrix)
+ elif path.endswith("pre_ffw_norm"):
+ converted_paths.append(f"{base_path}.pre_feedforward_layernorm.weight")
+ converted_weights.append(matrix)
+ elif path == _TRANSFORMER_EMBEDDER:
+ if param == "input_embedding":
+ converted_paths.append("embed_tokens.weight")
+ # Gemma 3n model doesn't have soft tokens or "end of" tokens for images and audio in its input and output
+ # embeddings, so we resize to avoid bugs observed with Mllama
+ pre_expansion_embeddings = weights
+ pad_token_slice = slice(config.pad_token_id, config.pad_token_id + 1)
+ new_embeddings = np.repeat(pre_expansion_embeddings[pad_token_slice], 256, axis=0)
+ weights = np.vstack([pre_expansion_embeddings, new_embeddings])
+ converted_weights.append(weights)
+ elif param == "per_layer_embeddings":
+ converted_paths.append("embed_tokens_per_layer.weight")
+ converted_weights.append(
+ weights.reshape(
+ config.vocab_size_per_layer_input, config.num_hidden_layers * config.hidden_size_per_layer_input
+ )
+ )
+ elif path.startswith(_TRANSFORMER_EMBEDDER):
+ # TODO: ryanmullins - support multimodal norms and projections
+ if path.endswith("per_layer_model_projection"):
+ converted_paths.append("per_layer_model_projection.weight")
+ converted_weights.append(
+ weights.reshape(
+ config.hidden_size, config.num_hidden_layers * config.hidden_size_per_layer_input
+ ).transpose()
+ )
+ elif path.endswith("per_layer_projection_norm"):
+ converted_paths.append("per_layer_projection_norm.weight")
+ converted_weights.append(weights)
+ elif path == _TRANSFORMER_FINAL_NORM:
+ converted_paths = ["norm.weight"]
+ converted_weights = [weights]
+
+ if (cpl := len(converted_paths)) != (cwl := len(converted_weights)):
+ raise ValueError(
+ "The `converted_paths` and `converted_weights` should be the same "
+ f"length. Got {cpl} and {cwl}, respectively, for {path}."
+ )
+
+ return zip(converted_paths, converted_weights)
+
+
+def convert_vision_weights(
+ config: Gemma3nVisionConfig,
+ path: str,
+ param: str,
+ weights: np.ndarray,
+) -> Iterable[tuple[str, np.ndarray]]:
+ def generate_base_path(path: str, block_type: str) -> tuple[str, tuple[int, int]]:
+ re_str = r"{}(\d+)/".format(block_type)
+ re_pattern = re.compile(re_str)
+ match = re.search(re_pattern, path).group(1)
+ idx = abs(int(match)) - 1
+
+ for block_idx, v in enumerate(_MOBILE_NET_TIMM_SUMMED_BLOCK_SIZES):
+ if v > idx:
+ offset = _MOBILE_NET_TIMM_SUMMED_BLOCK_SIZES[block_idx - 1] if block_idx > 0 else 0
+ layer_idx = idx - offset
+ return f"blocks.{block_idx}.{layer_idx}", (block_idx, layer_idx)
+
+ raise ValueError(f"could not extract a base path from {path}")
+
+ if _MOBILE_NET_MSFA in path:
+ converted_path = "msfa"
+
+ if "ffn/Normalize_0" in path:
+ converted_path += ".ffn.pw_exp.bn.weight"
+ converted_weight = weights
+ elif "ffn/Normalize_1" in path:
+ converted_path += ".ffn.pw_proj.bn.weight"
+ converted_weight = weights
+ elif "ffn/expand" in path:
+ converted_path += ".ffn.pw_exp.conv.weight"
+ converted_weight = weights.transpose()[:, :, None, None]
+ elif "ffn/project" in path:
+ converted_path += ".ffn.pw_proj.conv.weight"
+ converted_weight = weights.transpose()[:, :, None, None]
+ elif "Normalize_0" in path:
+ converted_path += ".norm.weight"
+ converted_weight = weights
+ elif _MOBILE_NET_CONV in path:
+ if "Conv_0" in path:
+ converted_path = "conv_stem.conv.weight"
+ converted_weight = weights.transpose(3, 2, 1, 0)
+ elif "Normalize_0" in path:
+ converted_path = "conv_stem.bn.weight"
+ converted_weight = weights
+ elif _MOBILE_NET_FIB in path:
+ converted_path, _ = generate_base_path(path, _MOBILE_NET_FIB)
+ if "Normalize_0" in path:
+ converted_path += ".bn1.weight"
+ converted_weight = weights
+ elif "Normalize_1" in path:
+ converted_path += ".bn2.weight"
+ converted_weight = weights
+ elif "expand_conv" in path:
+ converted_path += ".conv_exp.weight"
+ converted_weight = weights.transpose(3, 2, 1, 0)
+ else:
+ converted_path += ".conv_pwl.weight"
+ converted_weight = weights.transpose()[:, :, None, None]
+ elif _MOBILE_NET_MQA in path:
+ converted_path, _ = generate_base_path(path, _MOBILE_NET_MQA)
+
+ if "LayerScale_0" in path:
+ converted_path += ".layer_scale.gamma"
+ converted_weight = weights
+ elif "Normalize_0" in path:
+ converted_path += ".norm.weight"
+ converted_weight = weights
+ elif "Normalize_1" in path:
+ converted_path += ".attn.key.norm.weight"
+ converted_weight = weights
+ elif "Normalize_2" in path:
+ converted_path += ".attn.value.norm.weight"
+ converted_weight = weights
+ elif "key_dwconv" in path:
+ converted_path += ".attn.key.down_conv.weight"
+ converted_weight = weights.transpose()
+ elif "key_proj" in path:
+ converted_path += ".attn.key.proj.weight"
+ converted_weight = weights.transpose()[:, :, None, None]
+ elif "output_proj" in path:
+ converted_path += ".attn.output.proj.weight"
+ converted_weight = weights.transpose()[:, :, None, None]
+ elif "query_proj" in path:
+ converted_path += ".attn.query.proj.weight"
+ converted_weight = weights.transpose()[:, :, None, None]
+ elif "value_dwconv" in path:
+ converted_path += ".attn.value.down_conv.weight"
+ converted_weight = weights.transpose()
+ elif "value_proj" in path:
+ converted_path += ".attn.value.proj.weight"
+ converted_weight = weights.transpose()[:, :, None, None]
+ elif _MOBILE_NET_UIB in path:
+ converted_path, idx_key = generate_base_path(path, _MOBILE_NET_UIB)
+
+ has_dw_start = idx_key in _MOBILE_NET_UIB_HAS_DW_START
+ has_dw_mid = idx_key in _MOBILE_NET_UIB_HAS_DW_MID
+
+ if "LayerScale_0" in path:
+ converted_path += ".layer_scale.gamma"
+ converted_weight = weights
+ elif "Normalize_0" in path:
+ converted_path += ".dw_start.bn.weight" if has_dw_start else ".pw_exp.bn.weight"
+ converted_weight = weights
+ elif "Normalize_1" in path:
+ converted_path += ".pw_exp.bn.weight" if has_dw_start else ".pw_proj.bn.weight"
+ converted_weight = weights
+ elif "Normalize_2" in path:
+ converted_path += ".dw_mid.bn.weight" if has_dw_mid else ".pw_proj.bn.weight"
+ converted_weight = weights
+ elif "Normalize_3" in path:
+ converted_path += ".pw_proj.bn.weight"
+ converted_weight = weights
+ elif "expand" in path:
+ converted_path += ".pw_exp.conv.weight"
+ converted_weight = weights.transpose()[:, :, None, None]
+ elif "middle_dwconv" in path:
+ converted_path += ".dw_mid.conv.weight"
+ converted_weight = weights.transpose(3, 2, 1, 0)
+ elif "project" in path:
+ converted_path += ".pw_proj.conv.weight"
+ converted_weight = weights.transpose()[:, :, None, None]
+ elif "start_dwconv" in path:
+ converted_path += ".dw_start.conv.weight"
+ converted_weight = weights.transpose(3, 2, 1, 0)
+
+ return [(converted_path, converted_weight)]
+
+
+def convert(checkpoint_path: str, config: Gemma3nConfig) -> dict[str, torch.Tensor]:
+ """Loads Orbax checkpoint from `input_path` and converts it to HF tree."""
+ checkpointer = obc.PyTreeCheckpointer()
+ ckpt = checkpointer.restore(checkpoint_path)
+ hf_tree: dict[str, torch.Tensor] = {}
+
+ def update_tree(path: str, weights: np.ndarray, target_dtype: torch.dtype) -> None:
+ hf_tree[path] = torch.from_numpy(weights.astype("float32")).type(target_dtype)
+ if _VERBOSE.value:
+ logging.info(
+ "%s converted shape=%s with dtype=%s",
+ path,
+ weights.shape,
+ target_dtype,
+ )
+
+ for (path, param), value in tree.flatten_with_path(ckpt):
+ if param == "audio_input_embedding_extra":
+ update_tree("model.embed_audio.embedding.weight", value, config.audio_config.torch_dtype)
+ elif path.endswith("audio_embedding_norm"):
+ update_tree("model.embed_audio.hard_embedding_norm.weight", value, config.audio_config.torch_dtype)
+ elif path.endswith("audio_input_projection"):
+ update_tree(
+ "model.embed_audio.embedding_projection.weight", value.transpose(), config.audio_config.torch_dtype
+ )
+ elif path.endswith("audio_soft_embedding_norm"):
+ update_tree("model.embed_audio.soft_embedding_norm.weight", value, config.audio_config.torch_dtype)
+ elif param == "mm_input_embedding_extra":
+ update_tree("model.embed_vision.embedding.weight", value, config.vision_config.torch_dtype)
+ elif path.endswith("mm_hard_embedding_norm"):
+ update_tree("model.embed_vision.hard_embedding_norm.weight", value, config.vision_config.torch_dtype)
+ elif path.endswith("mm_input_projection"):
+ update_tree(
+ "model.embed_vision.embedding_projection.weight", value.transpose(), config.vision_config.torch_dtype
+ )
+ elif path.endswith("mm_soft_embedding_norm"):
+ update_tree("model.embed_vision.soft_embedding_norm.weight", value, config.vision_config.torch_dtype)
+ elif path.startswith(_TRANSFORMER_PARAMETER):
+ for path, weights in convert_transformer_weights(config.text_config, path, param, value):
+ update_tree(f"model.language_model.{path}", weights, config.text_config.torch_dtype)
+ elif _MOBILE_NET_PREFIX in path:
+ mobilenet_prefix_idx = path.index(_MOBILE_NET_PREFIX)
+ path = path[mobilenet_prefix_idx:]
+ for path, weights in convert_vision_weights(config.vision_config, path, param, value):
+ update_tree(f"model.vision_tower.timm_model.{path}", weights, config.vision_config.torch_dtype)
+ elif path.startswith(_AUDIO_ENCODER_PARAMETER):
+ for path, weights in convert_audio_encoder_weights(config.audio_config, path, param, value):
+ update_tree(f"model.audio_tower.{path}", weights, config.audio_config.torch_dtype)
+
+ hf_tree["lm_head.weight"] = hf_tree["model.language_model.embed_tokens.weight"]
+
+ return hf_tree
+
+
+def main(*args):
+ del args
+
+ output_path = _OUTPUT_PATH.value
+ variant = _VARIANT.value
+
+ config = _VARIANTS[variant]
+ config.audio_config.torch_dtype = getattr(torch, _AUDIO_DTYPE.value)
+ config.text_config.torch_dtype = getattr(torch, _TRANSFORMER_DTYPE.value)
+ config.vision_config.torch_dtype = getattr(torch, _VISION_DTYPE.value)
+ if _INCLUDE_CHAT_TEMPLATE.value:
+ # Chat template is included for instruction tuned models, which treat
+ # both "" and "" as generation stoppers.
+ config.eos_token_id = [1, 106]
+
+ logging.info(
+ "Converting Gemma 3 (%s) @ %s (language) and %s (vision)",
+ variant,
+ _TRANSFORMER_DTYPE.value,
+ _VISION_DTYPE.value,
+ )
+ state_tree = convert(_CHECKPOINT_PATH.value, config)
+ logging.info("Converted Gemma 3 (%s) state tree from Orbax to Hugging Face.", variant)
+
+ with accelerate.init_empty_weights():
+ model = Gemma3nForConditionalGeneration(config=config)
+
+ model.load_state_dict(state_tree, assign=True, strict=True)
+ logging.info(
+ "Loaded Gemma 3 (%s) in Hugging Face Transformers as a %s instance.",
+ variant,
+ type(model).__name__,
+ )
+ model.save_pretrained(output_path, state_dict=state_tree, safe_serialization=True)
+ logging.info(
+ "Saved Gemma 3 (%s) to SafeTensors in %s using %s",
+ variant,
+ output_path,
+ type(model).__name__,
+ )
+ del model
+ del state_tree
+
+ chat_template_kwargs = {"chat_template": _CHAT_TEMPLATE} if _INCLUDE_CHAT_TEMPLATE.value else {}
+
+ tokenizer = GemmaTokenizerFast(
+ _TOKENIZER_PATH.value,
+ add_bos_token=True,
+ extra_special_tokens={
+ "image_token": "", # Should be ID=262_145
+ "boi_token": "", # Should be ID=255_999
+ "eoi_token": "", # Should be ID=262_144
+ "audio_token": "", # Should be ID=262_273
+ "boa_token": "", # Should be ID=256_000
+ "eoa_token": "", # Should be ID=262_272
+ },
+ **chat_template_kwargs,
+ )
+ tokenizer.save_pretrained(output_path)
+ logging.info("Saved GemmaTokenizer for %s to %s", variant, output_path)
+
+ feature_extractor = Gemma3nAudioFeatureExtractor()
+ image_processor = SiglipImageProcessorFast(
+ image_seq_length=256,
+ image_mean=(0.5,) * 3,
+ image_std=(0.5,) * 3,
+ size={"height": 768, "width": 768},
+ resample=PILImageResampling.BILINEAR,
+ do_normalize=False,
+ )
+ processor = Gemma3nProcessor(
+ feature_extractor=feature_extractor,
+ image_processor=image_processor,
+ tokenizer=tokenizer,
+ **chat_template_kwargs,
+ )
+ processor.save_pretrained(output_path)
+
+ logging.info("Saved Gemma3nProcessor for %s to %s", variant, output_path)
+
+ # NOTE: feature_extractor and image_processor both use the same filename, preprocessor_config.json, when saved to
+ # disk, but the files are overwritten by processor.save_pretrained(). However, the configs can be unioned, saved,
+ # and loaded from the same preprocessor_config.json file, so we do that explicitly here.
+ feature_extractor_config = json.loads(feature_extractor.to_json_string())
+ image_processor_config = json.loads(image_processor.to_json_string())
+ preprocessor_config = {**feature_extractor_config, **image_processor_config}
+ with open(os.path.join(output_path, "preprocessor_config.json"), "w", encoding="utf-8") as writer:
+ writer.write(json.dumps(preprocessor_config, indent=2, sort_keys=True) + "\n")
+
+ logging.info("Saved joint preprocessor_config.json for %s to %s", variant, output_path)
+
+ del feature_extractor, image_processor, processor, tokenizer
+
+ generation_config = GenerationConfig(
+ pad_token_id=config.text_config.pad_token_id,
+ bos_token_id=config.text_config.bos_token_id,
+ eos_token_id=(
+ [config.text_config.eos_token_id, 106] if _INCLUDE_CHAT_TEMPLATE.value else config.text_config.eos_token_id
+ ),
+ cache_implementation="hybrid",
+ temperature=1.0,
+ do_sample=True,
+ top_k=64,
+ top_p=0.95,
+ )
+ generation_config.save_pretrained(output_path)
+
+
+if __name__ == "__main__":
+ app.run(main)
diff --git a/src/transformers/models/gemma3n/feature_extraction_gemma3n.py b/src/transformers/models/gemma3n/feature_extraction_gemma3n.py
new file mode 100644
index 000000000000..63598926af2b
--- /dev/null
+++ b/src/transformers/models/gemma3n/feature_extraction_gemma3n.py
@@ -0,0 +1,338 @@
+# coding=utf-8
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+from collections.abc import Sequence
+from typing import Optional, Union
+
+import numpy as np
+
+from ...feature_extraction_sequence_utils import SequenceFeatureExtractor
+from ...feature_extraction_utils import BatchFeature
+from ...utils import PaddingStrategy, TensorType, logging
+
+
+logger = logging.get_logger(__name__)
+
+
+def create_fb_matrix(
+ n_freqs: int,
+ f_min: float,
+ f_max: float,
+ n_mels: int,
+ sample_rate: int,
+ fft_length: int,
+ norm: Optional[str] = None,
+) -> np.ndarray:
+ r"""Create a frequency bin conversion matrix (NumPy version).
+
+ Args:
+ n_freqs (int): Number of frequencies to highlight/apply
+ f_min (float): Minimum frequency (Hz)
+ f_max (float): Maximum frequency (Hz)
+ n_mels (int): Number of mel filterbanks
+ sample_rate (int): Sample rate of the audio waveform
+ fft_length (int): FFT length
+ norm (Optional[str]): If 'slaney', divide the triangular mel weights by
+ the width of the mel band (area normalization). (Default: ``None``)
+
+ Returns:
+ np.ndarray: Triangular filter banks (fb matrix) of size (``n_freqs``,
+ ``n_mels``)
+ meaning number of frequencies to highlight/apply to x the number of
+ filterbanks.
+ Each column is a filterbank so that assuming there is a matrix A of
+ size (..., ``n_freqs``), the applied result would be
+ ``A @ create_fb_matrix_numpy(A.shape[-1], ...)``.
+ """
+
+ if norm is not None and norm != "slaney":
+ raise ValueError("norm must be one of None or 'slaney'")
+
+ # freq bins
+ all_freqs = np.arange(n_freqs, dtype=np.float32) * (sample_rate / fft_length)
+
+ # calculate mel freq bins
+ # hertz to mel(f) is 2595. * math.log10(1. + (f / 700.))
+ m_min = 2595.0 * math.log10(1.0 + (f_min / 700.0))
+ m_max = 2595.0 * math.log10(1.0 + (f_max / 700.0))
+ m_pts = np.linspace(m_min, m_max, n_mels + 2)
+ # mel to hertz(mel) is 700. * (10**(mel / 2595.) - 1.)
+ f_pts = 700.0 * (10 ** (m_pts / 2595.0) - 1.0)
+ # calculate difference between each mel point and each stft freq point in Hz
+ f_diff = f_pts[1:] - f_pts[:-1] # (n_mels + 1)
+ slopes = np.expand_dims(f_pts, 0) - np.expand_dims(all_freqs, 1) # (n_freqs, n_mels + 2)
+ # create overlapping triangles
+ zero = np.zeros(1, dtype=np.float32)
+ down_slopes = (-1.0 * slopes[:, :-2]) / f_diff[:-1] # (n_freqs, n_mels)
+ up_slopes = slopes[:, 2:] / f_diff[1:] # (n_freqs, n_mels)
+ fb = np.maximum(zero, np.minimum(down_slopes, up_slopes))
+
+ if norm is not None and norm == "slaney":
+ # Slaney-style mel is scaled to be approx constant energy per channel
+ enorm = 2.0 / (f_pts[2 : n_mels + 2] - f_pts[:n_mels])
+ fb *= np.expand_dims(enorm, 0)
+
+ return fb
+
+
+def _unfold(array: np.ndarray, dimension: int, size: int, step: int) -> np.ndarray:
+ """A basic NumPy equivalent of PyTorch's unfold for 2D arrays along the last dim."""
+ if array.ndim != 2:
+ raise ValueError("This unfold implementation currently supports 2D arrays (batch, time).")
+ if dimension != -1 and dimension != array.ndim - 1:
+ raise ValueError("This unfold implementation only supports unfolding the last dimension.")
+
+ batch_size, original_length = array.shape
+ num_frames = (original_length - size) // step + 1
+
+ if num_frames <= 0:
+ return np.zeros((batch_size, 0, size), dtype=array.dtype)
+
+ output_shape = (batch_size, num_frames, size)
+ output_strides = (array.strides[0], array.strides[1] * step, array.strides[1])
+
+ return np.lib.stride_tricks.as_strided(array, shape=output_shape, strides=output_strides)
+
+
+class Gemma3nAudioFeatureExtractor(SequenceFeatureExtractor):
+ """An audio feature extractor Universal Speech Models https://arxiv.org/abs/2303.01037.
+
+ Args:
+ feature_size (`int`, *optional*, defaults to 128):
+ The feature dimension of the extracted features.
+ sampling_rate (`int`, *optional*, defaults to 16000):
+ The sampling rate at which the audio files should be digitalized expressed in hertz (Hz).
+ padding_value (`float`, *optional*, defaults to 0.0):
+ Padding value used to pad the audio. Should correspond to silences.
+ return_attention_mask (`bool`, *optional*, defaults to `True`):
+ Whether to return the attention mask for the generated MEL spectrograms.
+ frame_length_ms (`float`, *optional*, defaults to 32.0):
+ The length of a frame in milliseconds.
+ hop_length_ms (`float`, *optional*, defaults to 10.0):
+ Length of the overlapping windows for the STFT used to obtain the Mel Frequency coefficients.
+ min_frequency (`float`, *optional*, defaults to 125.0):
+ The minimum frequency (in Hz) for the Mel filterbank.
+ max_frequency (`float`, *optional*, defaults to 7600.0):
+ The maximum frequency (in Hz) for the Mel filterbank.
+ preemphasis (`float`, *optional*, defaults to 0.97):
+ The preemphasis coefficient.
+ preemphasis_htk_flavor (`bool`, *optional*, defaults to `True`):
+ Whether to use HTK-style preemphasis.
+ fft_overdrive (`bool`, *optional*, defaults to `True`):
+ Whether to use FFT overdrive.
+ dither (`float`, *optional*, defaults to 0.0):
+ Adds dithering. In other words, adds a small Gaussian noise to each frame.
+ E.g. use 0.0001 to add dithering with a normal distribution centered
+ around 0.0 with standard deviation 0.0001 (assuming [-1,+1] range of raw_speech).
+ The value 0.0 means no dithering.
+ Dithering has similar effect as `spectrogram(mel_floor=...)`. It reduces
+ the high log_mel_fbank values for signals with hard-zero sections,
+ when VAD cutoff is present in the signal.
+ input_scale_factor (`float`, *optional*, defaults to 1.0):
+ Scaling factor applied to the input waveform.
+ mel_floor (`float`, *optional*, defaults to 1e-05):
+ Minimum value for Mel spectrograms to avoid log(0).
+ per_bin_mean (`Optional[Sequence[float]]`, *optional*):
+ Mean values for per-bin normalization.
+ per_bin_stddev (`Optional[Sequence[float]]`, *optional*):
+ Standard deviation values for per-bin normalization.
+ """
+
+ model_input_names = ["input_features", "input_features_mask"]
+
+ def __init__(
+ self,
+ feature_size: int = 128,
+ sampling_rate: int = 16_000,
+ padding_value: float = 0.0,
+ return_attention_mask: bool = True,
+ frame_length_ms: float = 32.0,
+ hop_length_ms: float = 10.0,
+ min_frequency: float = 125.0,
+ max_frequency: float = 7600.0,
+ preemphasis: float = 0.97,
+ preemphasis_htk_flavor: bool = True,
+ fft_overdrive: bool = True,
+ dither: float = 0.0,
+ input_scale_factor: float = 1.0,
+ mel_floor: float = 1e-5,
+ per_bin_mean: Optional[Sequence[float]] = None,
+ per_bin_stddev: Optional[Sequence[float]] = None,
+ **kwargs,
+ ):
+ super().__init__(
+ feature_size=feature_size,
+ sampling_rate=sampling_rate,
+ padding_value=padding_value,
+ return_attention_mask=return_attention_mask,
+ **kwargs,
+ )
+
+ self.min_frequency = min_frequency
+ self.max_frequency = max_frequency
+ self.preemphasis = preemphasis
+ self.preemphasis_htk_flavor = preemphasis_htk_flavor
+ self.fft_overdrive = fft_overdrive
+ self.dither = dither
+ self.input_scale_factor = input_scale_factor
+ self.frame_length = int(round(sampling_rate * frame_length_ms / 1000.0))
+ self.hop_length = int(round(sampling_rate * hop_length_ms / 1000.0))
+ self.mel_floor = np.array(mel_floor, dtype=np.float64)
+
+ fft_length = 2 ** math.ceil(math.log2(self.frame_length))
+ if self.fft_overdrive:
+ fft_length *= 2
+ self.fft_length = fft_length
+
+ hann_arange = np.arange(self.frame_length, dtype=np.float32)
+ window = 0.5 * (1 - np.cos(2 * np.pi * hann_arange / self.frame_length))
+ self.window = window.astype(np.float32)
+
+ self.mel_filters = create_fb_matrix(
+ n_freqs=self.fft_length // 2 + 1,
+ f_min=min_frequency,
+ f_max=max_frequency,
+ n_mels=feature_size,
+ sample_rate=self.sampling_rate,
+ norm=None,
+ fft_length=fft_length,
+ )
+
+ if per_bin_mean is not None:
+ self.per_bin_mean = np.array(per_bin_mean).reshape(1, 1, feature_size)
+ else:
+ self.per_bin_mean = None
+
+ if per_bin_stddev is not None:
+ self.per_bin_stddev = np.array(per_bin_stddev).reshape(1, 1, feature_size)
+ else:
+ self.per_bin_stddev = None
+
+ def _extract_spectrogram(self, waveform: np.ndarray, attention_mask: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
+ """"""
+ if waveform.ndim == 1: # If single waveform, add batch dimension
+ waveform = np.expand_dims(waveform, axis=0)
+
+ if self.dither > 0.0:
+ waveform = waveform + self.dither * np.random.randn(*waveform.shape).astype(waveform.dtype)
+
+ if self.input_scale_factor != 1.0:
+ waveform = waveform * self.input_scale_factor
+
+ frame_size_for_unfold = self.frame_length + 1
+
+ # NumPy equivalent of unfold for [B, NumFrames, frame_size_for_unfold]
+ frames_to_process = _unfold(waveform, dimension=-1, size=frame_size_for_unfold, step=self.hop_length)
+
+ if self.preemphasis > 0.0:
+ if self.preemphasis_htk_flavor:
+ first_in_frame = frames_to_process[..., :1] * (1.0 - self.preemphasis)
+ rest_in_frame = frames_to_process[..., 1:-1] - self.preemphasis * frames_to_process[..., :-2]
+ frames = np.concatenate([first_in_frame, rest_in_frame], axis=-1)
+ else:
+ frames = frames_to_process[..., 1:] - self.preemphasis * frames_to_process[..., :-1]
+ else:
+ frames = frames_to_process[..., :-1]
+
+ frames = frames * self.window # Broadcasting window
+ stft = np.fft.rfft(frames, n=self.fft_length, axis=-1)
+
+ magnitude_spec = np.abs(stft)
+
+ mel_spec = np.matmul(magnitude_spec, self.mel_filters)
+ log_mel_spec = np.log(np.maximum(mel_spec, self.mel_floor))
+
+ if self.per_bin_mean is not None:
+ log_mel_spec = log_mel_spec - self.per_bin_mean # Broadcasting
+
+ if self.per_bin_stddev is not None:
+ log_mel_spec = log_mel_spec / self.per_bin_stddev # Broadcasting
+
+ mel_spectrogram = log_mel_spec.squeeze()
+ mask = attention_mask[:: self.hop_length].astype(bool)
+ # TODO: The filtered mask is always exactly 3 elements longer than the mel_spectrogram. Why???
+ return mel_spectrogram, mask[: mel_spectrogram.shape[0]]
+
+ def __call__(
+ self,
+ raw_speech: Union[np.ndarray, list[float], list[np.ndarray], list[list[float]]],
+ padding: Union[bool, str, PaddingStrategy] = "longest",
+ max_length: Optional[int] = 480_000,
+ truncation: bool = True,
+ pad_to_multiple_of: Optional[int] = 128,
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ return_attention_mask: Optional[bool] = True,
+ **kwargs,
+ ) -> BatchFeature:
+ """Creates a batch of MEL spectrograms from the provided raw speech.
+
+ This implementation uses a different algorithm for windowing and preemphasis compared to the built-in
+ `transformers.audio_utils.spectrogram()` function that _will_ result in different outputs. Consider this
+ carefully when selecting an audio feature extactor, especially with pre-trained models.
+
+ Args:
+ raw_speech:
+ The audio for which MEL spectrograms are created.
+ padding (`Union[bool, str, PaddingStrategy]`, *optional*, defaults to `"longest"`):
+ The padding strategy to use for batches of audio with different lengths.
+ max_length (`int`, *optional*, defaults to 480000):
+ If provided, defines the maximum length of the audio to allow. Audio longer than this will be
+ truncated if `truncation=True`.
+ truncation (`bool`, *optional*, defaults to `True`):
+ Whether or not to truncate audio above `max_length`.
+ pad_to_multiple_of (`int`, *optional*, defaults to 128):
+ When padding, pad to a multiple of this value. The default value is defined for optimal TPU support.
+ return_tensors (`Union[str, TensorType]`, *optional*, defaults to `None`):
+ The type of tensors to return (e.g., NumPy, Torch, JAX, TensorFlow).
+ return_attention_mask (`bool`, *optional*, defaults to `True`):
+ Whether to return the attention mask for the generated MEL spectrograms.
+ """
+
+ is_batched_numpy = isinstance(raw_speech, np.ndarray) and len(raw_speech.shape) > 1
+ is_batched_sequence = isinstance(raw_speech, Sequence) and isinstance(raw_speech[0], (np.ndarray, Sequence))
+ is_batched = is_batched_numpy or is_batched_sequence
+
+ if is_batched:
+ raw_speech = [np.asarray([rs]).T for rs in raw_speech]
+ elif not is_batched and not isinstance(raw_speech, np.ndarray):
+ raw_speech = np.asarray(raw_speech)
+
+ if not is_batched: # always return a batch
+ raw_speech = [np.asarray([raw_speech])]
+
+ batched_speech = self.pad(
+ BatchFeature({"input_features": raw_speech}),
+ padding=padding,
+ max_length=max_length,
+ truncation=truncation,
+ pad_to_multiple_of=pad_to_multiple_of,
+ return_attention_mask=return_attention_mask,
+ )
+
+ prepared_speech = []
+ prepared_speech_mask = []
+ for speech, mask in zip(batched_speech.input_features, batched_speech.attention_mask):
+ speech, mask = self._extract_spectrogram(speech.T, mask)
+ prepared_speech.append(speech.astype(np.float32))
+ prepared_speech_mask.append(mask)
+
+ return BatchFeature(
+ {"input_features": prepared_speech, "input_features_mask": prepared_speech_mask},
+ tensor_type=return_tensors,
+ )
+
+
+__all__ = ["Gemma3nAudioFeatureExtractor"]
diff --git a/src/transformers/models/gemma3n/modeling_gemma3n.py b/src/transformers/models/gemma3n/modeling_gemma3n.py
new file mode 100644
index 000000000000..0817e16451ac
--- /dev/null
+++ b/src/transformers/models/gemma3n/modeling_gemma3n.py
@@ -0,0 +1,2422 @@
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# This file was automatically generated from src/transformers/models/gemma3n/modular_gemma3n.py.
+# Do NOT edit this file manually as any edits will be overwritten by the generation of
+# the file from the modular. If any change should be done, please apply the change to the
+# modular_gemma3n.py file directly. One of our CI enforces this.
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# coding=utf-8
+# Copyright 2025 Google Inc. HuggingFace Inc. team. All rights reserved.
+#
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import copy
+import math
+from collections.abc import Callable, Sequence
+from dataclasses import dataclass
+from typing import Optional, Union
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from ...activations import ACT2FN
+from ...cache_utils import Cache, DynamicCache, HybridCache
+from ...generation import GenerationMixin
+from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
+from ...modeling_flash_attention_utils import FlashAttentionKwargs
+from ...modeling_layers import GradientCheckpointingLayer
+from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
+from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
+from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
+from ...processing_utils import Unpack
+from ...utils import (
+ ModelOutput,
+ auto_docstring,
+ can_return_tuple,
+ is_torchdynamo_compiling,
+ logging,
+)
+from ...utils.deprecation import deprecate_kwarg
+from ..auto import AutoModel
+from .configuration_gemma3n import Gemma3nAudioConfig, Gemma3nConfig, Gemma3nTextConfig, Gemma3nVisionConfig
+
+
+logger = logging.get_logger(__name__)
+
+
+@dataclass
+@auto_docstring(
+ custom_intro="""
+ Base class for Gemma3n outputs, with hidden states and attentions.
+ """
+)
+class Gemma3nModelOutputWithPast(BaseModelOutputWithPast):
+ r"""
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`)
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
+ `past_key_values` input) to speed up sequential decoding.
+ image_hidden_states (`torch.FloatTensor`, *optional*):
+ A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
+ image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
+ audio_hidden_states (`torch.FloatTensor`, *optional*):
+ A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
+ audio_hidden_states of the model produced by the audio encoder and after projecting the last hidden state.
+ """
+
+ image_hidden_states: Optional[torch.FloatTensor] = None
+
+ audio_hidden_states: Optional[torch.FloatTensor] = None
+
+
+@dataclass
+@auto_docstring(
+ custom_intro="""
+ Base class for Gemma3n causal language model (or autoregressive) outputs.
+ """
+)
+class Gemma3nCausalLMOutputWithPast(ModelOutput):
+ r"""
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
+ Language modeling loss (for next-token prediction).
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.text_config.vocab_size)`):
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`)
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
+ `past_key_values` input) to speed up sequential decoding.
+ image_hidden_states (`torch.FloatTensor`, *optional*):
+ A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
+ image_hidden_states of the model produced by the vision encoder after projecting last hidden state.
+ audio_hidden_states (`torch.FloatTensor`, *optional*):
+ A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
+ audio_hidden_states of the model produced by the audio encoder and after projecting the last hidden state.
+ """
+
+ loss: Optional[torch.FloatTensor] = None
+ logits: Optional[torch.FloatTensor] = None
+ past_key_values: Optional[Union[list[torch.FloatTensor], Cache]] = None
+ hidden_states: Optional[tuple[torch.FloatTensor]] = None
+ attentions: Optional[tuple[torch.FloatTensor]] = None
+ image_hidden_states: Optional[torch.FloatTensor] = None
+
+ audio_hidden_states: Optional[torch.FloatTensor] = None
+
+
+class Gemma3nRMSNorm(nn.Module):
+ def __init__(self, dim: int, eps: float = 1e-6, with_scale: bool = True):
+ super().__init__()
+ self.eps = eps
+ self.with_scale = with_scale
+
+ if self.with_scale:
+ self.weight = nn.Parameter(torch.ones(dim))
+ else:
+ self.register_buffer("weight", torch.tensor(1.0), persistent=False)
+
+ def _norm(self, x):
+ return x / torch.sqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ # Llama does x.to(float16) * w whilst Gemma2 is (x * w).to(float16)
+ # See https://github.com/huggingface/transformers/pull/29402
+ output = self._norm(x.float()) * self.weight.float()
+ return output.type_as(x)
+
+ def extra_repr(self):
+ return f"{tuple(self.weight.shape)}, eps={self.eps}"
+
+
+# ==== Audio Encoder ====
+
+
+class Gemma3nAudioRelativePositionEmbedding(nn.Module):
+ def __init__(self, config: Gemma3nAudioConfig):
+ super().__init__()
+ self.config = config
+
+ self.num_heads = self.config.conf_num_attention_heads
+ self.channels = self.config.hidden_size
+ self.head_dim = self.channels // self.num_heads
+ self.max_backward = max(0, self.config.conf_attention_context_left - 1)
+ self.max_forward = self.config.conf_attention_context_right
+
+ self.pos_proj = nn.Linear(self.channels, self.num_heads * self.head_dim, bias=False)
+
+ min_timescale = 1.0
+ max_timescale = 1.0e4
+ num_timescales = self.channels // 2
+ log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / max(num_timescales - 1, 1)
+ inv_timescales = min_timescale * torch.exp(torch.arange(num_timescales) * -log_timescale_increment)
+ self.register_buffer(
+ "inv_timescales",
+ inv_timescales.float().unsqueeze(0).unsqueeze(0),
+ persistent=False,
+ )
+
+ def _get_timing_signal_1d_pos(self, position: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
+ position = position.float().unsqueeze(-1)
+ scaled_time = position * self.inv_timescales.to(device=position.device, dtype=torch.float32)
+ timing_signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=-1)
+ return timing_signal.type(dtype)
+
+ def _relative_shift(
+ self,
+ term_bd_before_shift: torch.Tensor,
+ batch_size: int,
+ num_heads: int,
+ num_query_blocks: int,
+ query_block_size: int,
+ key_context_size: int,
+ max_span_plus_1: int,
+ ) -> torch.Tensor:
+ """Performs the relative shift.
+
+ Args:
+ term_bd_before_shift: Tensor of shape [B, N, U, W, F_span]. batch_size
+ (B), num_heads (N), num_query_blocks (U), query_block_size (W),
+ key_context_size (C = W+L+R), max_span_plus_1 (F_span = L+R+1).
+
+ Returns:
+ Tensor of shape [B, N, U, W, C].
+ """
+ # term_bd_before_shift shape: [B, N, U, W, F_span]
+ # Target shape after shift: [B, N, U, W, C]
+
+ # Padding amount for the last dimension (F_span) to become (C + 1)
+ # C = key_context_size
+ # F_span = max_span_plus_1
+ pad_amount_last_dim = (key_context_size + 1) - max_span_plus_1
+
+ # PyTorch F.pad expects (pad_left, pad_right, pad_top, pad_bottom ...)
+ # We only pad the last dimension on the right.
+ padding_tuple = (0, pad_amount_last_dim)
+
+ term_bd_padded = nn.functional.pad(term_bd_before_shift, padding_tuple)
+ # Shape after pad: [B, N, U, W, C+1]
+
+ # Reshape for slicing (emulating JAX's behavior)
+ # [B, N, U, W * (C+1)]
+ term_bd_reshaped = term_bd_padded.reshape(
+ (
+ batch_size,
+ num_heads,
+ num_query_blocks,
+ query_block_size * (key_context_size + 1),
+ )
+ )
+
+ # Slice to effective [B, N, U, W * C]
+ term_bd_sliced = term_bd_reshaped[:, :, :, : query_block_size * key_context_size]
+
+ # Reshape back to [B, N, U, W, C]
+ term_bd_shifted = term_bd_sliced.reshape(
+ (
+ batch_size,
+ num_heads,
+ num_query_blocks,
+ query_block_size,
+ key_context_size,
+ )
+ )
+ return term_bd_shifted
+
+ def forward(self, queries: torch.Tensor, keys: torch.Tensor) -> torch.Tensor:
+ # queries: [B, U, W, N, H] (batch, num_query_blocks, query_block_size, num_heads, head_dim)
+ # keys: [B, U, C, N, H] (batch, num_query_blocks, key_context_size, num_heads, head_dim)
+ # C = W + L + R (key_context_size)
+ # F_span = L + R + 1 (max_span + 1)
+
+ batch_size, num_query_blocks, query_block_size, num_heads, head_dim = queries.shape
+ _, _, key_context_size, _, _ = keys.shape
+
+ # Relative positions for sinusoidal embeddings: [L, L-1, ..., -R]
+ # Length is L+R+1 = self.max_span + 1
+ pos_indices = torch.arange(self.max_backward, -self.max_forward - 1, -1, device=queries.device).unsqueeze(
+ 0
+ ) # Shape [1, F_span]
+
+ max_span_plus_1 = pos_indices.shape[1] # F_span
+
+ sin_emb_timing_signal = self._get_timing_signal_1d_pos(
+ pos_indices, dtype=queries.dtype
+ ) # Shape [1, F_span, self.channels]
+
+ # Project sinusoidal embeddings: [1, F_span, self.channels] -> [1, F_span, N*H]
+ projected_sin_emb = self.pos_proj(sin_emb_timing_signal)
+ # Reshape to [1, F_span, N, H] then squeeze to [F_span, N, H]
+ sin_emb = projected_sin_emb.reshape(1, max_span_plus_1, self.num_heads, self.head_dim).squeeze(
+ 0
+ ) # Shape [F, N, H]
+
+ # term_ac: Query-Key content interaction
+ # queries: [B, U, W, N, H] -> permute to [B, N, U, W, H] for matmul
+ # keys: [B, U, C, N, H] -> permute to [B, N, U, H, C] for matmul
+ queries_p = queries.permute(0, 3, 1, 2, 4) # [B, N, U, W, H]
+ keys_p_t = keys.permute(0, 3, 1, 4, 2) # [B, N, U, H, C]
+ term_ac = torch.matmul(queries_p, keys_p_t) # [B, N, U, W, C]
+
+ # term_bd: Query-Position interaction
+ # Original einsum: term_bd_unshifed = torch.einsum('buwnh,fnh->bnuwf', queries, sin_emb)
+ # queries shape: [B, U, W, N, H]
+ # sin_emb shape: [F, N, H]
+ # Target output shape: [B, N, U, W, F]
+
+ # Permute queries to [B, N, U, W, H] for easier broadcasting with sin_emb
+ q_permuted = queries.permute(0, 3, 1, 2, 4)
+
+ # Permute sin_emb to [N, H, F] to prepare for matmul
+ # sin_emb original is [F, N, H]
+ s_permuted = sin_emb.permute(1, 2, 0) # Shape: [N, H, F]
+
+ # Reshape queries for matmul: [B, N, U*W, H]
+ q_reshaped = q_permuted.reshape(batch_size, num_heads, num_query_blocks * query_block_size, head_dim)
+
+ # Perform matmul: [B, N, U*W, H] @ [N, H, F]
+ # s_permuted ([N, H, F]) will be broadcast to [B, N, H, F]
+ # Result: [B, N, U*W, F]
+ term_bd_unshifed_matmul = torch.matmul(q_reshaped, s_permuted)
+
+ # Reshape to target [B, N, U, W, F]
+ term_bd_unshifed = term_bd_unshifed_matmul.reshape(
+ batch_size,
+ num_heads,
+ num_query_blocks,
+ query_block_size,
+ max_span_plus_1,
+ )
+
+ # Apply relative shift to term_bd_unshifed
+ term_bd_shifted = self._relative_shift(
+ term_bd_unshifed,
+ batch_size,
+ num_heads,
+ num_query_blocks,
+ query_block_size,
+ key_context_size,
+ max_span_plus_1,
+ ) # Shape [B, N, U, W, C]
+
+ return term_ac + term_bd_shifted
+
+
+class Gemma3nAudioAttention(nn.Module):
+ def __init__(self, config: Gemma3nAudioConfig):
+ super().__init__()
+ self.config = config
+
+ self.num_heads = self.config.conf_num_attention_heads
+ self.hidden_size = self.config.hidden_size
+ self.head_dim = self.hidden_size // self.num_heads
+
+ self.chunk_size = self.config.conf_attention_chunk_size
+ self.max_future_horizon = self.config.conf_attention_context_right
+ self.max_past_horizon = max(0, self.config.conf_attention_context_left - 1)
+ self.attention_logits_soft_cap = self.config.conf_attention_logit_cap
+ self.context_size = self.chunk_size + self.max_past_horizon + self.max_future_horizon
+
+ self.relative_position_embedding = Gemma3nAudioRelativePositionEmbedding(config)
+ self.per_dim_scale = nn.Parameter(torch.zeros((self.head_dim,)))
+
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
+ self.k_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
+ self.v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
+
+ q_scale = self.head_dim**-0.5
+ r_softplus_0 = 1.0 / torch.nn.functional.softplus(torch.tensor(0.0))
+ self.register_buffer("q_scale", (q_scale * r_softplus_0).clone().detach(), persistent=False)
+
+ lower_causal_mask = torch.tril(
+ torch.ones((self.context_size, self.chunk_size), dtype=torch.bool),
+ diagonal=0,
+ ).T
+ upper_causal_mask = torch.tril(
+ torch.ones((self.chunk_size, self.context_size), dtype=torch.bool),
+ diagonal=self.max_past_horizon + self.max_future_horizon,
+ )
+ local_causal_valid_mask = torch.ones((self.chunk_size, self.context_size), dtype=torch.bool)
+ local_causal_valid_mask = local_causal_valid_mask * lower_causal_mask * upper_causal_mask
+ self.register_buffer("local_causal_valid_mask", local_causal_valid_mask, persistent=False)
+
+ self.register_buffer(
+ "softcap",
+ torch.tensor(self.attention_logits_soft_cap).float(),
+ persistent=False,
+ )
+
+ def _pad_dim1(self, x: torch.Tensor, pad_left: int, pad_right: int) -> torch.Tensor:
+ batch, _, *tail_shape = x.shape
+ left = x.new_zeros((batch, pad_left, *tail_shape))
+ right = x.new_zeros((batch, pad_right, *tail_shape))
+ x = torch.cat([left, x, right], dim=1)
+ return x
+
+ def _convert_to_block(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ """Turns a sequence to non overlapping blocks.
+
+ Args:
+ hidden_states: a tensor of [batch, time, ...].
+
+ Returns:
+ A tensor of [batch, num_blocks, block_size, ...], with necessary
+ paddings,
+ where output[:, i, ...] are x[:, i*block_size:(i+1)*block_size, ...].
+ """
+ shape = hidden_states.shape
+ b, t = shape[:2]
+ num_blocks = (t + self.chunk_size - 1) // self.chunk_size
+
+ if (padding_len := num_blocks * self.chunk_size - t) > 0:
+ hidden_states = self._pad_dim1(hidden_states, 0, padding_len)
+
+ permute_dims = (b, num_blocks, self.chunk_size) + shape[2:]
+ hidden_states = hidden_states.reshape(permute_dims).contiguous()
+ return hidden_states
+
+ def _extract_block_context(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ """Extracts temporal context for every block.
+
+ Args:
+ hidden_states: a tensor of [batch, time, ...].
+
+ Returns:
+ A tensor of [batch, num_blocks, context_size, ...], with necessary
+ paddings,
+ where context_size = block_size + left_context + right_context,
+ and output[:, i, ...] are x[:, start-left_context:end+right_context,
+ ...],
+ start = i * block_size, end = (i + 1) * block_size.
+ """
+ pad_left = self.max_past_horizon
+ # The JAX equivalent padding for signal.frame with pad_mode='valid' is
+ # (left_context, right_context + block_size - 1) on the time dimension.
+ # PyTorch's _pad_dim1 applies padding symmetrically if only one value is given,
+ # or (pad_dim_start, pad_dim_end) if two are given.
+ # Our _pad_dim1(x, pad_left, pad_right) pads dim -2 (time for [B,T,N,H])
+ # or dim 1 (time for [B,T]).
+ # The current pad_right calculation matches the JAX effective padding.
+ pad_right = self.max_future_horizon + self.chunk_size - 1
+ hidden_states = self._pad_dim1(hidden_states, pad_left, pad_right)
+
+ frame_len = self.context_size
+ frame_step = self.chunk_size
+
+ # Directly use unfold without the subframe_factor logic
+ # x.unfold(dimension, size, step)
+ # dimension=1 (time dimension, assuming x is [B, T_padded, ...])
+ # size=frame_len (context_size)
+ # step=frame_step (chunk_size)
+ x_unfolded = hidden_states.unfold(dimension=1, size=frame_len, step=frame_step)
+
+ # If x was [B, T_padded], x_unfolded is [B, num_blocks, frame_len]
+ # If x was [B, T_padded, N, H], x_unfolded is [B, num_blocks, N, H, frame_len]
+ # We want to match JAX's typical output for such operations which might be
+ # [B, num_blocks, frame_len, N, H] if N, H are present.
+ # The relative_position_embedding expects keys as [B, U, C, N, H].
+ # If x_unfolded is [B, U, N, H, C(frame_len)], we need to move C.
+ if hidden_states.ndim > 2 and x_unfolded.ndim > 3: # Check if inner dimensions (like N, H) exist
+ # Current shape after unfold for [B, T_pad, N, H] is [B, U, N, H, C]
+ # Target shape for keys in RPE: [B, U, C, N, H]
+ x_unfolded = torch.movedim(x_unfolded, source=-1, destination=2)
+
+ return x_unfolded.contiguous()
+
+ def forward(self, hidden_states: torch.Tensor, mask: torch.BoolTensor) -> torch.Tensor:
+ # sl.Dense uses jax.numpy.einsum("...a,abcd->...bcd") and jax.numpy.select()
+ qkv_shape = (*hidden_states.shape[:-1], self.num_heads, self.head_dim)
+ query_states = self.q_proj(hidden_states).reshape(qkv_shape).contiguous()
+ key_states = self.k_proj(hidden_states).reshape(qkv_shape).contiguous()
+ value_states = self.v_proj(hidden_states).reshape(qkv_shape).contiguous()
+
+ per_dim_scale_sp = torch.nn.functional.softplus(self.per_dim_scale)
+
+ broadcast_shape = (1, 1, 1, self.head_dim)
+ per_dim_scale_sp_broadcast = per_dim_scale_sp.view(broadcast_shape)
+ query_states = query_states * self.q_scale * per_dim_scale_sp_broadcast
+
+ batch_size, q_time = query_states.shape[:2]
+
+ query_blocks = self._convert_to_block(query_states)
+ key_blocks = self._extract_block_context(key_states)
+ value_blocks = self._extract_block_context(value_states)
+ num_query_blocks = query_blocks.shape[1]
+
+ # 1. Create a mask indicating originally valid positions.
+ original_valid_mask = ~mask # True for valid, False for padded
+
+ # 2. Extract blocks from this validity mask.
+ extracted_valid_mask_blocks = self._extract_block_context(original_valid_mask)
+
+ # If subframe_factor was used in _extract_block_context for a [B, T] input mask,
+ # the shape might be [B, U, C/SF, SF]. Reshape to [B, U, C].
+ # batch_size and num_query_blocks are known from query_blocks.
+ # self.context_size is C.
+ if (
+ extracted_valid_mask_blocks.ndim == 4
+ and extracted_valid_mask_blocks.shape[2] * extracted_valid_mask_blocks.shape[3] == self.context_size
+ ):
+ extracted_valid_mask_blocks = extracted_valid_mask_blocks.reshape(
+ batch_size, num_query_blocks, self.context_size
+ )
+ # After potential reshape, ensure it's [B, U, C] if it was from a [B,T] mask.
+ # This assertion might be too strict if _extract_block_context handles higher-rank inputs differently,
+ # but for the mask case, this should hold.
+ if extracted_valid_mask_blocks.shape != (
+ batch_size,
+ num_query_blocks,
+ self.context_size,
+ ):
+ raise ValueError(
+ "Shape of extracted_valid_mask_blocks"
+ f" {extracted_valid_mask_blocks.shape} is not ({batch_size},"
+ f" {num_query_blocks}, {self.context_size}) after potential reshape."
+ )
+
+ # 3. Expand dimensions for broadcasting with logits and causal mask.
+ # Target shape for broadcasting with logits [B,N,U,W,C]
+ # extracted_valid_mask_blocks to [B, 1, U, 1, C]
+ condition_from_input_validity = extracted_valid_mask_blocks.unsqueeze(1).unsqueeze(-2)
+
+ # self.local_causal_valid_mask is [W, C], True where allowed by local window.
+ # Expand to [1, 1, 1, W, C]
+ condition_from_causality = self.local_causal_valid_mask.unsqueeze(0).unsqueeze(0).unsqueeze(0)
+
+ # 4. Combine the two conditions.
+ # final_condition will be True where a key is *both* originally valid *and* causally accessible.
+ # Broadcasts to [B, 1, U, W, C]
+ final_condition_for_where = torch.logical_and(
+ condition_from_input_validity,
+ condition_from_causality.to(condition_from_input_validity.device), # Ensure same device
+ )
+
+ # Embed queries and keys
+ logits = self.relative_position_embedding(query_blocks, key_blocks)
+
+ # Apply attention logit softcap
+ # Ensure softcap is on the same device as logits
+ softcap_val = self.softcap.to(logits.device)
+ logits = logits / softcap_val
+ logits = torch.tanh(logits)
+ logits = logits * softcap_val
+
+ # Apply the combined mask.
+ # final_condition_for_where will broadcast with logits [B,N,U,W,C]
+ logits = torch.where(final_condition_for_where, logits, torch.finfo(logits.dtype).min)
+ probabilities = torch.nn.functional.softmax(logits, dim=-1, dtype=torch.float32).to(dtype=value_blocks.dtype)
+
+ # context_vectors is adapted from jax.numpy.einsum("BNuwc,BucNH->BuwNH", ...)
+ b_dim, n_dim, u_dim, w_dim, c_dim = probabilities.shape
+ h_dim = value_blocks.shape[-1]
+ prob_bun = probabilities.permute(0, 2, 1, 3, 4).reshape(-1, w_dim, c_dim)
+ v_bun = value_blocks.permute(0, 1, 3, 2, 4).reshape(-1, c_dim, h_dim)
+ result_bmm = torch.bmm(prob_bun, v_bun)
+ context_vectors = result_bmm.reshape(b_dim, u_dim, n_dim, w_dim, h_dim).permute(0, 1, 3, 2, 4)
+ context_vectors = context_vectors.reshape(
+ (
+ batch_size,
+ num_query_blocks * self.chunk_size,
+ self.num_heads,
+ self.head_dim,
+ )
+ )
+ context_vectors = context_vectors[:, :q_time]
+
+ return context_vectors
+
+
+class Gemma3nAudioCumulativeGroupNorm(nn.Module):
+ """Applies Group Normalization cumulatively over the time dimension.
+
+ This layer normalizes the input by calculating the mean and variance
+ cumulatively over the time dimension (dim 1). The statistics are computed
+ over all feature dimensions (specified by `feature_dims` and `num_channels`)
+ for elements marked as valid by the optional `mask`.
+
+ If a `mask` is provided (True for valid, False for invalid/padded),
+ invalid time steps do not contribute to the statistics calculation, and
+ their corresponding output values are zeroed out.
+
+ Scale and bias, if enabled, are applied per-channel (last dimension).
+ This behavior is similar to JAX's `GroupNormalization` with `num_groups=1`
+ and `cumulative=True`.
+ """
+
+ def __init__(
+ self,
+ num_channels: int, # Number of channels (size of the last dimension)
+ feature_dims: Sequence[int], # Sizes of non-channel feature dimensions, e.g., (H, W) for input [B,T,H,W,C]
+ eps: float = 1e-3,
+ ):
+ super().__init__()
+ self.num_channels = num_channels
+ self.feature_dims = tuple(feature_dims)
+ self.eps = eps
+
+ # Scale parameter depends only on the channel dimension
+ self.weight = nn.Parameter(torch.ones(num_channels))
+
+ # Axes for normalization: all dimensions except Batch (0) and Time (1).
+ # For input [B, T, *feature_dims, C], these are dims from 2 onwards.
+ self.reduction_axes = tuple(range(2, 2 + len(self.feature_dims) + 1))
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ """Applies cumulative group norm, optionally using a mask.
+
+ Args:
+ hidden_states: Input tensor, shape [B, T, *feature_dims, C].
+
+ Returns:
+ Normalized tensor with the same shape as x.
+ """
+ expected_input_suffix = self.feature_dims + (self.num_channels,)
+ if hidden_states.shape[2:] != expected_input_suffix:
+ raise ValueError(
+ f"Input tensor shape suffix {hidden_states.shape[2:]} does not match expected"
+ f" suffix (feature_dims + num_channels) {expected_input_suffix}"
+ )
+
+ input_dtype = hidden_states.dtype
+ # Calculations are performed in float32 for numerical stability.
+ calc_dtype = torch.float32
+ x_calc = hidden_states.to(calc_dtype)
+
+ # Prepare a broadcastable mask (`mask_calc`).
+ # If no mask is provided, treat all elements as valid
+ # (mask_calc is all ones).
+ # Otherwise, expand the [B, T] mask to [B, T, 1, ..., 1] for broadcasting.
+ mask_calc = torch.ones_like(x_calc, dtype=calc_dtype)
+
+ # Cumulative Statistics Calculation
+ # 1. Sum of values over reduction axes at each time step.
+ sum_values_at_t = torch.sum(x_calc, dim=self.reduction_axes, keepdim=True)
+ # 2. Cumulative sum of values over time.
+ cum_sum_values = torch.cumsum(sum_values_at_t, dim=1)
+
+ # 3. Count of valid elements in the normalization group at each time step.
+ # (A "group" here consists of all features at a given Batch, Time).
+ elements_in_group_at_t = torch.sum(mask_calc, dim=self.reduction_axes, keepdim=True)
+ # 4. Cumulative count of valid elements over time.
+ cum_count_elements = torch.cumsum(elements_in_group_at_t, dim=1)
+ # Avoid division by zero if all preceding elements were masked.
+ safe_cum_count_elements = torch.clamp(cum_count_elements, min=1.0)
+
+ # 5. Cumulative mean.
+ cum_mean = cum_sum_values / safe_cum_count_elements
+
+ # 6. Sum of squared differences from the cumulative mean.
+ # Only sum for valid elements: (x_calc - cum_mean)^2 * mask_calc.
+ # Using x_calc here for the difference, as cum_mean already accounts for masking.
+ squared_diff_from_mean = (x_calc - cum_mean).pow(2)
+ sum_sq_diff_at_t = torch.sum(squared_diff_from_mean, dim=self.reduction_axes, keepdim=True)
+
+ # 7. Cumulative sum of squared differences over time.
+ cum_sum_sq_diff = torch.cumsum(sum_sq_diff_at_t, dim=1)
+
+ # 8. Cumulative variance.
+ cum_variance = cum_sum_sq_diff / safe_cum_count_elements
+
+ # Normalize the input using the calculated cumulative statistics:
+ # (x - E[x]) / sqrt(Var[x] + eps)
+ normalized_x = (x_calc - cum_mean) * torch.rsqrt(cum_variance + self.eps)
+
+ # Apply affine transformation (scale and bias) if enabled.
+ # Scale and bias are applied per-channel (last dimension).
+ scale = self.weight.to(calc_dtype)
+ # Reshape for broadcasting: [C] -> [1, ..., 1, C]
+ scale_view_shape = [1] * (hidden_states.dim() - 1) + [self.num_channels]
+ normalized_x = normalized_x * scale.view(scale_view_shape)
+
+ # Zero out outputs for time steps that were originally masked (where mask_calc is 0).
+ # This ensures padded/invalid positions in the input result in zero output.
+ final_output = normalized_x * mask_calc
+
+ return final_output.to(input_dtype)
+
+
+class Gemma3nAudioSSCPConvBlock(nn.Module):
+ """A single convolution block for the SubSampleConvProjection.
+
+ This block consists of a 2D convolution, followed by CumulativeGroupNorm,
+ and a ReLU activation. It handles manual padding for the convolution.
+ """
+
+ def __init__(
+ self,
+ config: Gemma3nAudioConfig,
+ idx: int,
+ input_freq_dim: int, # Changed from input_spatial_dim
+ manual_padding: tuple[int, int, int, int] = (0, 0, 0, 0),
+ ):
+ super().__init__()
+ self.config = config
+ self.manual_padding = manual_padding
+
+ # in_channels is 1 for the first block, or C_out from previous block's conv
+ in_channels = 1 if idx == 0 else self.config.sscp_conv_channel_size[idx - 1]
+ out_channels = self.config.sscp_conv_channel_size[idx]
+ kernel_h, kernel_w = self.config.sscp_conv_kernel_size[idx]
+ stride_h, stride_w = self.config.sscp_conv_stride_size[idx]
+
+ self.conv = nn.Conv2d(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=(
+ kernel_h,
+ kernel_w,
+ ), # Kernel (kH, kW) operates on (Time, Freq_dim)
+ stride=(stride_h, stride_w),
+ padding=(0, 0), # Manual padding is used
+ bias=False,
+ )
+
+ # Calculate output frequency dimension (f_out_conv) after this convolution.
+ # input_freq_dim is the unpadded width (feature dimension).
+ # self.manual_padding is (pad_F_left, pad_F_right, pad_T_top, pad_T_bottom)
+ f_in_padded = input_freq_dim + self.manual_padding[0] + self.manual_padding[1]
+ f_out_conv = (f_in_padded - kernel_w) // stride_w + 1
+
+ self.norm = Gemma3nAudioCumulativeGroupNorm(
+ num_channels=out_channels, # Channels of the conv output
+ feature_dims=(f_out_conv,), # The frequency dimension size after conv
+ eps=self.config.sscp_conv_group_norm_eps,
+ )
+
+ self.activation = nn.ReLU()
+
+ def forward(self, audio_encodings: torch.Tensor) -> torch.Tensor:
+ # Input audio_encodings is [B, C_in, T_in, F_in] (e.g., C_in=1)
+ # manual_padding is (pad_F_left, pad_F_right, pad_T_top, pad_T_bottom)
+ # F.pad applies to last two dims: F_in then T_in
+ audio_encodings_padded = F.pad(audio_encodings, self.manual_padding, mode="constant", value=0.0)
+ # Expected padded shape for F_in, k_w=3, pad_F=(1,1) -> F_padded = F_in+2
+ # Expected padded shape for T_in, k_h=3, pad_T=(0,2) -> T_padded = T_in+2
+ audio_encodings_conv = self.conv(audio_encodings_padded)
+ # Expected conv output shape: [B, C_out, T_out, F_out]
+ # Input to norm is [B, T_out, F_out, C_out]
+ x_for_norm = audio_encodings_conv.permute(0, 2, 3, 1).contiguous()
+ x_normed = self.norm(x_for_norm)
+ # Output of norm is [B, T_out, F_out, C_out], permute back to [B, C_out, T_out, F_out]
+ audio_encodings_normed = x_normed.permute(0, 3, 1, 2).contiguous()
+ return self.activation(audio_encodings_normed)
+
+
+class Gemma3nAudioSubSampleConvProjection(nn.Module):
+ def __init__(self, config: Gemma3nAudioConfig):
+ super().__init__()
+ self.config = config
+
+ current_f_for_block_input = config.input_feat_size # Start with original feature dim
+ calculated_block_padding = []
+ calculated_f_out_dims = [] # Tracking frequency dimension output sizes
+
+ for i in range(2): # Assuming 2 conv layers as per sscp_conv_... arrays
+ kernel_h, kernel_w = config.sscp_conv_kernel_size[i]
+ stride_h, stride_w = config.sscp_conv_stride_size[i]
+
+ # Padding for Time (Height for Conv2d) - REVERSE_CAUSAL like
+ # JAX 'reverse_causal' padding is (0, kernel_size - 1)
+ pad_t_top = 0
+ pad_t_bottom = kernel_h - 1
+
+ # Frequency Padding (Width for Conv2d)
+ # Based on JAX effective padding (1,1) for F_in=10, K_w=3, S_w=2
+ # and the successful test configuration.
+ # If kernel/stride/input_freq for frequency changes, this might need re-evaluation
+ # to match generic JAX 'SAME' behavior if it differs.
+ pad_f_left = 1
+ pad_f_right = 1
+
+ manual_padding_tuple = (
+ pad_f_left,
+ pad_f_right,
+ pad_t_top,
+ pad_t_bottom,
+ )
+ calculated_block_padding.append(manual_padding_tuple)
+
+ # Calculate output frequency dimension after this convolution
+ # This uses the actual padding applied and kernel/stride.
+ f_in_padded = current_f_for_block_input + pad_f_left + pad_f_right
+ f_out_after_conv = (f_in_padded - kernel_w) // stride_w + 1 # Assuming dilation_w = 1
+ calculated_f_out_dims.append(f_out_after_conv)
+ current_f_for_block_input = f_out_after_conv
+
+ self.conv_0 = Gemma3nAudioSSCPConvBlock(
+ idx=0,
+ input_freq_dim=config.input_feat_size, # Pass original feature dim
+ config=config,
+ manual_padding=calculated_block_padding[0],
+ )
+ self.conv_1 = Gemma3nAudioSSCPConvBlock(
+ idx=1,
+ input_freq_dim=calculated_f_out_dims[0], # Output freq dim from conv_0
+ config=config,
+ manual_padding=calculated_block_padding[1],
+ )
+ final_c_out = config.sscp_conv_channel_size[-1]
+ final_f_out = calculated_f_out_dims[-1] # Final frequency dimension
+ self.input_proj_in_features = final_c_out * final_f_out
+ self.input_proj_linear = nn.Linear(self.input_proj_in_features, self.config.hidden_size, bias=False)
+
+ def forward(self, audio_encodings: torch.Tensor) -> torch.Tensor:
+ # audio_encodings is [B, T, F_in]
+ # Reshape to [B, 1, T, F_in] (Batch, Channels=1, Height=Time, Width=F_in)
+ audio_encodings_reshaped = audio_encodings.unsqueeze(1)
+ x = self.conv_0(audio_encodings_reshaped)
+ x = self.conv_1(x)
+ # x from conv_1 is [B, C_out_1, T_out_1, F_out_1]
+ b, c_out, t_out, f_out = x.shape
+ # Permute to [B, T_out_1, F_out_1, C_out_1] then flatten F_out_1 and C_out_1
+ x_permuted = x.permute(0, 2, 3, 1).contiguous()
+ output_flattened = x_permuted.view(b, t_out, f_out * c_out)
+ output = self.input_proj_linear(output_flattened)
+ return output
+
+
+class Gemma3nAudioConformerAttention(nn.Module):
+ def __init__(self, config: Gemma3nAudioConfig):
+ super().__init__()
+ self.config = config
+ self.post_in_features = self.config.hidden_size
+ self.register_buffer("gradient_clipping", torch.tensor(self.config.gradient_clipping), persistent=False)
+ self.pre_attn_norm = Gemma3nRMSNorm(self.config.hidden_size)
+ self.attn = Gemma3nAudioAttention(config)
+ self.post = nn.Linear(self.post_in_features, self.config.hidden_size, bias=False)
+ self.post_norm = Gemma3nRMSNorm(self.config.hidden_size)
+
+ def forward(self, audio_encodings: torch.Tensor, audio_mel_mask: torch.BoolTensor) -> torch.Tensor:
+ audio_encodings_input_to_attn = audio_encodings
+ audio_encodings = torch.clamp(audio_encodings, -self.gradient_clipping, self.gradient_clipping)
+ audio_encodings_norm = self.pre_attn_norm(audio_encodings)
+ # Output of self.attn is [B, T, NumHeads, HeadDim]
+ audio_encodings_attn_out = self.attn(audio_encodings_norm, audio_mel_mask)
+
+ # Reshape from [B, T, NumHeads, HeadDim] to [B, T, NumHeads * HeadDim]
+ # NumHeads * HeadDim = hidden_size
+ b, t, num_heads, head_dim = audio_encodings_attn_out.shape
+ audio_encodings_reshaped = audio_encodings_attn_out.reshape(b, t, num_heads * head_dim)
+
+ audio_encodings = self.post(audio_encodings_reshaped)
+ audio_encodings = torch.clamp(audio_encodings, -self.gradient_clipping, self.gradient_clipping)
+ return audio_encodings_input_to_attn + self.post_norm(audio_encodings)
+
+
+class Gemma3nAudioConformerFeedForward(nn.Module):
+ def __init__(self, config: Gemma3nAudioConfig):
+ super().__init__()
+ self.config = config
+
+ self.register_buffer("gradient_clipping", torch.tensor(self.config.gradient_clipping), persistent=False)
+
+ self.pre_layer_norm = Gemma3nRMSNorm(self.config.hidden_size)
+ self.ffw_layer_1 = nn.Linear(self.config.hidden_size, self.config.hidden_size * 4, bias=False)
+ self.ffw_layer_2 = nn.Linear(self.config.hidden_size * 4, self.config.hidden_size, bias=False)
+ self.post_layer_norm = Gemma3nRMSNorm(self.config.hidden_size)
+ self.post_layer_scale = torch.tensor(self.config.conf_residual_weight)
+
+ def forward(self, audio_encodings: torch.Tensor) -> torch.Tensor:
+ residual = audio_encodings
+ audio_encodings = torch.clamp(audio_encodings, -self.gradient_clipping, self.gradient_clipping)
+ audio_encodings = self.pre_layer_norm(audio_encodings)
+ audio_encodings: torch.Tensor = self.ffw_layer_1(audio_encodings)
+ audio_encodings = nn.functional.silu(audio_encodings)
+ audio_encodings: torch.Tensor = self.ffw_layer_2(audio_encodings)
+ audio_encodings = torch.clamp(audio_encodings, -self.gradient_clipping, self.gradient_clipping)
+ audio_encodings = self.post_layer_norm(audio_encodings)
+ return residual + (audio_encodings * self.post_layer_scale)
+
+
+class Gemma3nAudioConformerLightConv1d(nn.Module):
+ def __init__(self, config: Gemma3nAudioConfig):
+ super().__init__()
+ self.config = config
+
+ self.pre_layer_norm = Gemma3nRMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps)
+ self.linear_start = nn.Linear(self.config.hidden_size, self.config.hidden_size * 2, bias=False)
+ self.depthwise_conv1d = nn.Conv1d(
+ in_channels=self.config.hidden_size,
+ out_channels=self.config.hidden_size,
+ kernel_size=self.config.conf_conv_kernel_size,
+ stride=1,
+ padding=0, # Manual causal padding
+ groups=self.config.hidden_size, # Depthwise
+ bias=False,
+ )
+ self.register_buffer("gradient_clipping", torch.tensor(self.config.gradient_clipping), persistent=False)
+ self.conv_norm = Gemma3nRMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps)
+ self.linear_end = nn.Linear(self.config.hidden_size, self.config.hidden_size, bias=False)
+
+ self.causal_padding = self.config.conf_conv_kernel_size - 1
+
+ def forward(self, audio_encodings: torch.Tensor) -> torch.Tensor:
+ audio_encodings_residual = audio_encodings # Save for residual connection
+
+ audio_encodings = self.pre_layer_norm(audio_encodings)
+ audio_encodings = self.linear_start(audio_encodings)
+ audio_encodings = torch.nn.functional.glu(audio_encodings, dim=-1)
+ # Permute for Conv1d: [B, T, D] -> [B, D, T]
+ audio_encodings_permuted = audio_encodings.permute(0, 2, 1)
+ # Apply manual causal padding
+ audio_encodings_permuted_padded = F.pad(audio_encodings_permuted, (self.causal_padding, 0))
+ audio_encodings = self.depthwise_conv1d(audio_encodings_permuted_padded)
+ # Permute back: [B, D, T_out] -> [B, T_out, D]
+ audio_encodings = audio_encodings.permute(0, 2, 1)
+ audio_encodings = torch.clamp(audio_encodings, -self.gradient_clipping, self.gradient_clipping)
+ audio_encodings = self.conv_norm(audio_encodings)
+ audio_encodings = nn.functional.silu(audio_encodings)
+ audio_encodings = self.linear_end(audio_encodings)
+ output = audio_encodings + audio_encodings_residual
+ return output
+
+
+class Gemma3nAudioConformerBlock(nn.Module):
+ def __init__(self, config: Gemma3nAudioConfig):
+ super().__init__()
+ self.config = config
+
+ self.ffw_layer_start = Gemma3nAudioConformerFeedForward(self.config)
+ self.attention = Gemma3nAudioConformerAttention(self.config)
+ self.lconv1d = Gemma3nAudioConformerLightConv1d(self.config)
+ self.ffw_layer_end = Gemma3nAudioConformerFeedForward(self.config)
+ self.register_buffer("gradient_clipping", torch.tensor(self.config.gradient_clipping), persistent=False)
+ self.norm = Gemma3nRMSNorm(self.config.hidden_size)
+
+ def forward(self, audio_encodings: torch.Tensor, audio_mel_mask: torch.BoolTensor) -> torch.Tensor:
+ audio_encodings = self.ffw_layer_start(audio_encodings)
+ audio_encodings = self.attention(audio_encodings, audio_mel_mask)
+ validity_mask_for_lconv = ~audio_mel_mask # True for valid
+ audio_encodings_for_lconv_input = audio_encodings * validity_mask_for_lconv.unsqueeze(-1).to(
+ audio_encodings.dtype
+ )
+ audio_encodings = self.lconv1d(audio_encodings_for_lconv_input)
+
+ audio_encodings = self.ffw_layer_end(audio_encodings)
+ audio_encodings = torch.clamp(audio_encodings, -self.gradient_clipping, self.gradient_clipping)
+ output = self.norm(audio_encodings)
+ return output
+
+
+class Gemma3nAudioEncoder(PreTrainedModel):
+ """A Universal Speech Encoder -- https://arxiv.org/abs/2303.01037"""
+
+ config_class = Gemma3nAudioConfig
+
+ main_input_name = "audio_mel"
+
+ def __init__(self, config: Gemma3nAudioConfig):
+ super().__init__(config)
+ self.config = config
+
+ self.subsample_conv_projection = Gemma3nAudioSubSampleConvProjection(config)
+ self.conformer = nn.ModuleList(
+ [Gemma3nAudioConformerBlock(config) for _ in range(config.conf_num_hidden_layers)]
+ )
+
+ def forward(
+ self, audio_mel: torch.Tensor, audio_mel_mask: torch.BoolTensor
+ ) -> tuple[torch.Tensor, torch.BoolTensor]:
+ """Encodes a batch of MELs.
+
+ Args:
+ audio_mel: a torch.Tensor of shape [batch, num_frames, num_channels,
+ mel_bins].
+
+ Returns:
+ audio_encodings: a torch.Tensor of shape
+ `[batch_size, self.config.audio_soft_tokens_per_image,
+ self.config.audio_config.hidden_size]`
+ audio_mel_mask: a torch.BoolTensor of shape [batch, num_frames].
+ """
+ audio_encodings = self.subsample_conv_projection(audio_mel) # audio_encodings: [B, T_sub, D]
+
+ # Subsample the input audio_mel_mask to match the time dimension of audio_encodings (T_sub)
+ t_sub = audio_encodings.shape[1]
+
+ time_stride_product = 1
+ for stride_pair_idx in range(len(self.config.sscp_conv_stride_size)):
+ time_stride_product *= self.config.sscp_conv_stride_size[stride_pair_idx][0]
+
+ # Create indices for gathering from the original mask.
+ # These indices map to original time steps corresponding to the start of each
+ # receptive field in the subsampled output.
+ indices = torch.arange(t_sub, device=audio_mel_mask.device) * time_stride_product
+ indices = torch.clamp(indices, max=audio_mel_mask.shape[1] - 1) # Ensure indices are valid
+
+ # Expand indices for batch compatibility if B > 1 and indices is 1D.
+ if audio_mel_mask.ndim > 1 and indices.ndim == 1:
+ indices = indices.unsqueeze(0).expand(audio_mel_mask.shape[0], -1) # [B, T_sub]
+ elif (
+ audio_mel_mask.ndim == indices.ndim
+ and audio_mel_mask.shape[0] == 1
+ and indices.shape[0] != 1
+ and t_sub == indices.shape[0]
+ ):
+ # Handle case where B=1 but indices became [T_sub] instead of [1, T_sub]
+ indices = indices.unsqueeze(0)
+
+ current_mask = torch.gather(audio_mel_mask, 1, indices) # [B, T_sub]
+
+ for block in self.conformer:
+ audio_encodings = block(audio_encodings, current_mask) # Pass the processed mask
+
+ if self.config.conf_reduction_factor > 1:
+ audio_encodings = audio_encodings[:, :: self.config.conf_reduction_factor]
+ # Reduce the mask as well
+ current_mask = current_mask[:, :: self.config.conf_reduction_factor]
+
+ audio_encodings = audio_encodings.masked_fill(current_mask.unsqueeze(-1), 0.0)
+ return audio_encodings, current_mask
+
+
+class Gemma3nTextScaledWordEmbedding(nn.Embedding):
+ """
+ This module overrides nn.Embeddings' forward by multiplying with embeddings scale.
+ """
+
+ def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, embed_scale: float = 1.0):
+ super().__init__(num_embeddings, embedding_dim, padding_idx)
+ self.register_buffer("embed_scale", torch.tensor(embed_scale), persistent=False)
+
+ def forward(self, input_ids: torch.Tensor):
+ return super().forward(input_ids) * self.embed_scale.to(self.weight.dtype)
+
+
+class Gemma3nTextLaurelBlock(nn.Module):
+ """Learned Augmented Residual Layer"""
+
+ def __init__(self, config: Gemma3nTextConfig):
+ super().__init__()
+ self.config = config
+
+ self.linear_left = nn.Linear(self.config.hidden_size, self.config.laurel_rank, bias=False)
+ self.linear_right = nn.Linear(self.config.laurel_rank, self.config.hidden_size, bias=False)
+ self.post_laurel_norm = Gemma3nRMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ laurel_hidden_states: torch.Tensor = self.linear_left(hidden_states)
+ laurel_hidden_states: torch.Tensor = self.linear_right(laurel_hidden_states)
+ normed_laurel_hidden_states = self.post_laurel_norm(laurel_hidden_states)
+ return hidden_states + normed_laurel_hidden_states
+
+
+class Gemma3nTextMLP(nn.Module):
+ def __init__(self, config: Gemma3nTextConfig, layer_idx: int = 0):
+ super().__init__()
+ self.config = config
+ self.hidden_size = config.hidden_size
+ self.intermediate_size = config.intermediate_size[layer_idx]
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
+ self.act_fn = ACT2FN[config.hidden_activation]
+ self.activation_sparsity = config.activation_sparsity_pattern[layer_idx]
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ gate_proj = self.gate_proj(hidden_states)
+ if self.activation_sparsity > 0.0:
+ gate_proj = self._gaussian_topk(gate_proj)
+ activations = self.act_fn(gate_proj)
+ up_proj = self.up_proj(hidden_states)
+ down_proj = self.down_proj(activations * up_proj)
+ return down_proj
+
+ def _gaussian_topk(self, inputs: torch.Tensor) -> torch.Tensor:
+ target_sparsity_tensor = torch.tensor(self.activation_sparsity, dtype=torch.float32, device=inputs.device)
+ # normal_dist and std_multiplier are adapted from jax.scipy.stats.norm.ppf().
+ #
+ # References:
+ # * https://docs.jax.dev/en/latest/_autosummary/jax.scipy.stats.norm.ppf.html
+ # * https://pytorch.org/docs/stable/distributions.html#torch.distributions.normal.Normal
+ # * https://pytorch.org/docs/stable/distributions.html#torch.distributions.transformed_distribution.TransformedDistribution.icdf
+ normal_dist = torch.distributions.normal.Normal(0, 1)
+ std_multiplier: torch.Tensor = normal_dist.icdf(target_sparsity_tensor)
+ std_multiplier = std_multiplier.type(inputs.dtype)
+ inputs_mean = torch.mean(inputs, dim=-1, keepdim=True)
+ inputs_std = torch.std(inputs, dim=-1, keepdim=True, unbiased=False)
+ cutoff_x = inputs_mean + inputs_std * std_multiplier
+ return nn.functional.relu(inputs - cutoff_x)
+
+
+class Gemma3nTextAltUp(nn.Module):
+ """Alternating Updates (AltUp)
+
+ The AltUp module wraps transformer layers. The `predict` step modifies the
+ input to the transformer layer, and the `correct` step propagates the output
+ of the transformer layer to the sparsely updated dimensions.
+
+ See more in the research paper:
+
+ https://proceedings.neurips.cc/paper_files/paper/2023/file/f2059277ac6ce66e7e5543001afa8bb5-Paper-Conference.pdf
+ """
+
+ def __init__(self, config: Gemma3nTextConfig):
+ super().__init__()
+ self.config = config
+ self.correct_output_scale = nn.Parameter(torch.zeros(self.config.hidden_size))
+ self.correction_coefs = nn.Linear(self.config.altup_num_inputs, self.config.altup_num_inputs, bias=False)
+ self.prediction_coefs = nn.Linear(self.config.altup_num_inputs, self.config.altup_num_inputs**2, bias=False)
+ self.modality_router = nn.Linear(self.config.hidden_size, self.config.altup_num_inputs, bias=False)
+ self.router_norm = Gemma3nRMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps)
+ self.register_buffer("router_input_scale", torch.tensor(self.config.hidden_size**-1.0), persistent=False)
+
+ def compute_router_modalities(self, x: torch.Tensor) -> torch.Tensor:
+ router_inputs = self.router_norm(x) * self.router_input_scale
+ routed = self.modality_router(router_inputs)
+ return torch.tanh(routed.float()).type_as(x)
+
+ def predict(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ """Predicts the output of a layer using a trainable map.
+
+ Args:
+ hidden_states: A 4D tensor of shape `[num_altup_inputs, batch_size, num_tokens, hidden_size]` derived by
+ stacking the input embeddings and preprocessing the last `num_altup_inputs - 1` matrices.
+
+ Returns:
+ A 4D tensor of shape `[num_altup_inputs, batch_size, num_tokens, hidden_size]` containing the predictions.
+ """
+ modalities = self.compute_router_modalities(hidden_states[self.config.altup_active_idx])
+
+ if self.training and self.config.altup_coef_clip is not None:
+ self.prediction_coefs.weight.data.clamp_(-self.config.altup_coef_clip, self.config.altup_coef_clip)
+
+ # Project and then transpose all 2D matrices contained so that mulmat gives the correct result
+ all_coefs: torch.Tensor = (
+ self.prediction_coefs(modalities)
+ .reshape(*modalities.shape[:-1], self.config.altup_num_inputs, self.config.altup_num_inputs)
+ .permute(0, 1, 3, 2)
+ )
+
+ # permute hidden_states to [batch_size, num_tokens, hidden_size, altup_num_inputs]
+ predictions = torch.matmul(hidden_states.permute(1, 2, 3, 0), all_coefs)
+ predictions = predictions.permute(3, 0, 1, 2) # undo the permute
+ predictions += hidden_states # add the original input
+ return predictions.contiguous().type_as(hidden_states)
+
+ def correct(self, predictions: torch.Tensor, activated: torch.Tensor) -> torch.Tensor:
+ """Corrects the predictions relative to the
+
+ Args:
+ predictions: A 4D tensor of shape `[num_altup_inputs, batch_size, num_tokens, hidden_size]` derived by
+ stacking the input embeddings and preprocessing the last `num_altup_inputs - 1` matrices.
+ activated: A 3D tensor of shape `[batch_size, num_tokens, hidden_size]` containing the activated inputs.
+
+ Returns:
+ A 4D tensor of shape `[num_altup_inputs, batch_size, num_tokens, hidden_size]` correcting the original
+ predictions relative to the activated input embeddings.
+ """
+ modalities = self.compute_router_modalities(activated)
+ innovation = activated - predictions[self.config.altup_active_idx] # (batch, num_tokens, hidden_size)
+ innovation = innovation.repeat(self.config.altup_num_inputs, 1, 1, 1) # Repeat on dim0 to match predictions
+
+ if self.config.altup_coef_clip is not None:
+ self.correction_coefs.weight.data.clamp_(-self.config.altup_coef_clip, self.config.altup_coef_clip)
+
+ # all_coefs adapted from jax.numpy.einsum("...p,pi->...i", ...)
+ # Permute to (altup_num_inputs, batch_size, num_tokens) as the last dim is a scalar applied to each altup input
+ # and expand on dim1 for broadcastability
+ all_coefs: torch.Tensor = self.correction_coefs(modalities) + 1.0
+ all_coefs = all_coefs.permute(2, 0, 1).unsqueeze(-1)
+
+ corrected = torch.mul(innovation, all_coefs)
+ corrected += predictions # add the original input
+ return corrected.contiguous().type_as(activated)
+
+ def scale_corrected_output(self, corrected: torch.Tensor) -> torch.Tensor:
+ """Scales the provided 3D tensor of shape [batch_size, num_tokens, hidden_size]."""
+ return (corrected.type_as(self.correct_output_scale) * self.correct_output_scale).type_as(corrected)
+
+
+class Gemma3nTextRotaryEmbedding(nn.Module):
+ def __init__(self, config: Gemma3nTextConfig, device=None):
+ super().__init__()
+ # BC: "rope_type" was originally "type"
+ if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
+ self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
+ else:
+ self.rope_type = "default"
+ self.max_seq_len_cached = config.max_position_embeddings
+ self.original_max_seq_len = config.max_position_embeddings
+
+ self.config = config
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
+
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
+ self.original_inv_freq = self.inv_freq
+
+ @torch.no_grad()
+ @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
+ def forward(self, x, position_ids):
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
+ position_ids_expanded = position_ids[:, None, :].float()
+
+ device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
+ with torch.autocast(device_type=device_type, enabled=False): # Force float32
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
+ emb = torch.cat((freqs, freqs), dim=-1)
+ cos = emb.cos() * self.attention_scaling
+ sin = emb.sin() * self.attention_scaling
+
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
+
+
+def rotate_half(x):
+ """Rotates half the hidden dims of the input."""
+ x1 = x[..., : x.shape[-1] // 2]
+ x2 = x[..., x.shape[-1] // 2 :]
+ return torch.cat((-x2, x1), dim=-1)
+
+
+def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
+ """
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
+ """
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
+ if n_rep == 1:
+ return hidden_states
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
+
+
+def eager_attention_forward(
+ module: nn.Module,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attention_mask: Optional[torch.Tensor],
+ dropout: float = 0.0,
+ scaling: Optional[float] = None,
+ softcap: Optional[float] = None,
+ **kwargs,
+) -> tuple[torch.Tensor, torch.Tensor]:
+ if scaling is None:
+ scaling = module.head_dim**-0.5
+
+ key_states = repeat_kv(key, module.num_key_value_groups)
+ value_states = repeat_kv(value, module.num_key_value_groups)
+
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
+
+ if softcap is not None:
+ attn_weights = attn_weights / softcap
+ attn_weights = torch.tanh(attn_weights)
+ attn_weights = attn_weights * softcap
+ if attention_mask is not None: # no matter the length, we just slice it
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
+ attn_weights = attn_weights + causal_mask
+
+ # upcast attention to fp32
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
+ attn_output = torch.matmul(attn_weights, value_states)
+ attn_output = attn_output.transpose(1, 2).contiguous()
+ return attn_output, attn_weights
+
+
+def apply_rotary_pos_emb(
+ x: torch.Tensor,
+ cos: torch.Tensor,
+ sin: torch.Tensor,
+ position_ids: Optional[torch.Tensor] = None,
+ unsqueeze_dim: int = 1,
+):
+ """Applies Rotary Position Embedding to the query and key tensors.
+
+ Args:
+ x (`torch.Tensor`): The tensor to embed.
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
+ position_ids (`torch.Tensor`, *optional*):
+ Deprecated and unused.
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
+ Returns:
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
+ """
+ cos = cos.unsqueeze(unsqueeze_dim)
+ sin = sin.unsqueeze(unsqueeze_dim)
+ return (x * cos) + (rotate_half(x) * sin)
+
+
+class Gemma3nTextAttention(nn.Module):
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+ def __init__(self, config: Gemma3nTextConfig, layer_idx: int):
+ super().__init__()
+ self.is_sliding = config.layer_types[layer_idx] == "sliding_attention"
+ self.config = config
+ self.layer_idx = layer_idx
+ self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
+ self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
+ self.attention_dropout = self.config.attention_dropout
+ self.is_causal = True
+
+ self.q_proj = nn.Linear(
+ config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
+ )
+ self.k_proj = nn.Linear(
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
+ )
+ self.v_proj = nn.Linear(
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
+ )
+ self.o_proj = nn.Linear(
+ config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
+ )
+ self.sliding_window = config.sliding_window if self.is_sliding else None
+
+ self.q_norm = Gemma3nRMSNorm(dim=config.head_dim, eps=config.rms_norm_eps)
+ self.k_norm = Gemma3nRMSNorm(dim=config.head_dim, eps=config.rms_norm_eps)
+ self.v_norm = Gemma3nRMSNorm(dim=config.head_dim, eps=config.rms_norm_eps, with_scale=False)
+
+ first_kv_shared_layer_idx = self.config.num_hidden_layers - self.config.num_kv_shared_layers
+ self.is_kv_shared_layer = layer_idx >= first_kv_shared_layer_idx
+ # Find the index of the last sliding or full layer before sharing starts (or None if no sharing)
+ layer_type = config.layer_types[layer_idx]
+ self.kv_shared_layer_index = (
+ first_kv_shared_layer_idx - 1 - config.layer_types[first_kv_shared_layer_idx - 1 :: -1].index(layer_type)
+ if self.is_kv_shared_layer
+ else None
+ )
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ position_embeddings: torch.Tensor,
+ attention_mask: Optional[torch.Tensor],
+ past_key_value: Optional[Cache] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs: Unpack[FlashAttentionKwargs],
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
+ input_shape = hidden_states.shape[:-1]
+ hidden_shape = (*input_shape, -1, self.config.head_dim)
+
+ cos, sin = position_embeddings
+
+ query_states = self.q_proj(hidden_states).view(hidden_shape)
+ query_states = self.q_norm(query_states)
+ query_states = apply_rotary_pos_emb(query_states, cos, sin, unsqueeze_dim=2)
+ query_states = query_states.transpose(1, 2)
+
+ if self.is_kv_shared_layer and self.kv_shared_layer_index is not None and past_key_value is not None:
+ # HybridCache has complex slicing when layer_type == "sliding_attention" that impact Shared KV Cache.
+ if isinstance(past_key_value, HybridCache) and self.is_sliding:
+ max_length = past_key_value.sliding_window
+ if cache_position.shape[0] > max_length:
+ # If in the prefill phase for a "sliding_attention" layer and the prefill is larger than the cache,
+ # slice into the entire cache.
+ indices = slice(0, max_length)
+ else:
+ # If prefill fits or generating for a "sliding_attention" layer, clamp to max_cache_len - 1
+ indices = cache_position.clamp(min=0, max=max_length - 1)
+ else:
+ indices = cache_position
+
+ key_states = past_key_value.key_cache[self.kv_shared_layer_index][:, :, indices]
+ value_states = past_key_value.value_cache[self.kv_shared_layer_index][:, :, indices]
+ else:
+ key_states = self.k_proj(hidden_states).view(hidden_shape)
+ key_states = self.k_norm(key_states)
+ key_states = apply_rotary_pos_emb(key_states, cos, sin, unsqueeze_dim=2)
+ key_states = key_states.transpose(1, 2)
+
+ value_states = self.v_proj(hidden_states).view(hidden_shape)
+ value_states = self.v_norm(value_states)
+ value_states = value_states.transpose(1, 2)
+
+ if past_key_value is not None:
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
+ cache_kwargs = {
+ "sin": sin,
+ "cos": cos,
+ "cache_position": cache_position,
+ "sliding_window": self.sliding_window,
+ }
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ attention_interface: Callable = eager_attention_forward
+ if self.config._attn_implementation != "eager":
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+
+ attn_output, attn_weights = attention_interface(
+ self,
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ dropout=self.attention_dropout if self.training else 0.0,
+ scaling=1.0,
+ sliding_window=self.sliding_window,
+ **kwargs,
+ )
+
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
+ attn_output = self.o_proj(attn_output)
+ return attn_output, attn_weights
+
+
+class Gemma3nTextDecoderLayer(GradientCheckpointingLayer):
+ def __init__(self, config: Gemma3nTextConfig, layer_idx: int):
+ super().__init__()
+ self.config = config
+ self.hidden_size = config.hidden_size
+ self.layer_idx = layer_idx
+ self.attention_type = config.layer_types[layer_idx]
+ self.self_attn = Gemma3nTextAttention(config, layer_idx)
+ self.mlp = Gemma3nTextMLP(config, layer_idx=layer_idx)
+ self.input_layernorm = Gemma3nRMSNorm(self.hidden_size, eps=config.rms_norm_eps)
+ self.post_attention_layernorm = Gemma3nRMSNorm(self.hidden_size, eps=config.rms_norm_eps)
+ self.pre_feedforward_layernorm = Gemma3nRMSNorm(self.hidden_size, eps=config.rms_norm_eps)
+ self.post_feedforward_layernorm = Gemma3nRMSNorm(self.hidden_size, eps=config.rms_norm_eps)
+
+ self.hidden_size_per_layer_input = config.hidden_size_per_layer_input
+ self.act_fn = ACT2FN[config.hidden_activation]
+
+ self.altup = Gemma3nTextAltUp(config)
+ self.laurel = Gemma3nTextLaurelBlock(config)
+ self.per_layer_input_gate = nn.Linear(self.hidden_size, self.hidden_size_per_layer_input, bias=False)
+ self.per_layer_projection = nn.Linear(self.hidden_size_per_layer_input, self.hidden_size, bias=False)
+ self.post_per_layer_input_norm = Gemma3nRMSNorm(self.hidden_size, eps=config.rms_norm_eps)
+
+ @deprecate_kwarg("last_cache_position", version="4.53.0")
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ position_embeddings_global: torch.Tensor,
+ position_embeddings_local: torch.Tensor,
+ per_layer_input: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Cache] = None,
+ output_attentions: Optional[bool] = False,
+ use_cache: Optional[bool] = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs,
+ ) -> tuple[torch.Tensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
+ predictions = self.altup.predict(hidden_states)
+ active_prediction = predictions[self.config.altup_active_idx]
+
+ active_prediction_normed = self.input_layernorm(active_prediction)
+ laurel_output = self.laurel(active_prediction_normed)
+
+ # apply global RoPE to non-sliding layer only
+ if self.self_attn.is_sliding:
+ position_embeddings = position_embeddings_local
+ else:
+ position_embeddings = position_embeddings_global
+
+ attn, self_attn_weights = self.self_attn(
+ hidden_states=active_prediction_normed,
+ position_embeddings=position_embeddings,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ **kwargs,
+ )
+ attn = self.post_attention_layernorm(attn)
+
+ attn_gated = active_prediction + attn
+ attn_laurel = (attn_gated + laurel_output) / math.sqrt(2)
+
+ attn_norm = self.pre_feedforward_layernorm(attn_laurel)
+ attn_ffw = self.mlp(attn_norm)
+ attn_ffw_norm = self.post_feedforward_layernorm(attn_ffw)
+ attn_ffw_laurel_gated = attn_laurel + attn_ffw_norm
+ corrected_predictions = self.altup.correct(predictions, attn_ffw_laurel_gated)
+
+ first_prediction = corrected_predictions[self.config.altup_active_idx]
+ first_prediction_clone = first_prediction.clone()
+ if self.config.altup_correct_scale:
+ first_prediction = self.altup.scale_corrected_output(first_prediction_clone)
+
+ # per_layer_input_gate adapted from jax.numpy.einsum("btd,dp->btp", ...)
+ first_prediction = self.per_layer_input_gate(first_prediction)
+ first_prediction = self.act_fn(first_prediction)
+ first_prediction = torch.multiply(first_prediction, per_layer_input)
+
+ # per_layer_projection adapted from jax.numpy.einsum("btp,pd->btd", ...)
+ first_prediction = self.per_layer_projection(first_prediction)
+ first_prediction = self.post_per_layer_input_norm(first_prediction)
+ corrected_predictions[1:] += first_prediction
+
+ outputs = (corrected_predictions,)
+
+ if output_attentions:
+ outputs += (self_attn_weights,)
+
+ return outputs
+
+
+@auto_docstring
+class Gemma3nPreTrainedModel(PreTrainedModel):
+ config_class = Gemma3nConfig
+ base_model_prefix = ""
+ supports_gradient_checkpointing = True
+ _no_split_modules = ["Gemma3nDecoderLayer"]
+ _skip_keys_device_placement = ["past_key_values"]
+ _supports_flash_attn_3 = True
+ _supports_flash_attn_2 = True
+ _supports_sdpa = True
+ _supports_flex_attn = True
+ _supports_cache_class = True
+ _supports_quantized_cache = True
+ _supports_static_cache = True
+ _supports_attention_backend = True
+
+ def _init_weights(self, module):
+ # important: this ported version of Gemma2 isn't meant for training from scratch - only
+ # inference and fine-tuning - so the proper init weights code has been removed
+ std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range)
+
+ if isinstance(module, (nn.Linear, nn.Conv1d, nn.Conv2d)):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.Embedding):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+ elif isinstance(module, Gemma3nRMSNorm):
+ if module.with_scale:
+ module.weight.data.fill_(1.0)
+ elif isinstance(module, Gemma3nAudioCumulativeGroupNorm):
+ module.weight.data.fill_(1.0)
+ elif isinstance(module, Gemma3nAudioAttention):
+ module.per_dim_scale.data.zero_()
+ elif isinstance(module, Gemma3nTextAltUp):
+ module.correct_output_scale.data.zero_()
+
+
+@auto_docstring(custom_intro="The base Gemma 3n language model without a language modeling head.")
+class Gemma3nTextModel(Gemma3nPreTrainedModel):
+ config_class = Gemma3nTextConfig
+
+ def __init__(self, config: Gemma3nTextConfig):
+ super().__init__(config)
+ self.padding_idx = config.pad_token_id
+ self.vocab_size = config.vocab_size
+
+ # Gemma3n downcasts the below to bfloat16, causing sqrt(3072)=55.4256 to become 55.5. See https://github.com/huggingface/transformers/pull/29402
+ self.embed_tokens = Gemma3nTextScaledWordEmbedding(
+ config.vocab_size, config.hidden_size, self.padding_idx, embed_scale=self.config.hidden_size**0.5
+ )
+ self.layers = nn.ModuleList(
+ [Gemma3nTextDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
+ )
+
+ self.norm = Gemma3nRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.rotary_emb = Gemma3nTextRotaryEmbedding(config=config)
+ self.gradient_checkpointing = False
+
+ # TODO (raushan): Fix this after RoPE refactor. For now we hack it by
+ # reassigning thetas when we want to create a local RoPE layer. Config
+ # defaults should hold values for global RoPE.
+ config = copy.deepcopy(config)
+ config.rope_theta = config.rope_local_base_freq
+ config.rope_scaling = {"rope_type": "default"}
+ self.rotary_emb_local = Gemma3nTextRotaryEmbedding(config=config)
+
+ self.hidden_size = config.hidden_size
+ self.hidden_size_per_layer_input = config.hidden_size_per_layer_input
+
+ self.embed_tokens_per_layer = Gemma3nTextScaledWordEmbedding(
+ config.vocab_size_per_layer_input,
+ config.num_hidden_layers * config.hidden_size_per_layer_input,
+ self.padding_idx,
+ embed_scale=config.hidden_size_per_layer_input**0.5,
+ )
+
+ self.per_layer_model_projection = nn.Linear(
+ self.hidden_size,
+ config.num_hidden_layers * config.hidden_size_per_layer_input,
+ bias=False,
+ )
+
+ self.per_layer_projection_norm = Gemma3nRMSNorm(config.hidden_size_per_layer_input, eps=config.rms_norm_eps)
+
+ self.altup_projections = nn.ModuleList(
+ [nn.Linear(self.hidden_size, self.hidden_size, bias=False) for _ in range(1, self.config.altup_num_inputs)]
+ )
+
+ self.altup_unembed_projections = nn.ModuleList(
+ [nn.Linear(self.hidden_size, self.hidden_size, bias=False) for _ in range(1, self.config.altup_num_inputs)]
+ )
+
+ self.register_buffer("per_layer_projection_scale", torch.tensor(self.hidden_size**-0.5), persistent=False)
+ self.register_buffer("per_layer_input_scale", torch.rsqrt(torch.tensor(2.0)), persistent=False)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.embed_tokens = value
+
+ @can_return_tuple
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ per_layer_inputs: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ **flash_attn_kwargs: Unpack[FlashAttentionKwargs],
+ ) -> BaseModelOutputWithPast:
+ r"""
+ per_layer_inputs (torch.Tensor, *optional*, defaults to None):
+ Pre-computed per-layer embeddings. If None, they are derived from input_ids if provided.
+ """
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
+
+ if self.gradient_checkpointing and self.training and use_cache:
+ logger.warning_once(
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
+ )
+ use_cache = False
+
+ if input_ids is not None:
+ inputs_embeds = self.embed_tokens(input_ids)
+ per_layer_inputs = self.get_per_layer_inputs(input_ids)
+
+ per_layer_inputs = self.project_per_layer_inputs(inputs_embeds, per_layer_inputs)
+
+ if use_cache and past_key_values is None and not self.training:
+ past_key_values = DynamicCache()
+
+ if cache_position is None:
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
+ cache_position = torch.arange(
+ past_seen_tokens,
+ past_seen_tokens + inputs_embeds.shape[1],
+ device=inputs_embeds.device,
+ )
+
+ if position_ids is None:
+ position_ids = cache_position.unsqueeze(0)
+
+ # It may already have been prepared by e.g. `generate`
+ if not isinstance(causal_mask_mapping := attention_mask, dict):
+ # Prepare mask arguments
+ mask_kwargs = {
+ "config": self.config,
+ "input_embeds": inputs_embeds,
+ "attention_mask": attention_mask,
+ "cache_position": cache_position,
+ "past_key_values": past_key_values,
+ }
+ # Create the masks
+ causal_mask_mapping = {
+ "full_attention": create_causal_mask(**mask_kwargs),
+ "sliding_attention": create_sliding_window_causal_mask(**mask_kwargs),
+ }
+
+ # embed positions
+ hidden_states_0 = inputs_embeds
+
+ # Initialize RoPE embeddings
+ position_embeddings_global = self.rotary_emb(hidden_states_0, position_ids)
+ position_embeddings_local = self.rotary_emb_local(hidden_states_0, position_ids)
+
+ # Expand hidden_states to support per-layer inputs
+ target_magnitude: torch.Tensor = torch.mean(hidden_states_0**2, dim=-1, keepdim=True) ** 0.5
+ epsilon_tensor = torch.tensor(torch.finfo().min)
+
+ temp_hidden_states = [hidden_states_0]
+ for i in range(1, self.config.altup_num_inputs):
+ # altup_proj adapted from jax.numpy.einsum("btp,pd->btd", ...)
+ altup_proj: torch.Tensor = self.altup_projections[i - 1](hidden_states_0)
+ current_hidden_state = altup_proj.type(hidden_states_0.dtype)
+ new_magnitude = torch.mean(current_hidden_state**2, dim=-1, keepdim=True) ** 0.5
+ current_hidden_state = current_hidden_state * (
+ target_magnitude / torch.maximum(new_magnitude, epsilon_tensor)
+ )
+ temp_hidden_states.append(current_hidden_state)
+
+ hidden_states = torch.stack(temp_hidden_states, dim=0) # [num_altup_inputs, batch, seq_len, hidden_size]
+
+ # decoder layers
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attns = () if output_attentions else None
+
+ for decoder_layer in self.layers[: self.config.num_hidden_layers]:
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ causal_mask = causal_mask_mapping[decoder_layer.attention_type]
+ per_layer_input = per_layer_inputs[:, :, decoder_layer.layer_idx, :]
+
+ layer_outputs = decoder_layer(
+ hidden_states,
+ position_embeddings_global=position_embeddings_global,
+ position_embeddings_local=position_embeddings_local,
+ per_layer_input=per_layer_input,
+ attention_mask=causal_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_values,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ **flash_attn_kwargs,
+ )
+
+ hidden_states = layer_outputs[0]
+
+ if output_attentions:
+ all_self_attns += (layer_outputs[1],)
+
+ # add hidden states from the last decoder layer (but before reprojecting to stay consistent with layer output)
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ # Per-layer inputs to single output
+ target_magnitude = torch.mean(hidden_states[0] ** 2, dim=-1, keepdim=True) ** 0.5
+ temp_hidden_states = [hidden_states[0]]
+ for i in range(1, self.config.altup_num_inputs):
+ # altup_unembed_projections adapted from jax.numpy.einsum("btp,pd->btd", ...)
+ altup_unemb_proj: torch.Tensor = self.altup_unembed_projections[i - 1](hidden_states[i])
+ current_hidden_state = altup_unemb_proj.type(hidden_states_0.dtype)
+ new_magnitude = torch.mean(current_hidden_state**2, dim=-1, keepdim=True) ** 0.5
+ current_hidden_state = current_hidden_state * (
+ target_magnitude / torch.maximum(new_magnitude, epsilon_tensor)
+ )
+ temp_hidden_states.append(current_hidden_state)
+
+ hidden_states = torch.stack(temp_hidden_states)
+ hidden_states = torch.mean(hidden_states, dim=0)
+ hidden_states = self.norm(hidden_states)
+
+ return BaseModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ past_key_values=past_key_values,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attns,
+ )
+
+ def get_per_layer_inputs(self, input_ids: torch.LongTensor) -> torch.Tensor:
+ return self.embed_tokens_per_layer(input_ids).reshape(
+ *input_ids.shape,
+ self.config.num_hidden_layers,
+ self.hidden_size_per_layer_input,
+ )
+
+ def project_per_layer_inputs(
+ self,
+ inputs_embeds: torch.Tensor,
+ per_layer_inputs: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ per_layer_projection: torch.Tensor = self.per_layer_model_projection(inputs_embeds)
+ per_layer_projection *= self.per_layer_projection_scale.type(inputs_embeds.dtype)
+ per_layer_projection = per_layer_projection.reshape(
+ *inputs_embeds.shape[:-1],
+ self.config.num_hidden_layers,
+ self.hidden_size_per_layer_input,
+ )
+ per_layer_projection = self.per_layer_projection_norm(per_layer_projection)
+
+ if per_layer_inputs is None:
+ return per_layer_projection
+
+ if per_layer_projection.shape != per_layer_inputs.shape:
+ # per-layer inputs are sometimes padded with zeros, slice the relevant embeddings.
+ per_layer_inputs = per_layer_inputs[..., : self.config.num_hidden_layers, :]
+
+ return (per_layer_projection + per_layer_inputs) * self.per_layer_input_scale.type(inputs_embeds.dtype)
+
+
+@auto_docstring(custom_intro="The base Gemma 3n language model with a language modeling head.")
+class Gemma3nForCausalLM(Gemma3nPreTrainedModel, GenerationMixin):
+ _tied_weights_keys = ["lm_head.weight"]
+ _tp_plan = {"lm_head": "colwise_rep"}
+ _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
+ config_class = Gemma3nTextConfig
+ base_model_prefix = "model"
+ _checkpoint_conversion_mapping = {"model.language_model": "model"}
+
+ def __init__(self, config: Gemma3nTextConfig):
+ super().__init__(config)
+ self.model = Gemma3nTextModel(config)
+ self.vocab_size = config.vocab_size
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.model.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.model.embed_tokens = value
+
+ def get_output_embeddings(self):
+ return self.lm_head
+
+ def set_output_embeddings(self, new_embeddings):
+ self.lm_head = new_embeddings
+
+ def set_decoder(self, decoder):
+ self.model = decoder
+
+ def get_decoder(self):
+ return self.model
+
+ @can_return_tuple
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ logits_to_keep: Union[int, torch.Tensor] = 0,
+ **loss_kwargs,
+ ) -> CausalLMOutputWithPast:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+
+ Example:
+
+ ```python
+ >>> from transformers import AutoTokenizer, Gemma3nForCausalLM
+
+ >>> model = Gemma3nForCausalLM.from_pretrained("google/gemma-2-9b")
+ >>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-9b")
+
+ >>> prompt = "What is your favorite condiment?"
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
+
+ >>> # Generate
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
+ "What is your favorite condiment?"
+ ```"""
+
+ if self.training and self.config._attn_implementation != "eager":
+ logger.warning_once(
+ "It is strongly recommended to train Gemma3n models with the `eager` attention implementation "
+ f"instead of `{self.config._attn_implementation}`. Use `eager` with `AutoModelForCausalLM.from_pretrained('', attn_implementation='eager')`."
+ )
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
+ outputs: BaseModelOutputWithPast = self.model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ cache_position=cache_position,
+ **loss_kwargs,
+ )
+
+ hidden_states = outputs.last_hidden_state
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
+ if self.config.final_logit_softcapping is not None:
+ logits = logits / self.config.final_logit_softcapping
+ logits = torch.tanh(logits)
+ logits = logits * self.config.final_logit_softcapping
+
+ loss = None
+ if labels is not None:
+ loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
+
+ return CausalLMOutputWithPast(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+class Gemma3nMultimodalEmbedder(nn.Module):
+ """Embeds token ids or soft tokens for multimodal content into language model space."""
+
+ def __init__(
+ self,
+ multimodal_config: Union[Gemma3nAudioConfig, Gemma3nVisionConfig],
+ text_config: Gemma3nTextConfig,
+ ):
+ super().__init__()
+
+ self.multimodal_hidden_size = multimodal_config.hidden_size
+ self.eps = multimodal_config.rms_norm_eps
+ self.vocab_offset = multimodal_config.vocab_offset
+ self.vocab_size = multimodal_config.vocab_size
+ self.text_hidden_size = text_config.hidden_size
+
+ self.embedding = nn.Embedding(self.vocab_size, self.multimodal_hidden_size)
+ self.hard_embedding_norm = Gemma3nRMSNorm(self.multimodal_hidden_size, eps=self.eps)
+ self.soft_embedding_norm = Gemma3nRMSNorm(self.multimodal_hidden_size, eps=self.eps)
+ self.embedding_projection = nn.Linear(self.multimodal_hidden_size, self.text_hidden_size, bias=False)
+ self.embedding_post_projection_norm = Gemma3nRMSNorm(self.text_hidden_size, eps=self.eps, with_scale=False)
+
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ """Embeds token ids or soft tokens for multimodal content into language model space.
+
+ Args:
+ input_ids: A torch.LongTensor containing the token ids to embed. Values should be in the range
+ `[vocab_offset, vocab_offset + vocab_size)`.
+ inputs_embeds: A torch.Tensor containing the soft tokens to embed.
+
+ Returns:
+ A torch.Tensor of embeddings with shape `[batch_size, seq_len, self.config.text_config.hidden_size]`.
+ """
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
+
+ if inputs_embeds is not None:
+ emb_norm = self.soft_embedding_norm(inputs_embeds)
+ else:
+ hard_emb = self.embedding(input_ids - self.vocab_offset)
+ emb_norm = self.hard_embedding_norm(hard_emb)
+
+ emb_norm_proj = self.embedding_projection(emb_norm)
+ return self.embedding_post_projection_norm(emb_norm_proj)
+
+
+@auto_docstring(
+ custom_intro="""
+ The base Gemma 3n model comprising a vision backbone, an audio backbone, and a language model without a
+ language modeling head.
+ """
+)
+class Gemma3nModel(Gemma3nPreTrainedModel):
+ _checkpoint_conversion_mapping = {}
+ # we are filtering the logits/labels so we shouldn't divide the loss based on num_items_in_batch
+ accepts_loss_kwargs = False
+
+ def __init__(self, config: Gemma3nConfig):
+ super().__init__(config)
+ self.vision_tower = AutoModel.from_config(config=config.vision_config)
+ self.vocab_size = config.text_config.vocab_size
+
+ language_model = AutoModel.from_config(config=config.text_config)
+ self.language_model = language_model
+
+ self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
+ self.vocab_size_per_layer_input = config.text_config.vocab_size_per_layer_input
+ self.audio_tower = AutoModel.from_config(config.audio_config)
+ self.embed_vision = Gemma3nMultimodalEmbedder(config.vision_config, config.text_config)
+ self.embed_audio = Gemma3nMultimodalEmbedder(config.audio_config, config.text_config)
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.language_model.get_input_embeddings()
+
+ def set_input_embeddings(self, value):
+ self.language_model.set_input_embeddings(value)
+
+ def set_decoder(self, decoder):
+ self.language_model = decoder
+
+ def get_decoder(self):
+ return self.language_model
+
+ def get_image_features(self, pixel_values: torch.Tensor) -> torch.Tensor:
+ """
+ Projects the last hidden state from the vision model into language model space.
+
+ Args:
+ pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`)
+ The tensors corresponding to the input images.
+
+ Returns:
+ image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
+ """
+ vision_outputs = self.vision_tower(
+ pixel_values=pixel_values, do_pooling=False, return_dict=True
+ ).last_hidden_state
+ # Convert from (batch, channels, height, width) to (batch, height * width, channels) where:
+ # height == width and height * width == Gemma3nConfig.vision_soft_tokens_per_image.
+ vision_outputs = vision_outputs.reshape(
+ vision_outputs.shape[0],
+ self.config.vision_config.hidden_size,
+ self.config.vision_soft_tokens_per_image,
+ ).permute(0, 2, 1)
+ # Normalize and embed the soft tokens into language model space.
+ vision_outputs *= self.config.vision_config.hidden_size**0.5
+ return self.embed_vision(inputs_embeds=vision_outputs)
+
+ @can_return_tuple
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None, # text inputs
+ pixel_values: Optional[torch.FloatTensor] = None, # vision inputs
+ input_features: Optional[torch.FloatTensor] = None, # audio inputs
+ attention_mask: Optional[torch.Tensor] = None,
+ input_features_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Union[list[torch.FloatTensor], Cache]] = None,
+ token_type_ids: Optional[torch.LongTensor] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ **lm_kwargs,
+ ) -> Gemma3nCausalLMOutputWithPast:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+ config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.text_config.vocab_size]`.
+
+ Example:
+
+ ```python
+ >>> from PIL import Image
+ >>> import requests
+ >>> from transformers import AutoProcessor, Gemma3nForConditionalGeneration
+
+ >>> model = Gemma3nForConditionalGeneration.from_pretrained("google/gemma3n2-3b-mix-224")
+ >>> processor = AutoProcessor.from_pretrained("google/gemma3n2-3b-mix-224")
+
+ >>> prompt = "Where is the cat standing?"
+ >>> url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg"
+ >>> image = Image.open(requests.get(url, stream=True).raw)
+
+ >>> inputs = processor(images=image, text=prompt, return_tensors="pt")
+
+ >>> # Generate
+ >>> generate_ids = model.generate(**inputs,)
+ >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
+ "Where is the cat standing?\nsnow"
+ ```
+ """
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
+
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+
+ if input_ids is not None:
+ inputs_embeds = self.get_input_embeddings()(input_ids)
+
+ # Prepare per-layer inputs from inputs_ids
+ per_layer_inputs_mask = torch.logical_and(input_ids >= 0, input_ids < self.vocab_size_per_layer_input)
+ per_layer_inputs_tokens = torch.where(per_layer_inputs_mask, input_ids, torch.zeros_like(input_ids))
+ per_layer_inputs = self.language_model.get_per_layer_inputs(per_layer_inputs_tokens)
+
+ # Handle vision tokens (>= embed_vision.vocab_offset and < embed_audio.vocab_offset)
+ vision_mask = torch.logical_and(
+ input_ids >= self.embed_vision.vocab_offset, input_ids < self.embed_audio.vocab_offset
+ )
+ dummy_vision_token_id = self.embed_vision.vocab_offset + self.embed_vision.vocab_size - 1
+ vision_input_ids = torch.where(vision_mask, input_ids, dummy_vision_token_id).to(inputs_embeds.device)
+ vision_embeds = self.embed_vision(input_ids=vision_input_ids)
+ expanded_vision_mask = vision_mask.unsqueeze(-1).expand_as(inputs_embeds)
+ inputs_embeds = torch.where(expanded_vision_mask, vision_embeds, inputs_embeds)
+
+ # Handle audio tokens (>= embed_audio.vocab_offset)
+ audio_mask = input_ids >= self.embed_audio.vocab_offset
+ dummy_audio_token_id = self.embed_audio.vocab_offset + self.embed_audio.vocab_size - 1
+ audio_input_ids = torch.where(audio_mask, input_ids, dummy_audio_token_id).to(inputs_embeds.device)
+ audio_embeds = self.embed_audio(input_ids=audio_input_ids)
+ expanded_audio_mask = audio_mask.unsqueeze(-1).expand_as(inputs_embeds)
+ inputs_embeds = torch.where(expanded_audio_mask, audio_embeds, inputs_embeds)
+ else:
+ per_layer_inputs = None
+
+ # Merge text and images
+ if pixel_values is not None:
+ image_features = self.get_image_features(pixel_values)
+
+ if input_ids is None:
+ special_image_mask = inputs_embeds == self.get_input_embeddings()(
+ torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
+ )
+ else:
+ special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)
+ special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
+
+ if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
+ image_tokens_in_text = (special_image_mask).sum(dim=1).sum(dim=0)[0]
+ raise ValueError(
+ f"Number of images does not match number of special image tokens in the input text. "
+ f"Got {image_tokens_in_text} image tokens in the text and "
+ f"{image_features.shape[0] * image_features.shape[1]} tokens from image embeddings."
+ )
+ image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
+ inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
+
+ # Merge text and audio
+ if input_features is not None and input_features_mask is not None:
+ audio_features, audio_mask = self.get_audio_features(input_features, ~input_features_mask)
+
+ # The Gemma3nProcessor expects all audio will be 30s in length and inserts 188 audio soft tokens into the
+ # text to account for this. However, the audio preprocessing and encoder do not gurarantee they will
+ # produce 188 soft tokens; they will produce at most that many tokens, but they may produce fewer tokens
+ # depending on the length of the longest audio input in the batch. When we encounter this situation, we pad
+ # the audio feature out to 188 soft tokens with the emebedding of the last token in the embed_audio vocab.
+ audio_padding_toks = torch.tensor([[self.vocab_size - 1]], dtype=torch.long, device=audio_features.device)
+ audio_padding_embs = self.embed_audio(input_ids=audio_padding_toks)
+ audio_features = torch.where(audio_mask.unsqueeze(-1), audio_padding_embs, audio_features)
+
+ audio_batch_size, audio_seq_len, audio_embed_dim = audio_features.shape
+ extra_padding_tokens = self.config.audio_soft_tokens_per_image - audio_seq_len
+ extra_padding_features = audio_padding_embs.expand(audio_batch_size, extra_padding_tokens, audio_embed_dim)
+
+ audio_features = torch.cat((audio_features, extra_padding_features), dim=1)
+
+ if input_ids is None:
+ special_audio_mask = inputs_embeds == self.embed_audio(
+ input_ids=torch.tensor(self.config.audio_token_id, dtype=torch.long, device=inputs_embeds.device)
+ )
+ else:
+ special_audio_mask = (input_ids == self.config.audio_token_id).unsqueeze(-1)
+ special_audio_mask = special_audio_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
+
+ if not is_torchdynamo_compiling() and inputs_embeds[special_audio_mask].numel() != audio_features.numel():
+ audio_tokens_in_text = (special_audio_mask).sum(dim=1).sum(dim=0)[0]
+ raise ValueError(
+ f"Number of audio input features does not match number of special audio tokens in the input text. "
+ f"Got {audio_tokens_in_text} audio tokens in the text and "
+ f"{audio_features.shape[0] * audio_features.shape[1]} tokens from audio embeddings."
+ )
+ audio_features = audio_features.to(inputs_embeds.device, inputs_embeds.dtype)
+ inputs_embeds = inputs_embeds.masked_scatter(special_audio_mask, audio_features)
+
+ outputs = self.language_model(
+ input_ids=None,
+ per_layer_inputs=per_layer_inputs,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=True,
+ cache_position=cache_position,
+ **lm_kwargs,
+ )
+
+ return Gemma3nModelOutputWithPast(
+ last_hidden_state=outputs.last_hidden_state,
+ past_key_values=outputs.past_key_values if use_cache else None,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ image_hidden_states=image_features if pixel_values is not None else None,
+ audio_hidden_states=audio_features if input_features is not None else None,
+ )
+
+ def get_audio_features(
+ self, input_features: torch.Tensor, input_features_mask: torch.Tensor
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ """
+ Projects the last hidden state from the audio encoder into language model space.
+
+ Args:
+ input_features (`torch.FloatTensor]` of shape `(num_images, seq_length, num_features)`):
+ The tensors corresponding to the input audio.
+ input_features (`torch.FloatTensor]` of shape `(num_images, seq_length)`):
+ The attention mask for the input audio.
+
+ Returns:
+ audio_features (`torch.Tensor`): Audio feature tensor of shape `(num_images, audio_length, embed_dim)`).
+ """
+ audio_outputs, audio_mask = self.audio_tower(input_features, input_features_mask)
+ return self.embed_audio(inputs_embeds=audio_outputs), audio_mask
+
+
+@auto_docstring(
+ custom_intro="""
+ The base Gemma 3n model comprising a vision backbone, an audio backbone, a language model, and a language modeling
+ head.
+ """
+)
+class Gemma3nForConditionalGeneration(Gemma3nPreTrainedModel, GenerationMixin):
+ _checkpoint_conversion_mapping = {}
+ _tied_weights_keys = ["lm_head.weight"]
+ base_model_prefix = "model"
+
+ def __init__(self, config: Gemma3nConfig):
+ super().__init__(config)
+ self.model = Gemma3nModel(config)
+ self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.model.get_input_embeddings()
+
+ def set_input_embeddings(self, value):
+ self.model.set_input_embeddings(value)
+
+ def get_output_embeddings(self):
+ return self.lm_head
+
+ def set_output_embeddings(self, new_embeddings):
+ self.lm_head = new_embeddings
+
+ def set_decoder(self, decoder):
+ self.model.set_decoder(decoder)
+
+ def get_decoder(self):
+ return self.model.get_decoder()
+
+ def get_image_features(self, pixel_values):
+ return self.model.get_image_features(pixel_values)
+
+ # Make modules available throught conditional class for BC
+ @property
+ def language_model(self):
+ return self.model.language_model
+
+ @property
+ def vision_tower(self):
+ return self.model.vision_tower
+
+ @property
+ def multi_modal_projector(self):
+ raise AttributeError("Use embed_vision instead of multi_modal_projector.")
+
+ @can_return_tuple
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None, # text inputs
+ pixel_values: Optional[torch.FloatTensor] = None, # vision inputs
+ input_features: Optional[torch.FloatTensor] = None, # audio inputs
+ attention_mask: Optional[torch.Tensor] = None,
+ input_features_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Union[list[torch.FloatTensor], Cache]] = None,
+ token_type_ids: Optional[torch.LongTensor] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ logits_to_keep: Union[int, torch.Tensor] = 0,
+ **lm_kwargs,
+ ) -> Gemma3nCausalLMOutputWithPast:
+ r"""
+ input_features (torch.Tensor, *optional*, defaults to None):
+ The audio inputs to be encoded.
+ input_features_mask (torch.Tensor, *optional*, defaults to None):
+ The attention mask for the input audio.
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+ config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are
+ ignored (masked), the loss is only computed for the tokens with labels in
+ `[0, ..., config.text_config.vocab_size]`.
+
+ Example:
+
+ ```python
+ >>> from PIL import Image
+ >>> import requests
+ >>> from transformers import AutoProcessor, Gemma3ForConditionalGeneration
+
+ >>> model = Gemma3ForConditionalGeneration.from_pretrained("google/gemma-3-4b-it")
+ >>> processor = AutoProcessor.from_pretrained("google/gemma-3-4b-it")
+
+ >>> messages = [
+ ... {
+ ... "role": "system",
+ ... "content": [
+ ... {"type": "text", "text": "You are a helpful assistant."}
+ ... ]
+ ... },
+ ... {
+ ... "role": "user", "content": [
+ ... {"type": "image", "url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg"},
+ ... {"type": "text", "text": "Where is the cat standing?"},
+ ... ]
+ ... },
+ ... ]
+
+ >>> inputs = processor.apply_chat_template(
+ ... messages,
+ ... tokenizer=True,
+ ... return_dict=True,
+ ... return_tensors="pt",
+ ... add_generation_prompt=True
+ ... )
+ >>> # Generate
+ >>> generate_ids = model.generate(**inputs)
+ >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
+ "user\nYou are a helpful assistant.\n\n\n\n\n\nWhere is the cat standing?\nmodel\nBased on the image, the cat is standing in a snowy area, likely outdoors. It appears to"
+ ```
+ """
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+
+ outputs = self.model(
+ input_ids=input_ids,
+ pixel_values=pixel_values,
+ input_features=input_features,
+ attention_mask=attention_mask,
+ input_features_mask=input_features_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ token_type_ids=token_type_ids,
+ cache_position=cache_position,
+ inputs_embeds=inputs_embeds,
+ labels=labels,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=True,
+ **lm_kwargs,
+ )
+
+ hidden_states = outputs.last_hidden_state
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
+ if (final_logit_softcapping := self.config.get_text_config().final_logit_softcapping) is not None:
+ logits = logits / final_logit_softcapping
+ logits = torch.tanh(logits)
+ logits = logits * final_logit_softcapping
+
+ loss = None
+ if labels is not None:
+ # Upcast to float if we need to compute the loss to avoid potential precision issues
+ logits = logits.float()
+ shift_logits = logits[..., :-1, :]
+ shift_labels = labels[..., 1:]
+ if attention_mask is not None:
+ # we use the input attention mask to shift the logits and labels, because it is 2D.
+ # we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
+ shift_attention_mask = attention_mask[:, -shift_logits.shape[1] :].to(logits.device)
+ shift_logits = shift_logits[shift_attention_mask.to(logits.device) != 0].contiguous()
+ shift_labels = shift_labels[shift_attention_mask.to(shift_labels.device) != 0].contiguous()
+ else:
+ shift_logits = shift_logits.contiguous()
+ shift_labels = shift_labels.contiguous()
+ # Flatten the tokens
+ loss_fct = nn.CrossEntropyLoss()
+
+ flat_logits = shift_logits.view(-1, self.config.text_config.vocab_size)
+ flat_labels = shift_labels.view(-1).to(shift_logits.device)
+ loss = loss_fct(flat_logits, flat_labels)
+
+ return Gemma3nCausalLMOutputWithPast(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ image_hidden_states=outputs.image_hidden_states,
+ audio_hidden_states=outputs.audio_hidden_states,
+ )
+
+ def prepare_inputs_for_generation(
+ self,
+ input_ids,
+ past_key_values=None,
+ inputs_embeds=None,
+ cache_position=None,
+ position_ids=None,
+ pixel_values=None,
+ input_features=None,
+ attention_mask=None,
+ input_features_mask=None,
+ token_type_ids=None,
+ use_cache=True,
+ logits_to_keep=None,
+ labels=None,
+ **kwargs,
+ ):
+ # Overwritten -- custom `position_ids` and `pixel_values` handling
+ model_inputs = super().prepare_inputs_for_generation(
+ input_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ cache_position=cache_position,
+ use_cache=use_cache,
+ logits_to_keep=logits_to_keep,
+ token_type_ids=token_type_ids,
+ **kwargs,
+ )
+
+ # If we're in cached decoding stage, multimodal inputs should be None because input ids do not contain special
+ # tokens anymore. Otherwise multimodal inputs should be passed to model.
+ # NOTE: use_cache=False always needs pixel_values, input_features, and input_features_mask
+ if cache_position[0] == 0:
+ model_inputs["pixel_values"] = pixel_values
+ model_inputs["input_features"] = input_features
+ model_inputs["input_features_mask"] = input_features_mask
+
+ return model_inputs
+
+ @property
+ def audio_tower(self):
+ return self.model.audio_tower
+
+
+__all__ = [
+ "Gemma3nAudioEncoder",
+ "Gemma3nForCausalLM",
+ "Gemma3nForConditionalGeneration",
+ "Gemma3nModel",
+ "Gemma3nPreTrainedModel",
+ "Gemma3nTextModel",
+]
diff --git a/src/transformers/models/gemma3n/modular_gemma3n.py b/src/transformers/models/gemma3n/modular_gemma3n.py
new file mode 100644
index 000000000000..a3ffa710d842
--- /dev/null
+++ b/src/transformers/models/gemma3n/modular_gemma3n.py
@@ -0,0 +1,2664 @@
+# coding=utf-8
+# Copyright 2025 Google Inc. HuggingFace Inc. team. All rights reserved.
+#
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import copy
+import math
+from collections.abc import Callable, Sequence
+from typing import Any, Optional, Union
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from ...activations import ACT2FN
+from ...cache_utils import Cache, DynamicCache, HybridCache
+from ...configuration_utils import PretrainedConfig, layer_type_validation
+from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
+from ...modeling_flash_attention_utils import FlashAttentionKwargs
+from ...modeling_outputs import BaseModelOutputWithPast
+from ...modeling_rope_utils import rope_config_validation
+from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
+from ...processing_utils import Unpack
+from ...utils import auto_docstring, can_return_tuple, is_torchdynamo_compiling, logging
+from ..auto import AutoModel
+from ..gemma2.configuration_gemma2 import Gemma2Config
+from ..gemma2.modeling_gemma2 import (
+ Gemma2MLP,
+ Gemma2PreTrainedModel,
+ Gemma2RotaryEmbedding,
+ eager_attention_forward,
+ rotate_half,
+)
+from ..gemma3.modeling_gemma3 import (
+ Gemma3Attention,
+ Gemma3DecoderLayer,
+ Gemma3ForCausalLM,
+ Gemma3RMSNorm,
+ Gemma3TextModel,
+ Gemma3TextScaledWordEmbedding,
+)
+from ..paligemma.modeling_paligemma import (
+ PaliGemmaCausalLMOutputWithPast,
+ PaliGemmaForConditionalGeneration,
+ PaliGemmaModel,
+ PaligemmaModelOutputWithPast,
+)
+from ..timm_wrapper.configuration_timm_wrapper import TimmWrapperConfig
+
+
+logger = logging.get_logger(__name__)
+
+
+class Gemma3nTextConfig(Gemma2Config, PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`Gemma3nTextModel`]. It is used to instantiate an
+ Gemma3nTextModel model according to the specified arguments, defining the model architecture. Instantiating a
+ configuration with the defaults will yield a similar configuration to that of the Gemma 3n E4B, e.g.
+ [google/gemma-3n-E4B](https://huggingface.co/google/gemma-3n-E4B).
+
+ Configuration objects that inherit from [`Gemma3nTextConfig`] and can be used to control the model outputs. Read
+ the documentation from [`Gemma3nTextConfig`] for more information.
+
+ Args:
+ vocab_size (`int`, *optional*, defaults to 262400):
+ Vocabulary size of the Gemma3nText model. Defines the number of different tokens that can be represented by
+ the `inputs_ids` passed when calling [`Gemma3nTextModel`]
+ vocab_size_per_layer_input (`int`, *optional*, defaults to 262144):
+ Vocabulary size of the per-layer text embeddings that augment the standard embeddings.
+ hidden_size (`int`, *optional*, defaults to 2048):
+ Dimension of the hidden representations.
+ hidden_size_per_layer_input (`int`, *optional*, defaults to 256):
+ Dimension of the hidden representations for per-layer emebeddings.
+ intermediate_size (`int` or `Sequence[int]`, *optional*, defaults to 16384):
+ Dimension of the MLP representations. MatFormer configurations may wish to provide a sequence of integers
+ to account for vairable intermediate_size values across layers. In such cases,
+ `len(intermediate_size) == num_hidden_layers`.
+ num_hidden_layers (`int`, *optional*, defaults to 35):
+ Number of hidden layers in the Transformer decoder.
+ num_attention_heads (`int`, *optional*, defaults to 8):
+ Number of attention heads for each attention layer in the Transformer decoder.
+ num_key_value_heads (`int`, *optional*, defaults to 2):
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
+ `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
+ by meanpooling all the original heads within that group. For more details checkout this
+ [paper](https://arxiv.org/pdf/2305.13245.pdf). If not specified, will default to `num_attention_heads`.
+ head_dim (`int`, *optional*, defaults to 256):
+ The attention head dimension.
+ hidden_activation (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`):
+ The non-linear activation function (function or string) in the decoder. Will default to
+ `"gelu_pytorch_tanh"` if not specified. `"gelu_pytorch_tanh"` uses an approximation of the `"gelu"`
+ activation function.
+ max_position_embeddings (`int`, *optional*, defaults to 32768):
+ The maximum sequence length that this model might ever be used with.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ rms_norm_eps (`float`, *optional*, defaults to 1e-06):
+ The epsilon used by the rms normalization layers.
+ use_cache (`bool`, *optional*, defaults to `True`):
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
+ relevant if `config.is_decoder=True`.
+ pad_token_id (`int`, *optional*, defaults to 0):
+ Padding token id.
+ eos_token_id (`int`, *optional*, defaults to 1):
+ End of stream token id.
+ bos_token_id (`int`, *optional*, defaults to 2):
+ Beginning of stream token id.
+ rope_theta (`float`, *optional*, defaults to 1000000.0):
+ The base period of the RoPE embeddings.
+ rope_scaling (`Dict`, *optional*):
+ Dictionary containing the scaling configuration for the RoPE embeddings used in gloabl attention.
+ NOTE: if you apply new rope type and you expect the model to work on longer `max_position_embeddings`, we
+ recommend you to update this value accordingly.
+ Expected contents:
+ `rope_type` (`str`):
+ The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
+ 'llama3'], with 'default' being the original RoPE implementation.
+ `factor` (`float`, *optional*):
+ Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
+ most scaling types, a `factor` of x will enable the model to handle sequences of length x *
+ original maximum pre-trained length.
+ `original_max_position_embeddings` (`int`, *optional*):
+ Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
+ pretraining.
+ `attention_factor` (`float`, *optional*):
+ Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
+ computation. If unspecified, it defaults to value recommended by the implementation, using the
+ `factor` field to infer the suggested value.
+ `beta_fast` (`float`, *optional*):
+ Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
+ ramp function. If unspecified, it defaults to 32.
+ `beta_slow` (`float`, *optional*):
+ Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
+ ramp function. If unspecified, it defaults to 1.
+ `short_factor` (`List[float]`, *optional*):
+ Only used with 'longrope'. The scaling factor to be applied to short contexts (<
+ `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
+ size divided by the number of attention heads divided by 2
+ `long_factor` (`List[float]`, *optional*):
+ Only used with 'longrope'. The scaling factor to be applied to long contexts (<
+ `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
+ size divided by the number of attention heads divided by 2
+ `low_freq_factor` (`float`, *optional*):
+ Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
+ `high_freq_factor` (`float`, *optional*):
+ Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
+ rope_local_base_freq (float, *optional*, defaults to 10000.0):
+ The base period of the RoPE embeddings for local attention.
+ attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
+ Whether to use a bias in the query, key, value and output projection layers during self-attention.
+ attention_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the attention probabilities.
+ sliding_window (`int`, *optional*, defaults to 512):
+ This is the size of the sliding window used by local attention layers.
+ layer_types (`Optional`, *optional*):
+ A sequence of strings defining the attention type for that layer as either "sliding_attention" or
+ "full_attention". If not provided, `layer_types` will de inferred from `num_hidden_layers` using a pattern
+ of four "sliding_attention" layers followed one "full_attention". The last layer in the model should always
+ be a "full_attention" layer.
+ final_logit_softcapping (`float`, *optional*, defaults to 30.0):
+ Scaling factor when applying tanh softcapping on the logits.
+ altup_active_idx (`int`, *optional*, defaults to 0):
+ The index of the prediction from which AltUp will compute additional predictions or correct
+ altup_coef_clip (`float`, *optional*, defaults to 120.0):
+ The maximum amplitude of an AltUp prediction or correction coeficient weight.
+ altup_correct_scale (`bool`, *optional*, defaults to `True`):
+ If True, apply the `AltUp.correct_output_scale` to the corrected prediction at `altup_active_idx`.
+ altup_num_inputs (`int`, *optional*, defaults to 4):
+ The number of predictions that AltUp should be make given the input sequence.
+ num_kv_shared_layers (`int`, *optional*, defaults to 15):
+ The number of layer that share KV cache values. During the forward pass, the last `num_kv_shared_layers`
+ layers in the model "share" the KV values in that each local and global layer in this range uses the KV
+ cache values computed for the last local or global layer, respectively, before entering this range. The
+ value should be `num_kv_shared_layers` should be a scalar of `sliding_window_pattern`.
+ laurel_rank (int, *optional*, defaults to 64):
+ The intermediate size for the linear projections in the Learned Augmented Residual Layer.
+ activation_sparsity_pattern (Sequence[float], *optional*, defaults to `(0.95, 0.95, 0.95, 0.95, 0.95, 0.95, 0.95, 0.95, 0.95, 0.95, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0)`):
+ The sparsity factor used to extract the top-k activations for a given layer. The provided Sequence must
+ explicitly provide a sparsity value for each layer in the model.
+
+ ```python
+ >>> from transformers import Gemma3nTextModel, Gemma3nTextConfig
+
+ >>> # Initializing a Gemma3nText gemma3n_text-E4B style configuration
+ >>> configuration = Gemma3nTextConfig()
+
+ >>> # Initializing a model from the gemma3n_text-E4B style configuration
+ >>> model = Gemma3nTextModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```
+ """
+
+ model_type = "gemma3n_text"
+
+ def __init__(
+ self,
+ vocab_size: int = 262_400,
+ vocab_size_per_layer_input: int = 262_144,
+ hidden_size: int = 2048,
+ hidden_size_per_layer_input: int = 256,
+ intermediate_size: Union[int, Sequence[int]] = 16_384,
+ num_hidden_layers: int = 35,
+ num_attention_heads: int = 8,
+ num_key_value_heads: int = 2,
+ head_dim: int = 256,
+ hidden_activation: str = "gelu_pytorch_tanh",
+ max_position_embeddings: int = 32_768,
+ initializer_range: float = 0.02,
+ rms_norm_eps: float = 1e-6,
+ use_cache: bool = True,
+ pad_token_id: int = 0,
+ eos_token_id: int = 1,
+ bos_token_id: int = 2,
+ rope_theta: float = 1_000_000.0,
+ rope_scaling: Optional[dict[str, Any]] = None,
+ rope_local_base_freq: float = 10_000.0,
+ attention_bias: bool = False,
+ attention_dropout: float = 0.0,
+ sliding_window: int = 512,
+ layer_types: Optional[Sequence[str]] = None,
+ final_logit_softcapping: float = 30.0,
+ altup_active_idx: int = 0,
+ altup_coef_clip: float = 120.0,
+ altup_correct_scale: bool = True,
+ altup_num_inputs: int = 4,
+ num_kv_shared_layers: int = 15,
+ laurel_rank: int = 64,
+ activation_sparsity_pattern: Optional[Union[float, Sequence[float]]] = (0.95,) * 10 + (0.0,) * 25,
+ **kwargs,
+ ):
+ PretrainedConfig.__init__(
+ pad_token_id=pad_token_id,
+ bos_token_id=bos_token_id,
+ eos_token_id=eos_token_id,
+ **kwargs,
+ )
+
+ if isinstance(intermediate_size, Sequence) and (intsize_len := len(intermediate_size)) != num_hidden_layers:
+ raise ValueError(
+ "intermediate_size must have an explicit intermediate size for every layer or one for all layers. "
+ f"Expected {num_hidden_layers} values but got {intsize_len}."
+ )
+ elif not isinstance(intermediate_size, Sequence):
+ intermediate_size = [intermediate_size] * num_hidden_layers
+
+ self.vocab_size = vocab_size
+ self.vocab_size_per_layer_input = vocab_size_per_layer_input
+ self.max_position_embeddings = max_position_embeddings
+ self.hidden_size = hidden_size
+ self.intermediate_size = intermediate_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.head_dim = head_dim
+ self.num_key_value_heads = num_key_value_heads
+ self.initializer_range = initializer_range
+ self.rms_norm_eps = rms_norm_eps
+ self.use_cache = use_cache
+ self.rope_theta = rope_theta
+ self.attention_bias = attention_bias
+ self.attention_dropout = attention_dropout
+ self.hidden_activation = hidden_activation
+ self.sliding_window = sliding_window
+ self.final_logit_softcapping = final_logit_softcapping
+ self.layer_types = layer_types
+
+ self.rope_local_base_freq = rope_local_base_freq
+ self.rope_scaling = rope_scaling
+ rope_config_validation(self)
+
+ if layer_types is None:
+ self.layer_types = [
+ "full_attention" if i % 5 == 0 else "sliding_attention" for i in range(self.num_hidden_layers)
+ ]
+ else:
+ self.layer_types = layer_types
+
+ layer_type_validation(self.layer_types)
+
+ self.hidden_size_per_layer_input = hidden_size_per_layer_input
+ self.num_kv_shared_layers = num_kv_shared_layers
+
+ self.altup_active_idx = altup_active_idx
+ self.altup_coef_clip = altup_coef_clip
+ self.altup_correct_scale = altup_correct_scale
+ self.altup_num_inputs = altup_num_inputs
+
+ self.laurel_rank = laurel_rank
+
+ if activation_sparsity_pattern is None:
+ activation_sparsity_pattern = [0.0] * num_hidden_layers
+
+ if (len_asp := len(activation_sparsity_pattern)) != num_hidden_layers:
+ raise ValueError(
+ "activation_sparsity_pattern must have an explicit activation sparsity value for every layer."
+ f"Expected {num_hidden_layers} values but got {len_asp}."
+ )
+ self.activation_sparsity_pattern = activation_sparsity_pattern
+
+
+class Gemma3nAudioConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`Gemma3nAudioEncoder`], based on Gogole's
+ [Universal Speech Model](). It is used to instantiate an Gemma3nAudioEncoder model according to the specified
+ arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar
+ configuration to that of the Gemma 3n E4B, e.g. [google/gemma-3n-E4B](https://huggingface.co/google/gemma-3n-E4B).
+
+ Configuration objects that inherit from [`Gemma3nAudioConfig`] and can be used to control the model outputs. Read
+ the documentation from [`Gemma3nAudioConfig`] for more information.
+
+ Args:
+ vocab_size (`int`, *optional*, defaults to 128):
+ Vocabulary size of the additional hard-token embeddings for audio model. These augment the embeddings
+ included in the `Gemma3nTextModel` to provide, e.g., the end of audio and audio soft token placeholder
+ tokens when converting `input_ids` to embeddings in the `Gemma3nForConditionalGeneration` model.
+ vocab_offset (`int`, *optional*, defaults to 262272):
+ Offset between the tokenizer vocab index for the token ids embedded by `Gemma3nMultimodalEmbedder` and the
+ 0-indexed `Gemma3nMultimodalEmbedder.embedding` table.
+ input_feat_size (`int`, *optional*, defaults to 128):
+ The number of channels in each mel-spectrogram frame.
+ hidden_size (`int`, *optional*, defaults to 1536):
+ Dimension of the hidden representations.
+ rms_norm_eps (`float`, *optional*, defaults to 1e-06):
+ The epsilon used by the rms normalization layers.
+ gradient_clipping (`float`, *optional*, defaults to 10000000000.0):
+ Clipping value used to stablize extremely large gradient values.
+ conf_attention_chunk_size (`int`, *optional*, defaults to 12):
+ The sub-sequence size for local attention processing inside the Conformer ("conf") section of the
+ Universal Speech Model.
+ conf_attention_context_left (`int`, *optional*, defaults to 13):
+ The left context size of the local attention inside the Conformer ("conf") section of the
+ Universal Speech Model.
+ conf_attention_context_right (`int`, *optional*, defaults to 0):
+ The right context size of the local attention inside the Conformer ("conf") section of the
+ Universal Speech Model.
+ conf_attention_logit_cap (`float`, *optional*, defaults to 50.0):
+ Logit cap applied during local attention inside the Conformer ("conf") section of the
+ Universal Speech Model.
+ conf_num_attention_heads (`int`, *optional*, defaults to 8):
+ The number of attention heads in local attention inside the Conformer ("conf") section of the
+ Universal Speech Model.
+ conf_num_hidden_layers (`int`, *optional*, defaults to 12):
+ The number of layers that use local attention inside the Conformer ("conf") section of the
+ Universal Speech Model.
+ conf_conv_kernel_size (`int`, *optional*, defaults to 5):
+ Convolution kernel size for the conformer block inside the Conformer ("conf") section of the
+ Universal Speech Model.
+ conf_reduction_factor (`int`, *optional*, defaults to 4):
+ Reduction factor used in the conformer block inside the Conformer ("conf") section of the
+ Universal Speech Model.
+ conf_residual_weight (`float`, *optional*, defaults to 0.5):
+ Residual connection weight inside the Conformer ("conf") section of the
+ Universal Speech Model.
+ sscp_conv_channel_size (`tuple(int, int)`, *optional*, defaults to `(128, 32)`):
+ The channel sizes for the first and second convolutional layers in the Sub-sample Convolution Projection
+ ("sscp") section of the Universal Speech Model.
+ sscp_conv_group_norm_eps (`float`, *optional*, defaults to 0.001):
+ Epsilon used in group normalization in the subsample convolution projection in the Sub-sample Convolution
+ Projection ("sscp") section of the Universal Speech Model.
+ sscp_conv_kernel_size (`tuple(tuple(int, int), tuple(int, int))`, *optional*, defaults to `((3, 3), (3, 3))`):
+ Kernel sizes of the two convolutional layers in the subsample convolution projection in the Sub-sample
+ Convolution Projection ("sscp") section of the Universal Speech Model. The kernel sizes are specified as a
+ tuple of height and width for each layer, where the height corresponds to the time dimension and the width
+ corresponds to the frequency dimension.
+ sscp_conv_stride_size (`tuple(tuple(int, int), tuple(int, int))`, *optional*, defaults to `((2, 2), (2, 2))`):
+ Stride sizes of the two convolutional layers in the subsample convolution projection in the Sub-sample
+ Convolution Projection ("sscp") section of the Universal Speech Model. The stride sizes are specified as a
+ tuple of height and width for each layer, where the height corresponds to the time dimension and the width
+ corresponds to the frequency dimension.
+
+ Example:
+
+ ```python
+ >>> from transformers import Gemma3nAudioConfig, Gemma3nAudioEncoder
+
+ >>> # Initializing a Gemma3nAudioEncoder gemma3n_audio-E4B-style configuration
+ >>> configuration = Gemma3nAudioConfig()
+
+ >>> # Initializing a model from the gemma3n_audio-E4B style configuration
+ >>> model = Gemma3nAudioEncoder(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```
+ """
+
+ model_type = "gemma3n_audio"
+
+ def __init__(
+ self,
+ vocab_size: int = 128,
+ vocab_offset: int = 262_144 + 128, # text vocab size + vision vocab size
+ input_feat_size: int = 128,
+ hidden_size: int = 1536,
+ rms_norm_eps: float = 1e-6,
+ gradient_clipping: float = 10_000_000_000.0,
+ conf_attention_chunk_size: int = 12,
+ conf_attention_context_left: int = 13,
+ conf_attention_context_right: int = 0,
+ conf_attention_logit_cap: float = 50.0,
+ conf_num_attention_heads: int = 8,
+ conf_num_hidden_layers: int = 12,
+ conf_conv_kernel_size: int = 5,
+ conf_reduction_factor: int = 4,
+ conf_residual_weight: float = 0.5,
+ sscp_conv_channel_size: tuple[int, int] = (128, 32),
+ sscp_conv_group_norm_eps: float = 1e-3,
+ sscp_conv_kernel_size: tuple[tuple[int, int], tuple[int, int]] = (
+ (3, 3),
+ (3, 3),
+ ),
+ sscp_conv_stride_size: tuple[tuple[int, int], tuple[int, int]] = (
+ (2, 2),
+ (2, 2),
+ ),
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ self.input_feat_size = input_feat_size
+ self.hidden_size = hidden_size
+ self.rms_norm_eps = rms_norm_eps
+ self.vocab_size = vocab_size
+ self.vocab_offset = vocab_offset
+ self.gradient_clipping = gradient_clipping
+ self.conf_attention_chunk_size = conf_attention_chunk_size
+ self.conf_attention_context_left = conf_attention_context_left
+ self.conf_attention_context_right = conf_attention_context_right
+ self.conf_attention_logit_cap = conf_attention_logit_cap
+ self.conf_num_attention_heads = conf_num_attention_heads
+ self.conf_num_hidden_layers = conf_num_hidden_layers
+ self.conf_conv_kernel_size = conf_conv_kernel_size
+ self.conf_reduction_factor = conf_reduction_factor
+ self.conf_residual_weight = conf_residual_weight
+ self.sscp_conv_channel_size = sscp_conv_channel_size
+ self.sscp_conv_group_norm_eps = sscp_conv_group_norm_eps
+ self.sscp_conv_kernel_size = sscp_conv_kernel_size
+ self.sscp_conv_stride_size = sscp_conv_stride_size
+
+
+class Gemma3nVisionConfig(TimmWrapperConfig):
+ r"""
+ This is the configuration class to store the configuration for a timm backbone [`TimmWrapper`]. It is used to
+ instantiate an timm model model according to the specified arguments, defining the model architecture.
+ Instantiating a configuration with the defaults will yield a similar configuration to that of the Gemma 3n E4B
+ vision tower, e.g. [google/gemma-3n-E4B](https://huggingface.co/google/gemma-3n-E4B).
+
+ Configuration objects inherit from [`Gemma3nVisionConfig`] and can be used to control the model outputs. Read the
+ documentation from [`Gemma3nVisionConfig`] for more information.
+
+ Config loads imagenet label descriptions and stores them in `id2label` attribute, `label2id` attribute for default
+ imagenet models is set to `None` due to occlusions in the label descriptions.
+
+ Args:
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ do_pooling (`bool`, *optional*, defaults to `False`):
+ Whether to do pooling for the last_hidden_state in `TimmWrapper` or not.
+ architecture (`str`, *optional*, defaults to `"mobilenetv5_300m_enc"`):
+ Determines vision architecture for TimmWrapper.
+ hidden_size (`int`, *optional*, defaults to 2048):
+ Dimension of the hidden representations.
+ vocab_size (`int`, *optional*, defaults to 128):
+ Vocabulary size of the additional hard-token embeddings for vision model.
+ vocab_offset (`int`, *optional*, defaults to 262144):
+ Offset between the tokenizer vocab index for the token ids embedded by `Gemma3nMultimodalEmbedder` and the
+ 0-indexed `Gemma3nMultimodalEmbedder.embedding` table.
+ rms_norm_eps (`float`, *optional*, defaults to 1e-06):
+ The epsilon used by the rms normalization layers.
+
+ Example:
+ ```python
+ >>> from transformers import Gemma3nVisionConfig, TimmWrapper
+
+ >>> # Initializing a TimmWrapper gemma3n_vision-E4B-style configuration
+ >>> configuration = Gemma3nVisionConfig()
+
+ >>> # Initializing a gemma3n_vision-E4B-style TimmWrapper from the configuration
+ >>> model = TimmWrapper(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```
+ """
+
+ model_type = "gemma3n_vision"
+
+ def __init__(
+ self,
+ initializer_range: float = 0.02,
+ do_pooling: bool = False,
+ architecture: str = "mobilenetv5_300m_enc",
+ hidden_size: int = 2048,
+ vocab_size: int = 128,
+ vocab_offset: int = 262_144,
+ rms_norm_eps: float = 1e-06,
+ model_args: Optional[dict] = None,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ self.architecture = architecture
+ self.initializer_range = initializer_range
+ self.do_pooling = do_pooling
+ self.hidden_size = hidden_size
+ self.vocab_size = vocab_size
+ self.vocab_offset = vocab_offset
+ self.rms_norm_eps = rms_norm_eps
+
+
+class Gemma3nConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`Gemma3nForConditionalGeneration`]. It is used to
+ instantiate a Gemma3nForConditionalGeneration according to the specified arguments, defining the model
+ architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of
+ Gemma3n-E4B.
+
+ e.g. [google/gemma-3n-E4B](https://huggingface.co/google/gemma-3n-E4B)
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ text_config (`Union[Gemma3nTextConfig, dict]`, *optional*):
+ The config object of the text backbone.
+ vision_config (`Union[AutoConfig, dict]`, *optional*):
+ Custom vision config or dict.
+ audio_config (`Union[AutoConfig, dict]`, *optional*):
+ Custom audio config or dict.
+ audio_soft_tokens_per_image (`int`, *optional*, defaults to 188):
+ The number of soft tokens per audio clip.
+ vision_soft_tokens_per_image (`int`, *optional*, defaults to 256):
+ The number of soft tokens per image.
+ boi_token_id (`int`, *optional*, defaults to 255999):
+ The begin-of-image token index to wrap the image prompt.
+ eoi_token_id (`int`, *optional*, defaults to 262144):
+ The end-of-image token index to wrap the image prompt.
+ image_token_id (`int`, *optional*, defaults to 262145):
+ The image token index to encode the image prompt.
+ boa_token_id (`int`, *optional*, defaults to 256000):
+ The begin-of-audio token index to wrap the audio prompt.
+ eoa_token_id (`int`, *optional*, defaults to 262272):
+ The end-of-audio token index to wrap the audio prompt.
+ audio_token_id (`int`, *optional*, defaults to 262273):
+ The audio token index to encode the audio prompt.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+
+
+ Example:
+
+ ```python
+ >>> from transformers import Gemma3nForConditionalGeneration, Gemma3nConfig, Gemma3nTextConfig
+
+ >>> # Initializing a MobileNet vision config, which is loaded from TIMM
+ >>> vision_config = Gemma3nVisionConfig()
+
+ >>> # Initializing a Gemma3n Audio config
+ >>> audio_config = Gemma3nAudioConfig()
+
+ >>> # Initializing a Gemma3n Text config
+ >>> text_config = Gemma3nTextConfig()
+
+ >>> # Initializing a Gemma3n gemma-3-4b style configuration
+ >>> configuration = Gemma3nConfig(text_config, vision_config, audio_config)
+
+ >>> # Initializing a model from the gemma-3-4b style configuration
+ >>> model = Gemma3nTextConfig(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "gemma3n"
+ sub_configs = {
+ "text_config": Gemma3nTextConfig,
+ "vision_config": Gemma3nVisionConfig,
+ "audio_config": Gemma3nAudioConfig,
+ }
+
+ def __init__(
+ self,
+ text_config: Optional[Union[Gemma3nTextConfig, dict[str, Any]]] = None,
+ vision_config: Optional[Union[Gemma3nVisionConfig, dict[str, Any]]] = None,
+ audio_config: Optional[Union[Gemma3nAudioConfig, dict[str, Any]]] = None,
+ audio_soft_tokens_per_image: int = 188,
+ vision_soft_tokens_per_image: int = 256,
+ boi_token_id: int = 255_999,
+ eoi_token_id: int = 262_144,
+ image_token_id: int = 262_145,
+ boa_token_id: int = 256_000,
+ eoa_token_id: int = 262_272,
+ audio_token_id: int = 262_273,
+ initializer_range: float = 0.02,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+
+ if isinstance(text_config, dict):
+ text_config = Gemma3nTextConfig(**text_config)
+ elif text_config is None:
+ text_config = Gemma3nTextConfig()
+ logger.info("text_config is None. Using default Gemma3nTextConfig.")
+
+ if isinstance(vision_config, dict):
+ vision_config = Gemma3nVisionConfig(**vision_config)
+ elif vision_config is None:
+ vision_config = Gemma3nVisionConfig()
+ logger.info("vision_config is None. Using default Gemma3nVisionConfig.")
+
+ if isinstance(audio_config, dict):
+ audio_config = Gemma3nAudioConfig(**audio_config)
+ elif audio_config is None:
+ audio_config = Gemma3nAudioConfig()
+ logger.info("audio_config is None. Using default Gemma3nAudioConfig.")
+
+ self.text_config = text_config
+ self.vision_config = vision_config
+ self.audio_config = audio_config
+
+ self.audio_soft_tokens_per_image = audio_soft_tokens_per_image
+ self.vision_soft_tokens_per_image = vision_soft_tokens_per_image
+ self.boi_token_id = boi_token_id
+ self.eoi_token_id = eoi_token_id
+ self.image_token_id = image_token_id
+ self.boa_token_id = boa_token_id
+ self.eoa_token_id = eoa_token_id
+ self.audio_token_id = audio_token_id
+ self.initializer_range = initializer_range
+
+
+class Gemma3nModelOutputWithPast(PaligemmaModelOutputWithPast):
+ r"""
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`)
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
+ `past_key_values` input) to speed up sequential decoding.
+ image_hidden_states (`torch.FloatTensor`, *optional*):
+ A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
+ image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
+ audio_hidden_states (`torch.FloatTensor`, *optional*):
+ A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
+ audio_hidden_states of the model produced by the audio encoder and after projecting the last hidden state.
+ """
+
+ audio_hidden_states: Optional[torch.FloatTensor] = None
+
+
+class Gemma3nCausalLMOutputWithPast(PaliGemmaCausalLMOutputWithPast):
+ r"""
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
+ Language modeling loss (for next-token prediction).
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.text_config.vocab_size)`):
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`)
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
+ `past_key_values` input) to speed up sequential decoding.
+ image_hidden_states (`torch.FloatTensor`, *optional*):
+ A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
+ image_hidden_states of the model produced by the vision encoder after projecting last hidden state.
+ audio_hidden_states (`torch.FloatTensor`, *optional*):
+ A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
+ audio_hidden_states of the model produced by the audio encoder and after projecting the last hidden state.
+ """
+
+ audio_hidden_states: Optional[torch.FloatTensor] = None
+
+
+class Gemma3nRMSNorm(Gemma3RMSNorm):
+ def __init__(self, dim: int, eps: float = 1e-6, with_scale: bool = True):
+ super().__init__(dim, eps=eps)
+ del self.weight
+ self.with_scale = with_scale
+
+ if self.with_scale:
+ self.weight = nn.Parameter(torch.ones(dim))
+ else:
+ self.register_buffer("weight", torch.tensor(1.0), persistent=False)
+
+ def _norm(self, x):
+ return x / torch.sqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ # Llama does x.to(float16) * w whilst Gemma2 is (x * w).to(float16)
+ # See https://github.com/huggingface/transformers/pull/29402
+ output = self._norm(x.float()) * self.weight.float()
+ return output.type_as(x)
+
+
+# ==== Audio Encoder ====
+
+
+class Gemma3nAudioRelativePositionEmbedding(nn.Module):
+ def __init__(self, config: Gemma3nAudioConfig):
+ super().__init__()
+ self.config = config
+
+ self.num_heads = self.config.conf_num_attention_heads
+ self.channels = self.config.hidden_size
+ self.head_dim = self.channels // self.num_heads
+ self.max_backward = max(0, self.config.conf_attention_context_left - 1)
+ self.max_forward = self.config.conf_attention_context_right
+
+ self.pos_proj = nn.Linear(self.channels, self.num_heads * self.head_dim, bias=False)
+
+ min_timescale = 1.0
+ max_timescale = 1.0e4
+ num_timescales = self.channels // 2
+ log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / max(num_timescales - 1, 1)
+ inv_timescales = min_timescale * torch.exp(torch.arange(num_timescales) * -log_timescale_increment)
+ self.register_buffer(
+ "inv_timescales",
+ inv_timescales.float().unsqueeze(0).unsqueeze(0),
+ persistent=False,
+ )
+
+ def _get_timing_signal_1d_pos(self, position: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
+ position = position.float().unsqueeze(-1)
+ scaled_time = position * self.inv_timescales.to(device=position.device, dtype=torch.float32)
+ timing_signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=-1)
+ return timing_signal.type(dtype)
+
+ def _relative_shift(
+ self,
+ term_bd_before_shift: torch.Tensor,
+ batch_size: int,
+ num_heads: int,
+ num_query_blocks: int,
+ query_block_size: int,
+ key_context_size: int,
+ max_span_plus_1: int,
+ ) -> torch.Tensor:
+ """Performs the relative shift.
+
+ Args:
+ term_bd_before_shift: Tensor of shape [B, N, U, W, F_span]. batch_size
+ (B), num_heads (N), num_query_blocks (U), query_block_size (W),
+ key_context_size (C = W+L+R), max_span_plus_1 (F_span = L+R+1).
+
+ Returns:
+ Tensor of shape [B, N, U, W, C].
+ """
+ # term_bd_before_shift shape: [B, N, U, W, F_span]
+ # Target shape after shift: [B, N, U, W, C]
+
+ # Padding amount for the last dimension (F_span) to become (C + 1)
+ # C = key_context_size
+ # F_span = max_span_plus_1
+ pad_amount_last_dim = (key_context_size + 1) - max_span_plus_1
+
+ # PyTorch F.pad expects (pad_left, pad_right, pad_top, pad_bottom ...)
+ # We only pad the last dimension on the right.
+ padding_tuple = (0, pad_amount_last_dim)
+
+ term_bd_padded = nn.functional.pad(term_bd_before_shift, padding_tuple)
+ # Shape after pad: [B, N, U, W, C+1]
+
+ # Reshape for slicing (emulating JAX's behavior)
+ # [B, N, U, W * (C+1)]
+ term_bd_reshaped = term_bd_padded.reshape(
+ (
+ batch_size,
+ num_heads,
+ num_query_blocks,
+ query_block_size * (key_context_size + 1),
+ )
+ )
+
+ # Slice to effective [B, N, U, W * C]
+ term_bd_sliced = term_bd_reshaped[:, :, :, : query_block_size * key_context_size]
+
+ # Reshape back to [B, N, U, W, C]
+ term_bd_shifted = term_bd_sliced.reshape(
+ (
+ batch_size,
+ num_heads,
+ num_query_blocks,
+ query_block_size,
+ key_context_size,
+ )
+ )
+ return term_bd_shifted
+
+ def forward(self, queries: torch.Tensor, keys: torch.Tensor) -> torch.Tensor:
+ # queries: [B, U, W, N, H] (batch, num_query_blocks, query_block_size, num_heads, head_dim)
+ # keys: [B, U, C, N, H] (batch, num_query_blocks, key_context_size, num_heads, head_dim)
+ # C = W + L + R (key_context_size)
+ # F_span = L + R + 1 (max_span + 1)
+
+ batch_size, num_query_blocks, query_block_size, num_heads, head_dim = queries.shape
+ _, _, key_context_size, _, _ = keys.shape
+
+ # Relative positions for sinusoidal embeddings: [L, L-1, ..., -R]
+ # Length is L+R+1 = self.max_span + 1
+ pos_indices = torch.arange(self.max_backward, -self.max_forward - 1, -1, device=queries.device).unsqueeze(
+ 0
+ ) # Shape [1, F_span]
+
+ max_span_plus_1 = pos_indices.shape[1] # F_span
+
+ sin_emb_timing_signal = self._get_timing_signal_1d_pos(
+ pos_indices, dtype=queries.dtype
+ ) # Shape [1, F_span, self.channels]
+
+ # Project sinusoidal embeddings: [1, F_span, self.channels] -> [1, F_span, N*H]
+ projected_sin_emb = self.pos_proj(sin_emb_timing_signal)
+ # Reshape to [1, F_span, N, H] then squeeze to [F_span, N, H]
+ sin_emb = projected_sin_emb.reshape(1, max_span_plus_1, self.num_heads, self.head_dim).squeeze(
+ 0
+ ) # Shape [F, N, H]
+
+ # term_ac: Query-Key content interaction
+ # queries: [B, U, W, N, H] -> permute to [B, N, U, W, H] for matmul
+ # keys: [B, U, C, N, H] -> permute to [B, N, U, H, C] for matmul
+ queries_p = queries.permute(0, 3, 1, 2, 4) # [B, N, U, W, H]
+ keys_p_t = keys.permute(0, 3, 1, 4, 2) # [B, N, U, H, C]
+ term_ac = torch.matmul(queries_p, keys_p_t) # [B, N, U, W, C]
+
+ # term_bd: Query-Position interaction
+ # Original einsum: term_bd_unshifed = torch.einsum('buwnh,fnh->bnuwf', queries, sin_emb)
+ # queries shape: [B, U, W, N, H]
+ # sin_emb shape: [F, N, H]
+ # Target output shape: [B, N, U, W, F]
+
+ # Permute queries to [B, N, U, W, H] for easier broadcasting with sin_emb
+ q_permuted = queries.permute(0, 3, 1, 2, 4)
+
+ # Permute sin_emb to [N, H, F] to prepare for matmul
+ # sin_emb original is [F, N, H]
+ s_permuted = sin_emb.permute(1, 2, 0) # Shape: [N, H, F]
+
+ # Reshape queries for matmul: [B, N, U*W, H]
+ q_reshaped = q_permuted.reshape(batch_size, num_heads, num_query_blocks * query_block_size, head_dim)
+
+ # Perform matmul: [B, N, U*W, H] @ [N, H, F]
+ # s_permuted ([N, H, F]) will be broadcast to [B, N, H, F]
+ # Result: [B, N, U*W, F]
+ term_bd_unshifed_matmul = torch.matmul(q_reshaped, s_permuted)
+
+ # Reshape to target [B, N, U, W, F]
+ term_bd_unshifed = term_bd_unshifed_matmul.reshape(
+ batch_size,
+ num_heads,
+ num_query_blocks,
+ query_block_size,
+ max_span_plus_1,
+ )
+
+ # Apply relative shift to term_bd_unshifed
+ term_bd_shifted = self._relative_shift(
+ term_bd_unshifed,
+ batch_size,
+ num_heads,
+ num_query_blocks,
+ query_block_size,
+ key_context_size,
+ max_span_plus_1,
+ ) # Shape [B, N, U, W, C]
+
+ return term_ac + term_bd_shifted
+
+
+class Gemma3nAudioAttention(nn.Module):
+ def __init__(self, config: Gemma3nAudioConfig):
+ super().__init__()
+ self.config = config
+
+ self.num_heads = self.config.conf_num_attention_heads
+ self.hidden_size = self.config.hidden_size
+ self.head_dim = self.hidden_size // self.num_heads
+
+ self.chunk_size = self.config.conf_attention_chunk_size
+ self.max_future_horizon = self.config.conf_attention_context_right
+ self.max_past_horizon = max(0, self.config.conf_attention_context_left - 1)
+ self.attention_logits_soft_cap = self.config.conf_attention_logit_cap
+ self.context_size = self.chunk_size + self.max_past_horizon + self.max_future_horizon
+
+ self.relative_position_embedding = Gemma3nAudioRelativePositionEmbedding(config)
+ self.per_dim_scale = nn.Parameter(torch.zeros((self.head_dim,)))
+
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
+ self.k_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
+ self.v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
+
+ q_scale = self.head_dim**-0.5
+ r_softplus_0 = 1.0 / torch.nn.functional.softplus(torch.tensor(0.0))
+ self.register_buffer("q_scale", (q_scale * r_softplus_0).clone().detach(), persistent=False)
+
+ lower_causal_mask = torch.tril(
+ torch.ones((self.context_size, self.chunk_size), dtype=torch.bool),
+ diagonal=0,
+ ).T
+ upper_causal_mask = torch.tril(
+ torch.ones((self.chunk_size, self.context_size), dtype=torch.bool),
+ diagonal=self.max_past_horizon + self.max_future_horizon,
+ )
+ local_causal_valid_mask = torch.ones((self.chunk_size, self.context_size), dtype=torch.bool)
+ local_causal_valid_mask = local_causal_valid_mask * lower_causal_mask * upper_causal_mask
+ self.register_buffer("local_causal_valid_mask", local_causal_valid_mask, persistent=False)
+
+ self.register_buffer(
+ "softcap",
+ torch.tensor(self.attention_logits_soft_cap).float(),
+ persistent=False,
+ )
+
+ def _pad_dim1(self, x: torch.Tensor, pad_left: int, pad_right: int) -> torch.Tensor:
+ batch, _, *tail_shape = x.shape
+ left = x.new_zeros((batch, pad_left, *tail_shape))
+ right = x.new_zeros((batch, pad_right, *tail_shape))
+ x = torch.cat([left, x, right], dim=1)
+ return x
+
+ def _convert_to_block(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ """Turns a sequence to non overlapping blocks.
+
+ Args:
+ hidden_states: a tensor of [batch, time, ...].
+
+ Returns:
+ A tensor of [batch, num_blocks, block_size, ...], with necessary
+ paddings,
+ where output[:, i, ...] are x[:, i*block_size:(i+1)*block_size, ...].
+ """
+ shape = hidden_states.shape
+ b, t = shape[:2]
+ num_blocks = (t + self.chunk_size - 1) // self.chunk_size
+
+ if (padding_len := num_blocks * self.chunk_size - t) > 0:
+ hidden_states = self._pad_dim1(hidden_states, 0, padding_len)
+
+ permute_dims = (b, num_blocks, self.chunk_size) + shape[2:]
+ hidden_states = hidden_states.reshape(permute_dims).contiguous()
+ return hidden_states
+
+ def _extract_block_context(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ """Extracts temporal context for every block.
+
+ Args:
+ hidden_states: a tensor of [batch, time, ...].
+
+ Returns:
+ A tensor of [batch, num_blocks, context_size, ...], with necessary
+ paddings,
+ where context_size = block_size + left_context + right_context,
+ and output[:, i, ...] are x[:, start-left_context:end+right_context,
+ ...],
+ start = i * block_size, end = (i + 1) * block_size.
+ """
+ pad_left = self.max_past_horizon
+ # The JAX equivalent padding for signal.frame with pad_mode='valid' is
+ # (left_context, right_context + block_size - 1) on the time dimension.
+ # PyTorch's _pad_dim1 applies padding symmetrically if only one value is given,
+ # or (pad_dim_start, pad_dim_end) if two are given.
+ # Our _pad_dim1(x, pad_left, pad_right) pads dim -2 (time for [B,T,N,H])
+ # or dim 1 (time for [B,T]).
+ # The current pad_right calculation matches the JAX effective padding.
+ pad_right = self.max_future_horizon + self.chunk_size - 1
+ hidden_states = self._pad_dim1(hidden_states, pad_left, pad_right)
+
+ frame_len = self.context_size
+ frame_step = self.chunk_size
+
+ # Directly use unfold without the subframe_factor logic
+ # x.unfold(dimension, size, step)
+ # dimension=1 (time dimension, assuming x is [B, T_padded, ...])
+ # size=frame_len (context_size)
+ # step=frame_step (chunk_size)
+ x_unfolded = hidden_states.unfold(dimension=1, size=frame_len, step=frame_step)
+
+ # If x was [B, T_padded], x_unfolded is [B, num_blocks, frame_len]
+ # If x was [B, T_padded, N, H], x_unfolded is [B, num_blocks, N, H, frame_len]
+ # We want to match JAX's typical output for such operations which might be
+ # [B, num_blocks, frame_len, N, H] if N, H are present.
+ # The relative_position_embedding expects keys as [B, U, C, N, H].
+ # If x_unfolded is [B, U, N, H, C(frame_len)], we need to move C.
+ if hidden_states.ndim > 2 and x_unfolded.ndim > 3: # Check if inner dimensions (like N, H) exist
+ # Current shape after unfold for [B, T_pad, N, H] is [B, U, N, H, C]
+ # Target shape for keys in RPE: [B, U, C, N, H]
+ x_unfolded = torch.movedim(x_unfolded, source=-1, destination=2)
+
+ return x_unfolded.contiguous()
+
+ def forward(self, hidden_states: torch.Tensor, mask: torch.BoolTensor) -> torch.Tensor:
+ # sl.Dense uses jax.numpy.einsum("...a,abcd->...bcd") and jax.numpy.select()
+ qkv_shape = (*hidden_states.shape[:-1], self.num_heads, self.head_dim)
+ query_states = self.q_proj(hidden_states).reshape(qkv_shape).contiguous()
+ key_states = self.k_proj(hidden_states).reshape(qkv_shape).contiguous()
+ value_states = self.v_proj(hidden_states).reshape(qkv_shape).contiguous()
+
+ per_dim_scale_sp = torch.nn.functional.softplus(self.per_dim_scale)
+
+ broadcast_shape = (1, 1, 1, self.head_dim)
+ per_dim_scale_sp_broadcast = per_dim_scale_sp.view(broadcast_shape)
+ query_states = query_states * self.q_scale * per_dim_scale_sp_broadcast
+
+ batch_size, q_time = query_states.shape[:2]
+
+ query_blocks = self._convert_to_block(query_states)
+ key_blocks = self._extract_block_context(key_states)
+ value_blocks = self._extract_block_context(value_states)
+ num_query_blocks = query_blocks.shape[1]
+
+ # 1. Create a mask indicating originally valid positions.
+ original_valid_mask = ~mask # True for valid, False for padded
+
+ # 2. Extract blocks from this validity mask.
+ extracted_valid_mask_blocks = self._extract_block_context(original_valid_mask)
+
+ # If subframe_factor was used in _extract_block_context for a [B, T] input mask,
+ # the shape might be [B, U, C/SF, SF]. Reshape to [B, U, C].
+ # batch_size and num_query_blocks are known from query_blocks.
+ # self.context_size is C.
+ if (
+ extracted_valid_mask_blocks.ndim == 4
+ and extracted_valid_mask_blocks.shape[2] * extracted_valid_mask_blocks.shape[3] == self.context_size
+ ):
+ extracted_valid_mask_blocks = extracted_valid_mask_blocks.reshape(
+ batch_size, num_query_blocks, self.context_size
+ )
+ # After potential reshape, ensure it's [B, U, C] if it was from a [B,T] mask.
+ # This assertion might be too strict if _extract_block_context handles higher-rank inputs differently,
+ # but for the mask case, this should hold.
+ if extracted_valid_mask_blocks.shape != (
+ batch_size,
+ num_query_blocks,
+ self.context_size,
+ ):
+ raise ValueError(
+ "Shape of extracted_valid_mask_blocks"
+ f" {extracted_valid_mask_blocks.shape} is not ({batch_size},"
+ f" {num_query_blocks}, {self.context_size}) after potential reshape."
+ )
+
+ # 3. Expand dimensions for broadcasting with logits and causal mask.
+ # Target shape for broadcasting with logits [B,N,U,W,C]
+ # extracted_valid_mask_blocks to [B, 1, U, 1, C]
+ condition_from_input_validity = extracted_valid_mask_blocks.unsqueeze(1).unsqueeze(-2)
+
+ # self.local_causal_valid_mask is [W, C], True where allowed by local window.
+ # Expand to [1, 1, 1, W, C]
+ condition_from_causality = self.local_causal_valid_mask.unsqueeze(0).unsqueeze(0).unsqueeze(0)
+
+ # 4. Combine the two conditions.
+ # final_condition will be True where a key is *both* originally valid *and* causally accessible.
+ # Broadcasts to [B, 1, U, W, C]
+ final_condition_for_where = torch.logical_and(
+ condition_from_input_validity,
+ condition_from_causality.to(condition_from_input_validity.device), # Ensure same device
+ )
+
+ # Embed queries and keys
+ logits = self.relative_position_embedding(query_blocks, key_blocks)
+
+ # Apply attention logit softcap
+ # Ensure softcap is on the same device as logits
+ softcap_val = self.softcap.to(logits.device)
+ logits = logits / softcap_val
+ logits = torch.tanh(logits)
+ logits = logits * softcap_val
+
+ # Apply the combined mask.
+ # final_condition_for_where will broadcast with logits [B,N,U,W,C]
+ logits = torch.where(final_condition_for_where, logits, torch.finfo(logits.dtype).min)
+ probabilities = torch.nn.functional.softmax(logits, dim=-1, dtype=torch.float32).to(dtype=value_blocks.dtype)
+
+ # context_vectors is adapted from jax.numpy.einsum("BNuwc,BucNH->BuwNH", ...)
+ b_dim, n_dim, u_dim, w_dim, c_dim = probabilities.shape
+ h_dim = value_blocks.shape[-1]
+ prob_bun = probabilities.permute(0, 2, 1, 3, 4).reshape(-1, w_dim, c_dim)
+ v_bun = value_blocks.permute(0, 1, 3, 2, 4).reshape(-1, c_dim, h_dim)
+ result_bmm = torch.bmm(prob_bun, v_bun)
+ context_vectors = result_bmm.reshape(b_dim, u_dim, n_dim, w_dim, h_dim).permute(0, 1, 3, 2, 4)
+ context_vectors = context_vectors.reshape(
+ (
+ batch_size,
+ num_query_blocks * self.chunk_size,
+ self.num_heads,
+ self.head_dim,
+ )
+ )
+ context_vectors = context_vectors[:, :q_time]
+
+ return context_vectors
+
+
+class Gemma3nAudioCumulativeGroupNorm(nn.Module):
+ """Applies Group Normalization cumulatively over the time dimension.
+
+ This layer normalizes the input by calculating the mean and variance
+ cumulatively over the time dimension (dim 1). The statistics are computed
+ over all feature dimensions (specified by `feature_dims` and `num_channels`)
+ for elements marked as valid by the optional `mask`.
+
+ If a `mask` is provided (True for valid, False for invalid/padded),
+ invalid time steps do not contribute to the statistics calculation, and
+ their corresponding output values are zeroed out.
+
+ Scale and bias, if enabled, are applied per-channel (last dimension).
+ This behavior is similar to JAX's `GroupNormalization` with `num_groups=1`
+ and `cumulative=True`.
+ """
+
+ def __init__(
+ self,
+ num_channels: int, # Number of channels (size of the last dimension)
+ feature_dims: Sequence[int], # Sizes of non-channel feature dimensions, e.g., (H, W) for input [B,T,H,W,C]
+ eps: float = 1e-3,
+ ):
+ super().__init__()
+ self.num_channels = num_channels
+ self.feature_dims = tuple(feature_dims)
+ self.eps = eps
+
+ # Scale parameter depends only on the channel dimension
+ self.weight = nn.Parameter(torch.ones(num_channels))
+
+ # Axes for normalization: all dimensions except Batch (0) and Time (1).
+ # For input [B, T, *feature_dims, C], these are dims from 2 onwards.
+ self.reduction_axes = tuple(range(2, 2 + len(self.feature_dims) + 1))
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ """Applies cumulative group norm, optionally using a mask.
+
+ Args:
+ hidden_states: Input tensor, shape [B, T, *feature_dims, C].
+
+ Returns:
+ Normalized tensor with the same shape as x.
+ """
+ expected_input_suffix = self.feature_dims + (self.num_channels,)
+ if hidden_states.shape[2:] != expected_input_suffix:
+ raise ValueError(
+ f"Input tensor shape suffix {hidden_states.shape[2:]} does not match expected"
+ f" suffix (feature_dims + num_channels) {expected_input_suffix}"
+ )
+
+ input_dtype = hidden_states.dtype
+ # Calculations are performed in float32 for numerical stability.
+ calc_dtype = torch.float32
+ x_calc = hidden_states.to(calc_dtype)
+
+ # Prepare a broadcastable mask (`mask_calc`).
+ # If no mask is provided, treat all elements as valid
+ # (mask_calc is all ones).
+ # Otherwise, expand the [B, T] mask to [B, T, 1, ..., 1] for broadcasting.
+ mask_calc = torch.ones_like(x_calc, dtype=calc_dtype)
+
+ # Cumulative Statistics Calculation
+ # 1. Sum of values over reduction axes at each time step.
+ sum_values_at_t = torch.sum(x_calc, dim=self.reduction_axes, keepdim=True)
+ # 2. Cumulative sum of values over time.
+ cum_sum_values = torch.cumsum(sum_values_at_t, dim=1)
+
+ # 3. Count of valid elements in the normalization group at each time step.
+ # (A "group" here consists of all features at a given Batch, Time).
+ elements_in_group_at_t = torch.sum(mask_calc, dim=self.reduction_axes, keepdim=True)
+ # 4. Cumulative count of valid elements over time.
+ cum_count_elements = torch.cumsum(elements_in_group_at_t, dim=1)
+ # Avoid division by zero if all preceding elements were masked.
+ safe_cum_count_elements = torch.clamp(cum_count_elements, min=1.0)
+
+ # 5. Cumulative mean.
+ cum_mean = cum_sum_values / safe_cum_count_elements
+
+ # 6. Sum of squared differences from the cumulative mean.
+ # Only sum for valid elements: (x_calc - cum_mean)^2 * mask_calc.
+ # Using x_calc here for the difference, as cum_mean already accounts for masking.
+ squared_diff_from_mean = (x_calc - cum_mean).pow(2)
+ sum_sq_diff_at_t = torch.sum(squared_diff_from_mean, dim=self.reduction_axes, keepdim=True)
+
+ # 7. Cumulative sum of squared differences over time.
+ cum_sum_sq_diff = torch.cumsum(sum_sq_diff_at_t, dim=1)
+
+ # 8. Cumulative variance.
+ cum_variance = cum_sum_sq_diff / safe_cum_count_elements
+
+ # Normalize the input using the calculated cumulative statistics:
+ # (x - E[x]) / sqrt(Var[x] + eps)
+ normalized_x = (x_calc - cum_mean) * torch.rsqrt(cum_variance + self.eps)
+
+ # Apply affine transformation (scale and bias) if enabled.
+ # Scale and bias are applied per-channel (last dimension).
+ scale = self.weight.to(calc_dtype)
+ # Reshape for broadcasting: [C] -> [1, ..., 1, C]
+ scale_view_shape = [1] * (hidden_states.dim() - 1) + [self.num_channels]
+ normalized_x = normalized_x * scale.view(scale_view_shape)
+
+ # Zero out outputs for time steps that were originally masked (where mask_calc is 0).
+ # This ensures padded/invalid positions in the input result in zero output.
+ final_output = normalized_x * mask_calc
+
+ return final_output.to(input_dtype)
+
+
+class Gemma3nAudioSSCPConvBlock(nn.Module):
+ """A single convolution block for the SubSampleConvProjection.
+
+ This block consists of a 2D convolution, followed by CumulativeGroupNorm,
+ and a ReLU activation. It handles manual padding for the convolution.
+ """
+
+ def __init__(
+ self,
+ config: Gemma3nAudioConfig,
+ idx: int,
+ input_freq_dim: int, # Changed from input_spatial_dim
+ manual_padding: tuple[int, int, int, int] = (0, 0, 0, 0),
+ ):
+ super().__init__()
+ self.config = config
+ self.manual_padding = manual_padding
+
+ # in_channels is 1 for the first block, or C_out from previous block's conv
+ in_channels = 1 if idx == 0 else self.config.sscp_conv_channel_size[idx - 1]
+ out_channels = self.config.sscp_conv_channel_size[idx]
+ kernel_h, kernel_w = self.config.sscp_conv_kernel_size[idx]
+ stride_h, stride_w = self.config.sscp_conv_stride_size[idx]
+
+ self.conv = nn.Conv2d(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=(
+ kernel_h,
+ kernel_w,
+ ), # Kernel (kH, kW) operates on (Time, Freq_dim)
+ stride=(stride_h, stride_w),
+ padding=(0, 0), # Manual padding is used
+ bias=False,
+ )
+
+ # Calculate output frequency dimension (f_out_conv) after this convolution.
+ # input_freq_dim is the unpadded width (feature dimension).
+ # self.manual_padding is (pad_F_left, pad_F_right, pad_T_top, pad_T_bottom)
+ f_in_padded = input_freq_dim + self.manual_padding[0] + self.manual_padding[1]
+ f_out_conv = (f_in_padded - kernel_w) // stride_w + 1
+
+ self.norm = Gemma3nAudioCumulativeGroupNorm(
+ num_channels=out_channels, # Channels of the conv output
+ feature_dims=(f_out_conv,), # The frequency dimension size after conv
+ eps=self.config.sscp_conv_group_norm_eps,
+ )
+
+ self.activation = nn.ReLU()
+
+ def forward(self, audio_encodings: torch.Tensor) -> torch.Tensor:
+ # Input audio_encodings is [B, C_in, T_in, F_in] (e.g., C_in=1)
+ # manual_padding is (pad_F_left, pad_F_right, pad_T_top, pad_T_bottom)
+ # F.pad applies to last two dims: F_in then T_in
+ audio_encodings_padded = F.pad(audio_encodings, self.manual_padding, mode="constant", value=0.0)
+ # Expected padded shape for F_in, k_w=3, pad_F=(1,1) -> F_padded = F_in+2
+ # Expected padded shape for T_in, k_h=3, pad_T=(0,2) -> T_padded = T_in+2
+ audio_encodings_conv = self.conv(audio_encodings_padded)
+ # Expected conv output shape: [B, C_out, T_out, F_out]
+ # Input to norm is [B, T_out, F_out, C_out]
+ x_for_norm = audio_encodings_conv.permute(0, 2, 3, 1).contiguous()
+ x_normed = self.norm(x_for_norm)
+ # Output of norm is [B, T_out, F_out, C_out], permute back to [B, C_out, T_out, F_out]
+ audio_encodings_normed = x_normed.permute(0, 3, 1, 2).contiguous()
+ return self.activation(audio_encodings_normed)
+
+
+class Gemma3nAudioSubSampleConvProjection(nn.Module):
+ def __init__(self, config: Gemma3nAudioConfig):
+ super().__init__()
+ self.config = config
+
+ current_f_for_block_input = config.input_feat_size # Start with original feature dim
+ calculated_block_padding = []
+ calculated_f_out_dims = [] # Tracking frequency dimension output sizes
+
+ for i in range(2): # Assuming 2 conv layers as per sscp_conv_... arrays
+ kernel_h, kernel_w = config.sscp_conv_kernel_size[i]
+ stride_h, stride_w = config.sscp_conv_stride_size[i]
+
+ # Padding for Time (Height for Conv2d) - REVERSE_CAUSAL like
+ # JAX 'reverse_causal' padding is (0, kernel_size - 1)
+ pad_t_top = 0
+ pad_t_bottom = kernel_h - 1
+
+ # Frequency Padding (Width for Conv2d)
+ # Based on JAX effective padding (1,1) for F_in=10, K_w=3, S_w=2
+ # and the successful test configuration.
+ # If kernel/stride/input_freq for frequency changes, this might need re-evaluation
+ # to match generic JAX 'SAME' behavior if it differs.
+ pad_f_left = 1
+ pad_f_right = 1
+
+ manual_padding_tuple = (
+ pad_f_left,
+ pad_f_right,
+ pad_t_top,
+ pad_t_bottom,
+ )
+ calculated_block_padding.append(manual_padding_tuple)
+
+ # Calculate output frequency dimension after this convolution
+ # This uses the actual padding applied and kernel/stride.
+ f_in_padded = current_f_for_block_input + pad_f_left + pad_f_right
+ f_out_after_conv = (f_in_padded - kernel_w) // stride_w + 1 # Assuming dilation_w = 1
+ calculated_f_out_dims.append(f_out_after_conv)
+ current_f_for_block_input = f_out_after_conv
+
+ self.conv_0 = Gemma3nAudioSSCPConvBlock(
+ idx=0,
+ input_freq_dim=config.input_feat_size, # Pass original feature dim
+ config=config,
+ manual_padding=calculated_block_padding[0],
+ )
+ self.conv_1 = Gemma3nAudioSSCPConvBlock(
+ idx=1,
+ input_freq_dim=calculated_f_out_dims[0], # Output freq dim from conv_0
+ config=config,
+ manual_padding=calculated_block_padding[1],
+ )
+ final_c_out = config.sscp_conv_channel_size[-1]
+ final_f_out = calculated_f_out_dims[-1] # Final frequency dimension
+ self.input_proj_in_features = final_c_out * final_f_out
+ self.input_proj_linear = nn.Linear(self.input_proj_in_features, self.config.hidden_size, bias=False)
+
+ def forward(self, audio_encodings: torch.Tensor) -> torch.Tensor:
+ # audio_encodings is [B, T, F_in]
+ # Reshape to [B, 1, T, F_in] (Batch, Channels=1, Height=Time, Width=F_in)
+ audio_encodings_reshaped = audio_encodings.unsqueeze(1)
+ x = self.conv_0(audio_encodings_reshaped)
+ x = self.conv_1(x)
+ # x from conv_1 is [B, C_out_1, T_out_1, F_out_1]
+ b, c_out, t_out, f_out = x.shape
+ # Permute to [B, T_out_1, F_out_1, C_out_1] then flatten F_out_1 and C_out_1
+ x_permuted = x.permute(0, 2, 3, 1).contiguous()
+ output_flattened = x_permuted.view(b, t_out, f_out * c_out)
+ output = self.input_proj_linear(output_flattened)
+ return output
+
+
+class Gemma3nAudioConformerAttention(nn.Module):
+ def __init__(self, config: Gemma3nAudioConfig):
+ super().__init__()
+ self.config = config
+ self.post_in_features = self.config.hidden_size
+ self.register_buffer("gradient_clipping", torch.tensor(self.config.gradient_clipping), persistent=False)
+ self.pre_attn_norm = Gemma3nRMSNorm(self.config.hidden_size)
+ self.attn = Gemma3nAudioAttention(config)
+ self.post = nn.Linear(self.post_in_features, self.config.hidden_size, bias=False)
+ self.post_norm = Gemma3nRMSNorm(self.config.hidden_size)
+
+ def forward(self, audio_encodings: torch.Tensor, audio_mel_mask: torch.BoolTensor) -> torch.Tensor:
+ audio_encodings_input_to_attn = audio_encodings
+ audio_encodings = torch.clamp(audio_encodings, -self.gradient_clipping, self.gradient_clipping)
+ audio_encodings_norm = self.pre_attn_norm(audio_encodings)
+ # Output of self.attn is [B, T, NumHeads, HeadDim]
+ audio_encodings_attn_out = self.attn(audio_encodings_norm, audio_mel_mask)
+
+ # Reshape from [B, T, NumHeads, HeadDim] to [B, T, NumHeads * HeadDim]
+ # NumHeads * HeadDim = hidden_size
+ b, t, num_heads, head_dim = audio_encodings_attn_out.shape
+ audio_encodings_reshaped = audio_encodings_attn_out.reshape(b, t, num_heads * head_dim)
+
+ audio_encodings = self.post(audio_encodings_reshaped)
+ audio_encodings = torch.clamp(audio_encodings, -self.gradient_clipping, self.gradient_clipping)
+ return audio_encodings_input_to_attn + self.post_norm(audio_encodings)
+
+
+class Gemma3nAudioConformerFeedForward(nn.Module):
+ def __init__(self, config: Gemma3nAudioConfig):
+ super().__init__()
+ self.config = config
+
+ self.register_buffer("gradient_clipping", torch.tensor(self.config.gradient_clipping), persistent=False)
+
+ self.pre_layer_norm = Gemma3nRMSNorm(self.config.hidden_size)
+ self.ffw_layer_1 = nn.Linear(self.config.hidden_size, self.config.hidden_size * 4, bias=False)
+ self.ffw_layer_2 = nn.Linear(self.config.hidden_size * 4, self.config.hidden_size, bias=False)
+ self.post_layer_norm = Gemma3nRMSNorm(self.config.hidden_size)
+ self.post_layer_scale = torch.tensor(self.config.conf_residual_weight)
+
+ def forward(self, audio_encodings: torch.Tensor) -> torch.Tensor:
+ residual = audio_encodings
+ audio_encodings = torch.clamp(audio_encodings, -self.gradient_clipping, self.gradient_clipping)
+ audio_encodings = self.pre_layer_norm(audio_encodings)
+ audio_encodings: torch.Tensor = self.ffw_layer_1(audio_encodings)
+ audio_encodings = nn.functional.silu(audio_encodings)
+ audio_encodings: torch.Tensor = self.ffw_layer_2(audio_encodings)
+ audio_encodings = torch.clamp(audio_encodings, -self.gradient_clipping, self.gradient_clipping)
+ audio_encodings = self.post_layer_norm(audio_encodings)
+ return residual + (audio_encodings * self.post_layer_scale)
+
+
+class Gemma3nAudioConformerLightConv1d(nn.Module):
+ def __init__(self, config: Gemma3nAudioConfig):
+ super().__init__()
+ self.config = config
+
+ self.pre_layer_norm = Gemma3nRMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps)
+ self.linear_start = nn.Linear(self.config.hidden_size, self.config.hidden_size * 2, bias=False)
+ self.depthwise_conv1d = nn.Conv1d(
+ in_channels=self.config.hidden_size,
+ out_channels=self.config.hidden_size,
+ kernel_size=self.config.conf_conv_kernel_size,
+ stride=1,
+ padding=0, # Manual causal padding
+ groups=self.config.hidden_size, # Depthwise
+ bias=False,
+ )
+ self.register_buffer("gradient_clipping", torch.tensor(self.config.gradient_clipping), persistent=False)
+ self.conv_norm = Gemma3nRMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps)
+ self.linear_end = nn.Linear(self.config.hidden_size, self.config.hidden_size, bias=False)
+
+ self.causal_padding = self.config.conf_conv_kernel_size - 1
+
+ def forward(self, audio_encodings: torch.Tensor) -> torch.Tensor:
+ audio_encodings_residual = audio_encodings # Save for residual connection
+
+ audio_encodings = self.pre_layer_norm(audio_encodings)
+ audio_encodings = self.linear_start(audio_encodings)
+ audio_encodings = torch.nn.functional.glu(audio_encodings, dim=-1)
+ # Permute for Conv1d: [B, T, D] -> [B, D, T]
+ audio_encodings_permuted = audio_encodings.permute(0, 2, 1)
+ # Apply manual causal padding
+ audio_encodings_permuted_padded = F.pad(audio_encodings_permuted, (self.causal_padding, 0))
+ audio_encodings = self.depthwise_conv1d(audio_encodings_permuted_padded)
+ # Permute back: [B, D, T_out] -> [B, T_out, D]
+ audio_encodings = audio_encodings.permute(0, 2, 1)
+ audio_encodings = torch.clamp(audio_encodings, -self.gradient_clipping, self.gradient_clipping)
+ audio_encodings = self.conv_norm(audio_encodings)
+ audio_encodings = nn.functional.silu(audio_encodings)
+ audio_encodings = self.linear_end(audio_encodings)
+ output = audio_encodings + audio_encodings_residual
+ return output
+
+
+class Gemma3nAudioConformerBlock(nn.Module):
+ def __init__(self, config: Gemma3nAudioConfig):
+ super().__init__()
+ self.config = config
+
+ self.ffw_layer_start = Gemma3nAudioConformerFeedForward(self.config)
+ self.attention = Gemma3nAudioConformerAttention(self.config)
+ self.lconv1d = Gemma3nAudioConformerLightConv1d(self.config)
+ self.ffw_layer_end = Gemma3nAudioConformerFeedForward(self.config)
+ self.register_buffer("gradient_clipping", torch.tensor(self.config.gradient_clipping), persistent=False)
+ self.norm = Gemma3nRMSNorm(self.config.hidden_size)
+
+ def forward(self, audio_encodings: torch.Tensor, audio_mel_mask: torch.BoolTensor) -> torch.Tensor:
+ audio_encodings = self.ffw_layer_start(audio_encodings)
+ audio_encodings = self.attention(audio_encodings, audio_mel_mask)
+ validity_mask_for_lconv = ~audio_mel_mask # True for valid
+ audio_encodings_for_lconv_input = audio_encodings * validity_mask_for_lconv.unsqueeze(-1).to(
+ audio_encodings.dtype
+ )
+ audio_encodings = self.lconv1d(audio_encodings_for_lconv_input)
+
+ audio_encodings = self.ffw_layer_end(audio_encodings)
+ audio_encodings = torch.clamp(audio_encodings, -self.gradient_clipping, self.gradient_clipping)
+ output = self.norm(audio_encodings)
+ return output
+
+
+class Gemma3nAudioEncoder(PreTrainedModel):
+ """A Universal Speech Encoder -- https://arxiv.org/abs/2303.01037"""
+
+ config_class = Gemma3nAudioConfig
+
+ main_input_name = "audio_mel"
+
+ def __init__(self, config: Gemma3nAudioConfig):
+ super().__init__(config)
+ self.config = config
+
+ self.subsample_conv_projection = Gemma3nAudioSubSampleConvProjection(config)
+ self.conformer = nn.ModuleList(
+ [Gemma3nAudioConformerBlock(config) for _ in range(config.conf_num_hidden_layers)]
+ )
+
+ def forward(
+ self, audio_mel: torch.Tensor, audio_mel_mask: torch.BoolTensor
+ ) -> tuple[torch.Tensor, torch.BoolTensor]:
+ """Encodes a batch of MELs.
+
+ Args:
+ audio_mel: a torch.Tensor of shape [batch, num_frames, num_channels,
+ mel_bins].
+
+ Returns:
+ audio_encodings: a torch.Tensor of shape
+ `[batch_size, self.config.audio_soft_tokens_per_image,
+ self.config.audio_config.hidden_size]`
+ audio_mel_mask: a torch.BoolTensor of shape [batch, num_frames].
+ """
+ audio_encodings = self.subsample_conv_projection(audio_mel) # audio_encodings: [B, T_sub, D]
+
+ # Subsample the input audio_mel_mask to match the time dimension of audio_encodings (T_sub)
+ t_sub = audio_encodings.shape[1]
+
+ time_stride_product = 1
+ for stride_pair_idx in range(len(self.config.sscp_conv_stride_size)):
+ time_stride_product *= self.config.sscp_conv_stride_size[stride_pair_idx][0]
+
+ # Create indices for gathering from the original mask.
+ # These indices map to original time steps corresponding to the start of each
+ # receptive field in the subsampled output.
+ indices = torch.arange(t_sub, device=audio_mel_mask.device) * time_stride_product
+ indices = torch.clamp(indices, max=audio_mel_mask.shape[1] - 1) # Ensure indices are valid
+
+ # Expand indices for batch compatibility if B > 1 and indices is 1D.
+ if audio_mel_mask.ndim > 1 and indices.ndim == 1:
+ indices = indices.unsqueeze(0).expand(audio_mel_mask.shape[0], -1) # [B, T_sub]
+ elif (
+ audio_mel_mask.ndim == indices.ndim
+ and audio_mel_mask.shape[0] == 1
+ and indices.shape[0] != 1
+ and t_sub == indices.shape[0]
+ ):
+ # Handle case where B=1 but indices became [T_sub] instead of [1, T_sub]
+ indices = indices.unsqueeze(0)
+
+ current_mask = torch.gather(audio_mel_mask, 1, indices) # [B, T_sub]
+
+ for block in self.conformer:
+ audio_encodings = block(audio_encodings, current_mask) # Pass the processed mask
+
+ if self.config.conf_reduction_factor > 1:
+ audio_encodings = audio_encodings[:, :: self.config.conf_reduction_factor]
+ # Reduce the mask as well
+ current_mask = current_mask[:, :: self.config.conf_reduction_factor]
+
+ audio_encodings = audio_encodings.masked_fill(current_mask.unsqueeze(-1), 0.0)
+ return audio_encodings, current_mask
+
+
+# ==== Language Model ====
+
+
+class Gemma3nTextScaledWordEmbedding(Gemma3TextScaledWordEmbedding):
+ pass
+
+
+class Gemma3nTextLaurelBlock(nn.Module):
+ """Learned Augmented Residual Layer"""
+
+ def __init__(self, config: Gemma3nTextConfig):
+ super().__init__()
+ self.config = config
+
+ self.linear_left = nn.Linear(self.config.hidden_size, self.config.laurel_rank, bias=False)
+ self.linear_right = nn.Linear(self.config.laurel_rank, self.config.hidden_size, bias=False)
+ self.post_laurel_norm = Gemma3nRMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ laurel_hidden_states: torch.Tensor = self.linear_left(hidden_states)
+ laurel_hidden_states: torch.Tensor = self.linear_right(laurel_hidden_states)
+ normed_laurel_hidden_states = self.post_laurel_norm(laurel_hidden_states)
+ return hidden_states + normed_laurel_hidden_states
+
+
+class Gemma3nTextMLP(Gemma2MLP):
+ def __init__(self, config: Gemma3nTextConfig, layer_idx: int = 0):
+ super().__init__(config)
+ self.intermediate_size = config.intermediate_size[layer_idx]
+ self.activation_sparsity = config.activation_sparsity_pattern[layer_idx]
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ gate_proj = self.gate_proj(hidden_states)
+ if self.activation_sparsity > 0.0:
+ gate_proj = self._gaussian_topk(gate_proj)
+ activations = self.act_fn(gate_proj)
+ up_proj = self.up_proj(hidden_states)
+ down_proj = self.down_proj(activations * up_proj)
+ return down_proj
+
+ def _gaussian_topk(self, inputs: torch.Tensor) -> torch.Tensor:
+ target_sparsity_tensor = torch.tensor(self.activation_sparsity, dtype=torch.float32, device=inputs.device)
+ # normal_dist and std_multiplier are adapted from jax.scipy.stats.norm.ppf().
+ #
+ # References:
+ # * https://docs.jax.dev/en/latest/_autosummary/jax.scipy.stats.norm.ppf.html
+ # * https://pytorch.org/docs/stable/distributions.html#torch.distributions.normal.Normal
+ # * https://pytorch.org/docs/stable/distributions.html#torch.distributions.transformed_distribution.TransformedDistribution.icdf
+ normal_dist = torch.distributions.normal.Normal(0, 1)
+ std_multiplier: torch.Tensor = normal_dist.icdf(target_sparsity_tensor)
+ std_multiplier = std_multiplier.type(inputs.dtype)
+ inputs_mean = torch.mean(inputs, dim=-1, keepdim=True)
+ inputs_std = torch.std(inputs, dim=-1, keepdim=True, unbiased=False)
+ cutoff_x = inputs_mean + inputs_std * std_multiplier
+ return nn.functional.relu(inputs - cutoff_x)
+
+
+class Gemma3nTextAltUp(nn.Module):
+ """Alternating Updates (AltUp)
+
+ The AltUp module wraps transformer layers. The `predict` step modifies the
+ input to the transformer layer, and the `correct` step propagates the output
+ of the transformer layer to the sparsely updated dimensions.
+
+ See more in the research paper:
+
+ https://proceedings.neurips.cc/paper_files/paper/2023/file/f2059277ac6ce66e7e5543001afa8bb5-Paper-Conference.pdf
+ """
+
+ def __init__(self, config: Gemma3nTextConfig):
+ super().__init__()
+ self.config = config
+ self.correct_output_scale = nn.Parameter(torch.zeros(self.config.hidden_size))
+ self.correction_coefs = nn.Linear(self.config.altup_num_inputs, self.config.altup_num_inputs, bias=False)
+ self.prediction_coefs = nn.Linear(self.config.altup_num_inputs, self.config.altup_num_inputs**2, bias=False)
+ self.modality_router = nn.Linear(self.config.hidden_size, self.config.altup_num_inputs, bias=False)
+ self.router_norm = Gemma3nRMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps)
+ self.register_buffer("router_input_scale", torch.tensor(self.config.hidden_size**-1.0), persistent=False)
+
+ def compute_router_modalities(self, x: torch.Tensor) -> torch.Tensor:
+ router_inputs = self.router_norm(x) * self.router_input_scale
+ routed = self.modality_router(router_inputs)
+ return torch.tanh(routed.float()).type_as(x)
+
+ def predict(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ """Predicts the output of a layer using a trainable map.
+
+ Args:
+ hidden_states: A 4D tensor of shape `[num_altup_inputs, batch_size, num_tokens, hidden_size]` derived by
+ stacking the input embeddings and preprocessing the last `num_altup_inputs - 1` matrices.
+
+ Returns:
+ A 4D tensor of shape `[num_altup_inputs, batch_size, num_tokens, hidden_size]` containing the predictions.
+ """
+ modalities = self.compute_router_modalities(hidden_states[self.config.altup_active_idx])
+
+ if self.training and self.config.altup_coef_clip is not None:
+ self.prediction_coefs.weight.data.clamp_(-self.config.altup_coef_clip, self.config.altup_coef_clip)
+
+ # Project and then transpose all 2D matrices contained so that mulmat gives the correct result
+ all_coefs: torch.Tensor = (
+ self.prediction_coefs(modalities)
+ .reshape(*modalities.shape[:-1], self.config.altup_num_inputs, self.config.altup_num_inputs)
+ .permute(0, 1, 3, 2)
+ )
+
+ # permute hidden_states to [batch_size, num_tokens, hidden_size, altup_num_inputs]
+ predictions = torch.matmul(hidden_states.permute(1, 2, 3, 0), all_coefs)
+ predictions = predictions.permute(3, 0, 1, 2) # undo the permute
+ predictions += hidden_states # add the original input
+ return predictions.contiguous().type_as(hidden_states)
+
+ def correct(self, predictions: torch.Tensor, activated: torch.Tensor) -> torch.Tensor:
+ """Corrects the predictions relative to the
+
+ Args:
+ predictions: A 4D tensor of shape `[num_altup_inputs, batch_size, num_tokens, hidden_size]` derived by
+ stacking the input embeddings and preprocessing the last `num_altup_inputs - 1` matrices.
+ activated: A 3D tensor of shape `[batch_size, num_tokens, hidden_size]` containing the activated inputs.
+
+ Returns:
+ A 4D tensor of shape `[num_altup_inputs, batch_size, num_tokens, hidden_size]` correcting the original
+ predictions relative to the activated input embeddings.
+ """
+ modalities = self.compute_router_modalities(activated)
+ innovation = activated - predictions[self.config.altup_active_idx] # (batch, num_tokens, hidden_size)
+ innovation = innovation.repeat(self.config.altup_num_inputs, 1, 1, 1) # Repeat on dim0 to match predictions
+
+ if self.config.altup_coef_clip is not None:
+ self.correction_coefs.weight.data.clamp_(-self.config.altup_coef_clip, self.config.altup_coef_clip)
+
+ # all_coefs adapted from jax.numpy.einsum("...p,pi->...i", ...)
+ # Permute to (altup_num_inputs, batch_size, num_tokens) as the last dim is a scalar applied to each altup input
+ # and expand on dim1 for broadcastability
+ all_coefs: torch.Tensor = self.correction_coefs(modalities) + 1.0
+ all_coefs = all_coefs.permute(2, 0, 1).unsqueeze(-1)
+
+ corrected = torch.mul(innovation, all_coefs)
+ corrected += predictions # add the original input
+ return corrected.contiguous().type_as(activated)
+
+ def scale_corrected_output(self, corrected: torch.Tensor) -> torch.Tensor:
+ """Scales the provided 3D tensor of shape [batch_size, num_tokens, hidden_size]."""
+ return (corrected.type_as(self.correct_output_scale) * self.correct_output_scale).type_as(corrected)
+
+
+class Gemma3nTextRotaryEmbedding(Gemma2RotaryEmbedding):
+ pass
+
+
+def apply_rotary_pos_emb(
+ x: torch.Tensor,
+ cos: torch.Tensor,
+ sin: torch.Tensor,
+ position_ids: Optional[torch.Tensor] = None,
+ unsqueeze_dim: int = 1,
+):
+ """Applies Rotary Position Embedding to the query and key tensors.
+
+ Args:
+ x (`torch.Tensor`): The tensor to embed.
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
+ position_ids (`torch.Tensor`, *optional*):
+ Deprecated and unused.
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
+ Returns:
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
+ """
+ cos = cos.unsqueeze(unsqueeze_dim)
+ sin = sin.unsqueeze(unsqueeze_dim)
+ return (x * cos) + (rotate_half(x) * sin)
+
+
+class Gemma3nTextAttention(Gemma3Attention):
+ def __init__(self, config: Gemma3nTextConfig, layer_idx: int):
+ super().__init__()
+ del self.attn_logit_softcapping
+ del self.scaling
+ self.v_norm = Gemma3nRMSNorm(dim=config.head_dim, eps=config.rms_norm_eps, with_scale=False)
+
+ first_kv_shared_layer_idx = self.config.num_hidden_layers - self.config.num_kv_shared_layers
+ self.is_kv_shared_layer = layer_idx >= first_kv_shared_layer_idx
+ # Find the index of the last sliding or full layer before sharing starts (or None if no sharing)
+ layer_type = config.layer_types[layer_idx]
+ self.kv_shared_layer_index = (
+ first_kv_shared_layer_idx - 1 - config.layer_types[first_kv_shared_layer_idx - 1 :: -1].index(layer_type)
+ if self.is_kv_shared_layer
+ else None
+ )
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ position_embeddings: torch.Tensor,
+ attention_mask: Optional[torch.Tensor],
+ past_key_value: Optional[Cache] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs: Unpack[FlashAttentionKwargs],
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
+ input_shape = hidden_states.shape[:-1]
+ hidden_shape = (*input_shape, -1, self.config.head_dim)
+
+ cos, sin = position_embeddings
+
+ query_states = self.q_proj(hidden_states).view(hidden_shape)
+ query_states = self.q_norm(query_states)
+ query_states = apply_rotary_pos_emb(query_states, cos, sin, unsqueeze_dim=2)
+ query_states = query_states.transpose(1, 2)
+
+ if self.is_kv_shared_layer and self.kv_shared_layer_index is not None and past_key_value is not None:
+ # HybridCache has complex slicing when layer_type == "sliding_attention" that impact Shared KV Cache.
+ if isinstance(past_key_value, HybridCache) and self.is_sliding:
+ max_length = past_key_value.sliding_window
+ if cache_position.shape[0] > max_length:
+ # If in the prefill phase for a "sliding_attention" layer and the prefill is larger than the cache,
+ # slice into the entire cache.
+ indices = slice(0, max_length)
+ else:
+ # If prefill fits or generating for a "sliding_attention" layer, clamp to max_cache_len - 1
+ indices = cache_position.clamp(min=0, max=max_length - 1)
+ else:
+ indices = cache_position
+
+ key_states = past_key_value.key_cache[self.kv_shared_layer_index][:, :, indices]
+ value_states = past_key_value.value_cache[self.kv_shared_layer_index][:, :, indices]
+ else:
+ key_states = self.k_proj(hidden_states).view(hidden_shape)
+ key_states = self.k_norm(key_states)
+ key_states = apply_rotary_pos_emb(key_states, cos, sin, unsqueeze_dim=2)
+ key_states = key_states.transpose(1, 2)
+
+ value_states = self.v_proj(hidden_states).view(hidden_shape)
+ value_states = self.v_norm(value_states)
+ value_states = value_states.transpose(1, 2)
+
+ if past_key_value is not None:
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
+ cache_kwargs = {
+ "sin": sin,
+ "cos": cos,
+ "cache_position": cache_position,
+ "sliding_window": self.sliding_window,
+ }
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ attention_interface: Callable = eager_attention_forward
+ if self.config._attn_implementation != "eager":
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+
+ attn_output, attn_weights = attention_interface(
+ self,
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ dropout=self.attention_dropout if self.training else 0.0,
+ scaling=1.0,
+ sliding_window=self.sliding_window,
+ **kwargs,
+ )
+
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
+ attn_output = self.o_proj(attn_output)
+ return attn_output, attn_weights
+
+
+class Gemma3nTextDecoderLayer(Gemma3DecoderLayer):
+ def __init__(self, config: Gemma3nTextConfig, layer_idx: int):
+ super().__init__(config, layer_idx)
+ self.mlp = Gemma3nTextMLP(config, layer_idx=layer_idx)
+
+ self.hidden_size_per_layer_input = config.hidden_size_per_layer_input
+ self.act_fn = ACT2FN[config.hidden_activation]
+
+ self.altup = Gemma3nTextAltUp(config)
+ self.laurel = Gemma3nTextLaurelBlock(config)
+ self.self_attn = Gemma3nTextAttention(config, layer_idx)
+ self.per_layer_input_gate = nn.Linear(self.hidden_size, self.hidden_size_per_layer_input, bias=False)
+ self.per_layer_projection = nn.Linear(self.hidden_size_per_layer_input, self.hidden_size, bias=False)
+ self.post_per_layer_input_norm = Gemma3nRMSNorm(self.hidden_size, eps=config.rms_norm_eps)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ position_embeddings_global: torch.Tensor,
+ position_embeddings_local: torch.Tensor,
+ per_layer_input: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Cache] = None,
+ output_attentions: Optional[bool] = False,
+ use_cache: Optional[bool] = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs,
+ ) -> tuple[torch.Tensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
+ predictions = self.altup.predict(hidden_states)
+ active_prediction = predictions[self.config.altup_active_idx]
+
+ active_prediction_normed = self.input_layernorm(active_prediction)
+ laurel_output = self.laurel(active_prediction_normed)
+
+ # apply global RoPE to non-sliding layer only
+ if self.self_attn.is_sliding:
+ position_embeddings = position_embeddings_local
+ else:
+ position_embeddings = position_embeddings_global
+
+ attn, self_attn_weights = self.self_attn(
+ hidden_states=active_prediction_normed,
+ position_embeddings=position_embeddings,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ **kwargs,
+ )
+ attn = self.post_attention_layernorm(attn)
+
+ attn_gated = active_prediction + attn
+ attn_laurel = (attn_gated + laurel_output) / math.sqrt(2)
+
+ attn_norm = self.pre_feedforward_layernorm(attn_laurel)
+ attn_ffw = self.mlp(attn_norm)
+ attn_ffw_norm = self.post_feedforward_layernorm(attn_ffw)
+ attn_ffw_laurel_gated = attn_laurel + attn_ffw_norm
+ corrected_predictions = self.altup.correct(predictions, attn_ffw_laurel_gated)
+
+ first_prediction = corrected_predictions[self.config.altup_active_idx]
+ first_prediction_clone = first_prediction.clone()
+ if self.config.altup_correct_scale:
+ first_prediction = self.altup.scale_corrected_output(first_prediction_clone)
+
+ # per_layer_input_gate adapted from jax.numpy.einsum("btd,dp->btp", ...)
+ first_prediction = self.per_layer_input_gate(first_prediction)
+ first_prediction = self.act_fn(first_prediction)
+ first_prediction = torch.multiply(first_prediction, per_layer_input)
+
+ # per_layer_projection adapted from jax.numpy.einsum("btp,pd->btd", ...)
+ first_prediction = self.per_layer_projection(first_prediction)
+ first_prediction = self.post_per_layer_input_norm(first_prediction)
+ corrected_predictions[1:] += first_prediction
+
+ outputs = (corrected_predictions,)
+
+ if output_attentions:
+ outputs += (self_attn_weights,)
+
+ return outputs
+
+
+class Gemma3nPreTrainedModel(Gemma2PreTrainedModel):
+ config_class = Gemma3nConfig
+ base_model_prefix = ""
+ _no_split_modules = ["Gemma3nDecoderLayer"]
+
+ def _init_weights(self, module):
+ # important: this ported version of Gemma2 isn't meant for training from scratch - only
+ # inference and fine-tuning - so the proper init weights code has been removed
+ std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range)
+
+ if isinstance(module, (nn.Linear, nn.Conv1d, nn.Conv2d)):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.Embedding):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+ elif isinstance(module, Gemma3nRMSNorm):
+ if module.with_scale:
+ module.weight.data.fill_(1.0)
+ elif isinstance(module, Gemma3nAudioCumulativeGroupNorm):
+ module.weight.data.fill_(1.0)
+ elif isinstance(module, Gemma3nAudioAttention):
+ module.per_dim_scale.data.zero_()
+ elif isinstance(module, Gemma3nTextAltUp):
+ module.correct_output_scale.data.zero_()
+
+
+@auto_docstring(custom_intro="The base Gemma 3n language model without a language modeling head.")
+class Gemma3nTextModel(Gemma3TextModel):
+ config_class = Gemma3nTextConfig
+
+ def __init__(self, config: Gemma3nTextConfig):
+ super().__init__(config)
+
+ self.hidden_size = config.hidden_size
+ self.hidden_size_per_layer_input = config.hidden_size_per_layer_input
+
+ self.embed_tokens_per_layer = Gemma3nTextScaledWordEmbedding(
+ config.vocab_size_per_layer_input,
+ config.num_hidden_layers * config.hidden_size_per_layer_input,
+ self.padding_idx,
+ embed_scale=config.hidden_size_per_layer_input**0.5,
+ )
+
+ self.per_layer_model_projection = nn.Linear(
+ self.hidden_size,
+ config.num_hidden_layers * config.hidden_size_per_layer_input,
+ bias=False,
+ )
+
+ self.per_layer_projection_norm = Gemma3nRMSNorm(config.hidden_size_per_layer_input, eps=config.rms_norm_eps)
+ self.layers = nn.ModuleList(
+ [Gemma3nTextDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
+ )
+
+ self.norm = Gemma3nRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+
+ self.altup_projections = nn.ModuleList(
+ [nn.Linear(self.hidden_size, self.hidden_size, bias=False) for _ in range(1, self.config.altup_num_inputs)]
+ )
+
+ self.altup_unembed_projections = nn.ModuleList(
+ [nn.Linear(self.hidden_size, self.hidden_size, bias=False) for _ in range(1, self.config.altup_num_inputs)]
+ )
+
+ self.register_buffer("per_layer_projection_scale", torch.tensor(self.hidden_size**-0.5), persistent=False)
+ self.register_buffer("per_layer_input_scale", torch.rsqrt(torch.tensor(2.0)), persistent=False)
+ self.rotary_emb = Gemma3nTextRotaryEmbedding(config=config)
+
+ # TODO (raushan): Fix this after RoPE refactor. For now we hack it by
+ # reassigning thetas when we want to create a local RoPE layer. Config
+ # defaults should hold values for global RoPE.
+ config = copy.deepcopy(config)
+ config.rope_theta = config.rope_local_base_freq
+ config.rope_scaling = {"rope_type": "default"}
+ self.rotary_emb_local = Gemma3nTextRotaryEmbedding(config=config)
+
+ def get_per_layer_inputs(self, input_ids: torch.LongTensor) -> torch.Tensor:
+ return self.embed_tokens_per_layer(input_ids).reshape(
+ *input_ids.shape,
+ self.config.num_hidden_layers,
+ self.hidden_size_per_layer_input,
+ )
+
+ def project_per_layer_inputs(
+ self,
+ inputs_embeds: torch.Tensor,
+ per_layer_inputs: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ per_layer_projection: torch.Tensor = self.per_layer_model_projection(inputs_embeds)
+ per_layer_projection *= self.per_layer_projection_scale.type(inputs_embeds.dtype)
+ per_layer_projection = per_layer_projection.reshape(
+ *inputs_embeds.shape[:-1],
+ self.config.num_hidden_layers,
+ self.hidden_size_per_layer_input,
+ )
+ per_layer_projection = self.per_layer_projection_norm(per_layer_projection)
+
+ if per_layer_inputs is None:
+ return per_layer_projection
+
+ if per_layer_projection.shape != per_layer_inputs.shape:
+ # per-layer inputs are sometimes padded with zeros, slice the relevant embeddings.
+ per_layer_inputs = per_layer_inputs[..., : self.config.num_hidden_layers, :]
+
+ return (per_layer_projection + per_layer_inputs) * self.per_layer_input_scale.type(inputs_embeds.dtype)
+
+ @can_return_tuple
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ per_layer_inputs: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ **flash_attn_kwargs: Unpack[FlashAttentionKwargs],
+ ) -> BaseModelOutputWithPast:
+ r"""
+ per_layer_inputs (torch.Tensor, *optional*, defaults to None):
+ Pre-computed per-layer embeddings. If None, they are derived from input_ids if provided.
+ """
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
+
+ if self.gradient_checkpointing and self.training and use_cache:
+ logger.warning_once(
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
+ )
+ use_cache = False
+
+ if input_ids is not None:
+ inputs_embeds = self.embed_tokens(input_ids)
+ per_layer_inputs = self.get_per_layer_inputs(input_ids)
+
+ per_layer_inputs = self.project_per_layer_inputs(inputs_embeds, per_layer_inputs)
+
+ if use_cache and past_key_values is None and not self.training:
+ past_key_values = DynamicCache()
+
+ if cache_position is None:
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
+ cache_position = torch.arange(
+ past_seen_tokens,
+ past_seen_tokens + inputs_embeds.shape[1],
+ device=inputs_embeds.device,
+ )
+
+ if position_ids is None:
+ position_ids = cache_position.unsqueeze(0)
+
+ # It may already have been prepared by e.g. `generate`
+ if not isinstance(causal_mask_mapping := attention_mask, dict):
+ # Prepare mask arguments
+ mask_kwargs = {
+ "config": self.config,
+ "input_embeds": inputs_embeds,
+ "attention_mask": attention_mask,
+ "cache_position": cache_position,
+ "past_key_values": past_key_values,
+ }
+ # Create the masks
+ causal_mask_mapping = {
+ "full_attention": create_causal_mask(**mask_kwargs),
+ "sliding_attention": create_sliding_window_causal_mask(**mask_kwargs),
+ }
+
+ # embed positions
+ hidden_states_0 = inputs_embeds
+
+ # Initialize RoPE embeddings
+ position_embeddings_global = self.rotary_emb(hidden_states_0, position_ids)
+ position_embeddings_local = self.rotary_emb_local(hidden_states_0, position_ids)
+
+ # Expand hidden_states to support per-layer inputs
+ target_magnitude: torch.Tensor = torch.mean(hidden_states_0**2, dim=-1, keepdim=True) ** 0.5
+ epsilon_tensor = torch.tensor(torch.finfo().min)
+
+ temp_hidden_states = [hidden_states_0]
+ for i in range(1, self.config.altup_num_inputs):
+ # altup_proj adapted from jax.numpy.einsum("btp,pd->btd", ...)
+ altup_proj: torch.Tensor = self.altup_projections[i - 1](hidden_states_0)
+ current_hidden_state = altup_proj.type(hidden_states_0.dtype)
+ new_magnitude = torch.mean(current_hidden_state**2, dim=-1, keepdim=True) ** 0.5
+ current_hidden_state = current_hidden_state * (
+ target_magnitude / torch.maximum(new_magnitude, epsilon_tensor)
+ )
+ temp_hidden_states.append(current_hidden_state)
+
+ hidden_states = torch.stack(temp_hidden_states, dim=0) # [num_altup_inputs, batch, seq_len, hidden_size]
+
+ # decoder layers
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attns = () if output_attentions else None
+
+ for decoder_layer in self.layers[: self.config.num_hidden_layers]:
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ causal_mask = causal_mask_mapping[decoder_layer.attention_type]
+ per_layer_input = per_layer_inputs[:, :, decoder_layer.layer_idx, :]
+
+ layer_outputs = decoder_layer(
+ hidden_states,
+ position_embeddings_global=position_embeddings_global,
+ position_embeddings_local=position_embeddings_local,
+ per_layer_input=per_layer_input,
+ attention_mask=causal_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_values,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ **flash_attn_kwargs,
+ )
+
+ hidden_states = layer_outputs[0]
+
+ if output_attentions:
+ all_self_attns += (layer_outputs[1],)
+
+ # add hidden states from the last decoder layer (but before reprojecting to stay consistent with layer output)
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ # Per-layer inputs to single output
+ target_magnitude = torch.mean(hidden_states[0] ** 2, dim=-1, keepdim=True) ** 0.5
+ temp_hidden_states = [hidden_states[0]]
+ for i in range(1, self.config.altup_num_inputs):
+ # altup_unembed_projections adapted from jax.numpy.einsum("btp,pd->btd", ...)
+ altup_unemb_proj: torch.Tensor = self.altup_unembed_projections[i - 1](hidden_states[i])
+ current_hidden_state = altup_unemb_proj.type(hidden_states_0.dtype)
+ new_magnitude = torch.mean(current_hidden_state**2, dim=-1, keepdim=True) ** 0.5
+ current_hidden_state = current_hidden_state * (
+ target_magnitude / torch.maximum(new_magnitude, epsilon_tensor)
+ )
+ temp_hidden_states.append(current_hidden_state)
+
+ hidden_states = torch.stack(temp_hidden_states)
+ hidden_states = torch.mean(hidden_states, dim=0)
+ hidden_states = self.norm(hidden_states)
+
+ return BaseModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ past_key_values=past_key_values,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attns,
+ )
+
+
+@auto_docstring(custom_intro="The base Gemma 3n language model with a language modeling head.")
+class Gemma3nForCausalLM(Gemma3ForCausalLM):
+ _checkpoint_conversion_mapping = {"model.language_model": "model"}
+ base_model_prefix = "model"
+
+
+class Gemma3nMultimodalEmbedder(nn.Module):
+ """Embeds token ids or soft tokens for multimodal content into language model space."""
+
+ def __init__(
+ self,
+ multimodal_config: Union[Gemma3nAudioConfig, Gemma3nVisionConfig],
+ text_config: Gemma3nTextConfig,
+ ):
+ super().__init__()
+
+ self.multimodal_hidden_size = multimodal_config.hidden_size
+ self.eps = multimodal_config.rms_norm_eps
+ self.vocab_offset = multimodal_config.vocab_offset
+ self.vocab_size = multimodal_config.vocab_size
+ self.text_hidden_size = text_config.hidden_size
+
+ self.embedding = nn.Embedding(self.vocab_size, self.multimodal_hidden_size)
+ self.hard_embedding_norm = Gemma3nRMSNorm(self.multimodal_hidden_size, eps=self.eps)
+ self.soft_embedding_norm = Gemma3nRMSNorm(self.multimodal_hidden_size, eps=self.eps)
+ self.embedding_projection = nn.Linear(self.multimodal_hidden_size, self.text_hidden_size, bias=False)
+ self.embedding_post_projection_norm = Gemma3nRMSNorm(self.text_hidden_size, eps=self.eps, with_scale=False)
+
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ """Embeds token ids or soft tokens for multimodal content into language model space.
+
+ Args:
+ input_ids: A torch.LongTensor containing the token ids to embed. Values should be in the range
+ `[vocab_offset, vocab_offset + vocab_size)`.
+ inputs_embeds: A torch.Tensor containing the soft tokens to embed.
+
+ Returns:
+ A torch.Tensor of embeddings with shape `[batch_size, seq_len, self.config.text_config.hidden_size]`.
+ """
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
+
+ if inputs_embeds is not None:
+ emb_norm = self.soft_embedding_norm(inputs_embeds)
+ else:
+ hard_emb = self.embedding(input_ids - self.vocab_offset)
+ emb_norm = self.hard_embedding_norm(hard_emb)
+
+ emb_norm_proj = self.embedding_projection(emb_norm)
+ return self.embedding_post_projection_norm(emb_norm_proj)
+
+
+@auto_docstring(
+ custom_intro="""
+ The base Gemma 3n model comprising a vision backbone, an audio backbone, and a language model without a
+ language modeling head.
+ """
+)
+class Gemma3nModel(PaliGemmaModel):
+ _checkpoint_conversion_mapping = {}
+
+ def __init__(self, config: Gemma3nConfig):
+ super().__init__()
+ del self.multi_modal_projector # Replaced by Gemma3nVisionEmbedder
+ self.vocab_size_per_layer_input = config.text_config.vocab_size_per_layer_input
+ self.audio_tower = AutoModel.from_config(config.audio_config)
+ self.embed_vision = Gemma3nMultimodalEmbedder(config.vision_config, config.text_config)
+ self.embed_audio = Gemma3nMultimodalEmbedder(config.audio_config, config.text_config)
+
+ def get_image_features(self, pixel_values: torch.Tensor) -> torch.Tensor:
+ """
+ Projects the last hidden state from the vision model into language model space.
+
+ Args:
+ pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`)
+ The tensors corresponding to the input images.
+
+ Returns:
+ image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
+ """
+ vision_outputs = self.vision_tower(
+ pixel_values=pixel_values, do_pooling=False, return_dict=True
+ ).last_hidden_state
+ # Convert from (batch, channels, height, width) to (batch, height * width, channels) where:
+ # height == width and height * width == Gemma3nConfig.vision_soft_tokens_per_image.
+ vision_outputs = vision_outputs.reshape(
+ vision_outputs.shape[0],
+ self.config.vision_config.hidden_size,
+ self.config.vision_soft_tokens_per_image,
+ ).permute(0, 2, 1)
+ # Normalize and embed the soft tokens into language model space.
+ vision_outputs *= self.config.vision_config.hidden_size**0.5
+ return self.embed_vision(inputs_embeds=vision_outputs)
+
+ @can_return_tuple
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None, # text inputs
+ pixel_values: Optional[torch.FloatTensor] = None, # vision inputs
+ input_features: Optional[torch.FloatTensor] = None, # audio inputs
+ attention_mask: Optional[torch.Tensor] = None,
+ input_features_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Union[list[torch.FloatTensor], Cache]] = None,
+ token_type_ids: Optional[torch.LongTensor] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ **lm_kwargs,
+ ) -> Gemma3nCausalLMOutputWithPast:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+ config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.text_config.vocab_size]`.
+
+ Example:
+
+ ```python
+ >>> from PIL import Image
+ >>> import requests
+ >>> from transformers import AutoProcessor, Gemma3nForConditionalGeneration
+
+ >>> model = Gemma3nForConditionalGeneration.from_pretrained("google/gemma3n2-3b-mix-224")
+ >>> processor = AutoProcessor.from_pretrained("google/gemma3n2-3b-mix-224")
+
+ >>> prompt = "Where is the cat standing?"
+ >>> url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg"
+ >>> image = Image.open(requests.get(url, stream=True).raw)
+
+ >>> inputs = processor(images=image, text=prompt, return_tensors="pt")
+
+ >>> # Generate
+ >>> generate_ids = model.generate(**inputs,)
+ >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
+ "Where is the cat standing?\nsnow"
+ ```
+ """
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
+
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+
+ if input_ids is not None:
+ inputs_embeds = self.get_input_embeddings()(input_ids)
+
+ # Prepare per-layer inputs from inputs_ids
+ per_layer_inputs_mask = torch.logical_and(input_ids >= 0, input_ids < self.vocab_size_per_layer_input)
+ per_layer_inputs_tokens = torch.where(per_layer_inputs_mask, input_ids, torch.zeros_like(input_ids))
+ per_layer_inputs = self.language_model.get_per_layer_inputs(per_layer_inputs_tokens)
+
+ # Handle vision tokens (>= embed_vision.vocab_offset and < embed_audio.vocab_offset)
+ vision_mask = torch.logical_and(
+ input_ids >= self.embed_vision.vocab_offset, input_ids < self.embed_audio.vocab_offset
+ )
+ dummy_vision_token_id = self.embed_vision.vocab_offset + self.embed_vision.vocab_size - 1
+ vision_input_ids = torch.where(vision_mask, input_ids, dummy_vision_token_id).to(inputs_embeds.device)
+ vision_embeds = self.embed_vision(input_ids=vision_input_ids)
+ expanded_vision_mask = vision_mask.unsqueeze(-1).expand_as(inputs_embeds)
+ inputs_embeds = torch.where(expanded_vision_mask, vision_embeds, inputs_embeds)
+
+ # Handle audio tokens (>= embed_audio.vocab_offset)
+ audio_mask = input_ids >= self.embed_audio.vocab_offset
+ dummy_audio_token_id = self.embed_audio.vocab_offset + self.embed_audio.vocab_size - 1
+ audio_input_ids = torch.where(audio_mask, input_ids, dummy_audio_token_id).to(inputs_embeds.device)
+ audio_embeds = self.embed_audio(input_ids=audio_input_ids)
+ expanded_audio_mask = audio_mask.unsqueeze(-1).expand_as(inputs_embeds)
+ inputs_embeds = torch.where(expanded_audio_mask, audio_embeds, inputs_embeds)
+ else:
+ per_layer_inputs = None
+
+ # Merge text and images
+ if pixel_values is not None:
+ image_features = self.get_image_features(pixel_values)
+
+ if input_ids is None:
+ special_image_mask = inputs_embeds == self.get_input_embeddings()(
+ torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
+ )
+ else:
+ special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)
+ special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
+
+ if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
+ image_tokens_in_text = (special_image_mask).sum(dim=1).sum(dim=0)[0]
+ raise ValueError(
+ f"Number of images does not match number of special image tokens in the input text. "
+ f"Got {image_tokens_in_text} image tokens in the text and "
+ f"{image_features.shape[0] * image_features.shape[1]} tokens from image embeddings."
+ )
+ image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
+ inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
+
+ # Merge text and audio
+ if input_features is not None and input_features_mask is not None:
+ audio_features, audio_mask = self.get_audio_features(input_features, ~input_features_mask)
+
+ # The Gemma3nProcessor expects all audio will be 30s in length and inserts 188 audio soft tokens into the
+ # text to account for this. However, the audio preprocessing and encoder do not gurarantee they will
+ # produce 188 soft tokens; they will produce at most that many tokens, but they may produce fewer tokens
+ # depending on the length of the longest audio input in the batch. When we encounter this situation, we pad
+ # the audio feature out to 188 soft tokens with the emebedding of the last token in the embed_audio vocab.
+ audio_padding_toks = torch.tensor([[self.vocab_size - 1]], dtype=torch.long, device=audio_features.device)
+ audio_padding_embs = self.embed_audio(input_ids=audio_padding_toks)
+ audio_features = torch.where(audio_mask.unsqueeze(-1), audio_padding_embs, audio_features)
+
+ audio_batch_size, audio_seq_len, audio_embed_dim = audio_features.shape
+ extra_padding_tokens = self.config.audio_soft_tokens_per_image - audio_seq_len
+ extra_padding_features = audio_padding_embs.expand(audio_batch_size, extra_padding_tokens, audio_embed_dim)
+
+ audio_features = torch.cat((audio_features, extra_padding_features), dim=1)
+
+ if input_ids is None:
+ special_audio_mask = inputs_embeds == self.embed_audio(
+ input_ids=torch.tensor(self.config.audio_token_id, dtype=torch.long, device=inputs_embeds.device)
+ )
+ else:
+ special_audio_mask = (input_ids == self.config.audio_token_id).unsqueeze(-1)
+ special_audio_mask = special_audio_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
+
+ if not is_torchdynamo_compiling() and inputs_embeds[special_audio_mask].numel() != audio_features.numel():
+ audio_tokens_in_text = (special_audio_mask).sum(dim=1).sum(dim=0)[0]
+ raise ValueError(
+ f"Number of audio input features does not match number of special audio tokens in the input text. "
+ f"Got {audio_tokens_in_text} audio tokens in the text and "
+ f"{audio_features.shape[0] * audio_features.shape[1]} tokens from audio embeddings."
+ )
+ audio_features = audio_features.to(inputs_embeds.device, inputs_embeds.dtype)
+ inputs_embeds = inputs_embeds.masked_scatter(special_audio_mask, audio_features)
+
+ outputs = self.language_model(
+ input_ids=None,
+ per_layer_inputs=per_layer_inputs,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=True,
+ cache_position=cache_position,
+ **lm_kwargs,
+ )
+
+ return Gemma3nModelOutputWithPast(
+ last_hidden_state=outputs.last_hidden_state,
+ past_key_values=outputs.past_key_values if use_cache else None,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ image_hidden_states=image_features if pixel_values is not None else None,
+ audio_hidden_states=audio_features if input_features is not None else None,
+ )
+
+ def get_audio_features(
+ self, input_features: torch.Tensor, input_features_mask: torch.Tensor
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ """
+ Projects the last hidden state from the audio encoder into language model space.
+
+ Args:
+ input_features (`torch.FloatTensor]` of shape `(num_images, seq_length, num_features)`):
+ The tensors corresponding to the input audio.
+ input_features (`torch.FloatTensor]` of shape `(num_images, seq_length)`):
+ The attention mask for the input audio.
+
+ Returns:
+ audio_features (`torch.Tensor`): Audio feature tensor of shape `(num_images, audio_length, embed_dim)`).
+ """
+ audio_outputs, audio_mask = self.audio_tower(input_features, input_features_mask)
+ return self.embed_audio(inputs_embeds=audio_outputs), audio_mask
+
+ def _update_causal_mask(self, **super_kwargs):
+ raise AttributeError("We don't want to inherit it")
+
+
+@auto_docstring(
+ custom_intro="""
+ The base Gemma 3n model comprising a vision backbone, an audio backbone, a language model, and a language modeling
+ head.
+ """
+)
+class Gemma3nForConditionalGeneration(PaliGemmaForConditionalGeneration):
+ _checkpoint_conversion_mapping = {}
+ base_model_prefix = "model"
+
+ @property
+ def audio_tower(self):
+ return self.model.audio_tower
+
+ @property
+ def multi_modal_projector(self):
+ raise AttributeError("Use embed_vision instead of multi_modal_projector.")
+
+ @can_return_tuple
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None, # text inputs
+ pixel_values: Optional[torch.FloatTensor] = None, # vision inputs
+ input_features: Optional[torch.FloatTensor] = None, # audio inputs
+ attention_mask: Optional[torch.Tensor] = None,
+ input_features_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Union[list[torch.FloatTensor], Cache]] = None,
+ token_type_ids: Optional[torch.LongTensor] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ logits_to_keep: Union[int, torch.Tensor] = 0,
+ **lm_kwargs,
+ ) -> Gemma3nCausalLMOutputWithPast:
+ r"""
+ input_features (torch.Tensor, *optional*, defaults to None):
+ The audio inputs to be encoded.
+ input_features_mask (torch.Tensor, *optional*, defaults to None):
+ The attention mask for the input audio.
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+ config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are
+ ignored (masked), the loss is only computed for the tokens with labels in
+ `[0, ..., config.text_config.vocab_size]`.
+
+ Example:
+
+ ```python
+ >>> from PIL import Image
+ >>> import requests
+ >>> from transformers import AutoProcessor, Gemma3ForConditionalGeneration
+
+ >>> model = Gemma3ForConditionalGeneration.from_pretrained("google/gemma-3-4b-it")
+ >>> processor = AutoProcessor.from_pretrained("google/gemma-3-4b-it")
+
+ >>> messages = [
+ ... {
+ ... "role": "system",
+ ... "content": [
+ ... {"type": "text", "text": "You are a helpful assistant."}
+ ... ]
+ ... },
+ ... {
+ ... "role": "user", "content": [
+ ... {"type": "image", "url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg"},
+ ... {"type": "text", "text": "Where is the cat standing?"},
+ ... ]
+ ... },
+ ... ]
+
+ >>> inputs = processor.apply_chat_template(
+ ... messages,
+ ... tokenizer=True,
+ ... return_dict=True,
+ ... return_tensors="pt",
+ ... add_generation_prompt=True
+ ... )
+ >>> # Generate
+ >>> generate_ids = model.generate(**inputs)
+ >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
+ "user\nYou are a helpful assistant.\n\n\n\n\n\nWhere is the cat standing?\nmodel\nBased on the image, the cat is standing in a snowy area, likely outdoors. It appears to"
+ ```
+ """
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+
+ outputs = self.model(
+ input_ids=input_ids,
+ pixel_values=pixel_values,
+ input_features=input_features,
+ attention_mask=attention_mask,
+ input_features_mask=input_features_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ token_type_ids=token_type_ids,
+ cache_position=cache_position,
+ inputs_embeds=inputs_embeds,
+ labels=labels,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=True,
+ **lm_kwargs,
+ )
+
+ hidden_states = outputs.last_hidden_state
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
+ if (final_logit_softcapping := self.config.get_text_config().final_logit_softcapping) is not None:
+ logits = logits / final_logit_softcapping
+ logits = torch.tanh(logits)
+ logits = logits * final_logit_softcapping
+
+ loss = None
+ if labels is not None:
+ # Upcast to float if we need to compute the loss to avoid potential precision issues
+ logits = logits.float()
+ shift_logits = logits[..., :-1, :]
+ shift_labels = labels[..., 1:]
+ if attention_mask is not None:
+ # we use the input attention mask to shift the logits and labels, because it is 2D.
+ # we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
+ shift_attention_mask = attention_mask[:, -shift_logits.shape[1] :].to(logits.device)
+ shift_logits = shift_logits[shift_attention_mask.to(logits.device) != 0].contiguous()
+ shift_labels = shift_labels[shift_attention_mask.to(shift_labels.device) != 0].contiguous()
+ else:
+ shift_logits = shift_logits.contiguous()
+ shift_labels = shift_labels.contiguous()
+ # Flatten the tokens
+ loss_fct = nn.CrossEntropyLoss()
+
+ flat_logits = shift_logits.view(-1, self.config.text_config.vocab_size)
+ flat_labels = shift_labels.view(-1).to(shift_logits.device)
+ loss = loss_fct(flat_logits, flat_labels)
+
+ return Gemma3nCausalLMOutputWithPast(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ image_hidden_states=outputs.image_hidden_states,
+ audio_hidden_states=outputs.audio_hidden_states,
+ )
+
+ def prepare_inputs_for_generation(
+ self,
+ input_ids,
+ past_key_values=None,
+ inputs_embeds=None,
+ cache_position=None,
+ position_ids=None,
+ pixel_values=None,
+ input_features=None,
+ attention_mask=None,
+ input_features_mask=None,
+ token_type_ids=None,
+ use_cache=True,
+ logits_to_keep=None,
+ labels=None,
+ **kwargs,
+ ):
+ # Overwritten -- custom `position_ids` and `pixel_values` handling
+ model_inputs = super().prepare_inputs_for_generation(
+ input_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ cache_position=cache_position,
+ use_cache=use_cache,
+ logits_to_keep=logits_to_keep,
+ token_type_ids=token_type_ids,
+ **kwargs,
+ )
+
+ # If we're in cached decoding stage, multimodal inputs should be None because input ids do not contain special
+ # tokens anymore. Otherwise multimodal inputs should be passed to model.
+ # NOTE: use_cache=False always needs pixel_values, input_features, and input_features_mask
+ if cache_position[0] == 0:
+ model_inputs["pixel_values"] = pixel_values
+ model_inputs["input_features"] = input_features
+ model_inputs["input_features_mask"] = input_features_mask
+
+ return model_inputs
+
+ def _prepare_4d_causal_attention_mask_with_cache_position(self, **super_kwargs):
+ raise AttributeError("Do not inherit _prepare_4d_causal_attention_mask_with_cache_position from PaliGemma")
+
+
+__all__ = [
+ "Gemma3nAudioConfig",
+ "Gemma3nAudioEncoder",
+ "Gemma3nConfig",
+ "Gemma3nForCausalLM",
+ "Gemma3nForConditionalGeneration",
+ "Gemma3nModel",
+ "Gemma3nPreTrainedModel", # noqa: F822
+ "Gemma3nTextConfig",
+ "Gemma3nTextModel",
+ "Gemma3nVisionConfig",
+]
diff --git a/src/transformers/models/gemma3n/processing_gemma3n.py b/src/transformers/models/gemma3n/processing_gemma3n.py
new file mode 100644
index 000000000000..45e953b5c5d2
--- /dev/null
+++ b/src/transformers/models/gemma3n/processing_gemma3n.py
@@ -0,0 +1,191 @@
+# coding=utf-8
+# Copyright 2025 Google Inc. HuggingFace Inc. team. All rights reserved.
+#
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import Optional, Union
+
+import numpy as np
+
+from ...feature_extraction_utils import BatchFeature
+from ...image_utils import ImageInput, make_nested_list_of_images
+from ...processing_utils import AudioKwargs, ImagesKwargs, ProcessingKwargs, ProcessorMixin, Unpack
+from ...tokenization_utils_base import PreTokenizedInput, TextInput
+
+
+class Gemma3nImagesKwargs(ImagesKwargs):
+ do_pan_and_scan: Optional[bool]
+ pan_and_scan_min_crop_size: Optional[int]
+ pan_and_scan_max_num_crops: Optional[int]
+ pan_and_scan_min_ratio_to_activate: Optional[float]
+ do_convert_rgb: Optional[bool]
+
+
+class Gemma3nProcessorKwargs(ProcessingKwargs, total=False):
+ audio_kwargs: AudioKwargs
+ images_kwargs: Gemma3nImagesKwargs
+ _defaults = {
+ "text_kwargs": {
+ "padding": False,
+ },
+ }
+
+
+class Gemma3nProcessor(ProcessorMixin):
+ """
+ A processor for Gemma 3n, wrapping the full capabilities of a feature extractor, image processor, and tokenizer
+ into a single processor.
+
+ Args:
+ feature_extractor (`Gemma3nAudioFeatureExtractor`):
+ Feature extractor that converts raw audio waveforms into MEL spectrograms for the audio encoder. This
+ should return a `BatchFeature` with `input_features` and `input_features_mask` features.
+ image_processor (`SiglipImageProcessorFast`):
+ Image processor that prepares batches of images for the vision encoder. This should return a `BatchFeature`
+ with a `pixel_values` feature.
+ tokenizer (`GemmaTokenizerFast`):
+ The text tokenizer for the model.
+ chat_template (`string`, *optional*):
+ A Jinja template for generating text prompts from a set of messages.
+ audio_seq_length (int, *optional*, defaults to 188):
+ The number of audio soft tokens that will be added to the text prompt
+ image_seq_length (int, *optional*, defaults to 256):
+ The number of image soft tokens that should be added to
+ """
+
+ attributes = ["feature_extractor", "image_processor", "tokenizer"]
+ feature_extractor_class = "AutoFeatureExtractor"
+ image_processor_class = "AutoImageProcessor"
+ tokenizer_class = "AutoTokenizer"
+
+ def __init__(
+ self,
+ feature_extractor,
+ image_processor,
+ tokenizer,
+ chat_template=None,
+ audio_seq_length: int = 188,
+ image_seq_length: int = 256,
+ **kwargs,
+ ):
+ self.audio_seq_length = audio_seq_length
+ self.audio_token_id = tokenizer.audio_token_id
+ self.boa_token = tokenizer.boa_token
+ self.audio_token = tokenizer.audio_token
+ audio_tokens_expanded = "".join([tokenizer.audio_token] * audio_seq_length)
+ self.full_audio_sequence = f"\n\n{tokenizer.boa_token}{audio_tokens_expanded}{tokenizer.eoa_token}\n\n"
+
+ self.image_seq_length = image_seq_length
+ self.image_token_id = tokenizer.image_token_id
+ self.boi_token = tokenizer.boi_token
+ self.image_token = tokenizer.image_token
+ image_tokens_expanded = "".join([tokenizer.image_token] * image_seq_length)
+ self.full_image_sequence = f"\n\n{tokenizer.boi_token}{image_tokens_expanded}{tokenizer.eoi_token}\n\n"
+
+ super().__init__(
+ feature_extractor=feature_extractor,
+ image_processor=image_processor,
+ tokenizer=tokenizer,
+ chat_template=chat_template,
+ **kwargs,
+ )
+
+ def __call__(
+ self,
+ images: ImageInput = None,
+ text: Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]] = None,
+ audio: Optional[Union[np.ndarray, list[float], list[np.ndarray], list[list[float]]]] = None,
+ videos=None,
+ **kwargs: Unpack[Gemma3nProcessorKwargs],
+ ) -> BatchFeature:
+ if text is None and images is None and audio is None:
+ raise ValueError("Provide at least one of `text`, `images`, or `audio`.")
+
+ output_kwargs = self._merge_kwargs(
+ Gemma3nProcessorKwargs,
+ tokenizer_init_kwargs=self.tokenizer.init_kwargs,
+ **kwargs,
+ )
+
+ if isinstance(text, str):
+ text = [text]
+ elif not isinstance(text, list) and not isinstance(text[0], str):
+ raise ValueError("Invalid input text. Please provide a string, or a list of strings")
+
+ if audio is not None:
+ audio_inputs = self.feature_extractor(audio, **output_kwargs["audio_kwargs"])
+
+ if not text:
+ text = [self.audio_token for _ in audio]
+
+ # Expand placeholder audio tokens to the full audio token sequence
+ text = [prompt.replace(self.audio_token, self.full_audio_sequence) for prompt in text]
+ else:
+ audio_inputs = {}
+
+ if images is not None:
+ batched_images = make_nested_list_of_images(images)
+ image_inputs = self.image_processor(batched_images, **output_kwargs["images_kwargs"])
+
+ # Create empty text to be replaced with placeholders
+ if not text:
+ text = [" ".join([self.image_token] * len(images)) for images in batched_images]
+
+ if len(batched_images) != len(text):
+ raise ValueError(
+ f"Received inconsistently sized batches of images ({len(batched_images)}) and text ({len(text)})."
+ )
+
+ # Expand placeholder image tokens to the full image token sequence
+ text = [prompt.replace(self.image_token, self.full_image_sequence) for prompt in text]
+ else:
+ image_inputs = {}
+
+ return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
+ text_inputs = self.tokenizer(text=text, **output_kwargs["text_kwargs"], return_tensors="np")
+ self._check_special_mm_tokens(text, text_inputs, modalities=["image"])
+
+ # Add token type ids manually, as tokenizer can't do arbitrary position token types
+ array_ids = text_inputs["input_ids"]
+ token_type_ids = np.zeros_like(array_ids)
+ token_type_ids[array_ids == self.image_token_id] = 1
+ token_type_ids[array_ids == self.audio_token_id] = 3
+ text_inputs = {k: v.tolist() for k, v in text_inputs.items()} # in case user requested list inputs
+ text_inputs["token_type_ids"] = token_type_ids.tolist()
+ return BatchFeature(data={**text_inputs, **image_inputs, **audio_inputs}, tensor_type=return_tensors)
+
+ # Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Gemma
+ def batch_decode(self, *args, **kwargs):
+ """
+ This method forwards all its arguments to GemmaTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
+ refer to the docstring of this method for more information.
+ """
+ return self.tokenizer.batch_decode(*args, **kwargs)
+
+ # Copied from transformers.models.clip.processing_clip.CLIPProcessor.decode with CLIP->Gemma
+ def decode(self, *args, **kwargs):
+ """
+ This method forwards all its arguments to GemmaTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
+ the docstring of this method for more information.
+ """
+ return self.tokenizer.decode(*args, **kwargs)
+
+ @property
+ def model_input_names(self):
+ tokenizer_input_names = self.tokenizer.model_input_names + ["token_type_ids"]
+ image_processor_input_names = self.image_processor.model_input_names
+ feature_extactor_input_names = self.feature_extractor.model_input_names
+ return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names + feature_extactor_input_names))
+
+
+__all__ = ["Gemma3nProcessor"]
diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py
index 10f31b81c8f5..fe24e85c6419 100644
--- a/src/transformers/testing_utils.py
+++ b/src/transformers/testing_utils.py
@@ -1627,6 +1627,7 @@ def set_model_tester_for_less_flaky_test(test_case):
"AriaVisionText2TextModelTester",
"GPTNeoModelTester",
"DPTModelTester",
+ "Gemma3nTextModelTester", # cannot have a single layer combined with the cache sharing config attrs in the tester
]
if test_case.model_tester.__class__.__name__ in exceptional_classes:
target_num_hidden_layers = None
diff --git a/tests/models/gemma3n/__init__.py b/tests/models/gemma3n/__init__.py
new file mode 100644
index 000000000000..e69de29bb2d1
diff --git a/tests/models/gemma3n/test_feature_extraction_gemma3n.py b/tests/models/gemma3n/test_feature_extraction_gemma3n.py
new file mode 100644
index 000000000000..d2b10315bd6e
--- /dev/null
+++ b/tests/models/gemma3n/test_feature_extraction_gemma3n.py
@@ -0,0 +1,277 @@
+# Copyright 2025 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import itertools
+import os
+import random
+import tempfile
+import unittest
+from typing import Optional, Sequence
+
+import numpy as np
+from parameterized import parameterized
+
+from transformers.models.gemma3n import Gemma3nAudioFeatureExtractor
+from transformers.testing_utils import (
+ check_json_file_has_correct_format,
+ require_torch,
+)
+from transformers.utils.import_utils import is_torch_available
+
+from ...test_sequence_feature_extraction_common import SequenceFeatureExtractionTestMixin
+
+
+if is_torch_available():
+ pass
+
+global_rng = random.Random()
+
+MAX_LENGTH_FOR_TESTING = 512
+
+
+def floats_list(shape, scale=1.0, rng=None):
+ """Creates a random float32 tensor"""
+ if rng is None:
+ rng = global_rng
+
+ values = []
+ for _ in range(shape[0]):
+ values.append([])
+ for _ in range(shape[1]):
+ values[-1].append(rng.random() * scale)
+
+ return values
+
+
+class Gemma3nAudioFeatureExtractionTester:
+ def __init__(
+ self,
+ parent,
+ batch_size=7,
+ min_seq_length=400,
+ max_seq_length=2000,
+ feature_size: int = 128,
+ sampling_rate: int = 16_000,
+ padding_value: float = 0.0,
+ return_attention_mask: bool = False,
+ # ignore hop_length / frame_length for now, as ms -> length conversion causes issues with serialization tests
+ # frame_length_ms: float = 32.0,
+ # hop_length: float = 10.0,
+ min_frequency: float = 125.0,
+ max_frequency: float = 7600.0,
+ preemphasis: float = 0.97,
+ preemphasis_htk_flavor: bool = True,
+ fft_overdrive: bool = True,
+ dither: float = 0.0,
+ input_scale_factor: float = 1.0,
+ mel_floor: float = 1e-5,
+ per_bin_mean: Optional[Sequence[float]] = None,
+ per_bin_stddev: Optional[Sequence[float]] = None,
+ ):
+ self.parent = parent
+ self.batch_size = batch_size
+ self.min_seq_length = min_seq_length
+ self.max_seq_length = max_seq_length
+ self.seq_length_diff = (self.max_seq_length - self.min_seq_length) // (self.batch_size - 1)
+ self.feature_size = feature_size
+ self.sampling_rate = sampling_rate
+ self.padding_value = padding_value
+ self.return_attention_mask = return_attention_mask
+ # ignore hop_length / frame_length for now, as ms -> length conversion causes issues with serialization tests
+ # self.frame_length_ms = frame_length_ms
+ # self.hop_length = hop_length
+ self.min_frequency = min_frequency
+ self.max_frequency = max_frequency
+ self.preemphasis = preemphasis
+ self.preemphasis_htk_flavor = preemphasis_htk_flavor
+ self.fft_overdrive = fft_overdrive
+ self.dither = dither
+ self.input_scale_factor = input_scale_factor
+ self.mel_floor = mel_floor
+ self.per_bin_mean = per_bin_mean
+ self.per_bin_stddev = per_bin_stddev
+
+ def prepare_feat_extract_dict(self):
+ return {
+ "feature_size": self.feature_size,
+ "sampling_rate": self.sampling_rate,
+ "padding_value": self.padding_value,
+ "return_attention_mask": self.return_attention_mask,
+ "min_frequency": self.min_frequency,
+ "max_frequency": self.max_frequency,
+ "preemphasis": self.preemphasis,
+ "preemphasis_htk_flavor": self.preemphasis_htk_flavor,
+ "fft_overdrive": self.fft_overdrive,
+ "dither": self.dither,
+ "input_scale_factor": self.input_scale_factor,
+ "mel_floor": self.mel_floor,
+ "per_bin_mean": self.per_bin_mean,
+ "per_bin_stddev": self.per_bin_stddev,
+ }
+
+ def prepare_inputs_for_common(self, equal_length=False, numpify=False):
+ def _flatten(list_of_lists):
+ return list(itertools.chain(*list_of_lists))
+
+ if equal_length:
+ speech_inputs = [floats_list((self.max_seq_length, self.feature_size)) for _ in range(self.batch_size)]
+ else:
+ # make sure that inputs increase in size
+ speech_inputs = [
+ floats_list((x, self.feature_size))
+ for x in range(self.min_seq_length, self.max_seq_length, self.seq_length_diff)
+ ]
+ if numpify:
+ speech_inputs = [np.asarray(x) for x in speech_inputs]
+ return speech_inputs
+
+
+class Gemma3nAudioFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.TestCase):
+ feature_extraction_class = Gemma3nAudioFeatureExtractor
+
+ def setUp(self):
+ self.feat_extract_tester = Gemma3nAudioFeatureExtractionTester(self)
+
+ def test_feat_extract_from_and_save_pretrained(self):
+ feat_extract_first = self.feature_extraction_class(**self.feat_extract_dict)
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ saved_file = feat_extract_first.save_pretrained(tmpdirname)[0]
+ check_json_file_has_correct_format(saved_file)
+ feat_extract_second = self.feature_extraction_class.from_pretrained(tmpdirname)
+
+ dict_first = feat_extract_first.to_dict()
+ dict_second = feat_extract_second.to_dict()
+ mel_1 = feat_extract_first.mel_filters
+ mel_2 = feat_extract_second.mel_filters
+ self.assertTrue(np.allclose(mel_1, mel_2))
+ self.assertEqual(dict_first, dict_second)
+
+ def test_feat_extract_to_json_file(self):
+ feat_extract_first = self.feature_extraction_class(**self.feat_extract_dict)
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ json_file_path = os.path.join(tmpdirname, "feat_extract.json")
+ feat_extract_first.to_json_file(json_file_path)
+ feat_extract_second = self.feature_extraction_class.from_json_file(json_file_path)
+
+ dict_first = feat_extract_first.to_dict()
+ dict_second = feat_extract_second.to_dict()
+ mel_1 = feat_extract_first.mel_filters
+ mel_2 = feat_extract_second.mel_filters
+ self.assertTrue(np.allclose(mel_1, mel_2))
+ self.assertEqual(dict_first, dict_second)
+
+ def test_feat_extract_from_pretrained_kwargs(self):
+ feat_extract_first = self.feature_extraction_class(**self.feat_extract_dict)
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ saved_file = feat_extract_first.save_pretrained(tmpdirname)[0]
+ check_json_file_has_correct_format(saved_file)
+ feat_extract_second = self.feature_extraction_class.from_pretrained(
+ tmpdirname, feature_size=2 * self.feat_extract_dict["feature_size"]
+ )
+
+ mel_1 = feat_extract_first.mel_filters
+ mel_2 = feat_extract_second.mel_filters
+ self.assertTrue(2 * mel_1.shape[1] == mel_2.shape[1])
+
+ @parameterized.expand(
+ [
+ ([floats_list((1, x))[0] for x in range(800, 1400, 200)],),
+ ([floats_list((1, x))[0] for x in (800, 800, 800)],),
+ ([floats_list((1, x))[0] for x in range(200, (MAX_LENGTH_FOR_TESTING + 500), 200)], True),
+ ]
+ )
+ def test_call(self, audio_inputs, test_truncation=False):
+ feature_extractor = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
+ np_audio_inputs = [np.asarray(audio_input) for audio_input in audio_inputs]
+
+ input_features = feature_extractor(np_audio_inputs, padding="max_length", return_tensors="np").input_features
+ self.assertTrue(input_features.ndim == 3)
+ # input_features.shape should be (batch, num_frames, n_mels) ~= (batch, num_frames, feature_size)
+ # 480_000 is the max_length that inputs are padded to. we use that to calculate num_frames
+ expected_num_frames = (480_000 - feature_extractor.frame_length) // (feature_extractor.hop_length) + 1
+ self.assertTrue(
+ input_features.shape[-2] == expected_num_frames,
+ f"no match: {input_features.shape[-1]} vs {expected_num_frames}",
+ )
+ self.assertTrue(input_features.shape[-1] == feature_extractor.feature_size)
+
+ encoded_sequences_1 = feature_extractor(audio_inputs, return_tensors="np").input_features
+ encoded_sequences_2 = feature_extractor(np_audio_inputs, return_tensors="np").input_features
+ for enc_seq_1, enc_seq_2 in zip(encoded_sequences_1, encoded_sequences_2):
+ self.assertTrue(np.allclose(enc_seq_1, enc_seq_2, atol=1e-3))
+
+ if test_truncation:
+ audio_inputs_truncated = [x[:MAX_LENGTH_FOR_TESTING] for x in audio_inputs]
+ np_audio_inputs_truncated = [np.asarray(audio_input) for audio_input in audio_inputs_truncated]
+
+ encoded_sequences_1 = feature_extractor(
+ audio_inputs_truncated, max_length=MAX_LENGTH_FOR_TESTING, return_tensors="np"
+ ).input_features
+ encoded_sequences_2 = feature_extractor(
+ np_audio_inputs_truncated, max_length=MAX_LENGTH_FOR_TESTING, return_tensors="np"
+ ).input_features
+ for enc_seq_1, enc_seq_2 in zip(encoded_sequences_1, encoded_sequences_2):
+ self.assertTrue(np.allclose(enc_seq_1, enc_seq_2, atol=1e-3))
+
+ def test_dither(self):
+ np.random.seed(42) # seed the dithering randn()
+
+ # Tests that features with and without little dithering are similar, but not the same
+ dict_no_dither = self.feat_extract_tester.prepare_feat_extract_dict()
+ dict_no_dither["dither"] = 0.0
+
+ dict_dither = self.feat_extract_tester.prepare_feat_extract_dict()
+ dict_dither["dither"] = 0.00003 # approx. 1/32k
+
+ feature_extractor_no_dither = self.feature_extraction_class(**dict_no_dither)
+ feature_extractor_dither = self.feature_extraction_class(**dict_dither)
+
+ # create three inputs of length 800, 1000, and 1200
+ speech_inputs = [floats_list((1, x))[0] for x in range(800, 1400, 200)]
+ np_speech_inputs = [np.asarray(speech_input) for speech_input in speech_inputs]
+
+ # compute features
+ input_features_no_dither = feature_extractor_no_dither(
+ np_speech_inputs, padding=True, return_tensors="np", sampling_rate=dict_no_dither["sampling_rate"]
+ ).input_features
+ input_features_dither = feature_extractor_dither(
+ np_speech_inputs, padding=True, return_tensors="np", sampling_rate=dict_dither["sampling_rate"]
+ ).input_features
+
+ # test there is a difference between features (there's added noise to input signal)
+ diff = input_features_dither - input_features_no_dither
+
+ # features are not identical
+ self.assertTrue(np.abs(diff).mean() > 1e-6)
+ # features are not too different
+ self.assertTrue(np.abs(diff).mean() <= 1e-4)
+ self.assertTrue(np.abs(diff).max() <= 5e-3)
+
+ @require_torch
+ def test_double_precision_pad(self):
+ import torch
+
+ feature_extractor = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
+ np_speech_inputs = np.random.rand(100, 32).astype(np.float64)
+ py_speech_inputs = np_speech_inputs.tolist()
+
+ for inputs in [py_speech_inputs, np_speech_inputs]:
+ np_processed = feature_extractor.pad([{"input_features": inputs}], return_tensors="np")
+ self.assertTrue(np_processed.input_features.dtype == np.float32)
+ pt_processed = feature_extractor.pad([{"input_features": inputs}], return_tensors="pt")
+ self.assertTrue(pt_processed.input_features.dtype == torch.float32)
diff --git a/tests/models/gemma3n/test_modeling_gemma3n.py b/tests/models/gemma3n/test_modeling_gemma3n.py
new file mode 100644
index 000000000000..2f546e19e49c
--- /dev/null
+++ b/tests/models/gemma3n/test_modeling_gemma3n.py
@@ -0,0 +1,886 @@
+# coding=utf-8
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Testing suite for the PyTorch Gemma3n model."""
+
+import tempfile
+import unittest
+
+import numpy as np
+import pytest
+from datasets import load_dataset
+from parameterized import parameterized
+
+from transformers import (
+ AutoModelForCausalLM,
+ AutoProcessor,
+ AutoTokenizer,
+ Gemma3nAudioConfig,
+ Gemma3nAudioFeatureExtractor,
+ Gemma3nConfig,
+ Gemma3nTextConfig,
+ GenerationConfig,
+ is_torch_available,
+)
+from transformers.testing_utils import (
+ cleanup,
+ require_flash_attn,
+ require_read_token,
+ require_torch,
+ require_torch_gpu,
+ slow,
+ torch_device,
+)
+
+from ...generation.test_utils import GenerationTesterMixin
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
+from ..gemma.test_modeling_gemma import GemmaModelTester
+
+
+if is_torch_available():
+ import torch
+
+ from transformers import (
+ Gemma3nAudioEncoder,
+ Gemma3nForCausalLM,
+ Gemma3nForConditionalGeneration,
+ Gemma3nModel,
+ Gemma3nTextModel,
+ )
+
+
+class Gemma3nAudioModelTester:
+ def __init__(
+ self,
+ parent,
+ batch_size=2,
+ num_channels=32, # feature_size / input_feat_size
+ sampling_rate=16_000,
+ raw_audio_length=8_000,
+ is_training=True,
+ ):
+ self.parent = parent
+ self.batch_size = batch_size
+ self.num_channels = num_channels
+ self.sampling_rate = sampling_rate
+ self.raw_audio_length = raw_audio_length
+ self.is_training = is_training
+
+ def get_feature_extractor_config(self):
+ return {
+ "feature_size": self.num_channels,
+ "sampling_rate": self.sampling_rate,
+ "padding_value": 0.0,
+ "return_attention_mask": True,
+ "frame_length_ms": 32.0,
+ "hop_length_ms": 10.0,
+ "dither": 0.0, # Important for determinism
+ }
+
+ def get_audio_encoder_config(self):
+ return Gemma3nAudioConfig(
+ input_feat_size=self.num_channels,
+ hidden_size=32,
+ conf_num_attention_heads=4,
+ conf_num_hidden_layers=2,
+ sscp_conv_channel_size=(16, 8),
+ conf_conv_kernel_size=3,
+ conf_attention_chunk_size=4,
+ conf_attention_context_left=5,
+ )
+
+ def prepare_config_and_inputs_for_common(self):
+ # Prepare inputs for the audio encoder
+ feature_extractor_config = self.get_feature_extractor_config()
+ audio_encoder_config = self.get_audio_encoder_config()
+
+ np.random.seed(0)
+ raw_speech_1 = np.sin(2 * np.pi * 440 * np.linspace(0, 1, self.raw_audio_length)).astype(np.float32)
+ raw_speech_2 = np.random.randn(self.raw_audio_length // 2).astype(np.float32)
+ raw_speech = [raw_speech_1, raw_speech_2]
+
+ feature_extractor = Gemma3nAudioFeatureExtractor(**feature_extractor_config)
+ audio_inputs = feature_extractor(raw_speech, return_tensors="pt")
+
+ input_features = audio_inputs["input_features"]
+ # The encoder expects a padding mask (True for padding), while the feature extractor
+ # returns an attention mask (True for valid tokens). We must invert it.
+ input_features_mask = ~audio_inputs["input_features_mask"].to(torch.bool)
+
+ inputs_dict = {
+ "audio_mel": input_features,
+ "audio_mel_mask": input_features_mask,
+ }
+ return audio_encoder_config, inputs_dict
+
+
+@unittest.skip("Skipped for now!")
+@require_torch
+class Gemma3nAudioModelTest(ModelTesterMixin, unittest.TestCase):
+ all_model_classes = (Gemma3nAudioEncoder,) if is_torch_available() else ()
+ test_pruning = False
+ test_head_masking = False
+ test_missing_keys = False
+ is_generative = False
+ _is_stateful = True
+ main_input_name = "audio_mel"
+ test_initialization = False
+ test_can_init_all_missing_weights = False
+
+ def setUp(self):
+ self.model_tester = Gemma3nAudioModelTester(self)
+ self.config_tester = ConfigTester(self, config_class=Gemma3nAudioConfig, hidden_size=37)
+ torch.manual_seed(0)
+
+ # The following values are golden outputs from a deterministic run of the components.
+ # They are used to ensure that changes to the code do not alter the numerical output.
+ # Generated with seeds np.random.seed(0) and torch.manual_seed(0).
+ self.expected_input_features_shape = (2, 48, 32)
+ self.expected_input_features_slice = np.array([-5.733152, -5.337127, -4.916284, -4.378989, -3.7622747])
+ self.expected_input_features_mask_shape = (2, 48)
+ self.expected_input_features_mask_slice = np.array([True, True, True, True, False])
+
+ self.expected_encoder_output_shape = (2, 3, 32)
+ self.expected_encoder_output_slice = torch.tensor([-0.4159, 0.6459, 0.6305, 2.2902, 0.9683])
+ self.expected_encoder_mask_shape = (2, 3)
+ self.expected_encoder_mask_slice = torch.tensor([False, False, True])
+
+ # Prepare a shared feature extractor and raw audio for the tests
+ self.feature_extractor = Gemma3nAudioFeatureExtractor(**self.model_tester.get_feature_extractor_config())
+ np.random.seed(0)
+ raw_speech_1 = np.sin(2 * np.pi * 440 * np.linspace(0, 1, self.model_tester.raw_audio_length)).astype(
+ np.float32
+ )
+ raw_speech_2 = np.random.randn(self.model_tester.raw_audio_length // 2).astype(np.float32)
+ self.raw_speech = [raw_speech_1, raw_speech_2]
+
+ @unittest.skip("Audio encoder does not support attention output")
+ def test_attention_outputs(self):
+ pass
+
+ @unittest.skip("Audio encoder does not support hidden state output")
+ def test_hidden_states_output(self):
+ pass
+
+ @unittest.skip("Audio encoder returns a tuple, not a ModelOutput object, skipping equivalence test.")
+ def test_model_outputs_equivalence(self):
+ pass
+
+ @unittest.skip("Audio encoder does not support retaining gradients on hidden states/attentions.")
+ def test_retain_grad_hidden_states_attentions(self):
+ pass
+
+ @unittest.skip("Audio encoder does not have a concept of token embeddings")
+ def test_model_get_set_embeddings(self):
+ pass
+
+ @unittest.skip("Audio encoder does not have a concept of token embeddings")
+ def test_resize_tokens_embeddings(self):
+ pass
+
+ @unittest.skip("This model has a complex downsampling scheme that is hard to test with the generic batching test.")
+ def test_batching_equivalence(self):
+ pass
+
+ def test_feature_extractor(self):
+ """
+ Tests the feature extractor's output against pre-computed golden values.
+ This ensures the NumPy-based audio preprocessing is correct and consistent.
+ """
+ audio_inputs = self.feature_extractor(
+ self.raw_speech, padding="longest", pad_to_multiple_of=128, return_tensors="np"
+ )
+
+ input_features = audio_inputs["input_features"]
+ self.assertEqual(input_features.shape, self.expected_input_features_shape)
+ np.testing.assert_allclose(input_features[0, 0, :5], self.expected_input_features_slice, rtol=1e-5, atol=1e-5)
+
+ print(input_features[0, 0, :5])
+
+ input_features_mask = audio_inputs["input_features_mask"]
+ self.assertEqual(input_features_mask.shape, self.expected_input_features_mask_shape)
+ # The second audio sample is shorter (22 frames vs 48), so its mask should become False at index 22
+ np.testing.assert_array_equal(input_features_mask[1, 21:26], self.expected_input_features_mask_slice)
+
+ def test_audio_encoder(self):
+ """
+ Tests the audio encoder's forward pass against pre-computed golden values.
+ This ensures the PyTorch-based audio encoding model is correct and consistent.
+ """
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ model = Gemma3nAudioEncoder(config).to(torch_device).eval()
+
+ with torch.no_grad():
+ encoder_output, encoder_mask = model(**inputs_dict)
+
+ print(encoder_output[0, 0, :5])
+
+ # Check output encodings
+ self.assertEqual(encoder_output.shape, self.expected_encoder_output_shape)
+ torch.testing.assert_close(
+ encoder_output[0, 0, :5], self.expected_encoder_output_slice.to(torch_device), rtol=1e-4, atol=1e-4
+ )
+
+ # Check output mask (True means padded)
+ # Second sample has 22 feature frames. After downsampling by 4 (conv) -> 5 frames. After downsampling by 4 (reduction) -> 1 frame.
+ # So the mask should be [False, True, True]
+ self.assertEqual(encoder_mask.shape, self.expected_encoder_mask_shape)
+ torch.testing.assert_close(encoder_mask[1, :], self.expected_encoder_mask_slice.to(torch_device))
+
+
+class Gemma3nTextModelTester(GemmaModelTester):
+ activation_sparsity_pattern = None
+ forced_config_args = ["activation_sparsity_pattern"]
+
+ def __init__(
+ self,
+ parent,
+ batch_size=13,
+ seq_length=7,
+ is_training=True,
+ use_input_mask=True,
+ use_token_type_ids=False,
+ use_labels=True,
+ vocab_size=99,
+ vocab_size_per_layer_input=99,
+ hidden_size=16,
+ num_hidden_layers=4, # override to correctly test sharing cache pattern
+ num_kv_shared_layers=2, # important to override
+ layer_types=[
+ "full_attention",
+ "sliding_attention",
+ "full_attention",
+ "sliding_attention",
+ ], # similarly we want to test sharing on both types
+ num_attention_heads=2,
+ num_key_value_heads=2,
+ altup_num_inputs=2,
+ intermediate_size=21,
+ hidden_activation="gelu_pytorch_tanh",
+ max_position_embeddings=512,
+ type_vocab_size=16,
+ type_sequence_label_size=2,
+ initializer_range=0.02,
+ num_labels=3,
+ num_choices=4,
+ pad_token_id=0,
+ bos_token_id=1,
+ eos_token_id=2,
+ is_decoder=False,
+ ):
+ self._verify_model_attributes()
+ self.parent = parent
+ self.batch_size = batch_size
+ self.seq_length = seq_length
+ self.is_training = is_training
+ self.use_input_mask = use_input_mask
+ self.use_token_type_ids = use_token_type_ids
+ self.use_labels = use_labels
+ self.vocab_size = vocab_size
+ self.vocab_size_per_layer_input = vocab_size_per_layer_input
+ self.hidden_size = hidden_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_kv_shared_layers = num_kv_shared_layers
+ self.layer_types = layer_types
+ self.num_attention_heads = num_attention_heads
+ self.num_key_value_heads = num_key_value_heads
+ self.altup_num_inputs = altup_num_inputs
+ self.intermediate_size = intermediate_size
+ self.hidden_activation = hidden_activation
+ self.max_position_embeddings = max_position_embeddings
+ self.type_vocab_size = type_vocab_size
+ self.type_sequence_label_size = type_sequence_label_size
+ self.initializer_range = initializer_range
+ self.num_labels = num_labels
+ self.num_choices = num_choices
+ self.pad_token_id = pad_token_id
+ self.bos_token_id = bos_token_id
+ self.eos_token_id = eos_token_id
+ self.head_dim = self.hidden_size // self.num_attention_heads
+ self.is_decoder = is_decoder
+
+ if is_torch_available():
+ config_class = Gemma3nTextConfig
+ model_class = Gemma3nTextModel
+ for_causal_lm_class = Gemma3nForCausalLM
+
+
+@unittest.skip("Skipped for now!")
+@require_torch
+class Gemma3nTextModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
+ all_model_classes = (Gemma3nTextModel, Gemma3nForCausalLM) if is_torch_available() else ()
+ all_generative_model_classes = (Gemma3nForCausalLM,) if is_torch_available() else ()
+ test_headmasking = False
+ test_pruning = False
+ _is_stateful = True
+ model_split_percents = [0.5, 0.6]
+
+ def setUp(self):
+ self.model_tester = Gemma3nTextModelTester(self)
+ self.config_tester = ConfigTester(
+ self,
+ config_class=Gemma3nConfig,
+ hidden_size=37,
+ text_config={"activation_sparsity_pattern": None},
+ )
+
+ def _check_hidden_states_for_generate(
+ self, batch_size, hidden_states, prompt_length, output_length, config, use_cache=False
+ ):
+ "Gemma3n has special hidden states shape with 1 additional dim (which is then reduced with projections)"
+
+ self.assertIsInstance(hidden_states, tuple)
+ self.assertListEqual(
+ [isinstance(iter_hidden_states, tuple) for iter_hidden_states in hidden_states],
+ [True] * len(hidden_states),
+ )
+ self.assertEqual(len(hidden_states), (output_length - prompt_length))
+
+ # When `output_hidden_states=True`, each iteration of generate appends the hidden states corresponding to the
+ # new token(s)
+ # NOTE: `HybridCache` may have different lengths on different layers, if this test starts failing add more
+ # elaborate checks
+ for generated_length, iter_hidden_states in enumerate(hidden_states):
+ # regardless of using cache, the first forward pass will have the full prompt as input
+ if use_cache and generated_length > 0:
+ model_input_length = 1
+ else:
+ model_input_length = prompt_length + generated_length
+ expected_shape = (config.altup_num_inputs, batch_size, model_input_length, config.hidden_size)
+ # check hidden size
+ self.assertListEqual(
+ [layer_hidden_states.shape for layer_hidden_states in iter_hidden_states],
+ [expected_shape] * len(iter_hidden_states),
+ )
+
+
+class Gemma3nVision2TextModelTester:
+ text_config = {"activation_sparsity_pattern": None}
+ forced_config_args = ["text_config"]
+
+ def __init__(
+ self,
+ parent,
+ mm_tokens_per_image=2,
+ image_token_index=1,
+ boi_token_index=2,
+ eoi_token_index=3,
+ seq_length=25,
+ is_training=True,
+ vision_config={
+ "use_labels": True,
+ "image_size": 20,
+ "patch_size": 5,
+ "num_channels": 3,
+ "is_training": True,
+ "hidden_size": 32,
+ "num_key_value_heads": 1,
+ "num_hidden_layers": 2,
+ "num_attention_heads": 4,
+ "intermediate_size": 37,
+ "dropout": 0.1,
+ "attention_dropout": 0.1,
+ "initializer_range": 0.02,
+ },
+ use_cache=False,
+ ):
+ self.parent = parent
+ # `image_token_index` is set to 0 to pass "resize_embeddings" test, do not modify
+ self.mm_tokens_per_image = mm_tokens_per_image
+ self.image_token_index = image_token_index
+ self.boi_token_index = boi_token_index
+ self.eoi_token_index = eoi_token_index
+ self.llm_tester = Gemma3nTextModelTester(self.parent)
+ self.text_config = self.llm_tester.get_config()
+ self.vision_config = vision_config
+ self.seq_length = seq_length
+ self.pad_token_id = self.text_config.pad_token_id
+
+ self.num_hidden_layers = self.text_config.num_hidden_layers
+ self.vocab_size = self.text_config.vocab_size
+ self.hidden_size = self.text_config.hidden_size
+ self.num_attention_heads = self.text_config.num_attention_heads
+ self.is_training = is_training
+
+ self.batch_size = 3
+ self.num_channels = vision_config["num_channels"]
+ self.image_size = vision_config["image_size"]
+ self.encoder_seq_length = seq_length
+ self.use_cache = use_cache
+
+ def get_config(self):
+ return Gemma3nConfig(
+ text_config=self.text_config,
+ vision_config=self.vision_config,
+ image_token_index=self.image_token_index,
+ boi_token_index=self.boi_token_index,
+ eoi_token_index=self.eoi_token_index,
+ mm_tokens_per_image=self.mm_tokens_per_image,
+ )
+
+ def prepare_config_and_inputs(self):
+ pixel_values = floats_tensor(
+ [
+ self.batch_size,
+ self.vision_config["num_channels"],
+ self.vision_config["image_size"],
+ self.vision_config["image_size"],
+ ]
+ )
+ config = self.get_config()
+
+ return config, pixel_values
+
+ def prepare_config_and_inputs_for_common(self):
+ config_and_inputs = self.prepare_config_and_inputs()
+ config, pixel_values = config_and_inputs
+ input_ids = ids_tensor([self.batch_size, self.seq_length], config.text_config.vocab_size - 1) + 1
+ attention_mask = input_ids.ne(self.pad_token_id).to(torch_device)
+
+ # set the 3 first tokens to be image, and ensure that no other tokens are image tokens
+ # do not change this unless you modified image size or patch size
+ input_ids[input_ids == config.image_token_index] = self.pad_token_id
+ input_ids[:, :1] = config.image_token_index
+
+ token_type_ids = torch.zeros_like(input_ids)
+ token_type_ids[input_ids == config.image_token_index] = 1
+
+ inputs_dict = {
+ "pixel_values": pixel_values,
+ "input_ids": input_ids,
+ "attention_mask": attention_mask,
+ "token_type_ids": token_type_ids,
+ }
+ return config, inputs_dict
+
+
+@unittest.skip("Skipped for now!")
+@require_torch
+class Gemma3nVision2TextModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
+ all_model_classes = (Gemma3nModel, Gemma3nForConditionalGeneration) if is_torch_available() else ()
+ all_generative_model_classes = (Gemma3nForConditionalGeneration,) if is_torch_available() else ()
+ test_headmasking = False
+ test_pruning = False
+ test_missing_keys = False
+ _is_stateful = True
+ model_split_percents = [0.5, 0.6]
+
+ # MP works but offload doesn't work when the SigLIP MultiheadAttention is offloaded
+ # TODO: One potential solution would be to add to set preload_module_classes = ["SiglipMultiheadAttentionPoolingHead"]
+ # in the dispatch_model function
+ test_cpu_offload = False
+ test_disk_offload_safetensors = False
+ test_disk_offload_bin = False
+
+ def setUp(self):
+ self.model_tester = Gemma3nVision2TextModelTester(self)
+ self.config_tester = ConfigTester(
+ self,
+ config_class=Gemma3nConfig,
+ hidden_size=37,
+ text_config={"activation_sparsity_pattern": None},
+ )
+
+ @unittest.skip(reason="SiglipVisionModel (vision backbone) does not support standalone training")
+ def test_training_gradient_checkpointing(self):
+ pass
+
+ @unittest.skip(reason="SiglipVisionModel (vision backbone) does not support standalone training")
+ def test_training_gradient_checkpointing_use_reentrant(self):
+ pass
+
+ @unittest.skip(reason="SiglipVisionModel (vision backbone) does not support standalone training")
+ def test_training_gradient_checkpointing_use_reentrant_false(self):
+ pass
+
+ @unittest.skip(
+ reason="HybridCache can't be gathered because it is not iterable. Adding a simple iter and dumping `distributed_iterator`"
+ " as in Dynamic Cache doesnt work. NOTE: @gante all cache objects would need better compatibility with multi gpu setting"
+ )
+ def test_multi_gpu_data_parallel_forward(self):
+ pass
+
+ @unittest.skip("Failing because of unique cache (HybridCache)")
+ def test_model_outputs_equivalence(self, **kwargs):
+ pass
+
+ @parameterized.expand([("random",), ("same",)])
+ @pytest.mark.generate
+ @unittest.skip("Gemma3n has HybridCache which is not compatible with assisted decoding")
+ def test_assisted_decoding_matches_greedy_search(self, assistant_type):
+ pass
+
+ @unittest.skip("Gemma3n has HybridCache which is not compatible with assisted decoding")
+ def test_prompt_lookup_decoding_matches_greedy_search(self, assistant_type):
+ pass
+
+ @pytest.mark.generate
+ @unittest.skip("Gemma3n has HybridCache which is not compatible with assisted decoding")
+ def test_assisted_decoding_sample(self):
+ pass
+
+ @unittest.skip("Gemma3n has HybridCache which is not compatible with dola decoding")
+ def test_dola_decoding_sample(self):
+ pass
+
+ @unittest.skip("Gemma3n has HybridCache and doesn't support continue from past kv")
+ def test_generate_continue_from_past_key_values(self):
+ pass
+
+ @unittest.skip("Gemma3n has HybridCache and doesn't support low_memory generation")
+ def test_beam_search_low_memory(self):
+ pass
+
+ @unittest.skip("Gemma3n has HybridCache and doesn't support contrastive generation")
+ def test_contrastive_generate(self):
+ pass
+
+ @unittest.skip("Gemma3n has HybridCache and doesn't support contrastive generation")
+ def test_contrastive_generate_dict_outputs_use_cache(self):
+ pass
+
+ @unittest.skip("Gemma3n has HybridCache and doesn't support contrastive generation")
+ def test_contrastive_generate_low_memory(self):
+ pass
+
+ @unittest.skip("Gemma3n has HybridCache and doesn't support StaticCache. Though it could, it shouldn't support.")
+ def test_generate_with_static_cache(self):
+ pass
+
+ @unittest.skip("Gemma3n has HybridCache and doesn't support StaticCache. Though it could, it shouldn't support.")
+ def test_generate_from_inputs_embeds_with_static_cache(self):
+ pass
+
+ @unittest.skip(
+ reason="Siglip (vision backbone) uses the same initialization scheme as the Flax original implementation"
+ )
+ def test_initialization(self):
+ pass
+
+ @unittest.skip(
+ reason="Siglip has no FLEX attention, and we don't have a proper way to set/test attn in VLMs. TODO @raushan"
+ )
+ def test_flex_attention_with_grads(self):
+ pass
+
+ def test_automodelforcausallm(self):
+ """
+ Regression test for #36741 -- make sure `AutoModelForCausalLM` works with a Gemma3n config, i.e. that
+ `AutoModelForCausalLM.from_pretrained` pulls the text config before loading the model
+ """
+ config = self.model_tester.get_config()
+ model = Gemma3nForConditionalGeneration(config)
+ with tempfile.TemporaryDirectory() as tmp_dir:
+ model.save_pretrained(tmp_dir)
+ for_causal_lm = AutoModelForCausalLM.from_pretrained(tmp_dir)
+ self.assertIsInstance(for_causal_lm, Gemma3nForCausalLM)
+
+
+@unittest.skip("Skipped for now!")
+@slow
+@require_torch_gpu
+@require_read_token
+class Gemma3nIntegrationTest(unittest.TestCase):
+ def setUp(self):
+ self.processor = AutoProcessor.from_pretrained("Google/gemma-3n-E4B-it", padding_side="left")
+
+ url = "https://huggingface.co/datasets/hf-internal-testing/fixtures-captioning/resolve/main/cow_beach_1.png"
+ self.messages = [
+ {"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]},
+ {
+ "role": "user",
+ "content": [
+ {"type": "image", "url": url},
+ {"type": "text", "text": "What is shown in this image?"},
+ ],
+ },
+ ]
+
+ audio_ds = load_dataset(
+ "etechgrid/28.5k_wavfiles_dataset", "default", data_files="wav_dataset/103-1240-0000.wav"
+ )
+ self.audio_file_path = audio_ds["train"][0]["audio"]["path"]
+
+ def tearDown(self):
+ cleanup(torch_device, gc_collect=True)
+
+ def test_model_4b_bf16(self):
+ model_id = "Google/gemma-3n-E4B-it"
+
+ model = Gemma3nForConditionalGeneration.from_pretrained(
+ model_id, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16
+ ).to(torch_device)
+
+ inputs = self.processor.apply_chat_template(
+ self.messages,
+ tokenize=True,
+ return_dict=True,
+ return_tensors="pt",
+ add_generation_prompt=True,
+ ).to(torch_device)
+
+ output = model.generate(**inputs, max_new_tokens=30, do_sample=False)
+ output_text = self.processor.batch_decode(output, skip_special_tokens=True)
+
+ EXPECTED_TEXTS = ['user\nYou are a helpful assistant.\n\n\n\n\n\nWhat is shown in this image?\nmodel\nCertainly! \n\nThe image shows a brown cow standing on a sandy beach with clear blue water and a blue sky in the background. It looks like'] # fmt: skip
+ self.assertEqual(output_text, EXPECTED_TEXTS)
+
+ def test_model_with_audio(self):
+ """
+ Tests the full model pipeline with batched audio inputs provided as file paths.
+ This ensures the processor correctly loads and processes audio files.
+ """
+
+ model_id = "Google/gemma-3n-E4B-it"
+
+ model = Gemma3nForConditionalGeneration.from_pretrained(
+ model_id, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16
+ ).to(torch_device)
+
+ messages = [
+ [
+ {
+ "role": "user",
+ "content": [
+ {"type": "text", "text": "Transcribe the following speech segment in English:"},
+ {"type": "audio", "audio": str(self.audio_file_path)},
+ ],
+ }
+ ],
+ ]
+
+ inputs = self.processor.apply_chat_template(
+ messages,
+ add_generation_prompt=True,
+ tokenize=True,
+ return_dict=True,
+ padding=True,
+ return_tensors="pt",
+ ).to(torch_device, dtype=model.dtype)
+
+ input_len = inputs["input_ids"].shape[-1]
+
+ output = model.generate(**inputs, max_new_tokens=16, do_sample=False)
+ output = output[:, input_len:]
+ output_text = self.processor.batch_decode(output, skip_special_tokens=True)
+
+ EXPECTED_TEXTS = ["Chapter 1. Mrs. Rachel Lind is surprised.\n\nMrs. Rachel Lind"]
+ self.assertEqual(output_text, EXPECTED_TEXTS)
+
+ def test_model_4b_batch(self):
+ model_id = "Google/gemma-3n-E4B-it"
+
+ model = Gemma3nForConditionalGeneration.from_pretrained(
+ model_id, low_cpu_mem_usage=False, torch_dtype=torch.bfloat16
+ ).to(torch_device)
+
+ messages_2 = [
+ {"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]},
+ {
+ "role": "user",
+ "content": [
+ {
+ "type": "image",
+ "url": "https://huggingface.co/datasets/hf-internal-testing/fixtures-captioning/resolve/main/cow_beach_1.png",
+ },
+ {"type": "image", "url": "https://www.ilankelman.org/stopsigns/australia.jpg"},
+ {"type": "text", "text": "Are these images identical?"},
+ ],
+ },
+ ]
+
+ inputs = self.processor.apply_chat_template(
+ [self.messages, messages_2],
+ tokenize=True,
+ return_dict=True,
+ return_tensors="pt",
+ padding=True,
+ add_generation_prompt=True,
+ ).to(torch_device)
+
+ output = model.generate(**inputs, max_new_tokens=30, do_sample=False)
+ output_text = self.processor.batch_decode(output, skip_special_tokens=True)
+
+ EXPECTED_TEXTS = [
+ 'user\nYou are a helpful assistant.\n\n\n\n\n\nWhat is shown in this image?\nmodel\nCertainly! \n\nThe image shows a brown cow standing on a sandy beach with clear turquoise water and a blue sky in the background. It looks like',
+ "user\nYou are a helpful assistant.\n\n\n\n\n\n\n\n\n\nAre these images identical?\nmodel\nNo, these images are not identical. \n\nHere's a breakdown of the differences:\n\n* **Image 1:** Shows a cow"
+ ] # fmt: skip
+ self.assertEqual(output_text, EXPECTED_TEXTS)
+
+ def test_model_4b_crops(self):
+ model_id = "Google/gemma-3n-E4B-it"
+
+ model = Gemma3nForConditionalGeneration.from_pretrained(
+ model_id, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16
+ ).to(torch_device)
+
+ crop_config = {
+ "images_kwargs": {
+ "do_pan_and_scan": True,
+ "pan_and_scan_max_num_crops": 448,
+ "pan_and_scan_min_crop_size": 32,
+ "pan_and_scan_min_ratio_to_activate": 0.3,
+ }
+ }
+
+ inputs = self.processor.apply_chat_template(
+ self.messages,
+ tokenize=True,
+ return_dict=True,
+ return_tensors="pt",
+ add_generation_prompt=True,
+ **crop_config,
+ ).to(torch_device)
+
+ output = model.generate(**inputs, max_new_tokens=30, do_sample=False)
+ output_text = self.processor.batch_decode(output, skip_special_tokens=True)
+
+ EXPECTED_NUM_IMAGES = 3 # one for the origin image and two crops of images
+ EXPECTED_TEXTS = ['user\nYou are a helpful assistant.\n\nHere is the original image \n\n\n\n and here are some crops to help you see better \n\n\n\n \n\n\n\nWhat is shown in this image?\nmodel\nThe image shows a brown cow standing on a beach with a turquoise ocean and blue sky in the background.'] # fmt: skip
+ self.assertEqual(len(inputs["pixel_values"]), EXPECTED_NUM_IMAGES)
+ self.assertEqual(output_text, EXPECTED_TEXTS)
+
+ def test_model_4b_multiimage(self):
+ model_id = "Google/gemma-3n-E4B-it"
+
+ model = Gemma3nForConditionalGeneration.from_pretrained(
+ model_id, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16
+ ).to(torch_device)
+
+ messages = [
+ {"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]},
+ {
+ "role": "user",
+ "content": [
+ {"type": "image", "url": "https://www.ilankelman.org/stopsigns/australia.jpg"},
+ {"type": "text", "text": "What do you see here?"},
+ ],
+ },
+ ]
+
+ inputs = self.processor.apply_chat_template(
+ messages,
+ tokenize=True,
+ return_dict=True,
+ return_tensors="pt",
+ padding=True,
+ add_generation_prompt=True,
+ ).to(torch_device)
+
+ output = model.generate(**inputs, max_new_tokens=30, do_sample=False)
+ output_text = self.processor.batch_decode(output, skip_special_tokens=True)
+
+ EXPECTED_TEXTS = ["user\nYou are a helpful assistant.\n\n\n\n\n\nWhat do you see here?\nmodel\nOkay, let's break down what I see in this image:\n\n**Overall Scene:**\n\nIt looks like a street scene in a vibrant,"] # fmt: skip
+ self.assertEqual(output_text, EXPECTED_TEXTS)
+
+ def test_model_1b_text_only(self):
+ model_id = "google/gemma-3-1b-it"
+
+ model = Gemma3nForCausalLM.from_pretrained(model_id, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16).to(
+ torch_device
+ )
+ tokenizer = AutoTokenizer.from_pretrained(model_id, padding_side="left")
+ inputs = tokenizer("Write a poem about Machine Learning.", return_tensors="pt").to(torch_device)
+
+ output = model.generate(**inputs, max_new_tokens=30, do_sample=False)
+ output_text = tokenizer.batch_decode(output, skip_special_tokens=True)
+
+ EXPECTED_TEXTS = ['Write a poem about Machine Learning.\n\n---\n\nThe data flows, a river deep,\nWith patterns hidden, secrets sleep.\nA neural net, a watchful eye,\nLearning'] # fmt: skip
+ self.assertEqual(output_text, EXPECTED_TEXTS)
+
+ # TODO: raushan FA2 generates gibberish for no reason, check later
+ @require_flash_attn
+ @require_torch_gpu
+ @pytest.mark.flash_attn_test
+ def test_model_4b_flash_attn(self):
+ model_id = "Google/gemma-3n-E4B-it"
+
+ model = Gemma3nForConditionalGeneration.from_pretrained(
+ model_id, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2"
+ ).to(torch_device)
+
+ inputs = self.processor.apply_chat_template(
+ self.messages,
+ tokenize=True,
+ return_dict=True,
+ return_tensors="pt",
+ add_generation_prompt=True,
+ ).to(torch_device)
+
+ output = model.generate(**inputs, max_new_tokens=30, do_sample=False)
+ output_text = self.processor.batch_decode(output, skip_special_tokens=True)
+
+ EXPECTED_TEXTS = ['user\nYou are a helpful assistant.\n\n\n\n\n\nWhat is shown in this image?\nmodel\nCertainly! \n\nThe image shows a brown and white cow standing on a sandy beach next to a turquoise ocean. It looks like a very sunny and'] # fmt: skip
+ self.assertEqual(output_text, EXPECTED_TEXTS)
+
+ @parameterized.expand([("flash_attention_2",), ("sdpa",), ("eager",)])
+ def test_generation_beyond_sliding_window(self, attn_implementation: str):
+ """Test that we can correctly generate beyond the sliding window. This is non trivial as
+ we need to correctly slice the attention mask in all cases (because we use a HybridCache).
+ Outputs for every attention functions should be coherent and identical.
+ """
+ model_id = "google/gemma-3-1b-it"
+
+ input_text = [
+ "This is a nice place. " * 800 + "I really enjoy the scenery,", # This is larger than 4096 tokens
+ "A list of colors: red, blue", # This will almost all be padding tokens
+ ]
+ tokenizer = AutoTokenizer.from_pretrained(model_id, padding="left")
+ inputs = tokenizer(input_text, padding=True, return_tensors="pt").to(torch_device)
+
+ model = AutoModelForCausalLM.from_pretrained(
+ model_id, attn_implementation=attn_implementation, torch_dtype=torch.float16
+ ).to(torch_device)
+
+ # Make sure prefill is larger than sliding window
+ input_size = inputs.input_ids.shape[-1]
+ self.assertTrue(input_size > model.config.sliding_window)
+
+ out = model.generate(**inputs, max_new_tokens=20)[:, input_size:]
+ output_text = tokenizer.batch_decode(out)
+
+ EXPECTED_COMPLETIONS = [" and I'm going to take a walk.\n\nI really enjoy the scenery, and I'", ", green, yellow, orange, purple, brown, black, white, gray.\n\nI'"] # fmt: skip
+ self.assertEqual(output_text, EXPECTED_COMPLETIONS)
+
+ def test_generation_beyond_sliding_window_with_generation_config(self):
+ """
+ Same as `test_generation_beyond_sliding_window`, but passing a GenerationConfig. Regression test for #36684 --
+ ensures `cache_implementation='hybrid'` is correctly inherited from the base `model.generation_config`.
+ """
+ model_id = "google/gemma-3-1b-it"
+ attn_implementation = "sdpa"
+
+ input_text = [
+ "This is a nice place. " * 800 + "I really enjoy the scenery,", # This is larger than 4096 tokens
+ "A list of colors: red, blue", # This will almost all be padding tokens
+ ]
+ tokenizer = AutoTokenizer.from_pretrained(model_id, padding="left")
+ inputs = tokenizer(input_text, padding=True, return_tensors="pt").to(torch_device)
+
+ model = AutoModelForCausalLM.from_pretrained(
+ model_id, attn_implementation=attn_implementation, torch_dtype=torch.float16
+ ).to(torch_device)
+
+ # Make sure prefill is larger than sliding window
+ input_size = inputs.input_ids.shape[-1]
+ self.assertTrue(input_size > model.config.sliding_window)
+
+ generation_config = GenerationConfig(max_new_tokens=20)
+
+ out = model.generate(**inputs, generation_config=generation_config)[:, input_size:]
+ output_text = tokenizer.batch_decode(out)
+
+ EXPECTED_COMPLETIONS = [" and I'm going to take a walk.\n\nI really enjoy the scenery, and I'", ", green, yellow, orange, purple, brown, black, white, gray.\n\nI'"] # fmt: skip
+ self.assertEqual(output_text, EXPECTED_COMPLETIONS)
diff --git a/tests/models/gemma3n/test_processing_gemma3n.py b/tests/models/gemma3n/test_processing_gemma3n.py
new file mode 100644
index 000000000000..1d30a80c4896
--- /dev/null
+++ b/tests/models/gemma3n/test_processing_gemma3n.py
@@ -0,0 +1,185 @@
+# Copyright 2025 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import shutil
+import tempfile
+import unittest
+
+import numpy as np
+from parameterized import parameterized
+
+from transformers import GemmaTokenizerFast, SiglipImageProcessorFast, is_speech_available
+from transformers.testing_utils import require_sentencepiece, require_torch, require_torchaudio, require_vision
+
+from .test_feature_extraction_gemma3n import floats_list
+
+
+if is_speech_available():
+ from transformers.models.gemma3n import Gemma3nAudioFeatureExtractor, Gemma3nProcessor
+
+
+@require_torch
+@require_torchaudio
+@require_vision
+@require_sentencepiece
+class Gemma3nProcessorTest(unittest.TestCase):
+ def setUp(self):
+ # TODO: update to google?
+ self.model_id = "Google/gemma-3n-E4B-it"
+ self.tmpdirname = tempfile.mkdtemp(suffix="gemma3n")
+ self.maxDiff = None
+
+ def get_tokenizer(self, **kwargs):
+ return GemmaTokenizerFast.from_pretrained(self.model_id, **kwargs)
+
+ def get_feature_extractor(self, **kwargs):
+ return Gemma3nAudioFeatureExtractor.from_pretrained(self.model_id, **kwargs)
+
+ def get_image_processor(self, **kwargs):
+ return SiglipImageProcessorFast.from_pretrained(self.model_id, **kwargs)
+
+ def tearDown(self):
+ shutil.rmtree(self.tmpdirname)
+
+ def test_save_load_pretrained_default(self):
+ # NOTE: feature_extractor and image_processor both use the same filename, preprocessor_config.json, when saved to
+ # disk, but the files are overwritten by processor.save_pretrained(). This test does not attempt to address
+ # this potential issue, and as such, does not guarantee content accuracy.
+
+ tokenizer = self.get_tokenizer()
+ feature_extractor = self.get_feature_extractor()
+ image_processor = self.get_image_processor()
+
+ processor = Gemma3nProcessor(
+ tokenizer=tokenizer, feature_extractor=feature_extractor, image_processor=image_processor
+ )
+
+ processor.save_pretrained(self.tmpdirname)
+ processor = Gemma3nProcessor.from_pretrained(self.tmpdirname)
+
+ self.assertIsInstance(processor.tokenizer, GemmaTokenizerFast)
+ self.assertEqual(processor.tokenizer.get_vocab(), tokenizer.get_vocab())
+
+ self.assertIsInstance(processor.feature_extractor, Gemma3nAudioFeatureExtractor)
+ self.assertEqual(processor.feature_extractor.to_json_string(), feature_extractor.to_json_string())
+
+ def test_save_load_pretrained_additional_features(self):
+ tokenizer = self.get_tokenizer()
+ feature_extractor = self.get_feature_extractor()
+ image_processor = self.get_image_processor()
+
+ processor = Gemma3nProcessor(
+ tokenizer=tokenizer, feature_extractor=feature_extractor, image_processor=image_processor
+ )
+ processor.save_pretrained(self.tmpdirname)
+
+ tokenizer_add_kwargs = self.get_tokenizer(bos_token="(BOS-BOS)", eos_token="(EOS-EOS)")
+ feature_extractor_add_kwargs = self.get_feature_extractor(dither=5.0, padding_value=1.0)
+
+ processor = Gemma3nProcessor.from_pretrained(
+ self.tmpdirname, bos_token="(BOS-BOS)", eos_token="(EOS-EOS)", dither=5.0, padding_value=1.0
+ )
+
+ self.assertEqual(processor.tokenizer.get_vocab(), tokenizer_add_kwargs.get_vocab())
+ self.assertIsInstance(processor.tokenizer, GemmaTokenizerFast)
+
+ self.assertEqual(processor.feature_extractor.to_json_string(), feature_extractor_add_kwargs.to_json_string())
+ self.assertIsInstance(processor.feature_extractor, Gemma3nAudioFeatureExtractor)
+
+ @parameterized.expand([256, 512, 768, 1024])
+ def test_image_processor(self, image_size: int):
+ feature_extractor = self.get_feature_extractor()
+ tokenizer = self.get_tokenizer()
+ image_processor = self.get_image_processor()
+ processor = Gemma3nProcessor(
+ tokenizer=tokenizer, feature_extractor=feature_extractor, image_processor=image_processor
+ )
+
+ raw_image = np.random.randint(0, 256, size=(image_size, image_size, 3), dtype=np.uint8)
+ input_image_processor = image_processor(raw_image, return_tensors="pt")
+ input_processor = processor(text="Describe:", images=raw_image, return_tensors="pt")
+
+ for key in input_image_processor.keys():
+ self.assertAlmostEqual(input_image_processor[key].sum(), input_processor[key].sum(), delta=1e-2)
+ if "pixel_values" in key:
+ # NOTE: all images should be re-scaled to 768x768
+ self.assertEqual(input_image_processor[key].shape, (1, 3, 768, 768))
+ self.assertEqual(input_processor[key].shape, (1, 3, 768, 768))
+
+ def test_audio_feature_extractor(self):
+ feature_extractor = self.get_feature_extractor()
+ tokenizer = self.get_tokenizer()
+ image_processor = self.get_image_processor()
+ processor = Gemma3nProcessor(
+ tokenizer=tokenizer, feature_extractor=feature_extractor, image_processor=image_processor
+ )
+
+ raw_speech = floats_list((3, 1000))
+ input_feat_extract = feature_extractor(raw_speech, return_tensors="pt")
+ input_processor = processor(text="Transcribe:", audio=raw_speech, return_tensors="pt")
+
+ for key in input_feat_extract.keys():
+ self.assertAlmostEqual(input_feat_extract[key].sum(), input_processor[key].sum(), delta=1e-2)
+
+ def test_tokenizer(self):
+ feature_extractor = self.get_feature_extractor()
+ tokenizer = self.get_tokenizer()
+ image_processor = self.get_image_processor()
+ processor = Gemma3nProcessor(
+ tokenizer=tokenizer, feature_extractor=feature_extractor, image_processor=image_processor
+ )
+
+ input_str = "This is a test string"
+
+ encoded_processor = processor(text=input_str)
+
+ encoded_tok = tokenizer(input_str)
+
+ for key in encoded_tok.keys():
+ self.assertListEqual(encoded_tok[key], encoded_processor[key][0])
+
+ def test_tokenizer_decode(self):
+ feature_extractor = self.get_feature_extractor()
+ tokenizer = self.get_tokenizer()
+ image_processor = self.get_image_processor()
+ processor = Gemma3nProcessor(
+ tokenizer=tokenizer, feature_extractor=feature_extractor, image_processor=image_processor
+ )
+
+ predicted_ids = [[1, 4, 5, 8, 1, 0, 8], [3, 4, 3, 1, 1, 8, 9]]
+
+ decoded_processor = processor.batch_decode(predicted_ids)
+ decoded_tok = tokenizer.batch_decode(predicted_ids)
+
+ self.assertListEqual(decoded_tok, decoded_processor)
+
+ def test_model_input_names(self):
+ feature_extractor = self.get_feature_extractor()
+ tokenizer = self.get_tokenizer()
+ image_processor = self.get_image_processor()
+ processor = Gemma3nProcessor(
+ tokenizer=tokenizer, feature_extractor=feature_extractor, image_processor=image_processor
+ )
+
+ for key in feature_extractor.model_input_names:
+ self.assertIn(
+ key,
+ processor.model_input_names,
+ )
+
+ for key in image_processor.model_input_names:
+ self.assertIn(
+ key,
+ processor.model_input_names,
+ )
diff --git a/utils/check_config_attributes.py b/utils/check_config_attributes.py
index 22d6b033afbc..04fb04a64738 100644
--- a/utils/check_config_attributes.py
+++ b/utils/check_config_attributes.py
@@ -277,6 +277,7 @@
],
"Llama4VisionConfig": ["multi_modal_projector_bias", "norm_eps"],
"SmolLM3Config": ["no_rope_layer_interval"],
+ "Gemma3nVisionConfig": ["architecture", "do_pooling", "model_args"], # this is for use in `timm`
}
diff --git a/utils/check_docstrings.py b/utils/check_docstrings.py
index 3c27476bdc03..bc247b2b6011 100644
--- a/utils/check_docstrings.py
+++ b/utils/check_docstrings.py
@@ -79,6 +79,7 @@
# docstrings instead. If formatting should be ignored for the docstring, you can put a comment # no-format on the
# line before the docstring.
OBJECTS_TO_IGNORE = [
+ "Gemma3nVisionConfig",
"Llama4Processor",
# Deprecated
"InputExample",