Skip to content

Commit 609b545

Browse files
DomBrowndominicshanshan
authored andcommitted
[https://nvbugs/5461712] [fix] Disable deep_gemm for Qwen3 due to accuracy issues (#7170)
Signed-off-by: Dom Brown <[email protected]> Signed-off-by: Wangshanshan <[email protected]>
1 parent 08faae4 commit 609b545

File tree

5 files changed

+30
-6
lines changed

5 files changed

+30
-6
lines changed

tensorrt_llm/_torch/models/modeling_qwen3.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,9 @@ def __init__(
158158

159159
self.fuse_qk_norm_rope = fuse_qk_norm_rope
160160

161+
# Qwen3 has accuracy issues with deep_gemm (see: https://nvbugspro.nvidia.com/bug/5461712)
162+
disable_deep_gemm = True
163+
161164
super().__init__(
162165
hidden_size=config.hidden_size,
163166
num_attention_heads=config.num_attention_heads,
@@ -171,6 +174,7 @@ def __init__(
171174
dtype=config.torch_dtype,
172175
dense_bias=config.attention_bias,
173176
config=model_config,
177+
disable_deep_gemm=disable_deep_gemm,
174178
)
175179

176180
self.q_norm = RMSNorm(hidden_size=self.head_dim,
@@ -244,13 +248,18 @@ def __init__(
244248
layer_idx=layer_idx,
245249
)
246250

251+
# Qwen3 has accuracy issues with deep_gemm (see: https://nvbugspro.nvidia.com/bug/5461712)
252+
disable_deep_gemm = True
253+
247254
self.mlp = GatedMLP(
248255
hidden_size=config.hidden_size,
249256
intermediate_size=config.intermediate_size,
250257
bias=config.mlp_bias if hasattr(config, "mlp_bias") else False,
251258
dtype=config.torch_dtype,
252259
config=model_config,
260+
disable_deep_gemm=disable_deep_gemm,
253261
)
262+
254263
self.input_layernorm = RMSNorm(hidden_size=config.hidden_size,
255264
eps=config.rms_norm_eps,
256265
dtype=config.torch_dtype)

tensorrt_llm/_torch/modules/attention.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@ def __init__(
116116
config: Optional[ModelConfig] = None,
117117
q_scaling: float = 1.0,
118118
attention_chunk_size: Optional[int] = None,
119+
disable_deep_gemm: bool = False,
119120
):
120121
"""
121122
Initialize the Attention module.
@@ -134,6 +135,7 @@ def __init__(
134135
config (Optional[ModelConfig]): The model configuration.
135136
q_scaling (float): The scaling factor for the qk_scale. The definition is $O = softmax(QK^T * qk_scale) * V, qk_scale = 1 / (sqrt(head_dim) * q_scaling)$. The default value is 1.0.
136137
attention_chunk_size (Optional[int]): See [Chunked Attention] below.
138+
disable_deep_gemm (bool): Whether to disable deep_gemm for linear layers.
137139
"""
138140
super().__init__()
139141
self.layer_idx = layer_idx
@@ -215,7 +217,9 @@ def __init__(
215217
quant_config=config.get_quant_config(),
216218
skip_create_weights_in_init=config.skip_create_weights_in_init,
217219
allreduce_strategy=config.allreduce_strategy,
218-
force_dynamic_quantization=config.force_dynamic_quantization)
220+
force_dynamic_quantization=config.force_dynamic_quantization,
221+
disable_deep_gemm=disable_deep_gemm,
222+
)
219223
self.o_lora = LoraLayer([LoraModuleType.ATTENTION_DENSE],
220224
[self.hidden_size])
221225

@@ -230,7 +234,9 @@ def __init__(
230234
skip_create_weights_in_init=config.skip_create_weights_in_init,
231235
lora=self.o_lora,
232236
allreduce_strategy=config.allreduce_strategy,
233-
force_dynamic_quantization=config.force_dynamic_quantization)
237+
force_dynamic_quantization=config.force_dynamic_quantization,
238+
disable_deep_gemm=disable_deep_gemm,
239+
)
234240

235241
self.quant_config = config.get_quant_config()
236242
self.attn_backend = config.attn_backend

tensorrt_llm/_torch/modules/gated_mlp.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,9 @@ def __init__(self,
2929
overridden_tp_size: Optional[int] = None,
3030
reduce_output: bool = True,
3131
layer_idx: Optional[int] = None,
32-
use_cute_dsl_blockscaling_mm: bool = False):
32+
use_cute_dsl_blockscaling_mm: bool = False,
33+
disable_deep_gemm: bool = False):
34+
3335
super().__init__()
3436
self.layer_idx = layer_idx
3537
self.hidden_size = hidden_size
@@ -67,7 +69,9 @@ def __init__(self,
6769
skip_create_weights_in_init=config.skip_create_weights_in_init,
6870
allreduce_strategy=config.allreduce_strategy,
6971
force_dynamic_quantization=config.force_dynamic_quantization,
70-
use_cute_dsl_blockscaling_mm=use_cute_dsl_blockscaling_mm)
72+
use_cute_dsl_blockscaling_mm=use_cute_dsl_blockscaling_mm,
73+
disable_deep_gemm=disable_deep_gemm,
74+
)
7175

