diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index 82d60f9da7da..2149e89aa734 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -237,6 +237,7 @@ class AttentionLayer(Protocol): _v_scale: torch.Tensor _k_scale_float: float _v_scale_float: float + _prob_scale: torch.Tensor def forward( self, diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index 37b6cadcb98a..8076c4791d3c 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -766,6 +766,12 @@ def forward( query.dtype, seq_lens, make_attn_mask=causal_mask) # type: ignore + use_fp8_scales = (layer._q_scale and layer._k_scale + and layer._v_scale and layer._prob_scale + and self.kv_cache_dtype == "fp8") + full_scales = ( + layer._q_scale, layer._k_scale, layer._v_scale, + layer._prob_scale) if use_fp8_scales else None self.triton_attn_func( query, key, @@ -779,6 +785,7 @@ def forward( self.scale, attn_masks[0][None] if attn_masks is not None else None, + full_scales, ) elif self.use_naive_attn: if self.num_kv_heads != self.num_heads: diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 68452f4c03b0..aa218cc37af9 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -90,6 +90,7 @@ def __init__( # FlashAttn doesn't support quantizing the kv-cache only # but requires q to be quantized as well. self._q_scale = torch.tensor(1.0, dtype=torch.float32) + self._prob_scale = torch.tensor(1.0, dtype=torch.float32) # We also keep the float32 versions of k/v_scale for attention # backends that don't support tensors (Flashinfer) diff --git a/vllm/config.py b/vllm/config.py index 41a30efea039..ac1ed939c319 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -3766,6 +3766,17 @@ def _get_quantization_config( return quant_config return None + @staticmethod + def get_quantization_config( + model_config: ModelConfig, + load_config: LoadConfig) -> Optional[QuantizationConfig]: + import copy + + # For some reason, the _ version of this modifies the model_config + # object, so using deepcopy to avoid this problem. + return VllmConfig._get_quantization_config(copy.deepcopy(model_config), + load_config) + def with_hf_config( self, hf_config: PretrainedConfig, diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 9cb2aa797be5..4d4a2b59067a 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1377,6 +1377,23 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool: recommend_to_remove=False) return False + if current_platform.is_rocm(): + from vllm.model_executor.layers.quantization.fp8 import Fp8Config + load_config = self.create_load_config() + quantization_config = VllmConfig.get_quantization_config( + model_config, load_config) + if isinstance(quantization_config, Fp8Config): + _raise_or_fallback(feature_name="fp8 for ROCm", + recommend_to_remove=False) + return False + from vllm.model_executor.layers.quantization.quark.quark import ( + QuarkConfig) + + if isinstance(quantization_config, QuarkConfig + ) and quantization_config.has_fp8_layer_weights(): + _raise_or_fallback(feature_name="Quark fp8 for ROCm", + recommend_to_remove=False) + # No Fp8 KV cache so far. if self.kv_cache_dtype != "auto": fp8_attention = self.kv_cache_dtype.startswith("fp8") diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index be76785baccc..01056c37b86c 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -140,6 +140,11 @@ def get_cache_scale(self, name: str) -> Optional[str]: return name.replace(".k_proj.output_scale", ".attn.k_scale") if name.endswith(".output_scale") and ".v_proj" in name: return name.replace(".v_proj.output_scale", ".attn.v_scale") + if name.endswith(".output_scale") and ".q_proj" in name: + return name.replace(".q_proj.output_scale", ".attn.q_scale") + if name.endswith("self_attn.prob_output_scale"): + return name.replace(".prob_output_scale", ".attn.prob_scale") + # If no matches, return None return None diff --git a/vllm/model_executor/layers/quantization/kv_cache.py b/vllm/model_executor/layers/quantization/kv_cache.py index 5d766c2c27ac..5dff8b09693c 100644 --- a/vllm/model_executor/layers/quantization/kv_cache.py +++ b/vllm/model_executor/layers/quantization/kv_cache.py @@ -38,6 +38,9 @@ def create_weights(self, layer: torch.nn.Module): requires_grad=False) layer.v_scale = torch.nn.Parameter(torch.tensor(-1.0), requires_grad=False) + # Initialize P = softmax(QK^T) scales + layer.prob_scale = torch.nn.Parameter(torch.tensor(-1.0), + requires_grad=False) def apply(self, layer: torch.nn.Module) -> torch.Tensor: raise RuntimeError( @@ -97,5 +100,38 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: "may cause accuracy issues. Please make sure k/v_scale " "scaling factors are available in the fp8 checkpoint.") + if layer.q_scale > 0.0: + q_scale = layer.q_scale + if current_platform.is_fp8_fnuz(): + q_scale *= 2 + layer.calculate_kv_scales = False + else: + q_scale = 1.0 + if layer.prob_scale > 0.0: + prob_scale = layer.prob_scale + if current_platform.is_fp8_fnuz(): + prob_scale *= 2 + else: + prob_scale = 1.0 + + is_singleton_float = lambda x: isinstance(x, float) or isinstance( + x, torch.Tensor) and x.numel() == 1 and x.is_floating_point() + if not is_singleton_float(q_scale) or not is_singleton_float( + prob_scale): + raise ValueError("Only support per-tensor scaling factor" + "for fp8-quantized Q/prob") + + # These are used in the final Attention.forward() + layer._q_scale.copy_(q_scale) + layer._prob_scale.copy_(prob_scale) + if q_scale == 1.0 or prob_scale == 1.0: + logger.warning_once( + f"Using Q scale {q_scale} and prob scale {prob_scale} " + "with fp8 attention. This may cause accuracy issues. " + "Please make sure Q/prob scaling factors are " + "available in the fp8 checkpoint.") + del layer.k_scale del layer.v_scale + del layer.q_scale + del layer.prob_scale diff --git a/vllm/model_executor/layers/quantization/quark/quark.py b/vllm/model_executor/layers/quantization/quark/quark.py index ca71da8b736a..c4b3dcfbe8e6 100644 --- a/vllm/model_executor/layers/quantization/quark/quark.py +++ b/vllm/model_executor/layers/quantization/quark/quark.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 import fnmatch -import re from typing import Any, Dict, List, Optional, cast import torch @@ -125,6 +124,13 @@ def from_config(cls, config: Dict[str, Any]) -> "QuarkConfig": for q_config in q_configs: q_config["output_tensors"] = None + # In case q_proj output is also quantized, remove the configuration + # to keep qkv consistency. + q_proj_q_config = cast(Dict[str, Any], + layer_quant_config.get("*q_proj")) + if q_proj_q_config is not None: + q_proj_q_config["output_tensors"] = None + return cls(quant_config=config, kv_cache_group=kv_cache_group, kv_cache_config=kv_cache_config, @@ -289,29 +295,30 @@ def get_cache_scale(self, name: str) -> Optional[str]: :param name: param name :return: matching param name for KV cache scale in vLLM """ - if self.kv_cache_group is None or len(self.kv_cache_group) == 0: - return None - - kv_proj_names = [ - re.split(r"[*.]", kv_cache)[-1] for kv_cache in self.kv_cache_group - ] - if name.endswith(".output_scale"): - if len(kv_proj_names) == 1 and kv_proj_names[0] in name: - kv_output_scale_name = "." + kv_proj_names[0] + ".output_scale" - return name.replace(kv_output_scale_name, ".attn.k_scale") - - elif len(kv_proj_names) == 2: - for kv_proj_name in kv_proj_names: - if kv_proj_name in name and kv_proj_name == "k_proj": - return name.replace(".k_proj.output_scale", - ".attn.k_scale") - elif kv_proj_name in name and kv_proj_name == "v_proj": - return name.replace(".v_proj.output_scale", - ".attn.v_scale") + if name.endswith(".output_scale") and ".k_proj" in name: + return name.replace(".k_proj.output_scale", ".attn.k_scale") + if name.endswith(".output_scale") and ".v_proj" in name: + return name.replace(".v_proj.output_scale", ".attn.v_scale") + if name.endswith(".output_scale") and ".q_proj" in name: + return name.replace(".q_proj.output_scale", ".attn.q_scale") + if name.endswith("self_attn.prob_output_scale"): + return name.replace(".prob_output_scale", ".attn.prob_scale") # If no matches, return None return None + def has_fp8_layer_weights(self): + layer_quant_config = self.quant_config.get("layer_quant_config") + to_dict = lambda obj: cast(Dict[str, Any], obj) or {} + return any([ + 'fp8' in cast( + str, + to_dict( + to_dict(to_dict(layer_quant_config).get(layer_name)).get( + "weight")).get("dtype")) + for layer_name in ["*v_proj", "*k_proj", "*q_proj"] + ]) + class QuarkLinearMethod(LinearMethodBase):