From e524783a8710a60d4bc99e5e7db927d11b5b4267 Mon Sep 17 00:00:00 2001 From: mgoin Date: Tue, 29 Apr 2025 20:52:04 +0000 Subject: [PATCH] Fix noisy warning for uncalibrated q_scale/p_scale Signed-off-by: mgoin --- vllm/model_executor/layers/quantization/kv_cache.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/vllm/model_executor/layers/quantization/kv_cache.py b/vllm/model_executor/layers/quantization/kv_cache.py index 5dff8b09693c..67723c7c91cc 100644 --- a/vllm/model_executor/layers/quantization/kv_cache.py +++ b/vllm/model_executor/layers/quantization/kv_cache.py @@ -124,11 +124,12 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: # These are used in the final Attention.forward() layer._q_scale.copy_(q_scale) layer._prob_scale.copy_(prob_scale) - if q_scale == 1.0 or prob_scale == 1.0: + if layer.kv_cache_dtype == "fp8" and (q_scale == 1.0 + or prob_scale == 1.0): logger.warning_once( - f"Using Q scale {q_scale} and prob scale {prob_scale} " - "with fp8 attention. This may cause accuracy issues. " - "Please make sure Q/prob scaling factors are " + f"Using uncalibrated q_scale {q_scale} and/or prob_scale " + f"{prob_scale} with fp8 attention. This may cause accuracy " + "issues. Please make sure q/prob scaling factors are " "available in the fp8 checkpoint.") del layer.k_scale