7276
self.down_lora = LoraLayer([LoraModuleType.MLP_4H_TO_H],
7377
[self.hidden_size])
@@ -85,7 +89,9 @@ def __init__(self,
8589
lora=self.down_lora,
8690
allreduce_strategy=config.allreduce_strategy,
8791
force_dynamic_quantization=config.force_dynamic_quantization,
88-
use_cute_dsl_blockscaling_mm=use_cute_dsl_blockscaling_mm)
92+
use_cute_dsl_blockscaling_mm=use_cute_dsl_blockscaling_mm,
93+
disable_deep_gemm=disable_deep_gemm,
94+
)
8995

9096
# These two modules are mutually exclusive - either splitted_gate_up_lora or fused_gate_up_lora will be used,
9197
# but never both at the same time. splitted_gate_up_lora handles gate and up separately while fused_gate_up_lora

tensorrt_llm/_torch/modules/linear.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -613,7 +613,7 @@ def apply(self, module: Linear, input: torch.Tensor,
613613
input = input.to(torch.bfloat16) * module.input_scale
614614
assert input.dtype == torch.bfloat16
615615

616-
if get_sm_version() == 100:
616+
if get_sm_version() == 100 and not module.disable_deep_gemm:
617617
if module.use_cute_dsl_blockscaling_mm:
618618
# TODO (@lmin): replace with cute_dsl gemm
619619
act_input_fp8, act_input_sf = torch.ops.trtllm.fp8_quantize_1x128(
@@ -1595,6 +1595,7 @@ def __init__(
15951595
allreduce_strategy: AllReduceStrategy = AllReduceStrategy.AUTO,
15961596
force_dynamic_quantization: bool = False,
15971597
use_cute_dsl_blockscaling_mm: bool = False,
1598+
disable_deep_gemm: bool = False,
15981599
):
15991600
from ..distributed import AllReduce
16001601

@@ -1612,6 +1613,7 @@ def __init__(
16121613
self.gather_output = gather_output
16131614
self.force_dynamic_quantization = force_dynamic_quantization
16141615
self.use_cute_dsl_blockscaling_mm = use_cute_dsl_blockscaling_mm
1616+
self.disable_deep_gemm = disable_deep_gemm
16151617

16161618
local_in_features = in_features
16171619
local_out_features = out_features

tests/integration/test_lists/test-db/l0_b200.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ l0_b200:
3737
- accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_1gpu[True-True-cutlass]
3838
- accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_1gpu[True-True-trtllm]
3939
- accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_1gpu[True-True-triton]
40+
- accuracy/test_llm_api_pytorch.py::TestQwen3_8B::test_fp8_block_scales[latency] # To cover NVBUGS 5461712
4041
- disaggregated/test_workers.py::test_workers_kv_cache_aware_router_eviction[TinyLlama-1.1B-Chat-v1.0] # nvbugs 5300551
4142
- test_e2e.py::test_ptp_quickstart_advanced[Llama3.1-8B-NVFP4-nvfp4-quantized/Meta-Llama-3.1-8B]
4243
- test_e2e.py::test_ptp_quickstart_advanced[Llama3.1-8B-FP8-llama-3.1-model/Llama-3.1-8B-Instruct-FP8]

0 commit comments

Comments
 (0)