diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index dad16112082c..521724765beb 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -290,29 +290,30 @@ def __init__(self, quant_config: Optional[QuantizationConfig] = None, output_sizes: Optional[list[int]] = None, prefix: str = ""): - super().__init__(input_size, output_size, skip_bias_add, params_dtype, - quant_config, prefix) - - self.gather_output = gather_output - # Divide the weight matrix along the last dimension. - tp_size = get_tensor_model_parallel_world_size() - assert self.quant_method is not None - self.output_size_per_partition = divide(self.output_size, tp_size) + self.tp_size = get_tensor_model_parallel_world_size() + self.input_size_per_partition = input_size + self.output_size_per_partition = divide(output_size, self.tp_size) self.output_partition_sizes = [self.output_size_per_partition] # If QKV or MergedColumn, use output size of each partition. if hasattr(self, "output_sizes"): self.output_partition_sizes = [ - divide(output_size, tp_size) + divide(output_size, self.tp_size) for output_size in self.output_sizes ] + super().__init__(input_size, output_size, skip_bias_add, params_dtype, + quant_config, prefix) + + self.gather_output = gather_output + if output_sizes is None: output_sizes = [output_size] + assert self.quant_method is not None self.quant_method.create_weights( layer=self, - input_size_per_partition=self.input_size, + input_size_per_partition=self.input_size_per_partition, output_partition_sizes=self.output_partition_sizes, input_size=self.input_size, output_size=self.output_size, @@ -1044,22 +1045,24 @@ def __init__(self, reduce_results: bool = True, quant_config: Optional[QuantizationConfig] = None, prefix: str = ""): + # Divide the weight matrix along the first dimension. + self.tp_rank = get_tensor_model_parallel_rank() + self.tp_size = get_tensor_model_parallel_world_size() + self.input_size_per_partition = divide(input_size, self.tp_size) + self.output_size_per_partition = output_size + self.output_partition_sizes = [output_size] + super().__init__(input_size, output_size, skip_bias_add, params_dtype, quant_config, prefix) self.input_is_parallel = input_is_parallel self.reduce_results = reduce_results - # Divide the weight matrix along the last dimension. - self.tp_rank = get_tensor_model_parallel_rank() - self.tp_size = get_tensor_model_parallel_world_size() - self.input_size_per_partition = divide(input_size, self.tp_size) assert self.quant_method is not None - self.quant_method.create_weights( layer=self, input_size_per_partition=self.input_size_per_partition, - output_partition_sizes=[self.output_size], + output_partition_sizes=self.output_partition_sizes, input_size=self.input_size, output_size=self.output_size, params_dtype=self.params_dtype, diff --git a/vllm/model_executor/layers/quantization/awq_marlin.py b/vllm/model_executor/layers/quantization/awq_marlin.py index 8849ba292822..a43b2e597c1e 100644 --- a/vllm/model_executor/layers/quantization/awq_marlin.py +++ b/vllm/model_executor/layers/quantization/awq_marlin.py @@ -13,15 +13,17 @@ from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, UnquantizedLinearMethod, set_weight_attrs) -from vllm.model_executor.layers.quantization.awq import is_layer_skipped_awq +from vllm.model_executor.layers.quantization.awq import (AWQConfig, + is_layer_skipped_awq) from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase) from vllm.model_executor.layers.quantization.utils import replace_parameter from vllm.model_executor.layers.quantization.utils.marlin_utils import ( apply_awq_marlin_linear, awq_to_marlin_zero_points, check_marlin_supported, - marlin_make_empty_g_idx, marlin_make_workspace, marlin_moe_permute_scales, - marlin_permute_scales, moe_awq_to_marlin_zero_points, - verify_marlin_supported, verify_marlin_supports_shape) + check_marlin_supports_layer, marlin_make_empty_g_idx, + marlin_make_workspace, marlin_moe_permute_scales, marlin_permute_scales, + moe_awq_to_marlin_zero_points, verify_marlin_supported, + verify_marlin_supports_shape) from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.parameter import (GroupQuantScaleParameter, PackedvLLMParameter) @@ -40,18 +42,17 @@ class AWQMarlinConfig(QuantizationConfig): 8: scalar_types.uint8, } - def __init__(self, - weight_bits: int, - group_size: int, - zero_point: bool, + def __init__(self, weight_bits: int, group_size: int, zero_point: bool, lm_head_quantized: bool, - modules_to_not_convert: Optional[List[str]] = None) -> None: + modules_to_not_convert: Optional[List[str]], + full_config: Dict[str, Any]) -> None: self.pack_factor = 32 // weight_bits # packed into int32 self.group_size = group_size self.zero_point = zero_point self.lm_head_quantized = lm_head_quantized self.weight_bits = weight_bits self.modules_to_not_convert = modules_to_not_convert or [] + self.full_config = full_config if self.weight_bits not in self.TYPE_MAP: raise ValueError(f"Unsupported num_bits = {self.weight_bits}. " @@ -96,7 +97,7 @@ def from_config(cls, config: Dict[str, Any]) -> "AWQMarlinConfig": modules_to_not_convert = cls.get_from_keys_or( config, ["modules_to_not_convert"], None) return cls(weight_bits, group_size, zero_point, lm_head_quantized, - modules_to_not_convert) + modules_to_not_convert, config) @classmethod def override_quantization_method(cls, hf_quant_cfg, @@ -124,6 +125,13 @@ def get_quant_method(self, layer: torch.nn.Module, (isinstance(layer, ParallelLMHead) and self.lm_head_quantized)): if is_layer_skipped_awq(prefix, self.modules_to_not_convert): return UnquantizedLinearMethod() + # Check if the layer is supported by AWQMarlin. + if not check_marlin_supports_layer(layer, self.group_size): + logger.warning_once( + f"Layer '{prefix}' is not supported by AWQMarlin. " + "Falling back to unoptimized AWQ kernels.") + return AWQConfig.from_config( + self.full_config).get_quant_method(layer, prefix) return AWQMarlinLinearMethod(self) elif isinstance(layer, FusedMoE): return AWQMoEMethod(self) diff --git a/vllm/model_executor/layers/quantization/moe_wna16.py b/vllm/model_executor/layers/quantization/moe_wna16.py index 56fa597e2013..b9460e7d7985 100644 --- a/vllm/model_executor/layers/quantization/moe_wna16.py +++ b/vllm/model_executor/layers/quantization/moe_wna16.py @@ -16,6 +16,8 @@ from vllm.model_executor.layers.quantization.gptq import GPTQConfig from vllm.model_executor.layers.quantization.gptq_marlin import ( GPTQMarlinConfig) +from vllm.model_executor.layers.quantization.utils.marlin_utils import ( + check_marlin_supports_layer) from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform @@ -87,8 +89,8 @@ def from_config(cls, config: Dict[str, Any]) -> "MoeWNA16Config": modules_to_not_convert = [] elif linear_quant_method == "awq": has_zp = cls.get_from_keys(config, ["zero_point"]) - modules_to_not_convert = cls.get_from_keys( - config, ["modules_to_not_convert"]) + modules_to_not_convert = cls.get_from_keys_or( + config, ["modules_to_not_convert"], None) else: raise ValueError("moe_wna16 only support gptq and awq.") @@ -135,7 +137,8 @@ def get_quant_method(self, layer: torch.nn.Module, return GPTQConfig.from_config( self.full_config).get_quant_method(layer, prefix) elif self.linear_quant_method == "awq": - if self.use_marlin: + if self.use_marlin and check_marlin_supports_layer( + layer, self.group_size): return AWQMarlinConfig.from_config( self.full_config).get_quant_method(layer, prefix) else: diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils.py b/vllm/model_executor/layers/quantization/utils/marlin_utils.py index 3beba3083244..05e37251aa16 100644 --- a/vllm/model_executor/layers/quantization/utils/marlin_utils.py +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils.py @@ -6,6 +6,7 @@ import torch from vllm import _custom_ops as ops +from vllm.model_executor.layers.linear import LinearBase from vllm.platforms import current_platform from vllm.scalar_type import ScalarType, scalar_types @@ -135,6 +136,20 @@ def check_marlin_supports_shape(output_size_per_partition: int, return True, None +def check_marlin_supports_layer(layer: LinearBase, group_size: int) \ + -> bool: + output_size_per_partition = getattr(layer, "output_size_per_partition", + None) or layer.output_size + input_size_per_partition = getattr(layer, "input_size_per_partition", + None) or layer.input_size + + return check_marlin_supports_shape( + output_size_per_partition=output_size_per_partition, + input_size_per_partition=input_size_per_partition, + input_size=layer.input_size, + group_size=group_size)[0] + + def marlin_make_workspace(output_size_per_partition: int, device: torch.device) -> torch.Tensor: max_workspace_size = (output_size_per_partition //