1313import  vllm .envs  as  envs 
1414import  vllm .model_executor .layers .fused_moe .modular_kernel  as  mk 
1515from  vllm  import  _custom_ops  as  ops 
16+ from  vllm .distributed  import  get_tensor_model_parallel_world_size 
1617from  vllm .logger  import  init_logger 
1718from  vllm .model_executor .layers .fused_moe  import  (
1819    FusedMoE , FusedMoEActivationFormat , FusedMoEConfig , FusedMoEMethodBase ,
3132from  vllm .model_executor .layers .quantization .utils .flashinfer_fp4_moe  import  (
3233    build_flashinfer_fp4_cutlass_moe_prepare_finalize , reorder_w1w3_to_w3w1 ,
3334    select_nvfp4_gemm_impl )
35+ from  vllm .model_executor .layers .quantization .utils .fp8_utils  import  (
36+     expert_weight_is_col_major , get_col_major_tma_aligned_tensor ,
37+     requant_weight_ue8m0_inplace )
3438from  vllm .model_executor .layers .quantization .utils .marlin_utils  import  (
3539    check_moe_marlin_supports_layer , marlin_make_workspace_new ,
3640    marlin_moe_permute_scales )
4549from  vllm .model_executor .utils  import  set_weight_attrs 
4650from  vllm .platforms  import  current_platform 
4751from  vllm .scalar_type  import  scalar_types 
52+ from  vllm .utils .deep_gemm  import  is_deep_gemm_e8m0_used 
4853
4954logger  =  init_logger (__name__ )
5055
@@ -505,10 +510,12 @@ def __init__(
505510            self .weight_quant .strategy  ==  QuantizationStrategy .CHANNEL 
506511            and  self .input_quant .strategy  ==  QuantizationStrategy .TOKEN )
507512        if  not  (per_tensor  or  per_channel ):
508-             raise  ValueError (
509-                 "For FP8 Fused MoE layers, we require per tensor " 
510-                 "or channelwise, dynamic per token quantization. Found " 
511-                 f"{ self .weight_quant }  , { self .input_quant }  " )
513+             assert  self .weight_quant .strategy  ==  QuantizationStrategy .BLOCK 
514+             self .weight_block_size  =  self .weight_quant .block_structure 
515+             assert  self .weight_quant .dynamic  is  not   None 
516+         else :
517+             self .weight_block_size  =  None 
518+         self .block_quant  =  self .weight_block_size  is  not   None 
512519
513520        self .static_input_scales  =  not  self .input_quant .dynamic 
514521        if  self .static_input_scales  and  per_channel :
@@ -519,7 +526,8 @@ def __init__(
519526        # For GPUs that lack FP8 hardware support, we can leverage the Marlin 
520527        # kernel for fast weight-only FP8 quantization 
521528        self .use_marlin  =  (not  current_platform .has_device_capability (89 )
522-                            or  envs .VLLM_TEST_FORCE_FP8_MARLIN )
529+                            or  envs .VLLM_TEST_FORCE_FP8_MARLIN 
530+                            and  not  self .block_quant )
523531        # Disable marlin for rocm 
524532        if  current_platform .is_rocm ():
525533            self .use_marlin  =  False 
@@ -531,8 +539,9 @@ def __init__(
531539        # cutlass path 
532540        self .is_fp8_w8a8_sm100  =  quant_config ._is_fp8_w8a8_sm100 (
533541            self .weight_quant , self .input_quant )
534-         self .use_cutlass  =  (quant_config ._is_fp8_w8a8_sm90 (
535-             self .weight_quant , self .input_quant ) or  self .is_fp8_w8a8_sm100 )
542+         self .use_cutlass  =  not  self .block_quant  and  (
543+             quant_config ._is_fp8_w8a8_sm90 (self .weight_quant , self .input_quant )
544+             or  self .is_fp8_w8a8_sm100 )
536545        self .disable_expert_map  =  False 
537546
538547    def  create_weights (self , layer : torch .nn .Module , num_experts : int ,
@@ -547,6 +556,31 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int,
547556
548557        params_dtype  =  torch .float8_e4m3fn 
549558
559+         if  self .block_quant :
560+             assert  self .weight_block_size  is  not   None 
561+             layer .weight_block_size  =  self .weight_block_size 
562+             tp_size  =  get_tensor_model_parallel_world_size ()
563+             block_n , block_k  =  (
564+                 self .weight_block_size [0 ],
565+                 self .weight_block_size [1 ],
566+             )
567+             # NOTE: To ensure proper alignment of the block-wise quantization 
568+             # scales, the output_size of the weights for both the gate and up 
569+             # layers must be divisible by block_n. 
570+             # Required by column parallel or enabling merged weights 
571+             if  intermediate_size_per_partition  %  block_n  !=  0 :
572+                 raise  ValueError (
573+                     f"The output_size of gate's and up's weight = " 
574+                     f"{ intermediate_size_per_partition }   is not divisible by " 
575+                     f"weight quantization block_n = { block_n }  ." )
576+             if  (tp_size  >  1 
577+                     and  intermediate_size_per_partition  %  block_k  !=  0 ):
578+                 # Required by row parallel 
579+                 raise  ValueError (
580+                     f"The input_size of down's weight = " 
581+                     f"{ intermediate_size_per_partition }   is not divisible by " 
582+                     f"weight quantization block_k = { block_k }  ." )
583+ 
550584        # WEIGHTS 
551585        w13_weight  =  torch .nn .Parameter (torch .empty (
552586            num_experts ,
@@ -602,6 +636,27 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int,
602636            set_weight_attrs (w13_weight_scale , extra_weight_attrs )
603637            set_weight_attrs (w2_weight_scale , extra_weight_attrs )
604638
639+         elif  self .weight_quant .strategy  ==  QuantizationStrategy .BLOCK :
640+             w13_weight_scale  =  torch .nn .Parameter (torch .ones (
641+                 num_experts ,
642+                 2  * 
643+                 ((intermediate_size_per_partition  +  block_n  -  1 ) //  block_n ),
644+                 (hidden_size  +  block_k  -  1 ) //  block_k ,
645+                 dtype = torch .float32 ),
646+                                                   requires_grad = False )
647+             layer .register_parameter ("w13_weight_scale" , w13_weight_scale )
648+             w2_weight_scale  =  torch .nn .Parameter (torch .ones (
649+                 num_experts , (hidden_size  +  block_n  -  1 ) //  block_n ,
650+                 (intermediate_size_per_partition  +  block_k  -  1 ) //  block_k ,
651+                 dtype = torch .float32 ),
652+                                                  requires_grad = False )
653+             layer .register_parameter ("w2_weight_scale" , w2_weight_scale )
654+             # Add PER-CHANNEL quantization for FusedMoE.weight_loader. 
655+             extra_weight_attrs .update (
656+                 {"quant_method" : FusedMoeWeightScaleSupported .BLOCK .value })
657+             set_weight_attrs (w13_weight_scale , extra_weight_attrs )
658+             set_weight_attrs (w2_weight_scale , extra_weight_attrs )
659+ 
605660        # INPUT_SCALES 
606661        if  self .static_input_scales :
607662            w13_input_scale  =  torch .nn .Parameter (torch .ones (
@@ -706,6 +761,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
706761            del  layer .w2_input_scale 
707762
708763        if  self .use_cutlass :
764+             assert  self .weight_quant .strategy  !=  QuantizationStrategy .BLOCK 
709765            device  =  layer .w13_weight .device 
710766            # ab_strides1 and c_strides2 are the same 
711767            self .ab_strides1_c_strides2  =  torch .full (
@@ -724,6 +780,29 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
724780                device = device ,
725781                dtype = torch .int64 )
726782
783+         if  is_deep_gemm_e8m0_used () and  self .block_quant :
784+             assert  layer .weight_block_size  is  not   None 
785+             # Re-quantise the expert weights so their scales are UE8M0. 
786+             block_sz  =  tuple (layer .weight_block_size )
787+             requant_weight_ue8m0_inplace (
788+                 layer .w13_weight .data ,
789+                 layer .w13_weight_scale .data ,
790+                 block_sz ,
791+             )
792+             requant_weight_ue8m0_inplace (
793+                 layer .w2_weight .data ,
794+                 layer .w2_weight_scale .data ,
795+                 block_sz ,
796+             )
797+ 
798+             # Ensure column-major TMA alignment expected by DeepGEMM. 
799+             if  expert_weight_is_col_major (layer .w13_weight_scale ):
800+                 layer .w13_weight_scale  =  get_col_major_tma_aligned_tensor (
801+                     layer .w13_weight_scale )
802+             if  expert_weight_is_col_major (layer .w2_weight_scale ):
803+                 layer .w2_weight_scale  =  get_col_major_tma_aligned_tensor (
804+                     layer .w2_weight_scale )
805+ 
727806    def  maybe_make_prepare_finalize (
728807            self ) ->  Optional [mk .FusedMoEPrepareAndFinalize ]:
729808        if  self .use_marlin  or  self .rocm_aiter_moe_enabled :
@@ -777,9 +856,10 @@ def select_gemm_impl(
777856            return  experts 
778857
779858        # triton path 
780-         from  vllm .model_executor .layers .fused_moe  import  TritonExperts 
781-         from  vllm .model_executor .layers .fused_moe .fused_batched_moe  import  (
782-             BatchedTritonExperts )
859+         from  vllm .model_executor .layers .fused_moe .batched_triton_or_deep_gemm_moe  import  (  # noqa: E501 
860+             BatchedTritonOrDeepGemmExperts )
861+         from  vllm .model_executor .layers .fused_moe .triton_deep_gemm_moe  import  (
862+             TritonOrDeepGemmExperts )
783863
784864        assert  not  self .rocm_aiter_moe_enabled  and  not  self .use_marlin 
785865
@@ -790,14 +870,16 @@ def select_gemm_impl(
790870            assert  max_num_tokens_per_rank  is  not   None 
791871
792872            logger .debug ("BatchedTritonExperts(%s)" , self .__class__ .__name__ )
793-             return  BatchedTritonExperts (
873+             return  BatchedTritonOrDeepGemmExperts (
794874                max_num_tokens = max_num_tokens_per_rank ,
795875                num_dispatchers = prepare_finalize .num_dispatchers (),
796876                quant_config = self .moe_quant_config ,
797877            )
798878        else :
799-             logger .debug ("TritonExperts(%s)" , self .__class__ .__name__ )
800-             return  TritonExperts (self .moe_quant_config )
879+             logger .debug ("TritonOrDeepGemmExperts(%s)" ,
880+                          self .__class__ .__name__ )
881+             return  TritonOrDeepGemmExperts (self .moe_quant_config ,
882+                                            allow_deep_gemm = True )
801883
802884    def  get_fused_moe_quant_config (
803885            self , layer : torch .nn .Module ) ->  Optional [FusedMoEQuantConfig ]:
@@ -816,6 +898,7 @@ def get_fused_moe_quant_config(
816898            a2_scale = layer .w2_input_scale ,
817899            per_act_token_quant = per_act_token ,
818900            per_out_ch_quant = per_channel_quant ,
901+             block_shape = layer .weight_block_size ,
819902        )
820903
821904    def  apply (
0 commit comments