diff --git a/tests/quantization/test_gptq_dynamic.py b/tests/quantization/test_gptq_dynamic.py new file mode 100644 index 000000000000..c6f34fef2743 --- /dev/null +++ b/tests/quantization/test_gptq_dynamic.py @@ -0,0 +1,68 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Tests whether gptq models with dynamic quantized can be loaded. + +Run `pytest tests/quantization/test_gptq_dynamic.py --forked`. +""" + +import pytest +import torch + +from vllm.model_executor.layers.linear import UnquantizedLinearMethod +from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod +from vllm.model_executor.layers.quantization.gptq_marlin import ( + GPTQMarlinLinearMethod) +from vllm.model_executor.layers.quantization.utils.gptq_utils import ( + get_dynamic_override) + +PROMPT = "On the surface of Mars, we found" + +# The first layer is quantized using bits=4, group_size=128 +# The second layer is quantized using bits=8, group_size=32 +# All other layers (layer index >= 2) are not quantized +MODEL_QUANT = [ + ("ModelCloud/Qwen1.5-1.8B-Chat-GPTQ-4bits-dynamic-cfg-with-lm_head-symTrue", + True), + ("ModelCloud/Qwen1.5-1.8B-Chat-GPTQ-4bits-dynamic-cfg-with-lm_head-symFalse", + False), +] + + +@pytest.mark.parametrize("model_id, use_marlin_kernel", MODEL_QUANT) +def test_gptq_with_dynamic(vllm_runner, model_id: str, + use_marlin_kernel: bool): + + vllm_model = vllm_runner(model_id, dtype=torch.float16, max_model_len=2048) + + linear_method_cls = GPTQMarlinLinearMethod if use_marlin_kernel else ( + GPTQLinearMethod) + + for name, submodule in (vllm_model.model.llm_engine.model_executor. + driver_worker.model_runner.model.named_modules()): + if name == "lm_head": + assert isinstance(submodule.quant_method, linear_method_cls) + elif name == 'model.layers.0.self_attn.qkv_proj': + # The first layer is quantized using bits=4, group_size=128 + # desc_act=True + assert isinstance(submodule.quant_method, linear_method_cls) + config = submodule.quant_method.quant_config + assert config.weight_bits == 4 + assert config.group_size == 128 + assert config.desc_act + elif name == 'model.layers.1.self_attn.qkv_proj': + # The second layer is quantized using bits=8, group_size=32 + # desc_act=False + assert isinstance(submodule.quant_method, linear_method_cls) + config = submodule.quant_method.quant_config + assert get_dynamic_override(config, layer_name=name, + key="bits") == 8 + assert get_dynamic_override(config, + layer_name=name, + key="group_size") == 32 + assert not get_dynamic_override( + config, layer_name=name, key="desc_act") + elif (name == 'model.layers.2.self_attn.qkv_proj' + or name == 'model.layers.2.mlp.gate_up_proj'): + # All other layers (layer index >= 2) are not quantized + assert isinstance(submodule.quant_method, UnquantizedLinearMethod) + + del vllm_model diff --git a/tests/quantization/test_lm_head.py b/tests/quantization/test_lm_head.py index ec60d8a57559..20435a287e37 100644 --- a/tests/quantization/test_lm_head.py +++ b/tests/quantization/test_lm_head.py @@ -3,7 +3,6 @@ Run `pytest tests/quantization/test_quant_lm_head_true.py --forked`. """ -from typing import Tuple import pytest import torch @@ -17,31 +16,31 @@ PROMPT = "On the surface of Mars, we found" -MODELS_QUANT = [( - "LnL-AI/TinyLlama-1.1B-intermediate-step-1341k-3T-autoround-lm_head-symFalse", - True), ("TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ", False), - ("neuralmagic/Meta-Llama-3-8B-Instruct-FP8", False)] +MODELS_QUANT = [ + ("ModelCloud/Qwen1.5-1.8B-Chat-GPTQ-4bits-dynamic-cfg-with-lm_head", True), + ("ModelCloud/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit-10-25-2024", False), + ("TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ", False), + ("neuralmagic/Meta-Llama-3-8B-Instruct-FP8", False) +] -@pytest.mark.parametrize("model_lm_head_quant", MODELS_QUANT) +@pytest.mark.parametrize("model_id, lm_head_quantized", MODELS_QUANT) def test_lm_head( vllm_runner, - model_lm_head_quant: Tuple[str, bool], + model_id: str, + lm_head_quantized: bool, ) -> None: - model, lm_head_quantized = model_lm_head_quant - - with vllm_runner(model, dtype=torch.float16, + with vllm_runner(model_id, dtype=torch.float16, max_model_len=2048) as vllm_model: def check_model(model): lm_head_layer = model.lm_head - if lm_head_quantized: - assert isinstance(lm_head_layer.linear_method, + assert isinstance(lm_head_layer.quant_method, (GPTQLinearMethod, GPTQMarlinLinearMethod, MarlinLinearMethod)) else: - assert isinstance(lm_head_layer.linear_method, + assert isinstance(lm_head_layer.quant_method, UnquantizedEmbeddingMethod) vllm_model.apply_model(check_model) diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py index 9826aeb9dc27..7f68dae9717c 100644 --- a/vllm/lora/layers.py +++ b/vllm/lora/layers.py @@ -1039,7 +1039,7 @@ def _get_logits( embedding_bias: Optional[torch.Tensor] = None, ) -> Optional[torch.Tensor]: # Get the logits for the next tokens. - logits = lm_head.linear_method.apply(lm_head, hidden_states) + logits = lm_head.quant_method.apply(lm_head, hidden_states) if embedding_bias is not None: logits += embedding_bias diff --git a/vllm/model_executor/layers/logits_processor.py b/vllm/model_executor/layers/logits_processor.py index 0565c6e8be38..9b1742998578 100644 --- a/vllm/model_executor/layers/logits_processor.py +++ b/vllm/model_executor/layers/logits_processor.py @@ -108,9 +108,9 @@ def _get_logits( embedding_bias: Optional[torch.Tensor], ) -> Optional[torch.Tensor]: # Get the logits for the next tokens. - logits = lm_head.linear_method.apply(lm_head, - hidden_states, - bias=embedding_bias) + logits = lm_head.quant_method.apply(lm_head, + hidden_states, + bias=embedding_bias) # Gather logits for TP logits = self._gather_logits(logits) diff --git a/vllm/model_executor/layers/quantization/gptq.py b/vllm/model_executor/layers/quantization/gptq.py index 0cb77a7546d1..6d1f0cc2eb4d 100644 --- a/vllm/model_executor/layers/quantization/gptq.py +++ b/vllm/model_executor/layers/quantization/gptq.py @@ -3,16 +3,17 @@ import enum from enum import Enum from fractions import Fraction -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Union import torch from torch.nn.parameter import Parameter from vllm import _custom_ops as ops -from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase +from vllm.model_executor.layers.linear import LinearMethodBase from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) -from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead +from vllm.model_executor.layers.quantization.utils.gptq_utils import ( + get_linear_quant_method) from vllm.model_executor.parameter import (ChannelQuantScaleParameter, GroupQuantScaleParameter, PackedColumnParameter, @@ -32,7 +33,33 @@ def __init__( group_size: int, desc_act: bool, lm_head_quantized: bool, + dynamic: Dict[str, Dict[str, Union[int, bool]]], ) -> None: + # GPTQModel use `dynamic` config property to allow per module + # quantization config so each module can be individually optimized. + # Format is Dict[str, Dict] where key is a regex string that can + # perform both positive ("+:" prefixed) or negative ("-:" prefixed) + # matching of a module. + # Default to positive match, override base quant config mode, if no + # prefix is used. Value is in dict format of field key and override + # value. + # Negative matching will skip quantization init for this module + # entirely: + # non-quantized inference. More details and quantization examples can be + # found at: https://github.com/ModelCloud/GPTQModel + # Example: + # # last 1/2 of the layers 10-21 has 8bit vs 4bit for 0-9 + # # last 1/4 of the layers 16-21 has 8bit and group_size 64 + # dynamic = { + # #`.*\.` matches the layers_node prefix + # # positive match layer 10-15 + # r"+:.*\.(?:1[0-5])\..*": {"bits": 8,}, + # # positive match layer 16-21 + # r"+:.*\.(?:1[6-9]|20|21)\..*": {"bits": 8, "group_size": 64,}, + # r"-:.*\.moe\..*": {}, # negative match (skip) all `moe` layers + # } + self.dynamic = dynamic + self.weight_bits = weight_bits self.group_size = group_size self.desc_act = desc_act @@ -47,7 +74,8 @@ def __repr__(self) -> str: return (f"GPTQConfig(weight_bits={self.weight_bits}, " f"group_size={self.group_size}, " f"desc_act={self.desc_act})," - f"lm_head_quantized={self.lm_head_quantized}") + f"lm_head_quantized={self.lm_head_quantized}), " + f"dynamic={self.dynamic}") @classmethod def get_name(cls) -> str: @@ -68,19 +96,20 @@ def get_config_filenames(cls) -> List[str]: @classmethod def from_config(cls, config: Dict[str, Any]) -> "GPTQConfig": + dynamic = cls.get_from_keys_or(config, ["dynamic"], default={}) + dynamic = {} if dynamic is None else dynamic + weight_bits = cls.get_from_keys(config, ["bits"]) group_size = cls.get_from_keys(config, ["group_size"]) desc_act = cls.get_from_keys(config, ["desc_act"]) lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"], default=False) - return cls(weight_bits, group_size, desc_act, lm_head_quantized) + return cls(weight_bits, group_size, desc_act, lm_head_quantized, + dynamic) def get_quant_method(self, layer: torch.nn.Module, prefix: str) -> Optional["GPTQLinearMethod"]: - if (isinstance(layer, LinearBase) or - (isinstance(layer, ParallelLMHead) and self.lm_head_quantized)): - return GPTQLinearMethod(self) - return None + return get_linear_quant_method(self, layer, prefix, GPTQLinearMethod) class ExllamaState(Enum): diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index 84c53b2c16d5..0a9d86b008db 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -9,17 +9,21 @@ from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe.layer import ( FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported) -from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, +from vllm.model_executor.layers.linear import (LinearMethodBase, + UnquantizedLinearMethod, set_weight_attrs) from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.quantization.kernels.mixed_precision import ( MPLinearLayerConfig, choose_mp_linear_kernel) from vllm.model_executor.layers.quantization.utils import replace_parameter +from vllm.model_executor.layers.quantization.utils.gptq_utils import ( + get_linear_quant_method) from vllm.model_executor.layers.quantization.utils.marlin_utils import ( check_marlin_supported, marlin_moe_permute_scales, marlin_repeat_scales_on_all_ranks, verify_marlin_supported) -from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead +from vllm.model_executor.layers.vocab_parallel_embedding import ( + UnquantizedEmbeddingMethod) from vllm.model_executor.parameter import (ChannelQuantScaleParameter, GroupQuantScaleParameter, PackedColumnParameter, @@ -47,12 +51,41 @@ def __init__( desc_act: bool, is_sym: bool, lm_head_quantized: bool, + dynamic: Dict[str, Dict[str, Union[int, bool]]], ) -> None: if desc_act and group_size == -1: # In this case, act_order == True is the same as act_order == False # (since we have only one group per output channel) desc_act = False + # GPTQModel use `dynamic` config property to allow per module + # quantization config so each module can be individually optimized. + # Format is Dict[str, Dict] where key is a regex string that can + # perform both positive ("+:" prefixed) or negative ("-:" prefixed) + # matching of a module. + # Default to positive match, override base quant config mode, if no + # prefix is used. Value is in dict format of field key and override + # value. + # Negative matching will skip quantization init for this module + # entirely: + # non-quantized inference. More details and quantization examples can be + # found at: https://github.com/ModelCloud/GPTQModel + # Example: + # # last 1/2 of the layers 10-21 has 8bit vs 4bit for 0-9 + # # last 1/4 of the layers 16-21 has 8bit and group_size 64 + # dynamic = { + # #`.*\.` matches the layers_node prefix + # # positive match layer 10-15 + # r"+:.*\.(?:1[0-5])\..*": {"bits": 8,}, + # # positive match layer 16-21 + # r"+:.*\.(?:1[6-9]|20|21)\..*": {"bits": 8, "group_size": 64,}, + # r"-:.*\.moe\..*": {}, # negative match (skip) all `moe` layers + # } + self.dynamic = dynamic + + self.weight_bits = weight_bits + self.is_sym = is_sym + self.pack_factor = 32 // weight_bits # packed into int32 self.group_size = group_size self.desc_act = desc_act @@ -68,7 +101,8 @@ def __repr__(self) -> str: return (f"GPTQMarlinConfig(quant_type={self.quant_type}, " f"group_size={self.group_size}, " f"desc_act={self.desc_act}, " - f"lm_head_quantized={self.lm_head_quantized})") + f"lm_head_quantized={self.lm_head_quantized}), " + f"dynamic={self.dynamic}") @classmethod def get_name(cls) -> str: @@ -88,6 +122,9 @@ def get_config_filenames(cls) -> List[str]: @classmethod def from_config(cls, config: Dict[str, Any]) -> "GPTQMarlinConfig": + dynamic = cls.get_from_keys_or(config, ["dynamic"], default={}) + dynamic = {} if dynamic is None else dynamic + weight_bits = cls.get_from_keys(config, ["bits"]) group_size = cls.get_from_keys(config, ["group_size"]) desc_act = cls.get_from_keys(config, ["desc_act"]) @@ -95,7 +132,7 @@ def from_config(cls, config: Dict[str, Any]) -> "GPTQMarlinConfig": lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"], default=False) return cls(weight_bits, group_size, desc_act, is_sym, - lm_head_quantized) + lm_head_quantized, dynamic) @classmethod def override_quantization_method(cls, hf_quant_cfg, @@ -120,17 +157,15 @@ def override_quantization_method(cls, hf_quant_cfg, def get_quant_method( self, layer: torch.nn.Module, prefix: str - ) -> Optional[Union["GPTQMarlinLinearMethod", "GPTQMarlinMoEMethod"]]: - if isinstance(layer, LinearBase) or (isinstance(layer, ParallelLMHead) - and self.lm_head_quantized): - return GPTQMarlinLinearMethod(self) - elif isinstance(layer, FusedMoE): + ) -> Optional[Union["GPTQMarlinLinearMethod", "GPTQMarlinMoEMethod", + UnquantizedLinearMethod, UnquantizedEmbeddingMethod]]: + if isinstance(layer, FusedMoE): return GPTQMarlinMoEMethod(self) - return None + return get_linear_quant_method(self, layer, prefix, + GPTQMarlinLinearMethod) @classmethod def is_gptq_marlin_compatible(cls, quant_config: Dict[str, Any]): - # Extract data from quant config. quant_method = quant_config.get("quant_method", "").lower() num_bits = quant_config.get("bits") group_size = quant_config.get("group_size") @@ -143,7 +178,7 @@ def is_gptq_marlin_compatible(cls, quant_config: Dict[str, Any]): if quant_method != "gptq": return False - # If we cannot find the info needed in the config, cannot convert. + # Marlin conversion is only valid if required properties are found if (num_bits is None or group_size is None or sym is None or desc_act is None): return False diff --git a/vllm/model_executor/layers/quantization/utils/gptq_utils.py b/vllm/model_executor/layers/quantization/utils/gptq_utils.py new file mode 100644 index 000000000000..5b0e6299f473 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/gptq_utils.py @@ -0,0 +1,94 @@ +# SPDX-License-Identifier: Apache-2.0 +import re +from copy import deepcopy +from typing import Dict, Optional, Union + +import torch + +from vllm.config import QuantizationConfig +from vllm.model_executor.layers.linear import (LinearBase, + UnquantizedLinearMethod) +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, UnquantizedEmbeddingMethod) + + +# Match dynamic rules with module name (prefix) and override quantize +# config if module (prefix) matches a rule +def override_config(config: QuantizationConfig, prefix: str): + weight_bits = get_dynamic_override(config, prefix, "bits", + config.weight_bits) + if isinstance(weight_bits, int): + config.weight_bits = weight_bits + group_size = get_dynamic_override(config, prefix, "group_size", + config.group_size) + if isinstance(group_size, int): + config.group_size = group_size + desc_act = get_dynamic_override(config, prefix, "desc_act", + config.desc_act) + if isinstance(desc_act, bool): + config.desc_act = desc_act + + config.pack_factor = 32 // config.weight_bits # packed into int32 + if config.get_name() == "gptq_marlin": + is_sym = get_dynamic_override(config, prefix, "sym", config.is_sym) + if isinstance(is_sym, bool): + config.is_sym = is_sym + + if (config.weight_bits, config.is_sym) not in config.TYPE_MAP: + raise ValueError("Unsupported quantization config: " + f"bits={config.weight_bits}, sym={config.is_sym}") + + config.quant_type = config.TYPE_MAP[(config.weight_bits, + config.is_sym)] + elif config.get_name() == "gptq": + if config.weight_bits not in [2, 3, 4, 8]: + raise ValueError( + "Currently, only 2/3/4/8-bit weight quantization is " + f"supported for GPTQ, but got {config.weight_bits} bits.") + + +def get_dynamic_override( + config: QuantizationConfig, + layer_name: str, + key: Optional[str] = None, + default_value: Union[int, bool, + None] = None) -> Union[Dict, int, bool, None]: + for pattern, pattern_dict in config.dynamic.items(): + # Negative match: matched modules are excluded from quantized init + if pattern.startswith("-:"): + if re.match(pattern.removeprefix("-:"), layer_name): + return False + # Positive match: matched modules have quant properties overrides + # base quant config + elif re.match(pattern.removeprefix("+:"), layer_name): + if key is None: + return pattern_dict + else: + return pattern_dict.get(key, default_value) + return default_value + + +def get_linear_quant_method( + config: QuantizationConfig, + layer: torch.nn.Module, + prefix: str, + linear_method_cls: type, +): + cloned_config = deepcopy(config) + parallel_lm_head_quantized = isinstance( + layer, ParallelLMHead) and cloned_config.lm_head_quantized + if isinstance(layer, LinearBase) or parallel_lm_head_quantized: + # False = skip module, None = no override, else = Positive match + if get_dynamic_override( # noqa: E712 + cloned_config, # noqa: E712 + layer_name=prefix) == False: # noqa: E712 + if parallel_lm_head_quantized: + return UnquantizedEmbeddingMethod() + return UnquantizedLinearMethod() + + if prefix: + # Dynamic per module/layer rules may override base config + override_config(cloned_config, prefix=prefix) + + return linear_method_cls(cloned_config) + return None diff --git a/vllm/model_executor/layers/vocab_parallel_embedding.py b/vllm/model_executor/layers/vocab_parallel_embedding.py index e409094dd535..f65dfc3cb329 100644 --- a/vllm/model_executor/layers/vocab_parallel_embedding.py +++ b/vllm/model_executor/layers/vocab_parallel_embedding.py @@ -226,24 +226,24 @@ def __init__(self, self.tp_size) self.embedding_dim = embedding_dim - linear_method = None + quant_method = None if quant_config is not None: - linear_method = quant_config.get_quant_method(self, prefix=prefix) - if linear_method is None: - linear_method = UnquantizedEmbeddingMethod() + quant_method = quant_config.get_quant_method(self, prefix=prefix) + if quant_method is None: + quant_method = UnquantizedEmbeddingMethod() # If we are making an embedding layer, then our quantization linear # method must implement the embedding operation. If we are another # layer type like ParallelLMHead, this is not important. is_embedding_layer = type(self.__class__) is VocabParallelEmbedding - linear_method_implements_embedding = method_has_implemented_embedding( - type(linear_method)) - if is_embedding_layer and not linear_method_implements_embedding: + quant_method_implements_embedding = method_has_implemented_embedding( + type(quant_method)) + if is_embedding_layer and not quant_method_implements_embedding: raise NotImplementedError( - f"The class {type(linear_method).__name__} must implement " + f"The class {type(quant_method).__name__} must implement " "the 'embedding' method, see UnquantizedEmbeddingMethod.") - self.linear_method: QuantizeMethodBase = linear_method + self.quant_method: QuantizeMethodBase = quant_method if params_dtype is None: params_dtype = torch.get_default_dtype() @@ -260,13 +260,13 @@ def __init__(self, self.shard_indices.added_vocab_end_index - self.shard_indices.added_vocab_start_index) - self.linear_method.create_weights(self, - self.embedding_dim, - [self.num_embeddings_per_partition], - self.embedding_dim, - self.num_embeddings_padded, - params_dtype=params_dtype, - weight_loader=self.weight_loader) + self.quant_method.create_weights(self, + self.embedding_dim, + [self.num_embeddings_per_partition], + self.embedding_dim, + self.num_embeddings_padded, + params_dtype=params_dtype, + weight_loader=self.weight_loader) @classmethod def _get_indices(cls, vocab_size_padded: int, org_vocab_size_padded: int, @@ -412,8 +412,8 @@ def forward(self, input_): else: masked_input = input_ # Get the embeddings. - output_parallel = self.linear_method.embedding(self, - masked_input.long()) + output_parallel = self.quant_method.embedding(self, + masked_input.long()) # Mask the output embedding. if self.tp_size > 1: output_parallel.masked_fill_(input_mask.unsqueeze(-1), 0)