1515from  ...utils  import  AuxStreamType , EventType , Fp4QuantizedTensor 
1616from  .deep_ep_utils  import  buffer_pool , deep_ep_installed 
1717from  .interface  import  MoE 
18- from  .moe_backend  import  MoEBackendSelection 
18+ from  .moe_backend  import  MoEBackend ,  MoEBackendSelection 
1919from  .moe_load_balancer  import  get_moe_load_balancer 
2020from  .quantization  import  (DeepSeekFP8BlockScalesFusedMoEMethod ,
2121                           DeepSeekFP8BlockScalesFusedMoEMethodDeepGemm ,
@@ -234,8 +234,8 @@ def __init__(
234234        self .enable_dummy_allreduce  =  os .environ .get (
235235            "TRTLLM_ENABLE_DUMMY_ALLREDUCE" , "0" ) ==  "1" 
236236
237-         # Select  MoE backend based on configuration  
238-         self .moe_backend  =  None    # Will be initialized after weights are created 
237+         # MoE backend will be lazily initialized when first accessed (see moe_backend property)  
238+         self ._moe_backend  =  None 
239239
240240    def  _check_configs (self ):
241241        assert  self ._weights_created 
@@ -365,8 +365,18 @@ def create_weights(self):
365365        self ._weights_created  =  True 
366366        self ._check_configs ()
367367
368-         # Initialize MoE backend after weights are created 
369-         self .moe_backend  =  MoEBackendSelection .select_backend (self )
368+     @property  
369+     def  moe_backend (self ) ->  MoEBackend :
370+         """ 
371+         Lazily initialize and return the MoE backend. 
372+ 
373+         The backend is selected based on hardware capabilities and quantization 
374+         configuration, which are only available after weights are created. 
375+         """ 
376+         if  self ._moe_backend  is  None :
377+             assert  self ._weights_created , "Weights must be created before accessing moe_backend" 
378+             self ._moe_backend  =  MoEBackendSelection .select_backend (self )
379+         return  self ._moe_backend 
370380
371381    def  dummy_allreduce (self ):
372382        """ 
@@ -422,8 +432,6 @@ def forward_chunk(
422432        if  self .layer_load_balancer  and  is_first_call :
423433            self .layer_load_balancer .start_wait_gpu_stage ()
424434
425-         use_deepseek_fp8_block_scale  =  False 
426-         use_w4_group_scaling  =  False 
427435        weight_dtype  =  self .w3_w1_weight .dtype 
428436
429437        token_selected_experts , token_final_scales  =  self .routing_method .apply (
@@ -578,9 +586,8 @@ def forward_chunk(
578586                    x_sf  =  x_sf .view ((x_row , - 1 ))
579587
580588            elif  self .has_deepseek_fp8_block_scales :
581-                 use_deepseek_fp8_block_scale   =   True 
589+                 pass 
582590            elif  self .has_w4afp8 :
583-                 use_w4_group_scaling  =  True 
584591                weight_dtype  =  torch .quint4x2 
585592            else :
586593                raise  ValueError (
@@ -603,12 +610,12 @@ def forward_chunk(
603610                sizes = None  if  use_dp_padding  else  all_rank_num_tokens )
604611            x_row  =  x .shape [0 ]
605612
606-         ep_size  =  self .ep_size 
607-         ep_rank  =  self .ep_rank 
613+         #  ep_size = self.ep_size
614+         #  ep_rank = self.ep_rank
608615        w3_w1_weight  =  self .w3_w1_weight 
609616        w2_weight  =  self .w2_weight 
610-         cluster_size  =  self .cluster_size 
611-         cluster_rank  =  self .cluster_rank 
617+         #  cluster_size = self.cluster_size
618+         #  cluster_rank = self.cluster_rank
612619        quant_scales  =  self .quant_scales 
613620
614621        if  use_postquant_alltoall :
@@ -697,8 +704,9 @@ def forward_chunk(
697704        #     tuner_top_k=tuner_top_k, 
698705        # ) 
699706
700-         # Use the selected  backend to compute MoE with the same parameters as fused_moe  
707+         # Use backend interface with module as first parameter for automatic configuration extraction  
701708        final_hidden_states  =  self .moe_backend .run_moe (
709+             self ,  # Module as first parameter 
702710            x ,
703711            token_selected_slots ,
704712            token_final_scales ,
@@ -710,21 +718,11 @@ def forward_chunk(
710718            quant_scales = quant_scales ,
711719            input_sf = x_sf ,
712720            swizzled_input_sf = False ,
713-             tp_size = self .tp_size ,
714-             tp_rank = self .tp_rank ,
715-             ep_size = ep_size ,
716-             ep_rank = ep_rank ,
717-             cluster_size = cluster_size ,
718-             cluster_rank = cluster_rank ,
719-             enable_alltoall = use_all_to_all ,
720-             use_deepseek_fp8_block_scale = use_deepseek_fp8_block_scale ,
721-             use_w4_group_scaling = use_w4_group_scaling ,
721+             # Only need to pass runtime-variable parameters 
722722            min_latency_mode = False ,
723-             tune_max_num_tokens = self . tune_max_num_tokens ,
723+             use_fused_finalize = True ,
724724            tuner_num_tokens = tuner_num_tokens ,
725725            tuner_top_k = tuner_top_k ,
726-             module = 
727-             self ,  # Additional parameter for backend to access module properties 
728726        )
729727
730728        # print( 
0 commit comments