Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions examples/gemma/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
- [Run inference under INT8 KV caches for keras checkpoint](#run-inference-under-int8-kv-caches-for-keras-checkpoint)
- [Run Gemma 2](#run-gemma-2)
- [Run inference under bfloat16 for torch checkpoint](#run-inference-under-bfloat16-for-torch-checkpoint-1)
- [Run Gemma 3](#run-gemma-3)
- [Run inference under bfloat16 for HF checkpoint](#run-inference-under-bfloat16-for-hf-checkpoint-1)
- [Run Modelopt Quantization](#run-modelopt-quantization)
- [Requirements](#requirements)
- [Quantize Checkpoints](#quantize-checkpoints)
Expand Down Expand Up @@ -628,6 +630,52 @@ Average accuracy 0.697 - other (business, health, misc.)
Average accuracy: 0.630
```

### Run Gemma 3

Gemma 3's text generation model from HuggingFace is supported. Gemma3 1B model interleaves 5 local layers between each global layer. While local layers use sliding window attention with a short span of 512 tokens, global layers attend to the long context. TRTLLM support layerwise sliding-window attention and the sliding window size for each layer could be passed in using the `--max_attention_window_size` parameter at runtime. If a subpattern is provided, TRTLLM can extrapolate the complete pattern and the extrapolation logic is printed to terminal.

#### Run inference under bfloat16 for HF checkpoint
```bash
git clone https://huggingface.co/google/gemma-3-1b-it

CKPT_PATH=gemma-3-1b-it/
UNIFIED_CKPT_PATH=/tmp/checkpoints/tmp_1b_it_tensorrt_llm/bf16/tp1/
ENGINE_PATH=/tmp/gemma3/1b/bf16/1-gpu/
VOCAB_FILE_PATH=gemma-3-1b-it/tokenizer.model

python3 ./examples/gemma/convert_checkpoint.py \
--ckpt-type hf \
--model-dir ${CKPT_PATH} \
--dtype bfloat16 \
--world-size 1 \
--output-model-dir ${UNIFIED_CKPT_PATH}

trtllm-build --checkpoint_dir ${UNIFIED_CKPT_PATH} \
--gemm_plugin auto \
--max_batch_size 8 \
--max_input_len 3000 \
--max_seq_len 3100 \
--output_dir ${ENGINE_PATH}

python3 ./examples/summarize.py --test_trt_llm \
--vocab_file ${VOCAB_FILE_PATH} \
--engine_dir ${ENGINE_PATH} \
--batch_size 1 \
--max_ite 5 \
--max_attention_window_size 512 512 512 512 512 3100

...
[TensorRT-LLM][INFO] TRTGptModel mMaxAttentionWindowSize: (512, 512, 512, 512, 512, 3100) * 4 + (512, 512)
...
[04/09/2025-18:28:26] [TRT-LLM] [I] TensorRT-LLM (total latency: 1.6197962760925293 sec)
[04/09/2025-18:28:26] [TRT-LLM] [I] TensorRT-LLM (total output tokens: 475)
[04/09/2025-18:28:26] [TRT-LLM] [I] TensorRT-LLM (tokens per second: 293.2467539349165)
[04/09/2025-18:28:26] [TRT-LLM] [I] TensorRT-LLM beam 0 result
[04/09/2025-18:28:26] [TRT-LLM] [I] rouge1: 22.780314381954003
[04/09/2025-18:28:26] [TRT-LLM] [I] rouge2: 4.331099231480823
[04/09/2025-18:28:26] [TRT-LLM] [I] rougeL: 15.26751867562475
[04/09/2025-18:28:26] [TRT-LLM] [I] rougeLsum: 20.14696930976001
```

### Run Modelopt Quantization

Expand Down
104 changes: 78 additions & 26 deletions tensorrt_llm/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,13 @@ def __init__(self,
self.embed_positions = None
self.rotary_inv_freq = None
self.embed_positions_for_gpt_attention = None

# auxiliary params to support models with non-homegeneous attn layers requiring
# a different set of rope params. e.g. Gemma3.
self.embed_positions_local = None
self.rotary_inv_freq_local = None
self.embed_positions_for_gpt_attention_local = None

# long rope const parameters
self.long_rope_embed_positions = None
self.long_rope_rotary_inv_freq = None
Expand All @@ -186,10 +193,16 @@ def fill_attention_const_params_for_rope(
self,
embed_positions: Tensor = None,
rotary_inv_freq: Tensor = None,
embed_positions_for_gpt_attention: Tensor = None):
embed_positions_for_gpt_attention: Tensor = None,
embed_positions_local: Tensor = None,
rotary_inv_freq_local: Tensor = None,
embed_positions_for_gpt_attention_local: Tensor = None):
self.embed_positions = embed_positions
self.rotary_inv_freq = rotary_inv_freq
self.embed_positions_for_gpt_attention = embed_positions_for_gpt_attention
self.embed_positions_local = embed_positions_local
self.rotary_inv_freq_local = rotary_inv_freq_local
self.embed_positions_for_gpt_attention_local = embed_positions_for_gpt_attention_local
return self

def fill_attention_const_params_for_long_rope(
Expand Down Expand Up @@ -359,6 +372,7 @@ def __init__(self,
dtype=None,
position_embedding_type=PositionEmbeddingType.learned_absolute,
rotary_embedding_base=10000.0,
rotary_embedding_base_local=1.0,
rotary_embedding_scaling=None,
rotary_embedding_percentage=1.0,
rope_scaling_short_factors=None,
Expand Down Expand Up @@ -388,7 +402,8 @@ def __init__(self,
cp_size=1,
cp_rank=0,
max_seqlen_for_logn_scaling=8192,
use_logn_scaling=False):
use_logn_scaling=False,
is_local=False):
super().__init__()

self.local_layer_idx = local_layer_idx
Expand Down Expand Up @@ -417,6 +432,7 @@ def __init__(self,
self.cp_group = cp_group
self.cp_size = cp_size
self.cp_rank = cp_rank
self.is_local = is_local

self.num_layers = num_layers
self.apply_query_key_layer_scaling = apply_query_key_layer_scaling
Expand All @@ -437,6 +453,7 @@ def __init__(self,
self.max_distance = max_distance
self.num_buckets = num_buckets
self.rotary_embedding_base = rotary_embedding_base
self.rotary_embedding_base_local = rotary_embedding_base_local
self.rotary_embedding_scaling = rotary_embedding_scaling
self.rotary_embedding_scale_type = RotaryScalingType.none
self.rotary_embedding_scale = 1.0
Expand Down Expand Up @@ -656,26 +673,45 @@ def create_attention_const_params(model_cls, config):
model_cls.short_mscale = short_mscale
model_cls.long_mscale = long_mscale
else:
# Rotary const weights.
embed_positions = RopeEmbeddingUtils.create_sinusoidal_positions(
max_position_embeddings,
rotary_embedding_dim,
)
rotary_inv_freq, embed_positions_for_gpt_attention = RopeEmbeddingUtils.create_sinusoidal_positions_for_attention_plugin(
max_position_embeddings, rotary_embedding_dim,
rotary_embedding_base, rotary_embedding_scale,
rotary_embedding_scale_type, rotary_embedding_scaling)
model_cls.register_parameter(
'embed_positions',
Parameter(embed_positions, dtype='float32', is_buffer=True))
model_cls.register_parameter(
'rotary_inv_freq',
Parameter(rotary_inv_freq, dtype='float32', is_buffer=True))
model_cls.register_parameter(
'embed_positions_for_gpt_attention',
Parameter(embed_positions_for_gpt_attention,
dtype='float32',
is_buffer=True))

def register_rope_params(rotary_base, names_to_register):
# Rotary const weights.
embed_positions = RopeEmbeddingUtils.create_sinusoidal_positions(
max_position_embeddings,
rotary_embedding_dim,
)
rotary_inv_freq, embed_positions_for_gpt_attention = RopeEmbeddingUtils.create_sinusoidal_positions_for_attention_plugin(
max_position_embeddings, rotary_embedding_dim, rotary_base,
rotary_embedding_scale, rotary_embedding_scale_type,
rotary_embedding_scaling)
model_cls.register_parameter(
names_to_register[0],
Parameter(embed_positions, dtype='float32', is_buffer=True))
model_cls.register_parameter(
names_to_register[1],
Parameter(rotary_inv_freq, dtype='float32', is_buffer=True))
model_cls.register_parameter(
names_to_register[2],
Parameter(embed_positions_for_gpt_attention,
dtype='float32',
is_buffer=True))

register_rope_params(rotary_base=rotary_embedding_base,
names_to_register=[
'embed_positions', 'rotary_inv_freq',
'embed_positions_for_gpt_attention'
])

# For models with non-homegeneous attention layers requiring a second set of rope params. e.g. Gemma3.
rotary_embedding_base_local = getattr(config,
'rope_local_base_freq', None)
if rotary_embedding_base_local is not None:
register_rope_params(
rotary_base=rotary_embedding_base_local,
names_to_register=[
'embed_positions_local', 'rotary_inv_freq_local',
'embed_positions_for_gpt_attention_local'
])

@staticmethod
def fill_attention_params(model_cls, attention_params):
Expand All @@ -695,7 +731,15 @@ def fill_attention_params(model_cls, attention_params):
return attention_params.fill_attention_const_params_for_rope(
model_cls.embed_positions.value,
model_cls.rotary_inv_freq.value,
model_cls.embed_positions_for_gpt_attention.value)
model_cls.embed_positions_for_gpt_attention.value,
model_cls.embed_positions_local.value if hasattr(
model_cls, "embed_positions_local") else None,
model_cls.rotary_inv_freq_local.value if hasattr(
model_cls, "rotary_inv_freq_local") else None,
model_cls.embed_positions_for_gpt_attention_local.value
if hasattr(
model_cls,
"embed_positions_for_gpt_attention_local") else None)
# Fill nothing.
return attention_params

Expand Down Expand Up @@ -1020,6 +1064,11 @@ def compute_cross_kv(encoder_output):
# Rotary cos/sin cache.
rotary_cos_sin = getattr(attention_params,
"embed_positions_for_gpt_attention", None)
rotary_inv_freq_local = getattr(attention_params,
"rotary_inv_freq_local", None)
rotary_cos_sin_local = getattr(
attention_params, "embed_positions_for_gpt_attention_local",
None)

long_rope_rotary_inv_freq = getattr(attention_params,
"long_rope_rotary_inv_freq",
Expand Down Expand Up @@ -1062,7 +1111,8 @@ def compute_cross_kv(encoder_output):
hidden_size_per_head=self.attention_head_size,
q_scaling=self.q_scaling,
rotary_embedding_dim=self.rotary_embedding_dim,
rotary_embedding_base=self.rotary_embedding_base,
rotary_embedding_base=self.rotary_embedding_base
if not self.is_local else self.rotary_embedding_base_local,
rotary_embedding_scale_type=self.rotary_embedding_scale_type,
rotary_embedding_short_m_scale=attention_params.short_mscale,
rotary_embedding_long_m_scale=attention_params.long_mscale,
Expand All @@ -1071,8 +1121,10 @@ def compute_cross_kv(encoder_output):
rotary_embedding_original_max_positions=self.
original_max_position_embeddings,
position_embedding_type=self.position_embedding_type,
rotary_inv_freq=rotary_inv_freq,
rotary_cos_sin=rotary_cos_sin,
rotary_inv_freq=rotary_inv_freq
if not self.is_local else rotary_inv_freq_local,
rotary_cos_sin=rotary_cos_sin
if not self.is_local else rotary_cos_sin_local,
kv_orig_quant_scale=kv_orig_quant_scale,
kv_quant_orig_scale=kv_quant_orig_scale,
attention_output_orig_quant_scale=self.
Expand Down
4 changes: 3 additions & 1 deletion tensorrt_llm/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@
from .enc_dec.model import DecoderModel, EncoderModel, WhisperEncoder
from .falcon.config import FalconConfig
from .falcon.model import FalconForCausalLM, FalconModel
from .gemma.config import GEMMA2_ARCHITECTURE, GEMMA_ARCHITECTURE, GemmaConfig
from .gemma.config import (GEMMA2_ARCHITECTURE, GEMMA3_ARCHITECTURE,
GEMMA_ARCHITECTURE, GemmaConfig)
from .gemma.model import GemmaForCausalLM
from .gpt.config import GPTConfig
from .gpt.model import GPTForCausalLM, GPTModel
Expand Down Expand Up @@ -183,6 +184,7 @@
'SkyworkForCausalLM': LLaMAForCausalLM,
GEMMA_ARCHITECTURE: GemmaForCausalLM,
GEMMA2_ARCHITECTURE: GemmaForCausalLM,
GEMMA3_ARCHITECTURE: GemmaForCausalLM,
'QWenLMHeadModel': QWenForCausalLM,
'QWenForCausalLM': QWenForCausalLM,
'Qwen2ForCausalLM': QWenForCausalLM,
Expand Down
35 changes: 30 additions & 5 deletions tensorrt_llm/models/gemma/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from tensorrt_llm.mapping import Mapping
from tensorrt_llm.models.convert_utils import infer_dtype
from tensorrt_llm.models.modeling_utils import (Gemma2ConfigGroup,
Gemma3ConfigGroup,
PretrainedConfig, QuantConfig)

if TYPE_CHECKING:
Expand All @@ -30,6 +31,7 @@

GEMMA_ARCHITECTURE = "GemmaForCausalLM"
GEMMA2_ARCHITECTURE = "Gemma2ForCausalLM"
GEMMA3_ARCHITECTURE = "Gemma3ForCausalLM"


class GemmaConfig(PretrainedConfig):
Expand All @@ -48,6 +50,9 @@ def __init__(
final_logit_softcapping: Optional[float] = None,
attn_logit_softcapping: Optional[float] = None,
mapping: Optional[Union[Mapping, dict]] = None,
sliding_window_pattern: int = None,
rope_local_base_freq: int = None,
sliding_window: int = None,
**kwargs,
):
use_parallel_embedding = False
Expand Down Expand Up @@ -79,23 +84,29 @@ def __init__(
self.mlp_bias = mlp_bias

self.inter_layernorms = False
if self.is_gemma_2:
if self.is_gemma_2 or self.is_gemma_3:
self.inter_layernorms = True
assert query_pre_attn_scalar is not None, "Gemma2 models must configure `query_pre_attn_scalar`"
assert query_pre_attn_scalar is not None, "Gemma2 / Gemma3 models must configure `query_pre_attn_scalar`"
self.query_pre_attn_scalar = query_pre_attn_scalar
self.final_logit_softcapping = final_logit_softcapping
self.attn_logit_softcapping = attn_logit_softcapping
if self.is_gemma_2:
self.attn_logit_softcapping = attn_logit_softcapping
if self.is_gemma_3:
self.sliding_window_pattern = sliding_window_pattern
self.rope_local_base_freq = rope_local_base_freq
self.sliding_window = sliding_window

GEMMA_ADDED_FIELDS = {
"rotary_base", "rotary_scaling", "attn_bias", "mlp_bias",
"inter_layernorms"
}
GEMMA2_ADDED_FIELDS = Gemma2ConfigGroup.keys()
GEMMA3_ADDED_FIELDS = Gemma3ConfigGroup.keys()
VERBATIM = {
"num_hidden_layers", "num_attention_heads", "hidden_size",
"intermediate_size", "vocab_size", "max_position_embeddings",
"hidden_act", "use_parallel_embedding"
} | GEMMA2_ADDED_FIELDS
} | GEMMA2_ADDED_FIELDS | GEMMA3_ADDED_FIELDS

@property
def is_gemma_2(self) -> bool:
Expand All @@ -106,6 +117,15 @@ def gemma2_config(self):
return self.get_config_group(Gemma2ConfigGroup)
return None

@property
def is_gemma_3(self) -> bool:
return self.architecture == GEMMA3_ARCHITECTURE

def gemma3_config(self):
if self.is_gemma_3:
return self.get_config_group(Gemma3ConfigGroup)
return None

def to_dict(self):
"""Serialize the fields added in GemmaConfig"""

Expand All @@ -118,7 +138,11 @@ def to_dict(self):
**({
f: getattr(self, f)
for f in self.GEMMA2_ADDED_FIELDS
} if self.is_gemma_2 else {})
} if self.is_gemma_2 else {}),
**({
f: getattr(self, f)
for f in self.GEMMA3_ADDED_FIELDS
} if self.is_gemma_3 else {})
}

@classmethod
Expand Down Expand Up @@ -148,6 +172,7 @@ def from_hugging_face(
norm_epsilon=hf_config.rms_norm_eps,
num_key_value_heads=getattr(hf_config, "num_key_value_heads",
hf_config.num_attention_heads),
rotary_base=getattr(hf_config, "rope_theta", 10000.0),
rotary_scaling=getattr(hf_config, "rotary_scaling", None),
quantization=quant_config,
mapping=mapping,
Expand Down
6 changes: 6 additions & 0 deletions tensorrt_llm/models/gemma/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,10 @@ def rename_to_trt_llm(self, name: str) -> Optional[str]:
None), # merged with above
(r"model.layers.(\d+).self_attn.o_proj.weight",
r"layers.\1.attention.dense.weight"),
(r"model.layers.(\d+).self_attn.q_norm.weight",
r"layers.\1.attention.q_layernorm.weight"),
(r"model.layers.(\d+).self_attn.k_norm.weight",
r"layers.\1.attention.k_layernorm.weight"),
(r"model.layers.(\d+).mlp.gate_proj.weight",
r"layers.\1.mlp.fc.weight"),
(r"model.layers.(\d+).mlp.up_proj.weight",
Expand Down Expand Up @@ -795,6 +799,8 @@ def load_gemma_weights(
"pre_feedforward_layernorm",
"post_feedforward_layernorm",
"model.norm.weight",
"q_norm.weight",
"k_norm.weight",
)):
param = param + 1.0 # upcasted to float32 in case of bfloat16
add_trt_llm_weight(weights, trt_llm_name, param,
Expand Down
Loading