Skip to content

Commit 1354375

Browse files
nikhil-armmgoin
authored andcommitted
[Feat]: Add support for Dynamic Quant 4 bit CPU kleidiai kernels (vllm-project#17112)
Signed-off-by: Nikhil Gupta <[email protected]> Co-authored-by: mgoin <[email protected]> Signed-off-by: Noam Gat <[email protected]>
1 parent 8d4fb9f commit 1354375

File tree

5 files changed

+269
-11
lines changed

5 files changed

+269
-11
lines changed

vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py

Lines changed: 31 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,10 @@
2626
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
2727
W4A16SPARSE24_SUPPORTED_BITS, WNA16_SUPPORTED_BITS, CompressedTensors24,
2828
CompressedTensorsScheme, CompressedTensorsW4A4Fp4,
29-
CompressedTensorsW4A16Fp4, CompressedTensorsW4A16Sparse24,
30-
CompressedTensorsW8A8Fp8, CompressedTensorsW8A8Int8,
31-
CompressedTensorsW8A16Fp8, CompressedTensorsWNA16)
29+
CompressedTensorsW4A8Int, CompressedTensorsW4A16Fp4,
30+
CompressedTensorsW4A16Sparse24, CompressedTensorsW8A8Fp8,
31+
CompressedTensorsW8A8Int8, CompressedTensorsW8A16Fp8,
32+
CompressedTensorsWNA16)
3233
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
3334
find_matched_target, is_activation_quantization_format,
3435
should_ignore_layer)
@@ -74,7 +75,7 @@ def get_linear_method(self) -> "CompressedTensorsLinearMethod":
7475
return CompressedTensorsLinearMethod(self)
7576

7677
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
77-
return [torch.float16, torch.bfloat16]
78+
return [torch.float32, torch.float16, torch.bfloat16]
7879

