Skip to content

Commit 02a3ce2

Browse files
bnellnmmgoin
authored andcommitted
[Kernels] Support blocked fp8 quantization for compressed tensors MoE (#25219)
Signed-off-by: Bill Nell <[email protected]> Co-authored-by: Michael Goin <[email protected]> Signed-off-by: yewentao256 <[email protected]>
1 parent 9cae377 commit 02a3ce2

File tree

4 files changed

+112
-29
lines changed

4 files changed

+112
-29
lines changed

vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py

Lines changed: 96 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import vllm.envs as envs
1414
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
1515
from vllm import _custom_ops as ops
16+
from vllm.distributed import get_tensor_model_parallel_world_size
1617
from vllm.logger import init_logger
1718
from vllm.model_executor.layers.fused_moe import (
1819
FusedMoE, FusedMoEActivationFormat, FusedMoEConfig, FusedMoEMethodBase,
@@ -31,6 +32,9 @@
3132
from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import (
3233
build_flashinfer_fp4_cutlass_moe_prepare_finalize, reorder_w1w3_to_w3w1,
3334
select_nvfp4_gemm_impl)
35+
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
36+
expert_weight_is_col_major, get_col_major_tma_aligned_tensor,
37+
requant_weight_ue8m0_inplace)
3438
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
3539
check_moe_marlin_supports_layer, marlin_make_workspace_new,
3640
marlin_moe_permute_scales)
@@ -45,6 +49,7 @@
4549
from vllm.model_executor.utils import set_weight_attrs
4650
from vllm.platforms import current_platform
4751
from vllm.scalar_type import scalar_types
52+
from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used
4853

4954
logger = init_logger(__name__)
5055

@@ -505,10 +510,12 @@ def __init__(
505510
self.weight_quant.strategy == QuantizationStrategy.CHANNEL
506511
and self.input_quant.strategy == QuantizationStrategy.TOKEN)
507512
if not (per_tensor or per_channel):
508-
raise ValueError(
509-
"For FP8 Fused MoE layers, we require per tensor "
510-
"or channelwise, dynamic per token quantization. Found "
511-
f"{self.weight_quant}, {self.input_quant}")
513+
assert self.weight_quant.strategy == QuantizationStrategy.BLOCK
514+
self.weight_block_size = self.weight_quant.block_structure
515+
assert self.weight_quant.dynamic is not None
516+
else:
517+
self.weight_block_size = None
518+
self.block_quant = self.weight_block_size is not None
512519

