Skip to content

Commit f44b88d

Browse files
committed
[fix]: fix accuracy issues for dbo in deepseek
Signed-off-by: zhuohuan <[email protected]>
1 parent 8c24a6f commit f44b88d

File tree

1 file changed

+6
-10
lines changed

1 file changed

+6
-10
lines changed

vllm_ascend/models/deepseek_v2.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -343,10 +343,8 @@ def _forward_ms_op_gate(
343343
def _forward_ms_op_tp_allgather(
344344
self,
345345
hidden_states: torch.Tensor,
346-
shared_output: torch.Tensor,
347346
chunk_hidden_states: torch.Tensor,
348347
num_tokens: int = 0,
349-
hidden_dim: int = 0,
350348
):
351349

352350
if self.tp_size > 1:
@@ -373,11 +371,6 @@ def _forward_ms_op_tp_allgather(
373371
else:
374372
final_hidden_states = hidden_states
375373

376-
if shared_output is not None:
377-
final_hidden_states = final_hidden_states + shared_output
378-
final_hidden_states = final_hidden_states.view(
379-
num_tokens, hidden_dim)
380-
381374
return final_hidden_states
382375

383376

@@ -744,7 +737,7 @@ def _forward_ms_layer(
744737

745738
num_token, hidden_dim = hidden_states[i].shape
746739
hidden_states[i] = hidden_states[i].view(-1, hidden_dim)
747-
#num_tokens.append(num_token)
740+
num_tokens.append(num_token)
748741
hidden_dims.append(hidden_dim)
749742
if self.mlp.n_shared_experts is not None:
750743
# TODO: we can move shared expert computation into next block if reduce results is false
@@ -780,7 +773,6 @@ def _forward_ms_layer(
780773
if padded_num_tokens > 0:
781774
hidden_states[i] = nn.functional.pad(
782775
hidden_states[i], (0, 0, 0, padded_num_tokens))
783-
num_tokens.append(padded_num_tokens)
784776
chunk_hidden_state = torch.tensor_split(hidden_states[i],
785777
self.mlp.tp_size,
786778
dim=0)
@@ -839,9 +831,13 @@ def _forward_ms_layer(
839831
with set_multistream_context(context, i):
840832
hidden_states[i] = self.mlp._forward_ms_op_tp_allgather(
841833
hidden_states[i], shared_outputs[i],
842-
chunk_hidden_states[i], num_tokens[i], hidden_dims[i])
834+
chunk_hidden_states[i], padded_num_tokens, hidden_dims[i])
843835
with torch.npu.stream(ms_metadata.communicate_stream):
844836
# last
837+
if shared_output is not None:
838+
hidden_states[i] = hidden_states[i] + shared_outputs[i]
839+
hidden_states[i] = hidden_states[i].view(
840+
num_tokens[i], hidden_dims[i])
845841
if isinstance(self.mlp, CustomDeepseekV2MLP
846842
) and hidden_states[i].dtype == torch.float16:
847843
# Fix FP16 overflow

0 commit comments

Comments
 (0)