7980
@classmethod
8081
def get_min_capability(cls) -> int:
@@ -299,6 +300,22 @@ def _is_dynamic_token_w8a8(self, weight_quant: BaseModel,
299300
# Only symmetric weight quantization supported.
300301
return is_8_bits and is_token and weight_quant.symmetric and is_dynamic
301302

303+
def _is_dynamic_token_w4a8_int(self, weight_quant: BaseModel,
304+
input_quant: BaseModel) -> bool:
305+
is_weight_4_bits = weight_quant.num_bits == 4
306+
is_activation_8_bits = input_quant.num_bits == 8
307+
weight_strategy = (
308+
weight_quant.strategy == QuantizationStrategy.GROUP.value
309+
or weight_quant.strategy == QuantizationStrategy.CHANNEL.value)
310+
is_token = (weight_strategy and input_quant.strategy
311+
== QuantizationStrategy.TOKEN.value)
312+
is_dynamic = not weight_quant.dynamic and input_quant.dynamic
313+
314+
# Both symmetric and asymmetric input quantization supported.
315+
# Only symmetric weight quantization supported.
316+
return (is_weight_4_bits and is_activation_8_bits and is_token
317+
and weight_quant.symmetric and is_dynamic)
318+
302319
def _is_fp8_w8a8(self, weight_quant: BaseModel,
303320
input_quant: BaseModel) -> bool:
304321
# Confirm weights and activations quantized.
@@ -374,7 +391,6 @@ def _is_wNa16_group_channel(self, weight_quant: BaseModel,
374391
def _get_scheme_from_parts(
375392
self, weight_quant: BaseModel,
376393
input_quant: BaseModel) -> "CompressedTensorsScheme":
377-
378394
# Detect If Mixed Precision
379395
if self._is_fp4a16_nvfp4(weight_quant, input_quant):
380396
return CompressedTensorsW4A16Fp4()
@@ -443,6 +459,16 @@ def _get_scheme_from_parts(
443459
is_static_input_scheme=False,
444460
input_symmetric=input_quant.symmetric)
445461

462+
if self._is_dynamic_token_w4a8_int(weight_quant, input_quant):
463+
is_static_input_scheme = (input_quant
464+
and not input_quant.dynamic)
465+
return CompressedTensorsW4A8Int(
466+
num_bits=weight_quant.num_bits,
467+
strategy=weight_quant.strategy,
468+
group_size=weight_quant.group_size,
469+
is_static_input_scheme=is_static_input_scheme,
470+
input_symmetric=input_quant.symmetric)
471+
446472
raise NotImplementedError(
447473
"No compressed-tensors compatible scheme was found.")
448474

vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
from .compressed_tensors_scheme import CompressedTensorsScheme
55
from .compressed_tensors_w4a4_nvfp4 import CompressedTensorsW4A4Fp4
6+
from .compressed_tensors_w4a8_int import CompressedTensorsW4A8Int
67
from .compressed_tensors_w4a16_24 import (W4A16SPARSE24_SUPPORTED_BITS,
78
CompressedTensorsW4A16Sparse24)
89
from .compressed_tensors_w4a16_nvfp4 import CompressedTensorsW4A16Fp4
@@ -20,5 +21,5 @@
2021
"CompressedTensorsW8A8Int8", "CompressedTensorsW8A8Fp8",
2122
"WNA16_SUPPORTED_BITS", "W4A16SPARSE24_SUPPORTED_BITS",
2223
"CompressedTensors24", "CompressedTensorsW4A16Fp4",
23-
"CompressedTensorsW4A4Fp4"
24+
"CompressedTensorsW4A4Fp4", "CompressedTensorsW4A8Int"
2425
]
Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
4+
from typing import Callable, Optional
5+
6+
import torch
7+
8+
from vllm.logger import init_logger
9+
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
10+
CompressedTensorsScheme)
11+
from vllm.model_executor.layers.quantization.kernels.mixed_precision import (
12+
MPLinearLayerConfig, choose_mp_linear_kernel)
13+
from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
14+
GroupQuantScaleParameter,
15+
ModelWeightParameter)
16+
from vllm.scalar_type import scalar_types
17+
18+
logger = init_logger(__name__)
19+
20+
__all__ = ["CompressedTensorsW4A8Int"]
21+
W4A8_SUPPORTED_TYPES_MAP = {
22+
4: scalar_types.int4,
23+
}
24+
W4A8_SUPPORTED_BITS = list(W4A8_SUPPORTED_TYPES_MAP.keys())
25+
26+
27+
class CompressedTensorsW4A8Int(CompressedTensorsScheme):
28+
_kernel_backends_being_used: set[str] = set()
29+
30+
def __init__(self,
31+
strategy: str,
32+
num_bits: int,
33+
group_size: Optional[int] = None,
34+
is_static_input_scheme: bool = False,
35+
input_symmetric: bool = True):
36+
self.strategy = strategy
37+
self.group_size = -1 if group_size is None else group_size
38+
self.is_static_input_scheme = is_static_input_scheme
39+
self.input_symmetric = input_symmetric
40+
41+
if num_bits not in W4A8_SUPPORTED_TYPES_MAP:
42+
raise ValueError(
43+
f"Unsupported num_bits = {num_bits}."
44+
f"Supported num_bits = {W4A8_SUPPORTED_TYPES_MAP.keys()}")
45+
self.quant_type = W4A8_SUPPORTED_TYPES_MAP[num_bits]
46+
47+
@classmethod
48+
def get_min_capability(cls) -> int:
49+
return 1
50+
51+
def create_weights(self, layer: torch.nn.Module, output_size: int,
52+
input_size: int, output_partition_sizes: list[int],
53+
input_size_per_partition: int,
54+
params_dtype: torch.dtype, weight_loader: Callable,
55+
**kwargs):
56+
output_size_per_partition = sum(output_partition_sizes)
57+
row_parallel = (input_size != input_size_per_partition)
58+
59+
# Compute effective group_size
60+
if self.group_size == -1:
61+
effective_group_size = (input_size_per_partition
62+
if row_parallel else input_size)
63+
else:
64+
effective_group_size = self.group_size
65+
66+
# Ensure group_size divides input_size_per_partition
67+
assert input_size_per_partition % effective_group_size == 0, (
68+
f"input_size_per_partition {input_size_per_partition}"
69+
f" not divisible by group_size {effective_group_size}")
70+
71+
# Determine scale partitioning
72+
is_channelwise = (self.group_size == -1)
73+
repeat_scales = (is_channelwise and row_parallel)
74+
partition_scales = not repeat_scales
75+
76+
mp_linear_kernel_config = MPLinearLayerConfig(
77+
full_weight_shape=(input_size, output_size),
78+
partition_weight_shape=(input_size_per_partition,
79+
output_size_per_partition),
80+
weight_type=self.quant_type,
81+
act_type=params_dtype,
82+
group_size=effective_group_size,
83+
zero_points=False,
84+
has_g_idx=False,
85+
)
86+
87+
kernel_type = choose_mp_linear_kernel(mp_linear_kernel_config)
88+
if kernel_type.__name__ not in self._kernel_backends_being_used:
89+
logger.info("Using %s for CompressedTensorsW4A8Int",
90+
kernel_type.__name__)
91+
self._kernel_backends_being_used.add(kernel_type.__name__)
92+
93+
scales_and_zp_size = input_size_per_partition // effective_group_size
94+
95+
weight = ModelWeightParameter(data=torch.empty(
96+
output_size_per_partition,
97+
input_size_per_partition,
98+
dtype=torch.int8),
99+
input_dim=1,
100+
output_dim=0,
101+
weight_loader=weight_loader)
102+
layer.register_parameter("weight", weight)
103+
104+
weight_scale_args = {
105+
"weight_loader":
106+
weight_loader,
107+
"data":
108+
torch.empty(output_size_per_partition,
109+
scales_and_zp_size,
110+
dtype=params_dtype)
111+
}
112+
113+
if partition_scales:
114+
weight_scale = GroupQuantScaleParameter(output_dim=0,
115+
input_dim=1,
116+
**weight_scale_args)
117+
else:
118+
weight_scale = ChannelQuantScaleParameter(output_dim=0,
119+
**weight_scale_args)
120+
121+
layer.register_parameter("weight_packed", weight)
122+
layer.register_parameter("weight_scale", weight_scale)
123+
124+
self.kernel = kernel_type(mp_linear_kernel_config,
125+
w_q_param_name="weight_packed",
126+
w_s_param_name="weight_scale",
127+
w_zp_param_name=None,
128+
w_gidx_param_name=None)
129+
130+
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
131+
self.kernel.process_weights_after_loading(layer)
132+
133+
def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor,
134+
bias: Optional[torch.Tensor]) -> torch.Tensor:
135+
return self.kernel.apply_weights(layer, x, bias)

