From 2123013724c80f9ab6a65d0687c0341cf3e1f702 Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Wed, 12 Jun 2024 01:51:04 +0000 Subject: [PATCH 1/5] marlin 24 support --- .../compressed_tensors/compressed_tensors.py | 39 +++-- .../compressed_tensors/schemes/__init__.py | 1 + .../schemes/compressed_tensors_w4a16_24.py | 134 ++++++++++++++++++ .../quantization/compressed_tensors/utils.py | 8 ++ 4 files changed, 169 insertions(+), 13 deletions(-) create mode 100644 vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_24.py diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py index e134a26efa3d..0bfc5dd092cf 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -7,17 +7,20 @@ from vllm.model_executor.layers.quantization.base_config import ( # noqa: E501 QuantizationConfig) from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( - CompressedTensorsScheme, CompressedTensorsW4A16, + CompressedTensorsScheme, CompressedTensorsW4A16, CompressedTensors24, CompressedTensorsW8A8DynamicToken, CompressedTensorsW8A8StaticTensor) from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( - QuantizationArgs, QuantizationStrategy, find_first_name_or_class_match) + CompressionFormat, QuantizationArgs, QuantizationStrategy, + find_first_name_or_class_match) class CompressedTensorsConfig(QuantizationConfig): - def __init__(self, layer_quant_details: Dict[str, Any], ignore: List[str]): + def __init__(self, layer_quant_details: Dict[str, Any], ignore: List[str], + quant_format: str): self.ignore = ignore self.layer_quant_details = layer_quant_details + self.quant_format = quant_format def get_linear_method(self) -> "CompressedTensorsLinearMethod": return CompressedTensorsLinearMethod(self) @@ -46,6 +49,7 @@ def get_quant_method( def from_config(cls, config: Dict[str, Any]) -> "CompressedTensorsConfig": layer_quant_details: Dict[str, Any] = dict() ignore: List[str] = config.get("ignore", None) + quant_format: str = config.get("format", None) # The quant_config has multiple config_groups, each containing # an input_activations key with details about how the activations are @@ -69,7 +73,9 @@ def from_config(cls, config: Dict[str, Any]) -> "CompressedTensorsConfig": except Exception: layer_quant_details[target]["input_activations"] = None - return cls(layer_quant_details=layer_quant_details, ignore=ignore) + return cls(layer_quant_details=layer_quant_details, + ignore=ignore, + quant_format=quant_format) @classmethod def get_config_filenames(cls) -> List[str]: @@ -110,15 +116,21 @@ def _get_schema(self, weight_quant: BaseModel, input_quant: BaseModel) -> "CompressedTensorsScheme": if self._is_w4a16(weight_quant, input_quant): - return CompressedTensorsW4A16(num_bits=weight_quant.num_bits, - strategy=weight_quant.strategy, - group_size=weight_quant.group_size) - - if self._is_static_tensor_w8a8(weight_quant, input_quant): - return CompressedTensorsW8A8StaticTensor() - - if self._is_dynamic_token_w8a8(weight_quant, input_quant): - return CompressedTensorsW8A8DynamicToken() + if self.quant_format == CompressionFormat.marlin_24.value: + return CompressedTensors24(strategy=weight_quant.strategy, + num_bits=weight_quant.num_bits, + group_size=weight_quant.group_size) + if self.quant_format == CompressionFormat.pack_quantized.value: + return CompressedTensorsW4A16(num_bits=weight_quant.num_bits, + strategy=weight_quant.strategy, + group_size=weight_quant.group_size) + + if self.quant_format == CompressionFormat.int_quantized.value: + if self._is_static_tensor_w8a8(weight_quant, input_quant): + return CompressedTensorsW8A8StaticTensor() + + if self._is_dynamic_token_w8a8(weight_quant, input_quant): + return CompressedTensorsW8A8DynamicToken() raise NotImplementedError("Scheme not supported.") @@ -165,6 +177,7 @@ def create_weights(self, layer: torch.nn.Module, scheme = self.quantization_config.get_scheme(layer=layer) scheme.create_weights( layer=layer, + input_size=input_size, input_size_per_partition=input_size_per_partition, output_partition_sizes=output_partition_sizes, input_size=input_size, diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py index dc84d000803f..90a9e44f1772 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py @@ -2,6 +2,7 @@ from .compressed_tensors_unquantized import ( # noqa: F401 CompressedTensorsUnquantized) from .compressed_tensors_w4a16 import CompressedTensorsW4A16 # noqa: F401 +from .compressed_tensors_w4a16_24 import CompressedTensors24 # noqa: F401 from .compressed_tensors_w8a8_dynamictoken import ( # noqa: F401, E501 CompressedTensorsW8A8DynamicToken) from .compressed_tensors_w8a8_statictensor import ( # noqa: F401, E501 diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_24.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_24.py new file mode 100644 index 000000000000..e67d77d4b25d --- /dev/null +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_24.py @@ -0,0 +1,134 @@ +from typing import Callable, List, Optional + +import torch +from torch.nn import Parameter + +from vllm import _custom_ops as ops +from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( + CompressedTensorsScheme) +from vllm.model_executor.layers.quantization.gptq_marlin_24 import ( + GPTQ_MARLIN_24_MAX_PARALLEL, GPTQ_MARLIN_24_MIN_THREAD_N) +from vllm.model_executor.utils import set_weight_attrs + +__all__ = ["CompressedTensors24"] + + +class CompressedTensors24(CompressedTensorsScheme): + + def __init__(self, + strategy: str, + num_bits: int, + group_size: Optional[int] = None): + self.strategy = strategy + self.group_size = group_size + self.num_bits = num_bits + self.tile_size = 16 + + if self.strategy == "group" and self.group_size is None: + raise ValueError( + "group_size must be given when using strategy group") + + def create_weights(self, layer: torch.nn.Module, input_size: int, + output_partition_sizes: List[int], + input_size_per_partition: int, + params_dtype: torch.dtype, weight_loader: Callable, + **kwargs): + + pack_factor = 32 // self.num_bits + output_size_per_partition = sum(output_partition_sizes) + + qweight = Parameter( + torch.empty( + input_size_per_partition // self.tile_size // 2, + output_size_per_partition * self.tile_size // pack_factor, + dtype=torch.int32, + ), + requires_grad=False, + ) + set_weight_attrs( + qweight, + { + "input_dim": 0, + "output_dim": 1, + "packed_dim": 1, + "pack_factor": pack_factor, + "marlin_tile_size": self.tile_size, + }, + ) + + layer.register_parameter("weight_packed", qweight) + set_weight_attrs(qweight, {"weight_loader": weight_loader}) + + input_groups = (1 if self.group_size is None else + input_size_per_partition // self.group_size) + + scales = Parameter( + torch.empty( + input_groups, + output_size_per_partition, + dtype=params_dtype, + ), + requires_grad=False, + ) + set_weight_attrs( + scales, + { + "output_dim": 1, + "input_dim": None if input_groups == 1 else 0, + }, + ) + layer.register_parameter("scale_packed", scales) + set_weight_attrs(scales, {"weight_loader": weight_loader}) + + weight_shape = Parameter(torch.empty(2, dtype=torch.int64), + requires_grad=False) + + layer.register_parameter("weight_shape", weight_shape) + set_weight_attrs(weight_shape, {"weight_loader": weight_loader}) + + meta = Parameter( + torch.empty( + input_size_per_partition // 8 // 2 // 2, + output_size_per_partition * 2, + dtype=torch.int16, + ), + requires_grad=False, + ) + set_weight_attrs( + meta, + { + "input_dim": 0, + "packed_dim": 1, + "pack_factor": 1, + "output_dim": 1, + "marlin_tile_size": 2, + }, + ) + layer.register_parameter("meta", meta) + set_weight_attrs(meta, {"weight_loader": weight_loader}) + + max_workspace_size = ( + output_size_per_partition // + GPTQ_MARLIN_24_MIN_THREAD_N) * GPTQ_MARLIN_24_MAX_PARALLEL + workspace = Parameter(torch.zeros(max_workspace_size, dtype=torch.int), + requires_grad=False) + layer.workspace = workspace + + def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor): + qweight = layer.weight_packed + meta = layer.meta + scales = layer.scale_packed + workspace = layer.workspace + + x_2d = x.view(-1, x.shape[-1]) + + size_m = x_2d.shape[0] + size_k = x_2d.shape[1] + size_n = scales.shape[1] + + output_2d = ops.gptq_marlin_24_gemm(x_2d, qweight, meta, scales, + workspace, self.num_bits, size_m, + size_n, size_k) + + output = output_2d.view(x.shape[:-1] + (output_2d.shape[1], )) + return output diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/utils.py b/vllm/model_executor/layers/quantization/compressed_tensors/utils.py index fcc664910184..b2bec9b603d1 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/utils.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/utils.py @@ -6,6 +6,14 @@ from torch.nn import Module +class CompressionFormat(Enum): + dense = "dense" + sparse_bitmask = "sparse-bitmask" + int_quantized = "int-quantized" + pack_quantized = "pack-quantized" + marlin_24 = "marlin-24" + + class QuantizationType(str, Enum): """ Enum storing quantization type options From 195630abd9fef732858cb214e883f24e75b4e1ec Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Thu, 13 Jun 2024 17:31:27 +0000 Subject: [PATCH 2/5] clean-up; rebase; update tests --- tests/quantization/test_compressed_tensors.py | 22 ++++++++++++++++--- .../compressed_tensors/compressed_tensors.py | 13 ++++++----- .../schemes/compressed_tensors_w4a16_24.py | 6 ++--- 3 files changed, 29 insertions(+), 12 deletions(-) diff --git a/tests/quantization/test_compressed_tensors.py b/tests/quantization/test_compressed_tensors.py index 5670498f2d1e..d1ab0368ee2b 100644 --- a/tests/quantization/test_compressed_tensors.py +++ b/tests/quantization/test_compressed_tensors.py @@ -8,7 +8,7 @@ from vllm import SamplingParams from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501 - CompressedTensorsLinearMethod, CompressedTensorsW4A16, + CompressedTensors24, CompressedTensorsLinearMethod, CompressedTensorsW4A16, CompressedTensorsW8A8DynamicToken, CompressedTensorsW8A8StaticTensor) @@ -51,8 +51,7 @@ def test_compressed_tensors_no_enforce_eager(vllm_runner): def test_compressed_tensors_w8a8_dynanmic_per_token(vllm_runner): model_path = "nm-testing/tinyllama-oneshot-w8a8-dynamic-token-v2" - with vllm_runner(model_path, enforce_eager=True, - dtype=torch.float16) as llm: + with vllm_runner(model_path, dtype=torch.float16) as llm: model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501 layer = model.model.layers[0] @@ -83,3 +82,20 @@ def test_compressed_tensors_w4a16(vllm_runner, w4a16_args): assert qkv_proj.weight_packed.dtype is torch.int32 assert qkv_proj.weight_scale.dtype is torch.float16 assert qkv_proj.weight_packed.pack_factor == 8 + + +def test_compressed_tensors_w4a16_marlin24(vllm_runner): + model_path = "nm-testing/llama7b-one-shot-2_4-w4a16-marlin24-t" + with vllm_runner(model_path) as llm: + model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501 + layer = model.model.layers[0] + + qkv_proj = layer.self_attn.qkv_proj + + assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod) + assert isinstance(qkv_proj.scheme, CompressedTensors24) + assert qkv_proj.weight_packed.dtype is torch.int32 + + sampling_params = SamplingParams() + output = llm.generate("Hello world!", sampling_params=sampling_params) + assert output diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py index 0bfc5dd092cf..d618f1828f33 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -7,7 +7,7 @@ from vllm.model_executor.layers.quantization.base_config import ( # noqa: E501 QuantizationConfig) from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( - CompressedTensorsScheme, CompressedTensorsW4A16, CompressedTensors24, + CompressedTensors24, CompressedTensorsScheme, CompressedTensorsW4A16, CompressedTensorsW8A8DynamicToken, CompressedTensorsW8A8StaticTensor) from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( CompressionFormat, QuantizationArgs, QuantizationStrategy, @@ -121,9 +121,10 @@ def _get_schema(self, weight_quant: BaseModel, num_bits=weight_quant.num_bits, group_size=weight_quant.group_size) if self.quant_format == CompressionFormat.pack_quantized.value: - return CompressedTensorsW4A16(num_bits=weight_quant.num_bits, - strategy=weight_quant.strategy, - group_size=weight_quant.group_size) + return CompressedTensorsW4A16( + num_bits=weight_quant.num_bits, + strategy=weight_quant.strategy, + group_size=weight_quant.group_size) if self.quant_format == CompressionFormat.int_quantized.value: if self._is_static_tensor_w8a8(weight_quant, input_quant): @@ -132,7 +133,8 @@ def _get_schema(self, weight_quant: BaseModel, if self._is_dynamic_token_w8a8(weight_quant, input_quant): return CompressedTensorsW8A8DynamicToken() - raise NotImplementedError("Scheme not supported.") + raise NotImplementedError( + "No compressed-tensors compatible scheme was found.") def get_scheme(self, layer: torch.nn.Module) -> "CompressedTensorsScheme": @@ -180,7 +182,6 @@ def create_weights(self, layer: torch.nn.Module, input_size=input_size, input_size_per_partition=input_size_per_partition, output_partition_sizes=output_partition_sizes, - input_size=input_size, output_size=output_size, params_dtype=params_dtype, weight_loader=weight_loader) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_24.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_24.py index e67d77d4b25d..1d57636277ec 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_24.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_24.py @@ -53,11 +53,11 @@ def create_weights(self, layer: torch.nn.Module, input_size: int, "packed_dim": 1, "pack_factor": pack_factor, "marlin_tile_size": self.tile_size, + "weight_loader": weight_loader }, ) layer.register_parameter("weight_packed", qweight) - set_weight_attrs(qweight, {"weight_loader": weight_loader}) input_groups = (1 if self.group_size is None else input_size_per_partition // self.group_size) @@ -75,10 +75,10 @@ def create_weights(self, layer: torch.nn.Module, input_size: int, { "output_dim": 1, "input_dim": None if input_groups == 1 else 0, + "weight_loader": weight_loader }, ) layer.register_parameter("scale_packed", scales) - set_weight_attrs(scales, {"weight_loader": weight_loader}) weight_shape = Parameter(torch.empty(2, dtype=torch.int64), requires_grad=False) @@ -102,10 +102,10 @@ def create_weights(self, layer: torch.nn.Module, input_size: int, "pack_factor": 1, "output_dim": 1, "marlin_tile_size": 2, + "weight_loader": weight_loader }, ) layer.register_parameter("meta", meta) - set_weight_attrs(meta, {"weight_loader": weight_loader}) max_workspace_size = ( output_size_per_partition // From 472c0d61f5f57681d2882302c23d1e8ad21f745a Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Mon, 17 Jun 2024 14:49:14 +0000 Subject: [PATCH 3/5] improve the scheme name --- tests/quantization/test_compressed_tensors.py | 7 ++++--- .../compressed_tensors/compressed_tensors.py | 12 +++++++----- .../compressed_tensors/schemes/__init__.py | 2 +- .../schemes/compressed_tensors_w4a16_24.py | 2 +- 4 files changed, 13 insertions(+), 10 deletions(-) diff --git a/tests/quantization/test_compressed_tensors.py b/tests/quantization/test_compressed_tensors.py index d1ab0368ee2b..602a06e9c313 100644 --- a/tests/quantization/test_compressed_tensors.py +++ b/tests/quantization/test_compressed_tensors.py @@ -8,8 +8,9 @@ from vllm import SamplingParams from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501 - CompressedTensors24, CompressedTensorsLinearMethod, CompressedTensorsW4A16, - CompressedTensorsW8A8DynamicToken, CompressedTensorsW8A8StaticTensor) + CompressedTensorsW4A16Sparse24, CompressedTensorsLinearMethod, + CompressedTensorsW4A16, CompressedTensorsW8A8DynamicToken, + CompressedTensorsW8A8StaticTensor) def test_compressed_tensors_w8a8_static_setup(vllm_runner): @@ -93,7 +94,7 @@ def test_compressed_tensors_w4a16_marlin24(vllm_runner): qkv_proj = layer.self_attn.qkv_proj assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod) - assert isinstance(qkv_proj.scheme, CompressedTensors24) + assert isinstance(qkv_proj.scheme, CompressedTensorsW4A16Sparse24) assert qkv_proj.weight_packed.dtype is torch.int32 sampling_params = SamplingParams() diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py index d618f1828f33..4f91d7306102 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -7,8 +7,9 @@ from vllm.model_executor.layers.quantization.base_config import ( # noqa: E501 QuantizationConfig) from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( - CompressedTensors24, CompressedTensorsScheme, CompressedTensorsW4A16, - CompressedTensorsW8A8DynamicToken, CompressedTensorsW8A8StaticTensor) + CompressedTensorsW4A16Sparse24, CompressedTensorsScheme, + CompressedTensorsW4A16, CompressedTensorsW8A8DynamicToken, + CompressedTensorsW8A8StaticTensor) from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( CompressionFormat, QuantizationArgs, QuantizationStrategy, find_first_name_or_class_match) @@ -117,9 +118,10 @@ def _get_schema(self, weight_quant: BaseModel, if self._is_w4a16(weight_quant, input_quant): if self.quant_format == CompressionFormat.marlin_24.value: - return CompressedTensors24(strategy=weight_quant.strategy, - num_bits=weight_quant.num_bits, - group_size=weight_quant.group_size) + return CompressedTensorsW4A16Sparse24( + strategy=weight_quant.strategy, + num_bits=weight_quant.num_bits, + group_size=weight_quant.group_size) if self.quant_format == CompressionFormat.pack_quantized.value: return CompressedTensorsW4A16( num_bits=weight_quant.num_bits, diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py index 90a9e44f1772..393243ebb3fa 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py @@ -2,7 +2,7 @@ from .compressed_tensors_unquantized import ( # noqa: F401 CompressedTensorsUnquantized) from .compressed_tensors_w4a16 import CompressedTensorsW4A16 # noqa: F401 -from .compressed_tensors_w4a16_24 import CompressedTensors24 # noqa: F401 +from .compressed_tensors_w4a16_24 import CompressedTensorsW4A16Sparse24 # noqa: F401 from .compressed_tensors_w8a8_dynamictoken import ( # noqa: F401, E501 CompressedTensorsW8A8DynamicToken) from .compressed_tensors_w8a8_statictensor import ( # noqa: F401, E501 diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_24.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_24.py index 1d57636277ec..0027c7f8b26a 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_24.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_24.py @@ -13,7 +13,7 @@ __all__ = ["CompressedTensors24"] -class CompressedTensors24(CompressedTensorsScheme): +class CompressedTensorsW4A16Sparse24(CompressedTensorsScheme): def __init__(self, strategy: str, From 5549e7ee544ab6c6e9e65ac6f45e702f97280faa Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Mon, 17 Jun 2024 14:50:47 +0000 Subject: [PATCH 4/5] ruff --- .../compressed_tensors/schemes/compressed_tensors_w4a16_24.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_24.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_24.py index 0027c7f8b26a..d7e04ddb8d94 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_24.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_24.py @@ -10,7 +10,7 @@ GPTQ_MARLIN_24_MAX_PARALLEL, GPTQ_MARLIN_24_MIN_THREAD_N) from vllm.model_executor.utils import set_weight_attrs -__all__ = ["CompressedTensors24"] +__all__ = ["CompressedTensorsW4A16Sparse24"] class CompressedTensorsW4A16Sparse24(CompressedTensorsScheme): From 8bafa1580b0491ce6be7fa5675096f355fd8242d Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Mon, 17 Jun 2024 14:53:13 +0000 Subject: [PATCH 5/5] fix isort --- tests/quantization/test_compressed_tensors.py | 4 ++-- .../quantization/compressed_tensors/compressed_tensors.py | 4 ++-- .../quantization/compressed_tensors/schemes/__init__.py | 3 ++- 3 files changed, 6 insertions(+), 5 deletions(-) diff --git a/tests/quantization/test_compressed_tensors.py b/tests/quantization/test_compressed_tensors.py index 602a06e9c313..611c6b8b7fb9 100644 --- a/tests/quantization/test_compressed_tensors.py +++ b/tests/quantization/test_compressed_tensors.py @@ -8,8 +8,8 @@ from vllm import SamplingParams from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501 - CompressedTensorsW4A16Sparse24, CompressedTensorsLinearMethod, - CompressedTensorsW4A16, CompressedTensorsW8A8DynamicToken, + CompressedTensorsLinearMethod, CompressedTensorsW4A16, + CompressedTensorsW4A16Sparse24, CompressedTensorsW8A8DynamicToken, CompressedTensorsW8A8StaticTensor) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py index 4f91d7306102..92a84b3c0dd8 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -7,8 +7,8 @@ from vllm.model_executor.layers.quantization.base_config import ( # noqa: E501 QuantizationConfig) from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( - CompressedTensorsW4A16Sparse24, CompressedTensorsScheme, - CompressedTensorsW4A16, CompressedTensorsW8A8DynamicToken, + CompressedTensorsScheme, CompressedTensorsW4A16, + CompressedTensorsW4A16Sparse24, CompressedTensorsW8A8DynamicToken, CompressedTensorsW8A8StaticTensor) from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( CompressionFormat, QuantizationArgs, QuantizationStrategy, diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py index 393243ebb3fa..3c95aa11fc76 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py @@ -2,7 +2,8 @@ from .compressed_tensors_unquantized import ( # noqa: F401 CompressedTensorsUnquantized) from .compressed_tensors_w4a16 import CompressedTensorsW4A16 # noqa: F401 -from .compressed_tensors_w4a16_24 import CompressedTensorsW4A16Sparse24 # noqa: F401 +from .compressed_tensors_w4a16_24 import ( # noqa: F401 + CompressedTensorsW4A16Sparse24) from .compressed_tensors_w8a8_dynamictoken import ( # noqa: F401, E501 CompressedTensorsW8A8DynamicToken) from .compressed_tensors_w8a8_statictensor import ( # noqa: F401, E501