@@ -85,7 +85,7 @@ def _moe_problem_size(
8585 M = a1 .size (0 )
8686 else :
8787 assert a1 .dim () == 3
88- # assert a1.size(0) == E, f"{a1.size(0)} == {E}"
88+ assert a1 .size (0 ) == E , f"{ a1 .size (0 )} == { E } "
8989 M = a1 .size (1 ) # This is max_num_tokens
9090
9191 assert topk_ids .dim () == 2
@@ -536,11 +536,12 @@ def apply(
536536 global_num_experts : int ,
537537 expert_map : Optional [torch .Tensor ],
538538 a1q_scale : Optional [torch .Tensor ],
539+ a2_scale : Optional [torch .Tensor ],
539540 workspace13 : torch .Tensor ,
540541 workspace2 : torch .Tensor ,
541542 expert_tokens_meta : Optional [ExpertTokensMetadata ],
542543 apply_router_weight_on_input : bool ,
543- ):
544+ ) -> None :
544545 """
545546 This function computes the intermediate result of a Mixture of Experts
546547 (MoE) layer using two sets of weights, w1 and w2.
@@ -674,22 +675,22 @@ def _allocate_buffers(
674675
675676 # We can reuse the memory between cache1 and cache3 because by the
676677 # time we need cache3, we're done with cache1.
677- workspace13 = torch . zeros ( prod ( workspace13_shape ) ,
678- device = device ,
679- dtype = workspace_dtype )
680- workspace2 = torch . zeros ( prod ( workspace2_shape ) ,
681- device = device ,
682- dtype = workspace_dtype )
678+ workspace13 = self . workspace13_buffer . get ( workspace13_shape ,
679+ device = device ,
680+ dtype = workspace_dtype )
681+ workspace2 = self . workspace2_buffer . get ( workspace2_shape ,
682+ device = device ,
683+ dtype = workspace_dtype )
683684
684685 # Construct the entire output that can then be processed in chunks.
685686 if num_chunks == 1 and prod (workspace13_shape ) >= prod (
686687 fused_out_shape ):
687688 # Reuse workspace13 for the output in the non-chunked case.
688689 fused_out = _resize_cache (workspace13 , fused_out_shape )
689690 else :
690- fused_out = torch . empty (fused_out_shape ,
691- device = device ,
692- dtype = out_dtype )
691+ fused_out = self . fused_out_buffer . get (fused_out_shape ,
692+ device = device ,
693+ dtype = out_dtype )
693694
694695 return workspace13 , workspace2 , fused_out
695696
@@ -785,7 +786,10 @@ def forward(
785786 - torch.Tensor: The output tensor after applying the MoE layer.
786787 """
787788
788- output = hidden_states if inplace else torch .zeros_like (hidden_states )
789+ if inplace and self .shared_experts is None :
790+ output = hidden_states
791+ else :
792+ output = torch .zeros_like (hidden_states )
789793
790794 local_num_experts = w1 .size (0 )
791795 if global_num_experts == - 1 :
@@ -799,8 +803,6 @@ def forward(
799803 (a1q , a1q_scale , expert_tokens_meta , _expert_topk_ids ,
800804 _expert_topk_weights ) = self .prepare_finalize .prepare (
801805 hidden_states ,
802- a1_scale ,
803- a2_scale ,
804806 topk_weights ,
805807 topk_ids ,
806808 global_num_experts ,
@@ -810,10 +812,9 @@ def forward(
810812 )
811813 else :
812814 # Overlap shared expert compute with all2all dispatch.
813- receiver = self .prepare_finalize .prepare_async (
815+ dbo_maybe_run_recv_hook ()
816+ hook , receiver = self .prepare_finalize .prepare_async (
814817 hidden_states ,
815- a1_scale ,
816- a2_scale ,
817818 topk_weights ,
818819 topk_ids ,
819820 global_num_experts ,
@@ -838,6 +839,8 @@ def forward(
838839 topk_weights = (topk_weights if _expert_topk_weights is None else
839840 _expert_topk_weights )
840841
842+ fused_out = None
843+
841844 if a1q .numel () == 0 :
842845 # This happens when none of the tokens from the all2all reach this
843846 # EP rank. Also, note that this is only relevant for CUDAGraph
@@ -853,7 +856,7 @@ def forward(
853856 CHUNK_SIZE = envs .VLLM_FUSED_MOE_CHUNK_SIZE
854857 num_chunks = cdiv (M , CHUNK_SIZE )
855858 else :
856- CHUNK_SIZE = M #a1q.size(0)
859+ CHUNK_SIZE = M #a1q.size(0)
857860 num_chunks = 1
858861
859862 def input_chunk_range (chunk_idx : int ) -> tuple [int , int ]:
@@ -892,12 +895,8 @@ def input_chunk_range(chunk_idx: int) -> tuple[int, int]:
892895 activation = activation ,
893896 global_num_experts = global_num_experts ,
894897 expert_map = expert_map ,
895- w1_scale = w1_scale ,
896- w2_scale = w2_scale ,
897- w1_zp = w1_zp ,
898- w2_zp = w2_zp ,
899898 a1q_scale = _chunk_scales (a1q_scale , s , e ),
900- a2_scale = _chunk_scales (a2_scale , e , e ),
899+ a2_scale = _chunk_scales (self . fused_experts . a2_scale , e , e ),
901900 workspace13 = workspace13 ,
902901 workspace2 = workspace2 ,
903902 expert_tokens_meta = c_expert_tokens_meta ,
@@ -918,7 +917,7 @@ def input_chunk_range(chunk_idx: int) -> tuple[int, int]:
918917 self .fused_experts .finalize_weight_and_reduce_impl (),
919918 )
920919 if self .shared_experts is not None :
921- shared_output = self .shared_experts (a1 )
920+ shared_output = self .shared_experts (hidden_states )
922921 else :
923922 recv_hook = self .prepare_finalize .finalize_async (
924923 output ,
@@ -930,7 +929,7 @@ def input_chunk_range(chunk_idx: int) -> tuple[int, int]:
930929 )
931930
932931 if self .shared_experts is not None :
933- shared_output = self .shared_experts (a1 )
932+ shared_output = self .shared_experts (hidden_states )
934933
935934 assert recv_hook is not None
936935 dbo_register_recv_hook (recv_hook )
0 commit comments