@@ -343,10 +343,8 @@ def _forward_ms_op_gate(
343343    def  _forward_ms_op_tp_allgather (
344344        self ,
345345        hidden_states : torch .Tensor ,
346-         shared_output : torch .Tensor ,
347346        chunk_hidden_states : torch .Tensor ,
348347        num_tokens : int  =  0 ,
349-         hidden_dim : int  =  0 ,
350348    ):
351349
352350        if  self .tp_size  >  1 :
@@ -373,11 +371,6 @@ def _forward_ms_op_tp_allgather(
373371        else :
374372            final_hidden_states  =  hidden_states 
375373
376-             if  shared_output  is  not None :
377-                 final_hidden_states  =  final_hidden_states  +  shared_output 
378-             final_hidden_states  =  final_hidden_states .view (
379-                 num_tokens , hidden_dim )
380- 
381374        return  final_hidden_states 
382375
383376
@@ -744,7 +737,7 @@ def _forward_ms_layer(
744737
745738                num_token , hidden_dim  =  hidden_states [i ].shape 
746739                hidden_states [i ] =  hidden_states [i ].view (- 1 , hidden_dim )
747-                 # num_tokens.append(num_token)
740+                 num_tokens .append (num_token )
748741                hidden_dims .append (hidden_dim )
749742                if  self .mlp .n_shared_experts  is  not None :
750743                    # TODO: we can move shared expert computation into next block if reduce results is false 
@@ -780,7 +773,6 @@ def _forward_ms_layer(
780773                if  padded_num_tokens  >  0 :
781774                    hidden_states [i ] =  nn .functional .pad (
782775                        hidden_states [i ], (0 , 0 , 0 , padded_num_tokens ))
783-                 num_tokens .append (padded_num_tokens )
784776                chunk_hidden_state  =  torch .tensor_split (hidden_states [i ],
785777                                                        self .mlp .tp_size ,
786778                                                        dim = 0 )
@@ -839,9 +831,13 @@ def _forward_ms_layer(
839831            with  set_multistream_context (context , i ):
840832                hidden_states [i ] =  self .mlp ._forward_ms_op_tp_allgather (
841833                    hidden_states [i ], shared_outputs [i ],
842-                     chunk_hidden_states [i ], num_tokens [ i ] , hidden_dims [i ])
834+                     chunk_hidden_states [i ], padded_num_tokens , hidden_dims [i ])
843835            with  torch .npu .stream (ms_metadata .communicate_stream ):
844836                # last 
837+                 if  shared_output  is  not None :
838+                     hidden_states [i ] =  hidden_states [i ] +  shared_outputs [i ]
839+                 hidden_states [i ] =  hidden_states [i ].view (
840+                     num_tokens [i ], hidden_dims [i ])
845841                if  isinstance (self .mlp , CustomDeepseekV2MLP 
846842                              ) and  hidden_states [i ].dtype  ==  torch .float16 :
847843                    # Fix FP16 overflow 
0 commit comments