Skip to content

Commit e26c4e2

Browse files
committed
[TRTLLM-6656][chore] Validate FP8 support for Gemma3
Signed-off-by: Balaram Buddharaju <[email protected]>
1 parent f923974 commit e26c4e2

File tree

6 files changed

+39
-1
lines changed

6 files changed

+39
-1
lines changed

tensorrt_llm/_torch/models/modeling_gemma3vl.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,11 +134,16 @@ def get_sub_model_config(
134134
"text_config", "vision_config"
135135
], f"Expected subconfig name to be either 'text_config' or 'vision_config'. Got {name} instead."
136136
pretrained_config = getattr(model_config.pretrained_config, name)
137+
# ModelOpt currently doesn't quantize the vision part. Without setting quant config to None,
138+
# weight loading fails for vision part.
139+
quant_config = model_config.quant_config if name == "text_config" else None
140+
# FlashInfer backend supports custom mask which is needed for bidirectional mask in decoder.
137141
preferred_backend = "FLASHINFER" if name == "text_config" else "TRTLLM"
138142
sub_model_config: ModelConfig[Gemma3Config] = dataclasses.replace(
139143
model_config,
140144
pretrained_config=pretrained_config,
141-
attn_backend=preferred_backend)
145+
attn_backend=preferred_backend,
146+
quant_config=quant_config)
142147
# Make sure some fields that are not explicitly included in the sub config, but present
143148
# in the top-level config, are replicated.
144149
if (hasattr(sub_model_config.pretrained_config, "torch_dtype")

tests/integration/defs/accuracy/references/cnn_dailymail.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@ google/gemma-3-1b-it:
55
accuracy: 20.699
66
google/gemma-3-27b-it:
77
- accuracy: 28.90
8+
- quant_algo: FP8
9+
kv_cache_quant_algo: FP8
10+
accuracy: 27.90
811
gpt2:
912
- accuracy: 18.408
1013
- quant_algo: W8A16

tests/integration/defs/accuracy/references/gsm8k.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,8 +141,14 @@ speakleash/Bielik-11B-v2.2-Instruct:
141141
accuracy: 40.41
142142
google/gemma-3-1b-it:
143143
- accuracy: 25.52 # score getting from lm-eval with HF implementation
144+
- quant_algo: FP8
145+
kv_cache_quant_algo: FP8
146+
accuracy: 23.96
144147
google/gemma-3-27b-it:
145148
- accuracy: 91.66
149+
- quant_algo: FP8
150+
kv_cache_quant_algo: FP8
151+
accuracy: 90.66
146152
mistralai/Ministral-8B-Instruct-2410:
147153
- accuracy: 79.25
148154
- quant_algo: FP8

tests/integration/defs/accuracy/references/mmlu.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,9 @@ google/gemma-3-1b-it:
114114
accuracy: 37.5
115115
google/gemma-3-27b-it:
116116
- accuracy: 77.80
117+
- quant_algo: FP8
118+
kv_cache_quant_algo: FP8
119+
accuracy: 76.80
117120
Qwen/Qwen2-0.5B-Instruct:
118121
- accuracy: 45.30
119122
- quant_algo: FP8

tests/integration/defs/accuracy/test_llm_api_pytorch.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -751,6 +751,24 @@ def test_auto_dtype(self):
751751
task = GSM8K(self.MODEL_NAME)
752752
task.evaluate(llm)
753753

754+
def test_fp8_prequantized(self):
755+
# Disabling kv cache reuse as a WAR to deal with gaps in kernel support for Gemma3's non-inclusive sliding window size.
756+
kv_cache_config = KvCacheConfig(enable_block_reuse=False,
757+
enable_partial_reuse=False,
758+
dtype="fp8")
759+
prequantized_model_path = f"{llm_models_root()}/gemma/gemma-3-27b-it-fp8/"
760+
with LLM(prequantized_model_path,
761+
kv_cache_config=kv_cache_config,
762+
attn_backend="FLASHINFER",
763+
cuda_graph_config=None) as llm:
764+
assert llm.args.quant_config.quant_algo == QuantAlgo.FP8
765+
task = CnnDailymail(self.MODEL_NAME)
766+
task.evaluate(llm)
767+
task = GSM8K(self.MODEL_NAME)
768+
task.evaluate(llm)
769+
task = MMLU(self.MODEL_NAME)
770+
task.evaluate(llm)
771+
754772

755773
class TestGemma3_1BInstruct(LlmapiAccuracyTestHarness):
756774
MODEL_NAME = "google/gemma-3-1b-it"
@@ -784,6 +802,8 @@ def test_fp8_prequantized(self):
784802
assert llm.args.quant_config.quant_algo == QuantAlgo.FP8
785803
task = CnnDailymail(self.MODEL_NAME)
786804
task.evaluate(llm)
805+
task = GSM8K(self.MODEL_NAME)
806+
task.evaluate(llm)
787807
task = MMLU(self.MODEL_NAME)
788808
task.evaluate(llm)
789809

tests/integration/test_lists/test-db/l0_h100.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,7 @@ l0_h100:
203203
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales[mtp=vanilla-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False]
204204
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_no_kv_cache_reuse[quant_dtype=none-mtp_nextn=2-fp8kv=False-attention_dp=True-cuda_graph=True-overlap_scheduler=True]
205205
- accuracy/test_llm_api_pytorch.py::TestGemma3_1BInstruct::test_fp8_prequantized
206+
- accuracy/test_llm_api_pytorch.py::TestGemma3_27BInstruct::test_fp8_prequantized
206207
- accuracy/test_llm_api_pytorch.py::TestGemma3_27BInstruct::test_auto_dtype
207208
- accuracy/test_llm_api_pytorch.py::TestMistralSmall24B::test_auto_dtype
208209
- accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_fp8_block_scales[latency-torch_compile=False]

0 commit comments

Comments
 (0)