1313import vllm .envs as envs
1414import vllm .model_executor .layers .fused_moe .modular_kernel as mk
1515from vllm import _custom_ops as ops
16+ from vllm .distributed import get_tensor_model_parallel_world_size
1617from vllm .logger import init_logger
1718from vllm .model_executor .layers .fused_moe import (
1819 FusedMoE , FusedMoEActivationFormat , FusedMoEConfig , FusedMoEMethodBase ,
3132from vllm .model_executor .layers .quantization .utils .flashinfer_fp4_moe import (
3233 build_flashinfer_fp4_cutlass_moe_prepare_finalize , reorder_w1w3_to_w3w1 ,
3334 select_nvfp4_gemm_impl )
35+ from vllm .model_executor .layers .quantization .utils .fp8_utils import (
36+ expert_weight_is_col_major , get_col_major_tma_aligned_tensor ,
37+ requant_weight_ue8m0_inplace )
3438from vllm .model_executor .layers .quantization .utils .marlin_utils import (
3539 check_moe_marlin_supports_layer , marlin_make_workspace_new ,
3640 marlin_moe_permute_scales )
4549from vllm .model_executor .utils import set_weight_attrs
4650from vllm .platforms import current_platform
4751from vllm .scalar_type import scalar_types
52+ from vllm .utils .deep_gemm import is_deep_gemm_e8m0_used
4853
4954logger = init_logger (__name__ )
5055
@@ -505,10 +510,12 @@ def __init__(
505510 self .weight_quant .strategy == QuantizationStrategy .CHANNEL
506511 and self .input_quant .strategy == QuantizationStrategy .TOKEN )
507512 if not (per_tensor or per_channel ):
508- raise ValueError (
509- "For FP8 Fused MoE layers, we require per tensor "
510- "or channelwise, dynamic per token quantization. Found "
511- f"{ self .weight_quant } , { self .input_quant } " )
513+ assert self .weight_quant .strategy == QuantizationStrategy .BLOCK
514+ self .weight_block_size = self .weight_quant .block_structure
515+ assert self .weight_quant .dynamic is not None
516+ else :
517+ self .weight_block_size = None
518+ self .block_quant = self .weight_block_size is not None
512519
513520 self .static_input_scales = not self .input_quant .dynamic
514521 if self .static_input_scales and per_channel :
@@ -519,7 +526,8 @@ def __init__(
519526 # For GPUs that lack FP8 hardware support, we can leverage the Marlin
520527 # kernel for fast weight-only FP8 quantization
521528 self .use_marlin = (not current_platform .has_device_capability (89 )
522- or envs .VLLM_TEST_FORCE_FP8_MARLIN )
529+ or envs .VLLM_TEST_FORCE_FP8_MARLIN
530+ and not self .block_quant )
523531 # Disable marlin for rocm
524532 if current_platform .is_rocm ():
525533 self .use_marlin = False
@@ -531,8 +539,9 @@ def __init__(
531539 # cutlass path
532540 self .is_fp8_w8a8_sm100 = quant_config ._is_fp8_w8a8_sm100 (
533541 self .weight_quant , self .input_quant )
534- self .use_cutlass = (quant_config ._is_fp8_w8a8_sm90 (
535- self .weight_quant , self .input_quant ) or self .is_fp8_w8a8_sm100 )
542+ self .use_cutlass = not self .block_quant and (
543+ quant_config ._is_fp8_w8a8_sm90 (self .weight_quant , self .input_quant )
544+ or self .is_fp8_w8a8_sm100 )
536545 self .disable_expert_map = False
537546
538547 def create_weights (self , layer : torch .nn .Module , num_experts : int ,
@@ -547,6 +556,31 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int,
547556
548557 params_dtype = torch .float8_e4m3fn
549558
559+ if self .block_quant :
560+ assert self .weight_block_size is not None
561+ layer .weight_block_size = self .weight_block_size
562+ tp_size = get_tensor_model_parallel_world_size ()
563+ block_n , block_k = (
564+ self .weight_block_size [0 ],
565+ self .weight_block_size [1 ],
566+ )
567+ # NOTE: To ensure proper alignment of the block-wise quantization
568+ # scales, the output_size of the weights for both the gate and up
569+ # layers must be divisible by block_n.
570+ # Required by column parallel or enabling merged weights
571+ if intermediate_size_per_partition % block_n != 0 :
572+ raise ValueError (
573+ f"The output_size of gate's and up's weight = "
574+ f"{ intermediate_size_per_partition } is not divisible by "
575+ f"weight quantization block_n = { block_n } ." )
576+ if (tp_size > 1
577+ and intermediate_size_per_partition % block_k != 0 ):
578+ # Required by row parallel
579+ raise ValueError (
580+ f"The input_size of down's weight = "
581+ f"{ intermediate_size_per_partition } is not divisible by "
582+ f"weight quantization block_k = { block_k } ." )
583+
550584 # WEIGHTS
551585 w13_weight = torch .nn .Parameter (torch .empty (
552586 num_experts ,
@@ -602,6 +636,27 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int,
602636 set_weight_attrs (w13_weight_scale , extra_weight_attrs )
603637 set_weight_attrs (w2_weight_scale , extra_weight_attrs )
604638
639+ elif self .weight_quant .strategy == QuantizationStrategy .BLOCK :
640+ w13_weight_scale = torch .nn .Parameter (torch .ones (
641+ num_experts ,
642+ 2 *
643+ ((intermediate_size_per_partition + block_n - 1 ) // block_n ),
644+ (hidden_size + block_k - 1 ) // block_k ,
645+ dtype = torch .float32 ),
646+ requires_grad = False )
647+ layer .register_parameter ("w13_weight_scale" , w13_weight_scale )
648+ w2_weight_scale = torch .nn .Parameter (torch .ones (
649+ num_experts , (hidden_size + block_n - 1 ) // block_n ,
650+ (intermediate_size_per_partition + block_k - 1 ) // block_k ,
651+ dtype = torch .float32 ),
652+ requires_grad = False )
653+ layer .register_parameter ("w2_weight_scale" , w2_weight_scale )
654+ # Add PER-CHANNEL quantization for FusedMoE.weight_loader.
655+ extra_weight_attrs .update (
656+ {"quant_method" : FusedMoeWeightScaleSupported .BLOCK .value })
657+ set_weight_attrs (w13_weight_scale , extra_weight_attrs )
658+ set_weight_attrs (w2_weight_scale , extra_weight_attrs )
659+
605660 # INPUT_SCALES
606661 if self .static_input_scales :
607662 w13_input_scale = torch .nn .Parameter (torch .ones (
@@ -706,6 +761,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
706761 del layer .w2_input_scale
707762
708763 if self .use_cutlass :
764+ assert self .weight_quant .strategy != QuantizationStrategy .BLOCK
709765 device = layer .w13_weight .device
710766 # ab_strides1 and c_strides2 are the same
711767 self .ab_strides1_c_strides2 = torch .full (
@@ -724,6 +780,29 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
724780 device = device ,
725781 dtype = torch .int64 )
726782
783+ if is_deep_gemm_e8m0_used () and self .block_quant :
784+ assert layer .weight_block_size is not None
785+ # Re-quantise the expert weights so their scales are UE8M0.
786+ block_sz = tuple (layer .weight_block_size )
787+ requant_weight_ue8m0_inplace (
788+ layer .w13_weight .data ,
789+ layer .w13_weight_scale .data ,
790+ block_sz ,
791+ )
792+ requant_weight_ue8m0_inplace (
793+ layer .w2_weight .data ,
794+ layer .w2_weight_scale .data ,
795+ block_sz ,
796+ )
797+
798+ # Ensure column-major TMA alignment expected by DeepGEMM.
799+ if expert_weight_is_col_major (layer .w13_weight_scale ):
800+ layer .w13_weight_scale = get_col_major_tma_aligned_tensor (
801+ layer .w13_weight_scale )
802+ if expert_weight_is_col_major (layer .w2_weight_scale ):
803+ layer .w2_weight_scale = get_col_major_tma_aligned_tensor (
804+ layer .w2_weight_scale )
805+
727806 def maybe_make_prepare_finalize (
728807 self ) -> Optional [mk .FusedMoEPrepareAndFinalize ]:
729808 if self .use_marlin or self .rocm_aiter_moe_enabled :
@@ -777,9 +856,10 @@ def select_gemm_impl(
777856 return experts
778857
779858 # triton path
780- from vllm .model_executor .layers .fused_moe import TritonExperts
781- from vllm .model_executor .layers .fused_moe .fused_batched_moe import (
782- BatchedTritonExperts )
859+ from vllm .model_executor .layers .fused_moe .batched_triton_or_deep_gemm_moe import ( # noqa: E501
860+ BatchedTritonOrDeepGemmExperts )
861+ from vllm .model_executor .layers .fused_moe .triton_deep_gemm_moe import (
862+ TritonOrDeepGemmExperts )
783863
784864 assert not self .rocm_aiter_moe_enabled and not self .use_marlin
785865
@@ -790,14 +870,16 @@ def select_gemm_impl(
790870 assert max_num_tokens_per_rank is not None
791871
792872 logger .debug ("BatchedTritonExperts(%s)" , self .__class__ .__name__ )
793- return BatchedTritonExperts (
873+ return BatchedTritonOrDeepGemmExperts (
794874 max_num_tokens = max_num_tokens_per_rank ,
795875 num_dispatchers = prepare_finalize .num_dispatchers (),
796876 quant_config = self .moe_quant_config ,
797877 )
798878 else :
799- logger .debug ("TritonExperts(%s)" , self .__class__ .__name__ )
800- return TritonExperts (self .moe_quant_config )
879+ logger .debug ("TritonOrDeepGemmExperts(%s)" ,
880+ self .__class__ .__name__ )
881+ return TritonOrDeepGemmExperts (self .moe_quant_config ,
882+ allow_deep_gemm = True )
801883
802884 def get_fused_moe_quant_config (
803885 self , layer : torch .nn .Module ) -> Optional [FusedMoEQuantConfig ]:
@@ -816,6 +898,7 @@ def get_fused_moe_quant_config(
816898 a2_scale = layer .w2_input_scale ,
817899 per_act_token_quant = per_act_token ,
818900 per_out_ch_quant = per_channel_quant ,
901+ block_shape = layer .weight_block_size ,
819902 )
820903
821904 def apply (
0 commit comments