513520
self.static_input_scales = not self.input_quant.dynamic
514521
if self.static_input_scales and per_channel:
@@ -519,7 +526,8 @@ def __init__(
519526
# For GPUs that lack FP8 hardware support, we can leverage the Marlin
520527
# kernel for fast weight-only FP8 quantization
521528
self.use_marlin = (not current_platform.has_device_capability(89)
522-
or envs.VLLM_TEST_FORCE_FP8_MARLIN)
529+
or envs.VLLM_TEST_FORCE_FP8_MARLIN
530+
and not self.block_quant)
523531
# Disable marlin for rocm
524532
if current_platform.is_rocm():
525533
self.use_marlin = False
@@ -531,8 +539,9 @@ def __init__(
531539
# cutlass path
532540
self.is_fp8_w8a8_sm100 = quant_config._is_fp8_w8a8_sm100(
533541
self.weight_quant, self.input_quant)
534-
self.use_cutlass = (quant_config._is_fp8_w8a8_sm90(
535-
self.weight_quant, self.input_quant) or self.is_fp8_w8a8_sm100)
542+
self.use_cutlass = not self.block_quant and (
543+
quant_config._is_fp8_w8a8_sm90(self.weight_quant, self.input_quant)
544+
or self.is_fp8_w8a8_sm100)
536545
self.disable_expert_map = False
537546

538547
def create_weights(self, layer: torch.nn.Module, num_experts: int,
@@ -547,6 +556,31 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int,
547556

548557
params_dtype = torch.float8_e4m3fn
549558

559+
if self.block_quant:
560+
assert self.weight_block_size is not None
561+
layer.weight_block_size = self.weight_block_size
562+
tp_size = get_tensor_model_parallel_world_size()
563+
block_n, block_k = (
564+
self.weight_block_size[0],
565+
self.weight_block_size[1],
566+
)
567+
# NOTE: To ensure proper alignment of the block-wise quantization
568+
# scales, the output_size of the weights for both the gate and up
569+
# layers must be divisible by block_n.
570+
# Required by column parallel or enabling merged weights
571+
if intermediate_size_per_partition % block_n != 0:
572+
raise ValueError(
573+
f"The output_size of gate's and up's weight = "
574+
f"{intermediate_size_per_partition} is not divisible by "
575+
f"weight quantization block_n = {block_n}.")
576+
if (tp_size > 1
577+
and intermediate_size_per_partition % block_k != 0):
578+
# Required by row parallel
579+
raise ValueError(
580+
f"The input_size of down's weight = "
581+
f"{intermediate_size_per_partition} is not divisible by "
582+
f"weight quantization block_k = {block_k}.")
583+
550584
# WEIGHTS
551585
w13_weight = torch.nn.Parameter(torch.empty(
552586
num_experts,
@@ -602,6 +636,27 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int,
602636
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
603637
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
604638

639+
elif self.weight_quant.strategy == QuantizationStrategy.BLOCK:
640+
w13_weight_scale = torch.nn.Parameter(torch.ones(
641+
num_experts,
642+
2 *
643+
((intermediate_size_per_partition + block_n - 1) // block_n),
644+
(hidden_size + block_k - 1) // block_k,
645+
dtype=torch.float32),
646+
requires_grad=False)
647+
layer.register_parameter("w13_weight_scale", w13_weight_scale)
648+
w2_weight_scale = torch.nn.Parameter(torch.ones(
649+
num_experts, (hidden_size + block_n - 1) // block_n,
650+
(intermediate_size_per_partition + block_k - 1) // block_k,
651+
dtype=torch.float32),
652+
requires_grad=False)
653+
layer.register_parameter("w2_weight_scale", w2_weight_scale)
654+
# Add PER-CHANNEL quantization for FusedMoE.weight_loader.
655+
extra_weight_attrs.update(
656+
{"quant_method": FusedMoeWeightScaleSupported.BLOCK.value})
657+
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
658+
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
659+
605660
# INPUT_SCALES
606661
if self.static_input_scales:
607662
w13_input_scale = torch.nn.Parameter(torch.ones(
@@ -706,6 +761,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
706761
del layer.w2_input_scale
707762

708763
if self.use_cutlass:
764+
assert self.weight_quant.strategy != QuantizationStrategy.BLOCK
709765
device = layer.w13_weight.device
710766
# ab_strides1 and c_strides2 are the same
711767
self.ab_strides1_c_strides2 = torch.full(
@@ -724,6 +780,29 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
724780
device=device,
725781
dtype=torch.int64)
726782

783+
if is_deep_gemm_e8m0_used() and self.block_quant:
784+
assert layer.weight_block_size is not None
785+
# Re-quantise the expert weights so their scales are UE8M0.
786+
block_sz = tuple(layer.weight_block_size)
787+
requant_weight_ue8m0_inplace(
788+
layer.w13_weight.data,
789+
layer.w13_weight_scale.data,
790+
block_sz,
791+
)
792+
requant_weight_ue8m0_inplace(
793+
layer.w2_weight.data,
794+
layer.w2_weight_scale.data,
795+
block_sz,
796+
)
797+
798+
# Ensure column-major TMA alignment expected by DeepGEMM.
799+
if expert_weight_is_col_major(layer.w13_weight_scale):
800+
layer.w13_weight_scale = get_col_major_tma_aligned_tensor(
801+
layer.w13_weight_scale)
802+
if expert_weight_is_col_major(layer.w2_weight_scale):
803+
layer.w2_weight_scale = get_col_major_tma_aligned_tensor(
804+
layer.w2_weight_scale)
805+
727806
def maybe_make_prepare_finalize(
728807
self) -> Optional[mk.FusedMoEPrepareAndFinalize]:
729808
if self.use_marlin or self.rocm_aiter_moe_enabled:
@@ -777,9 +856,10 @@ def select_gemm_impl(
777856
return experts
778857

779858
# triton path
780-
from vllm.model_executor.layers.fused_moe import TritonExperts
781-
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
782-
BatchedTritonExperts)
859+
from vllm.model_executor.layers.fused_moe.batched_triton_or_deep_gemm_moe import ( # noqa: E501
860+
BatchedTritonOrDeepGemmExperts)
861+
from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import (
862+
TritonOrDeepGemmExperts)
783863

784864
assert not self.rocm_aiter_moe_enabled and not self.use_marlin
785865

@@ -790,14 +870,16 @@ def select_gemm_impl(
790870
assert max_num_tokens_per_rank is not None
791871

792872
logger.debug("BatchedTritonExperts(%s)", self.__class__.__name__)
793-
return BatchedTritonExperts(
873+
return BatchedTritonOrDeepGemmExperts(
794874
max_num_tokens=max_num_tokens_per_rank,
795875
num_dispatchers=prepare_finalize.num_dispatchers(),
796876
quant_config=self.moe_quant_config,
797877
)
798878
else:
799-
logger.debug("TritonExperts(%s)", self.__class__.__name__)
800-
return TritonExperts(self.moe_quant_config)
879+
logger.debug("TritonOrDeepGemmExperts(%s)",
880+
self.__class__.__name__)
881+
return TritonOrDeepGemmExperts(self.moe_quant_config,
882+
allow_deep_gemm=True)
801883

802884
def get_fused_moe_quant_config(
803885
self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]:
@@ -816,6 +898,7 @@ def get_fused_moe_quant_config(
816898
a2_scale=layer.w2_input_scale,
817899
per_act_token_quant=per_act_token,
818900
per_out_ch_quant=per_channel_quant,
901+
block_shape=layer.weight_block_size,
819902
)
820903

821904
def apply(

vllm/model_executor/layers/quantization/fp8.py

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,10 @@
3333
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
3434
apply_fp8_block_linear, check_aiter_fp8_linear_support,
3535
create_fp8_input_scale, create_fp8_scale_parameter,
36-
create_fp8_weight_parameter, get_col_major_tma_aligned_tensor,
37-
maybe_post_process_fp8_weight_block, process_fp8_weight_block_strategy,
38-
process_fp8_weight_tensor_strategy, requant_weight_ue8m0_inplace,
39-
validate_fp8_block_shape)
36+
create_fp8_weight_parameter, expert_weight_is_col_major,
37+
get_col_major_tma_aligned_tensor, maybe_post_process_fp8_weight_block,
38+
process_fp8_weight_block_strategy, process_fp8_weight_tensor_strategy,
39+
requant_weight_ue8m0_inplace, validate_fp8_block_shape)
4040
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
4141
apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin,
4242
prepare_moe_fp8_layer_for_marlin)
@@ -64,12 +64,6 @@
6464
logger = init_logger(__name__)
6565

6666

67-
def _is_col_major(x: torch.Tensor) -> bool:
68-
assert x.dim() == 3
69-
b, m, n = x.shape
70-
return x.stride(0) == m * n and x.stride(1) == 1 and x.stride(2) == m
71-
72-
7367
class Fp8Config(QuantizationConfig):
7468
"""Config class for FP8."""
7569

@@ -660,10 +654,10 @@ def process_weights_after_loading(self, layer: Module) -> None:
660654
# DeepGemm scales need to be transposed and aligned. We try to do
661655
# it ahead of time for performance reasons.
662656
if self.allow_deep_gemm and not is_deep_gemm_e8m0_used():
663-
if _is_col_major(layer.w13_weight_scale_inv):
657+
if expert_weight_is_col_major(layer.w13_weight_scale_inv):
664658
layer.w13_weight_scale_inv = \
665659
get_col_major_tma_aligned_tensor(layer.w13_weight_scale_inv)
666-
if _is_col_major(layer.w2_weight_scale_inv):
660+
if expert_weight_is_col_major(layer.w2_weight_scale_inv):
667661
layer.w2_weight_scale_inv = \
668662
get_col_major_tma_aligned_tensor(layer.w2_weight_scale_inv)
669663

@@ -811,10 +805,10 @@ def process_weights_after_loading(self, layer: Module) -> None:
811805
)
812806

813807
# Ensure column-major TMA alignment expected by DeepGEMM.
814-
if _is_col_major(layer.w13_weight_scale_inv):
808+
if expert_weight_is_col_major(layer.w13_weight_scale_inv):
815809
layer.w13_weight_scale_inv = get_col_major_tma_aligned_tensor(
816810
layer.w13_weight_scale_inv)
817-
if _is_col_major(layer.w2_weight_scale_inv):
811+
if expert_weight_is_col_major(layer.w2_weight_scale_inv):
818812
layer.w2_weight_scale_inv = get_col_major_tma_aligned_tensor(
819813
layer.w2_weight_scale_inv)
820814

vllm/model_executor/layers/quantization/utils/fp8_utils.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1014,3 +1014,9 @@ def apply_fp8_block_linear(layer: torch.nn.Module, input: torch.Tensor,
10141014
cutlass_block_fp8_supported=cutlass_block_fp8_supported,
10151015
use_aiter_and_is_supported=use_aiter_and_is_supported,
10161016
)
1017+
1018+
1019+
def expert_weight_is_col_major(x: torch.Tensor) -> bool:
1020+
assert x.dim() == 3
1021+
b, m, n = x.shape
1022+
return x.stride(0) == m * n and x.stride(1) == 1 and x.stride(2) == m

vllm/model_executor/warmup/deep_gemm_warmup.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,9 +53,9 @@ def _extract_data_from_fused_moe_module(
5353
"""
5454
assert isinstance(m, FusedMoE)
5555
w13 = m.w13_weight
56-
w13_s = m.w13_weight_scale_inv
56+
w13_s = getattr(m, "w13_weight_scale_inv", m.w13_weight_scale)
5757
w2 = m.w2_weight
58-
w2_s = m.w2_weight_scale_inv
58+
w2_s = getattr(m, "w2_weight_scale_inv", m.w2_weight_scale)
5959
num_topk = m.top_k
6060

6161
assert isinstance(w13, torch.Tensor)

0 commit comments

Comments
 (0)