Skip to content

Commit 2c27c69

Browse files
kylesayrsAlvant
authored andcommitted
[Misc] compressed-tensors code reuse (vllm-project#7277)
Signed-off-by: Alvant <[email protected]>
1 parent 71e2eb1 commit 2c27c69

File tree

8 files changed

+13
-85
lines changed

8 files changed

+13
-85
lines changed

requirements-common.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,3 +23,4 @@ pyzmq
2323
librosa # Required for audio processing
2424
soundfile # Required for audio processing
2525
gguf == 0.9.1
26+
compressed-tensors == 0.5.0

requirements-test.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ peft
1717
requests
1818
ray
1919
sentence-transformers # required for embedding
20-
compressed-tensors==0.4.0 # required for compressed-tensors
20+
compressed-tensors==0.5.0 # required for compressed-tensors
2121
timm # required for internvl test
2222

2323
# TODO: Add this after fully implementing llava(mantis)

tests/quantization/test_compressed_tensors.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,12 @@
55

66
import pytest
77
import torch
8+
from compressed_tensors.quantization import QuantizationType
89

910
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501
1011
CompressedTensorsLinearMethod, CompressedTensorsW4A16Sparse24,
1112
CompressedTensorsW8A8Fp8, CompressedTensorsW8A8Int8,
1213
CompressedTensorsW8A16Fp8, CompressedTensorsWNA16)
13-
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
14-
QuantizationType)
1514

1615

1716
@pytest.mark.parametrize("model_args", [

vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
from typing import Any, Dict, List, Optional
22

33
import torch
4+
from compressed_tensors.config import CompressionFormat
5+
from compressed_tensors.quantization import (QuantizationArgs,
6+
QuantizationStrategy,
7+
QuantizationType)
48
from pydantic import BaseModel
59

610
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
@@ -13,8 +17,7 @@
1317
CompressedTensorsW8A8Int8, CompressedTensorsW8A16Fp8,
1418
CompressedTensorsWNA16)
1519
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
16-
CompressionFormat, QuantizationArgs, QuantizationStrategy,
17-
QuantizationType, find_matched_target, is_activation_quantization_format,
20+
find_matched_target, is_activation_quantization_format,
1821
should_ignore_layer)
1922
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
2023
from vllm.platforms import current_platform

vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
11
from typing import Callable, List, Optional
22

33
import torch
4+
from compressed_tensors.quantization import QuantizationStrategy
45

56
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
67
CompressedTensorsScheme)
7-
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
8-
QuantizationStrategy)
98
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
109
apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin)
1110
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (

vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
11
from typing import Callable, List, Optional
22

33
import torch
4+
from compressed_tensors.quantization import QuantizationStrategy
45
from torch.nn import Parameter
56

67
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
78
CompressedTensorsScheme)
8-
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
9-
QuantizationStrategy)
109
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
1110
apply_fp8_linear, cutlass_fp8_supported, requantize_with_max_scale)
1211
from vllm.model_executor.parameter import (ChannelQuantScaleParameter,

vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
11
from typing import Callable, List, Optional
22

33
import torch
4+
from compressed_tensors.quantization import QuantizationStrategy
45
from torch.nn import Parameter
56

67
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
78
CompressedTensorsScheme)
8-
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
9-
QuantizationStrategy)
109
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
1110
apply_int8_linear, convert_to_channelwise)
1211
from vllm.model_executor.parameter import (BasevLLMParameter,

vllm/model_executor/layers/quantization/compressed_tensors/utils.py

Lines changed: 2 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -1,85 +1,13 @@
11
import re
2-
from enum import Enum
3-
from typing import Any, Dict, Iterable, Optional
2+
from typing import Iterable, Optional
43

5-
from pydantic import BaseModel, Field
4+
from compressed_tensors import CompressionFormat
65
from torch.nn import Module
76

87
from vllm.model_executor.layers.quantization.utils.quant_utils import (
98
FUSED_LAYER_NAME_MAPPING)
109

1110

12-
class CompressionFormat(Enum):
13-
dense = "dense"
14-
sparse_bitmask = "sparse-bitmask"
15-
naive_quantized = "naive-quantized"
16-
float_quantized = "float-quantized"
17-
int_quantized = "int-quantized"
18-
pack_quantized = "pack-quantized"
19-
marlin_24 = "marlin-24"
20-
21-
22-
class QuantizationType(str, Enum):
23-
"""
24-
Enum storing quantization type options
25-
"""
26-
27-
INT = "int"
28-
FLOAT = "float"
29-
30-
31-
class QuantizationStrategy(str, Enum):
32-
"""
33-
Enum storing quantization strategy options
34-
"""
35-
36-
TENSOR = "tensor"
37-
CHANNEL = "channel"
38-
GROUP = "group"
39-
BLOCK = "block"
40-
TOKEN = "token"
41-
42-
43-
class QuantizationArgs(BaseModel):
44-
"""
45-
User facing arguments used to define a quantization config
46-
for weights or activations
47-
48-
:param num_bits: quantization bit depth
49-
:param type: dtype to quantized to, either int or float
50-
:param symmetric: whether or not quantization scale is symmetric
51-
:param strategy: string determining the scope of scale/zero-point to apply
52-
:param group_size: group length to use for the group strategy
53-
:param block_structure: 2d block structure to use for the block
54-
strategy, must be of the format "2x4", "8x16", etc.
55-
:param dynamic: set True to perform dynamic quantization -
56-
values will not be calibrated during calibration phase,
57-
instead during inference new quantization ranges will be
58-
observed with every sample. Defaults to False for static
59-
quantization. Note that enabling dynamic quantization
60-
will change the default observer to a memoryless one
61-
"""
62-
63-
num_bits: int = 8
64-
type: QuantizationType = QuantizationType.INT
65-
symmetric: bool = True
66-
group_size: Optional[int] = None
67-
strategy: Optional[QuantizationStrategy] = None
68-
block_structure: Optional[str] = None
69-
dynamic: bool = False
70-
observer: str = Field(
71-
default="minmax",
72-
description=("The class to use to compute the quantization param - "
73-
"scale and zero-point'"),
74-
)
75-
observer_kwargs: Dict[str, Any] = Field(
76-
default_factory=dict,
77-
description=
78-
("optional dict of kwargs to be passed directly to torch quantization "
79-
"Observers constructor excluding quantization range or symmetry"),
80-
)
81-
82-
8311
def is_activation_quantization_format(format: str) -> bool:
8412
_ACTIVATION_QUANTIZATION_FORMATS = [
8513
CompressionFormat.naive_quantized.value,

0 commit comments

Comments
 (0)