diff --git a/examples/gemma/README.md b/examples/gemma/README.md index b3035efa990..7e026a5c9de 100644 --- a/examples/gemma/README.md +++ b/examples/gemma/README.md @@ -24,6 +24,8 @@ - [Run inference under INT8 KV caches for keras checkpoint](#run-inference-under-int8-kv-caches-for-keras-checkpoint) - [Run Gemma 2](#run-gemma-2) - [Run inference under bfloat16 for torch checkpoint](#run-inference-under-bfloat16-for-torch-checkpoint-1) + - [Run Gemma 3](#run-gemma-3) + - [Run inference under bfloat16 for HF checkpoint](#run-inference-under-bfloat16-for-hf-checkpoint-1) - [Run Modelopt Quantization](#run-modelopt-quantization) - [Requirements](#requirements) - [Quantize Checkpoints](#quantize-checkpoints) @@ -628,6 +630,52 @@ Average accuracy 0.697 - other (business, health, misc.) Average accuracy: 0.630 ``` +### Run Gemma 3 + +Gemma 3's text generation model from HuggingFace is supported. Gemma3 1B model interleaves 5 local layers between each global layer. While local layers use sliding window attention with a short span of 512 tokens, global layers attend to the long context. TRTLLM support layerwise sliding-window attention and the sliding window size for each layer could be passed in using the `--max_attention_window_size` parameter at runtime. If a subpattern is provided, TRTLLM can extrapolate the complete pattern and the extrapolation logic is printed to terminal. + +#### Run inference under bfloat16 for HF checkpoint +```bash +git clone https://huggingface.co/google/gemma-3-1b-it + +CKPT_PATH=gemma-3-1b-it/ +UNIFIED_CKPT_PATH=/tmp/checkpoints/tmp_1b_it_tensorrt_llm/bf16/tp1/ +ENGINE_PATH=/tmp/gemma3/1b/bf16/1-gpu/ +VOCAB_FILE_PATH=gemma-3-1b-it/tokenizer.model + +python3 ./examples/gemma/convert_checkpoint.py \ + --ckpt-type hf \ + --model-dir ${CKPT_PATH} \ + --dtype bfloat16 \ + --world-size 1 \ + --output-model-dir ${UNIFIED_CKPT_PATH} + +trtllm-build --checkpoint_dir ${UNIFIED_CKPT_PATH} \ + --gemm_plugin auto \ + --max_batch_size 8 \ + --max_input_len 3000 \ + --max_seq_len 3100 \ + --output_dir ${ENGINE_PATH} + +python3 ./examples/summarize.py --test_trt_llm \ + --vocab_file ${VOCAB_FILE_PATH} \ + --engine_dir ${ENGINE_PATH} \ + --batch_size 1 \ + --max_ite 5 \ + --max_attention_window_size 512 512 512 512 512 3100 + +... +[TensorRT-LLM][INFO] TRTGptModel mMaxAttentionWindowSize: (512, 512, 512, 512, 512, 3100) * 4 + (512, 512) +... +[04/09/2025-18:28:26] [TRT-LLM] [I] TensorRT-LLM (total latency: 1.6197962760925293 sec) +[04/09/2025-18:28:26] [TRT-LLM] [I] TensorRT-LLM (total output tokens: 475) +[04/09/2025-18:28:26] [TRT-LLM] [I] TensorRT-LLM (tokens per second: 293.2467539349165) +[04/09/2025-18:28:26] [TRT-LLM] [I] TensorRT-LLM beam 0 result +[04/09/2025-18:28:26] [TRT-LLM] [I] rouge1: 22.780314381954003 +[04/09/2025-18:28:26] [TRT-LLM] [I] rouge2: 4.331099231480823 +[04/09/2025-18:28:26] [TRT-LLM] [I] rougeL: 15.26751867562475 +[04/09/2025-18:28:26] [TRT-LLM] [I] rougeLsum: 20.14696930976001 +``` ### Run Modelopt Quantization diff --git a/tensorrt_llm/layers/attention.py b/tensorrt_llm/layers/attention.py index e3c26fb62d6..231b7be8daa 100755 --- a/tensorrt_llm/layers/attention.py +++ b/tensorrt_llm/layers/attention.py @@ -175,6 +175,13 @@ def __init__(self, self.embed_positions = None self.rotary_inv_freq = None self.embed_positions_for_gpt_attention = None + + # auxiliary params to support models with non-homegeneous attn layers requiring + # a different set of rope params. e.g. Gemma3. + self.embed_positions_local = None + self.rotary_inv_freq_local = None + self.embed_positions_for_gpt_attention_local = None + # long rope const parameters self.long_rope_embed_positions = None self.long_rope_rotary_inv_freq = None @@ -186,10 +193,16 @@ def fill_attention_const_params_for_rope( self, embed_positions: Tensor = None, rotary_inv_freq: Tensor = None, - embed_positions_for_gpt_attention: Tensor = None): + embed_positions_for_gpt_attention: Tensor = None, + embed_positions_local: Tensor = None, + rotary_inv_freq_local: Tensor = None, + embed_positions_for_gpt_attention_local: Tensor = None): self.embed_positions = embed_positions self.rotary_inv_freq = rotary_inv_freq self.embed_positions_for_gpt_attention = embed_positions_for_gpt_attention + self.embed_positions_local = embed_positions_local + self.rotary_inv_freq_local = rotary_inv_freq_local + self.embed_positions_for_gpt_attention_local = embed_positions_for_gpt_attention_local return self def fill_attention_const_params_for_long_rope( @@ -359,6 +372,7 @@ def __init__(self, dtype=None, position_embedding_type=PositionEmbeddingType.learned_absolute, rotary_embedding_base=10000.0, + rotary_embedding_base_local=1.0, rotary_embedding_scaling=None, rotary_embedding_percentage=1.0, rope_scaling_short_factors=None, @@ -388,7 +402,8 @@ def __init__(self, cp_size=1, cp_rank=0, max_seqlen_for_logn_scaling=8192, - use_logn_scaling=False): + use_logn_scaling=False, + is_local=False): super().__init__() self.local_layer_idx = local_layer_idx @@ -417,6 +432,7 @@ def __init__(self, self.cp_group = cp_group self.cp_size = cp_size self.cp_rank = cp_rank + self.is_local = is_local self.num_layers = num_layers self.apply_query_key_layer_scaling = apply_query_key_layer_scaling @@ -437,6 +453,7 @@ def __init__(self, self.max_distance = max_distance self.num_buckets = num_buckets self.rotary_embedding_base = rotary_embedding_base + self.rotary_embedding_base_local = rotary_embedding_base_local self.rotary_embedding_scaling = rotary_embedding_scaling self.rotary_embedding_scale_type = RotaryScalingType.none self.rotary_embedding_scale = 1.0 @@ -656,26 +673,45 @@ def create_attention_const_params(model_cls, config): model_cls.short_mscale = short_mscale model_cls.long_mscale = long_mscale else: - # Rotary const weights. - embed_positions = RopeEmbeddingUtils.create_sinusoidal_positions( - max_position_embeddings, - rotary_embedding_dim, - ) - rotary_inv_freq, embed_positions_for_gpt_attention = RopeEmbeddingUtils.create_sinusoidal_positions_for_attention_plugin( - max_position_embeddings, rotary_embedding_dim, - rotary_embedding_base, rotary_embedding_scale, - rotary_embedding_scale_type, rotary_embedding_scaling) - model_cls.register_parameter( - 'embed_positions', - Parameter(embed_positions, dtype='float32', is_buffer=True)) - model_cls.register_parameter( - 'rotary_inv_freq', - Parameter(rotary_inv_freq, dtype='float32', is_buffer=True)) - model_cls.register_parameter( - 'embed_positions_for_gpt_attention', - Parameter(embed_positions_for_gpt_attention, - dtype='float32', - is_buffer=True)) + + def register_rope_params(rotary_base, names_to_register): + # Rotary const weights. + embed_positions = RopeEmbeddingUtils.create_sinusoidal_positions( + max_position_embeddings, + rotary_embedding_dim, + ) + rotary_inv_freq, embed_positions_for_gpt_attention = RopeEmbeddingUtils.create_sinusoidal_positions_for_attention_plugin( + max_position_embeddings, rotary_embedding_dim, rotary_base, + rotary_embedding_scale, rotary_embedding_scale_type, + rotary_embedding_scaling) + model_cls.register_parameter( + names_to_register[0], + Parameter(embed_positions, dtype='float32', is_buffer=True)) + model_cls.register_parameter( + names_to_register[1], + Parameter(rotary_inv_freq, dtype='float32', is_buffer=True)) + model_cls.register_parameter( + names_to_register[2], + Parameter(embed_positions_for_gpt_attention, + dtype='float32', + is_buffer=True)) + + register_rope_params(rotary_base=rotary_embedding_base, + names_to_register=[ + 'embed_positions', 'rotary_inv_freq', + 'embed_positions_for_gpt_attention' + ]) + + # For models with non-homegeneous attention layers requiring a second set of rope params. e.g. Gemma3. + rotary_embedding_base_local = getattr(config, + 'rope_local_base_freq', None) + if rotary_embedding_base_local is not None: + register_rope_params( + rotary_base=rotary_embedding_base_local, + names_to_register=[ + 'embed_positions_local', 'rotary_inv_freq_local', + 'embed_positions_for_gpt_attention_local' + ]) @staticmethod def fill_attention_params(model_cls, attention_params): @@ -695,7 +731,15 @@ def fill_attention_params(model_cls, attention_params): return attention_params.fill_attention_const_params_for_rope( model_cls.embed_positions.value, model_cls.rotary_inv_freq.value, - model_cls.embed_positions_for_gpt_attention.value) + model_cls.embed_positions_for_gpt_attention.value, + model_cls.embed_positions_local.value if hasattr( + model_cls, "embed_positions_local") else None, + model_cls.rotary_inv_freq_local.value if hasattr( + model_cls, "rotary_inv_freq_local") else None, + model_cls.embed_positions_for_gpt_attention_local.value + if hasattr( + model_cls, + "embed_positions_for_gpt_attention_local") else None) # Fill nothing. return attention_params @@ -1020,6 +1064,11 @@ def compute_cross_kv(encoder_output): # Rotary cos/sin cache. rotary_cos_sin = getattr(attention_params, "embed_positions_for_gpt_attention", None) + rotary_inv_freq_local = getattr(attention_params, + "rotary_inv_freq_local", None) + rotary_cos_sin_local = getattr( + attention_params, "embed_positions_for_gpt_attention_local", + None) long_rope_rotary_inv_freq = getattr(attention_params, "long_rope_rotary_inv_freq", @@ -1062,7 +1111,8 @@ def compute_cross_kv(encoder_output): hidden_size_per_head=self.attention_head_size, q_scaling=self.q_scaling, rotary_embedding_dim=self.rotary_embedding_dim, - rotary_embedding_base=self.rotary_embedding_base, + rotary_embedding_base=self.rotary_embedding_base + if not self.is_local else self.rotary_embedding_base_local, rotary_embedding_scale_type=self.rotary_embedding_scale_type, rotary_embedding_short_m_scale=attention_params.short_mscale, rotary_embedding_long_m_scale=attention_params.long_mscale, @@ -1071,8 +1121,10 @@ def compute_cross_kv(encoder_output): rotary_embedding_original_max_positions=self. original_max_position_embeddings, position_embedding_type=self.position_embedding_type, - rotary_inv_freq=rotary_inv_freq, - rotary_cos_sin=rotary_cos_sin, + rotary_inv_freq=rotary_inv_freq + if not self.is_local else rotary_inv_freq_local, + rotary_cos_sin=rotary_cos_sin + if not self.is_local else rotary_cos_sin_local, kv_orig_quant_scale=kv_orig_quant_scale, kv_quant_orig_scale=kv_quant_orig_scale, attention_output_orig_quant_scale=self. diff --git a/tensorrt_llm/models/__init__.py b/tensorrt_llm/models/__init__.py index fdb13829f13..cbff1ed4c07 100755 --- a/tensorrt_llm/models/__init__.py +++ b/tensorrt_llm/models/__init__.py @@ -33,7 +33,8 @@ from .enc_dec.model import DecoderModel, EncoderModel, WhisperEncoder from .falcon.config import FalconConfig from .falcon.model import FalconForCausalLM, FalconModel -from .gemma.config import GEMMA2_ARCHITECTURE, GEMMA_ARCHITECTURE, GemmaConfig +from .gemma.config import (GEMMA2_ARCHITECTURE, GEMMA3_ARCHITECTURE, + GEMMA_ARCHITECTURE, GemmaConfig) from .gemma.model import GemmaForCausalLM from .gpt.config import GPTConfig from .gpt.model import GPTForCausalLM, GPTModel @@ -183,6 +184,7 @@ 'SkyworkForCausalLM': LLaMAForCausalLM, GEMMA_ARCHITECTURE: GemmaForCausalLM, GEMMA2_ARCHITECTURE: GemmaForCausalLM, + GEMMA3_ARCHITECTURE: GemmaForCausalLM, 'QWenLMHeadModel': QWenForCausalLM, 'QWenForCausalLM': QWenForCausalLM, 'Qwen2ForCausalLM': QWenForCausalLM, diff --git a/tensorrt_llm/models/gemma/config.py b/tensorrt_llm/models/gemma/config.py index 028c228d3c9..d95244f79d2 100644 --- a/tensorrt_llm/models/gemma/config.py +++ b/tensorrt_llm/models/gemma/config.py @@ -19,6 +19,7 @@ from tensorrt_llm.mapping import Mapping from tensorrt_llm.models.convert_utils import infer_dtype from tensorrt_llm.models.modeling_utils import (Gemma2ConfigGroup, + Gemma3ConfigGroup, PretrainedConfig, QuantConfig) if TYPE_CHECKING: @@ -30,6 +31,7 @@ GEMMA_ARCHITECTURE = "GemmaForCausalLM" GEMMA2_ARCHITECTURE = "Gemma2ForCausalLM" +GEMMA3_ARCHITECTURE = "Gemma3ForCausalLM" class GemmaConfig(PretrainedConfig): @@ -48,6 +50,9 @@ def __init__( final_logit_softcapping: Optional[float] = None, attn_logit_softcapping: Optional[float] = None, mapping: Optional[Union[Mapping, dict]] = None, + sliding_window_pattern: int = None, + rope_local_base_freq: int = None, + sliding_window: int = None, **kwargs, ): use_parallel_embedding = False @@ -79,23 +84,29 @@ def __init__( self.mlp_bias = mlp_bias self.inter_layernorms = False - if self.is_gemma_2: + if self.is_gemma_2 or self.is_gemma_3: self.inter_layernorms = True - assert query_pre_attn_scalar is not None, "Gemma2 models must configure `query_pre_attn_scalar`" + assert query_pre_attn_scalar is not None, "Gemma2 / Gemma3 models must configure `query_pre_attn_scalar`" self.query_pre_attn_scalar = query_pre_attn_scalar self.final_logit_softcapping = final_logit_softcapping - self.attn_logit_softcapping = attn_logit_softcapping + if self.is_gemma_2: + self.attn_logit_softcapping = attn_logit_softcapping + if self.is_gemma_3: + self.sliding_window_pattern = sliding_window_pattern + self.rope_local_base_freq = rope_local_base_freq + self.sliding_window = sliding_window GEMMA_ADDED_FIELDS = { "rotary_base", "rotary_scaling", "attn_bias", "mlp_bias", "inter_layernorms" } GEMMA2_ADDED_FIELDS = Gemma2ConfigGroup.keys() + GEMMA3_ADDED_FIELDS = Gemma3ConfigGroup.keys() VERBATIM = { "num_hidden_layers", "num_attention_heads", "hidden_size", "intermediate_size", "vocab_size", "max_position_embeddings", "hidden_act", "use_parallel_embedding" - } | GEMMA2_ADDED_FIELDS + } | GEMMA2_ADDED_FIELDS | GEMMA3_ADDED_FIELDS @property def is_gemma_2(self) -> bool: @@ -106,6 +117,15 @@ def gemma2_config(self): return self.get_config_group(Gemma2ConfigGroup) return None + @property + def is_gemma_3(self) -> bool: + return self.architecture == GEMMA3_ARCHITECTURE + + def gemma3_config(self): + if self.is_gemma_3: + return self.get_config_group(Gemma3ConfigGroup) + return None + def to_dict(self): """Serialize the fields added in GemmaConfig""" @@ -118,7 +138,11 @@ def to_dict(self): **({ f: getattr(self, f) for f in self.GEMMA2_ADDED_FIELDS - } if self.is_gemma_2 else {}) + } if self.is_gemma_2 else {}), + **({ + f: getattr(self, f) + for f in self.GEMMA3_ADDED_FIELDS + } if self.is_gemma_3 else {}) } @classmethod @@ -148,6 +172,7 @@ def from_hugging_face( norm_epsilon=hf_config.rms_norm_eps, num_key_value_heads=getattr(hf_config, "num_key_value_heads", hf_config.num_attention_heads), + rotary_base=getattr(hf_config, "rope_theta", 10000.0), rotary_scaling=getattr(hf_config, "rotary_scaling", None), quantization=quant_config, mapping=mapping, diff --git a/tensorrt_llm/models/gemma/convert.py b/tensorrt_llm/models/gemma/convert.py index 9536891e34d..2f2151f67a5 100644 --- a/tensorrt_llm/models/gemma/convert.py +++ b/tensorrt_llm/models/gemma/convert.py @@ -317,6 +317,10 @@ def rename_to_trt_llm(self, name: str) -> Optional[str]: None), # merged with above (r"model.layers.(\d+).self_attn.o_proj.weight", r"layers.\1.attention.dense.weight"), + (r"model.layers.(\d+).self_attn.q_norm.weight", + r"layers.\1.attention.q_layernorm.weight"), + (r"model.layers.(\d+).self_attn.k_norm.weight", + r"layers.\1.attention.k_layernorm.weight"), (r"model.layers.(\d+).mlp.gate_proj.weight", r"layers.\1.mlp.fc.weight"), (r"model.layers.(\d+).mlp.up_proj.weight", @@ -795,6 +799,8 @@ def load_gemma_weights( "pre_feedforward_layernorm", "post_feedforward_layernorm", "model.norm.weight", + "q_norm.weight", + "k_norm.weight", )): param = param + 1.0 # upcasted to float32 in case of bfloat16 add_trt_llm_weight(weights, trt_llm_name, param, diff --git a/tensorrt_llm/models/gemma/model.py b/tensorrt_llm/models/gemma/model.py index 46b301177bd..f39ed97e512 100644 --- a/tensorrt_llm/models/gemma/model.py +++ b/tensorrt_llm/models/gemma/model.py @@ -23,8 +23,8 @@ from ..._common import default_net from ..._utils import pad_vocab_size -from ...functional import (AllReduceFusionOp, AllReduceParams, Tensor, cast, - recv, send) +from ...functional import (AllReduceFusionOp, AllReduceParams, LayerNormType, + Tensor, cast, recv, send) from ...layers import (Attention, AttentionMaskType, AttentionParams, ColumnLinear, Embedding, GatedMLP, KeyValueCacheParams, LoraParams, PositionEmbeddingType, RmsNorm) @@ -56,13 +56,26 @@ def __init__(self, config: GemmaConfig, layer_idx: int): q_scaling = 1.0 max_attn_value = 0.0 + qk_layernorm = False + is_sliding = False + rotary_base = config.rotary_base + rotary_base_local = None gemma2_config = config.gemma2_config() + gemma3_config = config.gemma3_config() if gemma2_config: q_scaling = math.sqrt( gemma2_config.query_pre_attn_scalar) / math.sqrt( config.head_size) max_attn_value = config.attn_logit_softcapping or 0.0 + elif gemma3_config: + qk_layernorm = True + q_scaling = math.sqrt( + gemma3_config.query_pre_attn_scalar) / math.sqrt( + config.head_size) + is_sliding = bool( + (layer_idx + 1) % gemma3_config.sliding_window_pattern) + rotary_base_local = config.rope_local_base_freq self.attention = Attention( local_layer_idx=self.local_layer_idx, @@ -70,18 +83,22 @@ def __init__(self, config: GemmaConfig, layer_idx: int): num_attention_heads=config.num_attention_heads, num_kv_heads=config.num_key_value_heads, attention_head_size=config.head_size, + qk_layernorm=qk_layernorm, + layernorm_type=LayerNormType.RmsNorm, max_position_embeddings=config.max_position_embeddings, dtype=config.dtype, attention_mask_type=AttentionMaskType.causal, bias=config.attn_bias, position_embedding_type=PositionEmbeddingType.rope_gpt_neox, - rotary_embedding_base=config.rotary_base, + rotary_embedding_base=rotary_base, + rotary_embedding_base_local=rotary_base_local, rotary_embedding_scaling=config.rotary_scaling, tp_group=config.mapping.tp_group, tp_size=config.mapping.tp_size, quant_mode=config.quant_mode, q_scaling=q_scaling, max_attn_value=max_attn_value, + is_local=is_sliding, ) mlp_hidden_size = config.hidden_size * 4 if config.intermediate_size is None else config.intermediate_size diff --git a/tensorrt_llm/models/modeling_utils.py b/tensorrt_llm/models/modeling_utils.py index 1716003b0da..6612af446d7 100644 --- a/tensorrt_llm/models/modeling_utils.py +++ b/tensorrt_llm/models/modeling_utils.py @@ -63,12 +63,25 @@ def keys(cls): return {f.name for f in dataclasses.fields(cls)} +@dataclasses.dataclass(kw_only=True, frozen=True) +class Gemma3ConfigGroup: + query_pre_attn_scalar: float + final_logit_softcapping: Optional[float] + sliding_window_pattern: int + rope_local_base_freq: int + sliding_window: int + + @classmethod + def keys(cls): + return {f.name for f in dataclasses.fields(cls)} + + if TYPE_CHECKING: from typing import Type, TypeVar from typing_extensions import Self - ConfigGroups = Union[Gemma2ConfigGroup] + ConfigGroups = Union[Gemma2ConfigGroup, Gemma3ConfigGroup] """Groupings of config where, if one of said properties exists, we assume all of the properties exist (even if they are `None`)""" CG = TypeVar("CG", bound=ConfigGroups) @@ -1027,6 +1040,9 @@ def forward(self, else: assert False, "Context parallelism with non-remove-padding is not supported yet." + is_gemma_2_cg = self.config.has_config_group(Gemma2ConfigGroup) + is_gemma_3_cg = self.config.has_config_group(Gemma3ConfigGroup) + kwargs = { 'input_ids': input_ids, 'position_ids': position_ids, @@ -1080,9 +1096,10 @@ def forward(self, lm_logits *= getattr(self.config, 'output_multiplier_scale', 1) if self.mup_width_multiplier is not None: lm_logits = lm_logits / self.mup_width_multiplier - if self.config.has_config_group(Gemma2ConfigGroup): + if is_gemma_2_cg or is_gemma_3_cg: softcap = self.config.get_config_group( - Gemma2ConfigGroup).final_logit_softcapping + Gemma2ConfigGroup if not is_gemma_3_cg else + Gemma3ConfigGroup).final_logit_softcapping if softcap: lm_logits = lm_logits * float(1 / softcap) lm_logits = tanh(lm_logits) * float(softcap) diff --git a/tensorrt_llm/quantization/quantize_by_modelopt.py b/tensorrt_llm/quantization/quantize_by_modelopt.py index 6f721c30e07..2bd199b3a3b 100755 --- a/tensorrt_llm/quantization/quantize_by_modelopt.py +++ b/tensorrt_llm/quantization/quantize_by_modelopt.py @@ -142,6 +142,7 @@ def model_type_is_enc_dec(model_type): "QWen": "qwen", "Qwen2VLForConditionalGeneration": "qwen2_vl", "RecurrentGemma": "recurrentgemma", + "Gemma3": "gemma3", "Gemma2": "gemma2", "Gemma": "gemma", "MixtralForCausalLM": "llama", diff --git a/tests/integration/defs/examples/test_gemma.py b/tests/integration/defs/examples/test_gemma.py index 74d3becb1c4..f16bfb6deff 100644 --- a/tests/integration/defs/examples/test_gemma.py +++ b/tests/integration/defs/examples/test_gemma.py @@ -150,7 +150,7 @@ def test_llm_hf_gemma_quantization_1gpu(batch_size, data_type, gemma_model_root, @pytest.mark.parametrize("gemma_model_root", [ "gemma-2b", "gemma-7b", "gemma-2b-torch", "gemma-7b-torch", "gemma-2b-keras", "gemma-7b-keras", "gemma-2b-it-flax", "gemma-7b-it-flax", - "gemma-2-9b-it", "gemma-2-27b-it" + "gemma-2-9b-it", "gemma-2-27b-it", "gemma-3-1b-it" ], indirect=True) def test_llm_gemma_1gpu_summary(batch_size, data_type, gemma_model_root, @@ -232,6 +232,13 @@ def test_llm_gemma_1gpu_summary(batch_size, data_type, gemma_model_root, else: summary_cmd.append(f"--vocab_file={vocab_file}") + os.path.basename(gemma_model_root) + if 'gemma-3-1b-it' in gemma_model_root: + max_attention_window_size = [512, 512, 512, 512, 512, 3100] + summary_cmd.append(f"--max_attention_window_size") + for window_size in max_attention_window_size: + summary_cmd.append(str(window_size)) + venv_check_call(llm_venv, summary_cmd) diff --git a/tests/integration/test_lists/qa/examples_test_list.txt b/tests/integration/test_lists/qa/examples_test_list.txt index 1ce26f43e9a..551c6132571 100644 --- a/tests/integration/test_lists/qa/examples_test_list.txt +++ b/tests/integration/test_lists/qa/examples_test_list.txt @@ -35,6 +35,7 @@ examples/test_exaone.py::test_llm_exaone_1gpu[disable_weight_only-exaone_3.0_7.8 examples/test_exaone.py::test_llm_exaone_1gpu[enable_weight_only-exaone_deep_2.4b-float16-nb:1] examples/test_exaone.py::test_llm_exaone_2gpu[exaone_3.0_7.8b_instruct-float16-nb:1] examples/test_gemma.py::test_llm_gemma_1gpu_summary[gemma-2-27b-it-other-bfloat16-8] +examples/test_gemma.py::test_llm_gemma_1gpu_summary[gemma-3-1b-it-other-bfloat16-8] examples/test_gemma.py::test_llm_hf_gemma_quantization_1gpu[gemma-2-27b-it-fp8-bfloat16-8] examples/test_gemma.py::test_hf_gemma_fp8_base_bf16_multi_lora[gemma-2-9b-it] examples/test_gemma.py::test_hf_gemma_fp8_base_bf16_multi_lora[gemma-2-27b-it] diff --git a/tests/integration/test_lists/test-db/l0_h100.yml b/tests/integration/test_lists/test-db/l0_h100.yml index 9be1ebcc6bd..cb0f4935876 100644 --- a/tests/integration/test_lists/test-db/l0_h100.yml +++ b/tests/integration/test_lists/test-db/l0_h100.yml @@ -125,6 +125,7 @@ l0_h100: - examples/test_llama.py::test_llama_3_x_fp8_with_bf16_lora[llama-3.2-1b] - examples/test_qwen.py::test_llm_hf_qwen_multi_lora_1gpu[qwen2.5_1.5b_instruct] - examples/test_gemma.py::test_hf_gemma_fp8_base_bf16_multi_lora[gemma-2-9b-it] + - examples/test_gemma.py::test_llm_gemma_1gpu_summary[gemma-3-1b-it-other-bfloat16-8] - examples/test_phi.py::test_llm_phi_quantization_1gpu[Phi-4-mini-instruct-fp8-bfloat16] - unittest/trt/model_api/test_model_level_api.py # 9 mins on H100 - unittest/trt/model_api/test_model_api_multi_gpu.py # 0.5 mins on H100