vllm/model_executor/layers/quantization/kernels/mixed_precision/__init__.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
BitBLASLinearKernel)
1111
from vllm.model_executor.layers.quantization.kernels.mixed_precision.conch import ( # noqa: E501
1212
ConchLinearKernel)
13+
from vllm.model_executor.layers.quantization.kernels.mixed_precision.dynamic_4bit import ( # noqa: E501
14+
Dynamic4bitLinearKernel)
1315
from vllm.model_executor.layers.quantization.kernels.mixed_precision.exllama import ( # noqa: E501
1416
ExllamaLinearKernel)
1517
from vllm.model_executor.layers.quantization.kernels.mixed_precision.machete import ( # noqa: E501
@@ -25,6 +27,7 @@
2527
MacheteLinearKernel,
2628
AllSparkLinearKernel,
2729
MarlinLinearKernel,
30+
Dynamic4bitLinearKernel,
2831
BitBLASLinearKernel,
2932
ConchLinearKernel,
3033
ExllamaLinearKernel,
@@ -56,20 +59,21 @@ def choose_mp_linear_kernel(
5659
if current_platform is None:
5760
raise ValueError("Cannot determine compute capability")
5861
_cc = current_platform.get_device_capability()
59-
compute_capability = _cc[0] * 10 + _cc[1]
62+
if _cc is not None:
63+
compute_capability = _cc[0] * 10 + _cc[1]
6064

6165
failure_reasons = []
6266
for kernel in _POSSIBLE_KERNELS:
6367
if kernel.__name__ in envs.VLLM_DISABLED_KERNELS:
6468
failure_reasons.append(
6569
f' {kernel.__name__} disabled by environment variable')
6670
continue
67-
68-
if kernel.get_min_capability() > compute_capability:
71+
if (compute_capability is not None
72+
and kernel.get_min_capability() > compute_capability):
6973
failure_reasons.append(
7074
f"{kernel.__name__} requires capability "
71-
f"{kernel.get_min_capability()}, current compute capability "
72-
f"is {compute_capability}")
75+
f"{kernel.get_min_capability()}, current compute "
76+
f" capability is {compute_capability}")
7377
continue
7478

7579
can_implement, failure_reason = kernel.can_implement(config)
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
4+
from typing import Optional
5+
6+
import torch
7+
8+
from vllm.model_executor.layers.quantization.utils import replace_parameter
9+
from vllm.platforms import CpuArchEnum, current_platform
10+
from vllm.scalar_type import scalar_types
11+
12+
from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig
13+
14+
15+
class Dynamic4bitLinearKernel(MPLinearKernel):
16+
SUPPORTED_QUANT_TYPES = [scalar_types.int4]
17+
18+
@classmethod
19+
def get_min_capability(cls) -> int:
20+
return 1
21+
22+
@classmethod
23+
def can_implement(cls,
24+
c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]:
25+
if not current_platform.is_cpu():
26+
return False, "Only CPU is supported"
27+
if c.weight_type not in cls.SUPPORTED_QUANT_TYPES:
28+
return False, f"Unsupported quant type {c.weight_type}"
29+
if current_platform.get_cpu_architecture(
30+
) == CpuArchEnum.ARM and c.act_type not in [
31+
torch.float32,
32+
]:
33+
return False, "Dynamic4bitLinearKernel on Arm requires"\
34+
" Float32 activations"
35+
if c.full_weight_shape[0] % c.group_size != 0:
36+
return False, f"Group size ({c.group_size}) does not evenly divide"\
37+
" the number of input features "\
38+
f"({c.full_weight_shape[0]})"
39+
if current_platform.get_cpu_architecture() == CpuArchEnum.ARM:
40+
try:
41+
# Attempt to retrieve the operation
42+
_ = torch.ops.aten._dyn_quant_matmul_4bit
43+
except AttributeError:
44+
return False, f"PyTorch {torch.__version__} does not support"\
45+
" _dyn_quant_matmul_4bit. Install a newer version"
46+
return True, None
47+
48+
def process_weights_after_loading(self, layer: torch.nn.Module):
49+
c = self.config
50+
packed_weight = getattr(layer, self.w_q_name)
51+
packed_weight = packed_weight.add(8)
52+
uint8_packed = (packed_weight[::, 1::2] << 4
53+
| packed_weight[::, ::2]).to(torch.uint8)
54+
55+
scales = getattr(layer, self.w_s_name)
56+
block_size = c.group_size
57+
58+
# Handle scaling factors for partitioned weights
59+
if block_size == c.partition_weight_shape[0]:
60+
scales = scales.to(
61+
torch.float32
62+
) # Float32 & Bfloat16 variants requires float32 scales
63+
scales = scales.view(-1, 1) # Channel-wise scales
64+
if layer.bias is not None:
65+
layer.bias = layer.bias.to(
66+
torch.float32
67+
) # Float32 & Bfloat16 variants requires float32 bias
68+
else:
69+
# KleidiAI kernel requires bfloat16 scales with groupwise scheme
70+
scales = scales.to(torch.bfloat16)
71+
72+
# Repack weights as per kernel requirement
73+
w = torch.ops.aten._dyn_quant_pack_4bit_weight(
74+
uint8_packed, scales, layer.bias, block_size,
75+
c.partition_weight_shape[0], c.partition_weight_shape[1])
76+
replace_parameter(layer, self.w_q_name,
77+
torch.nn.Parameter(w, requires_grad=False))
78+
setattr(layer, self.w_s_name, None)
79+
80+
def apply_weights(self,
81+
layer: torch.nn.Module,
82+
x: torch.Tensor,
83+
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
84+
c = self.config
85+
x_2d = x.reshape(-1, x.shape[-1])
86+
out_shape = x.shape[:-1] + (c.partition_weight_shape[1], )
87+
88+
w_q = getattr(layer, self.w_q_name)
89+
output = torch.ops.aten._dyn_quant_matmul_4bit(
90+
x_2d, w_q, c.group_size, c.partition_weight_shape[0],
91+
c.partition_weight_shape[1])
92+
return output.reshape(out_shape)

0 commit comments

Comments
 (0)